triton-windows 3.3.1.post19__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/compiler/compiler.py
CHANGED
|
@@ -3,19 +3,19 @@ import hashlib
|
|
|
3
3
|
import json
|
|
4
4
|
from .._C.libtriton import get_cache_invalidating_env_vars, ir
|
|
5
5
|
from ..backends import backends
|
|
6
|
-
from ..backends.compiler import
|
|
7
|
-
from .. import
|
|
6
|
+
from ..backends.compiler import Language
|
|
7
|
+
from ..backends.compiler import BaseBackend, GPUTarget
|
|
8
|
+
from .. import __version__, knobs
|
|
8
9
|
from ..runtime.autotuner import OutOfResources
|
|
9
|
-
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
|
|
10
|
+
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager, get_cache_key
|
|
10
11
|
from ..runtime.driver import driver
|
|
11
12
|
from ..tools.disasm import get_sass
|
|
12
|
-
# TODO: this shouldn't be here
|
|
13
|
-
from .code_generator import ast_to_ttir
|
|
14
13
|
from pathlib import Path
|
|
15
14
|
import re
|
|
16
15
|
import functools
|
|
17
16
|
import os
|
|
18
|
-
import
|
|
17
|
+
import time
|
|
18
|
+
import copy
|
|
19
19
|
|
|
20
20
|
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
|
21
21
|
# and any following whitespace
|
|
@@ -53,6 +53,7 @@ class ASTSource:
|
|
|
53
53
|
|
|
54
54
|
def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
|
|
55
55
|
self.fn = fn
|
|
56
|
+
self.language = Language.TRITON
|
|
56
57
|
self.ext = "ttir"
|
|
57
58
|
self.name = fn.__name__
|
|
58
59
|
self.signature = signature
|
|
@@ -63,12 +64,9 @@ class ASTSource:
|
|
|
63
64
|
assert isinstance(k, tuple)
|
|
64
65
|
self.constants[k] = v
|
|
65
66
|
self.attrs = attrs or dict()
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
for k in self.signature.keys():
|
|
70
|
-
if not isinstance(k, str):
|
|
71
|
-
raise TypeError("Signature keys must be string")
|
|
67
|
+
for k in self.signature.keys():
|
|
68
|
+
if not isinstance(k, str):
|
|
69
|
+
raise TypeError("Signature keys must be string")
|
|
72
70
|
|
|
73
71
|
def hash(self):
|
|
74
72
|
sorted_sig = [v for k, v in sorted(self.signature.items())]
|
|
@@ -77,7 +75,8 @@ class ASTSource:
|
|
|
77
75
|
key = f"{self.fn.cache_key}-{str(self.attrs)}-{sorted_sig}-{constants_key}"
|
|
78
76
|
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
79
77
|
|
|
80
|
-
def make_ir(self, options, codegen_fns, module_map, context):
|
|
78
|
+
def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
|
|
79
|
+
from .code_generator import ast_to_ttir
|
|
81
80
|
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
|
|
82
81
|
module_map=module_map)
|
|
83
82
|
|
|
@@ -91,6 +90,7 @@ class IRSource:
|
|
|
91
90
|
self.path = path
|
|
92
91
|
path = Path(path)
|
|
93
92
|
self.ext = path.suffix[1:]
|
|
93
|
+
self.language = Language.TRITON
|
|
94
94
|
self.src = path.read_text()
|
|
95
95
|
ir.load_dialects(context)
|
|
96
96
|
backend.load_dialects(context)
|
|
@@ -114,7 +114,7 @@ class IRSource:
|
|
|
114
114
|
def hash(self):
|
|
115
115
|
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()
|
|
116
116
|
|
|
117
|
-
def make_ir(self, options, codegen_fns, module_map, context):
|
|
117
|
+
def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
|
|
118
118
|
self.module.context = context
|
|
119
119
|
return self.module
|
|
120
120
|
|
|
@@ -127,39 +127,8 @@ class IRSource:
|
|
|
127
127
|
|
|
128
128
|
|
|
129
129
|
@functools.lru_cache()
|
|
130
|
-
def
|
|
131
|
-
|
|
132
|
-
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
133
|
-
contents = []
|
|
134
|
-
# frontend
|
|
135
|
-
with open(__file__, "rb") as f:
|
|
136
|
-
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
137
|
-
# compiler
|
|
138
|
-
path_prefixes = [
|
|
139
|
-
(os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
|
|
140
|
-
(os.path.join(TRITON_PATH, "backends"), "triton.backends."),
|
|
141
|
-
]
|
|
142
|
-
for path, prefix in path_prefixes:
|
|
143
|
-
for lib in pkgutil.walk_packages([path], prefix=prefix):
|
|
144
|
-
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
|
145
|
-
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
146
|
-
|
|
147
|
-
# backend
|
|
148
|
-
libtriton_hash = hashlib.sha256()
|
|
149
|
-
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
|
|
150
|
-
with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
|
|
151
|
-
while True:
|
|
152
|
-
chunk = f.read(1024**2)
|
|
153
|
-
if not chunk:
|
|
154
|
-
break
|
|
155
|
-
libtriton_hash.update(chunk)
|
|
156
|
-
contents.append(libtriton_hash.hexdigest())
|
|
157
|
-
# language
|
|
158
|
-
language_path = os.path.join(TRITON_PATH, 'language')
|
|
159
|
-
for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
|
|
160
|
-
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
|
161
|
-
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
162
|
-
return f'{__version__}' + '-'.join(contents)
|
|
130
|
+
def max_shared_mem(device):
|
|
131
|
+
return driver.active.utils.get_device_properties(device)["max_shared_mem"]
|
|
163
132
|
|
|
164
133
|
|
|
165
134
|
def parse(full_name, ext, context):
|
|
@@ -179,7 +148,7 @@ def filter_traceback(e: BaseException):
|
|
|
179
148
|
|
|
180
149
|
These are uninteresting to the user -- "just show me *my* code!"
|
|
181
150
|
"""
|
|
182
|
-
if
|
|
151
|
+
if knobs.compilation.front_end_debugging:
|
|
183
152
|
return
|
|
184
153
|
|
|
185
154
|
if e.__cause__ is not None:
|
|
@@ -211,7 +180,50 @@ def filter_traceback(e: BaseException):
|
|
|
211
180
|
e.__traceback__ = frames[0]
|
|
212
181
|
|
|
213
182
|
|
|
214
|
-
|
|
183
|
+
class CompileTimer:
|
|
184
|
+
|
|
185
|
+
def __init__(self) -> None:
|
|
186
|
+
self.start: float = time.perf_counter()
|
|
187
|
+
self.ir_initialization_end: float | None = None
|
|
188
|
+
self.lowering_stage_ends: list[tuple[str, float]] = []
|
|
189
|
+
self.store_results_end: float | None = None
|
|
190
|
+
|
|
191
|
+
def finished_ir_initialization(self) -> None:
|
|
192
|
+
self.ir_initialization_end = time.perf_counter()
|
|
193
|
+
|
|
194
|
+
def stage_finished(self, stage_name: str) -> None:
|
|
195
|
+
self.lowering_stage_ends.append((stage_name, time.perf_counter()))
|
|
196
|
+
|
|
197
|
+
def end(self) -> knobs.CompileTimes:
|
|
198
|
+
timestamp = time.perf_counter()
|
|
199
|
+
if self.ir_initialization_end is None:
|
|
200
|
+
self.ir_initialization_end = timestamp
|
|
201
|
+
else:
|
|
202
|
+
self.store_results_end = timestamp
|
|
203
|
+
|
|
204
|
+
def delta(start: float, end: float | None) -> int:
|
|
205
|
+
if end is None:
|
|
206
|
+
return 0
|
|
207
|
+
return int((end - start) * 1000000)
|
|
208
|
+
|
|
209
|
+
lowering_stage_durations = []
|
|
210
|
+
stage_start = self.ir_initialization_end
|
|
211
|
+
for stage_name, stage_end in self.lowering_stage_ends:
|
|
212
|
+
lowering_stage_durations.append((stage_name, delta(stage_start, stage_end)))
|
|
213
|
+
stage_start = stage_end
|
|
214
|
+
|
|
215
|
+
return knobs.CompileTimes(
|
|
216
|
+
ir_initialization=delta(self.start, self.ir_initialization_end),
|
|
217
|
+
lowering_stages=lowering_stage_durations,
|
|
218
|
+
store_results=delta(stage_start, self.store_results_end),
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def compile(src, target=None, options=None, _env_vars=None):
|
|
223
|
+
compilation_listener = knobs.compilation.listener
|
|
224
|
+
if compilation_listener:
|
|
225
|
+
timer = CompileTimer()
|
|
226
|
+
|
|
215
227
|
if target is None:
|
|
216
228
|
target = driver.active.get_current_target()
|
|
217
229
|
assert isinstance(target, GPUTarget), "target must be of GPUTarget type"
|
|
@@ -226,15 +238,15 @@ def compile(src, target=None, options=None):
|
|
|
226
238
|
extra_options = src.parse_options()
|
|
227
239
|
options = backend.parse_options(dict(options or dict(), **extra_options))
|
|
228
240
|
# create cache manager
|
|
229
|
-
env_vars = get_cache_invalidating_env_vars()
|
|
230
|
-
key =
|
|
241
|
+
env_vars = get_cache_invalidating_env_vars() if _env_vars is None else _env_vars
|
|
242
|
+
key = get_cache_key(src, backend, options, env_vars=env_vars)
|
|
231
243
|
hash = hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
232
244
|
fn_cache_manager = get_cache_manager(hash)
|
|
233
245
|
# For dumping/overriding only hash the source as we want it to be independent of triton
|
|
234
246
|
# core changes to make it easier to track kernels by hash.
|
|
235
|
-
enable_override =
|
|
236
|
-
enable_ir_dump =
|
|
237
|
-
store_only_binary =
|
|
247
|
+
enable_override = knobs.compilation.override
|
|
248
|
+
enable_ir_dump = knobs.compilation.dump_ir
|
|
249
|
+
store_only_binary = knobs.compilation.store_binary_only
|
|
238
250
|
fn_override_manager = get_override_manager(src.hash()) if enable_override else None
|
|
239
251
|
fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None
|
|
240
252
|
# Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms.
|
|
@@ -245,10 +257,20 @@ def compile(src, target=None, options=None):
|
|
|
245
257
|
metadata_filename = f"{file_name}.json"
|
|
246
258
|
metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
|
|
247
259
|
metadata_path = metadata_group.get(metadata_filename)
|
|
248
|
-
always_compile =
|
|
260
|
+
always_compile = knobs.compilation.always_compile
|
|
249
261
|
if not always_compile and metadata_path is not None:
|
|
250
262
|
# cache hit!
|
|
251
|
-
|
|
263
|
+
res = CompiledKernel(src, metadata_group, hash)
|
|
264
|
+
if compilation_listener:
|
|
265
|
+
compilation_listener(
|
|
266
|
+
src=src,
|
|
267
|
+
metadata=res.metadata._asdict(),
|
|
268
|
+
metadata_group=metadata_group,
|
|
269
|
+
times=timer.end(),
|
|
270
|
+
cache_hit=True,
|
|
271
|
+
)
|
|
272
|
+
return res
|
|
273
|
+
|
|
252
274
|
# initialize metadata
|
|
253
275
|
metadata = {
|
|
254
276
|
"hash": hash,
|
|
@@ -259,7 +281,7 @@ def compile(src, target=None, options=None):
|
|
|
259
281
|
metadata["triton_version"] = __version__
|
|
260
282
|
# run compilation pipeline and populate metadata
|
|
261
283
|
stages = dict()
|
|
262
|
-
backend.add_stages(stages, options)
|
|
284
|
+
backend.add_stages(stages, options, src.language)
|
|
263
285
|
first_stage = list(stages.keys()).index(src.ext)
|
|
264
286
|
# when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
|
|
265
287
|
if ir_source:
|
|
@@ -275,15 +297,34 @@ def compile(src, target=None, options=None):
|
|
|
275
297
|
codegen_fns = backend.get_codegen_implementation(options)
|
|
276
298
|
module_map = backend.get_module_map()
|
|
277
299
|
try:
|
|
278
|
-
module = src.make_ir(options, codegen_fns, module_map, context)
|
|
300
|
+
module = src.make_ir(target, options, codegen_fns, module_map, context)
|
|
279
301
|
except Exception as e:
|
|
280
302
|
filter_traceback(e)
|
|
281
303
|
raise
|
|
282
|
-
|
|
304
|
+
|
|
305
|
+
if ir_source:
|
|
306
|
+
ir_filename = f"{file_name}.{src.ext}"
|
|
307
|
+
metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename)
|
|
308
|
+
else:
|
|
309
|
+
ir_filename = f"{file_name}.source"
|
|
310
|
+
metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename)
|
|
311
|
+
|
|
312
|
+
use_ir_loc = knobs.compilation.use_ir_loc
|
|
313
|
+
if ir_source and use_ir_loc:
|
|
314
|
+
module.create_location_snapshot(src.path)
|
|
315
|
+
print(f"Creating new locations for {src.path}")
|
|
316
|
+
|
|
317
|
+
if compilation_listener:
|
|
318
|
+
timer.finished_ir_initialization()
|
|
283
319
|
for ext, compile_ir in list(stages.items())[first_stage:]:
|
|
284
320
|
next_module = compile_ir(module, metadata)
|
|
285
321
|
ir_filename = f"{file_name}.{ext}"
|
|
286
|
-
if
|
|
322
|
+
if fn_override_manager is None:
|
|
323
|
+
# Users can override kernels at scale by setting `ir_override` in autotune config
|
|
324
|
+
# without TRITON_KERNEL_OVERRIDE
|
|
325
|
+
if (ir_override := metadata.get("ir_override", None)) and ir_override.endswith(f".{ext}"):
|
|
326
|
+
next_module = parse(ir_override, ext, context)
|
|
327
|
+
elif full_name := fn_override_manager.get_file(ir_filename):
|
|
287
328
|
print(f"\nOverriding kernel with file {full_name}")
|
|
288
329
|
next_module = parse(full_name, ext, context)
|
|
289
330
|
# If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json
|
|
@@ -291,12 +332,17 @@ def compile(src, target=None, options=None):
|
|
|
291
332
|
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
|
|
292
333
|
if fn_dump_manager is not None:
|
|
293
334
|
fn_dump_manager.put(next_module, ir_filename)
|
|
335
|
+
if ext == "cubin":
|
|
336
|
+
sass = get_sass(next_module)
|
|
337
|
+
fn_dump_manager.put(sass, file_name + ".sass")
|
|
294
338
|
# use an env variable to parse ir from file
|
|
295
339
|
if use_ir_loc == ext:
|
|
296
340
|
ir_full_name = fn_cache_manager.get_file(ir_filename)
|
|
297
341
|
next_module.create_location_snapshot(ir_full_name)
|
|
298
342
|
print(f"Creating new locations for {ir_full_name}")
|
|
299
343
|
module = next_module
|
|
344
|
+
if compilation_listener:
|
|
345
|
+
timer.stage_finished(ext)
|
|
300
346
|
# write-back metadata
|
|
301
347
|
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
|
|
302
348
|
binary=False)
|
|
@@ -310,13 +356,18 @@ def compile(src, target=None, options=None):
|
|
|
310
356
|
# this is likely due to the llvm-symbolizer forking a process
|
|
311
357
|
# TODO: Reconcile the difference here between the ASAN and non-ASAN path with enabling
|
|
312
358
|
# multithreading in the MLIR context
|
|
313
|
-
if not
|
|
359
|
+
if not knobs.compilation.enable_asan:
|
|
314
360
|
context.disable_multithreading()
|
|
361
|
+
|
|
362
|
+
# notify any listener
|
|
363
|
+
if compilation_listener:
|
|
364
|
+
compilation_listener(src=src, metadata=metadata, metadata_group=metadata_group, times=timer.end(),
|
|
365
|
+
cache_hit=False)
|
|
315
366
|
# return handle to compiled kernel
|
|
316
367
|
return CompiledKernel(src, metadata_group, hash)
|
|
317
368
|
|
|
318
369
|
|
|
319
|
-
def make_backend(target):
|
|
370
|
+
def make_backend(target: GPUTarget) -> BaseBackend:
|
|
320
371
|
actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)]
|
|
321
372
|
if len(actives) != 1:
|
|
322
373
|
raise RuntimeError(
|
|
@@ -330,7 +381,7 @@ class LazyDict:
|
|
|
330
381
|
self.data = data
|
|
331
382
|
self.extras = []
|
|
332
383
|
|
|
333
|
-
def get(self)
|
|
384
|
+
def get(self):
|
|
334
385
|
for func, args in self.extras:
|
|
335
386
|
self.data = self.data | func(*args)
|
|
336
387
|
self.extras.clear()
|
|
@@ -353,12 +404,11 @@ class AsmDict(dict):
|
|
|
353
404
|
return value
|
|
354
405
|
|
|
355
406
|
|
|
356
|
-
|
|
407
|
+
def _raise_error(err, *args, **kwargs):
|
|
408
|
+
raise copy.deepcopy(err)
|
|
357
409
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
launch_enter_hook = None
|
|
361
|
-
launch_exit_hook = None
|
|
410
|
+
|
|
411
|
+
class CompiledKernel:
|
|
362
412
|
|
|
363
413
|
def __init__(self, src, metadata_group, hash):
|
|
364
414
|
from collections import namedtuple
|
|
@@ -382,48 +432,66 @@ class CompiledKernel:
|
|
|
382
432
|
file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text()
|
|
383
433
|
for file in asm_files
|
|
384
434
|
})
|
|
435
|
+
self.metadata_group = metadata_group
|
|
385
436
|
self.kernel = self.asm[binary_ext]
|
|
386
437
|
# binaries are lazily initialized
|
|
387
438
|
# because it involves doing runtime things
|
|
388
439
|
# (e.g., checking amount of shared memory on current device)
|
|
389
440
|
self.module = None
|
|
390
441
|
self.function = None
|
|
442
|
+
self._run = None
|
|
391
443
|
|
|
392
444
|
def _init_handles(self):
|
|
393
445
|
if self.module is not None:
|
|
394
446
|
return
|
|
447
|
+
|
|
448
|
+
def raise_(err):
|
|
449
|
+
# clone the exception object so that the one saved in the closure
|
|
450
|
+
# of the partial function below doesn't get assigned a stack trace
|
|
451
|
+
# after the subsequent raise. otherwise, the CompiledKernel instance
|
|
452
|
+
# saved in the (global) kernel cache will keep references to all the
|
|
453
|
+
# locals in the traceback via the exception instance in the closure.
|
|
454
|
+
cloned_err = copy.deepcopy(err)
|
|
455
|
+
self._run = functools.partial(_raise_error, cloned_err)
|
|
456
|
+
raise err
|
|
457
|
+
|
|
395
458
|
device = driver.active.get_current_device()
|
|
396
459
|
# create launcher
|
|
397
|
-
self.
|
|
460
|
+
self._run = driver.active.launcher_cls(self.src, self.metadata)
|
|
398
461
|
# not enough shared memory to run the kernel
|
|
399
|
-
max_shared =
|
|
462
|
+
max_shared = max_shared_mem(device)
|
|
400
463
|
if self.metadata.shared > max_shared:
|
|
401
|
-
|
|
464
|
+
raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
|
|
402
465
|
if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None:
|
|
403
466
|
# Use blackwell max tmem size for now, this should be moved in device properties
|
|
404
467
|
max_tmem_size = 512 # tmem size in number of columns
|
|
405
468
|
if self.metadata.tmem_size > max_tmem_size:
|
|
406
|
-
|
|
469
|
+
raise_(OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory"))
|
|
470
|
+
if knobs.runtime.kernel_load_start_hook is not None:
|
|
471
|
+
knobs.runtime.kernel_load_start_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
|
|
407
472
|
# TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
|
|
408
|
-
self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
|
|
473
|
+
self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary(
|
|
409
474
|
self.name, self.kernel, self.metadata.shared, device)
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
475
|
+
warp_size = driver.active.get_current_target().warp_size
|
|
476
|
+
if self.metadata.num_warps * warp_size > self.n_max_threads:
|
|
477
|
+
raise_(OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads"))
|
|
478
|
+
if knobs.runtime.kernel_load_end_hook is not None:
|
|
479
|
+
knobs.runtime.kernel_load_end_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
|
|
480
|
+
|
|
481
|
+
@property
|
|
482
|
+
def run(self):
|
|
483
|
+
if self._run is None:
|
|
413
484
|
self._init_handles()
|
|
414
|
-
return
|
|
485
|
+
return self._run
|
|
415
486
|
|
|
416
487
|
def launch_metadata(self, grid, stream, *args):
|
|
417
|
-
if
|
|
488
|
+
if knobs.runtime.launch_enter_hook is None:
|
|
418
489
|
return None
|
|
490
|
+
self._init_handles()
|
|
419
491
|
ret = LazyDict({"name": self.name, "function": self.function, "stream": stream})
|
|
420
492
|
if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None:
|
|
421
493
|
return ret
|
|
422
|
-
arg_dict = {}
|
|
423
|
-
arg_idx = 0
|
|
424
|
-
for i, arg_name in enumerate(self.src.fn.arg_names):
|
|
425
|
-
arg_dict[arg_name] = args[arg_idx]
|
|
426
|
-
arg_idx += 1
|
|
494
|
+
arg_dict = {name: arg for name, arg in zip(self.src.fn.arg_names, args)}
|
|
427
495
|
ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict))
|
|
428
496
|
return ret
|
|
429
497
|
|
|
@@ -436,6 +504,6 @@ class CompiledKernel:
|
|
|
436
504
|
stream = driver.active.get_current_stream(device)
|
|
437
505
|
launch_metadata = self.launch_metadata(grid, stream, *args)
|
|
438
506
|
self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata,
|
|
439
|
-
|
|
507
|
+
knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *args)
|
|
440
508
|
|
|
441
509
|
return runner
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from triton.compiler.compiler import ASTSource
|
|
3
|
+
from triton.backends.compiler import Language
|
|
4
|
+
from triton.runtime.jit import JITFunction, constexpr_function
|
|
5
|
+
from typing import TypeVar, Optional, Callable, Iterable, Union
|
|
6
|
+
from triton._C.libtriton import ir
|
|
7
|
+
|
|
8
|
+
T = TypeVar("T")
|
|
9
|
+
|
|
10
|
+
__all__ = ["constexpr_function", "jit"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GluonASTSource(ASTSource):
|
|
14
|
+
|
|
15
|
+
def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
|
|
16
|
+
super().__init__(fn, signature, constexprs, attrs)
|
|
17
|
+
self.language = Language.GLUON
|
|
18
|
+
self.ext = "ttgir"
|
|
19
|
+
|
|
20
|
+
def make_ir(self, target, options, codegen_fns, module_map, context):
|
|
21
|
+
from triton.compiler.compiler import make_backend
|
|
22
|
+
from triton.compiler.code_generator import ast_to_ttir
|
|
23
|
+
|
|
24
|
+
builder = ir.builder(context)
|
|
25
|
+
module = builder.create_module()
|
|
26
|
+
|
|
27
|
+
# Assign module attributes eagerly, as they are needed to verify layouts
|
|
28
|
+
backend = make_backend(target)
|
|
29
|
+
target = backend.get_target_name(options)
|
|
30
|
+
|
|
31
|
+
module.set_attr("ttg.target", builder.get_string_attr(target))
|
|
32
|
+
module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps))
|
|
33
|
+
module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas))
|
|
34
|
+
module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(options.warp_size))
|
|
35
|
+
|
|
36
|
+
is_cuda = options.backend_name == "cuda"
|
|
37
|
+
if is_cuda and options.maxnreg is not None:
|
|
38
|
+
module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg))
|
|
39
|
+
|
|
40
|
+
module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
|
|
41
|
+
module_map=module_map, module=module)
|
|
42
|
+
return module
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class GluonJITFunction(JITFunction[T]):
|
|
46
|
+
|
|
47
|
+
def create_binder(self):
|
|
48
|
+
result = super().create_binder()
|
|
49
|
+
self.ASTSource = GluonASTSource
|
|
50
|
+
return result
|
|
51
|
+
|
|
52
|
+
def is_gluon(self):
|
|
53
|
+
return True
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def jit(
|
|
57
|
+
fn: Optional[T] = None,
|
|
58
|
+
*,
|
|
59
|
+
version=None,
|
|
60
|
+
repr: Optional[Callable] = None,
|
|
61
|
+
launch_metadata: Optional[Callable] = None,
|
|
62
|
+
do_not_specialize: Optional[Iterable[int | str]] = None,
|
|
63
|
+
do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
|
|
64
|
+
debug: Optional[bool] = None,
|
|
65
|
+
noinline: Optional[bool] = None,
|
|
66
|
+
) -> Union[GluonJITFunction[T], Callable[[T], JITFunction[T]]]:
|
|
67
|
+
"""
|
|
68
|
+
Decorator for JIT-compiling a function using the Triton compiler.
|
|
69
|
+
|
|
70
|
+
:note: When a jit'd function is called, arguments are
|
|
71
|
+
implicitly converted to pointers if they have a :code:`.data_ptr()` method
|
|
72
|
+
and a `.dtype` attribute.
|
|
73
|
+
|
|
74
|
+
:note: This function will be compiled and run on the GPU. It will only have access to:
|
|
75
|
+
|
|
76
|
+
* python primitives,
|
|
77
|
+
* builtins within the triton package,
|
|
78
|
+
* arguments to this function,
|
|
79
|
+
* other jit'd functions
|
|
80
|
+
|
|
81
|
+
:param fn: the function to be jit-compiled
|
|
82
|
+
:type fn: Callable
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def decorator(fn: T) -> JITFunction[T]:
|
|
86
|
+
assert callable(fn)
|
|
87
|
+
return GluonJITFunction(
|
|
88
|
+
fn,
|
|
89
|
+
version=version,
|
|
90
|
+
do_not_specialize=do_not_specialize,
|
|
91
|
+
do_not_specialize_on_alignment=do_not_specialize_on_alignment,
|
|
92
|
+
debug=debug,
|
|
93
|
+
noinline=noinline,
|
|
94
|
+
repr=repr,
|
|
95
|
+
launch_metadata=launch_metadata,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
if fn is not None:
|
|
99
|
+
return decorator(fn)
|
|
100
|
+
|
|
101
|
+
else:
|
|
102
|
+
return decorator
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from ._core import (
|
|
2
|
+
base_value,
|
|
3
|
+
base_type,
|
|
4
|
+
block_type,
|
|
5
|
+
broadcast,
|
|
6
|
+
constexpr,
|
|
7
|
+
dtype,
|
|
8
|
+
void,
|
|
9
|
+
int1,
|
|
10
|
+
int8,
|
|
11
|
+
int16,
|
|
12
|
+
int32,
|
|
13
|
+
int64,
|
|
14
|
+
uint8,
|
|
15
|
+
uint16,
|
|
16
|
+
uint32,
|
|
17
|
+
uint64,
|
|
18
|
+
float8e5,
|
|
19
|
+
float8e5b16,
|
|
20
|
+
float8e4nv,
|
|
21
|
+
float8e4b8,
|
|
22
|
+
float8e4b15,
|
|
23
|
+
float16,
|
|
24
|
+
bfloat16,
|
|
25
|
+
float32,
|
|
26
|
+
float64,
|
|
27
|
+
pointer_type,
|
|
28
|
+
shared_memory_descriptor,
|
|
29
|
+
tensor,
|
|
30
|
+
tuple,
|
|
31
|
+
tuple_type,
|
|
32
|
+
_unwrap_if_constexpr,
|
|
33
|
+
# API Functions
|
|
34
|
+
allocate_shared_memory,
|
|
35
|
+
arange,
|
|
36
|
+
associative_scan,
|
|
37
|
+
atomic_add,
|
|
38
|
+
atomic_and,
|
|
39
|
+
atomic_cas,
|
|
40
|
+
atomic_max,
|
|
41
|
+
atomic_min,
|
|
42
|
+
atomic_or,
|
|
43
|
+
atomic_xchg,
|
|
44
|
+
atomic_xor,
|
|
45
|
+
convert_layout,
|
|
46
|
+
device_assert,
|
|
47
|
+
expand_dims,
|
|
48
|
+
full,
|
|
49
|
+
histogram,
|
|
50
|
+
inline_asm_elementwise,
|
|
51
|
+
join,
|
|
52
|
+
load,
|
|
53
|
+
map_elementwise,
|
|
54
|
+
max_constancy,
|
|
55
|
+
max_contiguous,
|
|
56
|
+
maximum,
|
|
57
|
+
minimum,
|
|
58
|
+
multiple_of,
|
|
59
|
+
num_programs,
|
|
60
|
+
permute,
|
|
61
|
+
program_id,
|
|
62
|
+
reduce,
|
|
63
|
+
reshape,
|
|
64
|
+
set_auto_layout,
|
|
65
|
+
split,
|
|
66
|
+
static_assert,
|
|
67
|
+
static_print,
|
|
68
|
+
static_range,
|
|
69
|
+
store,
|
|
70
|
+
thread_barrier,
|
|
71
|
+
to_tensor,
|
|
72
|
+
warp_specialize,
|
|
73
|
+
where,
|
|
74
|
+
)
|
|
75
|
+
from ._layouts import (
|
|
76
|
+
AutoLayout,
|
|
77
|
+
BlockedLayout,
|
|
78
|
+
SliceLayout,
|
|
79
|
+
DistributedLinearLayout,
|
|
80
|
+
DotOperandLayout,
|
|
81
|
+
NVMMADistributedLayout,
|
|
82
|
+
NVMMASharedLayout,
|
|
83
|
+
SwizzledSharedLayout,
|
|
84
|
+
PaddedSharedLayout,
|
|
85
|
+
)
|
|
86
|
+
from ._math import (
|
|
87
|
+
umulhi,
|
|
88
|
+
exp,
|
|
89
|
+
exp2,
|
|
90
|
+
fma,
|
|
91
|
+
log,
|
|
92
|
+
log2,
|
|
93
|
+
cos,
|
|
94
|
+
rsqrt,
|
|
95
|
+
sin,
|
|
96
|
+
sqrt,
|
|
97
|
+
sqrt_rn,
|
|
98
|
+
abs,
|
|
99
|
+
fdiv,
|
|
100
|
+
div_rn,
|
|
101
|
+
erf,
|
|
102
|
+
floor,
|
|
103
|
+
ceil,
|
|
104
|
+
)
|
|
105
|
+
from ._standard import (
|
|
106
|
+
cdiv,
|
|
107
|
+
full_like,
|
|
108
|
+
max,
|
|
109
|
+
min,
|
|
110
|
+
reduce_or,
|
|
111
|
+
sum,
|
|
112
|
+
xor_sum,
|
|
113
|
+
zeros,
|
|
114
|
+
zeros_like,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
from . import nvidia
|
|
118
|
+
from . import amd
|
|
119
|
+
from . import extra
|