triton-windows 3.3.0.post19__cp312-cp312-win_amd64.whl → 3.4.0.post20__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +4 -1
- triton/_filecheck.py +87 -0
- triton/_internal_testing.py +26 -15
- triton/_utils.py +110 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +112 -78
- triton/backends/amd/driver.c +5 -2
- triton/backends/amd/driver.py +149 -47
- triton/backends/compiler.py +7 -21
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +92 -93
- triton/backends/nvidia/driver.c +90 -98
- triton/backends/nvidia/driver.py +303 -128
- triton/compiler/code_generator.py +212 -111
- triton/compiler/compiler.py +110 -25
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +4 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +99 -0
- triton/experimental/gluon/language/__init__.py +18 -0
- triton/experimental/gluon/language/_core.py +312 -0
- triton/experimental/gluon/language/_layouts.py +230 -0
- triton/experimental/gluon/language/_math.py +12 -0
- triton/experimental/gluon/language/_semantic.py +287 -0
- triton/experimental/gluon/language/_standard.py +47 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +40 -0
- triton/knobs.py +481 -0
- triton/language/__init__.py +39 -14
- triton/language/core.py +794 -537
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/libdevice.py +113 -104
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1706 -1770
- triton/language/standard.py +116 -51
- triton/runtime/autotuner.py +117 -59
- triton/runtime/build.py +76 -12
- triton/runtime/cache.py +18 -47
- triton/runtime/driver.py +32 -29
- triton/runtime/interpreter.py +72 -35
- triton/runtime/jit.py +146 -110
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/testing.py +16 -12
- triton/tools/disasm.py +3 -4
- triton/tools/tensor_descriptor.py +36 -0
- triton/windows_utils.py +14 -6
- {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
- triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
- {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +1 -1
- triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
- triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.0.post19.dist-info/RECORD +0 -253
- triton_windows-3.3.0.post19.dist-info/top_level.txt +0 -14
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import List, Tuple, TYPE_CHECKING
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
import triton.experimental.gluon.language._core as ttgl
|
|
5
|
+
from triton.experimental.gluon.language._layouts import NVMMASharedLayout
|
|
6
|
+
from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from triton._C import ir
|
|
10
|
+
|
|
11
|
+
__all__ = ["async_copy_global_to_shared", "async_copy_shared_to_global", "store_wait"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(eq=True)
|
|
15
|
+
class tensor_descriptor_type:
|
|
16
|
+
block_type: ttgl.block_type
|
|
17
|
+
shape_type: ttgl.tuple_type
|
|
18
|
+
strides_type: ttgl.tuple_type
|
|
19
|
+
layout: NVMMASharedLayout
|
|
20
|
+
|
|
21
|
+
def __str__(self) -> str:
|
|
22
|
+
return f"tensor_descriptor<{self.block_type}, {self.layout}>"
|
|
23
|
+
|
|
24
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor, int]:
|
|
25
|
+
handle = handles[cursor]
|
|
26
|
+
cursor += 1
|
|
27
|
+
shape, cursor = self.shape_type._unflatten_ir(handles, cursor)
|
|
28
|
+
strides, cursor = self.strides_type._unflatten_ir(handles, cursor)
|
|
29
|
+
value = tensor_descriptor(handle, shape, strides, self.block_type, layout=self.layout)
|
|
30
|
+
return value, cursor
|
|
31
|
+
|
|
32
|
+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
|
|
33
|
+
is_signed = self.block_type.element_ty.is_int_signed()
|
|
34
|
+
ty = builder.get_tensor_descriptor_layout_type(
|
|
35
|
+
self.block_type.to_ir(builder),
|
|
36
|
+
is_signed,
|
|
37
|
+
self.layout._to_ir(builder),
|
|
38
|
+
)
|
|
39
|
+
out.append(ty)
|
|
40
|
+
self.shape_type._flatten_ir_types(builder, out)
|
|
41
|
+
self.strides_type._flatten_ir_types(builder, out)
|
|
42
|
+
|
|
43
|
+
def mangle(self) -> str:
|
|
44
|
+
return f"TD{self.block_type.mangle}_{self.layout.mangle()}TD"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class tensor_descriptor:
|
|
48
|
+
|
|
49
|
+
def __init__(self, handle, shape: List[ttgl.tensor], strides: List[ttgl.tensor], block_type: ttgl.block_type,
|
|
50
|
+
layout: NVMMASharedLayout):
|
|
51
|
+
self.handle = handle
|
|
52
|
+
self.shape = ttgl.tuple(shape)
|
|
53
|
+
self.strides = ttgl.tuple(strides)
|
|
54
|
+
self.type = tensor_descriptor_type(block_type, shape_type=self.shape.type, strides_type=self.strides.type,
|
|
55
|
+
layout=layout)
|
|
56
|
+
|
|
57
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
58
|
+
handles.append(self.handle)
|
|
59
|
+
self.shape._flatten_ir(handles)
|
|
60
|
+
self.strides._flatten_ir(handles)
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def block_type(self):
|
|
64
|
+
return self.type.block_type
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def block_shape(self):
|
|
68
|
+
return self.type.block_type.shape
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def dtype(self):
|
|
72
|
+
return self.type.block_type.element_ty
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def layout(self):
|
|
76
|
+
return self.type.layout
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@builtin
|
|
80
|
+
def async_copy_global_to_shared(tensor_desc, coord, barrier, result, pred=True, _semantic=None):
|
|
81
|
+
coord = _semantic._convert_to_ir_values(coord, require_i64=False)
|
|
82
|
+
pred = _semantic.to_tensor(pred)
|
|
83
|
+
_semantic.builder.create_async_tma_copy_global_to_local(tensor_desc.handle, coord, barrier.handle, result.handle,
|
|
84
|
+
pred.handle)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@builtin
|
|
88
|
+
def async_copy_shared_to_global(tensor_desc, coord, src, _semantic=None):
|
|
89
|
+
coord = _semantic._convert_to_ir_values(coord, require_i64=False)
|
|
90
|
+
_semantic.builder.create_async_tma_copy_local_to_global(tensor_desc.handle, coord, src.handle)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@builtin
|
|
94
|
+
def store_wait(pendings, _semantic=None):
|
|
95
|
+
pendings = _unwrap_if_constexpr(pendings)
|
|
96
|
+
_semantic.builder.create_async_tma_store_wait(pendings)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import List, Any
|
|
3
|
+
from triton._utils import validate_block_shape, canonicalize_dtype, get_primitive_bitwidth
|
|
4
|
+
from triton.experimental.gluon.language._layouts import NVMMASharedLayout
|
|
5
|
+
|
|
6
|
+
__all__ = ["TensorDescriptor"]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class TensorDescriptor:
|
|
11
|
+
base: Any
|
|
12
|
+
shape: List[int]
|
|
13
|
+
strides: List[int]
|
|
14
|
+
block_shape: List[int]
|
|
15
|
+
layout: NVMMASharedLayout
|
|
16
|
+
|
|
17
|
+
def __post_init__(self):
|
|
18
|
+
rank = len(self.shape)
|
|
19
|
+
assert len(self.strides) == rank, f"rank mismatch: {self}"
|
|
20
|
+
assert len(self.block_shape) == rank, f"rank mismatch: {self}"
|
|
21
|
+
assert rank > 0, "rank must not be zero"
|
|
22
|
+
assert rank <= 5, "rank cannot be more than 5"
|
|
23
|
+
assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
|
|
24
|
+
validate_block_shape(self.block_shape)
|
|
25
|
+
dtype_str = canonicalize_dtype(self.base.dtype)
|
|
26
|
+
elem_bytes = get_primitive_bitwidth(dtype_str) // 8
|
|
27
|
+
for stride in self.strides[:-1]:
|
|
28
|
+
assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
|
|
29
|
+
assert self.strides[-1] == 1, "Last dimension must be contiguous"
|
|
30
|
+
assert isinstance(self.layout, NVMMASharedLayout), "Layout must be NVMMASharedLayout"
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def from_tensor(tensor: Any, block_shape: List[int], layout: NVMMASharedLayout):
|
|
34
|
+
return TensorDescriptor(
|
|
35
|
+
tensor,
|
|
36
|
+
tensor.shape,
|
|
37
|
+
tensor.stride(),
|
|
38
|
+
block_shape,
|
|
39
|
+
layout,
|
|
40
|
+
)
|
triton/knobs.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
import subprocess
|
|
7
|
+
import sysconfig
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from contextlib import contextmanager
|
|
11
|
+
from typing import cast, Any, Callable, Generator, Generic, Optional, Protocol, Type, TypeVar, TypedDict, TYPE_CHECKING, Union
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from .runtime.cache import CacheManager, RemoteCacheBackend
|
|
15
|
+
from .runtime.jit import JitFunctionInfo, KernelParam
|
|
16
|
+
from .compiler.compiler import ASTSource, LazyDict, IRSource
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Env:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
env = Env()
|
|
24
|
+
|
|
25
|
+
propagate_env: bool = True
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def getenv(key: str) -> Optional[str]:
|
|
29
|
+
res = os.getenv(key)
|
|
30
|
+
return res.strip() if res is not None else res
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def setenv(key: str, value: Optional[str]) -> None:
|
|
34
|
+
if not propagate_env:
|
|
35
|
+
return
|
|
36
|
+
|
|
37
|
+
if value is not None:
|
|
38
|
+
os.environ[key] = value
|
|
39
|
+
elif key in os.environ:
|
|
40
|
+
del os.environ[key]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def toenv(val: Any) -> Union[None, tuple[Optional[str]]]:
|
|
44
|
+
if val is None:
|
|
45
|
+
return (None, )
|
|
46
|
+
|
|
47
|
+
t = type(val)
|
|
48
|
+
if t is bool:
|
|
49
|
+
return ("1" if val else "0", )
|
|
50
|
+
|
|
51
|
+
if t is str:
|
|
52
|
+
return (val, )
|
|
53
|
+
|
|
54
|
+
if t is int:
|
|
55
|
+
return (str(val), )
|
|
56
|
+
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# There's an asymmetry here so that e.g. env_nvidia_tool can be specified with a
|
|
61
|
+
# a string but return an NvidiaTool.
|
|
62
|
+
SetType = TypeVar("SetType")
|
|
63
|
+
GetType = TypeVar("GetType")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class env_base(Generic[SetType, GetType]):
|
|
67
|
+
|
|
68
|
+
def __init__(self, key: str, default: Union[SetType, Callable[[], SetType]]) -> None:
|
|
69
|
+
self.key = key
|
|
70
|
+
self.default: Callable[[], SetType] = default if callable(default) else lambda: default
|
|
71
|
+
|
|
72
|
+
def __set_name__(self, objclass: Type[object], name: str) -> None:
|
|
73
|
+
self.name = name
|
|
74
|
+
|
|
75
|
+
def __get__(self, obj: Optional[object], objclass: Optional[Type[object]]) -> GetType:
|
|
76
|
+
if obj is None:
|
|
77
|
+
raise AttributeError(f"Cannot access {type(self)} on non-instance")
|
|
78
|
+
|
|
79
|
+
if self.name in obj.__dict__:
|
|
80
|
+
return self.transform(obj.__dict__[self.name])
|
|
81
|
+
else:
|
|
82
|
+
return self.get()
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def env_val(self) -> str | None:
|
|
86
|
+
return getenv(self.key)
|
|
87
|
+
|
|
88
|
+
def get(self) -> GetType:
|
|
89
|
+
env = self.env_val
|
|
90
|
+
return self.transform(self.default() if env is None else self.from_env(env))
|
|
91
|
+
|
|
92
|
+
def __set__(self, obj: object, value: Union[SetType, Env]) -> None:
|
|
93
|
+
if isinstance(value, Env):
|
|
94
|
+
obj.__dict__.pop(self.name, None)
|
|
95
|
+
else:
|
|
96
|
+
obj.__dict__[self.name] = value
|
|
97
|
+
if env_val := toenv(value):
|
|
98
|
+
setenv(self.key, env_val[0])
|
|
99
|
+
|
|
100
|
+
def __delete__(self, obj: object) -> None:
|
|
101
|
+
obj.__dict__.pop(self.name, None)
|
|
102
|
+
|
|
103
|
+
def transform(self, val: SetType) -> GetType:
|
|
104
|
+
# See comment about GetType/SetType in their definition above. Only needed
|
|
105
|
+
# if GetType != SetType.
|
|
106
|
+
return cast(GetType, val)
|
|
107
|
+
|
|
108
|
+
def from_env(self, val: str) -> SetType:
|
|
109
|
+
raise NotImplementedError()
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class env_str(env_base[str, str]):
|
|
113
|
+
|
|
114
|
+
def from_env(self, val: str) -> str:
|
|
115
|
+
return val
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class env_bool(env_base[bool, bool]):
|
|
119
|
+
|
|
120
|
+
def __init__(self, key: str, default: Union[bool, Callable[[], bool]] = False) -> None:
|
|
121
|
+
super().__init__(key, default)
|
|
122
|
+
|
|
123
|
+
def from_env(self, val: str) -> bool:
|
|
124
|
+
return val.lower() in ("1", "true", "yes", "on", "y")
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class env_int(env_base[int, int]):
|
|
128
|
+
|
|
129
|
+
def __init__(self, key: str, default: Union[int, Callable[[], int]] = 0) -> None:
|
|
130
|
+
super().__init__(key, default)
|
|
131
|
+
|
|
132
|
+
def from_env(self, val: str) -> int:
|
|
133
|
+
try:
|
|
134
|
+
return int(val)
|
|
135
|
+
except ValueError as exc:
|
|
136
|
+
raise RuntimeError(f"Unable to use {self.key}={val}: expected int") from exc
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class env_opt_base(Generic[GetType, SetType], env_base[Optional[GetType], Optional[SetType]]):
|
|
140
|
+
|
|
141
|
+
def __init__(self, key: str) -> None:
|
|
142
|
+
super().__init__(key, None)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
ClassType = TypeVar("ClassType")
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class env_class(Generic[ClassType], env_opt_base[Type[ClassType], Type[ClassType]]):
|
|
149
|
+
|
|
150
|
+
def __init__(self, key: str, type: str) -> None:
|
|
151
|
+
super().__init__(key)
|
|
152
|
+
# We can't pass the type directly to avoid import cycles
|
|
153
|
+
self.type = type
|
|
154
|
+
|
|
155
|
+
def from_env(self, val: str) -> Type[ClassType]:
|
|
156
|
+
comps = val.split(":", 1)
|
|
157
|
+
if len(comps) != 2:
|
|
158
|
+
raise RuntimeError(f"Unable to read {self.key}: '{val}' isn't of the form MODULE:CLASS")
|
|
159
|
+
cls = getattr(importlib.import_module(comps[0]), comps[1])
|
|
160
|
+
|
|
161
|
+
if not any((c.__name__ == self.type for c in cls.mro())):
|
|
162
|
+
raise RuntimeError(f"Unable to use '{val}' from {self.key}: not of type '{self.type}'")
|
|
163
|
+
|
|
164
|
+
return cast(Type[ClassType], cls)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@dataclass
|
|
168
|
+
class NvidiaTool:
|
|
169
|
+
path: str
|
|
170
|
+
version: str
|
|
171
|
+
|
|
172
|
+
@staticmethod
|
|
173
|
+
def from_path(path: str) -> Optional[NvidiaTool]:
|
|
174
|
+
try:
|
|
175
|
+
result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
|
|
176
|
+
if result is None:
|
|
177
|
+
return None
|
|
178
|
+
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
|
179
|
+
if version is None:
|
|
180
|
+
return None
|
|
181
|
+
return NvidiaTool(path, version.group(1))
|
|
182
|
+
except subprocess.CalledProcessError:
|
|
183
|
+
return None
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def find_nvidia_tool(binary: str) -> str:
|
|
187
|
+
path = os.path.join(
|
|
188
|
+
os.path.dirname(__file__),
|
|
189
|
+
"backends",
|
|
190
|
+
"nvidia",
|
|
191
|
+
"bin",
|
|
192
|
+
binary,
|
|
193
|
+
)
|
|
194
|
+
if os.access(path, os.X_OK):
|
|
195
|
+
return path
|
|
196
|
+
|
|
197
|
+
if os.name == "nt":
|
|
198
|
+
from triton.windows_utils import find_cuda
|
|
199
|
+
cuda_bin_path, _, _ = find_cuda()
|
|
200
|
+
if cuda_bin_path:
|
|
201
|
+
path = os.path.join(cuda_bin_path, binary)
|
|
202
|
+
if os.access(path, os.X_OK):
|
|
203
|
+
return path
|
|
204
|
+
|
|
205
|
+
return ""
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class env_nvidia_tool(env_base[str, NvidiaTool]):
|
|
209
|
+
|
|
210
|
+
def __init__(self, binary: str) -> None:
|
|
211
|
+
binary += sysconfig.get_config_var("EXE")
|
|
212
|
+
self.binary = binary
|
|
213
|
+
super().__init__(f"TRITON_{binary.upper()}_PATH", lambda: find_nvidia_tool(self.binary))
|
|
214
|
+
|
|
215
|
+
def transform(self, path: str) -> NvidiaTool:
|
|
216
|
+
paths = [
|
|
217
|
+
path,
|
|
218
|
+
# We still add default as fallback in case the pointed binary isn't
|
|
219
|
+
# accessible.
|
|
220
|
+
self.default(),
|
|
221
|
+
]
|
|
222
|
+
for path in paths:
|
|
223
|
+
if not path or not os.access(path, os.X_OK):
|
|
224
|
+
continue
|
|
225
|
+
if tool := NvidiaTool.from_path(path):
|
|
226
|
+
return tool
|
|
227
|
+
|
|
228
|
+
raise RuntimeError(f"Cannot find {self.binary}")
|
|
229
|
+
|
|
230
|
+
def from_env(self, val: str) -> str:
|
|
231
|
+
return val
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
# Separate classes so that types are correct
|
|
235
|
+
class env_opt_str(env_opt_base[str, str], env_str):
|
|
236
|
+
pass
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class env_opt_bool(env_opt_base[bool, bool], env_bool):
|
|
240
|
+
pass
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@dataclass(frozen=True)
|
|
244
|
+
class CompileTimes:
|
|
245
|
+
"""
|
|
246
|
+
Model holding timing information for an invocation of the compiler.
|
|
247
|
+
|
|
248
|
+
All times in microseconds.
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
# Duration of make_ir
|
|
252
|
+
ir_initialization: int
|
|
253
|
+
|
|
254
|
+
# Ordered mapping from lowering stage to duration spent in that stage.
|
|
255
|
+
# Keyed by stage extension, e.g. ttir, ttgir
|
|
256
|
+
lowering_stages: list[tuple[str, int]]
|
|
257
|
+
|
|
258
|
+
# Duration of saving artifacts/metadata to cache
|
|
259
|
+
store_results: int
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def total_lowering(self) -> int:
|
|
263
|
+
return sum((stage[1] for stage in self.lowering_stages))
|
|
264
|
+
|
|
265
|
+
@property
|
|
266
|
+
def total(self) -> int:
|
|
267
|
+
return self.ir_initialization + self.total_lowering + self.store_results
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class CompilationListener(Protocol):
|
|
271
|
+
|
|
272
|
+
def __call__(self, *, src: Union[ASTSource, IRSource], metadata: dict[str, Any], metadata_group: dict[str, str],
|
|
273
|
+
times: CompileTimes, cache_hit: bool) -> None:
|
|
274
|
+
...
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
knobs_type = TypeVar("knobs_type", bound='base_knobs')
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class base_knobs:
|
|
281
|
+
|
|
282
|
+
@property
|
|
283
|
+
def knob_descriptors(self) -> dict[str, env_base]:
|
|
284
|
+
return {
|
|
285
|
+
k: v
|
|
286
|
+
# data descriptors live on the class object
|
|
287
|
+
for k, v in type(self).__dict__.items()
|
|
288
|
+
if isinstance(v, env_base)
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
@property
|
|
292
|
+
def knobs(self) -> dict[str, Any]:
|
|
293
|
+
return {k: getattr(self, k) for k in self.knob_descriptors.keys()}
|
|
294
|
+
|
|
295
|
+
def copy(self: knobs_type) -> knobs_type:
|
|
296
|
+
res = type(self)()
|
|
297
|
+
res.__dict__.update(self.__dict__)
|
|
298
|
+
return res
|
|
299
|
+
|
|
300
|
+
def reset(self: knobs_type) -> knobs_type:
|
|
301
|
+
for knob in self.knob_descriptors.keys():
|
|
302
|
+
delattr(self, knob)
|
|
303
|
+
return self
|
|
304
|
+
|
|
305
|
+
@contextmanager
|
|
306
|
+
def scope(self) -> Generator[None, None, None]:
|
|
307
|
+
try:
|
|
308
|
+
initial_env = {knob.key: knob.env_val for knob in self.knob_descriptors.values()}
|
|
309
|
+
orig = dict(self.__dict__)
|
|
310
|
+
yield
|
|
311
|
+
finally:
|
|
312
|
+
self.__dict__.clear()
|
|
313
|
+
self.__dict__.update(orig)
|
|
314
|
+
|
|
315
|
+
for k, v in initial_env.items():
|
|
316
|
+
if v is not None:
|
|
317
|
+
os.environ[k] = v
|
|
318
|
+
elif k in os.environ:
|
|
319
|
+
del os.environ[k]
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class BuildImpl(Protocol):
|
|
323
|
+
|
|
324
|
+
def __call__(self, name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str],
|
|
325
|
+
libraries: list[str], /) -> str:
|
|
326
|
+
...
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class build_knobs(base_knobs):
|
|
330
|
+
"""Configuration controlling how the native compiler is invoked"""
|
|
331
|
+
cc: env_opt_str = env_opt_str("CC")
|
|
332
|
+
|
|
333
|
+
cudacrt_path: env_opt_str = env_opt_str("TRITON_CUDACRT_PATH")
|
|
334
|
+
cudart_path: env_opt_str = env_opt_str("TRITON_CUDART_PATH")
|
|
335
|
+
|
|
336
|
+
impl: Optional[BuildImpl] = None
|
|
337
|
+
|
|
338
|
+
@property
|
|
339
|
+
def backend_dirs(self) -> set[str]:
|
|
340
|
+
return {path for path in (self.cudacrt_path, self.cudart_path) if path is not None}
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
class redis_knobs(base_knobs):
|
|
344
|
+
key_format: env_str = env_str("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}")
|
|
345
|
+
host: env_str = env_str("TRITON_REDIS_HOST", "localhost")
|
|
346
|
+
port: env_int = env_int("TRITON_REDIS_PORT", 6379)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
cache: cache_knobs
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
class cache_knobs(base_knobs):
|
|
353
|
+
home_dir: env_str = env_str("TRITON_HOME", lambda: os.path.expanduser("~/"))
|
|
354
|
+
|
|
355
|
+
dump_dir: env_str = env_str("TRITON_DUMP_DIR", lambda: cache.get_triton_dir("dump"))
|
|
356
|
+
override_dir: env_str = env_str("TRITON_OVERRIDE_DIR", lambda: cache.get_triton_dir("override"))
|
|
357
|
+
dir: env_str = env_str("TRITON_CACHE_DIR", lambda: cache.get_triton_dir("cache"))
|
|
358
|
+
|
|
359
|
+
manager_class: env_class[CacheManager] = env_class("TRITON_CACHE_MANAGER", "CacheManager")
|
|
360
|
+
remote_manager_class: env_class[RemoteCacheBackend] = env_class("TRITON_REMOTE_CACHE_BACKEND", "RemoteCacheBackend")
|
|
361
|
+
|
|
362
|
+
def get_triton_dir(self, dirname: str) -> str:
|
|
363
|
+
return os.path.join(self.home_dir, ".triton", dirname)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
class compilation_knobs(base_knobs):
|
|
367
|
+
override: env_bool = env_bool("TRITON_KERNEL_OVERRIDE")
|
|
368
|
+
dump_ir: env_bool = env_bool("TRITON_KERNEL_DUMP")
|
|
369
|
+
store_binary_only: env_bool = env_bool("TRITON_STORE_BINARY_ONLY")
|
|
370
|
+
always_compile: env_bool = env_bool("TRITON_ALWAYS_COMPILE")
|
|
371
|
+
# TODO: Use enum to constrain / 'typecheck' the values
|
|
372
|
+
use_ir_loc: env_opt_str = env_opt_str("USE_IR_LOC")
|
|
373
|
+
enable_asan: env_bool = env_bool("TRITON_ENABLE_ASAN")
|
|
374
|
+
disable_line_info: env_bool = env_bool("TRITON_DISABLE_LINE_INFO")
|
|
375
|
+
front_end_debugging: env_bool = env_bool("TRITON_FRONT_END_DEBUGGING")
|
|
376
|
+
allow_non_constexpr_globals: env_bool = env_bool("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS")
|
|
377
|
+
listener: Union[CompilationListener, None] = None
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
class autotuning_knobs(base_knobs):
|
|
381
|
+
cache: env_bool = env_bool("TRITON_CACHE_AUTOTUNING")
|
|
382
|
+
print: env_bool = env_bool("TRITON_PRINT_AUTOTUNING")
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
class LaunchHook(Protocol):
|
|
386
|
+
|
|
387
|
+
def __call__(self, metadata: LazyDict) -> None:
|
|
388
|
+
...
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
# This is of the form [attr_name, attr_val]
|
|
392
|
+
# TODO: Use tuple instead of list for better typing.
|
|
393
|
+
KernelAttr = list[Union[str, int]]
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
class JITHookCompileInfo(TypedDict):
|
|
397
|
+
key: str
|
|
398
|
+
signature: dict[KernelParam, str]
|
|
399
|
+
device: int
|
|
400
|
+
constants: None
|
|
401
|
+
num_warps: int
|
|
402
|
+
num_ctas: int
|
|
403
|
+
num_stages: int
|
|
404
|
+
enable_fp_fusion: bool
|
|
405
|
+
launch_cooperative_grid: bool
|
|
406
|
+
extern_libs: tuple[tuple[str, str], ...]
|
|
407
|
+
configs: list[dict[tuple[int, ...], list[KernelAttr]]]
|
|
408
|
+
specialization_data: str
|
|
409
|
+
is_warmup: bool
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
class JITHook(Protocol):
|
|
413
|
+
|
|
414
|
+
def __call__(self, *, key: str, repr: str, fn: JitFunctionInfo, compile: JITHookCompileInfo, is_manual_warmup: bool,
|
|
415
|
+
already_compiled: bool) -> Optional[bool]:
|
|
416
|
+
...
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
class runtime_knobs(base_knobs):
|
|
420
|
+
interpret: env_bool = env_bool("TRITON_INTERPRET")
|
|
421
|
+
debug: env_bool = env_bool("TRITON_DEBUG")
|
|
422
|
+
override_arch: env_opt_str = env_opt_str("TRITON_OVERRIDE_ARCH")
|
|
423
|
+
|
|
424
|
+
launch_enter_hook: Optional[LaunchHook] = None
|
|
425
|
+
launch_exit_hook: Optional[LaunchHook] = None
|
|
426
|
+
|
|
427
|
+
# Hook for inspecting compiled functions and modules
|
|
428
|
+
jit_cache_hook: Optional[JITHook] = None
|
|
429
|
+
# Hook to signal that a kernel is done compiling and inspect compiled function.
|
|
430
|
+
# jit_cache_hook will always be called before compilation and jit_post_compile_hook after.
|
|
431
|
+
jit_post_compile_hook: Optional[JITHook] = None
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
class language_knobs(base_knobs):
|
|
435
|
+
fp32_default: env_opt_str = env_opt_str("TRITON_F32_DEFAULT")
|
|
436
|
+
default_fp_fusion: env_bool = env_bool("TRITON_DEFAULT_FP_FUSION", True)
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
class nvidia_knobs(base_knobs):
|
|
440
|
+
cuobjdump: env_nvidia_tool = env_nvidia_tool("cuobjdump")
|
|
441
|
+
nvdisasm: env_nvidia_tool = env_nvidia_tool("nvdisasm")
|
|
442
|
+
ptxas: env_nvidia_tool = env_nvidia_tool("ptxas")
|
|
443
|
+
|
|
444
|
+
dump_nvptx: env_bool = env_bool("NVPTX_ENABLE_DUMP")
|
|
445
|
+
disable_ptxas_opt: env_bool = env_bool("DISABLE_PTXAS_OPT")
|
|
446
|
+
mock_ptx_version: env_opt_str = env_opt_str("TRITON_MOCK_PTX_VERSION")
|
|
447
|
+
|
|
448
|
+
libdevice_path: env_opt_str = env_opt_str("TRITON_LIBDEVICE_PATH")
|
|
449
|
+
libcuda_path: env_opt_str = env_opt_str("TRITON_LIBCUDA_PATH")
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
class amd_knobs(base_knobs):
|
|
453
|
+
use_buffer_ops: env_bool = env_bool("AMDGCN_USE_BUFFER_OPS")
|
|
454
|
+
dump_amdgcn: env_bool = env_bool("AMDGCN_ENABLE_DUMP")
|
|
455
|
+
libhip_path: env_opt_str = env_opt_str("TRITON_LIBHIP_PATH")
|
|
456
|
+
lld_path: env_opt_str = env_opt_str("TRITON_HIP_LLD_PATH")
|
|
457
|
+
|
|
458
|
+
# We use strs so that we can have a default value based on other runtime info
|
|
459
|
+
use_block_pingpong: env_opt_bool = env_opt_bool("TRITON_HIP_USE_BLOCK_PINGPONG")
|
|
460
|
+
use_in_thread_transpose: env_opt_bool = env_opt_bool("TRITON_HIP_USE_IN_THREAD_TRANSPOSE")
|
|
461
|
+
|
|
462
|
+
global_prefetch: env_int = env_int("TRITON_HIP_GLOBAL_PREFETCH")
|
|
463
|
+
local_prefetch: env_int = env_int("TRITON_HIP_LOCAL_PREFETCH")
|
|
464
|
+
use_async_copy: env_bool = env_bool("TRITON_HIP_USE_ASYNC_COPY")
|
|
465
|
+
scalarize_packed_fops: env_bool = env_bool("AMDGCN_SCALARIZE_PACKED_FOPS")
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
class proton_knobs(base_knobs):
|
|
469
|
+
cupti_dir: env_opt_str = env_opt_str("TRITON_CUPTI_LIB_PATH")
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
build = build_knobs()
|
|
473
|
+
redis = redis_knobs()
|
|
474
|
+
cache = cache_knobs()
|
|
475
|
+
compilation = compilation_knobs()
|
|
476
|
+
autotuning = autotuning_knobs()
|
|
477
|
+
runtime = runtime_knobs()
|
|
478
|
+
language = language_knobs()
|
|
479
|
+
nvidia = nvidia_knobs()
|
|
480
|
+
amd = amd_knobs()
|
|
481
|
+
proton = proton_knobs()
|