triton-windows 3.3.1.post19__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/runtime/driver.py
CHANGED
|
@@ -1,60 +1,38 @@
|
|
|
1
|
-
from
|
|
2
|
-
from ..backends import DriverBase
|
|
1
|
+
from __future__ import annotations
|
|
3
2
|
|
|
3
|
+
from ..backends import backends, DriverBase
|
|
4
4
|
|
|
5
|
-
def _create_driver():
|
|
6
|
-
actives = [x.driver for x in backends.values() if x.driver.is_active()]
|
|
7
|
-
if len(actives) != 1:
|
|
8
|
-
raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.")
|
|
9
|
-
return actives[0]()
|
|
10
5
|
|
|
6
|
+
def _create_driver() -> DriverBase:
|
|
7
|
+
active_drivers = [x.driver for x in backends.values() if x.driver.is_active()]
|
|
8
|
+
if len(active_drivers) != 1:
|
|
9
|
+
raise RuntimeError(f"{len(active_drivers)} active drivers ({active_drivers}). There should only be one.")
|
|
10
|
+
return active_drivers[0]()
|
|
11
11
|
|
|
12
|
-
class LazyProxy:
|
|
13
12
|
|
|
14
|
-
|
|
15
|
-
self._init_fn = init_fn
|
|
16
|
-
self._obj = None
|
|
17
|
-
|
|
18
|
-
def _initialize_obj(self):
|
|
19
|
-
if self._obj is None:
|
|
20
|
-
self._obj = self._init_fn()
|
|
21
|
-
|
|
22
|
-
def __getattr__(self, name):
|
|
23
|
-
self._initialize_obj()
|
|
24
|
-
return getattr(self._obj, name)
|
|
25
|
-
|
|
26
|
-
def __setattr__(self, name, value):
|
|
27
|
-
if name in ["_init_fn", "_obj"]:
|
|
28
|
-
super().__setattr__(name, value)
|
|
29
|
-
else:
|
|
30
|
-
self._initialize_obj()
|
|
31
|
-
setattr(self._obj, name, value)
|
|
32
|
-
|
|
33
|
-
def __delattr__(self, name):
|
|
34
|
-
self._initialize_obj()
|
|
35
|
-
delattr(self._obj, name)
|
|
36
|
-
|
|
37
|
-
def __repr__(self):
|
|
38
|
-
if self._obj is None:
|
|
39
|
-
return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
|
|
40
|
-
return repr(self._obj)
|
|
41
|
-
|
|
42
|
-
def __str__(self):
|
|
43
|
-
self._initialize_obj()
|
|
44
|
-
return str(self._obj)
|
|
13
|
+
class DriverConfig:
|
|
45
14
|
|
|
15
|
+
def __init__(self) -> None:
|
|
16
|
+
self._default: DriverBase | None = None
|
|
17
|
+
self._active: DriverBase | None = None
|
|
46
18
|
|
|
47
|
-
|
|
19
|
+
@property
|
|
20
|
+
def default(self) -> DriverBase:
|
|
21
|
+
if self._default is None:
|
|
22
|
+
self._default = _create_driver()
|
|
23
|
+
return self._default
|
|
48
24
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
self.
|
|
25
|
+
@property
|
|
26
|
+
def active(self) -> DriverBase:
|
|
27
|
+
if self._active is None:
|
|
28
|
+
self._active = self.default
|
|
29
|
+
return self._active
|
|
52
30
|
|
|
53
|
-
def set_active(self, driver: DriverBase):
|
|
54
|
-
self.
|
|
31
|
+
def set_active(self, driver: DriverBase) -> None:
|
|
32
|
+
self._active = driver
|
|
55
33
|
|
|
56
|
-
def reset_active(self):
|
|
57
|
-
self.
|
|
34
|
+
def reset_active(self) -> None:
|
|
35
|
+
self._active = self.default
|
|
58
36
|
|
|
59
37
|
|
|
60
38
|
driver = DriverConfig()
|
triton/runtime/interpreter.py
CHANGED
|
@@ -1,32 +1,36 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
1
2
|
import ast
|
|
2
3
|
import textwrap
|
|
3
4
|
import inspect
|
|
4
|
-
from typing import Tuple, List
|
|
5
|
+
from typing import Tuple, List, Dict, Callable
|
|
5
6
|
|
|
6
7
|
import math
|
|
7
8
|
import numpy as np
|
|
8
9
|
|
|
9
10
|
import triton
|
|
10
11
|
import triton.language as tl
|
|
12
|
+
import dataclasses
|
|
11
13
|
from dataclasses import dataclass
|
|
14
|
+
|
|
15
|
+
from triton.language.semantic import TritonSemantic
|
|
16
|
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
|
12
17
|
from .errors import InterpreterError
|
|
13
18
|
from functools import partial
|
|
14
19
|
from .._C.libtriton import interpreter as _interpreter
|
|
15
20
|
from .._C.libtriton import ir as _ir
|
|
16
21
|
|
|
17
22
|
|
|
23
|
+
@dataclass
|
|
18
24
|
class TensorHandle:
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
self.dtype = dtype
|
|
29
|
-
self.attr = {}
|
|
25
|
+
'''
|
|
26
|
+
data: numpy array
|
|
27
|
+
dtype: triton type, either pointer_type or scalar_type.
|
|
28
|
+
we don't store block_type here because the shape information is already available in the data field
|
|
29
|
+
attr: a dictionary of attributes
|
|
30
|
+
'''
|
|
31
|
+
data: np.array
|
|
32
|
+
dtype: tl.dtype
|
|
33
|
+
attr: Dict = dataclasses.field(default_factory=dict)
|
|
30
34
|
|
|
31
35
|
def __bool__(self):
|
|
32
36
|
return bool(self.data.all())
|
|
@@ -73,17 +77,19 @@ class BlockPointerHandle:
|
|
|
73
77
|
class TensorDescHandle:
|
|
74
78
|
|
|
75
79
|
def __init__(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
|
|
76
|
-
block_shape: List[int]):
|
|
80
|
+
block_shape: List[int], padding):
|
|
77
81
|
self.base = base
|
|
78
82
|
self.ndim = len(shape)
|
|
79
83
|
self.shape = shape
|
|
80
84
|
self.strides = strides
|
|
81
85
|
self.block_shape = block_shape
|
|
86
|
+
self.padding = padding
|
|
82
87
|
|
|
83
88
|
def validate(self):
|
|
84
89
|
assert self.base.data.item() % 16 == 0, "base must be 16-byte aligned"
|
|
85
90
|
assert len(self.strides) == self.ndim
|
|
86
91
|
assert len(self.block_shape) == self.ndim
|
|
92
|
+
assert self.ndim >= 1, "descriptor cannot be 0 dimensional"
|
|
87
93
|
|
|
88
94
|
for stride in self.strides[:-1]:
|
|
89
95
|
assert stride.data.item() % 16 == 0, "stride must be 16-byte aligned"
|
|
@@ -103,6 +109,7 @@ class TensorDescHandle:
|
|
|
103
109
|
off = (offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
|
|
104
110
|
ptrs = ptrs + (itemsize * off * self.strides[dim].data).astype(np.uint64)
|
|
105
111
|
masks = masks & (0 <= off) & (off < self.shape[dim].data)
|
|
112
|
+
assert ptrs.dtype == np.uint64
|
|
106
113
|
ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
|
|
107
114
|
return ptrs, masks
|
|
108
115
|
|
|
@@ -114,7 +121,7 @@ class InterpreterOptions:
|
|
|
114
121
|
sanitize_overflow: bool = True
|
|
115
122
|
arch: str = None
|
|
116
123
|
supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15")
|
|
117
|
-
|
|
124
|
+
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
|
|
118
125
|
default_dot_input_precision: str = "tf32"
|
|
119
126
|
allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
|
|
120
127
|
max_num_imprecise_acc_default: int = 0
|
|
@@ -248,8 +255,8 @@ np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64])
|
|
|
248
255
|
class ExtraFunctions:
|
|
249
256
|
|
|
250
257
|
@staticmethod
|
|
251
|
-
def _convert_custom_types(input, dst_ty, fp_downcast_rounding,
|
|
252
|
-
return tl.tensor(
|
|
258
|
+
def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _semantic):
|
|
259
|
+
return tl.tensor(_semantic.builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty)
|
|
253
260
|
|
|
254
261
|
|
|
255
262
|
class InterpreterBuilder:
|
|
@@ -306,6 +313,9 @@ class InterpreterBuilder:
|
|
|
306
313
|
def get_double_ty(self):
|
|
307
314
|
return tl.float64
|
|
308
315
|
|
|
316
|
+
def get_int1_ty(self):
|
|
317
|
+
return tl.int1
|
|
318
|
+
|
|
309
319
|
def get_int8_ty(self):
|
|
310
320
|
return tl.int8
|
|
311
321
|
|
|
@@ -587,11 +597,18 @@ class InterpreterBuilder:
|
|
|
587
597
|
b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16)
|
|
588
598
|
return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar)
|
|
589
599
|
|
|
590
|
-
def create_make_range(self, start, stop):
|
|
600
|
+
def create_make_range(self, ret_ty, start, stop):
|
|
591
601
|
return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32)
|
|
592
602
|
|
|
593
|
-
def create_histogram(self, data, bins):
|
|
594
|
-
|
|
603
|
+
def create_histogram(self, data, bins, mask):
|
|
604
|
+
if mask is None:
|
|
605
|
+
mask = TensorHandle(np.ones_like(data.data, dtype=bool), tl.int1)
|
|
606
|
+
# force all masked elements to zero
|
|
607
|
+
data = np.where(mask.data, data.data, np.zeros_like(data.data))
|
|
608
|
+
histogram = np.histogram(data, bins=bins, range=(0, bins))[0]
|
|
609
|
+
# remove overcounted elements
|
|
610
|
+
histogram[0] -= np.logical_not(mask.data).sum()
|
|
611
|
+
return TensorHandle(histogram, tl.int32)
|
|
595
612
|
|
|
596
613
|
def create_gather(self, src, indices, axis):
|
|
597
614
|
return TensorHandle(np.take_along_axis(src.data, indices.data, axis=axis), src.dtype.scalar)
|
|
@@ -641,12 +658,16 @@ class InterpreterBuilder:
|
|
|
641
658
|
# Triton only supports splitting the original tensor into two along the last axis
|
|
642
659
|
return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar))
|
|
643
660
|
|
|
644
|
-
def create_splat(self,
|
|
661
|
+
def create_splat(self, ret_ty, arg):
|
|
662
|
+
shape = ret_ty.shape
|
|
645
663
|
if isinstance(arg.dtype, tl.block_type):
|
|
646
664
|
return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
|
|
647
665
|
else: # scalar
|
|
648
666
|
return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
|
|
649
667
|
|
|
668
|
+
def create_unsplat(self, arg):
|
|
669
|
+
return TensorHandle(np.full((1, ), arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
|
|
670
|
+
|
|
650
671
|
def create_atomic_cas(self, ptr, cmp, val, sem, scope):
|
|
651
672
|
if sem not in self.ir_sem_to_interpreter_sem:
|
|
652
673
|
raise ValueError(f"unsupported semantic {sem}")
|
|
@@ -709,14 +730,9 @@ class InterpreterBuilder:
|
|
|
709
730
|
ret.offsets[i].data += offsets[i].data
|
|
710
731
|
return ret
|
|
711
732
|
|
|
712
|
-
def create_make_tensor_descriptor(
|
|
713
|
-
|
|
714
|
-
base
|
|
715
|
-
shape: List[TensorHandle],
|
|
716
|
-
strides: List[TensorHandle],
|
|
717
|
-
tensor_shape: List[int],
|
|
718
|
-
):
|
|
719
|
-
desc = TensorDescHandle(base, shape, strides, tensor_shape)
|
|
733
|
+
def create_make_tensor_descriptor(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
|
|
734
|
+
tensor_shape: List[int], is_signed: bool, padding: str = "zero"):
|
|
735
|
+
desc = TensorDescHandle(base, shape, strides, tensor_shape, padding)
|
|
720
736
|
desc.validate()
|
|
721
737
|
return desc
|
|
722
738
|
|
|
@@ -724,7 +740,16 @@ class InterpreterBuilder:
|
|
|
724
740
|
eviction_policy):
|
|
725
741
|
assert isinstance(desc, TensorDescHandle)
|
|
726
742
|
ptrs, mask = desc.materialize_pointers(indices)
|
|
727
|
-
|
|
743
|
+
dtype_tt = ptrs.get_element_ty()
|
|
744
|
+
dtype_np = _get_np_dtype(dtype_tt)
|
|
745
|
+
padding = desc.padding
|
|
746
|
+
if padding == _ir.PADDING_OPTION.PAD_ZERO:
|
|
747
|
+
other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
|
|
748
|
+
elif padding == _ir.PADDING_OPTION.PAD_NAN:
|
|
749
|
+
other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt)
|
|
750
|
+
else:
|
|
751
|
+
raise ValueError(f"unsupported padding {padding}")
|
|
752
|
+
return self.create_masked_load(ptrs, mask, other, cache_modifier=cache_modifier,
|
|
728
753
|
eviction_policy=eviction_policy, is_volatile=False)
|
|
729
754
|
|
|
730
755
|
def create_descriptor_store(self, desc: TensorDescHandle, value: TensorHandle, indices: List[TensorHandle]):
|
|
@@ -753,15 +778,18 @@ class InterpreterBuilder:
|
|
|
753
778
|
np_type = _get_np_dtype(type)
|
|
754
779
|
if "int" in np_type.name:
|
|
755
780
|
return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar)
|
|
781
|
+
elif np_type == np.bool_:
|
|
782
|
+
return TensorHandle(np.full(1, True, dtype=np_type), type.scalar)
|
|
756
783
|
else:
|
|
757
784
|
raise TypeError(f"unsupported type {type}")
|
|
758
785
|
|
|
759
786
|
|
|
760
787
|
def _patch_attr(obj, name, member, builder):
|
|
788
|
+
semantic = TritonSemantic(builder)
|
|
761
789
|
new_member = lambda *args, member=member, **kwargs: (member(*args, **
|
|
762
790
|
{k: v
|
|
763
791
|
for k, v in kwargs.items()
|
|
764
|
-
if k != "
|
|
792
|
+
if k != "_semantic"}, _semantic=semantic))
|
|
765
793
|
setattr(obj, name, new_member)
|
|
766
794
|
|
|
767
795
|
|
|
@@ -822,12 +850,10 @@ class ReduceScanOpInterface:
|
|
|
822
850
|
|
|
823
851
|
def apply(self, input):
|
|
824
852
|
if not isinstance(input, tuple):
|
|
825
|
-
|
|
853
|
+
return self.apply((input, ))[0]
|
|
826
854
|
self.check_tensor(input)
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
def apply_impl(self, input):
|
|
830
|
-
raise NotImplementedError("apply_impl not implemented")
|
|
855
|
+
ret = self.apply_impl(input)
|
|
856
|
+
return tuple(ret) if isinstance(ret, (list, tuple)) else (ret, )
|
|
831
857
|
|
|
832
858
|
|
|
833
859
|
class ReduceOps(ReduceScanOpInterface):
|
|
@@ -887,7 +913,7 @@ class ReduceOps(ReduceScanOpInterface):
|
|
|
887
913
|
# Take a scalar
|
|
888
914
|
data = data.item()
|
|
889
915
|
ret.append(self.to_tensor(data, input[i].dtype))
|
|
890
|
-
return ret
|
|
916
|
+
return ret
|
|
891
917
|
|
|
892
918
|
def min_max(self, input, val_reduce_op, idx_reduce_op=None):
|
|
893
919
|
# If input is a tuple, it must be (val, index), and we only take val
|
|
@@ -916,9 +942,9 @@ class ReduceOps(ReduceScanOpInterface):
|
|
|
916
942
|
elif self.combine_fn == tl.standard._argmax_combine_tie_break_left:
|
|
917
943
|
return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax)
|
|
918
944
|
elif self.combine_fn == tl.standard._elementwise_max:
|
|
919
|
-
return self.min_max(input[0], val_reduce_op=np.
|
|
945
|
+
return self.min_max(input[0], val_reduce_op=np.nanmax, idx_reduce_op=None)
|
|
920
946
|
elif self.combine_fn == tl.standard._elementwise_min:
|
|
921
|
-
return self.min_max(input[0], val_reduce_op=np.
|
|
947
|
+
return self.min_max(input[0], val_reduce_op=np.nanmin, idx_reduce_op=None)
|
|
922
948
|
elif self.combine_fn == tl.standard._sum_combine:
|
|
923
949
|
return self.sum(input[0])
|
|
924
950
|
else:
|
|
@@ -985,7 +1011,7 @@ class ScanOps(ReduceScanOpInterface):
|
|
|
985
1011
|
if self.reverse:
|
|
986
1012
|
for arg in ret:
|
|
987
1013
|
arg.handle.data = np.flip(arg.handle.data, axis=self.axis)
|
|
988
|
-
return
|
|
1014
|
+
return ret
|
|
989
1015
|
|
|
990
1016
|
|
|
991
1017
|
def _patch_reduce_scan():
|
|
@@ -1092,7 +1118,7 @@ def _patch_lang(fn):
|
|
|
1092
1118
|
_patch_builtin(lang.math, interpreter_builder)
|
|
1093
1119
|
_patch_lang_tensor(lang.tensor)
|
|
1094
1120
|
_patch_lang_core(lang)
|
|
1095
|
-
_patch_builtin(tl.core.
|
|
1121
|
+
_patch_builtin(tl.core.tensor_descriptor_base, interpreter_builder)
|
|
1096
1122
|
|
|
1097
1123
|
|
|
1098
1124
|
def _tuple_create(arg, contents):
|
|
@@ -1107,7 +1133,7 @@ def _tuple_create(arg, contents):
|
|
|
1107
1133
|
# TODO: wrap everything in triton tensors
|
|
1108
1134
|
def _implicit_cvt(arg):
|
|
1109
1135
|
if isinstance(arg, int):
|
|
1110
|
-
ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg))
|
|
1136
|
+
ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None)
|
|
1111
1137
|
dtype = np.int32
|
|
1112
1138
|
if -2**31 <= arg < 2**31:
|
|
1113
1139
|
dtype = np.int32
|
|
@@ -1122,15 +1148,25 @@ def _implicit_cvt(arg):
|
|
|
1122
1148
|
handle = TensorHandle(np.array([arg], dtype=dtype), ty)
|
|
1123
1149
|
return tl.tensor(handle, ty)
|
|
1124
1150
|
if hasattr(arg, "data_ptr"):
|
|
1125
|
-
ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg))
|
|
1151
|
+
ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None)
|
|
1126
1152
|
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
|
|
1127
1153
|
return tl.tensor(handle, ty)
|
|
1128
1154
|
elif isinstance(arg, tuple):
|
|
1129
1155
|
return _tuple_create(arg, map(_implicit_cvt, arg))
|
|
1156
|
+
elif isinstance(arg, TensorDescriptor):
|
|
1157
|
+
strides = [_implicit_cvt(s) for s in arg.strides]
|
|
1158
|
+
assert arg.strides[-1] == 1
|
|
1159
|
+
strides[-1] = tl.constexpr(1)
|
|
1160
|
+
semantic = TritonSemantic(InterpreterBuilder())
|
|
1161
|
+
return semantic.make_tensor_descriptor(base=_implicit_cvt(arg.base),
|
|
1162
|
+
shape=[_implicit_cvt(s) for s in arg.shape], strides=strides,
|
|
1163
|
+
block_shape=[tl.constexpr(b)
|
|
1164
|
+
for b in arg.block_shape], padding_option=arg.padding)
|
|
1130
1165
|
return arg
|
|
1131
1166
|
|
|
1132
1167
|
|
|
1133
1168
|
interpreter_builder = InterpreterBuilder()
|
|
1169
|
+
interpreter_semantic = TritonSemantic(interpreter_builder)
|
|
1134
1170
|
|
|
1135
1171
|
|
|
1136
1172
|
def _unwrap_tensor(t):
|
|
@@ -1162,6 +1198,14 @@ class GridExecutor:
|
|
|
1162
1198
|
def _to_cpu(arg):
|
|
1163
1199
|
if isinstance(arg, tuple):
|
|
1164
1200
|
return _tuple_create(arg, map(_to_cpu, arg))
|
|
1201
|
+
elif isinstance(arg, TensorDescriptor):
|
|
1202
|
+
return TensorDescriptor(
|
|
1203
|
+
_to_cpu(arg.base),
|
|
1204
|
+
arg.shape,
|
|
1205
|
+
arg.strides,
|
|
1206
|
+
arg.block_shape,
|
|
1207
|
+
arg.padding,
|
|
1208
|
+
)
|
|
1165
1209
|
elif not hasattr(arg, "data_ptr"):
|
|
1166
1210
|
return arg
|
|
1167
1211
|
|
|
@@ -1195,6 +1239,8 @@ class GridExecutor:
|
|
|
1195
1239
|
elif isinstance(arg_dev, tuple):
|
|
1196
1240
|
for (arg_dev, arg_hst) in zip(arg_dev, arg_hst):
|
|
1197
1241
|
_from_cpu(arg_dev, arg_hst)
|
|
1242
|
+
elif isinstance(arg_dev, TensorDescriptor):
|
|
1243
|
+
_from_cpu(arg_dev.base, arg_hst.base)
|
|
1198
1244
|
|
|
1199
1245
|
for arg_dev, arg_hst in zip(args_dev, args_hst):
|
|
1200
1246
|
_from_cpu(arg_dev, arg_hst)
|
|
@@ -1235,6 +1281,8 @@ class GridExecutor:
|
|
|
1235
1281
|
interpreter_builder.set_grid_idx(x, y, z)
|
|
1236
1282
|
self.fn(**args)
|
|
1237
1283
|
except Exception as e:
|
|
1284
|
+
if triton.knobs.compilation.front_end_debugging:
|
|
1285
|
+
raise
|
|
1238
1286
|
raise InterpreterError(repr(e)) from e
|
|
1239
1287
|
# copy arguments back to propagate side-effects
|
|
1240
1288
|
self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst)
|
|
@@ -1249,14 +1297,10 @@ class ASTTransformer(ast.NodeTransformer):
|
|
|
1249
1297
|
if len(names) > 1:
|
|
1250
1298
|
raise ValueError("Multiple assignments are not supported")
|
|
1251
1299
|
# Modify the assignment x = value to
|
|
1252
|
-
#
|
|
1300
|
+
# interpreter_semantic.to_tensor(value, False)
|
|
1253
1301
|
node.value = ast.Call(
|
|
1254
|
-
func=ast.Attribute(
|
|
1255
|
-
|
|
1256
|
-
value=ast.Attribute(value=ast.Name(id='triton', ctx=ast.Load()), attr='language', ctx=ast.Load()),
|
|
1257
|
-
attr='semantic', ctx=ast.Load()), attr='to_tensor', ctx=ast.Load()),
|
|
1258
|
-
args=[node.value, ast.Name(id='interpreter_builder', ctx=ast.Load()),
|
|
1259
|
-
ast.Constant(value=False)], keywords=[])
|
|
1302
|
+
func=ast.Attribute(value=ast.Name(id="interpreter_semantic", ctx=ast.Load()), attr="to_tensor",
|
|
1303
|
+
ctx=ast.Load()), args=[node.value, ast.Constant(value=False)], keywords=[])
|
|
1260
1304
|
return node
|
|
1261
1305
|
|
|
1262
1306
|
|
|
@@ -1331,11 +1375,12 @@ class FunctionRewriter:
|
|
|
1331
1375
|
|
|
1332
1376
|
class InterpretedFunction:
|
|
1333
1377
|
# Cache all rewritten functions
|
|
1334
|
-
rewritten_fn = {}
|
|
1378
|
+
rewritten_fn: Dict[Callable, Callable] = {}
|
|
1335
1379
|
|
|
1336
1380
|
def __init__(self, fn, **kwargs) -> None:
|
|
1337
1381
|
self.fn = fn
|
|
1338
1382
|
self.rewriter = FunctionRewriter(fn, **kwargs)
|
|
1383
|
+
self.kwargs = kwargs
|
|
1339
1384
|
|
|
1340
1385
|
def run(*args, **kwargs):
|
|
1341
1386
|
grid = kwargs["grid"]
|