tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -1,48 +1,106 @@
1
1
  from __future__ import annotations
2
- from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type, Union
3
- import contextlib, decimal, statistics, random, json, atexit, time, array, ctypes
4
- from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv
2
+ from typing import cast, Type, TypeVar, Generic, Any
3
+ import contextlib, decimal, statistics, time, ctypes, array, os, fcntl
4
+ from tinygrad.helpers import PROFILE, from_mv, getenv, to_mv, round_up
5
5
  from tinygrad.renderer import Renderer
6
- from tinygrad.device import BufferOptions, Allocator, Compiler, Compiled, LRUAllocator
6
+ from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileRangeEvent, ProfileDeviceEvent
7
+ from tinygrad.ops import sym_infer, sint, Variable, UOp
8
+ from tinygrad.runtime.autogen import libc
7
9
 
8
- # **************** for HCQ Compatible Devices ****************
9
-
10
- def hcq_command(func):
10
+ class HWInterface:
11
11
  """
12
- Decorator for HWCommandQueue commands. Enables command indexing and stores metadata for command updates.
13
-
14
- For example:
15
- ```python
16
- @hcq_command
17
- def command_method(self, ...): ...
18
- ```
12
+ Hardware Abstraction Layer for HCQ devices. The class provides a unified interface for interacting with hardware devices.
19
13
  """
20
- def __wrapper(self, *args, **kwargs):
21
- self.cmds_offset.append(len(self.q))
22
- func(self, *args, **kwargs)
23
- self.cmds_len.append(len(self.q) - self.cmds_offset[-1])
24
- self.cmds_meta.append(func.__name__)
25
- return self
26
- return __wrapper
27
14
 
28
- class HWCommandQueue:
15
+ def __init__(self, path:str="", flags:int=os.O_RDONLY, fd:int|None=None):
16
+ self.path:str = path
17
+ self.fd:int = fd or os.open(path, flags)
18
+ def __del__(self):
19
+ if hasattr(self, 'fd'): os.close(self.fd)
20
+ def ioctl(self, request, arg): return fcntl.ioctl(self.fd, request, arg)
21
+ def mmap(self, start, sz, prot, flags, offset): return libc.mmap(start, sz, prot, flags, self.fd, offset)
22
+ def read(self, size=None, binary=False, offset=None):
23
+ if offset is not None: self.seek(offset)
24
+ with open(self.fd, "rb" if binary else "r", closefd=False) as file: return file.read(size)
25
+ def write(self, content, binary=False, offset=None):
26
+ if offset is not None: self.seek(offset)
27
+ with open(self.fd, "wb" if binary else "w", closefd=False) as file: file.write(content)
28
+ def listdir(self): return os.listdir(self.path)
29
+ def seek(self, offset): os.lseek(self.fd, offset, os.SEEK_SET)
30
+ @staticmethod
31
+ def anon_mmap(start, sz, prot, flags, offset): return libc.mmap(start, sz, prot, flags, -1, offset)
32
+ @staticmethod
33
+ def munmap(buf, sz): return libc.munmap(buf, sz)
34
+ @staticmethod
35
+ def exists(path): return os.path.exists(path)
36
+ @staticmethod
37
+ def readlink(path): return os.readlink(path)
38
+ @staticmethod
39
+ def eventfd(initval, flags=None): return HWInterface(fd=os.eventfd(initval, flags)) # type: ignore[attr-defined]
40
+
41
+ if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.mockgpu import MockHWInterface as HWInterface # noqa: F401 # pylint: disable=unused-import
42
+
43
+ # **************** for HCQ Compatible Devices ****************
44
+
45
+ SignalType = TypeVar('SignalType', bound='HCQSignal')
46
+ DeviceType = TypeVar('DeviceType', bound='HCQCompiled')
47
+ ProgramType = TypeVar('ProgramType', bound='HCQProgram')
48
+ ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState')
49
+ QueueType = TypeVar('QueueType', bound='HWQueue')
50
+
51
+ class BumpAllocator:
52
+ def __init__(self, size:int, base:int=0, wrap:bool=True): self.size, self.ptr, self.base, self.wrap = size, 0, base, wrap
53
+ def alloc(self, size:int, alignment:int=1) -> int:
54
+ if round_up(self.ptr, alignment) + size > self.size:
55
+ if not self.wrap: raise RuntimeError("Out of memory")
56
+ self.ptr = 0
57
+ self.ptr = (res:=round_up(self.ptr, alignment)) + size
58
+ return res + self.base
59
+
60
+ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
29
61
  """
30
62
  A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
31
- Both compute and copy queues should have the following commands implemented.
32
63
  """
33
64
 
34
- def __init__(self): self.q, self.binded_device, self.cmds_offset, self.cmds_len, self.cmds_meta = [], None, [], [], []
35
- def __len__(self): return len(self.cmds_offset)
36
- def _patch(self, cmd_idx, offset, data): self.q[(st:=self.cmds_offset[cmd_idx]+offset):st+len(data)] = array.array('I', data)
37
- def _cur_cmd_idx(self) -> int:
65
+ def __init__(self):
66
+ self._q:Any = []
67
+ self.binded_device:DeviceType|None = None
68
+ self.q_sints:list[tuple[int, int]] = []
69
+ self.mv_sints:list[tuple[memoryview, int, int, int|None]] = []
70
+ self.syms:list[sint] = []
71
+ self._prev_resolved_syms:list[int|None] = []
72
+
73
+ def _new_sym(self, sym:sint) -> int:
74
+ if sym not in self.syms:
75
+ self.syms.append(sym)
76
+ self._prev_resolved_syms.append(None)
77
+ return self.syms.index(sym)
78
+
79
+ def q(self, *values):
38
80
  """
39
- Returns the index of the command currently being enqueued.
40
- Should be called only within functions that enqueue commands and are decorated with `@hcq_command`.
81
+ Enqueues values in the queue.
82
+
83
+ Args:
84
+ values: The values to enqueue in the queue.
41
85
  """
42
- return len(self) - 1
43
86
 
44
- @hcq_command
45
- def signal(self, signal:HCQSignal, value:int):
87
+ for v in values:
88
+ if isinstance(v, UOp):
89
+ self.q_sints.append((len(self._q), self._new_sym(v)))
90
+ self._q.append(0xbadc0ded)
91
+ else: self._q.append(v)
92
+
93
+ # *** common commands ***
94
+
95
+ def timestamp(self, signal:SignalType):
96
+ """
97
+ Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
98
+
99
+ Args:
100
+ signal: The signal to store the timestamp
101
+ """
102
+
103
+ def signal(self, signal:SignalType, value:sint):
46
104
  """
47
105
  Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
48
106
 
@@ -50,11 +108,8 @@ class HWCommandQueue:
50
108
  signal: The signal to set
51
109
  value: The value to set the signal to
52
110
  """
53
- self._signal(signal, value)
54
- def _signal(self, signal:HCQSignal, value:int): raise NotImplementedError("backend should overload this function")
55
111
 
56
- @hcq_command
57
- def wait(self, signal:HCQSignal, value:int):
112
+ def wait(self, signal:SignalType, value:sint):
58
113
  """
59
114
  Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
60
115
 
@@ -62,49 +117,40 @@ class HWCommandQueue:
62
117
  signal: The signal to wait on
63
118
  value: The value to wait for
64
119
  """
65
- self._wait(signal, value)
66
- def _wait(self, signal, value): raise NotImplementedError("backend should overload this function")
67
120
 
68
- @hcq_command
69
- def timestamp(self, signal:HCQSignal):
70
- """
71
- Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
121
+ # *** commands for compute queues ***
72
122
 
73
- Args:
74
- signal: The signal to store the timestamp
123
+ def memory_barrier(self):
124
+ """
125
+ Enqueues a memory barrier command to ensure memory coherence between agents. Only on compute queues.
75
126
  """
76
- self._timestamp(signal)
77
- def _timestamp(self, signal): raise NotImplementedError("backend should overload this function")
78
127
 
79
- def update_signal(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
128
+ def exec(self, prg:ProgramType, args_state:ArgsStateType, global_size:tuple[sint, ...], local_size:tuple[sint, ...]):
80
129
  """
81
- Updates a previously queued signal command.
130
+ Enqueues an execution command for a kernel program. Only on compute queues.
82
131
 
83
132
  Args:
84
- cmd_idx: Index of the signal command to update
85
- signal: New signal to set (if None, keeps the original)
86
- value: New value to set (if None, keeps the original)
133
+ prg: The program to execute
134
+ args_state: The args state to execute program with
135
+ global_size: The global work size
136
+ local_size: The local work size
87
137
  """
88
- if self.cmds_meta[cmd_idx] != "signal": raise RuntimeError("called update_signal not on a signal command")
89
- self._update_signal(cmd_idx, signal, value)
90
- return self
91
- def _update_signal(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
92
138
 
93
- def update_wait(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
139
+ # *** commands for copy queues ***
140
+
141
+ def copy(self, dest:sint, src:sint, copy_size:int):
94
142
  """
95
- Updates a previously queued wait command.
143
+ Enqueues a copy command to transfer data. Only on copy queues.
96
144
 
97
145
  Args:
98
- cmd_idx: Index of the wait command to update
99
- signal: New signal to wait on (if None, keeps the original)
100
- value: New value to wait for (if None, keeps the original)
146
+ dest: The destination of the copy
147
+ src: The source of the copy
148
+ copy_size: The size of data to copy
101
149
  """
102
- if self.cmds_meta[cmd_idx] != "wait": raise RuntimeError("called update_wait not on a wait command")
103
- self._update_wait(cmd_idx, signal, value)
104
- return self
105
- def _update_wait(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
106
150
 
107
- def bind(self, device:HCQCompiled):
151
+ # *** submit and bind commands ***
152
+
153
+ def bind(self, dev:DeviceType):
108
154
  """
109
155
  Associates the queue with a specific device for optimized execution.
110
156
 
@@ -112,99 +158,65 @@ class HWCommandQueue:
112
158
  the need to copy queues into the device, thereby enhancing performance.
113
159
 
114
160
  Args:
115
- device: The target device for queue optimization.
161
+ dev: The target device for queue optimization.
116
162
 
117
163
  Note:
118
164
  Implementing this method is optional but recommended for performance gains.
119
165
  """
120
166
 
121
- def submit(self, device:HCQCompiled):
122
- """
123
- Submits the command queue to a specific device for execution.
167
+ def bind_args_state(self, args_state:ArgsStateType):
168
+ for vals, ptr, fmt in args_state.bind_data: self.bind_sints_to_ptr(*vals, ptr=ptr, fmt=fmt)
124
169
 
125
- Args:
126
- device: The device to submit the queue to
127
- """
128
- if self.q: self._submit(device)
129
- return self
130
- def _submit(self, device:HCQCompiled): raise NotImplementedError("backend should overload this function")
170
+ def bind_sints(self, *vals:sint, struct:ctypes.Structure, start_field:str, fmt, mask:int|None=None):
171
+ self.bind_sints_to_ptr(*vals, ptr=ctypes.addressof(struct) + getattr(type(struct), start_field).offset, fmt=fmt, mask=mask)
131
172
 
132
- class HWComputeQueue(HWCommandQueue):
133
- @hcq_command
134
- def memory_barrier(self):
135
- """
136
- Enqueues a memory barrier command to ensure memory coherence between agents.
137
- """
138
- self._memory_barrier()
139
- def _memory_barrier(self): pass
173
+ def bind_sints_to_ptr(self, *vals:sint, ptr:int, fmt, mask:int|None=None):
174
+ mv = to_mv(ptr, 8*len(vals)).cast(fmt)
175
+ for i, val in enumerate(vals):
176
+ if isinstance(val, int): mv[i] = val if mask is None else ((mv[i] & ~mask) | val)
177
+ else: self.mv_sints.append((mv, i, self._new_sym(val), mask))
140
178
 
141
- @hcq_command
142
- def exec(self, prg:HCQProgram, args_state:HCQArgsState, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]):
143
- """
144
- Enqueues an execution command for a kernel program.
179
+ def _apply_var_vals(self, var_vals:dict[Variable, int]):
180
+ resolved_syms = [sym_infer(sym, var_vals) for sym in self.syms]
145
181
 
146
- Args:
147
- prg: The program to execute
148
- args_state: The args state to execute program with
149
- global_size: The global work size
150
- local_size: The local work size
151
- """
152
- self._exec(prg, args_state, global_size, local_size)
153
- def _exec(self, prg, args_state, global_size, local_size): raise NotImplementedError("backend should overload this function")
182
+ for off, sym_idx in self.q_sints:
183
+ if self._prev_resolved_syms[sym_idx] == resolved_syms[sym_idx]: continue
184
+ self._q[off] = resolved_syms[sym_idx]
154
185
 
155
- def update_exec(self, cmd_idx:int, global_size:Optional[Tuple[int,int,int]]=None, local_size:Optional[Tuple[int,int,int]]=None):
156
- """
157
- Updates a previously queued execution command.
186
+ for mv, off, sym_idx, mask in self.mv_sints:
187
+ if self._prev_resolved_syms[sym_idx] == resolved_syms[sym_idx]: continue
188
+ mv[off] = resolved_syms[sym_idx] if mask is None else ((mv[off] & ~mask) | resolved_syms[sym_idx])
158
189
 
159
- Args:
160
- cmd_idx: Index of the execution command to update
161
- global_size: New global work size (if None, keeps the original)
162
- local_size: New local work size (if None, keeps the original)
163
- """
164
- if self.cmds_meta[cmd_idx] != "exec": raise RuntimeError("called update_exec not on an exec command")
165
- self._update_exec(cmd_idx, global_size, local_size)
166
- return self
167
- def _update_exec(self, cmd_idx, global_size, local_size): raise NotImplementedError("backend should overload this function")
190
+ self._prev_resolved_syms = cast(list[int|None], resolved_syms)
168
191
 
169
- class HWCopyQueue(HWCommandQueue):
170
- @hcq_command
171
- def copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int):
192
+ def submit(self, dev:DeviceType, var_vals:dict[Variable, int]|None=None):
172
193
  """
173
- Enqueues a copy command to transfer data.
194
+ Submits the command queue to a specific device for execution.
174
195
 
175
196
  Args:
176
- dest: The destination of the copy
177
- src: The source of the copy
178
- copy_size: The size of data to copy
197
+ dev: The device to submit the queue to
179
198
  """
180
- self._copy(dest, src, copy_size)
181
- def _copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int): raise NotImplementedError("backend should overload this function")
182
199
 
183
- def update_copy(self, cmd_idx:int, dest:Optional[HCQBuffer]=None, src:Optional[HCQBuffer]=None):
184
- """
185
- Updates a previously queued copy command.
186
-
187
- Args:
188
- cmd_idx: Index of the copy command to update
189
- dest: New destination of the copy (if None, keeps the original)
190
- src: New source of the copy (if None, keeps the original)
191
- """
192
- if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command")
193
- self._update_copy(cmd_idx, dest, src)
200
+ if var_vals is not None: self._apply_var_vals(var_vals)
201
+ self._submit(dev)
194
202
  return self
195
- def _update_copy(self, cmd_idx, dest, src): raise NotImplementedError("backend should overload this function")
203
+ def _submit(self, dev:DeviceType): raise NotImplementedError("need _submit")
196
204
 
197
- class HCQSignal:
198
- def __init__(self, value:int=0, is_timeline:bool=False): self._set_value(value)
205
+ class HCQSignal(Generic[DeviceType]):
206
+ def __init__(self, base_addr:sint=0, value:int=0, timeline_for_device:DeviceType|None=None, timestamp_divider=1, value_off=0, timestamp_off=8):
207
+ self.base_addr, self.value_addr, self.timestamp_addr = base_addr, base_addr+value_off, base_addr+timestamp_off
208
+ self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider)
209
+ self.timeline_for_device:DeviceType|None = timeline_for_device
210
+
211
+ if isinstance(base_addr, int):
212
+ self.value_mv, self.timestamp_mv = to_mv(self.value_addr, 8).cast('Q'), to_mv(self.timestamp_addr, 8).cast('Q')
213
+ self.value_mv[0] = value
199
214
 
200
215
  @property
201
- def value(self) -> int: return self._get_value()
216
+ def value(self) -> int: return self.value_mv[0]
202
217
 
203
218
  @value.setter
204
- def value(self, new_value:int): self._set_value(new_value)
205
-
206
- def _get_value(self) -> int: raise NotImplementedError("_get_value() method must be implemented")
207
- def _set_value(self, new_value:int): raise NotImplementedError("_set_value() method must be implemented")
219
+ def value(self, new_value:int): self.value_mv[0] = new_value
208
220
 
209
221
  @property
210
222
  def timestamp(self) -> decimal.Decimal:
@@ -216,8 +228,12 @@ class HCQSignal:
216
228
  Returns:
217
229
  The timestamp in microseconds.
218
230
  """
219
- return self._get_timestamp()
220
- def _get_timestamp(self) -> decimal.Decimal: raise NotImplementedError("_get_timestamp() method must be implemented")
231
+ return self.timestamp_mv[0] / self.timestamp_divider
232
+
233
+ def _sleep(self, time_spent_waiting_ms:int):
234
+ """
235
+ Optional function which can implement sleep functionality for the signal.
236
+ """
221
237
 
222
238
  def wait(self, value:int, timeout:int=getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000)):
223
239
  """
@@ -227,17 +243,18 @@ class HCQSignal:
227
243
  value: The value to wait for.
228
244
  timeout: Maximum time to wait in milliseconds. Defaults to 10s.
229
245
  """
230
- start_time = time.time() * 1000
231
- while time.time() * 1000 - start_time < timeout:
232
- if self.value >= value: return
233
- raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
246
+ start_time = int(time.perf_counter() * 1000)
247
+ while self.value < value and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout:
248
+ self._sleep(time_spent)
249
+ if self.value < value: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
234
250
 
235
251
  @contextlib.contextmanager
236
- def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
252
+ def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Type[HWQueue]|None=None, queue:HWQueue|None=None):
237
253
  st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None)
238
254
 
239
255
  if enabled and queue is not None: queue.timestamp(st)
240
256
  elif enabled:
257
+ assert queue_type is not None
241
258
  queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
242
259
  dev.timeline_value += 1
243
260
 
@@ -245,21 +262,33 @@ def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
245
262
  finally:
246
263
  if enabled and queue is not None: queue.timestamp(en)
247
264
  elif enabled:
265
+ assert queue_type is not None
248
266
  queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
249
267
  dev.timeline_value += 1
250
268
 
251
- if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
269
+ if enabled and PROFILE: dev.sig_prof_records.append((cast(HCQSignal, st), cast(HCQSignal, en), desc, queue_type is dev.hw_copy_queue_t))
270
+
271
+ class HCQArgsState(Generic[ProgramType]):
272
+ def __init__(self, ptr:int, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=()):
273
+ self.ptr, self.prg = ptr, prg
274
+ self.bind_data:list[tuple[tuple[sint, ...], int, str]] = []
275
+
276
+ def bind_sints_to_ptr(self, *vals:sint, ptr:int, fmt): self.bind_data.append((vals, ptr, fmt))
277
+
278
+ class CLikeArgsState(HCQArgsState[ProgramType]):
279
+ def __init__(self, ptr:int, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=(), prefix:list[int]|None=None):
280
+ super().__init__(ptr, prg, bufs, vals=vals)
281
+
282
+ if prefix is not None: to_mv(self.ptr, len(prefix) * 4).cast('I')[:] = array.array('I', prefix)
252
283
 
253
- class HCQArgsState:
254
- def __init__(self, ptr:int, prg:HCQProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): self.ptr, self.prg = ptr, prg
255
- def update_buffer(self, index:int, buf:HCQBuffer): raise NotImplementedError("need update_buffer")
256
- def update_var(self, index:int, val:int): raise NotImplementedError("need update_var")
284
+ self.bind_sints_to_ptr(*[b.va_addr for b in bufs], ptr=self.ptr + len(prefix or []) * 4, fmt='Q')
285
+ self.bind_sints_to_ptr(*vals, ptr=self.ptr + len(prefix or []) * 4 + len(bufs) * 8, fmt='I')
257
286
 
258
- class HCQProgram:
259
- def __init__(self, args_state_t:Type[HCQArgsState], device:HCQCompiled, name:str, kernargs_alloc_size:int):
260
- self.args_state_t, self.device, self.name, self.kernargs_alloc_size = args_state_t, device, name, kernargs_alloc_size
287
+ class HCQProgram(Generic[DeviceType]):
288
+ def __init__(self, args_state_t:Type[HCQArgsState], dev:DeviceType, name:str, kernargs_alloc_size:int):
289
+ self.args_state_t, self.dev, self.name, self.kernargs_alloc_size = args_state_t, dev, name, kernargs_alloc_size
261
290
 
262
- def fill_kernargs(self, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=(), kernargs_ptr:Optional[int]=None) -> HCQArgsState:
291
+ def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=(), kernargs_ptr:int|None=None) -> HCQArgsState:
263
292
  """
264
293
  Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
265
294
  Args:
@@ -269,10 +298,10 @@ class HCQProgram:
269
298
  Returns:
270
299
  Arguments state with the given buffers and values set for the program.
271
300
  """
272
- return self.args_state_t(kernargs_ptr or self.device._alloc_kernargs(self.kernargs_alloc_size), self, bufs, vals=vals)
301
+ return self.args_state_t(kernargs_ptr or self.dev.kernargs_allocator.alloc(self.kernargs_alloc_size), self, bufs, vals=vals)
273
302
 
274
- def __call__(self, *bufs:HCQBuffer, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1),
275
- vals:Tuple[int, ...]=(), wait:bool=False) -> Optional[float]:
303
+ def __call__(self, *bufs:HCQBuffer, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
304
+ vals:tuple[int, ...]=(), wait:bool=False) -> float|None:
276
305
  """
277
306
  Enqueues the program for execution with the given arguments and dimensions.
278
307
 
@@ -288,103 +317,52 @@ class HCQProgram:
288
317
  """
289
318
 
290
319
  kernargs = self.fill_kernargs(bufs, vals)
291
- q = self.device.hw_compute_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1).memory_barrier()
320
+ q = self.dev.hw_compute_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1).memory_barrier()
292
321
 
293
- with hcq_profile(self.device, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
322
+ with hcq_profile(self.dev, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
294
323
  q.exec(self, kernargs, global_size, local_size)
295
324
 
296
- q.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
297
- self.device.timeline_value += 1
325
+ q.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
326
+ self.dev.timeline_value += 1
298
327
 
299
- if wait: self.device.synchronize()
328
+ if wait: self.dev.synchronize()
300
329
  return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
301
330
 
302
- class ProfileLogger:
303
- writers: int = 0
304
- mjson: List[Dict] = []
305
- actors: Dict[Union[str, Tuple[str, str]], int] = {}
306
-
307
- def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
308
-
309
- def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
310
-
311
- def _ensure_actor(self, actor_name, subactor_name):
312
- if actor_name not in self.actors:
313
- self.actors[actor_name] = (pid:=len(self.actors))
314
- self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
315
-
316
- if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
317
- self.actors[subactor_key] = (tid:=len(self.actors))
318
- self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
319
-
320
- return self.actors[actor_name], self.actors.get(subactor_key, -1)
321
-
322
- def __del__(self):
323
- # perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
324
- for name, st, et, actor_name, subactor_name, args in self.events:
325
- pid, tid = self._ensure_actor(actor_name,subactor_name)
326
- args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
327
- self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
328
-
329
- for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
330
- dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
331
- pid, tid = self._ensure_actor(actor_name,subactor_name)
332
- self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
333
- self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})
334
-
335
- ProfileLogger.writers -= 1
336
- if ProfileLogger.writers == 0 and len(self.mjson) > 0:
337
- with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
338
- print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
339
-
340
- class HCQCompiled(Compiled):
331
+ class HCQCompiled(Compiled, Generic[SignalType]):
341
332
  """
342
333
  A base class for devices compatible with the HCQ (Hardware Command Queue) API.
343
334
  """
344
- devices: List[HCQCompiled] = []
345
- gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
346
- gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
335
+ devices: list[HCQCompiled] = []
347
336
 
348
- def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[HCQSignal],
349
- comp_queue_t:Type[HWComputeQueue], copy_queue_t:Optional[Type[HWCopyQueue]]):
337
+ def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
338
+ comp_queue_t:Type[HWQueue], copy_queue_t:Type[HWQueue]|None):
339
+ self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
350
340
  self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
351
341
  self.timeline_value:int = 1
352
- self.timeline_signal, self._shadow_timeline_signal = self.signal_t(0, is_timeline=True), self.signal_t(0, is_timeline=True)
353
- self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
354
- self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = []
355
- self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
356
- if PROFILE: self._prof_setup()
342
+ self.timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
343
+ self._shadow_timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
344
+ self.sig_prof_records:list[tuple[HCQSignal, HCQSignal, str, bool]] = []
357
345
 
358
346
  from tinygrad.runtime.graph.hcq import HCQGraph
359
347
  super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
360
348
 
361
- self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferOptions(cpu_access=True))
362
- self.kernargs_ptr:int = self.kernargs_page.va_addr
349
+ self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferSpec(cpu_access=True))
350
+ self.kernargs_allocator:BumpAllocator = BumpAllocator(self.kernargs_page.size, base=cast(int, self.kernargs_page.va_addr), wrap=True)
363
351
  self.devices.append(self)
364
352
 
365
353
  def synchronize(self):
366
- try: self.timeline_signal.wait(self.timeline_value - 1) if not hasattr(self, '_syncdev') else self._syncdev()
354
+ try: self.timeline_signal.wait(self.timeline_value - 1)
367
355
  except RuntimeError as e:
368
356
  if hasattr(self, 'on_device_hang'): self.on_device_hang()
369
357
  else: raise e
370
358
 
371
359
  if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
372
360
  if PROFILE:
373
- self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records]
361
+ Compiled.profile_events += [ProfileRangeEvent(self.device, name, st.timestamp, en.timestamp, cp) for st,en,name,cp in self.sig_prof_records]
374
362
  self.sig_prof_records = []
375
363
 
376
- def _alloc_kernargs(self, alloc_size:int) -> int:
377
- """
378
- Allocates space for arguments passed to the kernel.
379
- """
380
- if self.kernargs_ptr >= (self.kernargs_page.va_addr + self.kernargs_page.size - alloc_size): self.kernargs_ptr = self.kernargs_page.va_addr
381
- self.kernargs_ptr = (res:=self.kernargs_ptr) + alloc_size
382
- return res
383
-
384
- def _ensure_shared_time_base(self):
385
- if not self.gpu2cpu_compute_time_diff.is_nan(): return
386
-
387
- def _sync_cpu_queue(d, q_t):
364
+ def _at_profile_finalize(self):
365
+ def _sync(d:HCQCompiled, q_t:Type[HWQueue]):
388
366
  q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
389
367
  d.timeline_value += 1
390
368
  st = time.perf_counter_ns()
@@ -392,134 +370,94 @@ class HCQCompiled(Compiled):
392
370
  et = time.perf_counter_ns()
393
371
  return (decimal.Decimal(et+st) / 2000) - d.timeline_signal.timestamp
394
372
 
395
- # randomly sample the timing from GPU to CPU
396
- choices: List = [(d, d.hw_compute_queue_t, []) for d in self.devices]
397
- choices += [(d, d.hw_copy_queue_t, []) for d in self.devices if d.hw_copy_queue_t is not None]
398
- for _ in range(100*len(self.devices)):
399
- d,q,l = random.choice(choices)
400
- l.append(_sync_cpu_queue(d,q))
401
- for d,q,l in choices:
402
- if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l)
403
- if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l)
404
-
405
- def _sync_gpu_to_gpu_queue(d1, d2, q1_t, q2_t):
406
- q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
407
- .timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
408
- q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \
409
- .timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2)
410
- d1.timeline_value += 2
411
- d2.timeline_value += 2
412
- d1.timeline_signal.wait(d1.timeline_value - 1)
413
- d2.timeline_signal.wait(d2.timeline_value - 1)
414
- return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp
415
-
416
- # then test it by timing the GPU to GPU times
417
- jitter_matrix = [[float('nan')]*len(self.devices) for _ in range(len(self.devices))]
418
- for i1, d1 in enumerate(self.devices):
419
- for i2, d2 in enumerate(self.devices):
420
- if d1 == d2: continue
421
- d1_to_d2 = statistics.median(_sync_gpu_to_gpu_queue(d1, d2, d1.hw_compute_queue_t, d2.hw_compute_queue_t) - \
422
- _sync_gpu_to_gpu_queue(d2, d1, d2.hw_compute_queue_t, d1.hw_compute_queue_t) for _ in range(20)) / 2
423
- jitter_matrix[i1][i2] = d1_to_d2 - (d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff)
424
- print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix]))
425
-
426
- def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
427
- """
428
- Translates local gpu time (timestamp) into global cpu time.
429
- """
430
- self._ensure_shared_time_base()
431
- return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
432
-
433
- def _prof_setup(self):
434
- if hasattr(self, 'profile_logger'): return
435
- atexit.register(self._prof_finalize)
436
- self.profile_logger = ProfileLogger()
437
-
438
- def _prof_finalize(self):
439
- qname = ["COMPUTE", "DMA"]
440
-
441
- # Sync to be sure all events on the device are recorded.
442
- self.synchronize()
443
-
444
- for st, en, name, is_cp, args in self.raw_prof_records:
445
- self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp], args)]
446
- for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records:
447
- # Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint.
448
- a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy)
449
- self.profile_logger.deps += [(a_tm, b_tm, a_dev.dname, qname[a_is_copy], b_dev.dname, qname[b_is_copy])]
450
- self.raw_prof_records, self.dep_prof_records = [], []
451
-
452
- # Remove the logger, this flushes all data written by the device.
453
- del self.profile_logger
373
+ gpu2cpu_compute_time_diff = statistics.median([_sync(self, self.hw_compute_queue_t) for _ in range(40)])
374
+ if self.hw_copy_queue_t is None: gpu2cpu_copy_time_diff = decimal.Decimal(0)
375
+ else: gpu2cpu_copy_time_diff = statistics.median([_sync(self, self.hw_copy_queue_t) for _ in range(40)])
376
+ Compiled.profile_events += [ProfileDeviceEvent(self.device, gpu2cpu_compute_time_diff, gpu2cpu_copy_time_diff)]
454
377
 
455
378
  def _wrap_timeline_signal(self):
456
379
  self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
457
380
  self.timeline_signal.value = 0
458
- cast(HCQAllocator, self.allocator).b_timeline = [0] * len(cast(HCQAllocator, self.allocator).b)
381
+ cast(HCQAllocatorBase, self.allocator).b_timeline = [0] * len(cast(HCQAllocatorBase, self.allocator).b)
459
382
 
460
- # Protocol for hcq compatible allocators for allocated buffers to contain VA address and it's size.
461
- class HCQBuffer(Protocol): va_addr:int; size:int # noqa: E702
383
+ def _realloc(self, oldbuf:HCQBuffer|None, new_size:int, options:BufferSpec|None=None) -> tuple[HCQBuffer, bool]:
384
+ if oldbuf is not None: self.allocator.free(oldbuf, oldbuf.size, options=options)
385
+ try: buf, realloced = self.allocator.alloc(new_size, options=options), True
386
+ except MemoryError: buf, realloced = self.allocator.alloc(oldbuf.size if oldbuf is not None else new_size, options=options), False
387
+ return buf, realloced
462
388
 
463
- class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method
389
+ class HCQBuffer:
390
+ def __init__(self, va_addr:sint, size:int, texture_info:Any=None, meta:Any=None, _base:HCQBuffer|None=None):
391
+ self.va_addr, self.size, self.texture_info, self.meta, self._base = va_addr, size, texture_info, meta, _base
392
+
393
+ class HCQAllocatorBase(LRUAllocator, Generic[DeviceType]):
464
394
  """
465
395
  A base allocator class compatible with the HCQ (Hardware Command Queue) API.
466
396
 
467
- This class implements basic copy operations following the HCQ API, utilizing both `HWComputeQueue` and `HWCopyQueue`.
397
+ This class implements basic copy operations following the HCQ API, utilizing both types of `HWQueue`.
468
398
  """
469
399
 
470
- def __init__(self, device:HCQCompiled, batch_size:int=(2 << 20), batch_cnt:int=32):
471
- self.device:Any = device
472
- self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
400
+ def __init__(self, dev:DeviceType, batch_size:int=(2 << 20), batch_cnt:int=32):
401
+ self.dev:DeviceType = dev
402
+ self.b = [self._alloc(batch_size, BufferSpec(host=True)) for _ in range(batch_cnt)]
473
403
  self.b_timeline, self.b_next = [0] * len(self.b), 0
474
404
  super().__init__()
475
405
 
476
- def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer: raise NotImplementedError("need hcq compat alloc")
406
+ def map(self, buf:HCQBuffer): pass
407
+
408
+ def _offset(self, buf, size:int, offset:int) -> HCQBuffer:
409
+ return HCQBuffer(va_addr=buf.va_addr + offset, size=size, texture_info=buf.texture_info, meta=buf.meta, _base=buf._base or buf)
477
410
 
478
- def copyin(self, dest:HCQBuffer, src:memoryview):
479
- with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
411
+ class HCQAllocator(HCQAllocatorBase, Generic[DeviceType]):
412
+ def _copyin(self, dest:HCQBuffer, src:memoryview):
413
+ assert self.dev.hw_copy_queue_t is not None
414
+ with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"CPU -> {self.dev.device}", enabled=PROFILE):
480
415
  for i in range(0, src.nbytes, self.b[0].size):
481
416
  self.b_next = (self.b_next + 1) % len(self.b)
482
- self.device.timeline_signal.wait(self.b_timeline[self.b_next])
417
+ self.dev.timeline_signal.wait(self.b_timeline[self.b_next])
483
418
  ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
484
- self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
485
- .copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
486
- .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
487
- self.b_timeline[self.b_next] = self.device.timeline_value
488
- self.device.timeline_value += 1
419
+ self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
420
+ .copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
421
+ .signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
422
+ self.b_timeline[self.b_next] = self.dev.timeline_value
423
+ self.dev.timeline_value += 1
489
424
 
490
425
  def copy_from_disk(self, dest:HCQBuffer, src, size):
491
426
  def _get_temp_buf():
492
427
  # Check if the next buffer is safe to be used (its signal has passed) and reserve it.
493
- if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device.timeline_signal.value:
428
+ if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.dev.timeline_signal.value:
494
429
  self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
495
430
  return (self.b[self.b_next].va_addr, self.b_next)
496
431
  return None
497
432
 
498
- with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
433
+ assert self.dev.hw_copy_queue_t is not None
434
+ with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"DISK -> {self.dev.device}", enabled=PROFILE):
499
435
  for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
500
- self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
501
- .copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
502
- .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
503
- self.b_timeline[batch_info[1]] = self.device.timeline_value
504
- self.device.timeline_value += 1
436
+ self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
437
+ .copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
438
+ .signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
439
+ self.b_timeline[batch_info[1]] = self.dev.timeline_value
440
+ self.dev.timeline_value += 1
505
441
 
506
- def copyout(self, dest:memoryview, src:HCQBuffer):
507
- self.device.synchronize()
442
+ def _copyout(self, dest:memoryview, src:HCQBuffer):
443
+ self.dev.synchronize()
508
444
 
509
- with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
445
+ assert self.dev.hw_copy_queue_t is not None
446
+ with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"{self.dev.device} -> CPU", enabled=PROFILE):
510
447
  for i in range(0, dest.nbytes, self.b[0].size):
511
- self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
512
- .copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
513
- .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
514
- self.device.timeline_signal.wait(self.device.timeline_value)
515
- self.device.timeline_value += 1
448
+ self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
449
+ .copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
450
+ .signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
451
+ self.dev.timeline_signal.wait(self.dev.timeline_value)
452
+ self.dev.timeline_value += 1
516
453
 
517
454
  ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
518
455
 
519
- def transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev, dest_dev):
520
- src_dev.allocator.map(dest)
456
+ def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:DeviceType, dest_dev:DeviceType):
457
+ cast(HCQAllocator, src_dev.allocator).map(dest)
521
458
 
522
- with hcq_profile(src_dev, queue_type=src_dev.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
459
+ assert src_dev.hw_copy_queue_t is not None
460
+ with hcq_profile(src_dev, queue_type=src_dev.hw_copy_queue_t, desc=f"{src_dev.device} -> {dest_dev.device}", enabled=PROFILE):
523
461
  src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
524
462
  .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
525
463
  .copy(dest.va_addr, src.va_addr, sz) \
@@ -531,9 +469,3 @@ class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method
531
469
  .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
532
470
  .signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
533
471
  dest_dev.timeline_value += 1
534
-
535
- def map(self, buf:HCQBuffer): pass
536
-
537
- def offset(self, buf, size:int, offset:int) -> HCQBuffer:
538
- return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
539
- **{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)