triton-windows 3.2.0.post12__cp39-cp39-win_amd64.whl → 3.3.0a0.post12__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +3 -3
  3. triton/_internal_testing.py +59 -4
  4. triton/_utils.py +35 -0
  5. triton/backends/amd/compiler.py +121 -74
  6. triton/backends/amd/driver.py +77 -43
  7. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
  8. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
  13. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
  15. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
  16. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
  17. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
  18. triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
  19. triton/backends/amd/include/hip/hip_ext.h +4 -2
  20. triton/backends/amd/include/hip/hip_fp8.h +33 -0
  21. triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
  22. triton/backends/amd/include/hip/hip_version.h +3 -3
  23. triton/backends/amd/include/hip/hiprtc.h +25 -25
  24. triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
  25. triton/backends/amd/include/hsa/hsa.h +11 -2
  26. triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
  27. triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
  28. triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
  29. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
  30. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
  31. triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
  32. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
  33. triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
  34. triton/backends/amd/lib/asanrtl.bc +0 -0
  35. triton/backends/compiler.py +25 -225
  36. triton/backends/driver.py +7 -2
  37. triton/backends/nvidia/bin/ptxas.exe +0 -0
  38. triton/backends/nvidia/compiler.py +135 -90
  39. triton/backends/nvidia/driver.c +0 -1
  40. triton/backends/nvidia/driver.py +135 -49
  41. triton/backends/nvidia/include/cuda.h +2162 -241
  42. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  43. triton/compiler/__init__.py +2 -2
  44. triton/compiler/code_generator.py +334 -231
  45. triton/compiler/compiler.py +77 -66
  46. triton/language/__init__.py +22 -5
  47. triton/language/core.py +448 -74
  48. triton/language/extra/cuda/_experimental_tma.py +3 -5
  49. triton/language/math.py +1 -1
  50. triton/language/random.py +2 -1
  51. triton/language/semantic.py +206 -52
  52. triton/language/standard.py +35 -18
  53. triton/runtime/_allocation.py +32 -0
  54. triton/runtime/autotuner.py +27 -32
  55. triton/runtime/build.py +1 -48
  56. triton/runtime/cache.py +6 -6
  57. triton/runtime/errors.py +10 -0
  58. triton/runtime/interpreter.py +179 -45
  59. triton/runtime/jit.py +149 -190
  60. triton/testing.py +39 -11
  61. triton/tools/compile.py +27 -20
  62. triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
  63. triton/tools/mxfp.py +301 -0
  64. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/METADATA +5 -2
  65. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/RECORD +68 -59
  66. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/top_level.txt +2 -0
  67. /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
  68. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/WHEEL +0 -0
@@ -1,218 +1,12 @@
1
1
  import os
2
2
  import re
3
- import hashlib
4
3
  import subprocess
5
-
6
- from abc import ABCMeta, abstractmethod, abstractclassmethod
4
+ import sysconfig
5
+ from abc import ABCMeta, abstractmethod
7
6
  from dataclasses import dataclass
8
- from typing import Dict, List, Tuple, Union
7
+ from typing import Dict, Union
9
8
  from types import ModuleType
10
9
 
11
- # Table that associates strings to AttrsDescriptor (sub)classes.
12
- # In this way we can dynamically select the correct class
13
- # constructor
14
- _descriptor_table = {}
15
-
16
-
17
- def register_descriptor(cls):
18
- """
19
- Register a descriptor into the descriptor table
20
- """
21
- _descriptor_table[cls.__name__] = cls
22
- return cls
23
-
24
-
25
- @register_descriptor
26
- class AttrsDescriptor:
27
- """
28
- This class handles compile-time properties for specific function parameters.
29
-
30
- Different backends can add more properties to the common ones. The class
31
- contains two fields:
32
-
33
- `arg_properties`: a dictionary containing the different compile-time properties for different
34
- parameters. I.e., the dictionary is a map from property names to parameter indices
35
- {
36
- "prop0": (0, 2, 3)
37
- "prop1": (0, 4, 5)
38
- }
39
- Different backends might need different properties on those paraemters to enable
40
- specific optimizations. The common compile time properties contained in this class
41
- are :
42
- - "tt.divisibility", i.e., is the given parameter divisible by 16
43
- - "tt.equal_to_1", i.e., is the given parameter an integer constant 1
44
-
45
- `property_values`: a dictionary containing the value of the different compile-time properties, like:
46
- {
47
- "prop0": val0
48
- "prop1": val1
49
- }
50
-
51
- `constant_properties`: a set containing the properties that can be used to determine if a parameter is constant
52
-
53
- """
54
- __slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties')
55
-
56
- def __init__(self, params=None, values=None):
57
- """
58
- Initialize the compile-time properties
59
-
60
- We can initialize the AttrsDescriptor class by passing the list of params
61
- of the function and their `values`. The function will try to apply the properties
62
- to the values and save the parameters in the `arg_properties` list. If we don't pass
63
- either the `params` or the `values` we should initialize the class via an alternative method
64
- (see `from_dict` or `from_hints`)
65
- """
66
- # Default initialization
67
- self.arg_properties = {}
68
- self.property_values = {}
69
- self.constant_properties = set()
70
-
71
- self._add_common_properties(params, values)
72
- self._add_backend_properties(params, values)
73
- self._init_slots()
74
-
75
- def _add_common_properties(self, params, values):
76
- """ Add common compile-time properties """
77
- self.property_values["tt.divisibility"] = 16
78
- self.property_values["tt.equal_to"] = 1
79
- self.constant_properties.add("tt.equal_to")
80
-
81
- if (params is None) or (values is None):
82
- return
83
-
84
- # Compile properties deduction
85
- assert (len(params) == len(values))
86
-
87
- # Divisibility property
88
- self.arg_properties["tt.divisibility"] = [
89
- param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg)
90
- and not param.do_not_specialize and not param.do_not_specialize_on_alignment
91
- ]
92
-
93
- # Equal to 1 property
94
- self.arg_properties["tt.equal_to"] = [
95
- param.num
96
- for param, arg in zip(params, values)
97
- if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize
98
- ]
99
-
100
- def _add_backend_properties(self, params=None, values=None):
101
- """ This method is for different subclasses to implement their own compile-time properties """
102
- pass
103
-
104
- def _init_slots(self):
105
- """ Initialize the slots of this class """
106
- for name, val in self.arg_properties.items():
107
- setattr(self, name.removeprefix('tt.') + '_' + str(self.property_values[name]), val)
108
-
109
- def get_fn_attrs(self) -> Dict:
110
- """
111
- Get the function attributes as a dictionary.
112
-
113
- The returned dictionary will look like :
114
- {
115
- "arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]}
116
- "arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]}
117
- }
118
- """
119
- attrs = {}
120
- for prop_name, arg_set in self.arg_properties.items():
121
- prop_val = self.property_values[prop_name]
122
- for arg in arg_set:
123
- attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)]
124
- return attrs
125
-
126
- def get_constants(self) -> Dict:
127
- """ Return a mapping of constant parameters to their values """
128
- constants = {}
129
- for prop_name in self.constant_properties:
130
- for p in self.arg_properties.get(prop_name, []):
131
- constants[p] = self.property_values[prop_name]
132
- return constants
133
-
134
- def filter_out_constants(self):
135
- """ Return the same object, without properties marked as constants"""
136
- import copy
137
- c = copy.deepcopy(self)
138
- for prop_name in c.constant_properties:
139
- c.arg_properties.pop(prop_name, None)
140
- c.property_values.pop(prop_name, None)
141
- c.constant_properties = {}
142
- return c
143
-
144
- def hash(self):
145
- values = [sorted(self.arg_properties.values())]
146
- values += [sorted(self.property_values.values())]
147
- values += [sorted(self.constant_properties)]
148
- key = str(values)
149
- return hashlib.sha256(key.encode("utf-8")).hexdigest()
150
-
151
- def to_dict(self):
152
- """
153
- Store the fields of this class in a serializable dictionary
154
- """
155
- # We need to only store the `arg_properties` field. To initialize the
156
- # other fields we relay on the class type. We store it as a string in
157
- # the dictionary so that we can use it to invoke the appropriate
158
- # (sub)class constructor in the `from_dict` method.
159
- return {"arg_properties": self.arg_properties, "cls": type(self).__name__}
160
-
161
- @staticmethod
162
- def from_dict(data):
163
- """
164
- Create the object from a serializable dictionary
165
- """
166
- attrs_descriptor = _descriptor_table[data["cls"]]()
167
- for prop_name, param_ids in data["arg_properties"].items():
168
- attrs_descriptor.arg_properties[prop_name] = param_ids
169
- attrs_descriptor._init_slots()
170
- return attrs_descriptor
171
-
172
- @classmethod
173
- def from_hints(cls, hints: List[Tuple[int, int]]):
174
- """
175
- Create the class from a set of hints that are passed in.
176
-
177
- Instead of deducing the properties from a list of paramaters and values,
178
- the user can pass in a list of `hints=[(param_index, val)]` and if `val`
179
- matches one of the values of the properties (e.g., `prop_val[prop0]`),
180
- then we insert `param_index` into the correct list (e.g., in
181
- `arg_properties[prop0]`)
182
- """
183
- attrs_descriptor = cls()
184
- for prop_name, prop_val in attrs_descriptor.property_values.items():
185
- attrs_descriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val]
186
- attrs_descriptor._init_slots()
187
- return attrs_descriptor
188
-
189
- @staticmethod
190
- def is_divisible_by_16(x):
191
- """ Return if the argument is a multiple of 16"""
192
- if hasattr(x, "data_ptr"):
193
- return x.data_ptr() % 16 == 0
194
- elif isinstance(x, int):
195
- return x % 16 == 0
196
- if x is None:
197
- return True
198
- return False
199
-
200
- @staticmethod
201
- def is_equal_to_1(x):
202
- """ Return if the argument is a constant 1"""
203
- return True if isinstance(x, int) and not isinstance(x, bool) and x == 1 else False
204
-
205
- @staticmethod
206
- def get_property_key(val, align):
207
- if align and AttrsDescriptor.is_divisible_by_16(val):
208
- return "D"
209
- if AttrsDescriptor.is_equal_to_1(val):
210
- return "1"
211
- return "N"
212
-
213
- def __repr__(self):
214
- return f"AttrsDescriptor.from_dict({self.to_dict()!r})"
215
-
216
10
 
217
11
  @dataclass(frozen=True)
218
12
  class GPUTarget(object):
@@ -231,22 +25,23 @@ class BaseBackend(metaclass=ABCMeta):
231
25
 
232
26
  @staticmethod
233
27
  def _path_to_binary(binary: str):
28
+ binary += sysconfig.get_config_var("EXE")
234
29
  base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
235
30
  paths = [
236
31
  os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
237
32
  os.path.join(base_dir, "third_party", "cuda", "bin", binary),
238
33
  ]
239
- for p in paths:
240
- bin = p.split(" ")[0]
241
- if os.path.exists(bin) and os.path.isfile(bin):
242
- result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
34
+ for path in paths:
35
+ if os.path.exists(path) and os.path.isfile(path):
36
+ result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
243
37
  if result is not None:
244
38
  version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
245
39
  if version is not None:
246
- return p, version.group(1)
40
+ return path, version.group(1)
247
41
  raise RuntimeError(f"Cannot find {binary}")
248
42
 
249
- @abstractclassmethod
43
+ @classmethod
44
+ @abstractmethod
250
45
  def supports_target(target: GPUTarget):
251
46
  raise NotImplementedError
252
47
 
@@ -289,16 +84,21 @@ class BaseBackend(metaclass=ABCMeta):
289
84
  """
290
85
  raise NotImplementedError
291
86
 
292
- def get_attrs_descriptor(self, params, args):
293
- """
294
- Return an attribute descriptor: given a set of parameters and arguments
295
- the descriptor stores a set of compile time properties that can improve code
296
- generation. Different backends might benefit from different properties
297
- """
298
- return AttrsDescriptor(params, args)
87
+ @staticmethod
88
+ def parse_attr(desc):
89
+ assert isinstance(desc, str)
90
+ ret = []
91
+ if "D" in desc:
92
+ ret += [["tt.divisibility", 16]]
93
+ return ret
299
94
 
300
- def compute_spec_key(self, arg, align):
95
+ @staticmethod
96
+ def get_arg_specialization(arg, ty, **kwargs):
301
97
  """
302
- Return the ascii key for a given argument with a given set of properties
98
+ Return a string unique to each possible specialization of the argument
303
99
  """
304
- return AttrsDescriptor.get_property_key(arg, align)
100
+ if ty == "int" and arg % 16 == 0 and kwargs.get("align", False):
101
+ return "D"
102
+ if ty == "tensor" and arg.data_ptr() % 16 == 0 and kwargs.get("align", False):
103
+ return "D"
104
+ return ""
triton/backends/driver.py CHANGED
@@ -1,4 +1,4 @@
1
- from abc import ABCMeta, abstractmethod, abstractclassmethod
1
+ from abc import ABCMeta, abstractmethod
2
2
  from typing import Callable, List, Protocol, Sequence
3
3
 
4
4
 
@@ -10,7 +10,8 @@ class Benchmarker(Protocol):
10
10
 
11
11
  class DriverBase(metaclass=ABCMeta):
12
12
 
13
- @abstractclassmethod
13
+ @classmethod
14
+ @abstractmethod
14
15
  def is_active(self):
15
16
  pass
16
17
 
@@ -18,6 +19,10 @@ class DriverBase(metaclass=ABCMeta):
18
19
  def get_current_target(self):
19
20
  pass
20
21
 
22
+ @abstractmethod
23
+ def get_active_torch_device(self):
24
+ pass
25
+
21
26
  @abstractmethod
22
27
  def get_benchmarker(self) -> Benchmarker:
23
28
  """
Binary file