triton-windows 3.2.0.post11__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +85 -0
- triton/_internal_testing.py +123 -0
- triton/backends/__init__.py +50 -0
- triton/backends/amd/compiler.py +368 -0
- triton/backends/amd/driver.c +211 -0
- triton/backends/amd/driver.py +512 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
- triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
- triton/backends/amd/include/hip/channel_descriptor.h +39 -0
- triton/backends/amd/include/hip/device_functions.h +38 -0
- triton/backends/amd/include/hip/driver_types.h +468 -0
- triton/backends/amd/include/hip/hip_bf16.h +36 -0
- triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
- triton/backends/amd/include/hip/hip_common.h +100 -0
- triton/backends/amd/include/hip/hip_complex.h +38 -0
- triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
- triton/backends/amd/include/hip/hip_deprecated.h +95 -0
- triton/backends/amd/include/hip/hip_ext.h +159 -0
- triton/backends/amd/include/hip/hip_fp16.h +36 -0
- triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
- triton/backends/amd/include/hip/hip_hcc.h +24 -0
- triton/backends/amd/include/hip/hip_math_constants.h +36 -0
- triton/backends/amd/include/hip/hip_profile.h +27 -0
- triton/backends/amd/include/hip/hip_runtime.h +75 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
- triton/backends/amd/include/hip/hip_texture_types.h +29 -0
- triton/backends/amd/include/hip/hip_vector_types.h +41 -0
- triton/backends/amd/include/hip/hip_version.h +17 -0
- triton/backends/amd/include/hip/hiprtc.h +421 -0
- triton/backends/amd/include/hip/library_types.h +78 -0
- triton/backends/amd/include/hip/math_functions.h +42 -0
- triton/backends/amd/include/hip/surface_types.h +63 -0
- triton/backends/amd/include/hip/texture_types.h +194 -0
- triton/backends/amd/include/hsa/Brig.h +1131 -0
- triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
- triton/backends/amd/include/hsa/amd_hsa_elf.h +436 -0
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
- triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
- triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
- triton/backends/amd/include/hsa/hsa.h +5729 -0
- triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
- triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
- triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
- triton/backends/amd/include/roctracer/roctracer.h +779 -0
- triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
- triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
- triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
- triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
- triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
- triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
- triton/backends/amd/include/roctracer/roctx.h +229 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +304 -0
- triton/backends/driver.py +48 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +410 -0
- triton/backends/nvidia/driver.c +451 -0
- triton/backends/nvidia/driver.py +524 -0
- triton/backends/nvidia/include/cuda.h +24359 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +4 -0
- triton/compiler/code_generator.py +1303 -0
- triton/compiler/compiler.py +430 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/language/__init__.py +294 -0
- triton/language/_utils.py +21 -0
- triton/language/core.py +2694 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +13 -0
- triton/language/extra/cuda/_experimental_tma.py +108 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +3 -0
- triton/language/extra/hip/libdevice.py +475 -0
- triton/language/extra/libdevice.py +786 -0
- triton/language/math.py +250 -0
- triton/language/random.py +207 -0
- triton/language/semantic.py +1796 -0
- triton/language/standard.py +452 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/autotuner.py +408 -0
- triton/runtime/build.py +111 -0
- triton/runtime/cache.py +295 -0
- triton/runtime/driver.py +60 -0
- triton/runtime/errors.py +26 -0
- triton/runtime/interpreter.py +1235 -0
- triton/runtime/jit.py +951 -0
- triton/testing.py +511 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.c +67 -0
- triton/tools/compile.h +14 -0
- triton/tools/compile.py +155 -0
- triton/tools/disasm.py +144 -0
- triton/tools/experimental_descriptor.py +32 -0
- triton/tools/link.py +322 -0
- triton/windows_utils.py +375 -0
- triton_windows-3.2.0.post11.dist-info/METADATA +39 -0
- triton_windows-3.2.0.post11.dist-info/RECORD +154 -0
- triton_windows-3.2.0.post11.dist-info/WHEEL +5 -0
- triton_windows-3.2.0.post11.dist-info/top_level.txt +12 -0
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import hashlib
|
|
4
|
+
import subprocess
|
|
5
|
+
|
|
6
|
+
from abc import ABCMeta, abstractmethod, abstractclassmethod
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Dict, List, Tuple, Union
|
|
9
|
+
from types import ModuleType
|
|
10
|
+
|
|
11
|
+
# Table that associates strings to AttrsDescriptor (sub)classes.
|
|
12
|
+
# In this way we can dynamically select the correct class
|
|
13
|
+
# constructor
|
|
14
|
+
_descriptor_table = {}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def register_descriptor(cls):
|
|
18
|
+
"""
|
|
19
|
+
Register a descriptor into the descriptor table
|
|
20
|
+
"""
|
|
21
|
+
_descriptor_table[cls.__name__] = cls
|
|
22
|
+
return cls
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@register_descriptor
|
|
26
|
+
class AttrsDescriptor:
|
|
27
|
+
"""
|
|
28
|
+
This class handles compile-time properties for specific function parameters.
|
|
29
|
+
|
|
30
|
+
Different backends can add more properties to the common ones. The class
|
|
31
|
+
contains two fields:
|
|
32
|
+
|
|
33
|
+
`arg_properties`: a dictionary containing the different compile-time properties for different
|
|
34
|
+
parameters. I.e., the dictionary is a map from property names to parameter indices
|
|
35
|
+
{
|
|
36
|
+
"prop0": (0, 2, 3)
|
|
37
|
+
"prop1": (0, 4, 5)
|
|
38
|
+
}
|
|
39
|
+
Different backends might need different properties on those paraemters to enable
|
|
40
|
+
specific optimizations. The common compile time properties contained in this class
|
|
41
|
+
are :
|
|
42
|
+
- "tt.divisibility", i.e., is the given parameter divisible by 16
|
|
43
|
+
- "tt.equal_to_1", i.e., is the given parameter an integer constant 1
|
|
44
|
+
|
|
45
|
+
`property_values`: a dictionary containing the value of the different compile-time properties, like:
|
|
46
|
+
{
|
|
47
|
+
"prop0": val0
|
|
48
|
+
"prop1": val1
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
`constant_properties`: a set containing the properties that can be used to determine if a parameter is constant
|
|
52
|
+
|
|
53
|
+
"""
|
|
54
|
+
__slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties')
|
|
55
|
+
|
|
56
|
+
def __init__(self, params=None, values=None):
|
|
57
|
+
"""
|
|
58
|
+
Initialize the compile-time properties
|
|
59
|
+
|
|
60
|
+
We can initialize the AttrsDescriptor class by passing the list of params
|
|
61
|
+
of the function and their `values`. The function will try to apply the properties
|
|
62
|
+
to the values and save the parameters in the `arg_properties` list. If we don't pass
|
|
63
|
+
either the `params` or the `values` we should initialize the class via an alternative method
|
|
64
|
+
(see `from_dict` or `from_hints`)
|
|
65
|
+
"""
|
|
66
|
+
# Default initialization
|
|
67
|
+
self.arg_properties = {}
|
|
68
|
+
self.property_values = {}
|
|
69
|
+
self.constant_properties = set()
|
|
70
|
+
|
|
71
|
+
self._add_common_properties(params, values)
|
|
72
|
+
self._add_backend_properties(params, values)
|
|
73
|
+
self._init_slots()
|
|
74
|
+
|
|
75
|
+
def _add_common_properties(self, params, values):
|
|
76
|
+
""" Add common compile-time properties """
|
|
77
|
+
self.property_values["tt.divisibility"] = 16
|
|
78
|
+
self.property_values["tt.equal_to"] = 1
|
|
79
|
+
self.constant_properties.add("tt.equal_to")
|
|
80
|
+
|
|
81
|
+
if (params is None) or (values is None):
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
# Compile properties deduction
|
|
85
|
+
assert (len(params) == len(values))
|
|
86
|
+
|
|
87
|
+
# Divisibility property
|
|
88
|
+
self.arg_properties["tt.divisibility"] = [
|
|
89
|
+
param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg)
|
|
90
|
+
and not param.do_not_specialize and not param.do_not_specialize_on_alignment
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
# Equal to 1 property
|
|
94
|
+
self.arg_properties["tt.equal_to"] = [
|
|
95
|
+
param.num
|
|
96
|
+
for param, arg in zip(params, values)
|
|
97
|
+
if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
def _add_backend_properties(self, params=None, values=None):
|
|
101
|
+
""" This method is for different subclasses to implement their own compile-time properties """
|
|
102
|
+
pass
|
|
103
|
+
|
|
104
|
+
def _init_slots(self):
|
|
105
|
+
""" Initialize the slots of this class """
|
|
106
|
+
for name, val in self.arg_properties.items():
|
|
107
|
+
setattr(self, name.removeprefix('tt.') + '_' + str(self.property_values[name]), val)
|
|
108
|
+
|
|
109
|
+
def get_fn_attrs(self) -> Dict:
|
|
110
|
+
"""
|
|
111
|
+
Get the function attributes as a dictionary.
|
|
112
|
+
|
|
113
|
+
The returned dictionary will look like :
|
|
114
|
+
{
|
|
115
|
+
"arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]}
|
|
116
|
+
"arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]}
|
|
117
|
+
}
|
|
118
|
+
"""
|
|
119
|
+
attrs = {}
|
|
120
|
+
for prop_name, arg_set in self.arg_properties.items():
|
|
121
|
+
prop_val = self.property_values[prop_name]
|
|
122
|
+
for arg in arg_set:
|
|
123
|
+
attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)]
|
|
124
|
+
return attrs
|
|
125
|
+
|
|
126
|
+
def get_constants(self) -> Dict:
|
|
127
|
+
""" Return a mapping of constant parameters to their values """
|
|
128
|
+
constants = {}
|
|
129
|
+
for prop_name in self.constant_properties:
|
|
130
|
+
for p in self.arg_properties.get(prop_name, []):
|
|
131
|
+
constants[p] = self.property_values[prop_name]
|
|
132
|
+
return constants
|
|
133
|
+
|
|
134
|
+
def filter_out_constants(self):
|
|
135
|
+
""" Return the same object, without properties marked as constants"""
|
|
136
|
+
import copy
|
|
137
|
+
c = copy.deepcopy(self)
|
|
138
|
+
for prop_name in c.constant_properties:
|
|
139
|
+
c.arg_properties.pop(prop_name, None)
|
|
140
|
+
c.property_values.pop(prop_name, None)
|
|
141
|
+
c.constant_properties = {}
|
|
142
|
+
return c
|
|
143
|
+
|
|
144
|
+
def hash(self):
|
|
145
|
+
values = [sorted(self.arg_properties.values())]
|
|
146
|
+
values += [sorted(self.property_values.values())]
|
|
147
|
+
values += [sorted(self.constant_properties)]
|
|
148
|
+
key = str(values)
|
|
149
|
+
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
150
|
+
|
|
151
|
+
def to_dict(self):
|
|
152
|
+
"""
|
|
153
|
+
Store the fields of this class in a serializable dictionary
|
|
154
|
+
"""
|
|
155
|
+
# We need to only store the `arg_properties` field. To initialize the
|
|
156
|
+
# other fields we relay on the class type. We store it as a string in
|
|
157
|
+
# the dictionary so that we can use it to invoke the appropriate
|
|
158
|
+
# (sub)class constructor in the `from_dict` method.
|
|
159
|
+
return {"arg_properties": self.arg_properties, "cls": type(self).__name__}
|
|
160
|
+
|
|
161
|
+
@staticmethod
|
|
162
|
+
def from_dict(data):
|
|
163
|
+
"""
|
|
164
|
+
Create the object from a serializable dictionary
|
|
165
|
+
"""
|
|
166
|
+
attrs_descriptor = _descriptor_table[data["cls"]]()
|
|
167
|
+
for prop_name, param_ids in data["arg_properties"].items():
|
|
168
|
+
attrs_descriptor.arg_properties[prop_name] = param_ids
|
|
169
|
+
attrs_descriptor._init_slots()
|
|
170
|
+
return attrs_descriptor
|
|
171
|
+
|
|
172
|
+
@classmethod
|
|
173
|
+
def from_hints(cls, hints: List[Tuple[int, int]]):
|
|
174
|
+
"""
|
|
175
|
+
Create the class from a set of hints that are passed in.
|
|
176
|
+
|
|
177
|
+
Instead of deducing the properties from a list of paramaters and values,
|
|
178
|
+
the user can pass in a list of `hints=[(param_index, val)]` and if `val`
|
|
179
|
+
matches one of the values of the properties (e.g., `prop_val[prop0]`),
|
|
180
|
+
then we insert `param_index` into the correct list (e.g., in
|
|
181
|
+
`arg_properties[prop0]`)
|
|
182
|
+
"""
|
|
183
|
+
attrs_descriptor = cls()
|
|
184
|
+
for prop_name, prop_val in attrs_descriptor.property_values.items():
|
|
185
|
+
attrs_descriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val]
|
|
186
|
+
attrs_descriptor._init_slots()
|
|
187
|
+
return attrs_descriptor
|
|
188
|
+
|
|
189
|
+
@staticmethod
|
|
190
|
+
def is_divisible_by_16(x):
|
|
191
|
+
""" Return if the argument is a multiple of 16"""
|
|
192
|
+
if hasattr(x, "data_ptr"):
|
|
193
|
+
return x.data_ptr() % 16 == 0
|
|
194
|
+
elif isinstance(x, int):
|
|
195
|
+
return x % 16 == 0
|
|
196
|
+
if x is None:
|
|
197
|
+
return True
|
|
198
|
+
return False
|
|
199
|
+
|
|
200
|
+
@staticmethod
|
|
201
|
+
def is_equal_to_1(x):
|
|
202
|
+
""" Return if the argument is a constant 1"""
|
|
203
|
+
return True if isinstance(x, int) and not isinstance(x, bool) and x == 1 else False
|
|
204
|
+
|
|
205
|
+
@staticmethod
|
|
206
|
+
def get_property_key(val, align):
|
|
207
|
+
if align and AttrsDescriptor.is_divisible_by_16(val):
|
|
208
|
+
return "D"
|
|
209
|
+
if AttrsDescriptor.is_equal_to_1(val):
|
|
210
|
+
return "1"
|
|
211
|
+
return "N"
|
|
212
|
+
|
|
213
|
+
def __repr__(self):
|
|
214
|
+
return f"AttrsDescriptor.from_dict({self.to_dict()!r})"
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@dataclass(frozen=True)
|
|
218
|
+
class GPUTarget(object):
|
|
219
|
+
# Target backend, e.g., cuda, hip
|
|
220
|
+
backend: str
|
|
221
|
+
# Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip)
|
|
222
|
+
arch: Union[int, str]
|
|
223
|
+
warp_size: int
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class BaseBackend(metaclass=ABCMeta):
|
|
227
|
+
|
|
228
|
+
def __init__(self, target: GPUTarget) -> None:
|
|
229
|
+
self.target = target
|
|
230
|
+
assert self.supports_target(target)
|
|
231
|
+
|
|
232
|
+
@staticmethod
|
|
233
|
+
def _path_to_binary(binary: str):
|
|
234
|
+
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
|
|
235
|
+
paths = [
|
|
236
|
+
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
|
|
237
|
+
os.path.join(base_dir, "third_party", "cuda", "bin", binary),
|
|
238
|
+
]
|
|
239
|
+
for p in paths:
|
|
240
|
+
bin = p.split(" ")[0]
|
|
241
|
+
if os.path.exists(bin) and os.path.isfile(bin):
|
|
242
|
+
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
|
|
243
|
+
if result is not None:
|
|
244
|
+
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
|
245
|
+
if version is not None:
|
|
246
|
+
return p, version.group(1)
|
|
247
|
+
raise RuntimeError(f"Cannot find {binary}")
|
|
248
|
+
|
|
249
|
+
@abstractclassmethod
|
|
250
|
+
def supports_target(target: GPUTarget):
|
|
251
|
+
raise NotImplementedError
|
|
252
|
+
|
|
253
|
+
@abstractmethod
|
|
254
|
+
def hash(self) -> str:
|
|
255
|
+
"""Returns a unique identifier for this backend"""
|
|
256
|
+
raise NotImplementedError
|
|
257
|
+
|
|
258
|
+
@abstractmethod
|
|
259
|
+
def parse_options(self, options: dict) -> object:
|
|
260
|
+
"""
|
|
261
|
+
Converts an `options` dictionary into an arbitrary object and returns it.
|
|
262
|
+
This function may contain target-specific heuristics and check the legality of the provided options
|
|
263
|
+
"""
|
|
264
|
+
raise NotImplementedError
|
|
265
|
+
|
|
266
|
+
@abstractmethod
|
|
267
|
+
def add_stages(self, stages: dict, options: object) -> None:
|
|
268
|
+
"""
|
|
269
|
+
Populates `stages` dictionary with entries of the form:
|
|
270
|
+
ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes]
|
|
271
|
+
The value of each entry may populate a `metadata` dictionary.
|
|
272
|
+
Stages will be run sequentially (in inseriton order) and can communicate using `metadata`.
|
|
273
|
+
All stages are expected to return a `str` object, except for the last stage which returns
|
|
274
|
+
a `bytes` object for execution by the launcher.
|
|
275
|
+
"""
|
|
276
|
+
raise NotImplementedError
|
|
277
|
+
|
|
278
|
+
@abstractmethod
|
|
279
|
+
def load_dialects(self, context):
|
|
280
|
+
"""
|
|
281
|
+
Load additional MLIR dialects into the provided `context`
|
|
282
|
+
"""
|
|
283
|
+
raise NotImplementedError
|
|
284
|
+
|
|
285
|
+
@abstractmethod
|
|
286
|
+
def get_module_map(self) -> Dict[str, ModuleType]:
|
|
287
|
+
"""
|
|
288
|
+
Return a map of interface modules to their device-specific implementations
|
|
289
|
+
"""
|
|
290
|
+
raise NotImplementedError
|
|
291
|
+
|
|
292
|
+
def get_attrs_descriptor(self, params, args):
|
|
293
|
+
"""
|
|
294
|
+
Return an attribute descriptor: given a set of parameters and arguments
|
|
295
|
+
the descriptor stores a set of compile time properties that can improve code
|
|
296
|
+
generation. Different backends might benefit from different properties
|
|
297
|
+
"""
|
|
298
|
+
return AttrsDescriptor(params, args)
|
|
299
|
+
|
|
300
|
+
def compute_spec_key(self, arg, align):
|
|
301
|
+
"""
|
|
302
|
+
Return the ascii key for a given argument with a given set of properties
|
|
303
|
+
"""
|
|
304
|
+
return AttrsDescriptor.get_property_key(arg, align)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from abc import ABCMeta, abstractmethod, abstractclassmethod
|
|
2
|
+
from typing import Callable, List, Protocol, Sequence
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Benchmarker(Protocol):
|
|
6
|
+
|
|
7
|
+
def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]:
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DriverBase(metaclass=ABCMeta):
|
|
12
|
+
|
|
13
|
+
@abstractclassmethod
|
|
14
|
+
def is_active(self):
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def get_current_target(self):
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def get_benchmarker(self) -> Benchmarker:
|
|
23
|
+
"""
|
|
24
|
+
Return the benchmarking function that this backend should use by default.
|
|
25
|
+
"""
|
|
26
|
+
raise NotImplementedError
|
|
27
|
+
|
|
28
|
+
def __init__(self) -> None:
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class GPUDriver(DriverBase):
|
|
33
|
+
|
|
34
|
+
def __init__(self):
|
|
35
|
+
# TODO: support other frameworks than torch
|
|
36
|
+
import torch
|
|
37
|
+
self.get_device_capability = torch.cuda.get_device_capability
|
|
38
|
+
try:
|
|
39
|
+
from torch._C import _cuda_getCurrentRawStream
|
|
40
|
+
self.get_current_stream = _cuda_getCurrentRawStream
|
|
41
|
+
except ImportError:
|
|
42
|
+
self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream
|
|
43
|
+
self.get_current_device = torch.cuda.current_device
|
|
44
|
+
self.set_current_device = torch.cuda.set_device
|
|
45
|
+
|
|
46
|
+
# TODO: remove once TMA is cleaned up
|
|
47
|
+
def assemble_tensormap_to_arg(self, tensormaps_info, args):
|
|
48
|
+
return args
|
|
File without changes
|
|
Binary file
|