tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +248 -115
- tinygrad/codegen/lowerer.py +215 -0
- tinygrad/codegen/transcendental.py +310 -0
- tinygrad/codegen/uopgraph.py +622 -0
- tinygrad/codegen/uops.py +235 -393
- tinygrad/device.py +428 -69
- tinygrad/dtype.py +18 -4
- tinygrad/engine/graph.py +19 -32
- tinygrad/engine/jit.py +148 -70
- tinygrad/engine/realize.py +127 -51
- tinygrad/engine/schedule.py +259 -216
- tinygrad/engine/search.py +29 -22
- tinygrad/function.py +9 -0
- tinygrad/helpers.py +87 -49
- tinygrad/lazy.py +34 -35
- tinygrad/multi.py +41 -36
- tinygrad/nn/__init__.py +39 -22
- tinygrad/nn/state.py +3 -3
- tinygrad/ops.py +63 -62
- tinygrad/renderer/__init__.py +43 -21
- tinygrad/renderer/assembly.py +104 -106
- tinygrad/renderer/cstyle.py +87 -60
- tinygrad/renderer/llvmir.py +21 -30
- tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/kfd.py +32 -0
- tinygrad/runtime/autogen/libc.py +4260 -0
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/graph/clang.py +2 -2
- tinygrad/runtime/graph/cuda.py +8 -11
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +18 -15
- tinygrad/runtime/ops_amd.py +197 -305
- tinygrad/runtime/ops_clang.py +2 -2
- tinygrad/runtime/ops_cuda.py +36 -94
- tinygrad/runtime/ops_disk.py +3 -7
- tinygrad/runtime/ops_gpu.py +4 -2
- tinygrad/runtime/ops_hip.py +70 -0
- tinygrad/runtime/ops_metal.py +38 -27
- tinygrad/runtime/ops_nv.py +283 -363
- tinygrad/runtime/ops_python.py +26 -30
- tinygrad/runtime/support/compiler_cuda.py +78 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/shape/shapetracker.py +5 -14
- tinygrad/shape/symbolic.py +4 -8
- tinygrad/shape/view.py +34 -22
- tinygrad/tensor.py +399 -97
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
- tinygrad-0.9.2.dist-info/RECORD +70 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/runtime/{driver → support}/__init__.py +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/device.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import multiprocessing
|
2
|
+
import multiprocessing, decimal, statistics, random
|
3
3
|
from dataclasses import dataclass
|
4
4
|
from collections import defaultdict
|
5
|
-
from typing import List, Optional, Dict, Tuple, Any, cast
|
6
|
-
import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib
|
7
|
-
from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
|
5
|
+
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type
|
6
|
+
import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array
|
7
|
+
from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
|
8
8
|
from tinygrad.dtype import DType, ImageDType
|
9
9
|
from tinygrad.renderer import Renderer
|
10
10
|
|
@@ -25,11 +25,12 @@ class _Device:
|
|
25
25
|
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
|
26
26
|
if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
|
27
27
|
return ret
|
28
|
+
@property
|
29
|
+
def default(self) -> Compiled: return self[self.DEFAULT]
|
28
30
|
@functools.cached_property
|
29
31
|
def DEFAULT(self) -> str:
|
30
|
-
|
31
|
-
|
32
|
-
for device in ["METAL", "AMD", "CUDA", "GPU", "CLANG", "LLVM"]:
|
32
|
+
if (from_env:=next((d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1), None)): return from_env
|
33
|
+
for device in ["METAL", "AMD", "NV", "CUDA", "GPU", "CLANG", "LLVM"]:
|
33
34
|
try:
|
34
35
|
if self[device]:
|
35
36
|
os.environ[device] = "1" # we set this in environment for spawned children
|
@@ -90,7 +91,7 @@ class Buffer:
|
|
90
91
|
if self._base is not None:
|
91
92
|
return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
|
92
93
|
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
|
93
|
-
if self.is_allocated():
|
94
|
+
if self.is_allocated() and not SAVE_SCHEDULE:
|
94
95
|
buf = bytearray(self.nbytes)
|
95
96
|
self.copyout(memoryview(buf))
|
96
97
|
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
|
@@ -140,6 +141,10 @@ class Allocator:
|
|
140
141
|
def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
|
141
142
|
|
142
143
|
class LRUAllocator(Allocator): # pylint: disable=abstract-method
|
144
|
+
"""
|
145
|
+
The LRU Allocator is responsible for caching buffers.
|
146
|
+
It ensures that buffers are not freed until it is absolutely necessary, optimizing performance.
|
147
|
+
"""
|
143
148
|
def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
|
144
149
|
def alloc(self, size:int, options:Optional[BufferOptions]=None):
|
145
150
|
if len(c := self.cache[(size, options)]): return c.pop()
|
@@ -182,89 +187,439 @@ class Compiled:
|
|
182
187
|
def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
|
183
188
|
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
|
184
189
|
self.renderer = renderer or Renderer()
|
185
|
-
def synchronize(self):
|
190
|
+
def synchronize(self):
|
191
|
+
"""
|
192
|
+
Synchronize all pending operations on the device.
|
193
|
+
|
194
|
+
This method ensures that all previously queued operations on the device have been completed before proceeding.
|
195
|
+
"""
|
196
|
+
# override this in your device implementation
|
186
197
|
|
187
198
|
# **************** for HCQ Compatible Devices ****************
|
188
199
|
|
200
|
+
def hcq_command(func):
|
201
|
+
"""
|
202
|
+
Decorator for HWCommandQueue commands. Enables command indexing and stores metadata for command updates.
|
203
|
+
|
204
|
+
For example:
|
205
|
+
```python
|
206
|
+
@hcq_command
|
207
|
+
def command_method(self, ...): ...
|
208
|
+
```
|
209
|
+
"""
|
210
|
+
def __wrapper(self, *args, **kwargs):
|
211
|
+
self.cmds_offset.append(len(self.q))
|
212
|
+
func(self, *args, **kwargs)
|
213
|
+
self.cmds_len.append(len(self.q) - self.cmds_offset[-1])
|
214
|
+
self.cmds_meta.append(func.__name__)
|
215
|
+
return self
|
216
|
+
return __wrapper
|
217
|
+
|
218
|
+
class HWCommandQueue:
|
219
|
+
"""
|
220
|
+
A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
|
221
|
+
Both compute and copy queues should have the following commands implemented.
|
222
|
+
"""
|
223
|
+
|
224
|
+
def __init__(self): self.q, self.binded_device, self.cmds_offset, self.cmds_len, self.cmds_meta = [], None, [], [], []
|
225
|
+
def __len__(self): return len(self.cmds_offset)
|
226
|
+
def _patch(self, cmd_idx, offset, data): self.q[(st:=self.cmds_offset[cmd_idx]+offset):st+len(data)] = array.array('I', data)
|
227
|
+
|
228
|
+
@hcq_command
|
229
|
+
def signal(self, signal:HCQSignal, value:int):
|
230
|
+
"""
|
231
|
+
Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
|
232
|
+
|
233
|
+
Args:
|
234
|
+
signal: The signal to set
|
235
|
+
value: The value to set the signal to
|
236
|
+
"""
|
237
|
+
self._signal(signal, value)
|
238
|
+
def _signal(self, signal:HCQSignal, value:int): raise NotImplementedError("backend should overload this function")
|
239
|
+
|
240
|
+
@hcq_command
|
241
|
+
def wait(self, signal:HCQSignal, value:int):
|
242
|
+
"""
|
243
|
+
Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
signal: The signal to wait on
|
247
|
+
value: The value to wait for
|
248
|
+
"""
|
249
|
+
self._wait(signal, value)
|
250
|
+
def _wait(self, signal, value): raise NotImplementedError("backend should overload this function")
|
251
|
+
|
252
|
+
@hcq_command
|
253
|
+
def timestamp(self, signal:HCQSignal):
|
254
|
+
"""
|
255
|
+
Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
|
256
|
+
|
257
|
+
Args:
|
258
|
+
signal: The signal to store the timestamp
|
259
|
+
"""
|
260
|
+
self._timestamp(signal)
|
261
|
+
def _timestamp(self, signal): raise NotImplementedError("backend should overload this function")
|
262
|
+
|
263
|
+
def update_signal(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
|
264
|
+
"""
|
265
|
+
Updates a previously queued signal command.
|
266
|
+
|
267
|
+
Args:
|
268
|
+
cmd_idx: Index of the signal command to update
|
269
|
+
signal: New signal to set (if None, keeps the original)
|
270
|
+
value: New value to set (if None, keeps the original)
|
271
|
+
"""
|
272
|
+
if self.cmds_meta[cmd_idx] != "signal": raise RuntimeError("called update_signal not on a signal command")
|
273
|
+
self._update_signal(cmd_idx, signal, value)
|
274
|
+
return self
|
275
|
+
def _update_signal(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
|
276
|
+
|
277
|
+
def update_wait(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
|
278
|
+
"""
|
279
|
+
Updates a previously queued wait command.
|
280
|
+
|
281
|
+
Args:
|
282
|
+
cmd_idx: Index of the wait command to update
|
283
|
+
signal: New signal to wait on (if None, keeps the original)
|
284
|
+
value: New value to wait for (if None, keeps the original)
|
285
|
+
"""
|
286
|
+
if self.cmds_meta[cmd_idx] != "wait": raise RuntimeError("called update_wait not on a wait command")
|
287
|
+
self._update_wait(cmd_idx, signal, value)
|
288
|
+
return self
|
289
|
+
def _update_wait(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
|
290
|
+
|
291
|
+
def bind(self, device:HCQCompiled):
|
292
|
+
"""
|
293
|
+
Associates the queue with a specific device for optimized execution.
|
294
|
+
|
295
|
+
This optional method allows backend implementations to tailor the queue for efficient use on the given device. When implemented, it can eliminate
|
296
|
+
the need to copy queues into the device, thereby enhancing performance.
|
297
|
+
|
298
|
+
Args:
|
299
|
+
device: The target device for queue optimization.
|
300
|
+
|
301
|
+
Note:
|
302
|
+
Implementing this method is optional but recommended for performance gains.
|
303
|
+
"""
|
304
|
+
|
305
|
+
def submit(self, device:HCQCompiled):
|
306
|
+
"""
|
307
|
+
Submits the command queue to a specific device for execution.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
device: The device to submit the queue to
|
311
|
+
"""
|
312
|
+
self._submit(device)
|
313
|
+
return self
|
314
|
+
def _submit(self, device:HCQCompiled): raise NotImplementedError("backend should overload this function")
|
315
|
+
|
316
|
+
class HWComputeQueue(HWCommandQueue):
|
317
|
+
@hcq_command
|
318
|
+
def memory_barrier(self):
|
319
|
+
"""
|
320
|
+
Enqueues a memory barrier command to ensure memory coherence between agents.
|
321
|
+
"""
|
322
|
+
self._memory_barrier()
|
323
|
+
def _memory_barrier(self): pass
|
324
|
+
|
325
|
+
@hcq_command
|
326
|
+
def exec(self, prg:HCQProgram, args_state:HCQArgsState, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]):
|
327
|
+
"""
|
328
|
+
Enqueues an execution command for a kernel program.
|
329
|
+
|
330
|
+
Args:
|
331
|
+
prg: The program to execute
|
332
|
+
args_state: The args state to execute program with
|
333
|
+
global_size: The global work size
|
334
|
+
local_size: The local work size
|
335
|
+
"""
|
336
|
+
self._exec(prg, args_state, global_size, local_size)
|
337
|
+
def _exec(self, prg, args_state, global_size, local_size): raise NotImplementedError("backend should overload this function")
|
338
|
+
|
339
|
+
def update_exec(self, cmd_idx:int, global_size:Optional[Tuple[int,int,int]]=None, local_size:Optional[Tuple[int,int,int]]=None):
|
340
|
+
"""
|
341
|
+
Updates a previously queued execution command.
|
342
|
+
|
343
|
+
Args:
|
344
|
+
cmd_idx: Index of the execution command to update
|
345
|
+
global_size: New global work size (if None, keeps the original)
|
346
|
+
local_size: New local work size (if None, keeps the original)
|
347
|
+
"""
|
348
|
+
if self.cmds_meta[cmd_idx] != "exec": raise RuntimeError("called update_exec not on an exec command")
|
349
|
+
self._update_exec(cmd_idx, global_size, local_size)
|
350
|
+
return self
|
351
|
+
def _update_exec(self, cmd_idx, global_size, local_size): raise NotImplementedError("backend should overload this function")
|
352
|
+
|
353
|
+
class HWCopyQueue(HWCommandQueue):
|
354
|
+
@hcq_command
|
355
|
+
def copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int):
|
356
|
+
"""
|
357
|
+
Enqueues a copy command to transfer data.
|
358
|
+
|
359
|
+
Args:
|
360
|
+
dest: The destination of the copy
|
361
|
+
src: The source of the copy
|
362
|
+
copy_size: The size of data to copy
|
363
|
+
"""
|
364
|
+
self._copy(dest, src, copy_size)
|
365
|
+
def _copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int): raise NotImplementedError("backend should overload this function")
|
366
|
+
|
367
|
+
def update_copy(self, cmd_idx:int, dest:Optional[HCQBuffer]=None, src:Optional[HCQBuffer]=None):
|
368
|
+
"""
|
369
|
+
Updates a previously queued copy command.
|
370
|
+
|
371
|
+
Args:
|
372
|
+
cmd_idx: Index of the copy command to update
|
373
|
+
dest: New destination of the copy (if None, keeps the original)
|
374
|
+
src: New source of the copy (if None, keeps the original)
|
375
|
+
"""
|
376
|
+
if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command")
|
377
|
+
self._update_copy(cmd_idx, dest, src)
|
378
|
+
return self
|
379
|
+
def _update_copy(self, cmd_idx, dest, src): raise NotImplementedError("backend should overload this function")
|
380
|
+
|
381
|
+
class HCQSignal:
|
382
|
+
def __init__(self, value:int=0): self._set_value(value)
|
383
|
+
|
384
|
+
@property
|
385
|
+
def value(self) -> int: return self._get_value()
|
386
|
+
|
387
|
+
@value.setter
|
388
|
+
def value(self, new_value:int): self._set_value(new_value)
|
389
|
+
|
390
|
+
def _get_value(self) -> int: raise NotImplementedError("_get_value() method must be implemented")
|
391
|
+
def _set_value(self, new_value:int): raise NotImplementedError("_set_value() method must be implemented")
|
392
|
+
|
393
|
+
@property
|
394
|
+
def timestamp(self) -> decimal.Decimal:
|
395
|
+
"""
|
396
|
+
Get the timestamp field of the signal.
|
397
|
+
|
398
|
+
This property provides read-only access to the signal's timestamp.
|
399
|
+
|
400
|
+
Returns:
|
401
|
+
The timestamp in microseconds.
|
402
|
+
"""
|
403
|
+
return self._get_timestamp()
|
404
|
+
def _get_timestamp(self) -> decimal.Decimal: raise NotImplementedError("_get_timestamp() method must be implemented")
|
405
|
+
|
406
|
+
def wait(self, value:int, timeout:int=10000):
|
407
|
+
"""
|
408
|
+
Waits the signal is greater than or equal to a specific value.
|
409
|
+
|
410
|
+
Args:
|
411
|
+
value: The value to wait for.
|
412
|
+
timeout: Maximum time to wait in milliseconds. Defaults to 10s.
|
413
|
+
"""
|
414
|
+
raise NotImplementedError("wait() method must be implemented")
|
415
|
+
|
189
416
|
@contextlib.contextmanager
|
190
|
-
def hcq_profile(dev,
|
191
|
-
st, en = (dev.
|
192
|
-
|
417
|
+
def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
|
418
|
+
st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None)
|
419
|
+
|
420
|
+
if enabled and queue is not None: queue.timestamp(st)
|
421
|
+
elif enabled:
|
422
|
+
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
|
423
|
+
dev.timeline_value += 1
|
424
|
+
|
193
425
|
try: yield (st, en)
|
194
426
|
finally:
|
195
|
-
if enabled:
|
427
|
+
if enabled and queue is not None: queue.timestamp(en)
|
428
|
+
elif enabled:
|
429
|
+
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
|
430
|
+
dev.timeline_value += 1
|
431
|
+
|
196
432
|
if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
|
197
433
|
|
198
|
-
class
|
199
|
-
def __init__(self,
|
200
|
-
|
201
|
-
|
434
|
+
class HCQArgsState:
|
435
|
+
def __init__(self, ptr:int, prg:HCQProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): self.ptr, self.prg = ptr, prg
|
436
|
+
def update_buffer(self, index:int, buf:HCQBuffer): raise NotImplementedError("need update_buffer")
|
437
|
+
def update_var(self, index:int, val:int): raise NotImplementedError("need update_var")
|
438
|
+
|
439
|
+
class HCQProgram:
|
440
|
+
def __init__(self, args_state_t:Type[HCQArgsState], device:HCQCompiled, name:str, kernargs_alloc_size:int, kernargs_args_offset:int=0):
|
441
|
+
self.args_state_t, self.device, self.name = args_state_t, device, name
|
442
|
+
self.kernargs_alloc_size, self.kernargs_args_offset = kernargs_alloc_size, kernargs_args_offset
|
443
|
+
|
444
|
+
def fill_kernargs(self, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=(), kernargs_ptr:Optional[int]=None) -> HCQArgsState:
|
445
|
+
"""
|
446
|
+
Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
|
447
|
+
Args:
|
448
|
+
bufs: Buffers to be written to kernel arguments.
|
449
|
+
vals: Values to be written to kernel arguments.
|
450
|
+
kernargs_ptr: Optional pointer to pre-allocated kernel arguments memory.
|
451
|
+
Returns:
|
452
|
+
Arguments state with the given buffers and values set for the program.
|
453
|
+
"""
|
454
|
+
return self.args_state_t(kernargs_ptr or self.device._alloc_kernargs(self.kernargs_alloc_size), self, bufs, vals=vals)
|
455
|
+
|
456
|
+
def __call__(self, *bufs:HCQBuffer, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1),
|
457
|
+
vals:Tuple[int, ...]=(), wait:bool=False) -> Optional[float]:
|
458
|
+
"""
|
459
|
+
Enqueues the program for execution with the given arguments and dimensions.
|
460
|
+
|
461
|
+
Args:
|
462
|
+
bufs: Buffer arguments to execute the kernel with.
|
463
|
+
global_size: Specifies the global work size for kernel execution (equivalent to CUDA's grid size).
|
464
|
+
local_size: Specifies the local work size for kernel execution (equivalent to CUDA's block size).
|
465
|
+
vals: Value arguments to execute the kernel with.
|
466
|
+
wait: If True, waits for the kernel to complete execution.
|
467
|
+
|
468
|
+
Returns:
|
469
|
+
Execution time of the kernel if 'wait' is True, otherwise None.
|
470
|
+
"""
|
471
|
+
|
472
|
+
q = self.device.hw_compute_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1).memory_barrier()
|
473
|
+
|
474
|
+
with hcq_profile(self.device, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
|
475
|
+
q.exec(self, self.fill_kernargs(bufs, vals), global_size, local_size)
|
476
|
+
|
477
|
+
q.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
478
|
+
self.device.timeline_value += 1
|
479
|
+
|
480
|
+
if wait: self.device.timeline_signal.wait(self.device.timeline_value - 1)
|
481
|
+
return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
|
482
|
+
|
483
|
+
class HCQCompiled(Compiled):
|
484
|
+
"""
|
485
|
+
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
|
486
|
+
"""
|
487
|
+
devices: List[HCQCompiled] = []
|
488
|
+
gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
|
489
|
+
gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
|
490
|
+
|
491
|
+
def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[HCQSignal],
|
492
|
+
comp_queue_t:Type[HWComputeQueue], copy_queue_t:Optional[Type[HWCopyQueue]], timeline_signals:Tuple[HCQSignal, HCQSignal]):
|
493
|
+
self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
|
494
|
+
self.timeline_value:int = 1
|
202
495
|
self.timeline_signal, self._shadow_timeline_signal = timeline_signals
|
203
|
-
self.sig_prof_records:
|
204
|
-
self.raw_prof_records:
|
496
|
+
self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
|
497
|
+
self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = []
|
498
|
+
self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
|
205
499
|
if PROFILE: self._prof_setup()
|
206
500
|
|
207
501
|
from tinygrad.runtime.graph.hcq import HCQGraph
|
208
502
|
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
|
209
503
|
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
def
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
504
|
+
self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferOptions(cpu_access=True))
|
505
|
+
self.kernargs_ptr:int = self.kernargs_page.va_addr
|
506
|
+
self.devices.append(self)
|
507
|
+
|
508
|
+
def synchronize(self):
|
509
|
+
self.timeline_signal.wait(self.timeline_value - 1)
|
510
|
+
|
511
|
+
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
|
512
|
+
if PROFILE:
|
513
|
+
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records]
|
514
|
+
self.sig_prof_records = []
|
515
|
+
|
516
|
+
def _alloc_kernargs(self, alloc_size:int) -> int:
|
517
|
+
"""
|
518
|
+
Allocates space for arguments passed to the kernel.
|
519
|
+
"""
|
520
|
+
if self.kernargs_ptr >= (self.kernargs_page.va_addr + self.kernargs_page.size - alloc_size): self.kernargs_ptr = self.kernargs_page.va_addr
|
521
|
+
self.kernargs_ptr = (res:=self.kernargs_ptr) + alloc_size
|
522
|
+
return res
|
523
|
+
|
524
|
+
def _ensure_shared_time_base(self):
|
525
|
+
if not self.gpu2cpu_compute_time_diff.is_nan(): return
|
526
|
+
|
527
|
+
def _sync_cpu_queue(d, q_t):
|
528
|
+
q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
|
529
|
+
d.timeline_value += 1
|
530
|
+
st = time.perf_counter_ns()
|
531
|
+
d.timeline_signal.wait(d.timeline_value - 1) # average of the two
|
532
|
+
et = time.perf_counter_ns()
|
533
|
+
return (decimal.Decimal(et+st) / 2000) - d.timeline_signal.timestamp
|
534
|
+
|
535
|
+
# randomly sample the timing from GPU to CPU
|
536
|
+
choices: List = [(d, d.hw_compute_queue_t, []) for d in self.devices]
|
537
|
+
choices += [(d, d.hw_copy_queue_t, []) for d in self.devices if d.hw_copy_queue_t is not None]
|
538
|
+
for _ in range(100*len(self.devices)):
|
539
|
+
d,q,l = random.choice(choices)
|
540
|
+
l.append(_sync_cpu_queue(d,q))
|
541
|
+
for d,q,l in choices:
|
542
|
+
if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l)
|
543
|
+
if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l)
|
544
|
+
|
545
|
+
def _sync_gpu_to_gpu_queue(d1, d2, q1_t, q2_t):
|
546
|
+
q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
|
547
|
+
.timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
|
548
|
+
q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \
|
549
|
+
.timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2)
|
550
|
+
d1.timeline_value += 2
|
551
|
+
d2.timeline_value += 2
|
552
|
+
d1.timeline_signal.wait(d1.timeline_value - 1)
|
553
|
+
d2.timeline_signal.wait(d2.timeline_value - 1)
|
554
|
+
return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp
|
555
|
+
|
556
|
+
# then test it by timing the GPU to GPU times
|
557
|
+
jitter_matrix = [[float('nan')]*len(self.devices) for _ in range(len(self.devices))]
|
558
|
+
for i1, d1 in enumerate(self.devices):
|
559
|
+
for i2, d2 in enumerate(self.devices):
|
560
|
+
if d1 == d2: continue
|
561
|
+
d1_to_d2 = statistics.median(_sync_gpu_to_gpu_queue(d1, d2, d1.hw_compute_queue_t, d2.hw_compute_queue_t) - \
|
562
|
+
_sync_gpu_to_gpu_queue(d2, d1, d2.hw_compute_queue_t, d1.hw_compute_queue_t) for _ in range(20)) / 2
|
563
|
+
jitter_matrix[i1][i2] = d1_to_d2 - (d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff)
|
564
|
+
print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix]))
|
565
|
+
|
566
|
+
def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
|
567
|
+
"""
|
568
|
+
Translates local gpu time (timestamp) into global cpu time.
|
569
|
+
"""
|
570
|
+
self._ensure_shared_time_base()
|
571
|
+
return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
|
226
572
|
|
227
573
|
def _prof_setup(self):
|
574
|
+
if hasattr(self, 'profile_logger'): return
|
575
|
+
atexit.register(self._prof_finalize)
|
228
576
|
self.profile_logger = ProfileLogger()
|
229
577
|
|
230
|
-
|
231
|
-
|
232
|
-
self.timeline_value += 1
|
233
|
-
cpu_start_time = time.perf_counter_ns() / 1e3
|
234
|
-
self._wait_signal(self.timeline_signal, self.timeline_value - 1)
|
235
|
-
return cpu_start_time, self._read_timestamp(self.timeline_signal)
|
236
|
-
self.cpu_start_time, self.gpu_start_time = _sync_queue(self.hw_compute_queue_t)
|
237
|
-
self.copy_cpu_start_time, self.copy_gpu_start_time = _sync_queue(self.hw_copy_queue_t)
|
578
|
+
def _prof_finalize(self):
|
579
|
+
qname = ["COMPUTE", "DMA"]
|
238
580
|
|
239
|
-
|
581
|
+
# Sync to be sure all events on the device are recorded.
|
582
|
+
self.synchronize()
|
240
583
|
|
241
|
-
|
242
|
-
|
243
|
-
for
|
244
|
-
|
584
|
+
for st, en, name, is_cp, args in self.raw_prof_records:
|
585
|
+
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp], args)]
|
586
|
+
for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records:
|
587
|
+
# Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint.
|
588
|
+
a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy)
|
589
|
+
self.profile_logger.deps += [(a_tm, b_tm, a_dev.dname, qname[a_is_copy], b_dev.dname, qname[b_is_copy])]
|
590
|
+
self.raw_prof_records, self.dep_prof_records = [], []
|
245
591
|
|
246
|
-
|
247
|
-
for st, en, name, is_cp in self.raw_prof_records:
|
248
|
-
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, ["COMPUTE", "DMA"][is_cp])]
|
592
|
+
# Remove the logger, this flushes all data written by the device.
|
249
593
|
del self.profile_logger
|
250
594
|
|
251
595
|
def _wrap_timeline_signal(self):
|
252
596
|
self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
|
253
|
-
self.
|
254
|
-
cast(
|
597
|
+
self.timeline_signal.value = 0
|
598
|
+
cast(HCQAllocator, self.allocator).b_timeline = [0] * len(cast(HCQAllocator, self.allocator).b)
|
255
599
|
|
256
|
-
|
257
|
-
|
258
|
-
|
600
|
+
# Protocol for hcq compatible allocators for allocated buffers to contain VA address and it's size.
|
601
|
+
class HCQBuffer(Protocol): va_addr:int; size:int # noqa: E702
|
602
|
+
|
603
|
+
class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method
|
604
|
+
"""
|
605
|
+
A base allocator class compatible with the HCQ (Hardware Command Queue) API.
|
606
|
+
|
607
|
+
This class implements basic copy operations following the HCQ API, utilizing both `HWComputeQueue` and `HWCopyQueue`.
|
608
|
+
"""
|
609
|
+
|
610
|
+
def __init__(self, device:HCQCompiled, batch_size:int=(2 << 20), batch_cnt:int=32):
|
611
|
+
self.device:Any = device
|
259
612
|
self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
|
260
613
|
self.b_timeline, self.b_next = [0] * len(self.b), 0
|
261
614
|
super().__init__()
|
262
615
|
|
263
|
-
def
|
264
|
-
|
616
|
+
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer: raise NotImplementedError("need hcq compat alloc")
|
617
|
+
|
618
|
+
def copyin(self, dest:HCQBuffer, src:memoryview):
|
619
|
+
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
|
265
620
|
for i in range(0, src.nbytes, self.b[0].size):
|
266
621
|
self.b_next = (self.b_next + 1) % len(self.b)
|
267
|
-
self.device.
|
622
|
+
self.device.timeline_signal.wait(self.b_timeline[self.b_next])
|
268
623
|
ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
|
269
624
|
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
270
625
|
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
|
@@ -272,15 +627,15 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
|
|
272
627
|
self.b_timeline[self.b_next] = self.device.timeline_value
|
273
628
|
self.device.timeline_value += 1
|
274
629
|
|
275
|
-
def copy_from_disk(self, dest, src, size):
|
630
|
+
def copy_from_disk(self, dest:HCQBuffer, src, size):
|
276
631
|
def _get_temp_buf():
|
277
632
|
# Check if the next buffer is safe to be used (its signal has passed) and reserve it.
|
278
|
-
if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device.
|
633
|
+
if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device.timeline_signal.value:
|
279
634
|
self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
|
280
635
|
return (self.b[self.b_next].va_addr, self.b_next)
|
281
636
|
return None
|
282
637
|
|
283
|
-
with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
|
638
|
+
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
|
284
639
|
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
|
285
640
|
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
286
641
|
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
|
@@ -288,23 +643,23 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
|
|
288
643
|
self.b_timeline[batch_info[1]] = self.device.timeline_value
|
289
644
|
self.device.timeline_value += 1
|
290
645
|
|
291
|
-
def copyout(self, dest:memoryview, src):
|
646
|
+
def copyout(self, dest:memoryview, src:HCQBuffer):
|
292
647
|
self.device.synchronize()
|
293
648
|
|
294
|
-
with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
|
649
|
+
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
|
295
650
|
for i in range(0, dest.nbytes, self.b[0].size):
|
296
651
|
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
297
652
|
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
|
298
653
|
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
299
|
-
self.device.
|
654
|
+
self.device.timeline_signal.wait(self.device.timeline_value)
|
300
655
|
self.device.timeline_value += 1
|
301
656
|
|
302
657
|
ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
|
303
658
|
|
304
|
-
def transfer(self, dest, src, sz:
|
305
|
-
src_dev.
|
659
|
+
def transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev, dest_dev):
|
660
|
+
src_dev.allocator.map(dest)
|
306
661
|
|
307
|
-
with hcq_profile(
|
662
|
+
with hcq_profile(src_dev, queue_type=src_dev.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
|
308
663
|
src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
|
309
664
|
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
310
665
|
.copy(dest.va_addr, src.va_addr, sz) \
|
@@ -317,4 +672,8 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
|
|
317
672
|
.signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
|
318
673
|
dest_dev.timeline_value += 1
|
319
674
|
|
320
|
-
def
|
675
|
+
def map(self, buf:HCQBuffer): pass
|
676
|
+
|
677
|
+
def offset(self, buf, size:int, offset:int) -> HCQBuffer:
|
678
|
+
return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
|
679
|
+
**{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)
|
tinygrad/dtype.py
CHANGED
@@ -30,16 +30,16 @@ class ImageDType(DType):
|
|
30
30
|
# @dataclass(frozen=True, init=False, repr=False, eq=False)
|
31
31
|
class PtrDType(DType):
|
32
32
|
def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
|
33
|
-
def __repr__(self): return f"ptr.{super().__repr__()}"
|
34
33
|
def __hash__(self): return super().__hash__()
|
35
34
|
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
|
36
35
|
def __ne__(self, dt): return not (self == dt)
|
36
|
+
def __repr__(self): return f"PtrDType({super().__repr__()})"
|
37
37
|
|
38
38
|
class dtypes:
|
39
39
|
@staticmethod
|
40
40
|
def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64)
|
41
41
|
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
42
|
-
def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) or dtypes.is_unsigned(x)
|
42
|
+
def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.pyint) or dtypes.is_unsigned(x)
|
43
43
|
@staticmethod
|
44
44
|
def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
45
45
|
@staticmethod
|
@@ -53,7 +53,17 @@ class dtypes:
|
|
53
53
|
@staticmethod
|
54
54
|
def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
|
55
55
|
@staticmethod
|
56
|
+
def min(dtype:DType):
|
57
|
+
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
|
58
|
+
return -float("inf") if dtypes.is_float(dtype) else False
|
59
|
+
@staticmethod
|
60
|
+
def max(dtype:DType):
|
61
|
+
if dtypes.is_int(dtype): return (2**(dtype.itemsize*8-(0 if dtypes.is_unsigned(dtype) else 1)))-1
|
62
|
+
return float("inf") if dtypes.is_float(dtype) else True
|
63
|
+
@staticmethod
|
56
64
|
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
65
|
+
# TODO: priority should be higher than bool
|
66
|
+
pyint: Final[DType] = DType(-1, 8, "pyint", None, 1) # arbitrary precision integer, same itemsize to int64 so min/max works
|
57
67
|
bool: Final[DType] = DType(0, 1, "bool", '?', 1)
|
58
68
|
int8: Final[DType] = DType(1, 1, "char", 'b', 1)
|
59
69
|
uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
|
@@ -87,6 +97,9 @@ if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
|
87
97
|
dtypes.default_float = getattr(dtypes, env_default_float.lower())
|
88
98
|
assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
|
89
99
|
|
100
|
+
DTypeLike = Union[str, DType]
|
101
|
+
def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype)
|
102
|
+
|
90
103
|
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
|
91
104
|
# we don't support weak type and complex type
|
92
105
|
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
|
@@ -103,11 +116,12 @@ def least_upper_dtype(*ds:DType) -> DType:
|
|
103
116
|
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
|
104
117
|
|
105
118
|
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
106
|
-
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default')) or v.__class__ is staticmethod)}
|
119
|
+
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'pyint')) or v.__class__ is staticmethod)}
|
107
120
|
INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
|
121
|
+
INVERSE_DTYPES_DICT['pyint'] = 'pyint'
|
108
122
|
|
109
123
|
def sum_acc_dtype(dt:DType):
|
110
124
|
# default acc dtype for sum
|
111
125
|
if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
|
112
126
|
if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
|
113
|
-
return least_upper_dtype(dt, dtypes.float)
|
127
|
+
return least_upper_dtype(dt, dtypes.float)
|