tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +114 -172
- tinygrad/codegen/linearize.py +211 -81
- tinygrad/codegen/lowerer.py +30 -35
- tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
- tinygrad/codegen/transcendental.py +12 -13
- tinygrad/device.py +170 -47
- tinygrad/dtype.py +28 -26
- tinygrad/engine/jit.py +80 -63
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +162 -0
- tinygrad/engine/realize.py +58 -107
- tinygrad/engine/schedule.py +381 -314
- tinygrad/engine/search.py +40 -44
- tinygrad/gradient.py +70 -0
- tinygrad/helpers.py +77 -58
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +89 -64
- tinygrad/ops.py +562 -446
- tinygrad/renderer/__init__.py +79 -36
- tinygrad/renderer/cstyle.py +70 -84
- tinygrad/renderer/llvmir.py +32 -20
- tinygrad/renderer/ptx.py +79 -99
- tinygrad/renderer/wgsl.py +87 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libpciaccess.py +2023 -0
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +19 -21
- tinygrad/runtime/ops_amd.py +488 -327
- tinygrad/runtime/ops_clang.py +15 -28
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +129 -38
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +45 -40
- tinygrad/runtime/ops_metal.py +93 -73
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +232 -270
- tinygrad/runtime/ops_python.py +51 -46
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +63 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +384 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +26 -4
- tinygrad/runtime/support/hcq.py +254 -324
- tinygrad/runtime/support/llvm.py +32 -0
- tinygrad/shape/shapetracker.py +84 -53
- tinygrad/shape/view.py +103 -138
- tinygrad/spec.py +154 -0
- tinygrad/tensor.py +744 -496
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
- tinygrad-0.10.1.dist-info/RECORD +86 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/runtime/support/hcq.py
CHANGED
@@ -1,48 +1,104 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import
|
3
|
-
import contextlib, decimal, statistics,
|
4
|
-
from tinygrad.helpers import
|
2
|
+
from typing import cast, Type, TypeVar, Generic, Any
|
3
|
+
import contextlib, decimal, statistics, time, ctypes, array, os, fcntl
|
4
|
+
from tinygrad.helpers import PROFILE, from_mv, getenv, to_mv, round_up
|
5
5
|
from tinygrad.renderer import Renderer
|
6
|
-
from tinygrad.device import
|
6
|
+
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileRangeEvent, ProfileDeviceEvent
|
7
|
+
from tinygrad.ops import sym_infer, sint, Variable
|
8
|
+
from tinygrad.runtime.autogen import libc
|
7
9
|
|
8
|
-
|
9
|
-
|
10
|
-
def hcq_command(func):
|
10
|
+
class HWInterface:
|
11
11
|
"""
|
12
|
-
|
13
|
-
|
14
|
-
For example:
|
15
|
-
```python
|
16
|
-
@hcq_command
|
17
|
-
def command_method(self, ...): ...
|
18
|
-
```
|
12
|
+
Hardware Abstraction Layer for HCQ devices. The class provides a unified interface for interacting with hardware devices.
|
19
13
|
"""
|
20
|
-
def __wrapper(self, *args, **kwargs):
|
21
|
-
self.cmds_offset.append(len(self.q))
|
22
|
-
func(self, *args, **kwargs)
|
23
|
-
self.cmds_len.append(len(self.q) - self.cmds_offset[-1])
|
24
|
-
self.cmds_meta.append(func.__name__)
|
25
|
-
return self
|
26
|
-
return __wrapper
|
27
14
|
|
28
|
-
|
15
|
+
def __init__(self, path:str="", flags:int=os.O_RDONLY, fd:int|None=None):
|
16
|
+
self.path:str = path
|
17
|
+
self.fd:int = fd or os.open(path, flags)
|
18
|
+
def __del__(self):
|
19
|
+
if hasattr(self, 'fd'): os.close(self.fd)
|
20
|
+
def ioctl(self, request, arg): return fcntl.ioctl(self.fd, request, arg)
|
21
|
+
def mmap(self, start, sz, prot, flags, offset): return libc.mmap(start, sz, prot, flags, self.fd, offset)
|
22
|
+
def read(self, size=None, binary=False):
|
23
|
+
with open(self.fd, "rb" if binary else "r", closefd=False) as file: return file.read(size)
|
24
|
+
def write(self, content, binary=False):
|
25
|
+
with open(self.fd, "wb" if binary else "w", closefd=False) as file: file.write(content)
|
26
|
+
def listdir(self): return os.listdir(self.path)
|
27
|
+
def seek(self, offset): os.lseek(self.fd, offset, os.SEEK_SET)
|
28
|
+
@staticmethod
|
29
|
+
def anon_mmap(start, sz, prot, flags, offset): return libc.mmap(start, sz, prot, flags, -1, offset)
|
30
|
+
@staticmethod
|
31
|
+
def munmap(buf, sz): return libc.munmap(buf, sz)
|
32
|
+
@staticmethod
|
33
|
+
def exists(path): return os.path.exists(path)
|
34
|
+
@staticmethod
|
35
|
+
def readlink(path): return os.readlink(path)
|
36
|
+
@staticmethod
|
37
|
+
def eventfd(initval, flags=None): return HWInterface(fd=os.eventfd(initval, flags)) # type: ignore[attr-defined]
|
38
|
+
|
39
|
+
if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.mockgpu import MockHWInterface as HWInterface # noqa: F401 # pylint: disable=unused-import
|
40
|
+
|
41
|
+
# **************** for HCQ Compatible Devices ****************
|
42
|
+
|
43
|
+
SignalType = TypeVar('SignalType', bound='HCQSignal')
|
44
|
+
DeviceType = TypeVar('DeviceType', bound='HCQCompiled')
|
45
|
+
ProgramType = TypeVar('ProgramType', bound='HCQProgram')
|
46
|
+
ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState')
|
47
|
+
QueueType = TypeVar('QueueType', bound='HWQueue')
|
48
|
+
|
49
|
+
class BumpAllocator:
|
50
|
+
def __init__(self, size:int, base:int=0, wrap:bool=True): self.size, self.ptr, self.base, self.wrap = size, 0, base, wrap
|
51
|
+
def alloc(self, size:int, alignment:int=1) -> int:
|
52
|
+
if round_up(self.ptr, alignment) + size > self.size:
|
53
|
+
if not self.wrap: raise RuntimeError("Out of memory")
|
54
|
+
self.ptr = 0
|
55
|
+
self.ptr = (res:=round_up(self.ptr, alignment)) + size
|
56
|
+
return res + self.base
|
57
|
+
|
58
|
+
class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
29
59
|
"""
|
30
60
|
A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
|
31
|
-
Both compute and copy queues should have the following commands implemented.
|
32
61
|
"""
|
33
62
|
|
34
|
-
def __init__(self):
|
35
|
-
|
36
|
-
|
37
|
-
|
63
|
+
def __init__(self):
|
64
|
+
self._q:Any = []
|
65
|
+
self.binded_device:DeviceType|None = None
|
66
|
+
self.q_sints:list[tuple[int, int]] = []
|
67
|
+
self.mv_sints:list[tuple[memoryview, int, int, int|None]] = []
|
68
|
+
self.syms:list[sint] = []
|
69
|
+
self._prev_resolved_syms:list[int|None] = []
|
70
|
+
|
71
|
+
def _new_sym(self, sym:sint) -> int:
|
72
|
+
if sym not in self.syms:
|
73
|
+
self.syms.append(sym)
|
74
|
+
self._prev_resolved_syms.append(None)
|
75
|
+
return self.syms.index(sym)
|
76
|
+
|
77
|
+
def q(self, *values):
|
38
78
|
"""
|
39
|
-
|
40
|
-
|
79
|
+
Enqueues values in the queue.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
values: The values to enqueue in the queue.
|
41
83
|
"""
|
42
|
-
return len(self) - 1
|
43
84
|
|
44
|
-
|
45
|
-
|
85
|
+
for v in values:
|
86
|
+
if isinstance(v, int): self._q.append(v)
|
87
|
+
else:
|
88
|
+
self.q_sints.append((len(self._q), self._new_sym(v)))
|
89
|
+
self._q.append(0xbadc0ded)
|
90
|
+
|
91
|
+
# *** common commands ***
|
92
|
+
|
93
|
+
def timestamp(self, signal:SignalType):
|
94
|
+
"""
|
95
|
+
Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
signal: The signal to store the timestamp
|
99
|
+
"""
|
100
|
+
|
101
|
+
def signal(self, signal:SignalType, value:sint):
|
46
102
|
"""
|
47
103
|
Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
|
48
104
|
|
@@ -50,11 +106,8 @@ class HWCommandQueue:
|
|
50
106
|
signal: The signal to set
|
51
107
|
value: The value to set the signal to
|
52
108
|
"""
|
53
|
-
self._signal(signal, value)
|
54
|
-
def _signal(self, signal:HCQSignal, value:int): raise NotImplementedError("backend should overload this function")
|
55
109
|
|
56
|
-
|
57
|
-
def wait(self, signal:HCQSignal, value:int):
|
110
|
+
def wait(self, signal:SignalType, value:sint):
|
58
111
|
"""
|
59
112
|
Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
|
60
113
|
|
@@ -62,49 +115,40 @@ class HWCommandQueue:
|
|
62
115
|
signal: The signal to wait on
|
63
116
|
value: The value to wait for
|
64
117
|
"""
|
65
|
-
self._wait(signal, value)
|
66
|
-
def _wait(self, signal, value): raise NotImplementedError("backend should overload this function")
|
67
118
|
|
68
|
-
|
69
|
-
def timestamp(self, signal:HCQSignal):
|
70
|
-
"""
|
71
|
-
Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
|
119
|
+
# *** commands for compute queues ***
|
72
120
|
|
73
|
-
|
74
|
-
|
121
|
+
def memory_barrier(self):
|
122
|
+
"""
|
123
|
+
Enqueues a memory barrier command to ensure memory coherence between agents. Only on compute queues.
|
75
124
|
"""
|
76
|
-
self._timestamp(signal)
|
77
|
-
def _timestamp(self, signal): raise NotImplementedError("backend should overload this function")
|
78
125
|
|
79
|
-
def
|
126
|
+
def exec(self, prg:ProgramType, args_state:ArgsStateType, global_size:tuple[sint, ...], local_size:tuple[sint, ...]):
|
80
127
|
"""
|
81
|
-
|
128
|
+
Enqueues an execution command for a kernel program. Only on compute queues.
|
82
129
|
|
83
130
|
Args:
|
84
|
-
|
85
|
-
|
86
|
-
|
131
|
+
prg: The program to execute
|
132
|
+
args_state: The args state to execute program with
|
133
|
+
global_size: The global work size
|
134
|
+
local_size: The local work size
|
87
135
|
"""
|
88
|
-
if self.cmds_meta[cmd_idx] != "signal": raise RuntimeError("called update_signal not on a signal command")
|
89
|
-
self._update_signal(cmd_idx, signal, value)
|
90
|
-
return self
|
91
|
-
def _update_signal(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
|
92
136
|
|
93
|
-
|
137
|
+
# *** commands for copy queues ***
|
138
|
+
|
139
|
+
def copy(self, dest:sint, src:sint, copy_size:int):
|
94
140
|
"""
|
95
|
-
|
141
|
+
Enqueues a copy command to transfer data. Only on copy queues.
|
96
142
|
|
97
143
|
Args:
|
98
|
-
|
99
|
-
|
100
|
-
|
144
|
+
dest: The destination of the copy
|
145
|
+
src: The source of the copy
|
146
|
+
copy_size: The size of data to copy
|
101
147
|
"""
|
102
|
-
if self.cmds_meta[cmd_idx] != "wait": raise RuntimeError("called update_wait not on a wait command")
|
103
|
-
self._update_wait(cmd_idx, signal, value)
|
104
|
-
return self
|
105
|
-
def _update_wait(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
|
106
148
|
|
107
|
-
|
149
|
+
# *** submit and bind commands ***
|
150
|
+
|
151
|
+
def bind(self, dev:DeviceType):
|
108
152
|
"""
|
109
153
|
Associates the queue with a specific device for optimized execution.
|
110
154
|
|
@@ -112,99 +156,65 @@ class HWCommandQueue:
|
|
112
156
|
the need to copy queues into the device, thereby enhancing performance.
|
113
157
|
|
114
158
|
Args:
|
115
|
-
|
159
|
+
dev: The target device for queue optimization.
|
116
160
|
|
117
161
|
Note:
|
118
162
|
Implementing this method is optional but recommended for performance gains.
|
119
163
|
"""
|
120
164
|
|
121
|
-
def
|
122
|
-
|
123
|
-
Submits the command queue to a specific device for execution.
|
165
|
+
def bind_args_state(self, args_state:ArgsStateType):
|
166
|
+
for vals, ptr, fmt in args_state.bind_data: self.bind_sints_to_ptr(*vals, ptr=ptr, fmt=fmt)
|
124
167
|
|
125
|
-
|
126
|
-
|
127
|
-
"""
|
128
|
-
if self.q: self._submit(device)
|
129
|
-
return self
|
130
|
-
def _submit(self, device:HCQCompiled): raise NotImplementedError("backend should overload this function")
|
168
|
+
def bind_sints(self, *vals:sint, struct:ctypes.Structure, start_field:str, fmt, mask:int|None=None):
|
169
|
+
self.bind_sints_to_ptr(*vals, ptr=ctypes.addressof(struct) + getattr(type(struct), start_field).offset, fmt=fmt, mask=mask)
|
131
170
|
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
"""
|
138
|
-
self._memory_barrier()
|
139
|
-
def _memory_barrier(self): pass
|
171
|
+
def bind_sints_to_ptr(self, *vals:sint, ptr:int, fmt, mask:int|None=None):
|
172
|
+
mv = to_mv(ptr, 8*len(vals)).cast(fmt)
|
173
|
+
for i, val in enumerate(vals):
|
174
|
+
if isinstance(val, int): mv[i] = val if mask is None else ((mv[i] & ~mask) | val)
|
175
|
+
else: self.mv_sints.append((mv, i, self._new_sym(val), mask))
|
140
176
|
|
141
|
-
|
142
|
-
|
143
|
-
"""
|
144
|
-
Enqueues an execution command for a kernel program.
|
177
|
+
def _apply_var_vals(self, var_vals:dict[Variable, int]):
|
178
|
+
resolved_syms = [sym_infer(sym, var_vals) for sym in self.syms]
|
145
179
|
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
global_size: The global work size
|
150
|
-
local_size: The local work size
|
151
|
-
"""
|
152
|
-
self._exec(prg, args_state, global_size, local_size)
|
153
|
-
def _exec(self, prg, args_state, global_size, local_size): raise NotImplementedError("backend should overload this function")
|
180
|
+
for off, sym_idx in self.q_sints:
|
181
|
+
if self._prev_resolved_syms[sym_idx] == resolved_syms[sym_idx]: continue
|
182
|
+
self._q[off] = resolved_syms[sym_idx]
|
154
183
|
|
155
|
-
|
156
|
-
|
157
|
-
|
184
|
+
for mv, off, sym_idx, mask in self.mv_sints:
|
185
|
+
if self._prev_resolved_syms[sym_idx] == resolved_syms[sym_idx]: continue
|
186
|
+
mv[off] = resolved_syms[sym_idx] if mask is None else ((mv[off] & ~mask) | resolved_syms[sym_idx])
|
158
187
|
|
159
|
-
|
160
|
-
cmd_idx: Index of the execution command to update
|
161
|
-
global_size: New global work size (if None, keeps the original)
|
162
|
-
local_size: New local work size (if None, keeps the original)
|
163
|
-
"""
|
164
|
-
if self.cmds_meta[cmd_idx] != "exec": raise RuntimeError("called update_exec not on an exec command")
|
165
|
-
self._update_exec(cmd_idx, global_size, local_size)
|
166
|
-
return self
|
167
|
-
def _update_exec(self, cmd_idx, global_size, local_size): raise NotImplementedError("backend should overload this function")
|
188
|
+
self._prev_resolved_syms = cast(list[int|None], resolved_syms)
|
168
189
|
|
169
|
-
|
170
|
-
@hcq_command
|
171
|
-
def copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int):
|
190
|
+
def submit(self, dev:DeviceType, var_vals:dict[Variable, int]|None=None):
|
172
191
|
"""
|
173
|
-
|
192
|
+
Submits the command queue to a specific device for execution.
|
174
193
|
|
175
194
|
Args:
|
176
|
-
|
177
|
-
src: The source of the copy
|
178
|
-
copy_size: The size of data to copy
|
195
|
+
dev: The device to submit the queue to
|
179
196
|
"""
|
180
|
-
self._copy(dest, src, copy_size)
|
181
|
-
def _copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int): raise NotImplementedError("backend should overload this function")
|
182
197
|
|
183
|
-
|
184
|
-
|
185
|
-
Updates a previously queued copy command.
|
186
|
-
|
187
|
-
Args:
|
188
|
-
cmd_idx: Index of the copy command to update
|
189
|
-
dest: New destination of the copy (if None, keeps the original)
|
190
|
-
src: New source of the copy (if None, keeps the original)
|
191
|
-
"""
|
192
|
-
if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command")
|
193
|
-
self._update_copy(cmd_idx, dest, src)
|
198
|
+
if var_vals is not None: self._apply_var_vals(var_vals)
|
199
|
+
self._submit(dev)
|
194
200
|
return self
|
195
|
-
def
|
201
|
+
def _submit(self, dev:DeviceType): raise NotImplementedError("need _submit")
|
196
202
|
|
197
|
-
class HCQSignal:
|
198
|
-
def __init__(self, value:int=0,
|
203
|
+
class HCQSignal(Generic[DeviceType]):
|
204
|
+
def __init__(self, base_addr:sint=0, value:int=0, timeline_for_device:DeviceType|None=None, timestamp_divider=1, value_off=0, timestamp_off=8):
|
205
|
+
self.base_addr, self.value_addr, self.timestamp_addr = base_addr, base_addr+value_off, base_addr+timestamp_off
|
206
|
+
self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider)
|
207
|
+
self.timeline_for_device:DeviceType|None = timeline_for_device
|
208
|
+
|
209
|
+
if isinstance(base_addr, int):
|
210
|
+
self.value_mv, self.timestamp_mv = to_mv(self.value_addr, 8).cast('Q'), to_mv(self.timestamp_addr, 8).cast('Q')
|
211
|
+
self.value_mv[0] = value
|
199
212
|
|
200
213
|
@property
|
201
|
-
def value(self) -> int: return self.
|
214
|
+
def value(self) -> int: return self.value_mv[0]
|
202
215
|
|
203
216
|
@value.setter
|
204
|
-
def value(self, new_value:int): self.
|
205
|
-
|
206
|
-
def _get_value(self) -> int: raise NotImplementedError("_get_value() method must be implemented")
|
207
|
-
def _set_value(self, new_value:int): raise NotImplementedError("_set_value() method must be implemented")
|
217
|
+
def value(self, new_value:int): self.value_mv[0] = new_value
|
208
218
|
|
209
219
|
@property
|
210
220
|
def timestamp(self) -> decimal.Decimal:
|
@@ -216,8 +226,12 @@ class HCQSignal:
|
|
216
226
|
Returns:
|
217
227
|
The timestamp in microseconds.
|
218
228
|
"""
|
219
|
-
return self.
|
220
|
-
|
229
|
+
return self.timestamp_mv[0] / self.timestamp_divider
|
230
|
+
|
231
|
+
def _sleep(self, time_spent_waiting_ms:int):
|
232
|
+
"""
|
233
|
+
Optional function which can implement sleep functionality for the signal.
|
234
|
+
"""
|
221
235
|
|
222
236
|
def wait(self, value:int, timeout:int=getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000)):
|
223
237
|
"""
|
@@ -227,17 +241,18 @@ class HCQSignal:
|
|
227
241
|
value: The value to wait for.
|
228
242
|
timeout: Maximum time to wait in milliseconds. Defaults to 10s.
|
229
243
|
"""
|
230
|
-
start_time = time.
|
231
|
-
while time.
|
232
|
-
|
233
|
-
raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
|
244
|
+
start_time = int(time.perf_counter() * 1000)
|
245
|
+
while self.value < value and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout:
|
246
|
+
self._sleep(time_spent)
|
247
|
+
if self.value < value: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
|
234
248
|
|
235
249
|
@contextlib.contextmanager
|
236
|
-
def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
|
250
|
+
def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Type[HWQueue]|None=None, queue:HWQueue|None=None):
|
237
251
|
st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None)
|
238
252
|
|
239
253
|
if enabled and queue is not None: queue.timestamp(st)
|
240
254
|
elif enabled:
|
255
|
+
assert queue_type is not None
|
241
256
|
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
|
242
257
|
dev.timeline_value += 1
|
243
258
|
|
@@ -245,21 +260,33 @@ def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
|
|
245
260
|
finally:
|
246
261
|
if enabled and queue is not None: queue.timestamp(en)
|
247
262
|
elif enabled:
|
263
|
+
assert queue_type is not None
|
248
264
|
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
|
249
265
|
dev.timeline_value += 1
|
250
266
|
|
251
|
-
if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
|
267
|
+
if enabled and PROFILE: dev.sig_prof_records.append((cast(HCQSignal, st), cast(HCQSignal, en), desc, queue_type is dev.hw_copy_queue_t))
|
268
|
+
|
269
|
+
class HCQArgsState(Generic[ProgramType]):
|
270
|
+
def __init__(self, ptr:int, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=()):
|
271
|
+
self.ptr, self.prg = ptr, prg
|
272
|
+
self.bind_data:list[tuple[tuple[sint, ...], int, str]] = []
|
273
|
+
|
274
|
+
def bind_sints_to_ptr(self, *vals:sint, ptr:int, fmt): self.bind_data.append((vals, ptr, fmt))
|
275
|
+
|
276
|
+
class CLikeArgsState(HCQArgsState[ProgramType]):
|
277
|
+
def __init__(self, ptr:int, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=(), prefix:list[int]|None=None):
|
278
|
+
super().__init__(ptr, prg, bufs, vals=vals)
|
279
|
+
|
280
|
+
if prefix is not None: to_mv(self.ptr, len(prefix) * 4).cast('I')[:] = array.array('I', prefix)
|
252
281
|
|
253
|
-
|
254
|
-
|
255
|
-
def update_buffer(self, index:int, buf:HCQBuffer): raise NotImplementedError("need update_buffer")
|
256
|
-
def update_var(self, index:int, val:int): raise NotImplementedError("need update_var")
|
282
|
+
self.bind_sints_to_ptr(*[b.va_addr for b in bufs], ptr=self.ptr + len(prefix or []) * 4, fmt='Q')
|
283
|
+
self.bind_sints_to_ptr(*vals, ptr=self.ptr + len(prefix or []) * 4 + len(bufs) * 8, fmt='I')
|
257
284
|
|
258
|
-
class HCQProgram:
|
259
|
-
def __init__(self, args_state_t:Type[HCQArgsState],
|
260
|
-
self.args_state_t, self.
|
285
|
+
class HCQProgram(Generic[DeviceType]):
|
286
|
+
def __init__(self, args_state_t:Type[HCQArgsState], dev:DeviceType, name:str, kernargs_alloc_size:int):
|
287
|
+
self.args_state_t, self.dev, self.name, self.kernargs_alloc_size = args_state_t, dev, name, kernargs_alloc_size
|
261
288
|
|
262
|
-
def fill_kernargs(self, bufs:
|
289
|
+
def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=(), kernargs_ptr:int|None=None) -> HCQArgsState:
|
263
290
|
"""
|
264
291
|
Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
|
265
292
|
Args:
|
@@ -269,10 +296,10 @@ class HCQProgram:
|
|
269
296
|
Returns:
|
270
297
|
Arguments state with the given buffers and values set for the program.
|
271
298
|
"""
|
272
|
-
return self.args_state_t(kernargs_ptr or self.
|
299
|
+
return self.args_state_t(kernargs_ptr or self.dev.kernargs_allocator.alloc(self.kernargs_alloc_size), self, bufs, vals=vals)
|
273
300
|
|
274
|
-
def __call__(self, *bufs:HCQBuffer, global_size:
|
275
|
-
vals:
|
301
|
+
def __call__(self, *bufs:HCQBuffer, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
|
302
|
+
vals:tuple[int, ...]=(), wait:bool=False) -> float|None:
|
276
303
|
"""
|
277
304
|
Enqueues the program for execution with the given arguments and dimensions.
|
278
305
|
|
@@ -288,103 +315,52 @@ class HCQProgram:
|
|
288
315
|
"""
|
289
316
|
|
290
317
|
kernargs = self.fill_kernargs(bufs, vals)
|
291
|
-
q = self.
|
318
|
+
q = self.dev.hw_compute_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1).memory_barrier()
|
292
319
|
|
293
|
-
with hcq_profile(self.
|
320
|
+
with hcq_profile(self.dev, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
|
294
321
|
q.exec(self, kernargs, global_size, local_size)
|
295
322
|
|
296
|
-
q.signal(self.
|
297
|
-
self.
|
323
|
+
q.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
324
|
+
self.dev.timeline_value += 1
|
298
325
|
|
299
|
-
if wait: self.
|
326
|
+
if wait: self.dev.synchronize()
|
300
327
|
return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
|
301
328
|
|
302
|
-
class
|
303
|
-
writers: int = 0
|
304
|
-
mjson: List[Dict] = []
|
305
|
-
actors: Dict[Union[str, Tuple[str, str]], int] = {}
|
306
|
-
|
307
|
-
def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
|
308
|
-
|
309
|
-
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
|
310
|
-
|
311
|
-
def _ensure_actor(self, actor_name, subactor_name):
|
312
|
-
if actor_name not in self.actors:
|
313
|
-
self.actors[actor_name] = (pid:=len(self.actors))
|
314
|
-
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
|
315
|
-
|
316
|
-
if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
|
317
|
-
self.actors[subactor_key] = (tid:=len(self.actors))
|
318
|
-
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
|
319
|
-
|
320
|
-
return self.actors[actor_name], self.actors.get(subactor_key, -1)
|
321
|
-
|
322
|
-
def __del__(self):
|
323
|
-
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
|
324
|
-
for name, st, et, actor_name, subactor_name, args in self.events:
|
325
|
-
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
326
|
-
args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
|
327
|
-
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
|
328
|
-
|
329
|
-
for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
|
330
|
-
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
|
331
|
-
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
332
|
-
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
|
333
|
-
self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})
|
334
|
-
|
335
|
-
ProfileLogger.writers -= 1
|
336
|
-
if ProfileLogger.writers == 0 and len(self.mjson) > 0:
|
337
|
-
with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
|
338
|
-
print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
|
339
|
-
|
340
|
-
class HCQCompiled(Compiled):
|
329
|
+
class HCQCompiled(Compiled, Generic[SignalType]):
|
341
330
|
"""
|
342
331
|
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
|
343
332
|
"""
|
344
|
-
devices:
|
345
|
-
gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
|
346
|
-
gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
|
333
|
+
devices: list[HCQCompiled] = []
|
347
334
|
|
348
|
-
def __init__(self, device:str, allocator:
|
349
|
-
comp_queue_t:Type[
|
335
|
+
def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
|
336
|
+
comp_queue_t:Type[HWQueue], copy_queue_t:Type[HWQueue]|None):
|
337
|
+
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
|
350
338
|
self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
|
351
339
|
self.timeline_value:int = 1
|
352
|
-
self.timeline_signal
|
353
|
-
self.
|
354
|
-
self.
|
355
|
-
self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
|
356
|
-
if PROFILE: self._prof_setup()
|
340
|
+
self.timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
|
341
|
+
self._shadow_timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
|
342
|
+
self.sig_prof_records:list[tuple[HCQSignal, HCQSignal, str, bool]] = []
|
357
343
|
|
358
344
|
from tinygrad.runtime.graph.hcq import HCQGraph
|
359
345
|
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
|
360
346
|
|
361
|
-
self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20,
|
362
|
-
self.
|
347
|
+
self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferSpec(cpu_access=True))
|
348
|
+
self.kernargs_allocator:BumpAllocator = BumpAllocator(self.kernargs_page.size, base=cast(int, self.kernargs_page.va_addr), wrap=True)
|
363
349
|
self.devices.append(self)
|
364
350
|
|
365
351
|
def synchronize(self):
|
366
|
-
try: self.timeline_signal.wait(self.timeline_value - 1)
|
352
|
+
try: self.timeline_signal.wait(self.timeline_value - 1)
|
367
353
|
except RuntimeError as e:
|
368
354
|
if hasattr(self, 'on_device_hang'): self.on_device_hang()
|
369
355
|
else: raise e
|
370
356
|
|
371
357
|
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
|
372
358
|
if PROFILE:
|
373
|
-
|
359
|
+
Compiled.profile_events += [ProfileRangeEvent(self.device, name, st.timestamp, en.timestamp, cp) for st,en,name,cp in self.sig_prof_records]
|
374
360
|
self.sig_prof_records = []
|
375
361
|
|
376
|
-
def
|
377
|
-
|
378
|
-
Allocates space for arguments passed to the kernel.
|
379
|
-
"""
|
380
|
-
if self.kernargs_ptr >= (self.kernargs_page.va_addr + self.kernargs_page.size - alloc_size): self.kernargs_ptr = self.kernargs_page.va_addr
|
381
|
-
self.kernargs_ptr = (res:=self.kernargs_ptr) + alloc_size
|
382
|
-
return res
|
383
|
-
|
384
|
-
def _ensure_shared_time_base(self):
|
385
|
-
if not self.gpu2cpu_compute_time_diff.is_nan(): return
|
386
|
-
|
387
|
-
def _sync_cpu_queue(d, q_t):
|
362
|
+
def _at_profile_finalize(self):
|
363
|
+
def _sync(d:HCQCompiled, q_t:Type[HWQueue]):
|
388
364
|
q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
|
389
365
|
d.timeline_value += 1
|
390
366
|
st = time.perf_counter_ns()
|
@@ -392,134 +368,94 @@ class HCQCompiled(Compiled):
|
|
392
368
|
et = time.perf_counter_ns()
|
393
369
|
return (decimal.Decimal(et+st) / 2000) - d.timeline_signal.timestamp
|
394
370
|
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
d,q,l = random.choice(choices)
|
400
|
-
l.append(_sync_cpu_queue(d,q))
|
401
|
-
for d,q,l in choices:
|
402
|
-
if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l)
|
403
|
-
if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l)
|
404
|
-
|
405
|
-
def _sync_gpu_to_gpu_queue(d1, d2, q1_t, q2_t):
|
406
|
-
q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
|
407
|
-
.timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
|
408
|
-
q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \
|
409
|
-
.timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2)
|
410
|
-
d1.timeline_value += 2
|
411
|
-
d2.timeline_value += 2
|
412
|
-
d1.timeline_signal.wait(d1.timeline_value - 1)
|
413
|
-
d2.timeline_signal.wait(d2.timeline_value - 1)
|
414
|
-
return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp
|
415
|
-
|
416
|
-
# then test it by timing the GPU to GPU times
|
417
|
-
jitter_matrix = [[float('nan')]*len(self.devices) for _ in range(len(self.devices))]
|
418
|
-
for i1, d1 in enumerate(self.devices):
|
419
|
-
for i2, d2 in enumerate(self.devices):
|
420
|
-
if d1 == d2: continue
|
421
|
-
d1_to_d2 = statistics.median(_sync_gpu_to_gpu_queue(d1, d2, d1.hw_compute_queue_t, d2.hw_compute_queue_t) - \
|
422
|
-
_sync_gpu_to_gpu_queue(d2, d1, d2.hw_compute_queue_t, d1.hw_compute_queue_t) for _ in range(20)) / 2
|
423
|
-
jitter_matrix[i1][i2] = d1_to_d2 - (d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff)
|
424
|
-
print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix]))
|
425
|
-
|
426
|
-
def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
|
427
|
-
"""
|
428
|
-
Translates local gpu time (timestamp) into global cpu time.
|
429
|
-
"""
|
430
|
-
self._ensure_shared_time_base()
|
431
|
-
return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
|
432
|
-
|
433
|
-
def _prof_setup(self):
|
434
|
-
if hasattr(self, 'profile_logger'): return
|
435
|
-
atexit.register(self._prof_finalize)
|
436
|
-
self.profile_logger = ProfileLogger()
|
437
|
-
|
438
|
-
def _prof_finalize(self):
|
439
|
-
qname = ["COMPUTE", "DMA"]
|
440
|
-
|
441
|
-
# Sync to be sure all events on the device are recorded.
|
442
|
-
self.synchronize()
|
443
|
-
|
444
|
-
for st, en, name, is_cp, args in self.raw_prof_records:
|
445
|
-
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp], args)]
|
446
|
-
for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records:
|
447
|
-
# Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint.
|
448
|
-
a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy)
|
449
|
-
self.profile_logger.deps += [(a_tm, b_tm, a_dev.dname, qname[a_is_copy], b_dev.dname, qname[b_is_copy])]
|
450
|
-
self.raw_prof_records, self.dep_prof_records = [], []
|
451
|
-
|
452
|
-
# Remove the logger, this flushes all data written by the device.
|
453
|
-
del self.profile_logger
|
371
|
+
gpu2cpu_compute_time_diff = statistics.median([_sync(self, self.hw_compute_queue_t) for _ in range(40)])
|
372
|
+
if self.hw_copy_queue_t is None: gpu2cpu_copy_time_diff = decimal.Decimal(0)
|
373
|
+
else: gpu2cpu_copy_time_diff = statistics.median([_sync(self, self.hw_copy_queue_t) for _ in range(40)])
|
374
|
+
Compiled.profile_events += [ProfileDeviceEvent(self.device, gpu2cpu_compute_time_diff, gpu2cpu_copy_time_diff)]
|
454
375
|
|
455
376
|
def _wrap_timeline_signal(self):
|
456
377
|
self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
|
457
378
|
self.timeline_signal.value = 0
|
458
|
-
cast(
|
379
|
+
cast(HCQAllocatorBase, self.allocator).b_timeline = [0] * len(cast(HCQAllocatorBase, self.allocator).b)
|
459
380
|
|
460
|
-
|
461
|
-
|
381
|
+
def _realloc(self, oldbuf:HCQBuffer|None, new_size:int, options:BufferSpec|None=None) -> tuple[HCQBuffer, bool]:
|
382
|
+
if oldbuf is not None: self.allocator.free(oldbuf, oldbuf.size, options=options)
|
383
|
+
try: buf, realloced = self.allocator.alloc(new_size, options=options), True
|
384
|
+
except MemoryError: buf, realloced = self.allocator.alloc(oldbuf.size if oldbuf is not None else new_size, options=options), False
|
385
|
+
return buf, realloced
|
462
386
|
|
463
|
-
class
|
387
|
+
class HCQBuffer:
|
388
|
+
def __init__(self, va_addr:sint, size:int, texture_info:Any=None, meta:Any=None, _base:HCQBuffer|None=None):
|
389
|
+
self.va_addr, self.size, self.texture_info, self.meta, self._base = va_addr, size, texture_info, meta, _base
|
390
|
+
|
391
|
+
class HCQAllocatorBase(LRUAllocator, Generic[DeviceType]):
|
464
392
|
"""
|
465
393
|
A base allocator class compatible with the HCQ (Hardware Command Queue) API.
|
466
394
|
|
467
|
-
This class implements basic copy operations following the HCQ API, utilizing both
|
395
|
+
This class implements basic copy operations following the HCQ API, utilizing both types of `HWQueue`.
|
468
396
|
"""
|
469
397
|
|
470
|
-
def __init__(self,
|
471
|
-
self.
|
472
|
-
self.b = [self._alloc(batch_size,
|
398
|
+
def __init__(self, dev:DeviceType, batch_size:int=(2 << 20), batch_cnt:int=32):
|
399
|
+
self.dev:DeviceType = dev
|
400
|
+
self.b = [self._alloc(batch_size, BufferSpec(host=True)) for _ in range(batch_cnt)]
|
473
401
|
self.b_timeline, self.b_next = [0] * len(self.b), 0
|
474
402
|
super().__init__()
|
475
403
|
|
476
|
-
def
|
404
|
+
def map(self, buf:HCQBuffer): pass
|
405
|
+
|
406
|
+
def _offset(self, buf, size:int, offset:int) -> HCQBuffer:
|
407
|
+
return HCQBuffer(va_addr=buf.va_addr + offset, size=size, texture_info=buf.texture_info, meta=buf.meta, _base=buf._base or buf)
|
477
408
|
|
478
|
-
|
479
|
-
|
409
|
+
class HCQAllocator(HCQAllocatorBase, Generic[DeviceType]):
|
410
|
+
def _copyin(self, dest:HCQBuffer, src:memoryview):
|
411
|
+
assert self.dev.hw_copy_queue_t is not None
|
412
|
+
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"CPU -> {self.dev.device}", enabled=PROFILE):
|
480
413
|
for i in range(0, src.nbytes, self.b[0].size):
|
481
414
|
self.b_next = (self.b_next + 1) % len(self.b)
|
482
|
-
self.
|
415
|
+
self.dev.timeline_signal.wait(self.b_timeline[self.b_next])
|
483
416
|
ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
|
484
|
-
self.
|
485
|
-
|
486
|
-
|
487
|
-
self.b_timeline[self.b_next] = self.
|
488
|
-
self.
|
417
|
+
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
|
418
|
+
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
|
419
|
+
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
420
|
+
self.b_timeline[self.b_next] = self.dev.timeline_value
|
421
|
+
self.dev.timeline_value += 1
|
489
422
|
|
490
423
|
def copy_from_disk(self, dest:HCQBuffer, src, size):
|
491
424
|
def _get_temp_buf():
|
492
425
|
# Check if the next buffer is safe to be used (its signal has passed) and reserve it.
|
493
|
-
if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.
|
426
|
+
if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.dev.timeline_signal.value:
|
494
427
|
self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
|
495
428
|
return (self.b[self.b_next].va_addr, self.b_next)
|
496
429
|
return None
|
497
430
|
|
498
|
-
|
431
|
+
assert self.dev.hw_copy_queue_t is not None
|
432
|
+
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"DISK -> {self.dev.device}", enabled=PROFILE):
|
499
433
|
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
|
500
|
-
self.
|
501
|
-
|
502
|
-
|
503
|
-
self.b_timeline[batch_info[1]] = self.
|
504
|
-
self.
|
434
|
+
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
|
435
|
+
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
|
436
|
+
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
437
|
+
self.b_timeline[batch_info[1]] = self.dev.timeline_value
|
438
|
+
self.dev.timeline_value += 1
|
505
439
|
|
506
|
-
def
|
507
|
-
self.
|
440
|
+
def _copyout(self, dest:memoryview, src:HCQBuffer):
|
441
|
+
self.dev.synchronize()
|
508
442
|
|
509
|
-
|
443
|
+
assert self.dev.hw_copy_queue_t is not None
|
444
|
+
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"{self.dev.device} -> CPU", enabled=PROFILE):
|
510
445
|
for i in range(0, dest.nbytes, self.b[0].size):
|
511
|
-
self.
|
512
|
-
|
513
|
-
|
514
|
-
self.
|
515
|
-
self.
|
446
|
+
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
|
447
|
+
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
|
448
|
+
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
449
|
+
self.dev.timeline_signal.wait(self.dev.timeline_value)
|
450
|
+
self.dev.timeline_value += 1
|
516
451
|
|
517
452
|
ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
|
518
453
|
|
519
|
-
def
|
520
|
-
src_dev.allocator.map(dest)
|
454
|
+
def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:DeviceType, dest_dev:DeviceType):
|
455
|
+
cast(HCQAllocator, src_dev.allocator).map(dest)
|
521
456
|
|
522
|
-
|
457
|
+
assert src_dev.hw_copy_queue_t is not None
|
458
|
+
with hcq_profile(src_dev, queue_type=src_dev.hw_copy_queue_t, desc=f"{src_dev.device} -> {dest_dev.device}", enabled=PROFILE):
|
523
459
|
src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
|
524
460
|
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
525
461
|
.copy(dest.va_addr, src.va_addr, sz) \
|
@@ -531,9 +467,3 @@ class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method
|
|
531
467
|
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
532
468
|
.signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
|
533
469
|
dest_dev.timeline_value += 1
|
534
|
-
|
535
|
-
def map(self, buf:HCQBuffer): pass
|
536
|
-
|
537
|
-
def offset(self, buf, size:int, offset:int) -> HCQBuffer:
|
538
|
-
return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
|
539
|
-
**{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)
|