triton-windows 3.2.0.post11__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (154) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +85 -0
  3. triton/_internal_testing.py +123 -0
  4. triton/backends/__init__.py +50 -0
  5. triton/backends/amd/compiler.py +368 -0
  6. triton/backends/amd/driver.c +211 -0
  7. triton/backends/amd/driver.py +512 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  25. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  26. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  27. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  28. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  31. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  32. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  40. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  41. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  42. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  43. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  44. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  45. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  46. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  48. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  49. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  50. triton/backends/amd/include/hip/device_functions.h +38 -0
  51. triton/backends/amd/include/hip/driver_types.h +468 -0
  52. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  53. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  54. triton/backends/amd/include/hip/hip_common.h +100 -0
  55. triton/backends/amd/include/hip/hip_complex.h +38 -0
  56. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  57. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  58. triton/backends/amd/include/hip/hip_ext.h +159 -0
  59. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  60. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  61. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  62. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  63. triton/backends/amd/include/hip/hip_profile.h +27 -0
  64. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  65. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  66. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  67. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  68. triton/backends/amd/include/hip/hip_version.h +17 -0
  69. triton/backends/amd/include/hip/hiprtc.h +421 -0
  70. triton/backends/amd/include/hip/library_types.h +78 -0
  71. triton/backends/amd/include/hip/math_functions.h +42 -0
  72. triton/backends/amd/include/hip/surface_types.h +63 -0
  73. triton/backends/amd/include/hip/texture_types.h +194 -0
  74. triton/backends/amd/include/hsa/Brig.h +1131 -0
  75. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  76. triton/backends/amd/include/hsa/amd_hsa_elf.h +436 -0
  77. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  78. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  79. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  80. triton/backends/amd/include/hsa/hsa.h +5729 -0
  81. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  82. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  83. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  84. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  85. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  87. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  88. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  89. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  90. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  91. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  92. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  93. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  94. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  95. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  96. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  97. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  98. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  99. triton/backends/amd/include/roctracer/roctx.h +229 -0
  100. triton/backends/amd/lib/ockl.bc +0 -0
  101. triton/backends/amd/lib/ocml.bc +0 -0
  102. triton/backends/compiler.py +304 -0
  103. triton/backends/driver.py +48 -0
  104. triton/backends/nvidia/__init__.py +0 -0
  105. triton/backends/nvidia/bin/ptxas.exe +0 -0
  106. triton/backends/nvidia/compiler.py +410 -0
  107. triton/backends/nvidia/driver.c +451 -0
  108. triton/backends/nvidia/driver.py +524 -0
  109. triton/backends/nvidia/include/cuda.h +24359 -0
  110. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  111. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  112. triton/compiler/__init__.py +4 -0
  113. triton/compiler/code_generator.py +1303 -0
  114. triton/compiler/compiler.py +430 -0
  115. triton/compiler/errors.py +51 -0
  116. triton/compiler/make_launcher.py +0 -0
  117. triton/errors.py +5 -0
  118. triton/language/__init__.py +294 -0
  119. triton/language/_utils.py +21 -0
  120. triton/language/core.py +2694 -0
  121. triton/language/extra/__init__.py +26 -0
  122. triton/language/extra/cuda/__init__.py +13 -0
  123. triton/language/extra/cuda/_experimental_tma.py +108 -0
  124. triton/language/extra/cuda/libdevice.py +1629 -0
  125. triton/language/extra/cuda/utils.py +109 -0
  126. triton/language/extra/hip/__init__.py +3 -0
  127. triton/language/extra/hip/libdevice.py +475 -0
  128. triton/language/extra/libdevice.py +786 -0
  129. triton/language/math.py +250 -0
  130. triton/language/random.py +207 -0
  131. triton/language/semantic.py +1796 -0
  132. triton/language/standard.py +452 -0
  133. triton/runtime/__init__.py +23 -0
  134. triton/runtime/autotuner.py +408 -0
  135. triton/runtime/build.py +111 -0
  136. triton/runtime/cache.py +295 -0
  137. triton/runtime/driver.py +60 -0
  138. triton/runtime/errors.py +26 -0
  139. triton/runtime/interpreter.py +1235 -0
  140. triton/runtime/jit.py +951 -0
  141. triton/testing.py +511 -0
  142. triton/tools/__init__.py +0 -0
  143. triton/tools/build_extern.py +365 -0
  144. triton/tools/compile.c +67 -0
  145. triton/tools/compile.h +14 -0
  146. triton/tools/compile.py +155 -0
  147. triton/tools/disasm.py +144 -0
  148. triton/tools/experimental_descriptor.py +32 -0
  149. triton/tools/link.py +322 -0
  150. triton/windows_utils.py +375 -0
  151. triton_windows-3.2.0.post11.dist-info/METADATA +39 -0
  152. triton_windows-3.2.0.post11.dist-info/RECORD +154 -0
  153. triton_windows-3.2.0.post11.dist-info/WHEEL +5 -0
  154. triton_windows-3.2.0.post11.dist-info/top_level.txt +12 -0
@@ -0,0 +1,304 @@
1
+ import os
2
+ import re
3
+ import hashlib
4
+ import subprocess
5
+
6
+ from abc import ABCMeta, abstractmethod, abstractclassmethod
7
+ from dataclasses import dataclass
8
+ from typing import Dict, List, Tuple, Union
9
+ from types import ModuleType
10
+
11
+ # Table that associates strings to AttrsDescriptor (sub)classes.
12
+ # In this way we can dynamically select the correct class
13
+ # constructor
14
+ _descriptor_table = {}
15
+
16
+
17
+ def register_descriptor(cls):
18
+ """
19
+ Register a descriptor into the descriptor table
20
+ """
21
+ _descriptor_table[cls.__name__] = cls
22
+ return cls
23
+
24
+
25
+ @register_descriptor
26
+ class AttrsDescriptor:
27
+ """
28
+ This class handles compile-time properties for specific function parameters.
29
+
30
+ Different backends can add more properties to the common ones. The class
31
+ contains two fields:
32
+
33
+ `arg_properties`: a dictionary containing the different compile-time properties for different
34
+ parameters. I.e., the dictionary is a map from property names to parameter indices
35
+ {
36
+ "prop0": (0, 2, 3)
37
+ "prop1": (0, 4, 5)
38
+ }
39
+ Different backends might need different properties on those paraemters to enable
40
+ specific optimizations. The common compile time properties contained in this class
41
+ are :
42
+ - "tt.divisibility", i.e., is the given parameter divisible by 16
43
+ - "tt.equal_to_1", i.e., is the given parameter an integer constant 1
44
+
45
+ `property_values`: a dictionary containing the value of the different compile-time properties, like:
46
+ {
47
+ "prop0": val0
48
+ "prop1": val1
49
+ }
50
+
51
+ `constant_properties`: a set containing the properties that can be used to determine if a parameter is constant
52
+
53
+ """
54
+ __slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties')
55
+
56
+ def __init__(self, params=None, values=None):
57
+ """
58
+ Initialize the compile-time properties
59
+
60
+ We can initialize the AttrsDescriptor class by passing the list of params
61
+ of the function and their `values`. The function will try to apply the properties
62
+ to the values and save the parameters in the `arg_properties` list. If we don't pass
63
+ either the `params` or the `values` we should initialize the class via an alternative method
64
+ (see `from_dict` or `from_hints`)
65
+ """
66
+ # Default initialization
67
+ self.arg_properties = {}
68
+ self.property_values = {}
69
+ self.constant_properties = set()
70
+
71
+ self._add_common_properties(params, values)
72
+ self._add_backend_properties(params, values)
73
+ self._init_slots()
74
+
75
+ def _add_common_properties(self, params, values):
76
+ """ Add common compile-time properties """
77
+ self.property_values["tt.divisibility"] = 16
78
+ self.property_values["tt.equal_to"] = 1
79
+ self.constant_properties.add("tt.equal_to")
80
+
81
+ if (params is None) or (values is None):
82
+ return
83
+
84
+ # Compile properties deduction
85
+ assert (len(params) == len(values))
86
+
87
+ # Divisibility property
88
+ self.arg_properties["tt.divisibility"] = [
89
+ param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg)
90
+ and not param.do_not_specialize and not param.do_not_specialize_on_alignment
91
+ ]
92
+
93
+ # Equal to 1 property
94
+ self.arg_properties["tt.equal_to"] = [
95
+ param.num
96
+ for param, arg in zip(params, values)
97
+ if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize
98
+ ]
99
+
100
+ def _add_backend_properties(self, params=None, values=None):
101
+ """ This method is for different subclasses to implement their own compile-time properties """
102
+ pass
103
+
104
+ def _init_slots(self):
105
+ """ Initialize the slots of this class """
106
+ for name, val in self.arg_properties.items():
107
+ setattr(self, name.removeprefix('tt.') + '_' + str(self.property_values[name]), val)
108
+
109
+ def get_fn_attrs(self) -> Dict:
110
+ """
111
+ Get the function attributes as a dictionary.
112
+
113
+ The returned dictionary will look like :
114
+ {
115
+ "arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]}
116
+ "arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]}
117
+ }
118
+ """
119
+ attrs = {}
120
+ for prop_name, arg_set in self.arg_properties.items():
121
+ prop_val = self.property_values[prop_name]
122
+ for arg in arg_set:
123
+ attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)]
124
+ return attrs
125
+
126
+ def get_constants(self) -> Dict:
127
+ """ Return a mapping of constant parameters to their values """
128
+ constants = {}
129
+ for prop_name in self.constant_properties:
130
+ for p in self.arg_properties.get(prop_name, []):
131
+ constants[p] = self.property_values[prop_name]
132
+ return constants
133
+
134
+ def filter_out_constants(self):
135
+ """ Return the same object, without properties marked as constants"""
136
+ import copy
137
+ c = copy.deepcopy(self)
138
+ for prop_name in c.constant_properties:
139
+ c.arg_properties.pop(prop_name, None)
140
+ c.property_values.pop(prop_name, None)
141
+ c.constant_properties = {}
142
+ return c
143
+
144
+ def hash(self):
145
+ values = [sorted(self.arg_properties.values())]
146
+ values += [sorted(self.property_values.values())]
147
+ values += [sorted(self.constant_properties)]
148
+ key = str(values)
149
+ return hashlib.sha256(key.encode("utf-8")).hexdigest()
150
+
151
+ def to_dict(self):
152
+ """
153
+ Store the fields of this class in a serializable dictionary
154
+ """
155
+ # We need to only store the `arg_properties` field. To initialize the
156
+ # other fields we relay on the class type. We store it as a string in
157
+ # the dictionary so that we can use it to invoke the appropriate
158
+ # (sub)class constructor in the `from_dict` method.
159
+ return {"arg_properties": self.arg_properties, "cls": type(self).__name__}
160
+
161
+ @staticmethod
162
+ def from_dict(data):
163
+ """
164
+ Create the object from a serializable dictionary
165
+ """
166
+ attrs_descriptor = _descriptor_table[data["cls"]]()
167
+ for prop_name, param_ids in data["arg_properties"].items():
168
+ attrs_descriptor.arg_properties[prop_name] = param_ids
169
+ attrs_descriptor._init_slots()
170
+ return attrs_descriptor
171
+
172
+ @classmethod
173
+ def from_hints(cls, hints: List[Tuple[int, int]]):
174
+ """
175
+ Create the class from a set of hints that are passed in.
176
+
177
+ Instead of deducing the properties from a list of paramaters and values,
178
+ the user can pass in a list of `hints=[(param_index, val)]` and if `val`
179
+ matches one of the values of the properties (e.g., `prop_val[prop0]`),
180
+ then we insert `param_index` into the correct list (e.g., in
181
+ `arg_properties[prop0]`)
182
+ """
183
+ attrs_descriptor = cls()
184
+ for prop_name, prop_val in attrs_descriptor.property_values.items():
185
+ attrs_descriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val]
186
+ attrs_descriptor._init_slots()
187
+ return attrs_descriptor
188
+
189
+ @staticmethod
190
+ def is_divisible_by_16(x):
191
+ """ Return if the argument is a multiple of 16"""
192
+ if hasattr(x, "data_ptr"):
193
+ return x.data_ptr() % 16 == 0
194
+ elif isinstance(x, int):
195
+ return x % 16 == 0
196
+ if x is None:
197
+ return True
198
+ return False
199
+
200
+ @staticmethod
201
+ def is_equal_to_1(x):
202
+ """ Return if the argument is a constant 1"""
203
+ return True if isinstance(x, int) and not isinstance(x, bool) and x == 1 else False
204
+
205
+ @staticmethod
206
+ def get_property_key(val, align):
207
+ if align and AttrsDescriptor.is_divisible_by_16(val):
208
+ return "D"
209
+ if AttrsDescriptor.is_equal_to_1(val):
210
+ return "1"
211
+ return "N"
212
+
213
+ def __repr__(self):
214
+ return f"AttrsDescriptor.from_dict({self.to_dict()!r})"
215
+
216
+
217
+ @dataclass(frozen=True)
218
+ class GPUTarget(object):
219
+ # Target backend, e.g., cuda, hip
220
+ backend: str
221
+ # Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip)
222
+ arch: Union[int, str]
223
+ warp_size: int
224
+
225
+
226
+ class BaseBackend(metaclass=ABCMeta):
227
+
228
+ def __init__(self, target: GPUTarget) -> None:
229
+ self.target = target
230
+ assert self.supports_target(target)
231
+
232
+ @staticmethod
233
+ def _path_to_binary(binary: str):
234
+ base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
235
+ paths = [
236
+ os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
237
+ os.path.join(base_dir, "third_party", "cuda", "bin", binary),
238
+ ]
239
+ for p in paths:
240
+ bin = p.split(" ")[0]
241
+ if os.path.exists(bin) and os.path.isfile(bin):
242
+ result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
243
+ if result is not None:
244
+ version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
245
+ if version is not None:
246
+ return p, version.group(1)
247
+ raise RuntimeError(f"Cannot find {binary}")
248
+
249
+ @abstractclassmethod
250
+ def supports_target(target: GPUTarget):
251
+ raise NotImplementedError
252
+
253
+ @abstractmethod
254
+ def hash(self) -> str:
255
+ """Returns a unique identifier for this backend"""
256
+ raise NotImplementedError
257
+
258
+ @abstractmethod
259
+ def parse_options(self, options: dict) -> object:
260
+ """
261
+ Converts an `options` dictionary into an arbitrary object and returns it.
262
+ This function may contain target-specific heuristics and check the legality of the provided options
263
+ """
264
+ raise NotImplementedError
265
+
266
+ @abstractmethod
267
+ def add_stages(self, stages: dict, options: object) -> None:
268
+ """
269
+ Populates `stages` dictionary with entries of the form:
270
+ ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes]
271
+ The value of each entry may populate a `metadata` dictionary.
272
+ Stages will be run sequentially (in inseriton order) and can communicate using `metadata`.
273
+ All stages are expected to return a `str` object, except for the last stage which returns
274
+ a `bytes` object for execution by the launcher.
275
+ """
276
+ raise NotImplementedError
277
+
278
+ @abstractmethod
279
+ def load_dialects(self, context):
280
+ """
281
+ Load additional MLIR dialects into the provided `context`
282
+ """
283
+ raise NotImplementedError
284
+
285
+ @abstractmethod
286
+ def get_module_map(self) -> Dict[str, ModuleType]:
287
+ """
288
+ Return a map of interface modules to their device-specific implementations
289
+ """
290
+ raise NotImplementedError
291
+
292
+ def get_attrs_descriptor(self, params, args):
293
+ """
294
+ Return an attribute descriptor: given a set of parameters and arguments
295
+ the descriptor stores a set of compile time properties that can improve code
296
+ generation. Different backends might benefit from different properties
297
+ """
298
+ return AttrsDescriptor(params, args)
299
+
300
+ def compute_spec_key(self, arg, align):
301
+ """
302
+ Return the ascii key for a given argument with a given set of properties
303
+ """
304
+ return AttrsDescriptor.get_property_key(arg, align)
@@ -0,0 +1,48 @@
1
+ from abc import ABCMeta, abstractmethod, abstractclassmethod
2
+ from typing import Callable, List, Protocol, Sequence
3
+
4
+
5
+ class Benchmarker(Protocol):
6
+
7
+ def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]:
8
+ pass
9
+
10
+
11
+ class DriverBase(metaclass=ABCMeta):
12
+
13
+ @abstractclassmethod
14
+ def is_active(self):
15
+ pass
16
+
17
+ @abstractmethod
18
+ def get_current_target(self):
19
+ pass
20
+
21
+ @abstractmethod
22
+ def get_benchmarker(self) -> Benchmarker:
23
+ """
24
+ Return the benchmarking function that this backend should use by default.
25
+ """
26
+ raise NotImplementedError
27
+
28
+ def __init__(self) -> None:
29
+ pass
30
+
31
+
32
+ class GPUDriver(DriverBase):
33
+
34
+ def __init__(self):
35
+ # TODO: support other frameworks than torch
36
+ import torch
37
+ self.get_device_capability = torch.cuda.get_device_capability
38
+ try:
39
+ from torch._C import _cuda_getCurrentRawStream
40
+ self.get_current_stream = _cuda_getCurrentRawStream
41
+ except ImportError:
42
+ self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream
43
+ self.get_current_device = torch.cuda.current_device
44
+ self.set_current_device = torch.cuda.set_device
45
+
46
+ # TODO: remove once TMA is cleaned up
47
+ def assemble_tensormap_to_arg(self, tensormaps_info, args):
48
+ return args
File without changes
Binary file