triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +82 -0
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +255 -0
- triton/_utils.py +126 -0
- triton/backends/__init__.py +47 -0
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +461 -0
- triton/backends/amd/driver.c +283 -0
- triton/backends/amd/driver.py +724 -0
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +90 -0
- triton/backends/driver.py +66 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +533 -0
- triton/backends/nvidia/driver.c +517 -0
- triton/backends/nvidia/driver.py +799 -0
- triton/backends/nvidia/include/cuda.h +26280 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +7 -0
- triton/compiler/code_generator.py +1614 -0
- triton/compiler/compiler.py +509 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +342 -0
- triton/language/core.py +3405 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +16 -0
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +5 -0
- triton/language/extra/hip/libdevice.py +491 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +790 -0
- triton/language/math.py +249 -0
- triton/language/random.py +218 -0
- triton/language/semantic.py +1939 -0
- triton/language/standard.py +534 -0
- triton/language/target_info.py +54 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/_allocation.py +44 -0
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +476 -0
- triton/runtime/build.py +168 -0
- triton/runtime/cache.py +317 -0
- triton/runtime/driver.py +38 -0
- triton/runtime/errors.py +36 -0
- triton/runtime/interpreter.py +1414 -0
- triton/runtime/jit.py +1107 -0
- triton/runtime/tcc/include/_mingw.h +168 -0
- triton/runtime/tcc/include/assert.h +62 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +75 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +116 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +497 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +14 -0
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +42 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +591 -0
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +123 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +2958 -0
- triton/runtime/tcc/include/winapi/wincon.h +309 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +5837 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +543 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.py +210 -0
- triton/tools/disasm.py +143 -0
- triton/tools/extra/cuda/compile.c +70 -0
- triton/tools/extra/cuda/compile.h +14 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/link.py +322 -0
- triton/tools/mxfp.py +301 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +405 -0
- triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
- triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
- triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from ._core import (
|
|
2
|
+
base_value,
|
|
3
|
+
base_type,
|
|
4
|
+
block_type,
|
|
5
|
+
broadcast,
|
|
6
|
+
constexpr,
|
|
7
|
+
dtype,
|
|
8
|
+
void,
|
|
9
|
+
int1,
|
|
10
|
+
int8,
|
|
11
|
+
int16,
|
|
12
|
+
int32,
|
|
13
|
+
int64,
|
|
14
|
+
uint8,
|
|
15
|
+
uint16,
|
|
16
|
+
uint32,
|
|
17
|
+
uint64,
|
|
18
|
+
float8e5,
|
|
19
|
+
float8e5b16,
|
|
20
|
+
float8e4nv,
|
|
21
|
+
float8e4b8,
|
|
22
|
+
float8e4b15,
|
|
23
|
+
float16,
|
|
24
|
+
bfloat16,
|
|
25
|
+
float32,
|
|
26
|
+
float64,
|
|
27
|
+
pointer_type,
|
|
28
|
+
shared_memory_descriptor,
|
|
29
|
+
tensor,
|
|
30
|
+
tuple,
|
|
31
|
+
tuple_type,
|
|
32
|
+
_unwrap_if_constexpr,
|
|
33
|
+
# API Functions
|
|
34
|
+
allocate_shared_memory,
|
|
35
|
+
arange,
|
|
36
|
+
associative_scan,
|
|
37
|
+
atomic_add,
|
|
38
|
+
atomic_and,
|
|
39
|
+
atomic_cas,
|
|
40
|
+
atomic_max,
|
|
41
|
+
atomic_min,
|
|
42
|
+
atomic_or,
|
|
43
|
+
atomic_xchg,
|
|
44
|
+
atomic_xor,
|
|
45
|
+
convert_layout,
|
|
46
|
+
device_assert,
|
|
47
|
+
expand_dims,
|
|
48
|
+
full,
|
|
49
|
+
histogram,
|
|
50
|
+
inline_asm_elementwise,
|
|
51
|
+
join,
|
|
52
|
+
load,
|
|
53
|
+
map_elementwise,
|
|
54
|
+
max_constancy,
|
|
55
|
+
max_contiguous,
|
|
56
|
+
maximum,
|
|
57
|
+
minimum,
|
|
58
|
+
multiple_of,
|
|
59
|
+
num_programs,
|
|
60
|
+
permute,
|
|
61
|
+
program_id,
|
|
62
|
+
reduce,
|
|
63
|
+
reshape,
|
|
64
|
+
set_auto_layout,
|
|
65
|
+
split,
|
|
66
|
+
static_assert,
|
|
67
|
+
static_print,
|
|
68
|
+
static_range,
|
|
69
|
+
store,
|
|
70
|
+
thread_barrier,
|
|
71
|
+
to_tensor,
|
|
72
|
+
warp_specialize,
|
|
73
|
+
where,
|
|
74
|
+
)
|
|
75
|
+
from ._layouts import (
|
|
76
|
+
AutoLayout,
|
|
77
|
+
BlockedLayout,
|
|
78
|
+
SliceLayout,
|
|
79
|
+
DistributedLinearLayout,
|
|
80
|
+
DotOperandLayout,
|
|
81
|
+
NVMMADistributedLayout,
|
|
82
|
+
NVMMASharedLayout,
|
|
83
|
+
SwizzledSharedLayout,
|
|
84
|
+
PaddedSharedLayout,
|
|
85
|
+
)
|
|
86
|
+
from ._math import (
|
|
87
|
+
umulhi,
|
|
88
|
+
exp,
|
|
89
|
+
exp2,
|
|
90
|
+
fma,
|
|
91
|
+
log,
|
|
92
|
+
log2,
|
|
93
|
+
cos,
|
|
94
|
+
rsqrt,
|
|
95
|
+
sin,
|
|
96
|
+
sqrt,
|
|
97
|
+
sqrt_rn,
|
|
98
|
+
abs,
|
|
99
|
+
fdiv,
|
|
100
|
+
div_rn,
|
|
101
|
+
erf,
|
|
102
|
+
floor,
|
|
103
|
+
ceil,
|
|
104
|
+
)
|
|
105
|
+
from ._standard import (
|
|
106
|
+
cdiv,
|
|
107
|
+
full_like,
|
|
108
|
+
max,
|
|
109
|
+
min,
|
|
110
|
+
reduce_or,
|
|
111
|
+
sum,
|
|
112
|
+
xor_sum,
|
|
113
|
+
zeros,
|
|
114
|
+
zeros_like,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
from . import nvidia
|
|
118
|
+
from . import amd
|
|
119
|
+
from . import extra
|
|
@@ -0,0 +1,490 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import math
|
|
3
|
+
from typing import TypeVar, List, TYPE_CHECKING, Tuple
|
|
4
|
+
from functools import wraps
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from triton._C.libtriton.gluon_ir import GluonOpBuilder
|
|
8
|
+
from ._semantic import GluonSemantic
|
|
9
|
+
|
|
10
|
+
from ._layouts import SharedLayout, DistributedLayout
|
|
11
|
+
from triton._C.libtriton import ir
|
|
12
|
+
import triton.language.core as tl_core
|
|
13
|
+
from triton.language.core import (
|
|
14
|
+
constexpr,
|
|
15
|
+
base_value,
|
|
16
|
+
base_type,
|
|
17
|
+
dtype,
|
|
18
|
+
block_type, # TODO: block type with layout info
|
|
19
|
+
pointer_type,
|
|
20
|
+
void,
|
|
21
|
+
int1,
|
|
22
|
+
int8,
|
|
23
|
+
int16,
|
|
24
|
+
int32,
|
|
25
|
+
int64,
|
|
26
|
+
uint8,
|
|
27
|
+
uint16,
|
|
28
|
+
uint32,
|
|
29
|
+
uint64,
|
|
30
|
+
float8e5,
|
|
31
|
+
float8e5b16,
|
|
32
|
+
float8e4nv,
|
|
33
|
+
float8e4b8,
|
|
34
|
+
float8e4b15,
|
|
35
|
+
float16,
|
|
36
|
+
bfloat16,
|
|
37
|
+
float32,
|
|
38
|
+
float64,
|
|
39
|
+
_unwrap_if_constexpr,
|
|
40
|
+
_unwrap_shape,
|
|
41
|
+
static_range,
|
|
42
|
+
tensor,
|
|
43
|
+
tuple,
|
|
44
|
+
tuple_type,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# We define __all__ only to appease the python linter, these are not used in
|
|
48
|
+
# this file but we want to import them anyway so they are importable from here.
|
|
49
|
+
__all__ = [
|
|
50
|
+
"constexpr",
|
|
51
|
+
"pointer_type",
|
|
52
|
+
"void",
|
|
53
|
+
"int1",
|
|
54
|
+
"int8",
|
|
55
|
+
"int16",
|
|
56
|
+
"int32",
|
|
57
|
+
"int64",
|
|
58
|
+
"uint8",
|
|
59
|
+
"uint16",
|
|
60
|
+
"uint32",
|
|
61
|
+
"uint64",
|
|
62
|
+
"float8e5",
|
|
63
|
+
"float8e5b16",
|
|
64
|
+
"float8e4nv",
|
|
65
|
+
"float8e4b8",
|
|
66
|
+
"float8e4b15",
|
|
67
|
+
"float16",
|
|
68
|
+
"bfloat16",
|
|
69
|
+
"float32",
|
|
70
|
+
"float64",
|
|
71
|
+
"static_range",
|
|
72
|
+
"tuple",
|
|
73
|
+
"tuple_type",
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
T = TypeVar("T")
|
|
77
|
+
|
|
78
|
+
# TODO: split these
|
|
79
|
+
GLUON_BUILTIN = "__triton_builtin__"
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def builtin(fn: T) -> T:
|
|
83
|
+
"""Mark a function as a builtin."""
|
|
84
|
+
assert callable(fn)
|
|
85
|
+
|
|
86
|
+
@wraps(fn)
|
|
87
|
+
def wrapper(*args, **kwargs):
|
|
88
|
+
if "_semantic" not in kwargs or kwargs["_semantic"] is None:
|
|
89
|
+
raise ValueError("Did you forget to add @triton.gluon.jit ? "
|
|
90
|
+
"(`_semantic` argument must be provided outside of JIT functions.)")
|
|
91
|
+
return fn(*args, **kwargs)
|
|
92
|
+
|
|
93
|
+
setattr(wrapper, GLUON_BUILTIN, True)
|
|
94
|
+
|
|
95
|
+
return wrapper
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# Explicitly import forwarded Triton language symbols so mypy sees them.
|
|
99
|
+
associative_scan = builtin(tl_core.associative_scan)
|
|
100
|
+
atomic_add = builtin(tl_core.atomic_add)
|
|
101
|
+
atomic_and = builtin(tl_core.atomic_and)
|
|
102
|
+
atomic_cas = builtin(tl_core.atomic_cas)
|
|
103
|
+
atomic_max = builtin(tl_core.atomic_max)
|
|
104
|
+
atomic_min = builtin(tl_core.atomic_min)
|
|
105
|
+
atomic_or = builtin(tl_core.atomic_or)
|
|
106
|
+
atomic_xchg = builtin(tl_core.atomic_xchg)
|
|
107
|
+
atomic_xor = builtin(tl_core.atomic_xor)
|
|
108
|
+
broadcast = builtin(tl_core.broadcast)
|
|
109
|
+
device_assert = builtin(tl_core.device_assert)
|
|
110
|
+
expand_dims = builtin(tl_core.expand_dims)
|
|
111
|
+
inline_asm_elementwise = builtin(tl_core.inline_asm_elementwise)
|
|
112
|
+
join = builtin(tl_core.join)
|
|
113
|
+
load = builtin(tl_core.load)
|
|
114
|
+
map_elementwise = builtin(tl_core.map_elementwise)
|
|
115
|
+
max_constancy = builtin(tl_core.max_constancy)
|
|
116
|
+
max_contiguous = builtin(tl_core.max_contiguous)
|
|
117
|
+
maximum = builtin(tl_core.maximum)
|
|
118
|
+
minimum = builtin(tl_core.minimum)
|
|
119
|
+
multiple_of = builtin(tl_core.multiple_of)
|
|
120
|
+
num_programs = builtin(tl_core.num_programs)
|
|
121
|
+
permute = builtin(tl_core.permute)
|
|
122
|
+
program_id = builtin(tl_core.program_id)
|
|
123
|
+
reduce = builtin(tl_core.reduce)
|
|
124
|
+
reshape = builtin(tl_core.reshape)
|
|
125
|
+
split = builtin(tl_core.split)
|
|
126
|
+
static_assert = builtin(tl_core.static_assert)
|
|
127
|
+
static_print = builtin(tl_core.static_print)
|
|
128
|
+
store = builtin(tl_core.store)
|
|
129
|
+
to_tensor = builtin(tl_core.to_tensor)
|
|
130
|
+
where = builtin(tl_core.where)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class distributed_type(block_type):
|
|
134
|
+
|
|
135
|
+
def __init__(self, element_ty: dtype, shape: List[int], layout):
|
|
136
|
+
super().__init__(element_ty, shape)
|
|
137
|
+
self.layout = layout
|
|
138
|
+
self.name = f"<{self.shape}, {self.element_ty}, {self.layout}>"
|
|
139
|
+
assert isinstance(layout, DistributedLayout)
|
|
140
|
+
|
|
141
|
+
def to_ir(self, builder: ir.builder) -> ir.type:
|
|
142
|
+
elem_ty = self.element_ty.to_ir(builder)
|
|
143
|
+
layout = self.layout._to_ir(builder)
|
|
144
|
+
return builder.get_distributed_ty(elem_ty, self.shape, layout)
|
|
145
|
+
|
|
146
|
+
def mangle(self) -> str:
|
|
147
|
+
elt = self.scalar.mangle()
|
|
148
|
+
shape = "_".join(map(str, self.shape))
|
|
149
|
+
layout = self.layout.mangle()
|
|
150
|
+
return f"{elt}S{shape}SL{layout}L"
|
|
151
|
+
|
|
152
|
+
def with_element_ty(self, scalar_ty: dtype) -> block_type:
|
|
153
|
+
return distributed_type(scalar_ty, self.shape, self.layout)
|
|
154
|
+
|
|
155
|
+
def __eq__(self, other) -> bool:
|
|
156
|
+
if not isinstance(other, distributed_type):
|
|
157
|
+
return False
|
|
158
|
+
return super().__eq__(other) and self.layout == other.layout
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class shared_memory_descriptor_type(base_type):
|
|
162
|
+
|
|
163
|
+
def __init__(self, element_ty, shape, layout, alloc_shape):
|
|
164
|
+
self.element_ty = element_ty
|
|
165
|
+
self.shape = shape
|
|
166
|
+
self.layout = layout
|
|
167
|
+
self.alloc_shape = alloc_shape
|
|
168
|
+
assert isinstance(layout, SharedLayout)
|
|
169
|
+
|
|
170
|
+
def to_ir(self, builder: GluonOpBuilder) -> None:
|
|
171
|
+
return builder.get_shared_mem_desc_ty(
|
|
172
|
+
self.element_ty.to_ir(builder),
|
|
173
|
+
self.shape,
|
|
174
|
+
self.layout._to_ir(builder),
|
|
175
|
+
self.alloc_shape,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[shared_memory_descriptor, int]:
|
|
179
|
+
value = shared_memory_descriptor(handles[cursor], self.element_ty, self.shape, self.layout, self.alloc_shape)
|
|
180
|
+
return value, cursor + 1
|
|
181
|
+
|
|
182
|
+
def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None:
|
|
183
|
+
out.append(self.to_ir(builder))
|
|
184
|
+
|
|
185
|
+
def __str__(self) -> str:
|
|
186
|
+
return f"shared_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}, {self.alloc_shape}>"
|
|
187
|
+
|
|
188
|
+
def __eq__(self, other) -> bool:
|
|
189
|
+
return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout
|
|
190
|
+
and self.alloc_shape == other.alloc_shape)
|
|
191
|
+
|
|
192
|
+
def __neq__(self, other) -> bool:
|
|
193
|
+
return not (self == other)
|
|
194
|
+
|
|
195
|
+
def mangle(self) -> str:
|
|
196
|
+
shape_str = "_".join([str(s) for s in self.shape])
|
|
197
|
+
return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{self.alloc_shape}ASMD"
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class shared_memory_descriptor(base_value):
|
|
201
|
+
"""
|
|
202
|
+
Represents a handle to a shared memory allocation in Gluon IR.
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
def __init__(self, handle, element_ty, shape, layout, alloc_shape):
|
|
206
|
+
self.handle = handle
|
|
207
|
+
self.type = shared_memory_descriptor_type(element_ty, shape, layout, alloc_shape)
|
|
208
|
+
|
|
209
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
210
|
+
handles.append(self.handle)
|
|
211
|
+
|
|
212
|
+
@property
|
|
213
|
+
def dtype(self):
|
|
214
|
+
return self.type.element_ty
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def shape(self):
|
|
218
|
+
return self.type.shape
|
|
219
|
+
|
|
220
|
+
@property
|
|
221
|
+
def rank(self):
|
|
222
|
+
return len(self.shape)
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def numel(self) -> int:
|
|
226
|
+
return math.prod(self.shape)
|
|
227
|
+
|
|
228
|
+
@property
|
|
229
|
+
def layout(self):
|
|
230
|
+
return self.type.layout
|
|
231
|
+
|
|
232
|
+
def __str__(self) -> str:
|
|
233
|
+
return str(self.type)
|
|
234
|
+
|
|
235
|
+
@builtin
|
|
236
|
+
def load(self, layout, _semantic: GluonSemantic = None) -> tensor:
|
|
237
|
+
"""
|
|
238
|
+
Load a tensor from shared memory.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
layout (DistributedLayout): The destination layout of the tensor.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
tensor: A Gluon tensor containing the loaded data.
|
|
245
|
+
"""
|
|
246
|
+
layout = _unwrap_if_constexpr(layout)
|
|
247
|
+
return _semantic.shared_load(self, layout)
|
|
248
|
+
|
|
249
|
+
@builtin
|
|
250
|
+
def store(self, value, _semantic: GluonSemantic = None) -> None:
|
|
251
|
+
"""
|
|
252
|
+
Store a tensor into shared memory.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
value (tensor): The tensor whose contents to store.
|
|
256
|
+
"""
|
|
257
|
+
return _semantic.shared_store(self, value)
|
|
258
|
+
|
|
259
|
+
@builtin
|
|
260
|
+
def slice(self, start, length, dim=0, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
|
|
261
|
+
"""
|
|
262
|
+
Create a subview of shared memory by slicing along a given dimension.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
start (int): The starting index of the slice.
|
|
266
|
+
length (int): The length of the slice.
|
|
267
|
+
dim (int): The dimension to slice (default: 0).
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
shared_memory_descriptor: Descriptor for the sliced subview.
|
|
271
|
+
"""
|
|
272
|
+
start = _unwrap_if_constexpr(start)
|
|
273
|
+
length = _unwrap_if_constexpr(length)
|
|
274
|
+
dim = _unwrap_if_constexpr(dim)
|
|
275
|
+
return _semantic.memdesc_slice(self, start, length, dim)
|
|
276
|
+
|
|
277
|
+
@builtin
|
|
278
|
+
def index(self, index, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
|
|
279
|
+
"""
|
|
280
|
+
Create a subview of shared memory by indexing along the first dimension.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
index (int): The index at which to take the subview.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
shared_memory_descriptor: Descriptor for the indexed subview.
|
|
287
|
+
"""
|
|
288
|
+
index = _unwrap_if_constexpr(index)
|
|
289
|
+
return _semantic.memdesc_index(self, index)
|
|
290
|
+
|
|
291
|
+
@builtin
|
|
292
|
+
def permute(self, order, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
|
|
293
|
+
"""
|
|
294
|
+
Permute the dimensions of the shared memory descriptor.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
order (List[int]): The new ordering of dimensions.
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
shared_memory_descriptor: Descriptor with permuted dimensions.
|
|
301
|
+
"""
|
|
302
|
+
order = [_unwrap_if_constexpr(o) for o in order]
|
|
303
|
+
return _semantic.memdesc_trans(self, order)
|
|
304
|
+
|
|
305
|
+
@builtin
|
|
306
|
+
def reshape(self, shape, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
|
|
307
|
+
"""
|
|
308
|
+
Reshape the shared memory descriptor to a new shape and layout.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
shape (List[int]): The target shape.
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
shared_memory_descriptor: Descriptor with the new shape and layout.
|
|
315
|
+
"""
|
|
316
|
+
shape = [_unwrap_if_constexpr(s) for s in shape]
|
|
317
|
+
|
|
318
|
+
return _semantic.memdesc_reshape(self, shape)
|
|
319
|
+
|
|
320
|
+
@builtin
|
|
321
|
+
def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
|
|
322
|
+
"""
|
|
323
|
+
Reinterpret the shared memory descriptor as a different dtype, shape, or layout.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
dtype (dtype): The new data type.
|
|
327
|
+
shape (List[int]): The new shape.
|
|
328
|
+
layout (SharedLayout): The new layout.
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
shared_memory_descriptor: Descriptor with updated type and layout.
|
|
332
|
+
"""
|
|
333
|
+
dtype = _unwrap_if_constexpr(dtype)
|
|
334
|
+
shape = [_unwrap_if_constexpr(s) for s in shape]
|
|
335
|
+
layout = _unwrap_if_constexpr(layout)
|
|
336
|
+
|
|
337
|
+
return _semantic.memdesc_reinterpret(self, dtype, shape, layout)
|
|
338
|
+
|
|
339
|
+
@builtin
|
|
340
|
+
def _keep_alive(self, _semantic: GluonSemantic = None) -> None:
|
|
341
|
+
"""
|
|
342
|
+
Dummy use to keep the shared memory descriptor alive.
|
|
343
|
+
"""
|
|
344
|
+
return _semantic.shared_dealloc(self)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
@builtin
|
|
348
|
+
def arange(start, end, layout=None, _semantic=None):
|
|
349
|
+
"""
|
|
350
|
+
Generate a sequence tensor with values in [start, end) using a specified layout.
|
|
351
|
+
|
|
352
|
+
Args:
|
|
353
|
+
start (int): Inclusive start of the sequence.
|
|
354
|
+
end (int): Exclusive end of the sequence.
|
|
355
|
+
layout (DistributedLayout): The layout of the output tensor. Defaults to AutoLayout.
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
tensor: A 1D tensor containing sequential values.
|
|
359
|
+
"""
|
|
360
|
+
start = _unwrap_if_constexpr(start)
|
|
361
|
+
end = _unwrap_if_constexpr(end)
|
|
362
|
+
layout = _unwrap_if_constexpr(layout)
|
|
363
|
+
return _semantic.arange(start, end, layout)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
@builtin
|
|
367
|
+
def convert_layout(value, layout, assert_trivial=False, _semantic=None):
|
|
368
|
+
"""
|
|
369
|
+
Convert a tensor to a different distributed layout.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
value (tensor): The input tensor.
|
|
373
|
+
layout (DistributedLayout): The target layout.
|
|
374
|
+
assert_trivial (bool): If True, asserts that the conversion is trivial (no data movement).
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
tensor: The tensor with the new layout.
|
|
378
|
+
"""
|
|
379
|
+
layout = _unwrap_if_constexpr(layout)
|
|
380
|
+
return _semantic.convert_layout(value, layout, assert_trivial)
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
@builtin
|
|
384
|
+
def full(shape, value, dtype, layout=None, _semantic=None):
|
|
385
|
+
"""
|
|
386
|
+
Create a tensor filled with a scalar value, with specified shape, dtype, and layout.
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
shape (Sequence[int]): The shape of the tensor.
|
|
390
|
+
value (int or float): The fill value.
|
|
391
|
+
dtype (dtype): The data type for the tensor.
|
|
392
|
+
layout (Optional[DistributedLayout]): The layout of the output tensor, defaults to AutoLayout().
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
tensor: A tensor where every element equals value.
|
|
396
|
+
"""
|
|
397
|
+
shape = _unwrap_shape(shape)
|
|
398
|
+
value = _unwrap_if_constexpr(value)
|
|
399
|
+
dtype = _unwrap_if_constexpr(dtype)
|
|
400
|
+
layout = _unwrap_if_constexpr(layout)
|
|
401
|
+
return _semantic.full(shape, value, dtype, layout)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
@builtin
|
|
405
|
+
def histogram(input, num_bins, mask=None, layout=None, _semantic=None, _generator=None):
|
|
406
|
+
"""
|
|
407
|
+
Compute a histogram of a 1D integer tensor.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
input (tensor): 1D tensor of integer values.
|
|
411
|
+
num_bins (int): Number of bins. Bins have width 1 and start at 0.
|
|
412
|
+
mask (Optional[tensor]): Boolean mask to exclude elements when False.
|
|
413
|
+
layout (DistributedLayout): Destination layout of the output histogram.
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
tensor: 1D int32 tensor of length `num_bins` with the requested layout.
|
|
417
|
+
"""
|
|
418
|
+
num_bins = _unwrap_if_constexpr(num_bins)
|
|
419
|
+
layout = _unwrap_if_constexpr(layout)
|
|
420
|
+
if mask is not None:
|
|
421
|
+
mask = _semantic.to_tensor(mask)
|
|
422
|
+
return _semantic.histogram(input, num_bins, mask, layout)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
@builtin
|
|
426
|
+
def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None) -> shared_memory_descriptor:
|
|
427
|
+
"""
|
|
428
|
+
Allocate shared memory for a tensor with the given element type, shape, and layout.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
element_ty (dtype): The element data type.
|
|
432
|
+
shape (Sequence[int]): The dimensions of the shared memory.
|
|
433
|
+
layout (SharedLayout): The shared memory layout.
|
|
434
|
+
value (tensor, optional): Initial value to copy into shared memory.
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
shared_memory_descriptor: Descriptor for the allocated memory.
|
|
438
|
+
"""
|
|
439
|
+
element_ty = _unwrap_if_constexpr(element_ty)
|
|
440
|
+
shape = _unwrap_if_constexpr(shape)
|
|
441
|
+
shape = [_unwrap_if_constexpr(s) for s in shape]
|
|
442
|
+
layout = _unwrap_if_constexpr(layout)
|
|
443
|
+
return _semantic.allocate_shared(element_ty, shape, layout, value)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
@builtin
|
|
447
|
+
def set_auto_layout(value, layout, _semantic=None):
|
|
448
|
+
"""
|
|
449
|
+
Set a a tensor with AutoLayout to a concrete layout
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
value (tensor): The input tensor.
|
|
453
|
+
layout (DistribtedLayout): The target layout.
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
tensor: The tensor with the new layout.
|
|
457
|
+
"""
|
|
458
|
+
layout = _unwrap_if_constexpr(layout)
|
|
459
|
+
return _semantic.set_auto_layout(value, layout)
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
@builtin
|
|
463
|
+
def warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps, worker_num_regs,
|
|
464
|
+
_semantic=None, _generator=None):
|
|
465
|
+
"""
|
|
466
|
+
Create a warp-specialized execution region, partitioning work across warps.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
default_args (List[Any]): Arguments for the default region.
|
|
470
|
+
default_partition (callable): Function to build the default execution region.
|
|
471
|
+
worker_args (List[Any]): Arguments for each warp partition.
|
|
472
|
+
worker_partitions (List[callable]): Functions for each warp partition.
|
|
473
|
+
worker_num_warps (List[int]): Number of warps per partition.
|
|
474
|
+
worker_num_regs (List[int]): Number of registers per partition.
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
Tuple[Any, ...]: Results from the default region.
|
|
478
|
+
"""
|
|
479
|
+
worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
|
|
480
|
+
worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
|
|
481
|
+
return _semantic.warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps,
|
|
482
|
+
worker_num_regs, _generator)
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
@builtin
|
|
486
|
+
def thread_barrier(_semantic=None):
|
|
487
|
+
"""
|
|
488
|
+
Insert a barrier to synchronize threads within a CTA.
|
|
489
|
+
"""
|
|
490
|
+
return _semantic.debug_barrier()
|