triton-windows 3.2.0.post12__cp312-cp312-win_amd64.whl → 3.3.0a0.post12__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +3 -3
- triton/_internal_testing.py +59 -4
- triton/_utils.py +35 -0
- triton/backends/amd/compiler.py +121 -74
- triton/backends/amd/driver.py +77 -43
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
- triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
- triton/backends/amd/include/hip/hip_ext.h +4 -2
- triton/backends/amd/include/hip/hip_fp8.h +33 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
- triton/backends/amd/include/hip/hip_version.h +3 -3
- triton/backends/amd/include/hip/hiprtc.h +25 -25
- triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
- triton/backends/amd/include/hsa/hsa.h +11 -2
- triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/compiler.py +25 -225
- triton/backends/driver.py +7 -2
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +135 -90
- triton/backends/nvidia/driver.c +0 -1
- triton/backends/nvidia/driver.py +135 -49
- triton/backends/nvidia/include/cuda.h +2162 -241
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +2 -2
- triton/compiler/code_generator.py +334 -231
- triton/compiler/compiler.py +77 -66
- triton/language/__init__.py +22 -5
- triton/language/core.py +448 -74
- triton/language/extra/cuda/_experimental_tma.py +3 -5
- triton/language/math.py +1 -1
- triton/language/random.py +2 -1
- triton/language/semantic.py +206 -52
- triton/language/standard.py +35 -18
- triton/runtime/_allocation.py +32 -0
- triton/runtime/autotuner.py +27 -32
- triton/runtime/build.py +1 -48
- triton/runtime/cache.py +6 -6
- triton/runtime/errors.py +10 -0
- triton/runtime/interpreter.py +179 -45
- triton/runtime/jit.py +149 -190
- triton/testing.py +39 -11
- triton/tools/compile.py +27 -20
- triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
- triton/tools/mxfp.py +301 -0
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/METADATA +5 -2
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/RECORD +68 -59
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/top_level.txt +2 -0
- /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/WHEEL +0 -0
triton/backends/compiler.py
CHANGED
|
@@ -1,218 +1,12 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import re
|
|
3
|
-
import hashlib
|
|
4
3
|
import subprocess
|
|
5
|
-
|
|
6
|
-
from abc import ABCMeta, abstractmethod
|
|
4
|
+
import sysconfig
|
|
5
|
+
from abc import ABCMeta, abstractmethod
|
|
7
6
|
from dataclasses import dataclass
|
|
8
|
-
from typing import Dict,
|
|
7
|
+
from typing import Dict, Union
|
|
9
8
|
from types import ModuleType
|
|
10
9
|
|
|
11
|
-
# Table that associates strings to AttrsDescriptor (sub)classes.
|
|
12
|
-
# In this way we can dynamically select the correct class
|
|
13
|
-
# constructor
|
|
14
|
-
_descriptor_table = {}
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def register_descriptor(cls):
|
|
18
|
-
"""
|
|
19
|
-
Register a descriptor into the descriptor table
|
|
20
|
-
"""
|
|
21
|
-
_descriptor_table[cls.__name__] = cls
|
|
22
|
-
return cls
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
@register_descriptor
|
|
26
|
-
class AttrsDescriptor:
|
|
27
|
-
"""
|
|
28
|
-
This class handles compile-time properties for specific function parameters.
|
|
29
|
-
|
|
30
|
-
Different backends can add more properties to the common ones. The class
|
|
31
|
-
contains two fields:
|
|
32
|
-
|
|
33
|
-
`arg_properties`: a dictionary containing the different compile-time properties for different
|
|
34
|
-
parameters. I.e., the dictionary is a map from property names to parameter indices
|
|
35
|
-
{
|
|
36
|
-
"prop0": (0, 2, 3)
|
|
37
|
-
"prop1": (0, 4, 5)
|
|
38
|
-
}
|
|
39
|
-
Different backends might need different properties on those paraemters to enable
|
|
40
|
-
specific optimizations. The common compile time properties contained in this class
|
|
41
|
-
are :
|
|
42
|
-
- "tt.divisibility", i.e., is the given parameter divisible by 16
|
|
43
|
-
- "tt.equal_to_1", i.e., is the given parameter an integer constant 1
|
|
44
|
-
|
|
45
|
-
`property_values`: a dictionary containing the value of the different compile-time properties, like:
|
|
46
|
-
{
|
|
47
|
-
"prop0": val0
|
|
48
|
-
"prop1": val1
|
|
49
|
-
}
|
|
50
|
-
|
|
51
|
-
`constant_properties`: a set containing the properties that can be used to determine if a parameter is constant
|
|
52
|
-
|
|
53
|
-
"""
|
|
54
|
-
__slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties')
|
|
55
|
-
|
|
56
|
-
def __init__(self, params=None, values=None):
|
|
57
|
-
"""
|
|
58
|
-
Initialize the compile-time properties
|
|
59
|
-
|
|
60
|
-
We can initialize the AttrsDescriptor class by passing the list of params
|
|
61
|
-
of the function and their `values`. The function will try to apply the properties
|
|
62
|
-
to the values and save the parameters in the `arg_properties` list. If we don't pass
|
|
63
|
-
either the `params` or the `values` we should initialize the class via an alternative method
|
|
64
|
-
(see `from_dict` or `from_hints`)
|
|
65
|
-
"""
|
|
66
|
-
# Default initialization
|
|
67
|
-
self.arg_properties = {}
|
|
68
|
-
self.property_values = {}
|
|
69
|
-
self.constant_properties = set()
|
|
70
|
-
|
|
71
|
-
self._add_common_properties(params, values)
|
|
72
|
-
self._add_backend_properties(params, values)
|
|
73
|
-
self._init_slots()
|
|
74
|
-
|
|
75
|
-
def _add_common_properties(self, params, values):
|
|
76
|
-
""" Add common compile-time properties """
|
|
77
|
-
self.property_values["tt.divisibility"] = 16
|
|
78
|
-
self.property_values["tt.equal_to"] = 1
|
|
79
|
-
self.constant_properties.add("tt.equal_to")
|
|
80
|
-
|
|
81
|
-
if (params is None) or (values is None):
|
|
82
|
-
return
|
|
83
|
-
|
|
84
|
-
# Compile properties deduction
|
|
85
|
-
assert (len(params) == len(values))
|
|
86
|
-
|
|
87
|
-
# Divisibility property
|
|
88
|
-
self.arg_properties["tt.divisibility"] = [
|
|
89
|
-
param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg)
|
|
90
|
-
and not param.do_not_specialize and not param.do_not_specialize_on_alignment
|
|
91
|
-
]
|
|
92
|
-
|
|
93
|
-
# Equal to 1 property
|
|
94
|
-
self.arg_properties["tt.equal_to"] = [
|
|
95
|
-
param.num
|
|
96
|
-
for param, arg in zip(params, values)
|
|
97
|
-
if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize
|
|
98
|
-
]
|
|
99
|
-
|
|
100
|
-
def _add_backend_properties(self, params=None, values=None):
|
|
101
|
-
""" This method is for different subclasses to implement their own compile-time properties """
|
|
102
|
-
pass
|
|
103
|
-
|
|
104
|
-
def _init_slots(self):
|
|
105
|
-
""" Initialize the slots of this class """
|
|
106
|
-
for name, val in self.arg_properties.items():
|
|
107
|
-
setattr(self, name.removeprefix('tt.') + '_' + str(self.property_values[name]), val)
|
|
108
|
-
|
|
109
|
-
def get_fn_attrs(self) -> Dict:
|
|
110
|
-
"""
|
|
111
|
-
Get the function attributes as a dictionary.
|
|
112
|
-
|
|
113
|
-
The returned dictionary will look like :
|
|
114
|
-
{
|
|
115
|
-
"arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]}
|
|
116
|
-
"arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]}
|
|
117
|
-
}
|
|
118
|
-
"""
|
|
119
|
-
attrs = {}
|
|
120
|
-
for prop_name, arg_set in self.arg_properties.items():
|
|
121
|
-
prop_val = self.property_values[prop_name]
|
|
122
|
-
for arg in arg_set:
|
|
123
|
-
attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)]
|
|
124
|
-
return attrs
|
|
125
|
-
|
|
126
|
-
def get_constants(self) -> Dict:
|
|
127
|
-
""" Return a mapping of constant parameters to their values """
|
|
128
|
-
constants = {}
|
|
129
|
-
for prop_name in self.constant_properties:
|
|
130
|
-
for p in self.arg_properties.get(prop_name, []):
|
|
131
|
-
constants[p] = self.property_values[prop_name]
|
|
132
|
-
return constants
|
|
133
|
-
|
|
134
|
-
def filter_out_constants(self):
|
|
135
|
-
""" Return the same object, without properties marked as constants"""
|
|
136
|
-
import copy
|
|
137
|
-
c = copy.deepcopy(self)
|
|
138
|
-
for prop_name in c.constant_properties:
|
|
139
|
-
c.arg_properties.pop(prop_name, None)
|
|
140
|
-
c.property_values.pop(prop_name, None)
|
|
141
|
-
c.constant_properties = {}
|
|
142
|
-
return c
|
|
143
|
-
|
|
144
|
-
def hash(self):
|
|
145
|
-
values = [sorted(self.arg_properties.values())]
|
|
146
|
-
values += [sorted(self.property_values.values())]
|
|
147
|
-
values += [sorted(self.constant_properties)]
|
|
148
|
-
key = str(values)
|
|
149
|
-
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
150
|
-
|
|
151
|
-
def to_dict(self):
|
|
152
|
-
"""
|
|
153
|
-
Store the fields of this class in a serializable dictionary
|
|
154
|
-
"""
|
|
155
|
-
# We need to only store the `arg_properties` field. To initialize the
|
|
156
|
-
# other fields we relay on the class type. We store it as a string in
|
|
157
|
-
# the dictionary so that we can use it to invoke the appropriate
|
|
158
|
-
# (sub)class constructor in the `from_dict` method.
|
|
159
|
-
return {"arg_properties": self.arg_properties, "cls": type(self).__name__}
|
|
160
|
-
|
|
161
|
-
@staticmethod
|
|
162
|
-
def from_dict(data):
|
|
163
|
-
"""
|
|
164
|
-
Create the object from a serializable dictionary
|
|
165
|
-
"""
|
|
166
|
-
attrs_descriptor = _descriptor_table[data["cls"]]()
|
|
167
|
-
for prop_name, param_ids in data["arg_properties"].items():
|
|
168
|
-
attrs_descriptor.arg_properties[prop_name] = param_ids
|
|
169
|
-
attrs_descriptor._init_slots()
|
|
170
|
-
return attrs_descriptor
|
|
171
|
-
|
|
172
|
-
@classmethod
|
|
173
|
-
def from_hints(cls, hints: List[Tuple[int, int]]):
|
|
174
|
-
"""
|
|
175
|
-
Create the class from a set of hints that are passed in.
|
|
176
|
-
|
|
177
|
-
Instead of deducing the properties from a list of paramaters and values,
|
|
178
|
-
the user can pass in a list of `hints=[(param_index, val)]` and if `val`
|
|
179
|
-
matches one of the values of the properties (e.g., `prop_val[prop0]`),
|
|
180
|
-
then we insert `param_index` into the correct list (e.g., in
|
|
181
|
-
`arg_properties[prop0]`)
|
|
182
|
-
"""
|
|
183
|
-
attrs_descriptor = cls()
|
|
184
|
-
for prop_name, prop_val in attrs_descriptor.property_values.items():
|
|
185
|
-
attrs_descriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val]
|
|
186
|
-
attrs_descriptor._init_slots()
|
|
187
|
-
return attrs_descriptor
|
|
188
|
-
|
|
189
|
-
@staticmethod
|
|
190
|
-
def is_divisible_by_16(x):
|
|
191
|
-
""" Return if the argument is a multiple of 16"""
|
|
192
|
-
if hasattr(x, "data_ptr"):
|
|
193
|
-
return x.data_ptr() % 16 == 0
|
|
194
|
-
elif isinstance(x, int):
|
|
195
|
-
return x % 16 == 0
|
|
196
|
-
if x is None:
|
|
197
|
-
return True
|
|
198
|
-
return False
|
|
199
|
-
|
|
200
|
-
@staticmethod
|
|
201
|
-
def is_equal_to_1(x):
|
|
202
|
-
""" Return if the argument is a constant 1"""
|
|
203
|
-
return True if isinstance(x, int) and not isinstance(x, bool) and x == 1 else False
|
|
204
|
-
|
|
205
|
-
@staticmethod
|
|
206
|
-
def get_property_key(val, align):
|
|
207
|
-
if align and AttrsDescriptor.is_divisible_by_16(val):
|
|
208
|
-
return "D"
|
|
209
|
-
if AttrsDescriptor.is_equal_to_1(val):
|
|
210
|
-
return "1"
|
|
211
|
-
return "N"
|
|
212
|
-
|
|
213
|
-
def __repr__(self):
|
|
214
|
-
return f"AttrsDescriptor.from_dict({self.to_dict()!r})"
|
|
215
|
-
|
|
216
10
|
|
|
217
11
|
@dataclass(frozen=True)
|
|
218
12
|
class GPUTarget(object):
|
|
@@ -231,22 +25,23 @@ class BaseBackend(metaclass=ABCMeta):
|
|
|
231
25
|
|
|
232
26
|
@staticmethod
|
|
233
27
|
def _path_to_binary(binary: str):
|
|
28
|
+
binary += sysconfig.get_config_var("EXE")
|
|
234
29
|
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
|
|
235
30
|
paths = [
|
|
236
31
|
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
|
|
237
32
|
os.path.join(base_dir, "third_party", "cuda", "bin", binary),
|
|
238
33
|
]
|
|
239
|
-
for
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
|
|
34
|
+
for path in paths:
|
|
35
|
+
if os.path.exists(path) and os.path.isfile(path):
|
|
36
|
+
result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
|
|
243
37
|
if result is not None:
|
|
244
38
|
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
|
245
39
|
if version is not None:
|
|
246
|
-
return
|
|
40
|
+
return path, version.group(1)
|
|
247
41
|
raise RuntimeError(f"Cannot find {binary}")
|
|
248
42
|
|
|
249
|
-
@
|
|
43
|
+
@classmethod
|
|
44
|
+
@abstractmethod
|
|
250
45
|
def supports_target(target: GPUTarget):
|
|
251
46
|
raise NotImplementedError
|
|
252
47
|
|
|
@@ -289,16 +84,21 @@ class BaseBackend(metaclass=ABCMeta):
|
|
|
289
84
|
"""
|
|
290
85
|
raise NotImplementedError
|
|
291
86
|
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
return
|
|
87
|
+
@staticmethod
|
|
88
|
+
def parse_attr(desc):
|
|
89
|
+
assert isinstance(desc, str)
|
|
90
|
+
ret = []
|
|
91
|
+
if "D" in desc:
|
|
92
|
+
ret += [["tt.divisibility", 16]]
|
|
93
|
+
return ret
|
|
299
94
|
|
|
300
|
-
|
|
95
|
+
@staticmethod
|
|
96
|
+
def get_arg_specialization(arg, ty, **kwargs):
|
|
301
97
|
"""
|
|
302
|
-
Return
|
|
98
|
+
Return a string unique to each possible specialization of the argument
|
|
303
99
|
"""
|
|
304
|
-
|
|
100
|
+
if ty == "int" and arg % 16 == 0 and kwargs.get("align", False):
|
|
101
|
+
return "D"
|
|
102
|
+
if ty == "tensor" and arg.data_ptr() % 16 == 0 and kwargs.get("align", False):
|
|
103
|
+
return "D"
|
|
104
|
+
return ""
|
triton/backends/driver.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from abc import ABCMeta, abstractmethod
|
|
1
|
+
from abc import ABCMeta, abstractmethod
|
|
2
2
|
from typing import Callable, List, Protocol, Sequence
|
|
3
3
|
|
|
4
4
|
|
|
@@ -10,7 +10,8 @@ class Benchmarker(Protocol):
|
|
|
10
10
|
|
|
11
11
|
class DriverBase(metaclass=ABCMeta):
|
|
12
12
|
|
|
13
|
-
@
|
|
13
|
+
@classmethod
|
|
14
|
+
@abstractmethod
|
|
14
15
|
def is_active(self):
|
|
15
16
|
pass
|
|
16
17
|
|
|
@@ -18,6 +19,10 @@ class DriverBase(metaclass=ABCMeta):
|
|
|
18
19
|
def get_current_target(self):
|
|
19
20
|
pass
|
|
20
21
|
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def get_active_torch_device(self):
|
|
24
|
+
pass
|
|
25
|
+
|
|
21
26
|
@abstractmethod
|
|
22
27
|
def get_benchmarker(self) -> Benchmarker:
|
|
23
28
|
"""
|
|
Binary file
|