tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tinygrad/codegen/kernel.py +248 -115
  2. tinygrad/codegen/lowerer.py +215 -0
  3. tinygrad/codegen/transcendental.py +310 -0
  4. tinygrad/codegen/uopgraph.py +622 -0
  5. tinygrad/codegen/uops.py +235 -393
  6. tinygrad/device.py +428 -69
  7. tinygrad/dtype.py +18 -4
  8. tinygrad/engine/graph.py +19 -32
  9. tinygrad/engine/jit.py +148 -70
  10. tinygrad/engine/realize.py +127 -51
  11. tinygrad/engine/schedule.py +259 -216
  12. tinygrad/engine/search.py +29 -22
  13. tinygrad/function.py +9 -0
  14. tinygrad/helpers.py +87 -49
  15. tinygrad/lazy.py +34 -35
  16. tinygrad/multi.py +41 -36
  17. tinygrad/nn/__init__.py +39 -22
  18. tinygrad/nn/state.py +3 -3
  19. tinygrad/ops.py +63 -62
  20. tinygrad/renderer/__init__.py +43 -21
  21. tinygrad/renderer/assembly.py +104 -106
  22. tinygrad/renderer/cstyle.py +87 -60
  23. tinygrad/renderer/llvmir.py +21 -30
  24. tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
  25. tinygrad/runtime/autogen/cuda.py +6 -162
  26. tinygrad/runtime/autogen/kfd.py +32 -0
  27. tinygrad/runtime/autogen/libc.py +4260 -0
  28. tinygrad/runtime/autogen/nvrtc.py +579 -0
  29. tinygrad/runtime/graph/clang.py +2 -2
  30. tinygrad/runtime/graph/cuda.py +8 -11
  31. tinygrad/runtime/graph/hcq.py +120 -107
  32. tinygrad/runtime/graph/metal.py +18 -15
  33. tinygrad/runtime/ops_amd.py +197 -305
  34. tinygrad/runtime/ops_clang.py +2 -2
  35. tinygrad/runtime/ops_cuda.py +36 -94
  36. tinygrad/runtime/ops_disk.py +3 -7
  37. tinygrad/runtime/ops_gpu.py +4 -2
  38. tinygrad/runtime/ops_hip.py +70 -0
  39. tinygrad/runtime/ops_metal.py +38 -27
  40. tinygrad/runtime/ops_nv.py +283 -363
  41. tinygrad/runtime/ops_python.py +26 -30
  42. tinygrad/runtime/support/compiler_cuda.py +78 -0
  43. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
  44. tinygrad/runtime/support/elf.py +38 -0
  45. tinygrad/shape/shapetracker.py +5 -14
  46. tinygrad/shape/symbolic.py +4 -8
  47. tinygrad/shape/view.py +34 -22
  48. tinygrad/tensor.py +399 -97
  49. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
  50. tinygrad-0.9.2.dist-info/RECORD +70 -0
  51. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
  52. tinygrad/codegen/linearizer.py +0 -528
  53. tinygrad-0.9.1.dist-info/RECORD +0 -63
  54. /tinygrad/runtime/{driver → support}/__init__.py +0 -0
  55. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
  56. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/device.py CHANGED
@@ -1,10 +1,10 @@
1
1
  from __future__ import annotations
2
- import multiprocessing
2
+ import multiprocessing, decimal, statistics, random
3
3
  from dataclasses import dataclass
4
4
  from collections import defaultdict
5
- from typing import List, Optional, Dict, Tuple, Any, cast
6
- import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib
7
- from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
5
+ from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type
6
+ import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array
7
+ from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
8
8
  from tinygrad.dtype import DType, ImageDType
9
9
  from tinygrad.renderer import Renderer
10
10
 
@@ -25,11 +25,12 @@ class _Device:
25
25
  ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
26
26
  if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
27
27
  return ret
28
+ @property
29
+ def default(self) -> Compiled: return self[self.DEFAULT]
28
30
  @functools.cached_property
29
31
  def DEFAULT(self) -> str:
30
- device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore
31
- if device_from_env: return device_from_env
32
- for device in ["METAL", "AMD", "CUDA", "GPU", "CLANG", "LLVM"]:
32
+ if (from_env:=next((d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1), None)): return from_env
33
+ for device in ["METAL", "AMD", "NV", "CUDA", "GPU", "CLANG", "LLVM"]:
33
34
  try:
34
35
  if self[device]:
35
36
  os.environ[device] = "1" # we set this in environment for spawned children
@@ -90,7 +91,7 @@ class Buffer:
90
91
  if self._base is not None:
91
92
  return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
92
93
  if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
93
- if self.is_allocated():
94
+ if self.is_allocated() and not SAVE_SCHEDULE:
94
95
  buf = bytearray(self.nbytes)
95
96
  self.copyout(memoryview(buf))
96
97
  return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
@@ -140,6 +141,10 @@ class Allocator:
140
141
  def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
141
142
 
142
143
  class LRUAllocator(Allocator): # pylint: disable=abstract-method
144
+ """
145
+ The LRU Allocator is responsible for caching buffers.
146
+ It ensures that buffers are not freed until it is absolutely necessary, optimizing performance.
147
+ """
143
148
  def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
144
149
  def alloc(self, size:int, options:Optional[BufferOptions]=None):
145
150
  if len(c := self.cache[(size, options)]): return c.pop()
@@ -182,89 +187,439 @@ class Compiled:
182
187
  def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
183
188
  self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
184
189
  self.renderer = renderer or Renderer()
185
- def synchronize(self): pass # override this in your device
190
+ def synchronize(self):
191
+ """
192
+ Synchronize all pending operations on the device.
193
+
194
+ This method ensures that all previously queued operations on the device have been completed before proceeding.
195
+ """
196
+ # override this in your device implementation
186
197
 
187
198
  # **************** for HCQ Compatible Devices ****************
188
199
 
200
+ def hcq_command(func):
201
+ """
202
+ Decorator for HWCommandQueue commands. Enables command indexing and stores metadata for command updates.
203
+
204
+ For example:
205
+ ```python
206
+ @hcq_command
207
+ def command_method(self, ...): ...
208
+ ```
209
+ """
210
+ def __wrapper(self, *args, **kwargs):
211
+ self.cmds_offset.append(len(self.q))
212
+ func(self, *args, **kwargs)
213
+ self.cmds_len.append(len(self.q) - self.cmds_offset[-1])
214
+ self.cmds_meta.append(func.__name__)
215
+ return self
216
+ return __wrapper
217
+
218
+ class HWCommandQueue:
219
+ """
220
+ A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
221
+ Both compute and copy queues should have the following commands implemented.
222
+ """
223
+
224
+ def __init__(self): self.q, self.binded_device, self.cmds_offset, self.cmds_len, self.cmds_meta = [], None, [], [], []
225
+ def __len__(self): return len(self.cmds_offset)
226
+ def _patch(self, cmd_idx, offset, data): self.q[(st:=self.cmds_offset[cmd_idx]+offset):st+len(data)] = array.array('I', data)
227
+
228
+ @hcq_command
229
+ def signal(self, signal:HCQSignal, value:int):
230
+ """
231
+ Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
232
+
233
+ Args:
234
+ signal: The signal to set
235
+ value: The value to set the signal to
236
+ """
237
+ self._signal(signal, value)
238
+ def _signal(self, signal:HCQSignal, value:int): raise NotImplementedError("backend should overload this function")
239
+
240
+ @hcq_command
241
+ def wait(self, signal:HCQSignal, value:int):
242
+ """
243
+ Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
244
+
245
+ Args:
246
+ signal: The signal to wait on
247
+ value: The value to wait for
248
+ """
249
+ self._wait(signal, value)
250
+ def _wait(self, signal, value): raise NotImplementedError("backend should overload this function")
251
+
252
+ @hcq_command
253
+ def timestamp(self, signal:HCQSignal):
254
+ """
255
+ Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
256
+
257
+ Args:
258
+ signal: The signal to store the timestamp
259
+ """
260
+ self._timestamp(signal)
261
+ def _timestamp(self, signal): raise NotImplementedError("backend should overload this function")
262
+
263
+ def update_signal(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
264
+ """
265
+ Updates a previously queued signal command.
266
+
267
+ Args:
268
+ cmd_idx: Index of the signal command to update
269
+ signal: New signal to set (if None, keeps the original)
270
+ value: New value to set (if None, keeps the original)
271
+ """
272
+ if self.cmds_meta[cmd_idx] != "signal": raise RuntimeError("called update_signal not on a signal command")
273
+ self._update_signal(cmd_idx, signal, value)
274
+ return self
275
+ def _update_signal(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
276
+
277
+ def update_wait(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
278
+ """
279
+ Updates a previously queued wait command.
280
+
281
+ Args:
282
+ cmd_idx: Index of the wait command to update
283
+ signal: New signal to wait on (if None, keeps the original)
284
+ value: New value to wait for (if None, keeps the original)
285
+ """
286
+ if self.cmds_meta[cmd_idx] != "wait": raise RuntimeError("called update_wait not on a wait command")
287
+ self._update_wait(cmd_idx, signal, value)
288
+ return self
289
+ def _update_wait(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
290
+
291
+ def bind(self, device:HCQCompiled):
292
+ """
293
+ Associates the queue with a specific device for optimized execution.
294
+
295
+ This optional method allows backend implementations to tailor the queue for efficient use on the given device. When implemented, it can eliminate
296
+ the need to copy queues into the device, thereby enhancing performance.
297
+
298
+ Args:
299
+ device: The target device for queue optimization.
300
+
301
+ Note:
302
+ Implementing this method is optional but recommended for performance gains.
303
+ """
304
+
305
+ def submit(self, device:HCQCompiled):
306
+ """
307
+ Submits the command queue to a specific device for execution.
308
+
309
+ Args:
310
+ device: The device to submit the queue to
311
+ """
312
+ self._submit(device)
313
+ return self
314
+ def _submit(self, device:HCQCompiled): raise NotImplementedError("backend should overload this function")
315
+
316
+ class HWComputeQueue(HWCommandQueue):
317
+ @hcq_command
318
+ def memory_barrier(self):
319
+ """
320
+ Enqueues a memory barrier command to ensure memory coherence between agents.
321
+ """
322
+ self._memory_barrier()
323
+ def _memory_barrier(self): pass
324
+
325
+ @hcq_command
326
+ def exec(self, prg:HCQProgram, args_state:HCQArgsState, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]):
327
+ """
328
+ Enqueues an execution command for a kernel program.
329
+
330
+ Args:
331
+ prg: The program to execute
332
+ args_state: The args state to execute program with
333
+ global_size: The global work size
334
+ local_size: The local work size
335
+ """
336
+ self._exec(prg, args_state, global_size, local_size)
337
+ def _exec(self, prg, args_state, global_size, local_size): raise NotImplementedError("backend should overload this function")
338
+
339
+ def update_exec(self, cmd_idx:int, global_size:Optional[Tuple[int,int,int]]=None, local_size:Optional[Tuple[int,int,int]]=None):
340
+ """
341
+ Updates a previously queued execution command.
342
+
343
+ Args:
344
+ cmd_idx: Index of the execution command to update
345
+ global_size: New global work size (if None, keeps the original)
346
+ local_size: New local work size (if None, keeps the original)
347
+ """
348
+ if self.cmds_meta[cmd_idx] != "exec": raise RuntimeError("called update_exec not on an exec command")
349
+ self._update_exec(cmd_idx, global_size, local_size)
350
+ return self
351
+ def _update_exec(self, cmd_idx, global_size, local_size): raise NotImplementedError("backend should overload this function")
352
+
353
+ class HWCopyQueue(HWCommandQueue):
354
+ @hcq_command
355
+ def copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int):
356
+ """
357
+ Enqueues a copy command to transfer data.
358
+
359
+ Args:
360
+ dest: The destination of the copy
361
+ src: The source of the copy
362
+ copy_size: The size of data to copy
363
+ """
364
+ self._copy(dest, src, copy_size)
365
+ def _copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int): raise NotImplementedError("backend should overload this function")
366
+
367
+ def update_copy(self, cmd_idx:int, dest:Optional[HCQBuffer]=None, src:Optional[HCQBuffer]=None):
368
+ """
369
+ Updates a previously queued copy command.
370
+
371
+ Args:
372
+ cmd_idx: Index of the copy command to update
373
+ dest: New destination of the copy (if None, keeps the original)
374
+ src: New source of the copy (if None, keeps the original)
375
+ """
376
+ if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command")
377
+ self._update_copy(cmd_idx, dest, src)
378
+ return self
379
+ def _update_copy(self, cmd_idx, dest, src): raise NotImplementedError("backend should overload this function")
380
+
381
+ class HCQSignal:
382
+ def __init__(self, value:int=0): self._set_value(value)
383
+
384
+ @property
385
+ def value(self) -> int: return self._get_value()
386
+
387
+ @value.setter
388
+ def value(self, new_value:int): self._set_value(new_value)
389
+
390
+ def _get_value(self) -> int: raise NotImplementedError("_get_value() method must be implemented")
391
+ def _set_value(self, new_value:int): raise NotImplementedError("_set_value() method must be implemented")
392
+
393
+ @property
394
+ def timestamp(self) -> decimal.Decimal:
395
+ """
396
+ Get the timestamp field of the signal.
397
+
398
+ This property provides read-only access to the signal's timestamp.
399
+
400
+ Returns:
401
+ The timestamp in microseconds.
402
+ """
403
+ return self._get_timestamp()
404
+ def _get_timestamp(self) -> decimal.Decimal: raise NotImplementedError("_get_timestamp() method must be implemented")
405
+
406
+ def wait(self, value:int, timeout:int=10000):
407
+ """
408
+ Waits the signal is greater than or equal to a specific value.
409
+
410
+ Args:
411
+ value: The value to wait for.
412
+ timeout: Maximum time to wait in milliseconds. Defaults to 10s.
413
+ """
414
+ raise NotImplementedError("wait() method must be implemented")
415
+
189
416
  @contextlib.contextmanager
190
- def hcq_profile(dev, queue_type, enabled, desc):
191
- st, en = (dev._get_signal(), dev._get_signal()) if enabled else (None, None)
192
- if enabled: queue_type().timestamp(st).submit(dev)
417
+ def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
418
+ st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None)
419
+
420
+ if enabled and queue is not None: queue.timestamp(st)
421
+ elif enabled:
422
+ queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
423
+ dev.timeline_value += 1
424
+
193
425
  try: yield (st, en)
194
426
  finally:
195
- if enabled: queue_type().timestamp(en).submit(dev)
427
+ if enabled and queue is not None: queue.timestamp(en)
428
+ elif enabled:
429
+ queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
430
+ dev.timeline_value += 1
431
+
196
432
  if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
197
433
 
198
- class HCQCompatCompiled(Compiled):
199
- def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime, comp_queue_t, copy_queue_t, timeline_signals):
200
- self.hw_compute_queue_t, self.hw_copy_queue_t = comp_queue_t, copy_queue_t
201
- self.timeline_value: int = 1
434
+ class HCQArgsState:
435
+ def __init__(self, ptr:int, prg:HCQProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): self.ptr, self.prg = ptr, prg
436
+ def update_buffer(self, index:int, buf:HCQBuffer): raise NotImplementedError("need update_buffer")
437
+ def update_var(self, index:int, val:int): raise NotImplementedError("need update_var")
438
+
439
+ class HCQProgram:
440
+ def __init__(self, args_state_t:Type[HCQArgsState], device:HCQCompiled, name:str, kernargs_alloc_size:int, kernargs_args_offset:int=0):
441
+ self.args_state_t, self.device, self.name = args_state_t, device, name
442
+ self.kernargs_alloc_size, self.kernargs_args_offset = kernargs_alloc_size, kernargs_args_offset
443
+
444
+ def fill_kernargs(self, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=(), kernargs_ptr:Optional[int]=None) -> HCQArgsState:
445
+ """
446
+ Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
447
+ Args:
448
+ bufs: Buffers to be written to kernel arguments.
449
+ vals: Values to be written to kernel arguments.
450
+ kernargs_ptr: Optional pointer to pre-allocated kernel arguments memory.
451
+ Returns:
452
+ Arguments state with the given buffers and values set for the program.
453
+ """
454
+ return self.args_state_t(kernargs_ptr or self.device._alloc_kernargs(self.kernargs_alloc_size), self, bufs, vals=vals)
455
+
456
+ def __call__(self, *bufs:HCQBuffer, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1),
457
+ vals:Tuple[int, ...]=(), wait:bool=False) -> Optional[float]:
458
+ """
459
+ Enqueues the program for execution with the given arguments and dimensions.
460
+
461
+ Args:
462
+ bufs: Buffer arguments to execute the kernel with.
463
+ global_size: Specifies the global work size for kernel execution (equivalent to CUDA's grid size).
464
+ local_size: Specifies the local work size for kernel execution (equivalent to CUDA's block size).
465
+ vals: Value arguments to execute the kernel with.
466
+ wait: If True, waits for the kernel to complete execution.
467
+
468
+ Returns:
469
+ Execution time of the kernel if 'wait' is True, otherwise None.
470
+ """
471
+
472
+ q = self.device.hw_compute_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1).memory_barrier()
473
+
474
+ with hcq_profile(self.device, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
475
+ q.exec(self, self.fill_kernargs(bufs, vals), global_size, local_size)
476
+
477
+ q.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
478
+ self.device.timeline_value += 1
479
+
480
+ if wait: self.device.timeline_signal.wait(self.device.timeline_value - 1)
481
+ return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
482
+
483
+ class HCQCompiled(Compiled):
484
+ """
485
+ A base class for devices compatible with the HCQ (Hardware Command Queue) API.
486
+ """
487
+ devices: List[HCQCompiled] = []
488
+ gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
489
+ gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
490
+
491
+ def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[HCQSignal],
492
+ comp_queue_t:Type[HWComputeQueue], copy_queue_t:Optional[Type[HWCopyQueue]], timeline_signals:Tuple[HCQSignal, HCQSignal]):
493
+ self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
494
+ self.timeline_value:int = 1
202
495
  self.timeline_signal, self._shadow_timeline_signal = timeline_signals
203
- self.sig_prof_records: List[Tuple[Any, Any, str, bool]] = []
204
- self.raw_prof_records: List[Tuple[int, int, str, bool]] = []
496
+ self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
497
+ self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = []
498
+ self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
205
499
  if PROFILE: self._prof_setup()
206
500
 
207
501
  from tinygrad.runtime.graph.hcq import HCQGraph
208
502
  super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
209
503
 
210
- @classmethod
211
- def _read_signal(self, sig): raise NotImplementedError("need _read_signal") # reads a value for a signal
212
-
213
- @classmethod
214
- def _read_timestamp(self, sig): raise NotImplementedError("need _read_timestamp") # reads a timestamp for a signal
215
-
216
- @classmethod
217
- def _set_signal(self, sig, value): raise NotImplementedError("need _set_signal") # sets a value for a signal
218
-
219
- @classmethod
220
- def _get_signal(self, value=0, **kwargs): raise NotImplementedError("need _get_signal") # allocates a new signal
221
-
222
- @classmethod
223
- def _wait_signal(self, signal, value=0, timeout=10000): raise NotImplementedError("need _wait_signal") # waits for a signal value
224
-
225
- def _gpu2cpu_time(self, gpu_time, is_copy): raise NotImplementedError("need _gpu2cpu_time")
504
+ self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferOptions(cpu_access=True))
505
+ self.kernargs_ptr:int = self.kernargs_page.va_addr
506
+ self.devices.append(self)
507
+
508
+ def synchronize(self):
509
+ self.timeline_signal.wait(self.timeline_value - 1)
510
+
511
+ if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
512
+ if PROFILE:
513
+ self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records]
514
+ self.sig_prof_records = []
515
+
516
+ def _alloc_kernargs(self, alloc_size:int) -> int:
517
+ """
518
+ Allocates space for arguments passed to the kernel.
519
+ """
520
+ if self.kernargs_ptr >= (self.kernargs_page.va_addr + self.kernargs_page.size - alloc_size): self.kernargs_ptr = self.kernargs_page.va_addr
521
+ self.kernargs_ptr = (res:=self.kernargs_ptr) + alloc_size
522
+ return res
523
+
524
+ def _ensure_shared_time_base(self):
525
+ if not self.gpu2cpu_compute_time_diff.is_nan(): return
526
+
527
+ def _sync_cpu_queue(d, q_t):
528
+ q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
529
+ d.timeline_value += 1
530
+ st = time.perf_counter_ns()
531
+ d.timeline_signal.wait(d.timeline_value - 1) # average of the two
532
+ et = time.perf_counter_ns()
533
+ return (decimal.Decimal(et+st) / 2000) - d.timeline_signal.timestamp
534
+
535
+ # randomly sample the timing from GPU to CPU
536
+ choices: List = [(d, d.hw_compute_queue_t, []) for d in self.devices]
537
+ choices += [(d, d.hw_copy_queue_t, []) for d in self.devices if d.hw_copy_queue_t is not None]
538
+ for _ in range(100*len(self.devices)):
539
+ d,q,l = random.choice(choices)
540
+ l.append(_sync_cpu_queue(d,q))
541
+ for d,q,l in choices:
542
+ if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l)
543
+ if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l)
544
+
545
+ def _sync_gpu_to_gpu_queue(d1, d2, q1_t, q2_t):
546
+ q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
547
+ .timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
548
+ q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \
549
+ .timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2)
550
+ d1.timeline_value += 2
551
+ d2.timeline_value += 2
552
+ d1.timeline_signal.wait(d1.timeline_value - 1)
553
+ d2.timeline_signal.wait(d2.timeline_value - 1)
554
+ return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp
555
+
556
+ # then test it by timing the GPU to GPU times
557
+ jitter_matrix = [[float('nan')]*len(self.devices) for _ in range(len(self.devices))]
558
+ for i1, d1 in enumerate(self.devices):
559
+ for i2, d2 in enumerate(self.devices):
560
+ if d1 == d2: continue
561
+ d1_to_d2 = statistics.median(_sync_gpu_to_gpu_queue(d1, d2, d1.hw_compute_queue_t, d2.hw_compute_queue_t) - \
562
+ _sync_gpu_to_gpu_queue(d2, d1, d2.hw_compute_queue_t, d1.hw_compute_queue_t) for _ in range(20)) / 2
563
+ jitter_matrix[i1][i2] = d1_to_d2 - (d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff)
564
+ print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix]))
565
+
566
+ def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
567
+ """
568
+ Translates local gpu time (timestamp) into global cpu time.
569
+ """
570
+ self._ensure_shared_time_base()
571
+ return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
226
572
 
227
573
  def _prof_setup(self):
574
+ if hasattr(self, 'profile_logger'): return
575
+ atexit.register(self._prof_finalize)
228
576
  self.profile_logger = ProfileLogger()
229
577
 
230
- def _sync_queue(q_t):
231
- q_t().timestamp(self.timeline_signal).signal(self.timeline_signal, self.timeline_value).submit(self)
232
- self.timeline_value += 1
233
- cpu_start_time = time.perf_counter_ns() / 1e3
234
- self._wait_signal(self.timeline_signal, self.timeline_value - 1)
235
- return cpu_start_time, self._read_timestamp(self.timeline_signal)
236
- self.cpu_start_time, self.gpu_start_time = _sync_queue(self.hw_compute_queue_t)
237
- self.copy_cpu_start_time, self.copy_gpu_start_time = _sync_queue(self.hw_copy_queue_t)
578
+ def _prof_finalize(self):
579
+ qname = ["COMPUTE", "DMA"]
238
580
 
239
- atexit.register(self._prof_finalize)
581
+ # Sync to be sure all events on the device are recorded.
582
+ self.synchronize()
240
583
 
241
- def _prof_process_events(self):
242
- self.raw_prof_records += [(self._read_timestamp(st), self._read_timestamp(en), name, is_cp) for st, en, name, is_cp in self.sig_prof_records]
243
- for st, en, _, _ in self.sig_prof_records: self.signals_pool += [st, en] # type: ignore
244
- self.sig_prof_records = []
584
+ for st, en, name, is_cp, args in self.raw_prof_records:
585
+ self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp], args)]
586
+ for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records:
587
+ # Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint.
588
+ a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy)
589
+ self.profile_logger.deps += [(a_tm, b_tm, a_dev.dname, qname[a_is_copy], b_dev.dname, qname[b_is_copy])]
590
+ self.raw_prof_records, self.dep_prof_records = [], []
245
591
 
246
- def _prof_finalize(self):
247
- for st, en, name, is_cp in self.raw_prof_records:
248
- self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, ["COMPUTE", "DMA"][is_cp])]
592
+ # Remove the logger, this flushes all data written by the device.
249
593
  del self.profile_logger
250
594
 
251
595
  def _wrap_timeline_signal(self):
252
596
  self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
253
- self._set_signal(self.timeline_signal, 0)
254
- cast(HCQCompatAllocator, self.allocator).b_timeline = [0] * len(cast(HCQCompatAllocator, self.allocator).b)
597
+ self.timeline_signal.value = 0
598
+ cast(HCQAllocator, self.allocator).b_timeline = [0] * len(cast(HCQAllocator, self.allocator).b)
255
599
 
256
- class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
257
- def __init__(self, device, batch_size=(2 << 20), batch_cnt=32):
258
- self.device = device
600
+ # Protocol for hcq compatible allocators for allocated buffers to contain VA address and it's size.
601
+ class HCQBuffer(Protocol): va_addr:int; size:int # noqa: E702
602
+
603
+ class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method
604
+ """
605
+ A base allocator class compatible with the HCQ (Hardware Command Queue) API.
606
+
607
+ This class implements basic copy operations following the HCQ API, utilizing both `HWComputeQueue` and `HWCopyQueue`.
608
+ """
609
+
610
+ def __init__(self, device:HCQCompiled, batch_size:int=(2 << 20), batch_cnt:int=32):
611
+ self.device:Any = device
259
612
  self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
260
613
  self.b_timeline, self.b_next = [0] * len(self.b), 0
261
614
  super().__init__()
262
615
 
263
- def copyin(self, dest, src: memoryview):
264
- with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
616
+ def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer: raise NotImplementedError("need hcq compat alloc")
617
+
618
+ def copyin(self, dest:HCQBuffer, src:memoryview):
619
+ with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
265
620
  for i in range(0, src.nbytes, self.b[0].size):
266
621
  self.b_next = (self.b_next + 1) % len(self.b)
267
- self.device._wait_signal(self.device.timeline_signal, self.b_timeline[self.b_next])
622
+ self.device.timeline_signal.wait(self.b_timeline[self.b_next])
268
623
  ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
269
624
  self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
270
625
  .copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
@@ -272,15 +627,15 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
272
627
  self.b_timeline[self.b_next] = self.device.timeline_value
273
628
  self.device.timeline_value += 1
274
629
 
275
- def copy_from_disk(self, dest, src, size):
630
+ def copy_from_disk(self, dest:HCQBuffer, src, size):
276
631
  def _get_temp_buf():
277
632
  # Check if the next buffer is safe to be used (its signal has passed) and reserve it.
278
- if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device._read_signal(self.device.timeline_signal):
633
+ if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device.timeline_signal.value:
279
634
  self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
280
635
  return (self.b[self.b_next].va_addr, self.b_next)
281
636
  return None
282
637
 
283
- with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
638
+ with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
284
639
  for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
285
640
  self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
286
641
  .copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
@@ -288,23 +643,23 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
288
643
  self.b_timeline[batch_info[1]] = self.device.timeline_value
289
644
  self.device.timeline_value += 1
290
645
 
291
- def copyout(self, dest:memoryview, src):
646
+ def copyout(self, dest:memoryview, src:HCQBuffer):
292
647
  self.device.synchronize()
293
648
 
294
- with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
649
+ with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
295
650
  for i in range(0, dest.nbytes, self.b[0].size):
296
651
  self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
297
652
  .copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
298
653
  .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
299
- self.device._wait_signal(self.device.timeline_signal, self.device.timeline_value)
654
+ self.device.timeline_signal.wait(self.device.timeline_value)
300
655
  self.device.timeline_value += 1
301
656
 
302
657
  ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
303
658
 
304
- def transfer(self, dest, src, sz: int, src_dev, dest_dev):
305
- src_dev._gpu_map(dest)
659
+ def transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev, dest_dev):
660
+ src_dev.allocator.map(dest)
306
661
 
307
- with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
662
+ with hcq_profile(src_dev, queue_type=src_dev.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
308
663
  src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
309
664
  .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
310
665
  .copy(dest.va_addr, src.va_addr, sz) \
@@ -317,4 +672,8 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
317
672
  .signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
318
673
  dest_dev.timeline_value += 1
319
674
 
320
- def offset(self, buf, size:int, offset:int): return type(buf)(base=buf.base + offset, va_addr=buf.va_addr + offset, length=size, size=size)
675
+ def map(self, buf:HCQBuffer): pass
676
+
677
+ def offset(self, buf, size:int, offset:int) -> HCQBuffer:
678
+ return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
679
+ **{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)
tinygrad/dtype.py CHANGED
@@ -30,16 +30,16 @@ class ImageDType(DType):
30
30
  # @dataclass(frozen=True, init=False, repr=False, eq=False)
31
31
  class PtrDType(DType):
32
32
  def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
33
- def __repr__(self): return f"ptr.{super().__repr__()}"
34
33
  def __hash__(self): return super().__hash__()
35
34
  def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
36
35
  def __ne__(self, dt): return not (self == dt)
36
+ def __repr__(self): return f"PtrDType({super().__repr__()})"
37
37
 
38
38
  class dtypes:
39
39
  @staticmethod
40
40
  def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64)
41
41
  @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
42
- def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) or dtypes.is_unsigned(x)
42
+ def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.pyint) or dtypes.is_unsigned(x)
43
43
  @staticmethod
44
44
  def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
45
45
  @staticmethod
@@ -53,7 +53,17 @@ class dtypes:
53
53
  @staticmethod
54
54
  def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
55
55
  @staticmethod
56
+ def min(dtype:DType):
57
+ if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
58
+ return -float("inf") if dtypes.is_float(dtype) else False
59
+ @staticmethod
60
+ def max(dtype:DType):
61
+ if dtypes.is_int(dtype): return (2**(dtype.itemsize*8-(0 if dtypes.is_unsigned(dtype) else 1)))-1
62
+ return float("inf") if dtypes.is_float(dtype) else True
63
+ @staticmethod
56
64
  def fields() -> Dict[str, DType]: return DTYPES_DICT
65
+ # TODO: priority should be higher than bool
66
+ pyint: Final[DType] = DType(-1, 8, "pyint", None, 1) # arbitrary precision integer, same itemsize to int64 so min/max works
57
67
  bool: Final[DType] = DType(0, 1, "bool", '?', 1)
58
68
  int8: Final[DType] = DType(1, 1, "char", 'b', 1)
59
69
  uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
@@ -87,6 +97,9 @@ if (env_default_float := getenv("DEFAULT_FLOAT", "")):
87
97
  dtypes.default_float = getattr(dtypes, env_default_float.lower())
88
98
  assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
89
99
 
100
+ DTypeLike = Union[str, DType]
101
+ def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype)
102
+
90
103
  # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
91
104
  # we don't support weak type and complex type
92
105
  promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
@@ -103,11 +116,12 @@ def least_upper_dtype(*ds:DType) -> DType:
103
116
  def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
104
117
 
105
118
  # HACK: staticmethods are not callable in 3.8 so we have to compare the class
106
- DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default')) or v.__class__ is staticmethod)}
119
+ DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'pyint')) or v.__class__ is staticmethod)}
107
120
  INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
121
+ INVERSE_DTYPES_DICT['pyint'] = 'pyint'
108
122
 
109
123
  def sum_acc_dtype(dt:DType):
110
124
  # default acc dtype for sum
111
125
  if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
112
126
  if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
113
- return least_upper_dtype(dt, dtypes.float)
127
+ return least_upper_dtype(dt, dtypes.float)