triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +82 -0
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +255 -0
- triton/_utils.py +126 -0
- triton/backends/__init__.py +47 -0
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +461 -0
- triton/backends/amd/driver.c +283 -0
- triton/backends/amd/driver.py +724 -0
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +90 -0
- triton/backends/driver.py +66 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +533 -0
- triton/backends/nvidia/driver.c +517 -0
- triton/backends/nvidia/driver.py +799 -0
- triton/backends/nvidia/include/cuda.h +26280 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +7 -0
- triton/compiler/code_generator.py +1614 -0
- triton/compiler/compiler.py +509 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +342 -0
- triton/language/core.py +3405 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +16 -0
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +5 -0
- triton/language/extra/hip/libdevice.py +491 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +790 -0
- triton/language/math.py +249 -0
- triton/language/random.py +218 -0
- triton/language/semantic.py +1939 -0
- triton/language/standard.py +534 -0
- triton/language/target_info.py +54 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/_allocation.py +44 -0
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +476 -0
- triton/runtime/build.py +168 -0
- triton/runtime/cache.py +317 -0
- triton/runtime/driver.py +38 -0
- triton/runtime/errors.py +36 -0
- triton/runtime/interpreter.py +1414 -0
- triton/runtime/jit.py +1107 -0
- triton/runtime/tcc/include/_mingw.h +168 -0
- triton/runtime/tcc/include/assert.h +62 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +75 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +116 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +497 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +14 -0
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +42 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +591 -0
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +123 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +2958 -0
- triton/runtime/tcc/include/winapi/wincon.h +309 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +5837 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +543 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.py +210 -0
- triton/tools/disasm.py +143 -0
- triton/tools/extra/cuda/compile.c +70 -0
- triton/tools/extra/cuda/compile.h +14 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/link.py +322 -0
- triton/tools/mxfp.py +301 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +405 -0
- triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
- triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
- triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from abc import ABCMeta, abstractmethod
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Dict, Union
|
|
5
|
+
from types import ModuleType
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass(frozen=True)
|
|
9
|
+
class GPUTarget(object):
|
|
10
|
+
# Target backend, e.g., cuda, hip
|
|
11
|
+
backend: str
|
|
12
|
+
# Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip)
|
|
13
|
+
arch: Union[int, str]
|
|
14
|
+
warp_size: int
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Language(Enum):
|
|
18
|
+
"""The input language being compiled by the backend."""
|
|
19
|
+
TRITON = 0
|
|
20
|
+
GLUON = 1
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BaseBackend(metaclass=ABCMeta):
|
|
24
|
+
|
|
25
|
+
def __init__(self, target: GPUTarget) -> None:
|
|
26
|
+
self.target = target
|
|
27
|
+
assert self.supports_target(target)
|
|
28
|
+
|
|
29
|
+
@staticmethod
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def supports_target(target: GPUTarget):
|
|
32
|
+
raise NotImplementedError
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def hash(self) -> str:
|
|
36
|
+
"""Returns a unique identifier for this backend"""
|
|
37
|
+
raise NotImplementedError
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def parse_options(self, options: dict) -> object:
|
|
41
|
+
"""
|
|
42
|
+
Converts an `options` dictionary into an arbitrary object and returns it.
|
|
43
|
+
This function may contain target-specific heuristics and check the legality of the provided options
|
|
44
|
+
"""
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def add_stages(self, stages: dict, options: object) -> None:
|
|
49
|
+
"""
|
|
50
|
+
Populates `stages` dictionary with entries of the form:
|
|
51
|
+
ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes]
|
|
52
|
+
The value of each entry may populate a `metadata` dictionary.
|
|
53
|
+
Stages will be run sequentially (in inseriton order) and can communicate using `metadata`.
|
|
54
|
+
All stages are expected to return a `str` object, except for the last stage which returns
|
|
55
|
+
a `bytes` object for execution by the launcher.
|
|
56
|
+
"""
|
|
57
|
+
raise NotImplementedError
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def load_dialects(self, context):
|
|
61
|
+
"""
|
|
62
|
+
Load additional MLIR dialects into the provided `context`
|
|
63
|
+
"""
|
|
64
|
+
raise NotImplementedError
|
|
65
|
+
|
|
66
|
+
@abstractmethod
|
|
67
|
+
def get_module_map(self) -> Dict[str, ModuleType]:
|
|
68
|
+
"""
|
|
69
|
+
Return a map of interface modules to their device-specific implementations
|
|
70
|
+
"""
|
|
71
|
+
raise NotImplementedError
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def parse_attr(desc):
|
|
75
|
+
assert isinstance(desc, str)
|
|
76
|
+
ret = []
|
|
77
|
+
if "D" in desc:
|
|
78
|
+
ret += [["tt.divisibility", 16]]
|
|
79
|
+
return ret
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def get_arg_specialization(arg, ty, **kwargs):
|
|
83
|
+
"""
|
|
84
|
+
Return a string unique to each possible specialization of the argument
|
|
85
|
+
"""
|
|
86
|
+
if ty == "int" and arg % 16 == 0 and kwargs.get("align", False):
|
|
87
|
+
return "D"
|
|
88
|
+
if ty == "tensor" and arg.data_ptr() % 16 == 0 and kwargs.get("align", False):
|
|
89
|
+
return "D"
|
|
90
|
+
return ""
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from abc import ABCMeta, abstractmethod
|
|
2
|
+
from typing import Callable, List, Protocol, Sequence
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Benchmarker(Protocol):
|
|
6
|
+
|
|
7
|
+
def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]:
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DriverBase(metaclass=ABCMeta):
|
|
12
|
+
|
|
13
|
+
@classmethod
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def is_active(self):
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def map_python_to_cpp_type(self, ty: str) -> str:
|
|
20
|
+
"""
|
|
21
|
+
Converts a Triton type string to its corresponding C++ type string for this backend.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
ty (str): The Triton type string. e.g., 'i32', '*fp16', 'fp32'.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
str: The C++ type string.
|
|
28
|
+
"""
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def get_current_target(self):
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def get_active_torch_device(self):
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def get_benchmarker(self) -> Benchmarker:
|
|
41
|
+
"""
|
|
42
|
+
Return the benchmarking function that this backend should use by default.
|
|
43
|
+
"""
|
|
44
|
+
raise NotImplementedError
|
|
45
|
+
|
|
46
|
+
def __init__(self) -> None:
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class GPUDriver(DriverBase):
|
|
51
|
+
|
|
52
|
+
def __init__(self):
|
|
53
|
+
# TODO: support other frameworks than torch
|
|
54
|
+
import torch
|
|
55
|
+
self.get_device_capability = torch.cuda.get_device_capability
|
|
56
|
+
try:
|
|
57
|
+
from torch._C import _cuda_getCurrentRawStream
|
|
58
|
+
self.get_current_stream = _cuda_getCurrentRawStream
|
|
59
|
+
except ImportError:
|
|
60
|
+
self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream
|
|
61
|
+
self.get_current_device = torch.cuda.current_device
|
|
62
|
+
self.set_current_device = torch.cuda.set_device
|
|
63
|
+
|
|
64
|
+
# TODO: remove once TMA is cleaned up
|
|
65
|
+
def assemble_tensormap_to_arg(self, tensormaps_info, args):
|
|
66
|
+
return args
|
|
File without changes
|
|
Binary file
|