tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tinygrad/codegen/kernel.py +248 -115
  2. tinygrad/codegen/lowerer.py +215 -0
  3. tinygrad/codegen/transcendental.py +310 -0
  4. tinygrad/codegen/uopgraph.py +622 -0
  5. tinygrad/codegen/uops.py +235 -393
  6. tinygrad/device.py +428 -69
  7. tinygrad/dtype.py +18 -4
  8. tinygrad/engine/graph.py +19 -32
  9. tinygrad/engine/jit.py +148 -70
  10. tinygrad/engine/realize.py +127 -51
  11. tinygrad/engine/schedule.py +259 -216
  12. tinygrad/engine/search.py +29 -22
  13. tinygrad/function.py +9 -0
  14. tinygrad/helpers.py +87 -49
  15. tinygrad/lazy.py +34 -35
  16. tinygrad/multi.py +41 -36
  17. tinygrad/nn/__init__.py +39 -22
  18. tinygrad/nn/state.py +3 -3
  19. tinygrad/ops.py +63 -62
  20. tinygrad/renderer/__init__.py +43 -21
  21. tinygrad/renderer/assembly.py +104 -106
  22. tinygrad/renderer/cstyle.py +87 -60
  23. tinygrad/renderer/llvmir.py +21 -30
  24. tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
  25. tinygrad/runtime/autogen/cuda.py +6 -162
  26. tinygrad/runtime/autogen/kfd.py +32 -0
  27. tinygrad/runtime/autogen/libc.py +4260 -0
  28. tinygrad/runtime/autogen/nvrtc.py +579 -0
  29. tinygrad/runtime/graph/clang.py +2 -2
  30. tinygrad/runtime/graph/cuda.py +8 -11
  31. tinygrad/runtime/graph/hcq.py +120 -107
  32. tinygrad/runtime/graph/metal.py +18 -15
  33. tinygrad/runtime/ops_amd.py +197 -305
  34. tinygrad/runtime/ops_clang.py +2 -2
  35. tinygrad/runtime/ops_cuda.py +36 -94
  36. tinygrad/runtime/ops_disk.py +3 -7
  37. tinygrad/runtime/ops_gpu.py +4 -2
  38. tinygrad/runtime/ops_hip.py +70 -0
  39. tinygrad/runtime/ops_metal.py +38 -27
  40. tinygrad/runtime/ops_nv.py +283 -363
  41. tinygrad/runtime/ops_python.py +26 -30
  42. tinygrad/runtime/support/compiler_cuda.py +78 -0
  43. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
  44. tinygrad/runtime/support/elf.py +38 -0
  45. tinygrad/shape/shapetracker.py +5 -14
  46. tinygrad/shape/symbolic.py +4 -8
  47. tinygrad/shape/view.py +34 -22
  48. tinygrad/tensor.py +399 -97
  49. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
  50. tinygrad-0.9.2.dist-info/RECORD +70 -0
  51. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
  52. tinygrad/codegen/linearizer.py +0 -528
  53. tinygrad-0.9.1.dist-info/RECORD +0 -63
  54. /tinygrad/runtime/{driver → support}/__init__.py +0 -0
  55. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
  56. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,579 @@
1
+ # mypy: ignore-errors
2
+ # -*- coding: utf-8 -*-
3
+ #
4
+ # TARGET arch is: []
5
+ # WORD_SIZE is: 8
6
+ # POINTER_SIZE is: 8
7
+ # LONGDOUBLE_SIZE is: 16
8
+ #
9
+ import ctypes, ctypes.util
10
+
11
+
12
+ _libraries = {}
13
+ _libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc'))
14
+ def string_cast(char_pointer, encoding='utf-8', errors='strict'):
15
+ value = ctypes.cast(char_pointer, ctypes.c_char_p).value
16
+ if value is not None and encoding is not None:
17
+ value = value.decode(encoding, errors=errors)
18
+ return value
19
+
20
+
21
+ def char_pointer_cast(string, encoding='utf-8'):
22
+ if encoding is not None:
23
+ try:
24
+ string = string.encode(encoding)
25
+ except AttributeError:
26
+ # In Python3, bytes has no encode attribute
27
+ pass
28
+ string = ctypes.c_char_p(string)
29
+ return ctypes.cast(string, ctypes.POINTER(ctypes.c_char))
30
+
31
+
32
+
33
+ class AsDictMixin:
34
+ @classmethod
35
+ def as_dict(cls, self):
36
+ result = {}
37
+ if not isinstance(self, AsDictMixin):
38
+ # not a structure, assume it's already a python object
39
+ return self
40
+ if not hasattr(cls, "_fields_"):
41
+ return result
42
+ # sys.version_info >= (3, 5)
43
+ # for (field, *_) in cls._fields_: # noqa
44
+ for field_tuple in cls._fields_: # noqa
45
+ field = field_tuple[0]
46
+ if field.startswith('PADDING_'):
47
+ continue
48
+ value = getattr(self, field)
49
+ type_ = type(value)
50
+ if hasattr(value, "_length_") and hasattr(value, "_type_"):
51
+ # array
52
+ if not hasattr(type_, "as_dict"):
53
+ value = [v for v in value]
54
+ else:
55
+ type_ = type_._type_
56
+ value = [type_.as_dict(v) for v in value]
57
+ elif hasattr(value, "contents") and hasattr(value, "_type_"):
58
+ # pointer
59
+ try:
60
+ if not hasattr(type_, "as_dict"):
61
+ value = value.contents
62
+ else:
63
+ type_ = type_._type_
64
+ value = type_.as_dict(value.contents)
65
+ except ValueError:
66
+ # nullptr
67
+ value = None
68
+ elif isinstance(value, AsDictMixin):
69
+ # other structure
70
+ value = type_.as_dict(value)
71
+ result[field] = value
72
+ return result
73
+
74
+
75
+ class Structure(ctypes.Structure, AsDictMixin):
76
+
77
+ def __init__(self, *args, **kwds):
78
+ # We don't want to use positional arguments fill PADDING_* fields
79
+
80
+ args = dict(zip(self.__class__._field_names_(), args))
81
+ args.update(kwds)
82
+ super(Structure, self).__init__(**args)
83
+
84
+ @classmethod
85
+ def _field_names_(cls):
86
+ if hasattr(cls, '_fields_'):
87
+ return (f[0] for f in cls._fields_ if not f[0].startswith('PADDING'))
88
+ else:
89
+ return ()
90
+
91
+ @classmethod
92
+ def get_type(cls, field):
93
+ for f in cls._fields_:
94
+ if f[0] == field:
95
+ return f[1]
96
+ return None
97
+
98
+ @classmethod
99
+ def bind(cls, bound_fields):
100
+ fields = {}
101
+ for name, type_ in cls._fields_:
102
+ if hasattr(type_, "restype"):
103
+ if name in bound_fields:
104
+ if bound_fields[name] is None:
105
+ fields[name] = type_()
106
+ else:
107
+ # use a closure to capture the callback from the loop scope
108
+ fields[name] = (
109
+ type_((lambda callback: lambda *args: callback(*args))(
110
+ bound_fields[name]))
111
+ )
112
+ del bound_fields[name]
113
+ else:
114
+ # default callback implementation (does nothing)
115
+ try:
116
+ default_ = type_(0).restype().value
117
+ except TypeError:
118
+ default_ = None
119
+ fields[name] = type_((
120
+ lambda default_: lambda *args: default_)(default_))
121
+ else:
122
+ # not a callback function, use default initialization
123
+ if name in bound_fields:
124
+ fields[name] = bound_fields[name]
125
+ del bound_fields[name]
126
+ else:
127
+ fields[name] = type_()
128
+ if len(bound_fields) != 0:
129
+ raise ValueError(
130
+ "Cannot bind the following unknown callback(s) {}.{}".format(
131
+ cls.__name__, bound_fields.keys()
132
+ ))
133
+ return cls(**fields)
134
+
135
+
136
+ class Union(ctypes.Union, AsDictMixin):
137
+ pass
138
+
139
+
140
+
141
+ _libraries['libnvJitLink.so'] = ctypes.CDLL(ctypes.util.find_library('nvJitLink'))
142
+ c_int128 = ctypes.c_ubyte*16
143
+ c_uint128 = c_int128
144
+ void = None
145
+ if ctypes.sizeof(ctypes.c_longdouble) == 16:
146
+ c_long_double_t = ctypes.c_longdouble
147
+ else:
148
+ c_long_double_t = ctypes.c_ubyte*16
149
+
150
+
151
+
152
+
153
+ # values for enumeration 'c__EA_nvrtcResult'
154
+ c__EA_nvrtcResult__enumvalues = {
155
+ 0: 'NVRTC_SUCCESS',
156
+ 1: 'NVRTC_ERROR_OUT_OF_MEMORY',
157
+ 2: 'NVRTC_ERROR_PROGRAM_CREATION_FAILURE',
158
+ 3: 'NVRTC_ERROR_INVALID_INPUT',
159
+ 4: 'NVRTC_ERROR_INVALID_PROGRAM',
160
+ 5: 'NVRTC_ERROR_INVALID_OPTION',
161
+ 6: 'NVRTC_ERROR_COMPILATION',
162
+ 7: 'NVRTC_ERROR_BUILTIN_OPERATION_FAILURE',
163
+ 8: 'NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION',
164
+ 9: 'NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION',
165
+ 10: 'NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID',
166
+ 11: 'NVRTC_ERROR_INTERNAL_ERROR',
167
+ 12: 'NVRTC_ERROR_TIME_FILE_WRITE_FAILED',
168
+ }
169
+ NVRTC_SUCCESS = 0
170
+ NVRTC_ERROR_OUT_OF_MEMORY = 1
171
+ NVRTC_ERROR_PROGRAM_CREATION_FAILURE = 2
172
+ NVRTC_ERROR_INVALID_INPUT = 3
173
+ NVRTC_ERROR_INVALID_PROGRAM = 4
174
+ NVRTC_ERROR_INVALID_OPTION = 5
175
+ NVRTC_ERROR_COMPILATION = 6
176
+ NVRTC_ERROR_BUILTIN_OPERATION_FAILURE = 7
177
+ NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = 8
178
+ NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = 9
179
+ NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID = 10
180
+ NVRTC_ERROR_INTERNAL_ERROR = 11
181
+ NVRTC_ERROR_TIME_FILE_WRITE_FAILED = 12
182
+ c__EA_nvrtcResult = ctypes.c_uint32 # enum
183
+ nvrtcResult = c__EA_nvrtcResult
184
+ nvrtcResult__enumvalues = c__EA_nvrtcResult__enumvalues
185
+ try:
186
+ nvrtcGetErrorString = _libraries['libnvrtc.so'].nvrtcGetErrorString
187
+ nvrtcGetErrorString.restype = ctypes.POINTER(ctypes.c_char)
188
+ nvrtcGetErrorString.argtypes = [nvrtcResult]
189
+ except AttributeError:
190
+ pass
191
+ try:
192
+ nvrtcVersion = _libraries['libnvrtc.so'].nvrtcVersion
193
+ nvrtcVersion.restype = nvrtcResult
194
+ nvrtcVersion.argtypes = [ctypes.POINTER(ctypes.c_int32), ctypes.POINTER(ctypes.c_int32)]
195
+ except AttributeError:
196
+ pass
197
+ try:
198
+ nvrtcGetNumSupportedArchs = _libraries['libnvrtc.so'].nvrtcGetNumSupportedArchs
199
+ nvrtcGetNumSupportedArchs.restype = nvrtcResult
200
+ nvrtcGetNumSupportedArchs.argtypes = [ctypes.POINTER(ctypes.c_int32)]
201
+ except AttributeError:
202
+ pass
203
+ try:
204
+ nvrtcGetSupportedArchs = _libraries['libnvrtc.so'].nvrtcGetSupportedArchs
205
+ nvrtcGetSupportedArchs.restype = nvrtcResult
206
+ nvrtcGetSupportedArchs.argtypes = [ctypes.POINTER(ctypes.c_int32)]
207
+ except AttributeError:
208
+ pass
209
+ class struct__nvrtcProgram(Structure):
210
+ pass
211
+
212
+ nvrtcProgram = ctypes.POINTER(struct__nvrtcProgram)
213
+ try:
214
+ nvrtcCreateProgram = _libraries['libnvrtc.so'].nvrtcCreateProgram
215
+ nvrtcCreateProgram.restype = nvrtcResult
216
+ nvrtcCreateProgram.argtypes = [ctypes.POINTER(ctypes.POINTER(struct__nvrtcProgram)), ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char), ctypes.c_int32, ctypes.POINTER(ctypes.POINTER(ctypes.c_char)), ctypes.POINTER(ctypes.POINTER(ctypes.c_char))]
217
+ except AttributeError:
218
+ pass
219
+ try:
220
+ nvrtcDestroyProgram = _libraries['libnvrtc.so'].nvrtcDestroyProgram
221
+ nvrtcDestroyProgram.restype = nvrtcResult
222
+ nvrtcDestroyProgram.argtypes = [ctypes.POINTER(ctypes.POINTER(struct__nvrtcProgram))]
223
+ except AttributeError:
224
+ pass
225
+ try:
226
+ nvrtcCompileProgram = _libraries['libnvrtc.so'].nvrtcCompileProgram
227
+ nvrtcCompileProgram.restype = nvrtcResult
228
+ nvrtcCompileProgram.argtypes = [nvrtcProgram, ctypes.c_int32, ctypes.POINTER(ctypes.POINTER(ctypes.c_char))]
229
+ except AttributeError:
230
+ pass
231
+ try:
232
+ nvrtcGetPTXSize = _libraries['libnvrtc.so'].nvrtcGetPTXSize
233
+ nvrtcGetPTXSize.restype = nvrtcResult
234
+ nvrtcGetPTXSize.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_uint64)]
235
+ except AttributeError:
236
+ pass
237
+ try:
238
+ nvrtcGetPTX = _libraries['libnvrtc.so'].nvrtcGetPTX
239
+ nvrtcGetPTX.restype = nvrtcResult
240
+ nvrtcGetPTX.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_char)]
241
+ except AttributeError:
242
+ pass
243
+ try:
244
+ nvrtcGetCUBINSize = _libraries['libnvrtc.so'].nvrtcGetCUBINSize
245
+ nvrtcGetCUBINSize.restype = nvrtcResult
246
+ nvrtcGetCUBINSize.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_uint64)]
247
+ except AttributeError:
248
+ pass
249
+ try:
250
+ nvrtcGetCUBIN = _libraries['libnvrtc.so'].nvrtcGetCUBIN
251
+ nvrtcGetCUBIN.restype = nvrtcResult
252
+ nvrtcGetCUBIN.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_char)]
253
+ except AttributeError:
254
+ pass
255
+ try:
256
+ nvrtcGetNVVMSize = _libraries['libnvrtc.so'].nvrtcGetNVVMSize
257
+ nvrtcGetNVVMSize.restype = nvrtcResult
258
+ nvrtcGetNVVMSize.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_uint64)]
259
+ except AttributeError:
260
+ pass
261
+ try:
262
+ nvrtcGetNVVM = _libraries['libnvrtc.so'].nvrtcGetNVVM
263
+ nvrtcGetNVVM.restype = nvrtcResult
264
+ nvrtcGetNVVM.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_char)]
265
+ except AttributeError:
266
+ pass
267
+ try:
268
+ nvrtcGetLTOIRSize = _libraries['libnvrtc.so'].nvrtcGetLTOIRSize
269
+ nvrtcGetLTOIRSize.restype = nvrtcResult
270
+ nvrtcGetLTOIRSize.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_uint64)]
271
+ except AttributeError:
272
+ pass
273
+ try:
274
+ nvrtcGetLTOIR = _libraries['libnvrtc.so'].nvrtcGetLTOIR
275
+ nvrtcGetLTOIR.restype = nvrtcResult
276
+ nvrtcGetLTOIR.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_char)]
277
+ except AttributeError:
278
+ pass
279
+ try:
280
+ nvrtcGetOptiXIRSize = _libraries['libnvrtc.so'].nvrtcGetOptiXIRSize
281
+ nvrtcGetOptiXIRSize.restype = nvrtcResult
282
+ nvrtcGetOptiXIRSize.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_uint64)]
283
+ except AttributeError:
284
+ pass
285
+ try:
286
+ nvrtcGetOptiXIR = _libraries['libnvrtc.so'].nvrtcGetOptiXIR
287
+ nvrtcGetOptiXIR.restype = nvrtcResult
288
+ nvrtcGetOptiXIR.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_char)]
289
+ except AttributeError:
290
+ pass
291
+ try:
292
+ nvrtcGetProgramLogSize = _libraries['libnvrtc.so'].nvrtcGetProgramLogSize
293
+ nvrtcGetProgramLogSize.restype = nvrtcResult
294
+ nvrtcGetProgramLogSize.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_uint64)]
295
+ except AttributeError:
296
+ pass
297
+ try:
298
+ nvrtcGetProgramLog = _libraries['libnvrtc.so'].nvrtcGetProgramLog
299
+ nvrtcGetProgramLog.restype = nvrtcResult
300
+ nvrtcGetProgramLog.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_char)]
301
+ except AttributeError:
302
+ pass
303
+ try:
304
+ nvrtcAddNameExpression = _libraries['libnvrtc.so'].nvrtcAddNameExpression
305
+ nvrtcAddNameExpression.restype = nvrtcResult
306
+ nvrtcAddNameExpression.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_char)]
307
+ except AttributeError:
308
+ pass
309
+ try:
310
+ nvrtcGetLoweredName = _libraries['libnvrtc.so'].nvrtcGetLoweredName
311
+ nvrtcGetLoweredName.restype = nvrtcResult
312
+ nvrtcGetLoweredName.argtypes = [nvrtcProgram, ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.POINTER(ctypes.c_char))]
313
+ except AttributeError:
314
+ pass
315
+
316
+ # values for enumeration 'c__EA_nvJitLinkResult'
317
+ c__EA_nvJitLinkResult__enumvalues = {
318
+ 0: 'NVJITLINK_SUCCESS',
319
+ 1: 'NVJITLINK_ERROR_UNRECOGNIZED_OPTION',
320
+ 2: 'NVJITLINK_ERROR_MISSING_ARCH',
321
+ 3: 'NVJITLINK_ERROR_INVALID_INPUT',
322
+ 4: 'NVJITLINK_ERROR_PTX_COMPILE',
323
+ 5: 'NVJITLINK_ERROR_NVVM_COMPILE',
324
+ 6: 'NVJITLINK_ERROR_INTERNAL',
325
+ 7: 'NVJITLINK_ERROR_THREADPOOL',
326
+ 8: 'NVJITLINK_ERROR_UNRECOGNIZED_INPUT',
327
+ }
328
+ NVJITLINK_SUCCESS = 0
329
+ NVJITLINK_ERROR_UNRECOGNIZED_OPTION = 1
330
+ NVJITLINK_ERROR_MISSING_ARCH = 2
331
+ NVJITLINK_ERROR_INVALID_INPUT = 3
332
+ NVJITLINK_ERROR_PTX_COMPILE = 4
333
+ NVJITLINK_ERROR_NVVM_COMPILE = 5
334
+ NVJITLINK_ERROR_INTERNAL = 6
335
+ NVJITLINK_ERROR_THREADPOOL = 7
336
+ NVJITLINK_ERROR_UNRECOGNIZED_INPUT = 8
337
+ c__EA_nvJitLinkResult = ctypes.c_uint32 # enum
338
+ nvJitLinkResult = c__EA_nvJitLinkResult
339
+ nvJitLinkResult__enumvalues = c__EA_nvJitLinkResult__enumvalues
340
+
341
+ # values for enumeration 'c__EA_nvJitLinkInputType'
342
+ c__EA_nvJitLinkInputType__enumvalues = {
343
+ 0: 'NVJITLINK_INPUT_NONE',
344
+ 1: 'NVJITLINK_INPUT_CUBIN',
345
+ 2: 'NVJITLINK_INPUT_PTX',
346
+ 3: 'NVJITLINK_INPUT_LTOIR',
347
+ 4: 'NVJITLINK_INPUT_FATBIN',
348
+ 5: 'NVJITLINK_INPUT_OBJECT',
349
+ 6: 'NVJITLINK_INPUT_LIBRARY',
350
+ 10: 'NVJITLINK_INPUT_ANY',
351
+ }
352
+ NVJITLINK_INPUT_NONE = 0
353
+ NVJITLINK_INPUT_CUBIN = 1
354
+ NVJITLINK_INPUT_PTX = 2
355
+ NVJITLINK_INPUT_LTOIR = 3
356
+ NVJITLINK_INPUT_FATBIN = 4
357
+ NVJITLINK_INPUT_OBJECT = 5
358
+ NVJITLINK_INPUT_LIBRARY = 6
359
+ NVJITLINK_INPUT_ANY = 10
360
+ c__EA_nvJitLinkInputType = ctypes.c_uint32 # enum
361
+ nvJitLinkInputType = c__EA_nvJitLinkInputType
362
+ nvJitLinkInputType__enumvalues = c__EA_nvJitLinkInputType__enumvalues
363
+ class struct_nvJitLink(Structure):
364
+ pass
365
+
366
+ nvJitLinkHandle = ctypes.POINTER(struct_nvJitLink)
367
+ uint32_t = ctypes.c_uint32
368
+ try:
369
+ __nvJitLinkCreate_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkCreate_12_4
370
+ __nvJitLinkCreate_12_4.restype = nvJitLinkResult
371
+ __nvJitLinkCreate_12_4.argtypes = [ctypes.POINTER(ctypes.POINTER(struct_nvJitLink)), uint32_t, ctypes.POINTER(ctypes.POINTER(ctypes.c_char))]
372
+ except AttributeError:
373
+ pass
374
+ try:
375
+ nvJitLinkCreate = _libraries['libnvJitLink.so'].nvJitLinkCreate
376
+ nvJitLinkCreate.restype = nvJitLinkResult
377
+ nvJitLinkCreate.argtypes = [ctypes.POINTER(ctypes.POINTER(struct_nvJitLink)), uint32_t, ctypes.POINTER(ctypes.POINTER(ctypes.c_char))]
378
+ except AttributeError:
379
+ pass
380
+ try:
381
+ __nvJitLinkDestroy_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkDestroy_12_4
382
+ __nvJitLinkDestroy_12_4.restype = nvJitLinkResult
383
+ __nvJitLinkDestroy_12_4.argtypes = [ctypes.POINTER(ctypes.POINTER(struct_nvJitLink))]
384
+ except AttributeError:
385
+ pass
386
+ try:
387
+ nvJitLinkDestroy = _libraries['libnvJitLink.so'].nvJitLinkDestroy
388
+ nvJitLinkDestroy.restype = nvJitLinkResult
389
+ nvJitLinkDestroy.argtypes = [ctypes.POINTER(ctypes.POINTER(struct_nvJitLink))]
390
+ except AttributeError:
391
+ pass
392
+ size_t = ctypes.c_uint64
393
+ try:
394
+ __nvJitLinkAddData_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkAddData_12_4
395
+ __nvJitLinkAddData_12_4.restype = nvJitLinkResult
396
+ __nvJitLinkAddData_12_4.argtypes = [nvJitLinkHandle, nvJitLinkInputType, ctypes.POINTER(None), size_t, ctypes.POINTER(ctypes.c_char)]
397
+ except AttributeError:
398
+ pass
399
+ try:
400
+ nvJitLinkAddData = _libraries['libnvJitLink.so'].nvJitLinkAddData
401
+ nvJitLinkAddData.restype = nvJitLinkResult
402
+ nvJitLinkAddData.argtypes = [nvJitLinkHandle, nvJitLinkInputType, ctypes.POINTER(None), size_t, ctypes.POINTER(ctypes.c_char)]
403
+ except AttributeError:
404
+ pass
405
+ try:
406
+ __nvJitLinkAddFile_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkAddFile_12_4
407
+ __nvJitLinkAddFile_12_4.restype = nvJitLinkResult
408
+ __nvJitLinkAddFile_12_4.argtypes = [nvJitLinkHandle, nvJitLinkInputType, ctypes.POINTER(ctypes.c_char)]
409
+ except AttributeError:
410
+ pass
411
+ try:
412
+ nvJitLinkAddFile = _libraries['libnvJitLink.so'].nvJitLinkAddFile
413
+ nvJitLinkAddFile.restype = nvJitLinkResult
414
+ nvJitLinkAddFile.argtypes = [nvJitLinkHandle, nvJitLinkInputType, ctypes.POINTER(ctypes.c_char)]
415
+ except AttributeError:
416
+ pass
417
+ try:
418
+ __nvJitLinkComplete_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkComplete_12_4
419
+ __nvJitLinkComplete_12_4.restype = nvJitLinkResult
420
+ __nvJitLinkComplete_12_4.argtypes = [nvJitLinkHandle]
421
+ except AttributeError:
422
+ pass
423
+ try:
424
+ nvJitLinkComplete = _libraries['libnvJitLink.so'].nvJitLinkComplete
425
+ nvJitLinkComplete.restype = nvJitLinkResult
426
+ nvJitLinkComplete.argtypes = [nvJitLinkHandle]
427
+ except AttributeError:
428
+ pass
429
+ try:
430
+ __nvJitLinkGetLinkedCubinSize_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkGetLinkedCubinSize_12_4
431
+ __nvJitLinkGetLinkedCubinSize_12_4.restype = nvJitLinkResult
432
+ __nvJitLinkGetLinkedCubinSize_12_4.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_uint64)]
433
+ except AttributeError:
434
+ pass
435
+ try:
436
+ nvJitLinkGetLinkedCubinSize = _libraries['libnvJitLink.so'].nvJitLinkGetLinkedCubinSize
437
+ nvJitLinkGetLinkedCubinSize.restype = nvJitLinkResult
438
+ nvJitLinkGetLinkedCubinSize.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_uint64)]
439
+ except AttributeError:
440
+ pass
441
+ try:
442
+ __nvJitLinkGetLinkedCubin_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkGetLinkedCubin_12_4
443
+ __nvJitLinkGetLinkedCubin_12_4.restype = nvJitLinkResult
444
+ __nvJitLinkGetLinkedCubin_12_4.argtypes = [nvJitLinkHandle, ctypes.POINTER(None)]
445
+ except AttributeError:
446
+ pass
447
+ try:
448
+ nvJitLinkGetLinkedCubin = _libraries['libnvJitLink.so'].nvJitLinkGetLinkedCubin
449
+ nvJitLinkGetLinkedCubin.restype = nvJitLinkResult
450
+ nvJitLinkGetLinkedCubin.argtypes = [nvJitLinkHandle, ctypes.POINTER(None)]
451
+ except AttributeError:
452
+ pass
453
+ try:
454
+ __nvJitLinkGetLinkedPtxSize_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkGetLinkedPtxSize_12_4
455
+ __nvJitLinkGetLinkedPtxSize_12_4.restype = nvJitLinkResult
456
+ __nvJitLinkGetLinkedPtxSize_12_4.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_uint64)]
457
+ except AttributeError:
458
+ pass
459
+ try:
460
+ nvJitLinkGetLinkedPtxSize = _libraries['libnvJitLink.so'].nvJitLinkGetLinkedPtxSize
461
+ nvJitLinkGetLinkedPtxSize.restype = nvJitLinkResult
462
+ nvJitLinkGetLinkedPtxSize.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_uint64)]
463
+ except AttributeError:
464
+ pass
465
+ try:
466
+ __nvJitLinkGetLinkedPtx_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkGetLinkedPtx_12_4
467
+ __nvJitLinkGetLinkedPtx_12_4.restype = nvJitLinkResult
468
+ __nvJitLinkGetLinkedPtx_12_4.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_char)]
469
+ except AttributeError:
470
+ pass
471
+ try:
472
+ nvJitLinkGetLinkedPtx = _libraries['libnvJitLink.so'].nvJitLinkGetLinkedPtx
473
+ nvJitLinkGetLinkedPtx.restype = nvJitLinkResult
474
+ nvJitLinkGetLinkedPtx.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_char)]
475
+ except AttributeError:
476
+ pass
477
+ try:
478
+ __nvJitLinkGetErrorLogSize_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkGetErrorLogSize_12_4
479
+ __nvJitLinkGetErrorLogSize_12_4.restype = nvJitLinkResult
480
+ __nvJitLinkGetErrorLogSize_12_4.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_uint64)]
481
+ except AttributeError:
482
+ pass
483
+ try:
484
+ nvJitLinkGetErrorLogSize = _libraries['libnvJitLink.so'].nvJitLinkGetErrorLogSize
485
+ nvJitLinkGetErrorLogSize.restype = nvJitLinkResult
486
+ nvJitLinkGetErrorLogSize.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_uint64)]
487
+ except AttributeError:
488
+ pass
489
+ try:
490
+ __nvJitLinkGetErrorLog_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkGetErrorLog_12_4
491
+ __nvJitLinkGetErrorLog_12_4.restype = nvJitLinkResult
492
+ __nvJitLinkGetErrorLog_12_4.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_char)]
493
+ except AttributeError:
494
+ pass
495
+ try:
496
+ nvJitLinkGetErrorLog = _libraries['libnvJitLink.so'].nvJitLinkGetErrorLog
497
+ nvJitLinkGetErrorLog.restype = nvJitLinkResult
498
+ nvJitLinkGetErrorLog.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_char)]
499
+ except AttributeError:
500
+ pass
501
+ try:
502
+ __nvJitLinkGetInfoLogSize_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkGetInfoLogSize_12_4
503
+ __nvJitLinkGetInfoLogSize_12_4.restype = nvJitLinkResult
504
+ __nvJitLinkGetInfoLogSize_12_4.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_uint64)]
505
+ except AttributeError:
506
+ pass
507
+ try:
508
+ nvJitLinkGetInfoLogSize = _libraries['libnvJitLink.so'].nvJitLinkGetInfoLogSize
509
+ nvJitLinkGetInfoLogSize.restype = nvJitLinkResult
510
+ nvJitLinkGetInfoLogSize.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_uint64)]
511
+ except AttributeError:
512
+ pass
513
+ try:
514
+ __nvJitLinkGetInfoLog_12_4 = _libraries['libnvJitLink.so'].__nvJitLinkGetInfoLog_12_4
515
+ __nvJitLinkGetInfoLog_12_4.restype = nvJitLinkResult
516
+ __nvJitLinkGetInfoLog_12_4.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_char)]
517
+ except AttributeError:
518
+ pass
519
+ try:
520
+ nvJitLinkGetInfoLog = _libraries['libnvJitLink.so'].nvJitLinkGetInfoLog
521
+ nvJitLinkGetInfoLog.restype = nvJitLinkResult
522
+ nvJitLinkGetInfoLog.argtypes = [nvJitLinkHandle, ctypes.POINTER(ctypes.c_char)]
523
+ except AttributeError:
524
+ pass
525
+ try:
526
+ nvJitLinkVersion = _libraries['libnvJitLink.so'].nvJitLinkVersion
527
+ nvJitLinkVersion.restype = nvJitLinkResult
528
+ nvJitLinkVersion.argtypes = [ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(ctypes.c_uint32)]
529
+ except AttributeError:
530
+ pass
531
+ __all__ = \
532
+ ['NVJITLINK_ERROR_INTERNAL', 'NVJITLINK_ERROR_INVALID_INPUT',
533
+ 'NVJITLINK_ERROR_MISSING_ARCH', 'NVJITLINK_ERROR_NVVM_COMPILE',
534
+ 'NVJITLINK_ERROR_PTX_COMPILE', 'NVJITLINK_ERROR_THREADPOOL',
535
+ 'NVJITLINK_ERROR_UNRECOGNIZED_INPUT',
536
+ 'NVJITLINK_ERROR_UNRECOGNIZED_OPTION', 'NVJITLINK_INPUT_ANY',
537
+ 'NVJITLINK_INPUT_CUBIN', 'NVJITLINK_INPUT_FATBIN',
538
+ 'NVJITLINK_INPUT_LIBRARY', 'NVJITLINK_INPUT_LTOIR',
539
+ 'NVJITLINK_INPUT_NONE', 'NVJITLINK_INPUT_OBJECT',
540
+ 'NVJITLINK_INPUT_PTX', 'NVJITLINK_SUCCESS',
541
+ 'NVRTC_ERROR_BUILTIN_OPERATION_FAILURE',
542
+ 'NVRTC_ERROR_COMPILATION', 'NVRTC_ERROR_INTERNAL_ERROR',
543
+ 'NVRTC_ERROR_INVALID_INPUT', 'NVRTC_ERROR_INVALID_OPTION',
544
+ 'NVRTC_ERROR_INVALID_PROGRAM',
545
+ 'NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID',
546
+ 'NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION',
547
+ 'NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION',
548
+ 'NVRTC_ERROR_OUT_OF_MEMORY',
549
+ 'NVRTC_ERROR_PROGRAM_CREATION_FAILURE',
550
+ 'NVRTC_ERROR_TIME_FILE_WRITE_FAILED', 'NVRTC_SUCCESS',
551
+ '__nvJitLinkAddData_12_4', '__nvJitLinkAddFile_12_4',
552
+ '__nvJitLinkComplete_12_4', '__nvJitLinkCreate_12_4',
553
+ '__nvJitLinkDestroy_12_4', '__nvJitLinkGetErrorLogSize_12_4',
554
+ '__nvJitLinkGetErrorLog_12_4', '__nvJitLinkGetInfoLogSize_12_4',
555
+ '__nvJitLinkGetInfoLog_12_4',
556
+ '__nvJitLinkGetLinkedCubinSize_12_4',
557
+ '__nvJitLinkGetLinkedCubin_12_4',
558
+ '__nvJitLinkGetLinkedPtxSize_12_4',
559
+ '__nvJitLinkGetLinkedPtx_12_4', 'c__EA_nvJitLinkInputType',
560
+ 'c__EA_nvJitLinkResult', 'c__EA_nvrtcResult', 'nvJitLinkAddData',
561
+ 'nvJitLinkAddFile', 'nvJitLinkComplete', 'nvJitLinkCreate',
562
+ 'nvJitLinkDestroy', 'nvJitLinkGetErrorLog',
563
+ 'nvJitLinkGetErrorLogSize', 'nvJitLinkGetInfoLog',
564
+ 'nvJitLinkGetInfoLogSize', 'nvJitLinkGetLinkedCubin',
565
+ 'nvJitLinkGetLinkedCubinSize', 'nvJitLinkGetLinkedPtx',
566
+ 'nvJitLinkGetLinkedPtxSize', 'nvJitLinkHandle',
567
+ 'nvJitLinkInputType', 'nvJitLinkInputType__enumvalues',
568
+ 'nvJitLinkResult', 'nvJitLinkResult__enumvalues',
569
+ 'nvJitLinkVersion', 'nvrtcAddNameExpression',
570
+ 'nvrtcCompileProgram', 'nvrtcCreateProgram',
571
+ 'nvrtcDestroyProgram', 'nvrtcGetCUBIN', 'nvrtcGetCUBINSize',
572
+ 'nvrtcGetErrorString', 'nvrtcGetLTOIR', 'nvrtcGetLTOIRSize',
573
+ 'nvrtcGetLoweredName', 'nvrtcGetNVVM', 'nvrtcGetNVVMSize',
574
+ 'nvrtcGetNumSupportedArchs', 'nvrtcGetOptiXIR',
575
+ 'nvrtcGetOptiXIRSize', 'nvrtcGetPTX', 'nvrtcGetPTXSize',
576
+ 'nvrtcGetProgramLog', 'nvrtcGetProgramLogSize',
577
+ 'nvrtcGetSupportedArchs', 'nvrtcProgram', 'nvrtcResult',
578
+ 'nvrtcResult__enumvalues', 'nvrtcVersion', 'size_t',
579
+ 'struct__nvrtcProgram', 'struct_nvJitLink', 'uint32_t']
@@ -1,7 +1,7 @@
1
1
  from typing import List, Dict, cast
2
2
  import ctypes
3
- from tinygrad.helpers import dedup, cpu_time_execution, GraphException, DEBUG
4
- from tinygrad.engine.jit import GraphRunner
3
+ from tinygrad.helpers import dedup, cpu_time_execution, DEBUG
4
+ from tinygrad.engine.jit import GraphRunner, GraphException
5
5
  from tinygrad.device import Buffer, Device
6
6
  from tinygrad.engine.realize import ExecItem, CompiledRunner
7
7
  from tinygrad.shape.symbolic import Variable
@@ -1,12 +1,12 @@
1
1
  import ctypes
2
2
  from typing import Any, Optional, Tuple, Dict, List, cast
3
3
  import tinygrad.runtime.autogen.cuda as cuda
4
- from tinygrad.helpers import init_c_var, GraphException, dedup
4
+ from tinygrad.helpers import init_c_var, dedup
5
5
  from tinygrad.device import Buffer, Device
6
6
  from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
7
7
  from tinygrad.shape.symbolic import Variable
8
8
  from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
9
- from tinygrad.engine.jit import MultiGraphRunner
9
+ from tinygrad.engine.jit import MultiGraphRunner, GraphException
10
10
 
11
11
  class CUDAGraph(MultiGraphRunner):
12
12
  def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
@@ -33,7 +33,7 @@ class CUDAGraph(MultiGraphRunner):
33
33
  kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs)
34
34
  check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
35
35
 
36
- if j in self.jc_idx_with_updatable_launch_dims or j in self.jc_idx_with_updatable_var_vals or j in self.jc_idx_with_updatable_rawbufs:
36
+ if j in self.launch_dims_replace or j in self.var_vals_replace or j in self.jc_idx_with_updatable_rawbufs:
37
37
  self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
38
38
  elif isinstance(ji.prg, BufferXfer):
39
39
  dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
@@ -58,13 +58,13 @@ class CUDAGraph(MultiGraphRunner):
58
58
  elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
59
59
 
60
60
  # Update var_vals in the c_args struct.
61
- for j in self.jc_idx_with_updatable_var_vals:
62
- for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
63
- setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
61
+ for j, i, v in self.updated_vars(var_vals): setattr(self.updatable_nodes[j][2], f'v{i}', v)
64
62
 
65
63
  # Update launch dims in the kern_params struct.
66
- for j in self.jc_idx_with_updatable_launch_dims:
67
- self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
64
+ for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
65
+ prg = cast(CompiledRunner, self.jit_cache[j].prg)
66
+ node, global_size, local_size = self.updatable_nodes[j][1], global_dims or prg.p.global_size, local_dims or prg.p.local_size
67
+ node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size # type: ignore[misc]
68
68
 
69
69
  # Update graph nodes with the updated structs.
70
70
  for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
@@ -76,6 +76,3 @@ class CUDAGraph(MultiGraphRunner):
76
76
  def __del__(self):
77
77
  if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))
78
78
  if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
79
-
80
- def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
81
- node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size