tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/runtime/support/hcq.py
CHANGED
@@ -1,48 +1,106 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import
|
3
|
-
import contextlib, decimal, statistics,
|
4
|
-
from tinygrad.helpers import
|
2
|
+
from typing import cast, Type, TypeVar, Generic, Any
|
3
|
+
import contextlib, decimal, statistics, time, ctypes, array, os, fcntl
|
4
|
+
from tinygrad.helpers import PROFILE, from_mv, getenv, to_mv, round_up
|
5
5
|
from tinygrad.renderer import Renderer
|
6
|
-
from tinygrad.device import
|
6
|
+
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileRangeEvent, ProfileDeviceEvent
|
7
|
+
from tinygrad.ops import sym_infer, sint, Variable, UOp
|
8
|
+
from tinygrad.runtime.autogen import libc
|
7
9
|
|
8
|
-
|
9
|
-
|
10
|
-
def hcq_command(func):
|
10
|
+
class HWInterface:
|
11
11
|
"""
|
12
|
-
|
13
|
-
|
14
|
-
For example:
|
15
|
-
```python
|
16
|
-
@hcq_command
|
17
|
-
def command_method(self, ...): ...
|
18
|
-
```
|
12
|
+
Hardware Abstraction Layer for HCQ devices. The class provides a unified interface for interacting with hardware devices.
|
19
13
|
"""
|
20
|
-
def __wrapper(self, *args, **kwargs):
|
21
|
-
self.cmds_offset.append(len(self.q))
|
22
|
-
func(self, *args, **kwargs)
|
23
|
-
self.cmds_len.append(len(self.q) - self.cmds_offset[-1])
|
24
|
-
self.cmds_meta.append(func.__name__)
|
25
|
-
return self
|
26
|
-
return __wrapper
|
27
14
|
|
28
|
-
|
15
|
+
def __init__(self, path:str="", flags:int=os.O_RDONLY, fd:int|None=None):
|
16
|
+
self.path:str = path
|
17
|
+
self.fd:int = fd or os.open(path, flags)
|
18
|
+
def __del__(self):
|
19
|
+
if hasattr(self, 'fd'): os.close(self.fd)
|
20
|
+
def ioctl(self, request, arg): return fcntl.ioctl(self.fd, request, arg)
|
21
|
+
def mmap(self, start, sz, prot, flags, offset): return libc.mmap(start, sz, prot, flags, self.fd, offset)
|
22
|
+
def read(self, size=None, binary=False, offset=None):
|
23
|
+
if offset is not None: self.seek(offset)
|
24
|
+
with open(self.fd, "rb" if binary else "r", closefd=False) as file: return file.read(size)
|
25
|
+
def write(self, content, binary=False, offset=None):
|
26
|
+
if offset is not None: self.seek(offset)
|
27
|
+
with open(self.fd, "wb" if binary else "w", closefd=False) as file: file.write(content)
|
28
|
+
def listdir(self): return os.listdir(self.path)
|
29
|
+
def seek(self, offset): os.lseek(self.fd, offset, os.SEEK_SET)
|
30
|
+
@staticmethod
|
31
|
+
def anon_mmap(start, sz, prot, flags, offset): return libc.mmap(start, sz, prot, flags, -1, offset)
|
32
|
+
@staticmethod
|
33
|
+
def munmap(buf, sz): return libc.munmap(buf, sz)
|
34
|
+
@staticmethod
|
35
|
+
def exists(path): return os.path.exists(path)
|
36
|
+
@staticmethod
|
37
|
+
def readlink(path): return os.readlink(path)
|
38
|
+
@staticmethod
|
39
|
+
def eventfd(initval, flags=None): return HWInterface(fd=os.eventfd(initval, flags)) # type: ignore[attr-defined]
|
40
|
+
|
41
|
+
if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.mockgpu import MockHWInterface as HWInterface # noqa: F401 # pylint: disable=unused-import
|
42
|
+
|
43
|
+
# **************** for HCQ Compatible Devices ****************
|
44
|
+
|
45
|
+
SignalType = TypeVar('SignalType', bound='HCQSignal')
|
46
|
+
DeviceType = TypeVar('DeviceType', bound='HCQCompiled')
|
47
|
+
ProgramType = TypeVar('ProgramType', bound='HCQProgram')
|
48
|
+
ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState')
|
49
|
+
QueueType = TypeVar('QueueType', bound='HWQueue')
|
50
|
+
|
51
|
+
class BumpAllocator:
|
52
|
+
def __init__(self, size:int, base:int=0, wrap:bool=True): self.size, self.ptr, self.base, self.wrap = size, 0, base, wrap
|
53
|
+
def alloc(self, size:int, alignment:int=1) -> int:
|
54
|
+
if round_up(self.ptr, alignment) + size > self.size:
|
55
|
+
if not self.wrap: raise RuntimeError("Out of memory")
|
56
|
+
self.ptr = 0
|
57
|
+
self.ptr = (res:=round_up(self.ptr, alignment)) + size
|
58
|
+
return res + self.base
|
59
|
+
|
60
|
+
class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
29
61
|
"""
|
30
62
|
A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
|
31
|
-
Both compute and copy queues should have the following commands implemented.
|
32
63
|
"""
|
33
64
|
|
34
|
-
def __init__(self):
|
35
|
-
|
36
|
-
|
37
|
-
|
65
|
+
def __init__(self):
|
66
|
+
self._q:Any = []
|
67
|
+
self.binded_device:DeviceType|None = None
|
68
|
+
self.q_sints:list[tuple[int, int]] = []
|
69
|
+
self.mv_sints:list[tuple[memoryview, int, int, int|None]] = []
|
70
|
+
self.syms:list[sint] = []
|
71
|
+
self._prev_resolved_syms:list[int|None] = []
|
72
|
+
|
73
|
+
def _new_sym(self, sym:sint) -> int:
|
74
|
+
if sym not in self.syms:
|
75
|
+
self.syms.append(sym)
|
76
|
+
self._prev_resolved_syms.append(None)
|
77
|
+
return self.syms.index(sym)
|
78
|
+
|
79
|
+
def q(self, *values):
|
38
80
|
"""
|
39
|
-
|
40
|
-
|
81
|
+
Enqueues values in the queue.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
values: The values to enqueue in the queue.
|
41
85
|
"""
|
42
|
-
return len(self) - 1
|
43
86
|
|
44
|
-
|
45
|
-
|
87
|
+
for v in values:
|
88
|
+
if isinstance(v, UOp):
|
89
|
+
self.q_sints.append((len(self._q), self._new_sym(v)))
|
90
|
+
self._q.append(0xbadc0ded)
|
91
|
+
else: self._q.append(v)
|
92
|
+
|
93
|
+
# *** common commands ***
|
94
|
+
|
95
|
+
def timestamp(self, signal:SignalType):
|
96
|
+
"""
|
97
|
+
Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
signal: The signal to store the timestamp
|
101
|
+
"""
|
102
|
+
|
103
|
+
def signal(self, signal:SignalType, value:sint):
|
46
104
|
"""
|
47
105
|
Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
|
48
106
|
|
@@ -50,11 +108,8 @@ class HWCommandQueue:
|
|
50
108
|
signal: The signal to set
|
51
109
|
value: The value to set the signal to
|
52
110
|
"""
|
53
|
-
self._signal(signal, value)
|
54
|
-
def _signal(self, signal:HCQSignal, value:int): raise NotImplementedError("backend should overload this function")
|
55
111
|
|
56
|
-
|
57
|
-
def wait(self, signal:HCQSignal, value:int):
|
112
|
+
def wait(self, signal:SignalType, value:sint):
|
58
113
|
"""
|
59
114
|
Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
|
60
115
|
|
@@ -62,49 +117,40 @@ class HWCommandQueue:
|
|
62
117
|
signal: The signal to wait on
|
63
118
|
value: The value to wait for
|
64
119
|
"""
|
65
|
-
self._wait(signal, value)
|
66
|
-
def _wait(self, signal, value): raise NotImplementedError("backend should overload this function")
|
67
120
|
|
68
|
-
|
69
|
-
def timestamp(self, signal:HCQSignal):
|
70
|
-
"""
|
71
|
-
Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
|
121
|
+
# *** commands for compute queues ***
|
72
122
|
|
73
|
-
|
74
|
-
|
123
|
+
def memory_barrier(self):
|
124
|
+
"""
|
125
|
+
Enqueues a memory barrier command to ensure memory coherence between agents. Only on compute queues.
|
75
126
|
"""
|
76
|
-
self._timestamp(signal)
|
77
|
-
def _timestamp(self, signal): raise NotImplementedError("backend should overload this function")
|
78
127
|
|
79
|
-
def
|
128
|
+
def exec(self, prg:ProgramType, args_state:ArgsStateType, global_size:tuple[sint, ...], local_size:tuple[sint, ...]):
|
80
129
|
"""
|
81
|
-
|
130
|
+
Enqueues an execution command for a kernel program. Only on compute queues.
|
82
131
|
|
83
132
|
Args:
|
84
|
-
|
85
|
-
|
86
|
-
|
133
|
+
prg: The program to execute
|
134
|
+
args_state: The args state to execute program with
|
135
|
+
global_size: The global work size
|
136
|
+
local_size: The local work size
|
87
137
|
"""
|
88
|
-
if self.cmds_meta[cmd_idx] != "signal": raise RuntimeError("called update_signal not on a signal command")
|
89
|
-
self._update_signal(cmd_idx, signal, value)
|
90
|
-
return self
|
91
|
-
def _update_signal(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
|
92
138
|
|
93
|
-
|
139
|
+
# *** commands for copy queues ***
|
140
|
+
|
141
|
+
def copy(self, dest:sint, src:sint, copy_size:int):
|
94
142
|
"""
|
95
|
-
|
143
|
+
Enqueues a copy command to transfer data. Only on copy queues.
|
96
144
|
|
97
145
|
Args:
|
98
|
-
|
99
|
-
|
100
|
-
|
146
|
+
dest: The destination of the copy
|
147
|
+
src: The source of the copy
|
148
|
+
copy_size: The size of data to copy
|
101
149
|
"""
|
102
|
-
if self.cmds_meta[cmd_idx] != "wait": raise RuntimeError("called update_wait not on a wait command")
|
103
|
-
self._update_wait(cmd_idx, signal, value)
|
104
|
-
return self
|
105
|
-
def _update_wait(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
|
106
150
|
|
107
|
-
|
151
|
+
# *** submit and bind commands ***
|
152
|
+
|
153
|
+
def bind(self, dev:DeviceType):
|
108
154
|
"""
|
109
155
|
Associates the queue with a specific device for optimized execution.
|
110
156
|
|
@@ -112,99 +158,65 @@ class HWCommandQueue:
|
|
112
158
|
the need to copy queues into the device, thereby enhancing performance.
|
113
159
|
|
114
160
|
Args:
|
115
|
-
|
161
|
+
dev: The target device for queue optimization.
|
116
162
|
|
117
163
|
Note:
|
118
164
|
Implementing this method is optional but recommended for performance gains.
|
119
165
|
"""
|
120
166
|
|
121
|
-
def
|
122
|
-
|
123
|
-
Submits the command queue to a specific device for execution.
|
167
|
+
def bind_args_state(self, args_state:ArgsStateType):
|
168
|
+
for vals, ptr, fmt in args_state.bind_data: self.bind_sints_to_ptr(*vals, ptr=ptr, fmt=fmt)
|
124
169
|
|
125
|
-
|
126
|
-
|
127
|
-
"""
|
128
|
-
if self.q: self._submit(device)
|
129
|
-
return self
|
130
|
-
def _submit(self, device:HCQCompiled): raise NotImplementedError("backend should overload this function")
|
170
|
+
def bind_sints(self, *vals:sint, struct:ctypes.Structure, start_field:str, fmt, mask:int|None=None):
|
171
|
+
self.bind_sints_to_ptr(*vals, ptr=ctypes.addressof(struct) + getattr(type(struct), start_field).offset, fmt=fmt, mask=mask)
|
131
172
|
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
"""
|
138
|
-
self._memory_barrier()
|
139
|
-
def _memory_barrier(self): pass
|
173
|
+
def bind_sints_to_ptr(self, *vals:sint, ptr:int, fmt, mask:int|None=None):
|
174
|
+
mv = to_mv(ptr, 8*len(vals)).cast(fmt)
|
175
|
+
for i, val in enumerate(vals):
|
176
|
+
if isinstance(val, int): mv[i] = val if mask is None else ((mv[i] & ~mask) | val)
|
177
|
+
else: self.mv_sints.append((mv, i, self._new_sym(val), mask))
|
140
178
|
|
141
|
-
|
142
|
-
|
143
|
-
"""
|
144
|
-
Enqueues an execution command for a kernel program.
|
179
|
+
def _apply_var_vals(self, var_vals:dict[Variable, int]):
|
180
|
+
resolved_syms = [sym_infer(sym, var_vals) for sym in self.syms]
|
145
181
|
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
global_size: The global work size
|
150
|
-
local_size: The local work size
|
151
|
-
"""
|
152
|
-
self._exec(prg, args_state, global_size, local_size)
|
153
|
-
def _exec(self, prg, args_state, global_size, local_size): raise NotImplementedError("backend should overload this function")
|
182
|
+
for off, sym_idx in self.q_sints:
|
183
|
+
if self._prev_resolved_syms[sym_idx] == resolved_syms[sym_idx]: continue
|
184
|
+
self._q[off] = resolved_syms[sym_idx]
|
154
185
|
|
155
|
-
|
156
|
-
|
157
|
-
|
186
|
+
for mv, off, sym_idx, mask in self.mv_sints:
|
187
|
+
if self._prev_resolved_syms[sym_idx] == resolved_syms[sym_idx]: continue
|
188
|
+
mv[off] = resolved_syms[sym_idx] if mask is None else ((mv[off] & ~mask) | resolved_syms[sym_idx])
|
158
189
|
|
159
|
-
|
160
|
-
cmd_idx: Index of the execution command to update
|
161
|
-
global_size: New global work size (if None, keeps the original)
|
162
|
-
local_size: New local work size (if None, keeps the original)
|
163
|
-
"""
|
164
|
-
if self.cmds_meta[cmd_idx] != "exec": raise RuntimeError("called update_exec not on an exec command")
|
165
|
-
self._update_exec(cmd_idx, global_size, local_size)
|
166
|
-
return self
|
167
|
-
def _update_exec(self, cmd_idx, global_size, local_size): raise NotImplementedError("backend should overload this function")
|
190
|
+
self._prev_resolved_syms = cast(list[int|None], resolved_syms)
|
168
191
|
|
169
|
-
|
170
|
-
@hcq_command
|
171
|
-
def copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int):
|
192
|
+
def submit(self, dev:DeviceType, var_vals:dict[Variable, int]|None=None):
|
172
193
|
"""
|
173
|
-
|
194
|
+
Submits the command queue to a specific device for execution.
|
174
195
|
|
175
196
|
Args:
|
176
|
-
|
177
|
-
src: The source of the copy
|
178
|
-
copy_size: The size of data to copy
|
197
|
+
dev: The device to submit the queue to
|
179
198
|
"""
|
180
|
-
self._copy(dest, src, copy_size)
|
181
|
-
def _copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int): raise NotImplementedError("backend should overload this function")
|
182
199
|
|
183
|
-
|
184
|
-
|
185
|
-
Updates a previously queued copy command.
|
186
|
-
|
187
|
-
Args:
|
188
|
-
cmd_idx: Index of the copy command to update
|
189
|
-
dest: New destination of the copy (if None, keeps the original)
|
190
|
-
src: New source of the copy (if None, keeps the original)
|
191
|
-
"""
|
192
|
-
if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command")
|
193
|
-
self._update_copy(cmd_idx, dest, src)
|
200
|
+
if var_vals is not None: self._apply_var_vals(var_vals)
|
201
|
+
self._submit(dev)
|
194
202
|
return self
|
195
|
-
def
|
203
|
+
def _submit(self, dev:DeviceType): raise NotImplementedError("need _submit")
|
196
204
|
|
197
|
-
class HCQSignal:
|
198
|
-
def __init__(self, value:int=0,
|
205
|
+
class HCQSignal(Generic[DeviceType]):
|
206
|
+
def __init__(self, base_addr:sint=0, value:int=0, timeline_for_device:DeviceType|None=None, timestamp_divider=1, value_off=0, timestamp_off=8):
|
207
|
+
self.base_addr, self.value_addr, self.timestamp_addr = base_addr, base_addr+value_off, base_addr+timestamp_off
|
208
|
+
self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider)
|
209
|
+
self.timeline_for_device:DeviceType|None = timeline_for_device
|
210
|
+
|
211
|
+
if isinstance(base_addr, int):
|
212
|
+
self.value_mv, self.timestamp_mv = to_mv(self.value_addr, 8).cast('Q'), to_mv(self.timestamp_addr, 8).cast('Q')
|
213
|
+
self.value_mv[0] = value
|
199
214
|
|
200
215
|
@property
|
201
|
-
def value(self) -> int: return self.
|
216
|
+
def value(self) -> int: return self.value_mv[0]
|
202
217
|
|
203
218
|
@value.setter
|
204
|
-
def value(self, new_value:int): self.
|
205
|
-
|
206
|
-
def _get_value(self) -> int: raise NotImplementedError("_get_value() method must be implemented")
|
207
|
-
def _set_value(self, new_value:int): raise NotImplementedError("_set_value() method must be implemented")
|
219
|
+
def value(self, new_value:int): self.value_mv[0] = new_value
|
208
220
|
|
209
221
|
@property
|
210
222
|
def timestamp(self) -> decimal.Decimal:
|
@@ -216,8 +228,12 @@ class HCQSignal:
|
|
216
228
|
Returns:
|
217
229
|
The timestamp in microseconds.
|
218
230
|
"""
|
219
|
-
return self.
|
220
|
-
|
231
|
+
return self.timestamp_mv[0] / self.timestamp_divider
|
232
|
+
|
233
|
+
def _sleep(self, time_spent_waiting_ms:int):
|
234
|
+
"""
|
235
|
+
Optional function which can implement sleep functionality for the signal.
|
236
|
+
"""
|
221
237
|
|
222
238
|
def wait(self, value:int, timeout:int=getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000)):
|
223
239
|
"""
|
@@ -227,17 +243,18 @@ class HCQSignal:
|
|
227
243
|
value: The value to wait for.
|
228
244
|
timeout: Maximum time to wait in milliseconds. Defaults to 10s.
|
229
245
|
"""
|
230
|
-
start_time = time.
|
231
|
-
while time.
|
232
|
-
|
233
|
-
raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
|
246
|
+
start_time = int(time.perf_counter() * 1000)
|
247
|
+
while self.value < value and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout:
|
248
|
+
self._sleep(time_spent)
|
249
|
+
if self.value < value: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
|
234
250
|
|
235
251
|
@contextlib.contextmanager
|
236
|
-
def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
|
252
|
+
def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Type[HWQueue]|None=None, queue:HWQueue|None=None):
|
237
253
|
st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None)
|
238
254
|
|
239
255
|
if enabled and queue is not None: queue.timestamp(st)
|
240
256
|
elif enabled:
|
257
|
+
assert queue_type is not None
|
241
258
|
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
|
242
259
|
dev.timeline_value += 1
|
243
260
|
|
@@ -245,21 +262,33 @@ def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
|
|
245
262
|
finally:
|
246
263
|
if enabled and queue is not None: queue.timestamp(en)
|
247
264
|
elif enabled:
|
265
|
+
assert queue_type is not None
|
248
266
|
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
|
249
267
|
dev.timeline_value += 1
|
250
268
|
|
251
|
-
if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
|
269
|
+
if enabled and PROFILE: dev.sig_prof_records.append((cast(HCQSignal, st), cast(HCQSignal, en), desc, queue_type is dev.hw_copy_queue_t))
|
270
|
+
|
271
|
+
class HCQArgsState(Generic[ProgramType]):
|
272
|
+
def __init__(self, ptr:int, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=()):
|
273
|
+
self.ptr, self.prg = ptr, prg
|
274
|
+
self.bind_data:list[tuple[tuple[sint, ...], int, str]] = []
|
275
|
+
|
276
|
+
def bind_sints_to_ptr(self, *vals:sint, ptr:int, fmt): self.bind_data.append((vals, ptr, fmt))
|
277
|
+
|
278
|
+
class CLikeArgsState(HCQArgsState[ProgramType]):
|
279
|
+
def __init__(self, ptr:int, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=(), prefix:list[int]|None=None):
|
280
|
+
super().__init__(ptr, prg, bufs, vals=vals)
|
281
|
+
|
282
|
+
if prefix is not None: to_mv(self.ptr, len(prefix) * 4).cast('I')[:] = array.array('I', prefix)
|
252
283
|
|
253
|
-
|
254
|
-
|
255
|
-
def update_buffer(self, index:int, buf:HCQBuffer): raise NotImplementedError("need update_buffer")
|
256
|
-
def update_var(self, index:int, val:int): raise NotImplementedError("need update_var")
|
284
|
+
self.bind_sints_to_ptr(*[b.va_addr for b in bufs], ptr=self.ptr + len(prefix or []) * 4, fmt='Q')
|
285
|
+
self.bind_sints_to_ptr(*vals, ptr=self.ptr + len(prefix or []) * 4 + len(bufs) * 8, fmt='I')
|
257
286
|
|
258
|
-
class HCQProgram:
|
259
|
-
def __init__(self, args_state_t:Type[HCQArgsState],
|
260
|
-
self.args_state_t, self.
|
287
|
+
class HCQProgram(Generic[DeviceType]):
|
288
|
+
def __init__(self, args_state_t:Type[HCQArgsState], dev:DeviceType, name:str, kernargs_alloc_size:int):
|
289
|
+
self.args_state_t, self.dev, self.name, self.kernargs_alloc_size = args_state_t, dev, name, kernargs_alloc_size
|
261
290
|
|
262
|
-
def fill_kernargs(self, bufs:
|
291
|
+
def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=(), kernargs_ptr:int|None=None) -> HCQArgsState:
|
263
292
|
"""
|
264
293
|
Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
|
265
294
|
Args:
|
@@ -269,10 +298,10 @@ class HCQProgram:
|
|
269
298
|
Returns:
|
270
299
|
Arguments state with the given buffers and values set for the program.
|
271
300
|
"""
|
272
|
-
return self.args_state_t(kernargs_ptr or self.
|
301
|
+
return self.args_state_t(kernargs_ptr or self.dev.kernargs_allocator.alloc(self.kernargs_alloc_size), self, bufs, vals=vals)
|
273
302
|
|
274
|
-
def __call__(self, *bufs:HCQBuffer, global_size:
|
275
|
-
vals:
|
303
|
+
def __call__(self, *bufs:HCQBuffer, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
|
304
|
+
vals:tuple[int, ...]=(), wait:bool=False) -> float|None:
|
276
305
|
"""
|
277
306
|
Enqueues the program for execution with the given arguments and dimensions.
|
278
307
|
|
@@ -288,103 +317,52 @@ class HCQProgram:
|
|
288
317
|
"""
|
289
318
|
|
290
319
|
kernargs = self.fill_kernargs(bufs, vals)
|
291
|
-
q = self.
|
320
|
+
q = self.dev.hw_compute_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1).memory_barrier()
|
292
321
|
|
293
|
-
with hcq_profile(self.
|
322
|
+
with hcq_profile(self.dev, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
|
294
323
|
q.exec(self, kernargs, global_size, local_size)
|
295
324
|
|
296
|
-
q.signal(self.
|
297
|
-
self.
|
325
|
+
q.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
326
|
+
self.dev.timeline_value += 1
|
298
327
|
|
299
|
-
if wait: self.
|
328
|
+
if wait: self.dev.synchronize()
|
300
329
|
return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
|
301
330
|
|
302
|
-
class
|
303
|
-
writers: int = 0
|
304
|
-
mjson: List[Dict] = []
|
305
|
-
actors: Dict[Union[str, Tuple[str, str]], int] = {}
|
306
|
-
|
307
|
-
def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
|
308
|
-
|
309
|
-
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
|
310
|
-
|
311
|
-
def _ensure_actor(self, actor_name, subactor_name):
|
312
|
-
if actor_name not in self.actors:
|
313
|
-
self.actors[actor_name] = (pid:=len(self.actors))
|
314
|
-
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
|
315
|
-
|
316
|
-
if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
|
317
|
-
self.actors[subactor_key] = (tid:=len(self.actors))
|
318
|
-
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
|
319
|
-
|
320
|
-
return self.actors[actor_name], self.actors.get(subactor_key, -1)
|
321
|
-
|
322
|
-
def __del__(self):
|
323
|
-
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
|
324
|
-
for name, st, et, actor_name, subactor_name, args in self.events:
|
325
|
-
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
326
|
-
args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
|
327
|
-
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
|
328
|
-
|
329
|
-
for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
|
330
|
-
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
|
331
|
-
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
332
|
-
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
|
333
|
-
self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})
|
334
|
-
|
335
|
-
ProfileLogger.writers -= 1
|
336
|
-
if ProfileLogger.writers == 0 and len(self.mjson) > 0:
|
337
|
-
with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
|
338
|
-
print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
|
339
|
-
|
340
|
-
class HCQCompiled(Compiled):
|
331
|
+
class HCQCompiled(Compiled, Generic[SignalType]):
|
341
332
|
"""
|
342
333
|
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
|
343
334
|
"""
|
344
|
-
devices:
|
345
|
-
gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
|
346
|
-
gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
|
335
|
+
devices: list[HCQCompiled] = []
|
347
336
|
|
348
|
-
def __init__(self, device:str, allocator:
|
349
|
-
comp_queue_t:Type[
|
337
|
+
def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
|
338
|
+
comp_queue_t:Type[HWQueue], copy_queue_t:Type[HWQueue]|None):
|
339
|
+
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
|
350
340
|
self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
|
351
341
|
self.timeline_value:int = 1
|
352
|
-
self.timeline_signal
|
353
|
-
self.
|
354
|
-
self.
|
355
|
-
self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
|
356
|
-
if PROFILE: self._prof_setup()
|
342
|
+
self.timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
|
343
|
+
self._shadow_timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
|
344
|
+
self.sig_prof_records:list[tuple[HCQSignal, HCQSignal, str, bool]] = []
|
357
345
|
|
358
346
|
from tinygrad.runtime.graph.hcq import HCQGraph
|
359
347
|
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
|
360
348
|
|
361
|
-
self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20,
|
362
|
-
self.
|
349
|
+
self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferSpec(cpu_access=True))
|
350
|
+
self.kernargs_allocator:BumpAllocator = BumpAllocator(self.kernargs_page.size, base=cast(int, self.kernargs_page.va_addr), wrap=True)
|
363
351
|
self.devices.append(self)
|
364
352
|
|
365
353
|
def synchronize(self):
|
366
|
-
try: self.timeline_signal.wait(self.timeline_value - 1)
|
354
|
+
try: self.timeline_signal.wait(self.timeline_value - 1)
|
367
355
|
except RuntimeError as e:
|
368
356
|
if hasattr(self, 'on_device_hang'): self.on_device_hang()
|
369
357
|
else: raise e
|
370
358
|
|
371
359
|
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
|
372
360
|
if PROFILE:
|
373
|
-
|
361
|
+
Compiled.profile_events += [ProfileRangeEvent(self.device, name, st.timestamp, en.timestamp, cp) for st,en,name,cp in self.sig_prof_records]
|
374
362
|
self.sig_prof_records = []
|
375
363
|
|
376
|
-
def
|
377
|
-
|
378
|
-
Allocates space for arguments passed to the kernel.
|
379
|
-
"""
|
380
|
-
if self.kernargs_ptr >= (self.kernargs_page.va_addr + self.kernargs_page.size - alloc_size): self.kernargs_ptr = self.kernargs_page.va_addr
|
381
|
-
self.kernargs_ptr = (res:=self.kernargs_ptr) + alloc_size
|
382
|
-
return res
|
383
|
-
|
384
|
-
def _ensure_shared_time_base(self):
|
385
|
-
if not self.gpu2cpu_compute_time_diff.is_nan(): return
|
386
|
-
|
387
|
-
def _sync_cpu_queue(d, q_t):
|
364
|
+
def _at_profile_finalize(self):
|
365
|
+
def _sync(d:HCQCompiled, q_t:Type[HWQueue]):
|
388
366
|
q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
|
389
367
|
d.timeline_value += 1
|
390
368
|
st = time.perf_counter_ns()
|
@@ -392,134 +370,94 @@ class HCQCompiled(Compiled):
|
|
392
370
|
et = time.perf_counter_ns()
|
393
371
|
return (decimal.Decimal(et+st) / 2000) - d.timeline_signal.timestamp
|
394
372
|
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
d,q,l = random.choice(choices)
|
400
|
-
l.append(_sync_cpu_queue(d,q))
|
401
|
-
for d,q,l in choices:
|
402
|
-
if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l)
|
403
|
-
if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l)
|
404
|
-
|
405
|
-
def _sync_gpu_to_gpu_queue(d1, d2, q1_t, q2_t):
|
406
|
-
q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
|
407
|
-
.timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
|
408
|
-
q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \
|
409
|
-
.timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2)
|
410
|
-
d1.timeline_value += 2
|
411
|
-
d2.timeline_value += 2
|
412
|
-
d1.timeline_signal.wait(d1.timeline_value - 1)
|
413
|
-
d2.timeline_signal.wait(d2.timeline_value - 1)
|
414
|
-
return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp
|
415
|
-
|
416
|
-
# then test it by timing the GPU to GPU times
|
417
|
-
jitter_matrix = [[float('nan')]*len(self.devices) for _ in range(len(self.devices))]
|
418
|
-
for i1, d1 in enumerate(self.devices):
|
419
|
-
for i2, d2 in enumerate(self.devices):
|
420
|
-
if d1 == d2: continue
|
421
|
-
d1_to_d2 = statistics.median(_sync_gpu_to_gpu_queue(d1, d2, d1.hw_compute_queue_t, d2.hw_compute_queue_t) - \
|
422
|
-
_sync_gpu_to_gpu_queue(d2, d1, d2.hw_compute_queue_t, d1.hw_compute_queue_t) for _ in range(20)) / 2
|
423
|
-
jitter_matrix[i1][i2] = d1_to_d2 - (d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff)
|
424
|
-
print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix]))
|
425
|
-
|
426
|
-
def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
|
427
|
-
"""
|
428
|
-
Translates local gpu time (timestamp) into global cpu time.
|
429
|
-
"""
|
430
|
-
self._ensure_shared_time_base()
|
431
|
-
return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
|
432
|
-
|
433
|
-
def _prof_setup(self):
|
434
|
-
if hasattr(self, 'profile_logger'): return
|
435
|
-
atexit.register(self._prof_finalize)
|
436
|
-
self.profile_logger = ProfileLogger()
|
437
|
-
|
438
|
-
def _prof_finalize(self):
|
439
|
-
qname = ["COMPUTE", "DMA"]
|
440
|
-
|
441
|
-
# Sync to be sure all events on the device are recorded.
|
442
|
-
self.synchronize()
|
443
|
-
|
444
|
-
for st, en, name, is_cp, args in self.raw_prof_records:
|
445
|
-
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp], args)]
|
446
|
-
for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records:
|
447
|
-
# Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint.
|
448
|
-
a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy)
|
449
|
-
self.profile_logger.deps += [(a_tm, b_tm, a_dev.dname, qname[a_is_copy], b_dev.dname, qname[b_is_copy])]
|
450
|
-
self.raw_prof_records, self.dep_prof_records = [], []
|
451
|
-
|
452
|
-
# Remove the logger, this flushes all data written by the device.
|
453
|
-
del self.profile_logger
|
373
|
+
gpu2cpu_compute_time_diff = statistics.median([_sync(self, self.hw_compute_queue_t) for _ in range(40)])
|
374
|
+
if self.hw_copy_queue_t is None: gpu2cpu_copy_time_diff = decimal.Decimal(0)
|
375
|
+
else: gpu2cpu_copy_time_diff = statistics.median([_sync(self, self.hw_copy_queue_t) for _ in range(40)])
|
376
|
+
Compiled.profile_events += [ProfileDeviceEvent(self.device, gpu2cpu_compute_time_diff, gpu2cpu_copy_time_diff)]
|
454
377
|
|
455
378
|
def _wrap_timeline_signal(self):
|
456
379
|
self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
|
457
380
|
self.timeline_signal.value = 0
|
458
|
-
cast(
|
381
|
+
cast(HCQAllocatorBase, self.allocator).b_timeline = [0] * len(cast(HCQAllocatorBase, self.allocator).b)
|
459
382
|
|
460
|
-
|
461
|
-
|
383
|
+
def _realloc(self, oldbuf:HCQBuffer|None, new_size:int, options:BufferSpec|None=None) -> tuple[HCQBuffer, bool]:
|
384
|
+
if oldbuf is not None: self.allocator.free(oldbuf, oldbuf.size, options=options)
|
385
|
+
try: buf, realloced = self.allocator.alloc(new_size, options=options), True
|
386
|
+
except MemoryError: buf, realloced = self.allocator.alloc(oldbuf.size if oldbuf is not None else new_size, options=options), False
|
387
|
+
return buf, realloced
|
462
388
|
|
463
|
-
class
|
389
|
+
class HCQBuffer:
|
390
|
+
def __init__(self, va_addr:sint, size:int, texture_info:Any=None, meta:Any=None, _base:HCQBuffer|None=None):
|
391
|
+
self.va_addr, self.size, self.texture_info, self.meta, self._base = va_addr, size, texture_info, meta, _base
|
392
|
+
|
393
|
+
class HCQAllocatorBase(LRUAllocator, Generic[DeviceType]):
|
464
394
|
"""
|
465
395
|
A base allocator class compatible with the HCQ (Hardware Command Queue) API.
|
466
396
|
|
467
|
-
This class implements basic copy operations following the HCQ API, utilizing both
|
397
|
+
This class implements basic copy operations following the HCQ API, utilizing both types of `HWQueue`.
|
468
398
|
"""
|
469
399
|
|
470
|
-
def __init__(self,
|
471
|
-
self.
|
472
|
-
self.b = [self._alloc(batch_size,
|
400
|
+
def __init__(self, dev:DeviceType, batch_size:int=(2 << 20), batch_cnt:int=32):
|
401
|
+
self.dev:DeviceType = dev
|
402
|
+
self.b = [self._alloc(batch_size, BufferSpec(host=True)) for _ in range(batch_cnt)]
|
473
403
|
self.b_timeline, self.b_next = [0] * len(self.b), 0
|
474
404
|
super().__init__()
|
475
405
|
|
476
|
-
def
|
406
|
+
def map(self, buf:HCQBuffer): pass
|
407
|
+
|
408
|
+
def _offset(self, buf, size:int, offset:int) -> HCQBuffer:
|
409
|
+
return HCQBuffer(va_addr=buf.va_addr + offset, size=size, texture_info=buf.texture_info, meta=buf.meta, _base=buf._base or buf)
|
477
410
|
|
478
|
-
|
479
|
-
|
411
|
+
class HCQAllocator(HCQAllocatorBase, Generic[DeviceType]):
|
412
|
+
def _copyin(self, dest:HCQBuffer, src:memoryview):
|
413
|
+
assert self.dev.hw_copy_queue_t is not None
|
414
|
+
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"CPU -> {self.dev.device}", enabled=PROFILE):
|
480
415
|
for i in range(0, src.nbytes, self.b[0].size):
|
481
416
|
self.b_next = (self.b_next + 1) % len(self.b)
|
482
|
-
self.
|
417
|
+
self.dev.timeline_signal.wait(self.b_timeline[self.b_next])
|
483
418
|
ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
|
484
|
-
self.
|
485
|
-
|
486
|
-
|
487
|
-
self.b_timeline[self.b_next] = self.
|
488
|
-
self.
|
419
|
+
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
|
420
|
+
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
|
421
|
+
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
422
|
+
self.b_timeline[self.b_next] = self.dev.timeline_value
|
423
|
+
self.dev.timeline_value += 1
|
489
424
|
|
490
425
|
def copy_from_disk(self, dest:HCQBuffer, src, size):
|
491
426
|
def _get_temp_buf():
|
492
427
|
# Check if the next buffer is safe to be used (its signal has passed) and reserve it.
|
493
|
-
if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.
|
428
|
+
if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.dev.timeline_signal.value:
|
494
429
|
self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
|
495
430
|
return (self.b[self.b_next].va_addr, self.b_next)
|
496
431
|
return None
|
497
432
|
|
498
|
-
|
433
|
+
assert self.dev.hw_copy_queue_t is not None
|
434
|
+
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"DISK -> {self.dev.device}", enabled=PROFILE):
|
499
435
|
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
|
500
|
-
self.
|
501
|
-
|
502
|
-
|
503
|
-
self.b_timeline[batch_info[1]] = self.
|
504
|
-
self.
|
436
|
+
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
|
437
|
+
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
|
438
|
+
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
439
|
+
self.b_timeline[batch_info[1]] = self.dev.timeline_value
|
440
|
+
self.dev.timeline_value += 1
|
505
441
|
|
506
|
-
def
|
507
|
-
self.
|
442
|
+
def _copyout(self, dest:memoryview, src:HCQBuffer):
|
443
|
+
self.dev.synchronize()
|
508
444
|
|
509
|
-
|
445
|
+
assert self.dev.hw_copy_queue_t is not None
|
446
|
+
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"{self.dev.device} -> CPU", enabled=PROFILE):
|
510
447
|
for i in range(0, dest.nbytes, self.b[0].size):
|
511
|
-
self.
|
512
|
-
|
513
|
-
|
514
|
-
self.
|
515
|
-
self.
|
448
|
+
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
|
449
|
+
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
|
450
|
+
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
451
|
+
self.dev.timeline_signal.wait(self.dev.timeline_value)
|
452
|
+
self.dev.timeline_value += 1
|
516
453
|
|
517
454
|
ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
|
518
455
|
|
519
|
-
def
|
520
|
-
src_dev.allocator.map(dest)
|
456
|
+
def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:DeviceType, dest_dev:DeviceType):
|
457
|
+
cast(HCQAllocator, src_dev.allocator).map(dest)
|
521
458
|
|
522
|
-
|
459
|
+
assert src_dev.hw_copy_queue_t is not None
|
460
|
+
with hcq_profile(src_dev, queue_type=src_dev.hw_copy_queue_t, desc=f"{src_dev.device} -> {dest_dev.device}", enabled=PROFILE):
|
523
461
|
src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
|
524
462
|
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
525
463
|
.copy(dest.va_addr, src.va_addr, sz) \
|
@@ -531,9 +469,3 @@ class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method
|
|
531
469
|
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
532
470
|
.signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
|
533
471
|
dest_dev.timeline_value += 1
|
534
|
-
|
535
|
-
def map(self, buf:HCQBuffer): pass
|
536
|
-
|
537
|
-
def offset(self, buf, size:int, offset:int) -> HCQBuffer:
|
538
|
-
return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
|
539
|
-
**{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)
|