triton-windows 3.5.1.post21__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +82 -0
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +255 -0
- triton/_utils.py +126 -0
- triton/backends/__init__.py +47 -0
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +461 -0
- triton/backends/amd/driver.c +283 -0
- triton/backends/amd/driver.py +724 -0
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +90 -0
- triton/backends/driver.py +66 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +533 -0
- triton/backends/nvidia/driver.c +517 -0
- triton/backends/nvidia/driver.py +799 -0
- triton/backends/nvidia/include/cuda.h +26280 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +7 -0
- triton/compiler/code_generator.py +1614 -0
- triton/compiler/compiler.py +509 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +342 -0
- triton/language/core.py +3405 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +16 -0
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +5 -0
- triton/language/extra/hip/libdevice.py +491 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +790 -0
- triton/language/math.py +249 -0
- triton/language/random.py +218 -0
- triton/language/semantic.py +1939 -0
- triton/language/standard.py +534 -0
- triton/language/target_info.py +54 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/_allocation.py +44 -0
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +476 -0
- triton/runtime/build.py +168 -0
- triton/runtime/cache.py +317 -0
- triton/runtime/driver.py +38 -0
- triton/runtime/errors.py +36 -0
- triton/runtime/interpreter.py +1414 -0
- triton/runtime/jit.py +1107 -0
- triton/runtime/tcc/include/_mingw.h +168 -0
- triton/runtime/tcc/include/assert.h +62 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +75 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +116 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +497 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +14 -0
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +42 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +591 -0
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +123 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +2958 -0
- triton/runtime/tcc/include/winapi/wincon.h +309 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +5837 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +543 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.py +210 -0
- triton/tools/disasm.py +143 -0
- triton/tools/extra/cuda/compile.c +70 -0
- triton/tools/extra/cuda/compile.h +14 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/link.py +322 -0
- triton/tools/mxfp.py +301 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +405 -0
- triton_windows-3.5.1.post21.dist-info/METADATA +46 -0
- triton_windows-3.5.1.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.1.post21.dist-info/WHEEL +5 -0
- triton_windows-3.5.1.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.1.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.1.post21.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import List, Tuple, TYPE_CHECKING
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from triton.language.core import base_type, base_value
|
|
5
|
+
import triton.experimental.gluon.language._core as ttgl
|
|
6
|
+
from triton.experimental.gluon.language._layouts import NVMMASharedLayout
|
|
7
|
+
from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from triton._C import ir
|
|
11
|
+
|
|
12
|
+
__all__ = ["async_copy_global_to_shared", "async_copy_shared_to_global", "store_wait"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(eq=True)
|
|
16
|
+
class tensor_descriptor_type(base_type):
|
|
17
|
+
block_type: ttgl.block_type
|
|
18
|
+
shape_type: ttgl.tuple_type
|
|
19
|
+
strides_type: ttgl.tuple_type
|
|
20
|
+
layout: NVMMASharedLayout
|
|
21
|
+
|
|
22
|
+
def __str__(self) -> str:
|
|
23
|
+
return f"tensor_descriptor<{self.block_type}, {self.layout}>"
|
|
24
|
+
|
|
25
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor, int]:
|
|
26
|
+
handle = handles[cursor]
|
|
27
|
+
cursor += 1
|
|
28
|
+
shape, cursor = self.shape_type._unflatten_ir(handles, cursor)
|
|
29
|
+
strides, cursor = self.strides_type._unflatten_ir(handles, cursor)
|
|
30
|
+
value = tensor_descriptor(handle, shape, strides, self.block_type, layout=self.layout)
|
|
31
|
+
return value, cursor
|
|
32
|
+
|
|
33
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
|
|
34
|
+
is_signed = self.block_type.element_ty.is_int_signed()
|
|
35
|
+
ty = builder.get_tensor_descriptor_layout_type(
|
|
36
|
+
self.block_type.to_ir(builder),
|
|
37
|
+
is_signed,
|
|
38
|
+
self.layout._to_ir(builder),
|
|
39
|
+
)
|
|
40
|
+
out.append(ty)
|
|
41
|
+
self.shape_type._flatten_ir_types(builder, out)
|
|
42
|
+
self.strides_type._flatten_ir_types(builder, out)
|
|
43
|
+
|
|
44
|
+
def mangle(self) -> str:
|
|
45
|
+
return f"TD{self.block_type.mangle()}_{self.layout.mangle()}TD"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class tensor_descriptor(base_value):
|
|
49
|
+
|
|
50
|
+
def __init__(self, handle, shape: List[ttgl.tensor], strides: List[ttgl.tensor], block_type: ttgl.block_type,
|
|
51
|
+
layout: NVMMASharedLayout):
|
|
52
|
+
self.handle = handle
|
|
53
|
+
self.shape = ttgl.tuple(shape)
|
|
54
|
+
self.strides = ttgl.tuple(strides)
|
|
55
|
+
self.type = tensor_descriptor_type(block_type, shape_type=self.shape.type, strides_type=self.strides.type,
|
|
56
|
+
layout=layout)
|
|
57
|
+
|
|
58
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
59
|
+
handles.append(self.handle)
|
|
60
|
+
self.shape._flatten_ir(handles)
|
|
61
|
+
self.strides._flatten_ir(handles)
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def block_type(self):
|
|
65
|
+
return self.type.block_type
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def block_shape(self):
|
|
69
|
+
return self.type.block_type.shape
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def dtype(self):
|
|
73
|
+
return self.type.block_type.element_ty
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def layout(self):
|
|
77
|
+
return self.type.layout
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@builtin
|
|
81
|
+
def async_copy_global_to_shared(tensor_desc, coord, barrier, result, pred=True, _semantic=None):
|
|
82
|
+
coord = _semantic._convert_to_ir_values(coord, require_i64=False)
|
|
83
|
+
pred = _semantic.to_tensor(pred)
|
|
84
|
+
_semantic.builder.create_async_tma_copy_global_to_local(tensor_desc.handle, coord, barrier.handle, result.handle,
|
|
85
|
+
pred.handle)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@builtin
|
|
89
|
+
def async_copy_shared_to_global(tensor_desc, coord, src, _semantic=None):
|
|
90
|
+
coord = _semantic._convert_to_ir_values(coord, require_i64=False)
|
|
91
|
+
_semantic.builder.create_async_tma_copy_local_to_global(tensor_desc.handle, coord, src.handle)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@builtin
|
|
95
|
+
def store_wait(pendings, _semantic=None):
|
|
96
|
+
pendings = _unwrap_if_constexpr(pendings)
|
|
97
|
+
_semantic.builder.create_async_tma_store_wait(pendings)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import List, Any
|
|
3
|
+
from triton._utils import validate_block_shape, canonicalize_dtype, get_primitive_bitwidth
|
|
4
|
+
from triton.experimental.gluon.language._layouts import NVMMASharedLayout
|
|
5
|
+
|
|
6
|
+
__all__ = ["TensorDescriptor"]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class TensorDescriptor:
|
|
11
|
+
base: Any
|
|
12
|
+
shape: List[int]
|
|
13
|
+
strides: List[int]
|
|
14
|
+
block_shape: List[int]
|
|
15
|
+
layout: NVMMASharedLayout
|
|
16
|
+
padding: str = "zero"
|
|
17
|
+
|
|
18
|
+
def __post_init__(self):
|
|
19
|
+
rank = len(self.shape)
|
|
20
|
+
assert len(self.strides) == rank, f"rank mismatch: {self}"
|
|
21
|
+
assert len(self.block_shape) == rank, f"rank mismatch: {self}"
|
|
22
|
+
assert rank > 0, "rank must not be zero"
|
|
23
|
+
assert rank <= 5, "rank cannot be more than 5"
|
|
24
|
+
assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
|
|
25
|
+
validate_block_shape(self.block_shape)
|
|
26
|
+
dtype_str = canonicalize_dtype(self.base.dtype)
|
|
27
|
+
elem_bytes = get_primitive_bitwidth(dtype_str) // 8
|
|
28
|
+
for stride in self.strides[:-1]:
|
|
29
|
+
assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
|
|
30
|
+
assert self.strides[-1] == 1, "Last dimension must be contiguous"
|
|
31
|
+
assert isinstance(self.layout, NVMMASharedLayout), "Layout must be NVMMASharedLayout"
|
|
32
|
+
assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding"
|
|
33
|
+
if self.padding == "nan":
|
|
34
|
+
assert self.base.dtype.is_floating_point, "Padding option `nan` is only supported for floating point tensors"
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def from_tensor(tensor: Any, block_shape: List[int], layout: NVMMASharedLayout, padding="zero"):
|
|
38
|
+
return TensorDescriptor(
|
|
39
|
+
tensor,
|
|
40
|
+
tensor.shape,
|
|
41
|
+
tensor.stride(),
|
|
42
|
+
block_shape,
|
|
43
|
+
layout,
|
|
44
|
+
padding,
|
|
45
|
+
)
|
triton/knobs.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import importlib
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
import subprocess
|
|
8
|
+
import sysconfig
|
|
9
|
+
import warnings
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from contextlib import contextmanager
|
|
13
|
+
from typing import cast, Any, Callable, Generator, Generic, Optional, Protocol, Type, TypeVar, TypedDict, TYPE_CHECKING, Union
|
|
14
|
+
|
|
15
|
+
from triton._C.libtriton import getenv, getenv_bool # type: ignore
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from .runtime.cache import CacheManager, RemoteCacheBackend
|
|
19
|
+
from .runtime.jit import JitFunctionInfo, KernelParam
|
|
20
|
+
from .compiler.compiler import ASTSource, LazyDict, IRSource
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Env:
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
env = Env()
|
|
28
|
+
|
|
29
|
+
propagate_env: bool = True
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def setenv(key: str, value: Optional[str]) -> None:
|
|
33
|
+
if not propagate_env:
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
if value is not None:
|
|
37
|
+
os.environ[key] = value
|
|
38
|
+
elif key in os.environ:
|
|
39
|
+
del os.environ[key]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def toenv(val: Any) -> Union[None, tuple[Optional[str]]]:
|
|
43
|
+
if val is None:
|
|
44
|
+
return (None, )
|
|
45
|
+
|
|
46
|
+
t = type(val)
|
|
47
|
+
if t is bool:
|
|
48
|
+
return ("1" if val else "0", )
|
|
49
|
+
|
|
50
|
+
if t is str:
|
|
51
|
+
return (val, )
|
|
52
|
+
|
|
53
|
+
if t is int:
|
|
54
|
+
return (str(val), )
|
|
55
|
+
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# There's an asymmetry here so that e.g. env_nvidia_tool can be specified with a
|
|
60
|
+
# a string but return an NvidiaTool.
|
|
61
|
+
SetType = TypeVar("SetType")
|
|
62
|
+
GetType = TypeVar("GetType")
|
|
63
|
+
|
|
64
|
+
_NOTHING = object()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class env_base(Generic[SetType, GetType]):
|
|
68
|
+
|
|
69
|
+
def __init__(self, key: str) -> None:
|
|
70
|
+
self.key = key
|
|
71
|
+
|
|
72
|
+
def __set_name__(self, objclass: Type[object], name: str) -> None:
|
|
73
|
+
self.name = name
|
|
74
|
+
|
|
75
|
+
def __get__(self, obj: Optional[object], objclass: Optional[Type[object]]) -> GetType:
|
|
76
|
+
py_val = obj.__dict__.get(self.name, _NOTHING)
|
|
77
|
+
if py_val is _NOTHING:
|
|
78
|
+
return self.get()
|
|
79
|
+
return self.transform(py_val)
|
|
80
|
+
|
|
81
|
+
def get(self) -> GetType:
|
|
82
|
+
raise NotImplementedError()
|
|
83
|
+
|
|
84
|
+
def __set__(self, obj: object, value: Union[SetType, Env]) -> None:
|
|
85
|
+
if isinstance(value, Env):
|
|
86
|
+
obj.__dict__.pop(self.name, None)
|
|
87
|
+
else:
|
|
88
|
+
obj.__dict__[self.name] = value
|
|
89
|
+
if env_val := toenv(value):
|
|
90
|
+
setenv(self.key, env_val[0])
|
|
91
|
+
|
|
92
|
+
def __delete__(self, obj: object) -> None:
|
|
93
|
+
obj.__dict__.pop(self.name, None)
|
|
94
|
+
|
|
95
|
+
def transform(self, val: SetType) -> GetType:
|
|
96
|
+
# See comment about GetType/SetType in their definition above. Only needed
|
|
97
|
+
# if GetType != SetType.
|
|
98
|
+
return cast(GetType, val)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class env_str(env_base[str, str]):
|
|
102
|
+
|
|
103
|
+
def __init__(self, key: str, default: str):
|
|
104
|
+
super().__init__(key)
|
|
105
|
+
self.default = default
|
|
106
|
+
|
|
107
|
+
def get(self) -> str:
|
|
108
|
+
return getenv(self.key, self.default)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class env_str_callable_default(env_base[str, str]):
|
|
112
|
+
|
|
113
|
+
def __init__(self, key: str, default_factory: Callable[[], str]):
|
|
114
|
+
super().__init__(key)
|
|
115
|
+
self.default_factory = default_factory
|
|
116
|
+
|
|
117
|
+
def get(self) -> str:
|
|
118
|
+
env_val = getenv(self.key)
|
|
119
|
+
if env_val is None:
|
|
120
|
+
return self.default_factory()
|
|
121
|
+
return env_val
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class env_bool(env_base[bool, bool]):
|
|
125
|
+
|
|
126
|
+
def __init__(self, key: str, default: bool = False) -> None:
|
|
127
|
+
super().__init__(key)
|
|
128
|
+
self.default = default
|
|
129
|
+
|
|
130
|
+
def get(self) -> bool:
|
|
131
|
+
return getenv_bool(self.key, self.default)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class env_int(env_base[int, int]):
|
|
135
|
+
|
|
136
|
+
def __init__(self, key: str, default: int = 0) -> None:
|
|
137
|
+
super().__init__(key)
|
|
138
|
+
self.default = default
|
|
139
|
+
|
|
140
|
+
def get(self) -> int:
|
|
141
|
+
val = getenv(self.key)
|
|
142
|
+
if val is None:
|
|
143
|
+
return self.default
|
|
144
|
+
try:
|
|
145
|
+
return int(val)
|
|
146
|
+
except ValueError as exc:
|
|
147
|
+
raise RuntimeError(f"Unable to use {self.key}={val}: expected int") from exc
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
ClassType = TypeVar("ClassType")
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class env_class(Generic[ClassType], env_base[Optional[Type[ClassType]], Optional[Type[ClassType]]]):
|
|
154
|
+
|
|
155
|
+
def __init__(self, key: str, type: str) -> None:
|
|
156
|
+
super().__init__(key)
|
|
157
|
+
# We can't pass the type directly to avoid import cycles
|
|
158
|
+
self.type = type
|
|
159
|
+
|
|
160
|
+
def get(self) -> Optional[Type[ClassType]]:
|
|
161
|
+
val = getenv(self.key)
|
|
162
|
+
if val is None:
|
|
163
|
+
return None
|
|
164
|
+
comps = val.split(":", 1)
|
|
165
|
+
if len(comps) != 2:
|
|
166
|
+
raise RuntimeError(f"Unable to read {self.key}: '{val}' isn't of the form MODULE:CLASS")
|
|
167
|
+
cls = getattr(importlib.import_module(comps[0]), comps[1])
|
|
168
|
+
|
|
169
|
+
if not any((c.__name__ == self.type for c in cls.mro())):
|
|
170
|
+
raise RuntimeError(f"Unable to use '{val}' from {self.key}: not of type '{self.type}'")
|
|
171
|
+
|
|
172
|
+
return cast(Type[ClassType], cls)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@dataclass
|
|
176
|
+
class NvidiaTool:
|
|
177
|
+
path: str
|
|
178
|
+
version: str
|
|
179
|
+
|
|
180
|
+
@staticmethod
|
|
181
|
+
@functools.lru_cache
|
|
182
|
+
def from_path(path: str) -> Optional[NvidiaTool]:
|
|
183
|
+
try:
|
|
184
|
+
result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
|
|
185
|
+
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
|
186
|
+
if version is None:
|
|
187
|
+
return None
|
|
188
|
+
return NvidiaTool(path, version.group(1))
|
|
189
|
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
190
|
+
return None
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def find_nvidia_tool(binary: str) -> str:
|
|
194
|
+
path = os.path.join(
|
|
195
|
+
os.path.dirname(__file__),
|
|
196
|
+
"backends",
|
|
197
|
+
"nvidia",
|
|
198
|
+
"bin",
|
|
199
|
+
binary,
|
|
200
|
+
)
|
|
201
|
+
if os.access(path, os.X_OK):
|
|
202
|
+
return path
|
|
203
|
+
|
|
204
|
+
if os.name == "nt":
|
|
205
|
+
from triton.windows_utils import find_cuda
|
|
206
|
+
cuda_bin_path, _, _ = find_cuda()
|
|
207
|
+
if cuda_bin_path:
|
|
208
|
+
path = os.path.join(cuda_bin_path, binary)
|
|
209
|
+
if os.access(path, os.X_OK):
|
|
210
|
+
return path
|
|
211
|
+
|
|
212
|
+
warnings.warn(f"Failed to find {binary}")
|
|
213
|
+
return ""
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class env_nvidia_tool(env_base[str, NvidiaTool]):
|
|
217
|
+
|
|
218
|
+
def __init__(self, binary: str) -> None:
|
|
219
|
+
binary += sysconfig.get_config_var("EXE")
|
|
220
|
+
self.binary = binary
|
|
221
|
+
self.default_path = find_nvidia_tool(binary)
|
|
222
|
+
super().__init__(f"TRITON_{binary.upper()}_PATH")
|
|
223
|
+
|
|
224
|
+
def get(self) -> NvidiaTool:
|
|
225
|
+
return self.transform(getenv(self.key))
|
|
226
|
+
|
|
227
|
+
def transform(self, path: str) -> NvidiaTool:
|
|
228
|
+
# We still add default as fallback in case the pointed binary isn't
|
|
229
|
+
# accessible.
|
|
230
|
+
if path is not None:
|
|
231
|
+
paths = [path, self.default_path]
|
|
232
|
+
else:
|
|
233
|
+
paths = [self.default_path]
|
|
234
|
+
|
|
235
|
+
for path in paths:
|
|
236
|
+
if tool := NvidiaTool.from_path(path):
|
|
237
|
+
return tool
|
|
238
|
+
|
|
239
|
+
raise RuntimeError(f"Cannot find {self.binary}")
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
# Separate classes so that types are correct
|
|
243
|
+
class env_opt_str(env_base[Optional[str], Optional[str]]):
|
|
244
|
+
|
|
245
|
+
def get(self) -> Optional[str]:
|
|
246
|
+
return getenv(self.key)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class env_opt_bool(env_base):
|
|
250
|
+
|
|
251
|
+
def get(self) -> Optional[str]:
|
|
252
|
+
return getenv_bool(self.key, None)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
@dataclass(frozen=True)
|
|
256
|
+
class CompileTimes:
|
|
257
|
+
"""
|
|
258
|
+
Model holding timing information for an invocation of the compiler.
|
|
259
|
+
|
|
260
|
+
All times in microseconds.
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
# Duration of make_ir
|
|
264
|
+
ir_initialization: int
|
|
265
|
+
|
|
266
|
+
# Ordered mapping from lowering stage to duration spent in that stage.
|
|
267
|
+
# Keyed by stage extension, e.g. ttir, ttgir
|
|
268
|
+
lowering_stages: list[tuple[str, int]]
|
|
269
|
+
|
|
270
|
+
# Duration of saving artifacts/metadata to cache
|
|
271
|
+
store_results: int
|
|
272
|
+
|
|
273
|
+
@property
|
|
274
|
+
def total_lowering(self) -> int:
|
|
275
|
+
return sum((stage[1] for stage in self.lowering_stages))
|
|
276
|
+
|
|
277
|
+
@property
|
|
278
|
+
def total(self) -> int:
|
|
279
|
+
return self.ir_initialization + self.total_lowering + self.store_results
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class CompilationListener(Protocol):
|
|
283
|
+
|
|
284
|
+
def __call__(self, *, src: Union[ASTSource, IRSource], metadata: dict[str, Any], metadata_group: dict[str, str],
|
|
285
|
+
times: CompileTimes, cache_hit: bool) -> None:
|
|
286
|
+
...
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
knobs_type = TypeVar("knobs_type", bound='base_knobs')
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class base_knobs:
|
|
293
|
+
|
|
294
|
+
@property
|
|
295
|
+
def knob_descriptors(self) -> dict[str, env_base]:
|
|
296
|
+
return {
|
|
297
|
+
k: v
|
|
298
|
+
# data descriptors live on the class object
|
|
299
|
+
for k, v in type(self).__dict__.items()
|
|
300
|
+
if isinstance(v, env_base)
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
@property
|
|
304
|
+
def knobs(self) -> dict[str, Any]:
|
|
305
|
+
return {k: getattr(self, k) for k in self.knob_descriptors.keys()}
|
|
306
|
+
|
|
307
|
+
def copy(self: knobs_type) -> knobs_type:
|
|
308
|
+
res = type(self)()
|
|
309
|
+
res.__dict__.update(self.__dict__)
|
|
310
|
+
return res
|
|
311
|
+
|
|
312
|
+
def reset(self: knobs_type) -> knobs_type:
|
|
313
|
+
for knob in self.knob_descriptors.keys():
|
|
314
|
+
delattr(self, knob)
|
|
315
|
+
return self
|
|
316
|
+
|
|
317
|
+
@contextmanager
|
|
318
|
+
def scope(self) -> Generator[None, None, None]:
|
|
319
|
+
try:
|
|
320
|
+
initial_env = {knob.key: getenv(knob.key) for knob in self.knob_descriptors.values()}
|
|
321
|
+
orig = dict(self.__dict__)
|
|
322
|
+
yield
|
|
323
|
+
finally:
|
|
324
|
+
self.__dict__.clear()
|
|
325
|
+
self.__dict__.update(orig)
|
|
326
|
+
|
|
327
|
+
for k, v in initial_env.items():
|
|
328
|
+
if v is not None:
|
|
329
|
+
os.environ[k] = v
|
|
330
|
+
elif k in os.environ:
|
|
331
|
+
del os.environ[k]
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
class BuildImpl(Protocol):
|
|
335
|
+
|
|
336
|
+
def __call__(self, name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str],
|
|
337
|
+
libraries: list[str], /) -> str:
|
|
338
|
+
...
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
class build_knobs(base_knobs):
|
|
342
|
+
"""Configuration controlling how the native compiler is invoked"""
|
|
343
|
+
cc: env_opt_str = env_opt_str("CC")
|
|
344
|
+
|
|
345
|
+
cudacrt_path: env_opt_str = env_opt_str("TRITON_CUDACRT_PATH")
|
|
346
|
+
cudart_path: env_opt_str = env_opt_str("TRITON_CUDART_PATH")
|
|
347
|
+
|
|
348
|
+
impl: Optional[BuildImpl] = None
|
|
349
|
+
|
|
350
|
+
@property
|
|
351
|
+
def backend_dirs(self) -> set[str]:
|
|
352
|
+
return {path for path in (self.cudacrt_path, self.cudart_path) if path is not None}
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
class redis_knobs(base_knobs):
|
|
356
|
+
key_format: env_str = env_str("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}")
|
|
357
|
+
host: env_str = env_str("TRITON_REDIS_HOST", "localhost")
|
|
358
|
+
port: env_int = env_int("TRITON_REDIS_PORT", 6379)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
cache: cache_knobs
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class cache_knobs(base_knobs):
|
|
365
|
+
home_dir: env_str = env_str("TRITON_HOME", os.path.expanduser("~/"))
|
|
366
|
+
|
|
367
|
+
dump_dir = env_str_callable_default("TRITON_DUMP_DIR", lambda: cache.get_triton_dir("dump"))
|
|
368
|
+
override_dir = env_str_callable_default("TRITON_OVERRIDE_DIR", lambda: cache.get_triton_dir("override"))
|
|
369
|
+
dir = env_str_callable_default("TRITON_CACHE_DIR", lambda: cache.get_triton_dir("cache"))
|
|
370
|
+
|
|
371
|
+
manager_class: env_class[CacheManager] = env_class("TRITON_CACHE_MANAGER", "CacheManager")
|
|
372
|
+
remote_manager_class: env_class[RemoteCacheBackend] = env_class("TRITON_REMOTE_CACHE_BACKEND", "RemoteCacheBackend")
|
|
373
|
+
|
|
374
|
+
def get_triton_dir(self, dirname: str) -> str:
|
|
375
|
+
return os.path.join(self.home_dir, ".triton", dirname)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
class compilation_knobs(base_knobs):
|
|
379
|
+
override: env_bool = env_bool("TRITON_KERNEL_OVERRIDE")
|
|
380
|
+
dump_ir: env_bool = env_bool("TRITON_KERNEL_DUMP")
|
|
381
|
+
store_binary_only: env_bool = env_bool("TRITON_STORE_BINARY_ONLY")
|
|
382
|
+
always_compile: env_bool = env_bool("TRITON_ALWAYS_COMPILE")
|
|
383
|
+
# TODO: Use enum to constrain / 'typecheck' the values
|
|
384
|
+
use_ir_loc: env_opt_str = env_opt_str("USE_IR_LOC")
|
|
385
|
+
enable_asan: env_bool = env_bool("TRITON_ENABLE_ASAN")
|
|
386
|
+
disable_line_info: env_bool = env_bool("TRITON_DISABLE_LINE_INFO")
|
|
387
|
+
front_end_debugging: env_bool = env_bool("TRITON_FRONT_END_DEBUGGING")
|
|
388
|
+
allow_non_constexpr_globals: env_bool = env_bool("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS")
|
|
389
|
+
enable_experimental_consan: env_bool = env_bool("TRITON_ENABLE_EXPERIMENTAL_CONSAN")
|
|
390
|
+
listener: Union[CompilationListener, None] = None
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
class autotuning_knobs(base_knobs):
|
|
394
|
+
cache: env_bool = env_bool("TRITON_CACHE_AUTOTUNING")
|
|
395
|
+
print: env_bool = env_bool("TRITON_PRINT_AUTOTUNING")
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
class LaunchHook(Protocol):
|
|
399
|
+
"""Hook invoked before and after kernel launching
|
|
400
|
+
"""
|
|
401
|
+
|
|
402
|
+
def __call__(self, metadata: LazyDict) -> None:
|
|
403
|
+
...
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
class InitHandleHook(Protocol):
|
|
407
|
+
"""Hook invoked around kernel binary/module loading.
|
|
408
|
+
module/function can be None for the *start* hook (before loading).
|
|
409
|
+
"""
|
|
410
|
+
|
|
411
|
+
def __call__(
|
|
412
|
+
self,
|
|
413
|
+
module: Optional[object],
|
|
414
|
+
function: Optional[Callable],
|
|
415
|
+
name: str,
|
|
416
|
+
metadata_group: dict[str, str],
|
|
417
|
+
hash: str,
|
|
418
|
+
) -> None:
|
|
419
|
+
...
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
F = TypeVar("F", bound=Callable)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
class HookChain(Generic[F]):
|
|
426
|
+
"""A chain of hooks of the same type F to be called in order.
|
|
427
|
+
"""
|
|
428
|
+
|
|
429
|
+
def __init__(self, reversed: bool = False):
|
|
430
|
+
self.calls: list[F] = []
|
|
431
|
+
self.reversed = reversed
|
|
432
|
+
|
|
433
|
+
def add(self, func: F) -> None:
|
|
434
|
+
if func not in self.calls:
|
|
435
|
+
self.calls.append(func)
|
|
436
|
+
|
|
437
|
+
def remove(self, func: F) -> None:
|
|
438
|
+
if func in self.calls:
|
|
439
|
+
self.calls.remove(func)
|
|
440
|
+
|
|
441
|
+
def __call__(self, *args, **kwargs):
|
|
442
|
+
for call in self.calls if not self.reversed else reversed(self.calls):
|
|
443
|
+
call(*args, **kwargs)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
# This is of the form [attr_name, attr_val]
|
|
447
|
+
# TODO: Use tuple instead of list for better typing.
|
|
448
|
+
KernelAttr = list[Union[str, int]]
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
class JITHookCompileInfo(TypedDict):
|
|
452
|
+
key: str
|
|
453
|
+
signature: dict[KernelParam, str]
|
|
454
|
+
device: int
|
|
455
|
+
constants: None
|
|
456
|
+
num_warps: int
|
|
457
|
+
num_ctas: int
|
|
458
|
+
num_stages: int
|
|
459
|
+
enable_fp_fusion: bool
|
|
460
|
+
launch_cooperative_grid: bool
|
|
461
|
+
extern_libs: tuple[tuple[str, str], ...]
|
|
462
|
+
configs: list[dict[tuple[int, ...], list[KernelAttr]]]
|
|
463
|
+
specialization_data: str
|
|
464
|
+
is_warmup: bool
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
class JITHook(Protocol):
|
|
468
|
+
|
|
469
|
+
def __call__(self, *, key: str, repr: str, fn: JitFunctionInfo, compile: JITHookCompileInfo, is_manual_warmup: bool,
|
|
470
|
+
already_compiled: bool) -> Optional[bool]:
|
|
471
|
+
...
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
class runtime_knobs(base_knobs):
|
|
475
|
+
interpret: env_bool = env_bool("TRITON_INTERPRET")
|
|
476
|
+
# debug is on critical path for kernel launches
|
|
477
|
+
# avoid repeated reads from env-var by calling get directly
|
|
478
|
+
debug: bool = env_bool("TRITON_DEBUG").get()
|
|
479
|
+
override_arch: env_opt_str = env_opt_str("TRITON_OVERRIDE_ARCH")
|
|
480
|
+
|
|
481
|
+
launch_enter_hook: HookChain[LaunchHook] = HookChain()
|
|
482
|
+
launch_exit_hook: HookChain[LaunchHook] = HookChain(reversed=True)
|
|
483
|
+
kernel_load_start_hook: HookChain[InitHandleHook] = HookChain()
|
|
484
|
+
kernel_load_end_hook: HookChain[InitHandleHook] = HookChain(reversed=True)
|
|
485
|
+
|
|
486
|
+
# Hook for inspecting compiled functions and modules
|
|
487
|
+
jit_cache_hook: Optional[JITHook] = None
|
|
488
|
+
# Hook to signal that a kernel is done compiling and inspect compiled function.
|
|
489
|
+
# jit_cache_hook will always be called before compilation and jit_post_compile_hook after.
|
|
490
|
+
jit_post_compile_hook: Optional[JITHook] = None
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
class language_knobs(base_knobs):
|
|
494
|
+
fp32_default: env_opt_str = env_opt_str("TRITON_F32_DEFAULT")
|
|
495
|
+
default_fp_fusion: env_bool = env_bool("TRITON_DEFAULT_FP_FUSION", True)
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
class nvidia_knobs(base_knobs):
|
|
499
|
+
cuobjdump: env_nvidia_tool = env_nvidia_tool("cuobjdump")
|
|
500
|
+
nvdisasm: env_nvidia_tool = env_nvidia_tool("nvdisasm")
|
|
501
|
+
ptxas: env_nvidia_tool = env_nvidia_tool("ptxas")
|
|
502
|
+
|
|
503
|
+
dump_nvptx: env_bool = env_bool("NVPTX_ENABLE_DUMP")
|
|
504
|
+
disable_ptxas_opt: env_bool = env_bool("DISABLE_PTXAS_OPT")
|
|
505
|
+
mock_ptx_version: env_opt_str = env_opt_str("TRITON_MOCK_PTX_VERSION")
|
|
506
|
+
dump_ptxas_log: env_bool = env_bool("TRITON_DUMP_PTXAS_LOG")
|
|
507
|
+
|
|
508
|
+
libdevice_path: env_opt_str = env_opt_str("TRITON_LIBDEVICE_PATH")
|
|
509
|
+
libcuda_path: env_opt_str = env_opt_str("TRITON_LIBCUDA_PATH")
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
class amd_knobs(base_knobs):
|
|
513
|
+
use_buffer_ops: env_bool = env_bool("AMDGCN_USE_BUFFER_OPS")
|
|
514
|
+
# Note: This requires use_buffer_ops be true to have any effect
|
|
515
|
+
use_buffer_atomics: env_bool = env_bool("AMDGCN_USE_BUFFER_ATOMICS", True)
|
|
516
|
+
dump_amdgcn: env_bool = env_bool("AMDGCN_ENABLE_DUMP")
|
|
517
|
+
libhip_path: env_opt_str = env_opt_str("TRITON_LIBHIP_PATH")
|
|
518
|
+
|
|
519
|
+
# We use strs so that we can have a default value based on other runtime info
|
|
520
|
+
use_block_pingpong: env_opt_bool = env_opt_bool("TRITON_HIP_USE_BLOCK_PINGPONG")
|
|
521
|
+
use_in_thread_transpose: env_opt_bool = env_opt_bool("TRITON_HIP_USE_IN_THREAD_TRANSPOSE")
|
|
522
|
+
|
|
523
|
+
global_prefetch: env_int = env_int("TRITON_HIP_GLOBAL_PREFETCH")
|
|
524
|
+
local_prefetch: env_int = env_int("TRITON_HIP_LOCAL_PREFETCH")
|
|
525
|
+
use_async_copy: env_bool = env_bool("TRITON_HIP_USE_ASYNC_COPY")
|
|
526
|
+
scalarize_packed_fops: env_bool = env_bool("AMDGCN_SCALARIZE_PACKED_FOPS")
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
class proton_knobs(base_knobs):
|
|
530
|
+
cupti_dir: env_opt_str = env_opt_str("TRITON_CUPTI_LIB_PATH")
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
build = build_knobs()
|
|
534
|
+
redis = redis_knobs()
|
|
535
|
+
cache = cache_knobs()
|
|
536
|
+
compilation = compilation_knobs()
|
|
537
|
+
autotuning = autotuning_knobs()
|
|
538
|
+
runtime = runtime_knobs()
|
|
539
|
+
language = language_knobs()
|
|
540
|
+
nvidia = nvidia_knobs()
|
|
541
|
+
amd = amd_knobs()
|
|
542
|
+
proton = proton_knobs()
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def refresh_knobs():
|
|
546
|
+
runtime.debug = env_bool("TRITON_DEBUG").get()
|