tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +248 -115
- tinygrad/codegen/lowerer.py +215 -0
- tinygrad/codegen/transcendental.py +310 -0
- tinygrad/codegen/uopgraph.py +622 -0
- tinygrad/codegen/uops.py +235 -393
- tinygrad/device.py +428 -69
- tinygrad/dtype.py +18 -4
- tinygrad/engine/graph.py +19 -32
- tinygrad/engine/jit.py +148 -70
- tinygrad/engine/realize.py +127 -51
- tinygrad/engine/schedule.py +259 -216
- tinygrad/engine/search.py +29 -22
- tinygrad/function.py +9 -0
- tinygrad/helpers.py +87 -49
- tinygrad/lazy.py +34 -35
- tinygrad/multi.py +41 -36
- tinygrad/nn/__init__.py +39 -22
- tinygrad/nn/state.py +3 -3
- tinygrad/ops.py +63 -62
- tinygrad/renderer/__init__.py +43 -21
- tinygrad/renderer/assembly.py +104 -106
- tinygrad/renderer/cstyle.py +87 -60
- tinygrad/renderer/llvmir.py +21 -30
- tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/kfd.py +32 -0
- tinygrad/runtime/autogen/libc.py +4260 -0
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/graph/clang.py +2 -2
- tinygrad/runtime/graph/cuda.py +8 -11
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +18 -15
- tinygrad/runtime/ops_amd.py +197 -305
- tinygrad/runtime/ops_clang.py +2 -2
- tinygrad/runtime/ops_cuda.py +36 -94
- tinygrad/runtime/ops_disk.py +3 -7
- tinygrad/runtime/ops_gpu.py +4 -2
- tinygrad/runtime/ops_hip.py +70 -0
- tinygrad/runtime/ops_metal.py +38 -27
- tinygrad/runtime/ops_nv.py +283 -363
- tinygrad/runtime/ops_python.py +26 -30
- tinygrad/runtime/support/compiler_cuda.py +78 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/shape/shapetracker.py +5 -14
- tinygrad/shape/symbolic.py +4 -8
- tinygrad/shape/view.py +34 -22
- tinygrad/tensor.py +399 -97
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
- tinygrad-0.9.2.dist-info/RECORD +70 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/runtime/{driver → support}/__init__.py +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/runtime/graph/hcq.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
|
-
import collections,
|
1
|
+
import collections, time
|
2
2
|
from typing import List, Any, Dict, cast, Optional, Tuple, Set
|
3
|
-
from tinygrad.helpers import round_up,
|
4
|
-
from tinygrad.device import
|
3
|
+
from tinygrad.helpers import round_up, PROFILE, memsize_to_str
|
4
|
+
from tinygrad.device import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWCommandQueue, HWComputeQueue, HWCopyQueue, HCQArgsState, \
|
5
|
+
Buffer, BufferOptions, Compiled, Device
|
5
6
|
from tinygrad.shape.symbolic import Variable
|
6
7
|
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
7
8
|
from tinygrad.engine.jit import MultiGraphRunner
|
@@ -9,179 +10,191 @@ from tinygrad.engine.jit import MultiGraphRunner
|
|
9
10
|
class HCQGraph(MultiGraphRunner):
|
10
11
|
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
11
12
|
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
12
|
-
self.devices = list(set(cast(
|
13
|
+
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
|
13
14
|
|
14
15
|
# Allocate kernel args.
|
15
16
|
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
16
17
|
for ji in self.jit_cache:
|
17
18
|
if not isinstance(ji.prg, CompiledRunner): continue
|
18
19
|
kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
|
19
|
-
self.kernargs_bufs: Dict[Compiled,
|
20
|
-
kernargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
|
20
|
+
self.kernargs_bufs: Dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)) for dev,sz in kernargs_size.items()}
|
21
21
|
|
22
22
|
# Fill initial arguments.
|
23
|
-
self.
|
24
|
-
|
25
|
-
|
23
|
+
self.ji_args: Dict[int, HCQArgsState] = {}
|
24
|
+
|
25
|
+
kargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
|
26
26
|
for j,ji in enumerate(self.jit_cache):
|
27
27
|
if not isinstance(ji.prg, CompiledRunner): continue
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
self.ji_args_bufs[j] = to_mv(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset, len(ji.bufs) * 8).cast('Q')
|
32
|
-
self.ji_args_vars[j] = to_mv(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset + len(ji.bufs) * 8, len(ji.prg.p.vars) * 4).cast('I')
|
33
|
-
for i in range(len(ji.bufs)): self.ji_args_bufs[j][i] = cast(Buffer, ji.bufs[i])._buf.va_addr
|
34
|
-
for i in range(len(ji.prg.p.vars)): self.ji_args_vars[j][i] = var_vals[ji.prg.p.vars[i]]
|
35
|
-
|
36
|
-
# NV needs constbuffer to be set
|
37
|
-
if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
|
28
|
+
kargs_ptrs[ji.prg.device] = (kargs_ptr:=kargs_ptrs[ji.prg.device]) + round_up(ji.prg.clprg.kernargs_alloc_size, 16)
|
29
|
+
self.ji_args[j] = ji.prg.clprg.fill_kernargs([cast(Buffer, b)._buf for b in ji.bufs], [var_vals[v] for v in ji.prg.p.vars], kargs_ptr)
|
38
30
|
|
39
31
|
# Schedule Dependencies.
|
40
32
|
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
|
41
33
|
# graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
|
42
34
|
# global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’s
|
43
35
|
# compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
|
44
|
-
self.
|
45
|
-
|
36
|
+
self.ji_schedule: Dict[int, Tuple[HCQCompiled, HWCommandQueue, List, List, HCQSignal, Optional[int]]] = {}
|
37
|
+
|
38
|
+
self.comp_queues: Dict[HCQCompiled, HWComputeQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
|
39
|
+
self.copy_queues: Dict[HCQCompiled, HWCopyQueue] = {} # lazy allocation
|
46
40
|
|
47
|
-
self.
|
48
|
-
self.
|
49
|
-
self.dev_kickoff_signal = {dev: self.devices[0]._get_signal(value=0) for dev in self.devices + ['CPU']} # Dict[dev, signal]
|
50
|
-
self.kickoff_value = 0
|
41
|
+
self.signals: Dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
|
42
|
+
self.kickoff_value: int = 0
|
51
43
|
|
52
|
-
self.
|
53
|
-
|
44
|
+
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(self.jit_cache) * 2)] if PROFILE else []
|
45
|
+
self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = []
|
54
46
|
|
55
|
-
|
56
|
-
|
47
|
+
last_j: Dict[HWCommandQueue, Optional[int]] = collections.defaultdict(lambda: None)
|
48
|
+
queue_access: Dict[HWCommandQueue, Dict[HWCommandQueue, Optional[int]]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
|
49
|
+
dev_access: Dict[HWCommandQueue, Set[HCQCompiled]] = collections.defaultdict(set)
|
50
|
+
|
51
|
+
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
|
57
52
|
|
58
53
|
for j,ji in enumerate(self.jit_cache):
|
59
|
-
enqueue_dev = ji.prg.device if isinstance(ji.prg, CompiledRunner) else Device[ji.bufs[1].device] #type:ignore
|
60
|
-
enqueue_queue = self.comp_queues[enqueue_dev] if
|
61
|
-
out_signal = self.signals
|
62
|
-
writable_buffers = ji.prg.p.outcount if isinstance(ji.prg, CompiledRunner) else 1
|
63
|
-
deps = self.access_resources(enqueue_queue, ji.bufs[writable_buffers:], ji.bufs[:writable_buffers], j + 1)
|
54
|
+
enqueue_dev = ji.prg.device if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
|
55
|
+
enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
|
56
|
+
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
|
64
57
|
|
65
|
-
|
66
|
-
|
67
|
-
|
58
|
+
# Get dependencies based on input and output buffers.
|
59
|
+
rdeps = self._access_resources(ji.bufs[(wb:=ji.prg.p.outcount if is_exec_prg else 1):], ji.bufs[:wb], (enqueue_queue, j + 1)) #type:ignore
|
60
|
+
|
61
|
+
# Update dependencies to include previous kernel in queue. This is required for timeline signals.
|
62
|
+
opt_deps, deps = [], rdeps + ([(enqueue_queue, prev_ji + 1)] if (prev_ji:=last_j[enqueue_queue]) is not None else [])
|
63
|
+
|
64
|
+
# Optimize dependencies by removing redundant ones. Remove waiting for the value of the queue which is known to be already
|
65
|
+
# synced with the current queue.
|
66
|
+
for dep_queue, dep_val in sorted(deps, key=lambda x: x[1], reverse=True):
|
67
|
+
if (qa:=queue_access[enqueue_queue][dep_queue]) is None or qa < dep_val:
|
68
|
+
opt_deps.append((self.signals[dep_queue], dep_val))
|
69
|
+
queue_access[enqueue_queue][dep_queue] = dep_val
|
68
70
|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
71
|
+
# Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph.
|
72
|
+
for dep_queue, _ in opt_deps: dev_access[enqueue_queue].update(dev_access[dep_queue])
|
73
|
+
sync_signals = [(self.signals[d], self.kickoff_value) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]]
|
74
|
+
dev_access[enqueue_queue].update(cast(HCQCompiled, Device[cast(Buffer, b).device]) for b in ji.bufs)
|
73
75
|
|
74
|
-
#
|
75
|
-
|
76
|
-
|
77
|
-
|
76
|
+
# Remove self-dependency for compute and copy queues.
|
77
|
+
# For compute, in case of NV, optimize when only 1 same-queue dependency exists, since NV chains 2+ executions in this case,
|
78
|
+
# eliminating dependency need.
|
79
|
+
dname = enqueue_dev.dname.split(":", 1)[0]
|
80
|
+
can_opt = (dname == "AMD" or (dname == "NV" and len(sync_signals) == 0 and len(opt_deps) == 1 and id(opt_deps[0][0]) == id(out_signal)))
|
81
|
+
if can_opt or isinstance(ji.prg, BufferXfer): opt_deps = [x for x in opt_deps if id(x[0]) != id(out_signal)]
|
78
82
|
|
79
|
-
|
80
|
-
|
81
|
-
self.
|
82
|
-
|
83
|
+
# Enable necessary signals in the schedule by setting the signal value.
|
84
|
+
for sig, val in opt_deps: self.ji_schedule[val - 1] = self.ji_schedule[val - 1][:5] + (val,)
|
85
|
+
self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if is_exec_prg else (j + 1))
|
86
|
+
|
87
|
+
# Collect profile information if profiling is enabled.
|
88
|
+
if PROFILE:
|
89
|
+
prof_ji_desc = ji.prg.clprg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
|
90
|
+
|
91
|
+
sig_st, sig_en = (j * 2, True), (j * 2 + 1, True)
|
92
|
+
if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None: sig_st = (prev_ji * 2 + 1, False)
|
93
|
+
|
94
|
+
if is_exec_prg: prof_args = None
|
95
|
+
else: prof_args = {"Size": memsize_to_str(ji.bufs[0].nbytes), "GB/S": lambda dur, b=ji.bufs[0].nbytes: f"{b/1e3/dur:.2f}"} # type: ignore
|
96
|
+
|
97
|
+
self.prof_records.append((sig_st, sig_en, enqueue_dev, prof_ji_desc, not is_exec_prg, [d - 1 for _, d in rdeps], prof_args))
|
98
|
+
|
99
|
+
last_j[enqueue_queue] = j
|
83
100
|
|
84
101
|
# Build hardware queues.
|
85
|
-
self.
|
86
|
-
self.copy_to_devs: Dict[
|
87
|
-
self.kickoff_wait_cmds: Dict[
|
102
|
+
self.op_cmd_idx: Dict[int, Tuple[Any, int]] = {}
|
103
|
+
self.copy_to_devs: Dict[HCQCompiled, Set[HCQCompiled]] = {dev: set() for dev in self.devices}
|
104
|
+
self.kickoff_wait_cmds: Dict[HWCommandQueue, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
|
88
105
|
|
89
106
|
for dev in self.devices:
|
90
107
|
self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
91
|
-
.wait(self.
|
108
|
+
.wait(self.signals['CPU'], self.kickoff_value).signal(self.signals[dev], self.kickoff_value)
|
92
109
|
|
93
110
|
for j,ji in enumerate(self.jit_cache):
|
94
|
-
deps,
|
95
|
-
|
111
|
+
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
|
112
|
+
|
113
|
+
for i in range(len(sync_signals)): self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) + i)
|
114
|
+
for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
|
96
115
|
|
97
116
|
# Encode waits and start profile timestamp (if needed).
|
98
|
-
|
99
|
-
enqueue_queue.wait(sig, val)
|
100
|
-
if id(sig) in [id(x) for x in self.dev_kickoff_signal.values()]: self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) - 1)
|
101
|
-
if prof_info: enqueue_queue.timestamp(prof_info[0])
|
117
|
+
if PROFILE and self.prof_records[j][0][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][0][0]])
|
102
118
|
|
103
119
|
# Encode main commands based on ji type.
|
104
120
|
if isinstance(ji.prg, CompiledRunner):
|
105
|
-
enqueue_queue.exec(ji.prg.clprg, self.
|
106
|
-
signal=self.signals[enqueue_queue] if signal_value is not None else None, signal_value=signal_value)
|
107
|
-
self.exec_ptrs[j] = (enqueue_queue, len(enqueue_queue) - 1)
|
121
|
+
cast(HWComputeQueue, enqueue_queue).exec(ji.prg.clprg, self.ji_args[j], *ji.prg.p.launch_dims(var_vals))
|
108
122
|
elif isinstance(ji.prg, BufferXfer):
|
109
123
|
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
110
|
-
Device[src.device].
|
111
|
-
enqueue_queue.copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes)
|
112
|
-
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
|
124
|
+
cast(HCQAllocator, Device[src.device].allocator).map(dest._buf)
|
125
|
+
cast(HWCopyQueue, enqueue_queue).copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes)
|
126
|
+
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
|
127
|
+
self.op_cmd_idx[j] = (enqueue_queue, len(enqueue_queue) - 1)
|
113
128
|
|
114
129
|
# Encode finish profile timestamp (if needed).
|
115
|
-
if
|
130
|
+
if PROFILE and self.prof_records[j][1][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][1][0]])
|
131
|
+
|
132
|
+
if signal_val is not None: enqueue_queue.signal(signal, signal_val)
|
116
133
|
|
117
134
|
for dev in self.devices:
|
118
135
|
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
|
119
|
-
if (
|
120
|
-
|
136
|
+
if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1)
|
137
|
+
|
138
|
+
self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value).bind(dev)
|
139
|
+
if dev in self.copy_queues: self.copy_queues[dev].bind(dev)
|
121
140
|
|
122
|
-
|
123
|
-
|
124
|
-
if hasattr(self.copy_queues[dev], 'bind') and self.last_ji[self.copy_queues[dev]] is not None: self.copy_queues[dev].bind(dev)
|
141
|
+
self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
|
142
|
+
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
|
125
143
|
|
126
144
|
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
127
145
|
# Wait and restore signals
|
128
146
|
self.kickoff_value += 1
|
129
|
-
for dev in self.devices: dev.
|
130
|
-
for
|
131
|
-
|
132
|
-
self.devices[0]._set_signal(self.dev_kickoff_signal['CPU'], self.kickoff_value)
|
147
|
+
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
148
|
+
for sig in self.queue_signals_to_reset: sig.value = 0
|
149
|
+
self.signals['CPU'].value = self.kickoff_value
|
133
150
|
|
134
|
-
if PROFILE and self.kickoff_value > 1:
|
135
|
-
for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): #type: ignore
|
136
|
-
dev.raw_prof_records += [(dev._read_timestamp(st), dev._read_timestamp(en), desc, is_cp)]
|
151
|
+
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
|
137
152
|
|
138
153
|
# Update rawbuffers
|
139
|
-
for (j,i),input_idx in self.input_replace.items():
|
154
|
+
for (j,i),input_idx in self.input_replace.items():
|
155
|
+
if j in self.ji_args: self.ji_args[j].update_buffer(i, input_rawbuffers[input_idx]._buf)
|
156
|
+
else: self.op_cmd_idx[j][0].update_copy(self.op_cmd_idx[j][1], **{('dest' if i == 0 else 'src'): input_rawbuffers[input_idx]._buf.va_addr})
|
140
157
|
|
141
158
|
# Update var_vals
|
142
|
-
for j in self.
|
143
|
-
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars): self.ji_args_vars[j][i] = var_vals[v]
|
159
|
+
for j, i, v in self.updated_vars(var_vals): self.ji_args[j].update_var(i, v)
|
144
160
|
|
145
|
-
|
146
|
-
|
147
|
-
queue
|
161
|
+
# Update launch dims
|
162
|
+
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
163
|
+
queue, cmd_ptr = self.op_cmd_idx[j]
|
164
|
+
queue.update_exec(cmd_ptr, global_dims, local_dims)
|
148
165
|
|
149
166
|
for dev in self.devices:
|
150
|
-
self.comp_queues[dev].
|
151
|
-
|
152
|
-
|
167
|
+
comp_queue, copy_queue, need_sig_upd = self.comp_queues[dev], self.copy_queues.get(dev, None), dev.timeline_signal != self.last_timeline[dev][0]
|
168
|
+
comp_queue.update_wait(1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value - 1) \
|
169
|
+
.update_wait(2, value=self.kickoff_value).update_signal(3, value=self.kickoff_value) \
|
170
|
+
.update_signal(len(comp_queue)-1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value).submit(dev)
|
153
171
|
|
154
|
-
if
|
155
|
-
for cmd_idx in self.kickoff_wait_cmds[
|
156
|
-
|
172
|
+
if copy_queue is not None:
|
173
|
+
for cmd_idx in self.kickoff_wait_cmds[copy_queue]: copy_queue.update_wait(cmd_idx, value=self.kickoff_value)
|
174
|
+
copy_queue.submit(dev)
|
157
175
|
|
158
|
-
self.
|
176
|
+
self.last_timeline[dev] = (dev.timeline_signal, dev.timeline_value)
|
159
177
|
dev.timeline_value += 1
|
160
178
|
|
161
179
|
if wait:
|
162
180
|
st = time.perf_counter()
|
163
|
-
for dev in self.devices: dev.
|
181
|
+
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
164
182
|
return time.perf_counter() - st
|
165
183
|
return None
|
166
184
|
|
167
|
-
def
|
168
|
-
|
185
|
+
def collect_timestamps(self):
|
186
|
+
timestamps = [s.timestamp for s in self.prof_signals]
|
169
187
|
|
170
|
-
|
171
|
-
|
172
|
-
for buf in read+write:
|
173
|
-
if buf.device not in self.save_devs[queue]:
|
174
|
-
self.save_devs[queue].add(buf.device)
|
175
|
-
sync_signals += [(self.dev_kickoff_signal[Device[buf.device]], self.kickoff_value)]
|
188
|
+
for (st,_), (en,_), dev, desc, is_cp, deps, args in self.prof_records:
|
189
|
+
dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp, args)]
|
176
190
|
|
177
|
-
|
191
|
+
for x in deps:
|
192
|
+
(b_st,_), (b_en,_), b_dev, _, b_is_cp, _, _ = self.prof_records[x]
|
193
|
+
dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)]
|
178
194
|
|
179
195
|
def __del__(self):
|
180
|
-
for dev in self.devices: dev.
|
196
|
+
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
181
197
|
|
182
|
-
|
183
|
-
if PROFILE and self.kickoff_value > 1:
|
184
|
-
for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): dev.sig_prof_records += [(st, en, desc, is_cp)] #type: ignore
|
198
|
+
if PROFILE and self.kickoff_value >= 1: self.collect_timestamps()
|
185
199
|
|
186
|
-
|
187
|
-
for dev, buf in self.kernargs_bufs.items(): dev.allocator._free(buf, BufferOptions(cpu_access=True))
|
200
|
+
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferOptions(cpu_access=True))
|
tinygrad/runtime/graph/metal.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
1
|
from typing import List, Any, Dict, cast, Optional
|
2
2
|
import Metal
|
3
3
|
from tinygrad.dtype import dtypes
|
4
|
-
from tinygrad.helpers import dedup, unwrap2
|
4
|
+
from tinygrad.helpers import dedup, unwrap2
|
5
5
|
from tinygrad.device import Buffer
|
6
6
|
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
7
|
-
from tinygrad.engine.jit import GraphRunner
|
7
|
+
from tinygrad.engine.jit import GraphRunner, GraphException
|
8
8
|
from tinygrad.shape.symbolic import Variable
|
9
9
|
from tinygrad.runtime.ops_metal import wait_check
|
10
10
|
|
@@ -24,7 +24,7 @@ class MetalGraph(GraphRunner):
|
|
24
24
|
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
25
25
|
|
26
26
|
if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
|
27
|
-
all_resources = [self.int_buf] if len(self.vars) else []
|
27
|
+
all_resources = [self.int_buf.buf] if len(self.vars) else []
|
28
28
|
|
29
29
|
for j,ji in enumerate(self.jit_cache):
|
30
30
|
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
|
@@ -35,27 +35,30 @@ class MetalGraph(GraphRunner):
|
|
35
35
|
icb_command.setComputePipelineState_(unwrap2(
|
36
36
|
self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)))
|
37
37
|
for i,b in enumerate(ji.bufs):
|
38
|
-
if b is not None:
|
39
|
-
icb_command.setKernelBuffer_offset_atIndex_(b._buf,
|
40
|
-
all_resources.append(b._buf)
|
41
|
-
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, self.vars.index(v)*4, len(ji.bufs)+i)
|
42
|
-
|
43
|
-
|
44
|
-
|
38
|
+
if b is not None and b not in input_rawbuffers:
|
39
|
+
icb_command.setKernelBuffer_offset_atIndex_(b._buf.buf, b._buf.offset, i)
|
40
|
+
all_resources.append(b._buf.buf)
|
41
|
+
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
|
42
|
+
|
43
|
+
global_size, local_size = prg.p.launch_dims(var_vals)
|
44
|
+
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
45
45
|
icb_command.setBarrier()
|
46
46
|
|
47
47
|
self.all_resources = dedup(all_resources)
|
48
48
|
self.command_buffer: Any = None
|
49
|
-
if len(self.vars): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
|
49
|
+
if len(self.vars): self.int_buf_view = self.int_buf.buf.contents().as_buffer(self.int_buf.buf.length()).cast('i')
|
50
50
|
|
51
51
|
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
52
52
|
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
53
|
-
all_resources = dedup(self.all_resources + [x._buf for x in input_rawbuffers])
|
53
|
+
all_resources = dedup(self.all_resources + [x._buf.buf for x in input_rawbuffers])
|
54
54
|
|
55
55
|
for (j,i),input_idx in self.input_replace.items():
|
56
|
-
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf,
|
57
|
-
|
58
|
-
|
56
|
+
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf.buf,
|
57
|
+
input_rawbuffers[input_idx]._buf.offset, i)
|
58
|
+
|
59
|
+
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
60
|
+
prg = cast(CompiledRunner, self.jit_cache[j].prg)
|
61
|
+
global_size, local_size = global_dims or prg.p.global_size, local_dims or prg.p.local_size
|
59
62
|
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size),
|
60
63
|
Metal.MTLSize(*local_size))
|
61
64
|
for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
|