tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
@@ -1,48 +1,104 @@
1
1
  from __future__ import annotations
2
- from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type, Union
3
- import contextlib, decimal, statistics, random, json, atexit, time, array, ctypes
4
- from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv
2
+ from typing import cast, Type, TypeVar, Generic, Any
3
+ import contextlib, decimal, statistics, time, ctypes, array, os, fcntl
4
+ from tinygrad.helpers import PROFILE, from_mv, getenv, to_mv, round_up
5
5
  from tinygrad.renderer import Renderer
6
- from tinygrad.device import BufferOptions, Allocator, Compiler, Compiled, LRUAllocator
6
+ from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileRangeEvent, ProfileDeviceEvent
7
+ from tinygrad.ops import sym_infer, sint, Variable
8
+ from tinygrad.runtime.autogen import libc
7
9
 
8
- # **************** for HCQ Compatible Devices ****************
9
-
10
- def hcq_command(func):
10
+ class HWInterface:
11
11
  """
12
- Decorator for HWCommandQueue commands. Enables command indexing and stores metadata for command updates.
13
-
14
- For example:
15
- ```python
16
- @hcq_command
17
- def command_method(self, ...): ...
18
- ```
12
+ Hardware Abstraction Layer for HCQ devices. The class provides a unified interface for interacting with hardware devices.
19
13
  """
20
- def __wrapper(self, *args, **kwargs):
21
- self.cmds_offset.append(len(self.q))
22
- func(self, *args, **kwargs)
23
- self.cmds_len.append(len(self.q) - self.cmds_offset[-1])
24
- self.cmds_meta.append(func.__name__)
25
- return self
26
- return __wrapper
27
14
 
28
- class HWCommandQueue:
15
+ def __init__(self, path:str="", flags:int=os.O_RDONLY, fd:int|None=None):
16
+ self.path:str = path
17
+ self.fd:int = fd or os.open(path, flags)
18
+ def __del__(self):
19
+ if hasattr(self, 'fd'): os.close(self.fd)
20
+ def ioctl(self, request, arg): return fcntl.ioctl(self.fd, request, arg)
21
+ def mmap(self, start, sz, prot, flags, offset): return libc.mmap(start, sz, prot, flags, self.fd, offset)
22
+ def read(self, size=None, binary=False):
23
+ with open(self.fd, "rb" if binary else "r", closefd=False) as file: return file.read(size)
24
+ def write(self, content, binary=False):
25
+ with open(self.fd, "wb" if binary else "w", closefd=False) as file: file.write(content)
26
+ def listdir(self): return os.listdir(self.path)
27
+ def seek(self, offset): os.lseek(self.fd, offset, os.SEEK_SET)
28
+ @staticmethod
29
+ def anon_mmap(start, sz, prot, flags, offset): return libc.mmap(start, sz, prot, flags, -1, offset)
30
+ @staticmethod
31
+ def munmap(buf, sz): return libc.munmap(buf, sz)
32
+ @staticmethod
33
+ def exists(path): return os.path.exists(path)
34
+ @staticmethod
35
+ def readlink(path): return os.readlink(path)
36
+ @staticmethod
37
+ def eventfd(initval, flags=None): return HWInterface(fd=os.eventfd(initval, flags)) # type: ignore[attr-defined]
38
+
39
+ if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.mockgpu import MockHWInterface as HWInterface # noqa: F401 # pylint: disable=unused-import
40
+
41
+ # **************** for HCQ Compatible Devices ****************
42
+
43
+ SignalType = TypeVar('SignalType', bound='HCQSignal')
44
+ DeviceType = TypeVar('DeviceType', bound='HCQCompiled')
45
+ ProgramType = TypeVar('ProgramType', bound='HCQProgram')
46
+ ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState')
47
+ QueueType = TypeVar('QueueType', bound='HWQueue')
48
+
49
+ class BumpAllocator:
50
+ def __init__(self, size:int, base:int=0, wrap:bool=True): self.size, self.ptr, self.base, self.wrap = size, 0, base, wrap
51
+ def alloc(self, size:int, alignment:int=1) -> int:
52
+ if round_up(self.ptr, alignment) + size > self.size:
53
+ if not self.wrap: raise RuntimeError("Out of memory")
54
+ self.ptr = 0
55
+ self.ptr = (res:=round_up(self.ptr, alignment)) + size
56
+ return res + self.base
57
+
58
+ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
29
59
  """
30
60
  A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
31
- Both compute and copy queues should have the following commands implemented.
32
61
  """
33
62
 
34
- def __init__(self): self.q, self.binded_device, self.cmds_offset, self.cmds_len, self.cmds_meta = [], None, [], [], []
35
- def __len__(self): return len(self.cmds_offset)
36
- def _patch(self, cmd_idx, offset, data): self.q[(st:=self.cmds_offset[cmd_idx]+offset):st+len(data)] = array.array('I', data)
37
- def _cur_cmd_idx(self) -> int:
63
+ def __init__(self):
64
+ self._q:Any = []
65
+ self.binded_device:DeviceType|None = None
66
+ self.q_sints:list[tuple[int, int]] = []
67
+ self.mv_sints:list[tuple[memoryview, int, int, int|None]] = []
68
+ self.syms:list[sint] = []
69
+ self._prev_resolved_syms:list[int|None] = []
70
+
71
+ def _new_sym(self, sym:sint) -> int:
72
+ if sym not in self.syms:
73
+ self.syms.append(sym)
74
+ self._prev_resolved_syms.append(None)
75
+ return self.syms.index(sym)
76
+
77
+ def q(self, *values):
38
78
  """
39
- Returns the index of the command currently being enqueued.
40
- Should be called only within functions that enqueue commands and are decorated with `@hcq_command`.
79
+ Enqueues values in the queue.
80
+
81
+ Args:
82
+ values: The values to enqueue in the queue.
41
83
  """
42
- return len(self) - 1
43
84
 
44
- @hcq_command
45
- def signal(self, signal:HCQSignal, value:int):
85
+ for v in values:
86
+ if isinstance(v, int): self._q.append(v)
87
+ else:
88
+ self.q_sints.append((len(self._q), self._new_sym(v)))
89
+ self._q.append(0xbadc0ded)
90
+
91
+ # *** common commands ***
92
+
93
+ def timestamp(self, signal:SignalType):
94
+ """
95
+ Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
96
+
97
+ Args:
98
+ signal: The signal to store the timestamp
99
+ """
100
+
101
+ def signal(self, signal:SignalType, value:sint):
46
102
  """
47
103
  Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
48
104
 
@@ -50,11 +106,8 @@ class HWCommandQueue:
50
106
  signal: The signal to set
51
107
  value: The value to set the signal to
52
108
  """
53
- self._signal(signal, value)
54
- def _signal(self, signal:HCQSignal, value:int): raise NotImplementedError("backend should overload this function")
55
109
 
56
- @hcq_command
57
- def wait(self, signal:HCQSignal, value:int):
110
+ def wait(self, signal:SignalType, value:sint):
58
111
  """
59
112
  Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
60
113
 
@@ -62,49 +115,40 @@ class HWCommandQueue:
62
115
  signal: The signal to wait on
63
116
  value: The value to wait for
64
117
  """
65
- self._wait(signal, value)
66
- def _wait(self, signal, value): raise NotImplementedError("backend should overload this function")
67
118
 
68
- @hcq_command
69
- def timestamp(self, signal:HCQSignal):
70
- """
71
- Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
119
+ # *** commands for compute queues ***
72
120
 
73
- Args:
74
- signal: The signal to store the timestamp
121
+ def memory_barrier(self):
122
+ """
123
+ Enqueues a memory barrier command to ensure memory coherence between agents. Only on compute queues.
75
124
  """
76
- self._timestamp(signal)
77
- def _timestamp(self, signal): raise NotImplementedError("backend should overload this function")
78
125
 
79
- def update_signal(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
126
+ def exec(self, prg:ProgramType, args_state:ArgsStateType, global_size:tuple[sint, ...], local_size:tuple[sint, ...]):
80
127
  """
81
- Updates a previously queued signal command.
128
+ Enqueues an execution command for a kernel program. Only on compute queues.
82
129
 
83
130
  Args:
84
- cmd_idx: Index of the signal command to update
85
- signal: New signal to set (if None, keeps the original)
86
- value: New value to set (if None, keeps the original)
131
+ prg: The program to execute
132
+ args_state: The args state to execute program with
133
+ global_size: The global work size
134
+ local_size: The local work size
87
135
  """
88
- if self.cmds_meta[cmd_idx] != "signal": raise RuntimeError("called update_signal not on a signal command")
89
- self._update_signal(cmd_idx, signal, value)
90
- return self
91
- def _update_signal(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
92
136
 
93
- def update_wait(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
137
+ # *** commands for copy queues ***
138
+
139
+ def copy(self, dest:sint, src:sint, copy_size:int):
94
140
  """
95
- Updates a previously queued wait command.
141
+ Enqueues a copy command to transfer data. Only on copy queues.
96
142
 
97
143
  Args:
98
- cmd_idx: Index of the wait command to update
99
- signal: New signal to wait on (if None, keeps the original)
100
- value: New value to wait for (if None, keeps the original)
144
+ dest: The destination of the copy
145
+ src: The source of the copy
146
+ copy_size: The size of data to copy
101
147
  """
102
- if self.cmds_meta[cmd_idx] != "wait": raise RuntimeError("called update_wait not on a wait command")
103
- self._update_wait(cmd_idx, signal, value)
104
- return self
105
- def _update_wait(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
106
148
 
107
- def bind(self, device:HCQCompiled):
149
+ # *** submit and bind commands ***
150
+
151
+ def bind(self, dev:DeviceType):
108
152
  """
109
153
  Associates the queue with a specific device for optimized execution.
110
154
 
@@ -112,99 +156,65 @@ class HWCommandQueue:
112
156
  the need to copy queues into the device, thereby enhancing performance.
113
157
 
114
158
  Args:
115
- device: The target device for queue optimization.
159
+ dev: The target device for queue optimization.
116
160
 
117
161
  Note:
118
162
  Implementing this method is optional but recommended for performance gains.
119
163
  """
120
164
 
121
- def submit(self, device:HCQCompiled):
122
- """
123
- Submits the command queue to a specific device for execution.
165
+ def bind_args_state(self, args_state:ArgsStateType):
166
+ for vals, ptr, fmt in args_state.bind_data: self.bind_sints_to_ptr(*vals, ptr=ptr, fmt=fmt)
124
167
 
125
- Args:
126
- device: The device to submit the queue to
127
- """
128
- if self.q: self._submit(device)
129
- return self
130
- def _submit(self, device:HCQCompiled): raise NotImplementedError("backend should overload this function")
168
+ def bind_sints(self, *vals:sint, struct:ctypes.Structure, start_field:str, fmt, mask:int|None=None):
169
+ self.bind_sints_to_ptr(*vals, ptr=ctypes.addressof(struct) + getattr(type(struct), start_field).offset, fmt=fmt, mask=mask)
131
170
 
132
- class HWComputeQueue(HWCommandQueue):
133
- @hcq_command
134
- def memory_barrier(self):
135
- """
136
- Enqueues a memory barrier command to ensure memory coherence between agents.
137
- """
138
- self._memory_barrier()
139
- def _memory_barrier(self): pass
171
+ def bind_sints_to_ptr(self, *vals:sint, ptr:int, fmt, mask:int|None=None):
172
+ mv = to_mv(ptr, 8*len(vals)).cast(fmt)
173
+ for i, val in enumerate(vals):
174
+ if isinstance(val, int): mv[i] = val if mask is None else ((mv[i] & ~mask) | val)
175
+ else: self.mv_sints.append((mv, i, self._new_sym(val), mask))
140
176
 
141
- @hcq_command
142
- def exec(self, prg:HCQProgram, args_state:HCQArgsState, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]):
143
- """
144
- Enqueues an execution command for a kernel program.
177
+ def _apply_var_vals(self, var_vals:dict[Variable, int]):
178
+ resolved_syms = [sym_infer(sym, var_vals) for sym in self.syms]
145
179
 
146
- Args:
147
- prg: The program to execute
148
- args_state: The args state to execute program with
149
- global_size: The global work size
150
- local_size: The local work size
151
- """
152
- self._exec(prg, args_state, global_size, local_size)
153
- def _exec(self, prg, args_state, global_size, local_size): raise NotImplementedError("backend should overload this function")
180
+ for off, sym_idx in self.q_sints:
181
+ if self._prev_resolved_syms[sym_idx] == resolved_syms[sym_idx]: continue
182
+ self._q[off] = resolved_syms[sym_idx]
154
183
 
155
- def update_exec(self, cmd_idx:int, global_size:Optional[Tuple[int,int,int]]=None, local_size:Optional[Tuple[int,int,int]]=None):
156
- """
157
- Updates a previously queued execution command.
184
+ for mv, off, sym_idx, mask in self.mv_sints:
185
+ if self._prev_resolved_syms[sym_idx] == resolved_syms[sym_idx]: continue
186
+ mv[off] = resolved_syms[sym_idx] if mask is None else ((mv[off] & ~mask) | resolved_syms[sym_idx])
158
187
 
159
- Args:
160
- cmd_idx: Index of the execution command to update
161
- global_size: New global work size (if None, keeps the original)
162
- local_size: New local work size (if None, keeps the original)
163
- """
164
- if self.cmds_meta[cmd_idx] != "exec": raise RuntimeError("called update_exec not on an exec command")
165
- self._update_exec(cmd_idx, global_size, local_size)
166
- return self
167
- def _update_exec(self, cmd_idx, global_size, local_size): raise NotImplementedError("backend should overload this function")
188
+ self._prev_resolved_syms = cast(list[int|None], resolved_syms)
168
189
 
169
- class HWCopyQueue(HWCommandQueue):
170
- @hcq_command
171
- def copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int):
190
+ def submit(self, dev:DeviceType, var_vals:dict[Variable, int]|None=None):
172
191
  """
173
- Enqueues a copy command to transfer data.
192
+ Submits the command queue to a specific device for execution.
174
193
 
175
194
  Args:
176
- dest: The destination of the copy
177
- src: The source of the copy
178
- copy_size: The size of data to copy
195
+ dev: The device to submit the queue to
179
196
  """
180
- self._copy(dest, src, copy_size)
181
- def _copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int): raise NotImplementedError("backend should overload this function")
182
197
 
183
- def update_copy(self, cmd_idx:int, dest:Optional[HCQBuffer]=None, src:Optional[HCQBuffer]=None):
184
- """
185
- Updates a previously queued copy command.
186
-
187
- Args:
188
- cmd_idx: Index of the copy command to update
189
- dest: New destination of the copy (if None, keeps the original)
190
- src: New source of the copy (if None, keeps the original)
191
- """
192
- if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command")
193
- self._update_copy(cmd_idx, dest, src)
198
+ if var_vals is not None: self._apply_var_vals(var_vals)
199
+ self._submit(dev)
194
200
  return self
195
- def _update_copy(self, cmd_idx, dest, src): raise NotImplementedError("backend should overload this function")
201
+ def _submit(self, dev:DeviceType): raise NotImplementedError("need _submit")
196
202
 
197
- class HCQSignal:
198
- def __init__(self, value:int=0, is_timeline:bool=False): self._set_value(value)
203
+ class HCQSignal(Generic[DeviceType]):
204
+ def __init__(self, base_addr:sint=0, value:int=0, timeline_for_device:DeviceType|None=None, timestamp_divider=1, value_off=0, timestamp_off=8):
205
+ self.base_addr, self.value_addr, self.timestamp_addr = base_addr, base_addr+value_off, base_addr+timestamp_off
206
+ self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider)
207
+ self.timeline_for_device:DeviceType|None = timeline_for_device
208
+
209
+ if isinstance(base_addr, int):
210
+ self.value_mv, self.timestamp_mv = to_mv(self.value_addr, 8).cast('Q'), to_mv(self.timestamp_addr, 8).cast('Q')
211
+ self.value_mv[0] = value
199
212
 
200
213
  @property
201
- def value(self) -> int: return self._get_value()
214
+ def value(self) -> int: return self.value_mv[0]
202
215
 
203
216
  @value.setter
204
- def value(self, new_value:int): self._set_value(new_value)
205
-
206
- def _get_value(self) -> int: raise NotImplementedError("_get_value() method must be implemented")
207
- def _set_value(self, new_value:int): raise NotImplementedError("_set_value() method must be implemented")
217
+ def value(self, new_value:int): self.value_mv[0] = new_value
208
218
 
209
219
  @property
210
220
  def timestamp(self) -> decimal.Decimal:
@@ -216,8 +226,12 @@ class HCQSignal:
216
226
  Returns:
217
227
  The timestamp in microseconds.
218
228
  """
219
- return self._get_timestamp()
220
- def _get_timestamp(self) -> decimal.Decimal: raise NotImplementedError("_get_timestamp() method must be implemented")
229
+ return self.timestamp_mv[0] / self.timestamp_divider
230
+
231
+ def _sleep(self, time_spent_waiting_ms:int):
232
+ """
233
+ Optional function which can implement sleep functionality for the signal.
234
+ """
221
235
 
222
236
  def wait(self, value:int, timeout:int=getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000)):
223
237
  """
@@ -227,17 +241,18 @@ class HCQSignal:
227
241
  value: The value to wait for.
228
242
  timeout: Maximum time to wait in milliseconds. Defaults to 10s.
229
243
  """
230
- start_time = time.time() * 1000
231
- while time.time() * 1000 - start_time < timeout:
232
- if self.value >= value: return
233
- raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
244
+ start_time = int(time.perf_counter() * 1000)
245
+ while self.value < value and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout:
246
+ self._sleep(time_spent)
247
+ if self.value < value: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
234
248
 
235
249
  @contextlib.contextmanager
236
- def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
250
+ def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Type[HWQueue]|None=None, queue:HWQueue|None=None):
237
251
  st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None)
238
252
 
239
253
  if enabled and queue is not None: queue.timestamp(st)
240
254
  elif enabled:
255
+ assert queue_type is not None
241
256
  queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
242
257
  dev.timeline_value += 1
243
258
 
@@ -245,21 +260,33 @@ def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
245
260
  finally:
246
261
  if enabled and queue is not None: queue.timestamp(en)
247
262
  elif enabled:
263
+ assert queue_type is not None
248
264
  queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
249
265
  dev.timeline_value += 1
250
266
 
251
- if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
267
+ if enabled and PROFILE: dev.sig_prof_records.append((cast(HCQSignal, st), cast(HCQSignal, en), desc, queue_type is dev.hw_copy_queue_t))
268
+
269
+ class HCQArgsState(Generic[ProgramType]):
270
+ def __init__(self, ptr:int, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=()):
271
+ self.ptr, self.prg = ptr, prg
272
+ self.bind_data:list[tuple[tuple[sint, ...], int, str]] = []
273
+
274
+ def bind_sints_to_ptr(self, *vals:sint, ptr:int, fmt): self.bind_data.append((vals, ptr, fmt))
275
+
276
+ class CLikeArgsState(HCQArgsState[ProgramType]):
277
+ def __init__(self, ptr:int, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=(), prefix:list[int]|None=None):
278
+ super().__init__(ptr, prg, bufs, vals=vals)
279
+
280
+ if prefix is not None: to_mv(self.ptr, len(prefix) * 4).cast('I')[:] = array.array('I', prefix)
252
281
 
253
- class HCQArgsState:
254
- def __init__(self, ptr:int, prg:HCQProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): self.ptr, self.prg = ptr, prg
255
- def update_buffer(self, index:int, buf:HCQBuffer): raise NotImplementedError("need update_buffer")
256
- def update_var(self, index:int, val:int): raise NotImplementedError("need update_var")
282
+ self.bind_sints_to_ptr(*[b.va_addr for b in bufs], ptr=self.ptr + len(prefix or []) * 4, fmt='Q')
283
+ self.bind_sints_to_ptr(*vals, ptr=self.ptr + len(prefix or []) * 4 + len(bufs) * 8, fmt='I')
257
284
 
258
- class HCQProgram:
259
- def __init__(self, args_state_t:Type[HCQArgsState], device:HCQCompiled, name:str, kernargs_alloc_size:int):
260
- self.args_state_t, self.device, self.name, self.kernargs_alloc_size = args_state_t, device, name, kernargs_alloc_size
285
+ class HCQProgram(Generic[DeviceType]):
286
+ def __init__(self, args_state_t:Type[HCQArgsState], dev:DeviceType, name:str, kernargs_alloc_size:int):
287
+ self.args_state_t, self.dev, self.name, self.kernargs_alloc_size = args_state_t, dev, name, kernargs_alloc_size
261
288
 
262
- def fill_kernargs(self, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=(), kernargs_ptr:Optional[int]=None) -> HCQArgsState:
289
+ def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=(), kernargs_ptr:int|None=None) -> HCQArgsState:
263
290
  """
264
291
  Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
265
292
  Args:
@@ -269,10 +296,10 @@ class HCQProgram:
269
296
  Returns:
270
297
  Arguments state with the given buffers and values set for the program.
271
298
  """
272
- return self.args_state_t(kernargs_ptr or self.device._alloc_kernargs(self.kernargs_alloc_size), self, bufs, vals=vals)
299
+ return self.args_state_t(kernargs_ptr or self.dev.kernargs_allocator.alloc(self.kernargs_alloc_size), self, bufs, vals=vals)
273
300
 
274
- def __call__(self, *bufs:HCQBuffer, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1),
275
- vals:Tuple[int, ...]=(), wait:bool=False) -> Optional[float]:
301
+ def __call__(self, *bufs:HCQBuffer, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
302
+ vals:tuple[int, ...]=(), wait:bool=False) -> float|None:
276
303
  """
277
304
  Enqueues the program for execution with the given arguments and dimensions.
278
305
 
@@ -288,103 +315,52 @@ class HCQProgram:
288
315
  """
289
316
 
290
317
  kernargs = self.fill_kernargs(bufs, vals)
291
- q = self.device.hw_compute_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1).memory_barrier()
318
+ q = self.dev.hw_compute_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1).memory_barrier()
292
319
 
293
- with hcq_profile(self.device, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
320
+ with hcq_profile(self.dev, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
294
321
  q.exec(self, kernargs, global_size, local_size)
295
322
 
296
- q.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
297
- self.device.timeline_value += 1
323
+ q.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
324
+ self.dev.timeline_value += 1
298
325
 
299
- if wait: self.device.synchronize()
326
+ if wait: self.dev.synchronize()
300
327
  return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
301
328
 
302
- class ProfileLogger:
303
- writers: int = 0
304
- mjson: List[Dict] = []
305
- actors: Dict[Union[str, Tuple[str, str]], int] = {}
306
-
307
- def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
308
-
309
- def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
310
-
311
- def _ensure_actor(self, actor_name, subactor_name):
312
- if actor_name not in self.actors:
313
- self.actors[actor_name] = (pid:=len(self.actors))
314
- self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
315
-
316
- if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
317
- self.actors[subactor_key] = (tid:=len(self.actors))
318
- self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
319
-
320
- return self.actors[actor_name], self.actors.get(subactor_key, -1)
321
-
322
- def __del__(self):
323
- # perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
324
- for name, st, et, actor_name, subactor_name, args in self.events:
325
- pid, tid = self._ensure_actor(actor_name,subactor_name)
326
- args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
327
- self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
328
-
329
- for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
330
- dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
331
- pid, tid = self._ensure_actor(actor_name,subactor_name)
332
- self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
333
- self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})
334
-
335
- ProfileLogger.writers -= 1
336
- if ProfileLogger.writers == 0 and len(self.mjson) > 0:
337
- with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
338
- print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
339
-
340
- class HCQCompiled(Compiled):
329
+ class HCQCompiled(Compiled, Generic[SignalType]):
341
330
  """
342
331
  A base class for devices compatible with the HCQ (Hardware Command Queue) API.
343
332
  """
344
- devices: List[HCQCompiled] = []
345
- gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
346
- gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
333
+ devices: list[HCQCompiled] = []
347
334
 
348
- def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[HCQSignal],
349
- comp_queue_t:Type[HWComputeQueue], copy_queue_t:Optional[Type[HWCopyQueue]]):
335
+ def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
336
+ comp_queue_t:Type[HWQueue], copy_queue_t:Type[HWQueue]|None):
337
+ self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
350
338
  self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
351
339
  self.timeline_value:int = 1
352
- self.timeline_signal, self._shadow_timeline_signal = self.signal_t(0, is_timeline=True), self.signal_t(0, is_timeline=True)
353
- self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
354
- self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = []
355
- self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
356
- if PROFILE: self._prof_setup()
340
+ self.timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
341
+ self._shadow_timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
342
+ self.sig_prof_records:list[tuple[HCQSignal, HCQSignal, str, bool]] = []
357
343
 
358
344
  from tinygrad.runtime.graph.hcq import HCQGraph
359
345
  super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
360
346
 
361
- self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferOptions(cpu_access=True))
362
- self.kernargs_ptr:int = self.kernargs_page.va_addr
347
+ self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferSpec(cpu_access=True))
348
+ self.kernargs_allocator:BumpAllocator = BumpAllocator(self.kernargs_page.size, base=cast(int, self.kernargs_page.va_addr), wrap=True)
363
349
  self.devices.append(self)
364
350
 
365
351
  def synchronize(self):
366
- try: self.timeline_signal.wait(self.timeline_value - 1) if not hasattr(self, '_syncdev') else self._syncdev()
352
+ try: self.timeline_signal.wait(self.timeline_value - 1)
367
353
  except RuntimeError as e:
368
354
  if hasattr(self, 'on_device_hang'): self.on_device_hang()
369
355
  else: raise e
370
356
 
371
357
  if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
372
358
  if PROFILE:
373
- self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records]
359
+ Compiled.profile_events += [ProfileRangeEvent(self.device, name, st.timestamp, en.timestamp, cp) for st,en,name,cp in self.sig_prof_records]
374
360
  self.sig_prof_records = []
375
361
 
376
- def _alloc_kernargs(self, alloc_size:int) -> int:
377
- """
378
- Allocates space for arguments passed to the kernel.
379
- """
380
- if self.kernargs_ptr >= (self.kernargs_page.va_addr + self.kernargs_page.size - alloc_size): self.kernargs_ptr = self.kernargs_page.va_addr
381
- self.kernargs_ptr = (res:=self.kernargs_ptr) + alloc_size
382
- return res
383
-
384
- def _ensure_shared_time_base(self):
385
- if not self.gpu2cpu_compute_time_diff.is_nan(): return
386
-
387
- def _sync_cpu_queue(d, q_t):
362
+ def _at_profile_finalize(self):
363
+ def _sync(d:HCQCompiled, q_t:Type[HWQueue]):
388
364
  q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
389
365
  d.timeline_value += 1
390
366
  st = time.perf_counter_ns()
@@ -392,134 +368,94 @@ class HCQCompiled(Compiled):
392
368
  et = time.perf_counter_ns()
393
369
  return (decimal.Decimal(et+st) / 2000) - d.timeline_signal.timestamp
394
370
 
395
- # randomly sample the timing from GPU to CPU
396
- choices: List = [(d, d.hw_compute_queue_t, []) for d in self.devices]
397
- choices += [(d, d.hw_copy_queue_t, []) for d in self.devices if d.hw_copy_queue_t is not None]
398
- for _ in range(100*len(self.devices)):
399
- d,q,l = random.choice(choices)
400
- l.append(_sync_cpu_queue(d,q))
401
- for d,q,l in choices:
402
- if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l)
403
- if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l)
404
-
405
- def _sync_gpu_to_gpu_queue(d1, d2, q1_t, q2_t):
406
- q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
407
- .timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
408
- q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \
409
- .timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2)
410
- d1.timeline_value += 2
411
- d2.timeline_value += 2
412
- d1.timeline_signal.wait(d1.timeline_value - 1)
413
- d2.timeline_signal.wait(d2.timeline_value - 1)
414
- return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp
415
-
416
- # then test it by timing the GPU to GPU times
417
- jitter_matrix = [[float('nan')]*len(self.devices) for _ in range(len(self.devices))]
418
- for i1, d1 in enumerate(self.devices):
419
- for i2, d2 in enumerate(self.devices):
420
- if d1 == d2: continue
421
- d1_to_d2 = statistics.median(_sync_gpu_to_gpu_queue(d1, d2, d1.hw_compute_queue_t, d2.hw_compute_queue_t) - \
422
- _sync_gpu_to_gpu_queue(d2, d1, d2.hw_compute_queue_t, d1.hw_compute_queue_t) for _ in range(20)) / 2
423
- jitter_matrix[i1][i2] = d1_to_d2 - (d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff)
424
- print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix]))
425
-
426
- def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
427
- """
428
- Translates local gpu time (timestamp) into global cpu time.
429
- """
430
- self._ensure_shared_time_base()
431
- return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
432
-
433
- def _prof_setup(self):
434
- if hasattr(self, 'profile_logger'): return
435
- atexit.register(self._prof_finalize)
436
- self.profile_logger = ProfileLogger()
437
-
438
- def _prof_finalize(self):
439
- qname = ["COMPUTE", "DMA"]
440
-
441
- # Sync to be sure all events on the device are recorded.
442
- self.synchronize()
443
-
444
- for st, en, name, is_cp, args in self.raw_prof_records:
445
- self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp], args)]
446
- for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records:
447
- # Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint.
448
- a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy)
449
- self.profile_logger.deps += [(a_tm, b_tm, a_dev.dname, qname[a_is_copy], b_dev.dname, qname[b_is_copy])]
450
- self.raw_prof_records, self.dep_prof_records = [], []
451
-
452
- # Remove the logger, this flushes all data written by the device.
453
- del self.profile_logger
371
+ gpu2cpu_compute_time_diff = statistics.median([_sync(self, self.hw_compute_queue_t) for _ in range(40)])
372
+ if self.hw_copy_queue_t is None: gpu2cpu_copy_time_diff = decimal.Decimal(0)
373
+ else: gpu2cpu_copy_time_diff = statistics.median([_sync(self, self.hw_copy_queue_t) for _ in range(40)])
374
+ Compiled.profile_events += [ProfileDeviceEvent(self.device, gpu2cpu_compute_time_diff, gpu2cpu_copy_time_diff)]
454
375
 
455
376
  def _wrap_timeline_signal(self):
456
377
  self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
457
378
  self.timeline_signal.value = 0
458
- cast(HCQAllocator, self.allocator).b_timeline = [0] * len(cast(HCQAllocator, self.allocator).b)
379
+ cast(HCQAllocatorBase, self.allocator).b_timeline = [0] * len(cast(HCQAllocatorBase, self.allocator).b)
459
380
 
460
- # Protocol for hcq compatible allocators for allocated buffers to contain VA address and it's size.
461
- class HCQBuffer(Protocol): va_addr:int; size:int # noqa: E702
381
+ def _realloc(self, oldbuf:HCQBuffer|None, new_size:int, options:BufferSpec|None=None) -> tuple[HCQBuffer, bool]:
382
+ if oldbuf is not None: self.allocator.free(oldbuf, oldbuf.size, options=options)
383
+ try: buf, realloced = self.allocator.alloc(new_size, options=options), True
384
+ except MemoryError: buf, realloced = self.allocator.alloc(oldbuf.size if oldbuf is not None else new_size, options=options), False
385
+ return buf, realloced
462
386
 
463
- class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method
387
+ class HCQBuffer:
388
+ def __init__(self, va_addr:sint, size:int, texture_info:Any=None, meta:Any=None, _base:HCQBuffer|None=None):
389
+ self.va_addr, self.size, self.texture_info, self.meta, self._base = va_addr, size, texture_info, meta, _base
390
+
391
+ class HCQAllocatorBase(LRUAllocator, Generic[DeviceType]):
464
392
  """
465
393
  A base allocator class compatible with the HCQ (Hardware Command Queue) API.
466
394
 
467
- This class implements basic copy operations following the HCQ API, utilizing both `HWComputeQueue` and `HWCopyQueue`.
395
+ This class implements basic copy operations following the HCQ API, utilizing both types of `HWQueue`.
468
396
  """
469
397
 
470
- def __init__(self, device:HCQCompiled, batch_size:int=(2 << 20), batch_cnt:int=32):
471
- self.device:Any = device
472
- self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
398
+ def __init__(self, dev:DeviceType, batch_size:int=(2 << 20), batch_cnt:int=32):
399
+ self.dev:DeviceType = dev
400
+ self.b = [self._alloc(batch_size, BufferSpec(host=True)) for _ in range(batch_cnt)]
473
401
  self.b_timeline, self.b_next = [0] * len(self.b), 0
474
402
  super().__init__()
475
403
 
476
- def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer: raise NotImplementedError("need hcq compat alloc")
404
+ def map(self, buf:HCQBuffer): pass
405
+
406
+ def _offset(self, buf, size:int, offset:int) -> HCQBuffer:
407
+ return HCQBuffer(va_addr=buf.va_addr + offset, size=size, texture_info=buf.texture_info, meta=buf.meta, _base=buf._base or buf)
477
408
 
478
- def copyin(self, dest:HCQBuffer, src:memoryview):
479
- with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
409
+ class HCQAllocator(HCQAllocatorBase, Generic[DeviceType]):
410
+ def _copyin(self, dest:HCQBuffer, src:memoryview):
411
+ assert self.dev.hw_copy_queue_t is not None
412
+ with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"CPU -> {self.dev.device}", enabled=PROFILE):
480
413
  for i in range(0, src.nbytes, self.b[0].size):
481
414
  self.b_next = (self.b_next + 1) % len(self.b)
482
- self.device.timeline_signal.wait(self.b_timeline[self.b_next])
415
+ self.dev.timeline_signal.wait(self.b_timeline[self.b_next])
483
416
  ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
484
- self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
485
- .copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
486
- .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
487
- self.b_timeline[self.b_next] = self.device.timeline_value
488
- self.device.timeline_value += 1
417
+ self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
418
+ .copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
419
+ .signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
420
+ self.b_timeline[self.b_next] = self.dev.timeline_value
421
+ self.dev.timeline_value += 1
489
422
 
490
423
  def copy_from_disk(self, dest:HCQBuffer, src, size):
491
424
  def _get_temp_buf():
492
425
  # Check if the next buffer is safe to be used (its signal has passed) and reserve it.
493
- if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device.timeline_signal.value:
426
+ if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.dev.timeline_signal.value:
494
427
  self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
495
428
  return (self.b[self.b_next].va_addr, self.b_next)
496
429
  return None
497
430
 
498
- with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
431
+ assert self.dev.hw_copy_queue_t is not None
432
+ with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"DISK -> {self.dev.device}", enabled=PROFILE):
499
433
  for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
500
- self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
501
- .copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
502
- .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
503
- self.b_timeline[batch_info[1]] = self.device.timeline_value
504
- self.device.timeline_value += 1
434
+ self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
435
+ .copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
436
+ .signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
437
+ self.b_timeline[batch_info[1]] = self.dev.timeline_value
438
+ self.dev.timeline_value += 1
505
439
 
506
- def copyout(self, dest:memoryview, src:HCQBuffer):
507
- self.device.synchronize()
440
+ def _copyout(self, dest:memoryview, src:HCQBuffer):
441
+ self.dev.synchronize()
508
442
 
509
- with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
443
+ assert self.dev.hw_copy_queue_t is not None
444
+ with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"{self.dev.device} -> CPU", enabled=PROFILE):
510
445
  for i in range(0, dest.nbytes, self.b[0].size):
511
- self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
512
- .copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
513
- .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
514
- self.device.timeline_signal.wait(self.device.timeline_value)
515
- self.device.timeline_value += 1
446
+ self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
447
+ .copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
448
+ .signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
449
+ self.dev.timeline_signal.wait(self.dev.timeline_value)
450
+ self.dev.timeline_value += 1
516
451
 
517
452
  ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
518
453
 
519
- def transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev, dest_dev):
520
- src_dev.allocator.map(dest)
454
+ def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:DeviceType, dest_dev:DeviceType):
455
+ cast(HCQAllocator, src_dev.allocator).map(dest)
521
456
 
522
- with hcq_profile(src_dev, queue_type=src_dev.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
457
+ assert src_dev.hw_copy_queue_t is not None
458
+ with hcq_profile(src_dev, queue_type=src_dev.hw_copy_queue_t, desc=f"{src_dev.device} -> {dest_dev.device}", enabled=PROFILE):
523
459
  src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
524
460
  .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
525
461
  .copy(dest.va_addr, src.va_addr, sz) \
@@ -531,9 +467,3 @@ class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method
531
467
  .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
532
468
  .signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
533
469
  dest_dev.timeline_value += 1
534
-
535
- def map(self, buf:HCQBuffer): pass
536
-
537
- def offset(self, buf, size:int, offset:int) -> HCQBuffer:
538
- return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
539
- **{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)