tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +248 -115
- tinygrad/codegen/lowerer.py +215 -0
- tinygrad/codegen/transcendental.py +310 -0
- tinygrad/codegen/uopgraph.py +622 -0
- tinygrad/codegen/uops.py +235 -393
- tinygrad/device.py +428 -69
- tinygrad/dtype.py +18 -4
- tinygrad/engine/graph.py +19 -32
- tinygrad/engine/jit.py +148 -70
- tinygrad/engine/realize.py +127 -51
- tinygrad/engine/schedule.py +259 -216
- tinygrad/engine/search.py +29 -22
- tinygrad/function.py +9 -0
- tinygrad/helpers.py +87 -49
- tinygrad/lazy.py +34 -35
- tinygrad/multi.py +41 -36
- tinygrad/nn/__init__.py +39 -22
- tinygrad/nn/state.py +3 -3
- tinygrad/ops.py +63 -62
- tinygrad/renderer/__init__.py +43 -21
- tinygrad/renderer/assembly.py +104 -106
- tinygrad/renderer/cstyle.py +87 -60
- tinygrad/renderer/llvmir.py +21 -30
- tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/kfd.py +32 -0
- tinygrad/runtime/autogen/libc.py +4260 -0
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/graph/clang.py +2 -2
- tinygrad/runtime/graph/cuda.py +8 -11
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +18 -15
- tinygrad/runtime/ops_amd.py +197 -305
- tinygrad/runtime/ops_clang.py +2 -2
- tinygrad/runtime/ops_cuda.py +36 -94
- tinygrad/runtime/ops_disk.py +3 -7
- tinygrad/runtime/ops_gpu.py +4 -2
- tinygrad/runtime/ops_hip.py +70 -0
- tinygrad/runtime/ops_metal.py +38 -27
- tinygrad/runtime/ops_nv.py +283 -363
- tinygrad/runtime/ops_python.py +26 -30
- tinygrad/runtime/support/compiler_cuda.py +78 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/shape/shapetracker.py +5 -14
- tinygrad/shape/symbolic.py +4 -8
- tinygrad/shape/view.py +34 -22
- tinygrad/tensor.py +399 -97
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
- tinygrad-0.9.2.dist-info/RECORD +70 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/runtime/{driver → support}/__init__.py +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,579 @@
|
|
1
|
+
# mypy: ignore-errors
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
#
|
4
|
+
# TARGET arch is: []
|
5
|
+
# WORD_SIZE is: 8
|
6
|
+
# POINTER_SIZE is: 8
|
7
|
+
# LONGDOUBLE_SIZE is: 16
|
8
|
+
#
|
9
|
+
import ctypes, ctypes.util
|
10
|
+
|
11
|
+
|
12
|
+
_libraries = {}
|
13
|
+
_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc'))
|
14
|
+
def string_cast(char_pointer, encoding='utf-8', errors='strict'):
|
15
|
+
value = ctypes.cast(char_pointer, ctypes.c_char_p).value
|
16
|
+
if value is not None and encoding is not None:
|
17
|
+
value = value.decode(encoding, errors=errors)
|
18
|
+
return value
|
19
|
+
|
20
|
+
|
21
|
+
def char_pointer_cast(string, encoding='utf-8'):
|
22
|
+
if encoding is not None:
|
23
|
+
try:
|
24
|
+
string = string.encode(encoding)
|
25
|
+
except AttributeError:
|
26
|
+
# In Python3, bytes has no encode attribute
|
27
|
+
pass
|
28
|
+
string = ctypes.c_char_p(string)
|
29
|
+
return ctypes.cast(string, ctypes.POINTER(ctypes.c_char))
|
30
|
+
|
31
|
+
|
32
|
+
|
33
|
+
class AsDictMixin:
|
34
|
+
@classmethod
|
35
|
+
def as_dict(cls, self):
|
36
|
+
result = {}
|
37
|
+
if not isinstance(self, AsDictMixin):
|
38
|
+
# not a structure, assume it's already a python object
|
39
|
+
return self
|
40
|
+
if not hasattr(cls, "_fields_"):
|
41
|
+
return result
|
42
|
+
# sys.version_info >= (3, 5)
|
43
|
+
# for (field, *_) in cls._fields_: # noqa
|
44
|
+
for field_tuple in cls._fields_: # noqa
|
45
|
+
field = field_tuple[0]
|
46
|
+
if field.startswith('PADDING_'):
|
47
|
+
continue
|
48
|
+
value = getattr(self, field)
|
49
|
+
type_ = type(value)
|
50
|
+
if hasattr(value, "_length_") and hasattr(value, "_type_"):
|
51
|
+
# array
|
52
|
+
if not hasattr(type_, "as_dict"):
|
53
|
+
value = [v for v in value]
|
54
|
+
else:
|
55
|
+
type_ = type_._type_
|
56
|
+
value = [type_.as_dict(v) for v in value]
|
57
|
+
elif hasattr(value, "contents") and hasattr(value, "_type_"):
|
58
|
+
# pointer
|
59
|
+
try:
|
60
|
+
if not hasattr(type_, "as_dict"):
|
61
|
+
value = value.contents
|
62
|
+
else:
|
63
|
+
type_ = type_._type_
|
64
|
+
value = type_.as_dict(value.contents)
|
65
|
+
except ValueError:
|
66
|
+
# nullptr
|
67
|
+
value = None
|
68
|
+
elif isinstance(value, AsDictMixin):
|
69
|
+
# other structure
|
70
|
+
value = type_.as_dict(value)
|
71
|
+
result[field] = value
|
72
|
+
return result
|
73
|
+
|
74
|
+
|
75
|
+
class Structure(ctypes.Structure, AsDictMixin):
|
76
|
+
|
77
|
+
def __init__(self, *args, **kwds):
|
78
|
+
# We don't want to use positional arguments fill PADDING_* fields
|
79
|
+
|
80
|
+
args = dict(zip(self.__class__._field_names_(), args))
|
81
|
+
args.update(kwds)
|
82
|
+
super(Structure, self).__init__(**args)
|
83
|
+
|
84
|
+
@classmethod
|
85
|
+
def _field_names_(cls):
|
86
|
+
if hasattr(cls, '_fields_'):
|
87
|
+
return (f[0] for f in cls._fields_ if not f[0].startswith('PADDING'))
|
88
|
+
else:
|
89
|
+
return ()
|
90
|
+
|
91
|
+
@classmethod
|
92
|
+
def get_type(cls, field):
|
93
|
+
for f in cls._fields_:
|
94
|
+
if f[0] == field:
|
95
|
+
return f[1]
|
96
|
+
return None
|
97
|
+
|
98
|
+
@classmethod
|
99
|
+
def bind(cls, bound_fields):
|
100
|
+
fields = {}
|
101
|
+
for name, type_ in cls._fields_:
|
102
|
+
if hasattr(type_, "restype"):
|
103
|
+
if name in bound_fields:
|
104
|
+
if bound_fields[name] is None:
|
105
|
+
fields[name] = type_()
|
106
|
+
else:
|
107
|
+
# use a closure to capture the callback from the loop scope
|
108
|
+
fields[name] = (
|
109
|
+
type_((lambda callback: lambda *args: callback(*args))(
|
110
|
+
bound_fields[name]))
|
111
|
+
)
|
112
|
+
del bound_fields[name]
|
113
|
+
else:
|
114
|
+
# default callback implementation (does nothing)
|
115
|
+
try:
|
116
|
+
default_ = type_(0).restype().value
|
117
|
+
except TypeError:
|
118
|
+
default_ = None
|
119
|
+
fields[name] = type_((
|
120
|
+
lambda default_: lambda *args: default_)(default_))
|
121
|
+
else:
|
122
|
+
# not a callback function, use default initialization
|
123
|
+
if name in bound_fields:
|
124
|
+
fields[name] = bound_fields[name]
|
125
|
+
del bound_fields[name]
|
126
|
+
else:
|
127
|
+
fields[name] = type_()
|
128
|
+
if len(bound_fields) != 0:
|
129
|
+
raise ValueError(
|
130
|
+
"Cannot bind the following unknown callback(s) {}.{}".format(
|
131
|
+
cls.__name__, bound_fields.keys()
|
132
|
+
))
|
133
|
+
return cls(**fields)
|
134
|
+
|
135
|
+
|
136
|
+
class Union(ctypes.Union, AsDictMixin):
|
137
|
+
pass
|
138
|
+
|
139
|
+
|
140
|
+
|
141
|
+
_libraries['libnvJitLink.so'] = ctypes.CDLL(ctypes.util.find_library('nvJitLink'))
|
142
|
+
c_int128 = ctypes.c_ubyte*16
|
143
|
+
c_uint128 = c_int128
|
144
|
+
void = None
|
145
|
+
if ctypes.sizeof(ctypes.c_longdouble) == 16:
|
146
|
+
c_long_double_t = ctypes.c_longdouble
|
147
|
+
else:
|
148
|
+
c_long_double_t = ctypes.c_ubyte*16
|
149
|
+
|
150
|
+
|
151
|
+
|
152
|
+
|
153
|
+
# values for enumeration 'c__EA_nvrtcResult'
|
154
|
+
c__EA_nvrtcResult__enumvalues = {
|
155
|
+
0: 'NVRTC_SUCCESS',
|
156
|
+
1: 'NVRTC_ERROR_OUT_OF_MEMORY',
|
157
|
+
2: 'NVRTC_ERROR_PROGRAM_CREATION_FAILURE',
|
158
|
+
3: 'NVRTC_ERROR_INVALID_INPUT',
|
159
|
+
4: 'NVRTC_ERROR_INVALID_PROGRAM',
|
160
|
+
5: 'NVRTC_ERROR_INVALID_OPTION',
|
161
|
+
6: 'NVRTC_ERROR_COMPILATION',
|
162
|
+
7: 'NVRTC_ERROR_BUILTIN_OPERATION_FAILURE',
|
163
|
+
8: 'NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION',
|
164
|
+
9: 'NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION',
|
165
|
+
10: 'NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID',
|
166
|
+
11: 'NVRTC_ERROR_INTERNAL_ERROR',
|
167
|
+
12: 'NVRTC_ERROR_TIME_FILE_WRITE_FAILED',
|
168
|
+
}
|
169
|
+
NVRTC_SUCCESS = 0
|
170
|
+
NVRTC_ERROR_OUT_OF_MEMORY = 1
|
171
|
+
NVRTC_ERROR_PROGRAM_CREATION_FAILURE = 2
|
172
|
+
NVRTC_ERROR_INVALID_INPUT = 3
|
173
|
+
NVRTC_ERROR_INVALID_PROGRAM = 4
|
174
|
+
NVRTC_ERROR_INVALID_OPTION = 5
|
175
|
+
NVRTC_ERROR_COMPILATION = 6
|
176
|
+
NVRTC_ERROR_BUILTIN_OPERATION_FAILURE = 7
|
177
|
+
NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = 8
|
178
|
+
NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = 9
|
179
|
+
NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID = 10
|
180
|
+
NVRTC_ERROR_INTERNAL_ERROR = 11
|
181
|
+
NVRTC_ERROR_TIME_FILE_WRITE_FAILED = 12
|
182
|
+
c__EA_nvrtcResult = ctypes.c_uint32 # enum
|
183
|
+
nvrtcResult = c__EA_nvrtcResult
|
184
|
+
nvrtcResult__enumvalues = c__EA_nvrtcResult__enumvalues
|
185
|
+
try:
|
186
|
+
nvrtcGetErrorString = _libraries['libnvrtc.so'].nvrtcGetErrorString
|
187
|
+
nvrtcGetErrorString.restype = ctypes.POINTER(ctypes.c_char)
|
188
|
+
nvrtcGetErrorString.argtypes = [nvrtcResult]
|
189
|
+
except AttributeError:
|
190
|
+
pass
|
191
|
+
try:
|
192
|
+
nvrtcVersion = _libraries['libnvrtc.so'].nvrtcVersion
|
193
|
+
nvrtcVersion.restype = nvrtcResult
|
194
|
+
nvrtcVersion.argtypes = [ctypes.POINTER(ctypes.c_int32), ctypes.POINTER(ctypes.c_int32)]
|
195
|
+
except AttributeError:
|
196
|
+
pass
|
197
|
+
try:
|
198
|
+
nvrtcGetNumSupportedArchs = _libraries['libnvrtc.so'].nvrtcGetNumSupportedArchs
|
199
|
+
nvrtcGetNumSupportedArchs.restype = nvrtcResult
|
200
|
+
nvrtcGetNumSupportedArchs.argtypes = [ctypes.POINTER(ctypes.c_int32)]
|
201
|
+
except AttributeError:
|
202
|
+
pass
|
203
|
+
try:
|
204
|
+
nvrtcGetSupportedArchs = _libraries['libnvrtc.so'].nvrtcGetSupportedArchs
|
205
|
+
nvrtcGetSupportedArchs.restype = nvrtcResult
|
206
|
+
nvrtcGetSupportedArchs.argtypes = [ctypes.POINTER(ctypes.c_int32)]
|
207
|
+
except AttributeError:
|
208
|
+
pass
|
209
|
+
class struct__nvrtcProgram(Structure):
|
210
|
+
pass
|
211
|
+
|
212
|
+
nvrtcProgram = ctypes.POINTER(struct__nvrtcProgram)
|
213
|
+
try:
|
214
|
+
nvrtcCreateProgram = _libraries['libnvrtc.so'].nvrtcCreateProgram
|
215
|
+
nvrtcCreateProgram.restype = nvrtcResult
|
216
|
+
nvrtcCreateProgram.argtypes = [ctypes.POINTER(ctypes.POINTER(struct__nvrtcProgram)), ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char), ctypes.c_int32, ctypes.POINTER(ctypes.POINTER(ctypes.c_char)), ctypes.POINTER(ctypes.POINTER(ctypes.c_char))]
|
217
|
+
except AttributeError:
|
218
|
+
pass
|
219
|
+
try:
|
220
|
+
nvrtcDestroyProgram = _libraries['libnvrtc.so'].nvrtcDestroyProgram
|
221
|
+
nvrtcDestroyProgram.restype = nvrtcResult
|
222
|
+
nvrtcDestroyProgram.argtypes = [ctypes.POINTER(ctypes.POINTER(struct__nvrtcProgram))]
|
223
|
+
except AttributeError:
|
224
|
+
pass
|
225
|
+
try:
|
226
|
+
nvrtcCompileProgram = _libraries['libnvrtc.so'].nvrtcCompileProgram
|
227
|
+
nvrtcCompileProgram.restype = nvrtcResult
|
228
|
+
nvrtcCompileProgram.argtypes = [nvrtcProgram, ctypes.c_int32, ctypes.POINTER(ctypes.POINTER(ctypes.c_char))]
|
229
|
+
except AttributeError:
|
230
|
+
pass
|
231
|
+
try:
|
232
|
+
nvrtcGetPTXSize = _libraries['libnvrtc.so'].nvrtcGetPTXSize
|
233
|
+
nvrtcGetPTXSize.restype = nvrtcResult
|
234
|
+
nvrtcGetPTXSize.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_uint64)]
|
235
|
+
except AttributeError:
|
236
|
+
pass
|
237
|
+
try:
|
238
|
+
nvrtcGetPTX = _libraries['libnvrtc.so'].nvrtcGetPTX
|
239
|
+
nvrtcGetPTX.restype = nvrtcResult
|
240
|
+
nvrtcGetPTX.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_char)]
|
241
|
+
except AttributeError:
|
242
|
+
pass
|
243
|
+
try:
|
244
|
+
nvrtcGetCUBINSize = _libraries['libnvrtc.so'].nvrtcGetCUBINSize
|
245
|
+
nvrtcGetCUBINSize.restype = nvrtcResult
|
246
|
+
nvrtcGetCUBINSize.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_uint64)]
|
247
|
+
except AttributeError:
|
248
|
+
pass
|
249
|
+
try:
|
250
|
+
nvrtcGetCUBIN = _libraries['libnvrtc.so'].nvrtcGetCUBIN
|
251
|
+
nvrtcGetCUBIN.restype = nvrtcResult
|
252
|
+
nvrtcGetCUBIN.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_char)]
|
253
|
+
except AttributeError:
|
254
|
+
pass
|
255
|
+
try:
|
256
|
+
nvrtcGetNVVMSize = _libraries['libnvrtc.so'].nvrtcGetNVVMSize
|
257
|
+
nvrtcGetNVVMSize.restype = nvrtcResult
|
258
|
+
nvrtcGetNVVMSize.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_uint64)]
|
259
|
+
except AttributeError:
|
260
|
+
pass
|
261
|
+
try:
|
262
|
+
nvrtcGetNVVM = _libraries['libnvrtc.so'].nvrtcGetNVVM
|
263
|
+
nvrtcGetNVVM.restype = nvrtcResult
|
264
|
+
nvrtcGetNVVM.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_char)]
|
265
|
+
except AttributeError:
|
266
|
+
pass
|
267
|
+
try:
|
268
|
+
nvrtcGetLTOIRSize = _libraries['libnvrtc.so'].nvrtcGetLTOIRSize
|
269
|
+
nvrtcGetLTOIRSize.restype = nvrtcResult
|
270
|
+
nvrtcGetLTOIRSize.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_uint64)]
|
271
|
+
except AttributeError:
|
272
|
+
pass
|
273
|
+
try:
|
274
|
+
nvrtcGetLTOIR = _libraries['libnvrtc.so'].nvrtcGetLTOIR
|
275
|
+
nvrtcGetLTOIR.restype = nvrtcResult
|
276
|
+
nvrtcGetLTOIR.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_char)]
|
277
|
+
except AttributeError:
|
278
|
+
pass
|
279
|
+
try:
|
280
|
+
nvrtcGetOptiXIRSize = _libraries['libnvrtc.so'].nvrtcGetOptiXIRSize
|
281
|
+
nvrtcGetOptiXIRSize.restype = nvrtcResult
|
282
|
+
nvrtcGetOptiXIRSize.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_uint64)]
|
283
|
+
except AttributeError:
|
284
|
+
pass
|
285
|
+
try:
|
286
|
+
nvrtcGetOptiXIR = _libraries['libnvrtc.so'].nvrtcGetOptiXIR
|
287
|
+
nvrtcGetOptiXIR.restype = nvrtcResult
|
288
|
+
nvrtcGetOptiXIR.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_char)]
|
289
|
+
except AttributeError:
|
290
|
+
pass
|
291
|
+
try:
|
292
|
+
nvrtcGetProgramLogSize = _libraries['libnvrtc.so'].nvrtcGetProgramLogSize
|
293
|
+
nvrtcGetProgramLogSize.restype = nvrtcResult
|
294
|
+
nvrtcGetProgramLogSize.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_uint64)]
|
295
|
+
except AttributeError:
|
296
|
+
pass
|
297
|
+
try:
|
298
|
+
nvrtcGetProgramLog = _libraries['libnvrtc.so'].nvrtcGetProgramLog
|
299
|
+
nvrtcGetProgramLog.restype = nvrtcResult
|
300
|
+
nvrtcGetProgramLog.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_char)]
|
301
|
+
except AttributeError:
|
302
|
+
pass
|
303
|
+
try:
|
304
|
+
nvrtcAddNameExpression = _libraries['libnvrtc.so'].nvrtcAddNameExpression
|
305
|
+
nvrtcAddNameExpression.restype = nvrtcResult
|
306
|
+
nvrtcAddNameExpression.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_char)]
|
307
|
+
except AttributeError:
|
308
|
+
pass
|
309
|
+
try:
|
310
|
+
nvrtcGetLoweredName = _libraries['libnvrtc.so'].nvrtcGetLoweredName
|
311
|
+
nvrtcGetLoweredName.restype = nvrtcResult
|
312
|
+
nvrtcGetLoweredName.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.POINTER(ctypes.c_char))]
|
313
|
+
except AttributeError:
|
314
|
+
pass
|
315
|
+
|
316
|
+
# values for enumeration 'c__EA_nvJitLinkResult'
|
317
|
+
c__EA_nvJitLinkResult__enumvalues = {
|
318
|
+
0: 'NVJITLINK_SUCCESS',
|
319
|
+
1: 'NVJITLINK_ERROR_UNRECOGNIZED_OPTION',
|
320
|
+
2: 'NVJITLINK_ERROR_MISSING_ARCH',
|
321
|
+
3: 'NVJITLINK_ERROR_INVALID_INPUT',
|
322
|
+
4: 'NVJITLINK_ERROR_PTX_COMPILE',
|
323
|
+
5: 'NVJITLINK_ERROR_NVVM_COMPILE',
|
324
|
+
6: 'NVJITLINK_ERROR_INTERNAL',
|
325
|
+
7: 'NVJITLINK_ERROR_THREADPOOL',
|
326
|
+
8: 'NVJITLINK_ERROR_UNRECOGNIZED_INPUT',
|
327
|
+
}
|
328
|
+
NVJITLINK_SUCCESS = 0
|
329
|
+
NVJITLINK_ERROR_UNRECOGNIZED_OPTION = 1
|
330
|
+
NVJITLINK_ERROR_MISSING_ARCH = 2
|
331
|
+
NVJITLINK_ERROR_INVALID_INPUT = 3
|
332
|
+
NVJITLINK_ERROR_PTX_COMPILE = 4
|
333
|
+
NVJITLINK_ERROR_NVVM_COMPILE = 5
|
334
|
+
NVJITLINK_ERROR_INTERNAL = 6
|
335
|
+
NVJITLINK_ERROR_THREADPOOL = 7
|
336
|
+
NVJITLINK_ERROR_UNRECOGNIZED_INPUT = 8
|
337
|
+
c__EA_nvJitLinkResult = ctypes.c_uint32 # enum
|
338
|
+
nvJitLinkResult = c__EA_nvJitLinkResult
|
339
|
+
nvJitLinkResult__enumvalues = c__EA_nvJitLinkResult__enumvalues
|
340
|
+
|
341
|
+
# values for enumeration 'c__EA_nvJitLinkInputType'
|
342
|
+
c__EA_nvJitLinkInputType__enumvalues = {
|
343
|
+
0: 'NVJITLINK_INPUT_NONE',
|
344
|
+
1: 'NVJITLINK_INPUT_CUBIN',
|
345
|
+
2: 'NVJITLINK_INPUT_PTX',
|
346
|
+
3: 'NVJITLINK_INPUT_LTOIR',
|
347
|
+
4: 'NVJITLINK_INPUT_FATBIN',
|
348
|
+
5: 'NVJITLINK_INPUT_OBJECT',
|
349
|
+
6: 'NVJITLINK_INPUT_LIBRARY',
|
350
|
+
10: 'NVJITLINK_INPUT_ANY',
|
351
|
+
}
|
352
|
+
NVJITLINK_INPUT_NONE = 0
|
353
|
+
NVJITLINK_INPUT_CUBIN = 1
|
354
|
+
NVJITLINK_INPUT_PTX = 2
|
355
|
+
NVJITLINK_INPUT_LTOIR = 3
|
356
|
+
NVJITLINK_INPUT_FATBIN = 4
|
357
|
+
NVJITLINK_INPUT_OBJECT = 5
|
358
|
+
NVJITLINK_INPUT_LIBRARY = 6
|
359
|
+
NVJITLINK_INPUT_ANY = 10
|
360
|
+
c__EA_nvJitLinkInputType = ctypes.c_uint32 # enum
|
361
|
+
nvJitLinkInputType = c__EA_nvJitLinkInputType
|
362
|
+
nvJitLinkInputType__enumvalues = c__EA_nvJitLinkInputType__enumvalues
|
363
|
+
class struct_nvJitLink(Structure):
|
364
|
+
pass
|
365
|
+
|
366
|
+
nvJitLinkHandle = ctypes.POINTER(struct_nvJitLink)
|
367
|
+
uint32_t = ctypes.c_uint32
|
368
|
+
try:
|
369
|
+
__nvJitLinkCreate_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkCreate_12_4
|
370
|
+
__nvJitLinkCreate_12_4.restype = nvJitLinkResult
|
371
|
+
__nvJitLinkCreate_12_4.argtypes = [ctypes.POINTER(ctypes.POINTER(struct_nvJitLink)), uint32_t, ctypes.POINTER(ctypes.POINTER(ctypes.c_char))]
|
372
|
+
except AttributeError:
|
373
|
+
pass
|
374
|
+
try:
|
375
|
+
nvJitLinkCreate = _libraries['libnvJitLink.so'].nvJitLinkCreate
|
376
|
+
nvJitLinkCreate.restype = nvJitLinkResult
|
377
|
+
nvJitLinkCreate.argtypes = [ctypes.POINTER(ctypes.POINTER(struct_nvJitLink)), uint32_t, ctypes.POINTER(ctypes.POINTER(ctypes.c_char))]
|
378
|
+
except AttributeError:
|
379
|
+
pass
|
380
|
+
try:
|
381
|
+
__nvJitLinkDestroy_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkDestroy_12_4
|
382
|
+
__nvJitLinkDestroy_12_4.restype = nvJitLinkResult
|
383
|
+
__nvJitLinkDestroy_12_4.argtypes = [ctypes.POINTER(ctypes.POINTER(struct_nvJitLink))]
|
384
|
+
except AttributeError:
|
385
|
+
pass
|
386
|
+
try:
|
387
|
+
nvJitLinkDestroy = _libraries['libnvJitLink.so'].nvJitLinkDestroy
|
388
|
+
nvJitLinkDestroy.restype = nvJitLinkResult
|
389
|
+
nvJitLinkDestroy.argtypes = [ctypes.POINTER(ctypes.POINTER(struct_nvJitLink))]
|
390
|
+
except AttributeError:
|
391
|
+
pass
|
392
|
+
size_t = ctypes.c_uint64
|
393
|
+
try:
|
394
|
+
__nvJitLinkAddData_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkAddData_12_4
|
395
|
+
__nvJitLinkAddData_12_4.restype = nvJitLinkResult
|
396
|
+
__nvJitLinkAddData_12_4.argtypes = [nvJitLinkHandle, nvJitLinkInputType, ctypes.POINTER(None), size_t, ctypes.POINTER(ctypes.c_char)]
|
397
|
+
except AttributeError:
|
398
|
+
pass
|
399
|
+
try:
|
400
|
+
nvJitLinkAddData = _libraries['libnvJitLink.so'].nvJitLinkAddData
|
401
|
+
nvJitLinkAddData.restype = nvJitLinkResult
|
402
|
+
nvJitLinkAddData.argtypes = [nvJitLinkHandle, nvJitLinkInputType, ctypes.POINTER(None), size_t, ctypes.POINTER(ctypes.c_char)]
|
403
|
+
except AttributeError:
|
404
|
+
pass
|
405
|
+
try:
|
406
|
+
__nvJitLinkAddFile_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkAddFile_12_4
|
407
|
+
__nvJitLinkAddFile_12_4.restype = nvJitLinkResult
|
408
|
+
__nvJitLinkAddFile_12_4.argtypes = [nvJitLinkHandle, nvJitLinkInputType, ctypes.POINTER(ctypes.c_char)]
|
409
|
+
except AttributeError:
|
410
|
+
pass
|
411
|
+
try:
|
412
|
+
nvJitLinkAddFile = _libraries['libnvJitLink.so'].nvJitLinkAddFile
|
413
|
+
nvJitLinkAddFile.restype = nvJitLinkResult
|
414
|
+
nvJitLinkAddFile.argtypes = [nvJitLinkHandle, nvJitLinkInputType, ctypes.POINTER(ctypes.c_char)]
|
415
|
+
except AttributeError:
|
416
|
+
pass
|
417
|
+
try:
|
418
|
+
__nvJitLinkComplete_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkComplete_12_4
|
419
|
+
__nvJitLinkComplete_12_4.restype = nvJitLinkResult
|
420
|
+
__nvJitLinkComplete_12_4.argtypes = [nvJitLinkHandle]
|
421
|
+
except AttributeError:
|
422
|
+
pass
|
423
|
+
try:
|
424
|
+
nvJitLinkComplete = _libraries['libnvJitLink.so'].nvJitLinkComplete
|
425
|
+
nvJitLinkComplete.restype = nvJitLinkResult
|
426
|
+
nvJitLinkComplete.argtypes = [nvJitLinkHandle]
|
427
|
+
except AttributeError:
|
428
|
+
pass
|
429
|
+
try:
|
430
|
+
__nvJitLinkGetLinkedCubinSize_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkGetLinkedCubinSize_12_4
|
431
|
+
__nvJitLinkGetLinkedCubinSize_12_4.restype = nvJitLinkResult
|
432
|
+
__nvJitLinkGetLinkedCubinSize_12_4.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_uint64)]
|
433
|
+
except AttributeError:
|
434
|
+
pass
|
435
|
+
try:
|
436
|
+
nvJitLinkGetLinkedCubinSize = _libraries['libnvJitLink.so'].nvJitLinkGetLinkedCubinSize
|
437
|
+
nvJitLinkGetLinkedCubinSize.restype = nvJitLinkResult
|
438
|
+
nvJitLinkGetLinkedCubinSize.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_uint64)]
|
439
|
+
except AttributeError:
|
440
|
+
pass
|
441
|
+
try:
|
442
|
+
__nvJitLinkGetLinkedCubin_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkGetLinkedCubin_12_4
|
443
|
+
__nvJitLinkGetLinkedCubin_12_4.restype = nvJitLinkResult
|
444
|
+
__nvJitLinkGetLinkedCubin_12_4.argtypes = [nvJitLinkHandle, ctypes.POINTER(None)]
|
445
|
+
except AttributeError:
|
446
|
+
pass
|
447
|
+
try:
|
448
|
+
nvJitLinkGetLinkedCubin = _libraries['libnvJitLink.so'].nvJitLinkGetLinkedCubin
|
449
|
+
nvJitLinkGetLinkedCubin.restype = nvJitLinkResult
|
450
|
+
nvJitLinkGetLinkedCubin.argtypes = [nvJitLinkHandle, ctypes.POINTER(None)]
|
451
|
+
except AttributeError:
|
452
|
+
pass
|
453
|
+
try:
|
454
|
+
__nvJitLinkGetLinkedPtxSize_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkGetLinkedPtxSize_12_4
|
455
|
+
__nvJitLinkGetLinkedPtxSize_12_4.restype = nvJitLinkResult
|
456
|
+
__nvJitLinkGetLinkedPtxSize_12_4.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_uint64)]
|
457
|
+
except AttributeError:
|
458
|
+
pass
|
459
|
+
try:
|
460
|
+
nvJitLinkGetLinkedPtxSize = _libraries['libnvJitLink.so'].nvJitLinkGetLinkedPtxSize
|
461
|
+
nvJitLinkGetLinkedPtxSize.restype = nvJitLinkResult
|
462
|
+
nvJitLinkGetLinkedPtxSize.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_uint64)]
|
463
|
+
except AttributeError:
|
464
|
+
pass
|
465
|
+
try:
|
466
|
+
__nvJitLinkGetLinkedPtx_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkGetLinkedPtx_12_4
|
467
|
+
__nvJitLinkGetLinkedPtx_12_4.restype = nvJitLinkResult
|
468
|
+
__nvJitLinkGetLinkedPtx_12_4.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_char)]
|
469
|
+
except AttributeError:
|
470
|
+
pass
|
471
|
+
try:
|
472
|
+
nvJitLinkGetLinkedPtx = _libraries['libnvJitLink.so'].nvJitLinkGetLinkedPtx
|
473
|
+
nvJitLinkGetLinkedPtx.restype = nvJitLinkResult
|
474
|
+
nvJitLinkGetLinkedPtx.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_char)]
|
475
|
+
except AttributeError:
|
476
|
+
pass
|
477
|
+
try:
|
478
|
+
__nvJitLinkGetErrorLogSize_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkGetErrorLogSize_12_4
|
479
|
+
__nvJitLinkGetErrorLogSize_12_4.restype = nvJitLinkResult
|
480
|
+
__nvJitLinkGetErrorLogSize_12_4.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_uint64)]
|
481
|
+
except AttributeError:
|
482
|
+
pass
|
483
|
+
try:
|
484
|
+
nvJitLinkGetErrorLogSize = _libraries['libnvJitLink.so'].nvJitLinkGetErrorLogSize
|
485
|
+
nvJitLinkGetErrorLogSize.restype = nvJitLinkResult
|
486
|
+
nvJitLinkGetErrorLogSize.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_uint64)]
|
487
|
+
except AttributeError:
|
488
|
+
pass
|
489
|
+
try:
|
490
|
+
__nvJitLinkGetErrorLog_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkGetErrorLog_12_4
|
491
|
+
__nvJitLinkGetErrorLog_12_4.restype = nvJitLinkResult
|
492
|
+
__nvJitLinkGetErrorLog_12_4.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_char)]
|
493
|
+
except AttributeError:
|
494
|
+
pass
|
495
|
+
try:
|
496
|
+
nvJitLinkGetErrorLog = _libraries['libnvJitLink.so'].nvJitLinkGetErrorLog
|
497
|
+
nvJitLinkGetErrorLog.restype = nvJitLinkResult
|
498
|
+
nvJitLinkGetErrorLog.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_char)]
|
499
|
+
except AttributeError:
|
500
|
+
pass
|
501
|
+
try:
|
502
|
+
__nvJitLinkGetInfoLogSize_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkGetInfoLogSize_12_4
|
503
|
+
__nvJitLinkGetInfoLogSize_12_4.restype = nvJitLinkResult
|
504
|
+
__nvJitLinkGetInfoLogSize_12_4.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_uint64)]
|
505
|
+
except AttributeError:
|
506
|
+
pass
|
507
|
+
try:
|
508
|
+
nvJitLinkGetInfoLogSize = _libraries['libnvJitLink.so'].nvJitLinkGetInfoLogSize
|
509
|
+
nvJitLinkGetInfoLogSize.restype = nvJitLinkResult
|
510
|
+
nvJitLinkGetInfoLogSize.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_uint64)]
|
511
|
+
except AttributeError:
|
512
|
+
pass
|
513
|
+
try:
|
514
|
+
__nvJitLinkGetInfoLog_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkGetInfoLog_12_4
|
515
|
+
__nvJitLinkGetInfoLog_12_4.restype = nvJitLinkResult
|
516
|
+
__nvJitLinkGetInfoLog_12_4.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_char)]
|
517
|
+
except AttributeError:
|
518
|
+
pass
|
519
|
+
try:
|
520
|
+
nvJitLinkGetInfoLog = _libraries['libnvJitLink.so'].nvJitLinkGetInfoLog
|
521
|
+
nvJitLinkGetInfoLog.restype = nvJitLinkResult
|
522
|
+
nvJitLinkGetInfoLog.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_char)]
|
523
|
+
except AttributeError:
|
524
|
+
pass
|
525
|
+
try:
|
526
|
+
nvJitLinkVersion = _libraries['libnvJitLink.so'].nvJitLinkVersion
|
527
|
+
nvJitLinkVersion.restype = nvJitLinkResult
|
528
|
+
nvJitLinkVersion.argtypes = [ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(ctypes.c_uint32)]
|
529
|
+
except AttributeError:
|
530
|
+
pass
|
531
|
+
__all__ = \
|
532
|
+
['NVJITLINK_ERROR_INTERNAL', 'NVJITLINK_ERROR_INVALID_INPUT',
|
533
|
+
'NVJITLINK_ERROR_MISSING_ARCH', 'NVJITLINK_ERROR_NVVM_COMPILE',
|
534
|
+
'NVJITLINK_ERROR_PTX_COMPILE', 'NVJITLINK_ERROR_THREADPOOL',
|
535
|
+
'NVJITLINK_ERROR_UNRECOGNIZED_INPUT',
|
536
|
+
'NVJITLINK_ERROR_UNRECOGNIZED_OPTION', 'NVJITLINK_INPUT_ANY',
|
537
|
+
'NVJITLINK_INPUT_CUBIN', 'NVJITLINK_INPUT_FATBIN',
|
538
|
+
'NVJITLINK_INPUT_LIBRARY', 'NVJITLINK_INPUT_LTOIR',
|
539
|
+
'NVJITLINK_INPUT_NONE', 'NVJITLINK_INPUT_OBJECT',
|
540
|
+
'NVJITLINK_INPUT_PTX', 'NVJITLINK_SUCCESS',
|
541
|
+
'NVRTC_ERROR_BUILTIN_OPERATION_FAILURE',
|
542
|
+
'NVRTC_ERROR_COMPILATION', 'NVRTC_ERROR_INTERNAL_ERROR',
|
543
|
+
'NVRTC_ERROR_INVALID_INPUT', 'NVRTC_ERROR_INVALID_OPTION',
|
544
|
+
'NVRTC_ERROR_INVALID_PROGRAM',
|
545
|
+
'NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID',
|
546
|
+
'NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION',
|
547
|
+
'NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION',
|
548
|
+
'NVRTC_ERROR_OUT_OF_MEMORY',
|
549
|
+
'NVRTC_ERROR_PROGRAM_CREATION_FAILURE',
|
550
|
+
'NVRTC_ERROR_TIME_FILE_WRITE_FAILED', 'NVRTC_SUCCESS',
|
551
|
+
'__nvJitLinkAddData_12_4', '__nvJitLinkAddFile_12_4',
|
552
|
+
'__nvJitLinkComplete_12_4', '__nvJitLinkCreate_12_4',
|
553
|
+
'__nvJitLinkDestroy_12_4', '__nvJitLinkGetErrorLogSize_12_4',
|
554
|
+
'__nvJitLinkGetErrorLog_12_4', '__nvJitLinkGetInfoLogSize_12_4',
|
555
|
+
'__nvJitLinkGetInfoLog_12_4',
|
556
|
+
'__nvJitLinkGetLinkedCubinSize_12_4',
|
557
|
+
'__nvJitLinkGetLinkedCubin_12_4',
|
558
|
+
'__nvJitLinkGetLinkedPtxSize_12_4',
|
559
|
+
'__nvJitLinkGetLinkedPtx_12_4', 'c__EA_nvJitLinkInputType',
|
560
|
+
'c__EA_nvJitLinkResult', 'c__EA_nvrtcResult', 'nvJitLinkAddData',
|
561
|
+
'nvJitLinkAddFile', 'nvJitLinkComplete', 'nvJitLinkCreate',
|
562
|
+
'nvJitLinkDestroy', 'nvJitLinkGetErrorLog',
|
563
|
+
'nvJitLinkGetErrorLogSize', 'nvJitLinkGetInfoLog',
|
564
|
+
'nvJitLinkGetInfoLogSize', 'nvJitLinkGetLinkedCubin',
|
565
|
+
'nvJitLinkGetLinkedCubinSize', 'nvJitLinkGetLinkedPtx',
|
566
|
+
'nvJitLinkGetLinkedPtxSize', 'nvJitLinkHandle',
|
567
|
+
'nvJitLinkInputType', 'nvJitLinkInputType__enumvalues',
|
568
|
+
'nvJitLinkResult', 'nvJitLinkResult__enumvalues',
|
569
|
+
'nvJitLinkVersion', 'nvrtcAddNameExpression',
|
570
|
+
'nvrtcCompileProgram', 'nvrtcCreateProgram',
|
571
|
+
'nvrtcDestroyProgram', 'nvrtcGetCUBIN', 'nvrtcGetCUBINSize',
|
572
|
+
'nvrtcGetErrorString', 'nvrtcGetLTOIR', 'nvrtcGetLTOIRSize',
|
573
|
+
'nvrtcGetLoweredName', 'nvrtcGetNVVM', 'nvrtcGetNVVMSize',
|
574
|
+
'nvrtcGetNumSupportedArchs', 'nvrtcGetOptiXIR',
|
575
|
+
'nvrtcGetOptiXIRSize', 'nvrtcGetPTX', 'nvrtcGetPTXSize',
|
576
|
+
'nvrtcGetProgramLog', 'nvrtcGetProgramLogSize',
|
577
|
+
'nvrtcGetSupportedArchs', 'nvrtcProgram', 'nvrtcResult',
|
578
|
+
'nvrtcResult__enumvalues', 'nvrtcVersion', 'size_t',
|
579
|
+
'struct__nvrtcProgram', 'struct_nvJitLink', 'uint32_t']
|
tinygrad/runtime/graph/clang.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from typing import List, Dict, cast
|
2
2
|
import ctypes
|
3
|
-
from tinygrad.helpers import dedup, cpu_time_execution,
|
4
|
-
from tinygrad.engine.jit import GraphRunner
|
3
|
+
from tinygrad.helpers import dedup, cpu_time_execution, DEBUG
|
4
|
+
from tinygrad.engine.jit import GraphRunner, GraphException
|
5
5
|
from tinygrad.device import Buffer, Device
|
6
6
|
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
7
7
|
from tinygrad.shape.symbolic import Variable
|
tinygrad/runtime/graph/cuda.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
import ctypes
|
2
2
|
from typing import Any, Optional, Tuple, Dict, List, cast
|
3
3
|
import tinygrad.runtime.autogen.cuda as cuda
|
4
|
-
from tinygrad.helpers import init_c_var,
|
4
|
+
from tinygrad.helpers import init_c_var, dedup
|
5
5
|
from tinygrad.device import Buffer, Device
|
6
6
|
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
7
7
|
from tinygrad.shape.symbolic import Variable
|
8
8
|
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
9
|
-
from tinygrad.engine.jit import MultiGraphRunner
|
9
|
+
from tinygrad.engine.jit import MultiGraphRunner, GraphException
|
10
10
|
|
11
11
|
class CUDAGraph(MultiGraphRunner):
|
12
12
|
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
@@ -33,7 +33,7 @@ class CUDAGraph(MultiGraphRunner):
|
|
33
33
|
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs)
|
34
34
|
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
|
35
35
|
|
36
|
-
if j in self.
|
36
|
+
if j in self.launch_dims_replace or j in self.var_vals_replace or j in self.jc_idx_with_updatable_rawbufs:
|
37
37
|
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
|
38
38
|
elif isinstance(ji.prg, BufferXfer):
|
39
39
|
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
@@ -58,13 +58,13 @@ class CUDAGraph(MultiGraphRunner):
|
|
58
58
|
elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
|
59
59
|
|
60
60
|
# Update var_vals in the c_args struct.
|
61
|
-
for j in self.
|
62
|
-
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
|
63
|
-
setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
|
61
|
+
for j, i, v in self.updated_vars(var_vals): setattr(self.updatable_nodes[j][2], f'v{i}', v)
|
64
62
|
|
65
63
|
# Update launch dims in the kern_params struct.
|
66
|
-
for j in self.
|
67
|
-
|
64
|
+
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
65
|
+
prg = cast(CompiledRunner, self.jit_cache[j].prg)
|
66
|
+
node, global_size, local_size = self.updatable_nodes[j][1], global_dims or prg.p.global_size, local_dims or prg.p.local_size
|
67
|
+
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size # type: ignore[misc]
|
68
68
|
|
69
69
|
# Update graph nodes with the updated structs.
|
70
70
|
for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
|
@@ -76,6 +76,3 @@ class CUDAGraph(MultiGraphRunner):
|
|
76
76
|
def __del__(self):
|
77
77
|
if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))
|
78
78
|
if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
|
79
|
-
|
80
|
-
def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
|
81
|
-
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size
|