tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/runtime/graph/clang.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
1
|
from typing import List, Dict, cast
|
2
2
|
import ctypes
|
3
|
-
from tinygrad.helpers import dedup, cpu_time_execution,
|
4
|
-
from tinygrad.engine.jit import GraphRunner
|
3
|
+
from tinygrad.helpers import dedup, cpu_time_execution, DEBUG
|
4
|
+
from tinygrad.engine.jit import GraphRunner, GraphException
|
5
5
|
from tinygrad.device import Buffer, Device
|
6
6
|
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
7
|
-
from tinygrad.
|
7
|
+
from tinygrad.ops import Variable
|
8
8
|
from tinygrad.runtime.ops_clang import ClangProgram
|
9
9
|
from tinygrad.renderer.cstyle import ClangRenderer
|
10
10
|
render_dtype = ClangRenderer().render_dtype
|
tinygrad/runtime/graph/cuda.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
import ctypes
|
2
2
|
from typing import Any, Optional, Tuple, Dict, List, cast
|
3
3
|
import tinygrad.runtime.autogen.cuda as cuda
|
4
|
-
from tinygrad.helpers import init_c_var,
|
4
|
+
from tinygrad.helpers import init_c_var, dedup
|
5
5
|
from tinygrad.device import Buffer, Device
|
6
6
|
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
7
|
-
from tinygrad.
|
7
|
+
from tinygrad.ops import Variable
|
8
8
|
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
9
|
-
from tinygrad.engine.jit import MultiGraphRunner
|
9
|
+
from tinygrad.engine.jit import MultiGraphRunner, GraphException
|
10
10
|
|
11
11
|
class CUDAGraph(MultiGraphRunner):
|
12
12
|
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
@@ -25,21 +25,20 @@ class CUDAGraph(MultiGraphRunner):
|
|
25
25
|
global_size, local_size = ji.prg.p.launch_dims(var_vals)
|
26
26
|
|
27
27
|
new_node = cuda.CUgraphNode()
|
28
|
-
deps = self._access_resources([x.base for x in ji.bufs
|
29
|
-
[x.base for x in ji.bufs[:ji.prg.p.outcount] if x is not None], new_dependency=new_node)
|
28
|
+
deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node)
|
30
29
|
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
31
30
|
|
32
31
|
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.p.vars])
|
33
32
|
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs)
|
34
33
|
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
|
35
34
|
|
36
|
-
if j in self.
|
35
|
+
if j in self.launch_dims_replace or j in self.var_vals_replace or j in self.jc_idx_with_updatable_rawbufs:
|
37
36
|
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
|
38
37
|
elif isinstance(ji.prg, BufferXfer):
|
39
38
|
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
40
39
|
src_dev = cast(CUDADevice, Device[src.device])
|
41
40
|
node_from = cuda.CUgraphNode()
|
42
|
-
deps = self._access_resources(
|
41
|
+
deps = self._access_resources(rawbufs=[dest.base, src.base], write=[0], new_dependency=node_from)
|
43
42
|
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
44
43
|
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
|
45
44
|
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
|
@@ -58,13 +57,13 @@ class CUDAGraph(MultiGraphRunner):
|
|
58
57
|
elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
|
59
58
|
|
60
59
|
# Update var_vals in the c_args struct.
|
61
|
-
for j in self.
|
62
|
-
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
|
63
|
-
setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
|
60
|
+
for j, i, v in self.updated_vars(var_vals): setattr(self.updatable_nodes[j][2], f'v{i}', v)
|
64
61
|
|
65
62
|
# Update launch dims in the kern_params struct.
|
66
|
-
for j in self.
|
67
|
-
|
63
|
+
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
64
|
+
prg = cast(CompiledRunner, self.jit_cache[j].prg)
|
65
|
+
node, global_size, local_size = self.updatable_nodes[j][1], global_dims or prg.p.global_size, local_dims or prg.p.local_size
|
66
|
+
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size # type: ignore[misc]
|
68
67
|
|
69
68
|
# Update graph nodes with the updated structs.
|
70
69
|
for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
|
@@ -76,6 +75,3 @@ class CUDAGraph(MultiGraphRunner):
|
|
76
75
|
def __del__(self):
|
77
76
|
if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))
|
78
77
|
if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
|
79
|
-
|
80
|
-
def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
|
81
|
-
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size
|
tinygrad/runtime/graph/hcq.py
CHANGED
@@ -1,187 +1,200 @@
|
|
1
|
-
import collections,
|
1
|
+
import collections, time
|
2
2
|
from typing import List, Any, Dict, cast, Optional, Tuple, Set
|
3
|
-
from tinygrad.helpers import round_up,
|
3
|
+
from tinygrad.helpers import round_up, PROFILE, memsize_to_str
|
4
|
+
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWCommandQueue, HWComputeQueue, HWCopyQueue, HCQArgsState
|
4
5
|
from tinygrad.device import Buffer, BufferOptions, Compiled, Device
|
5
|
-
from tinygrad.
|
6
|
+
from tinygrad.ops import Variable
|
6
7
|
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
7
8
|
from tinygrad.engine.jit import MultiGraphRunner
|
8
9
|
|
9
10
|
class HCQGraph(MultiGraphRunner):
|
10
11
|
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
11
12
|
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
12
|
-
self.devices = list(set(cast(
|
13
|
+
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
|
13
14
|
|
14
15
|
# Allocate kernel args.
|
15
16
|
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
16
17
|
for ji in self.jit_cache:
|
17
18
|
if not isinstance(ji.prg, CompiledRunner): continue
|
18
19
|
kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
|
19
|
-
self.kernargs_bufs: Dict[Compiled,
|
20
|
-
kernargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
|
20
|
+
self.kernargs_bufs: Dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)) for dev,sz in kernargs_size.items()}
|
21
21
|
|
22
22
|
# Fill initial arguments.
|
23
|
-
self.
|
24
|
-
|
25
|
-
|
23
|
+
self.ji_args: Dict[int, HCQArgsState] = {}
|
24
|
+
|
25
|
+
kargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
|
26
26
|
for j,ji in enumerate(self.jit_cache):
|
27
27
|
if not isinstance(ji.prg, CompiledRunner): continue
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
self.ji_args_bufs[j] = to_mv(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset, len(ji.bufs) * 8).cast('Q')
|
32
|
-
self.ji_args_vars[j] = to_mv(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset + len(ji.bufs) * 8, len(ji.prg.p.vars) * 4).cast('I')
|
33
|
-
for i in range(len(ji.bufs)): self.ji_args_bufs[j][i] = cast(Buffer, ji.bufs[i])._buf.va_addr
|
34
|
-
for i in range(len(ji.prg.p.vars)): self.ji_args_vars[j][i] = var_vals[ji.prg.p.vars[i]]
|
35
|
-
|
36
|
-
# NV needs constbuffer to be set
|
37
|
-
if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
|
28
|
+
kargs_ptrs[ji.prg.device] = (kargs_ptr:=kargs_ptrs[ji.prg.device]) + round_up(ji.prg.clprg.kernargs_alloc_size, 16)
|
29
|
+
self.ji_args[j] = ji.prg.clprg.fill_kernargs([cast(Buffer, b)._buf for b in ji.bufs], [var_vals[v] for v in ji.prg.p.vars], kargs_ptr)
|
38
30
|
|
39
31
|
# Schedule Dependencies.
|
40
32
|
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
|
41
33
|
# graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
|
42
34
|
# global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’s
|
43
35
|
# compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
|
44
|
-
self.
|
45
|
-
|
36
|
+
self.ji_schedule: Dict[int, Tuple[HCQCompiled, HWCommandQueue, List, List, HCQSignal, Optional[int]]] = {}
|
37
|
+
|
38
|
+
self.comp_queues: Dict[HCQCompiled, HWComputeQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
|
39
|
+
self.copy_queues: Dict[HCQCompiled, HWCopyQueue] = {} # lazy allocation
|
46
40
|
|
47
|
-
self.
|
48
|
-
self.
|
49
|
-
self.dev_kickoff_signal = {dev: self.devices[0]._get_signal(value=0) for dev in self.devices + ['CPU']} # Dict[dev, signal]
|
50
|
-
self.kickoff_value = 0
|
41
|
+
self.signals: Dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
|
42
|
+
self.kickoff_value: int = 0
|
51
43
|
|
52
|
-
self.
|
53
|
-
|
44
|
+
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(self.jit_cache) * 2)] if PROFILE else []
|
45
|
+
self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = []
|
54
46
|
|
55
|
-
|
56
|
-
|
47
|
+
last_j: Dict[HWCommandQueue, Optional[int]] = collections.defaultdict(lambda: None)
|
48
|
+
queue_access: Dict[HWCommandQueue, Dict[HWCommandQueue, Optional[int]]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
|
49
|
+
dev_access: Dict[HWCommandQueue, Set[HCQCompiled]] = collections.defaultdict(set)
|
50
|
+
|
51
|
+
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
|
57
52
|
|
58
53
|
for j,ji in enumerate(self.jit_cache):
|
59
|
-
enqueue_dev = ji.prg.device if isinstance(ji.prg, CompiledRunner) else Device[ji.bufs[1].device] #type:ignore
|
60
|
-
enqueue_queue = self.comp_queues[enqueue_dev] if
|
61
|
-
out_signal = self.signals
|
62
|
-
writable_buffers = ji.prg.p.outcount if isinstance(ji.prg, CompiledRunner) else 1
|
63
|
-
deps = self.access_resources(enqueue_queue, ji.bufs[writable_buffers:], ji.bufs[:writable_buffers], j + 1)
|
54
|
+
enqueue_dev = ji.prg.device if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
|
55
|
+
enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
|
56
|
+
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
|
64
57
|
|
65
|
-
|
66
|
-
|
67
|
-
|
58
|
+
# Get dependencies based on input and output buffers.
|
59
|
+
rdeps = self._access_resources(ji.bufs, ji.prg.p.outs if is_exec_prg else [0], (enqueue_queue, j + 1)) #type:ignore
|
60
|
+
|
61
|
+
# Update dependencies to include previous kernel in queue. This is required for timeline signals.
|
62
|
+
opt_deps, deps = [], rdeps + ([(enqueue_queue, prev_ji + 1)] if (prev_ji:=last_j[enqueue_queue]) is not None else [])
|
63
|
+
|
64
|
+
# Optimize dependencies by removing redundant ones. Remove waiting for the value of the queue which is known to be already
|
65
|
+
# synced with the current queue.
|
66
|
+
for dep_queue, dep_val in sorted(deps, key=lambda x: x[1], reverse=True):
|
67
|
+
if (qa:=queue_access[enqueue_queue][dep_queue]) is None or qa < dep_val:
|
68
|
+
opt_deps.append((self.signals[dep_queue], dep_val))
|
69
|
+
queue_access[enqueue_queue][dep_queue] = dep_val
|
68
70
|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
71
|
+
# Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph.
|
72
|
+
for dep_queue, _ in opt_deps: dev_access[enqueue_queue].update(dev_access[dep_queue])
|
73
|
+
sync_signals = [(self.signals[d], self.kickoff_value) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]]
|
74
|
+
dev_access[enqueue_queue].update(cast(HCQCompiled, Device[cast(Buffer, b).device]) for b in ji.bufs)
|
73
75
|
|
74
|
-
#
|
75
|
-
|
76
|
-
|
77
|
-
|
76
|
+
# Remove self-dependency for compute and copy queues.
|
77
|
+
# For compute, in case of NV, optimize when only 1 same-queue dependency exists, since NV chains 2+ executions in this case,
|
78
|
+
# eliminating dependency need.
|
79
|
+
dname = enqueue_dev.dname.split(":", 1)[0]
|
80
|
+
can_opt = dname in {"AMD", "QCOM"} or (dname == "NV" and len(sync_signals) == 0 and len(opt_deps) == 1 and id(opt_deps[0][0]) == id(out_signal))
|
81
|
+
if can_opt or isinstance(ji.prg, BufferXfer): opt_deps = [x for x in opt_deps if id(x[0]) != id(out_signal)]
|
78
82
|
|
79
|
-
|
80
|
-
|
81
|
-
self.
|
82
|
-
|
83
|
+
# Enable necessary signals in the schedule by setting the signal value.
|
84
|
+
for sig, val in opt_deps: self.ji_schedule[val - 1] = self.ji_schedule[val - 1][:5] + (val,)
|
85
|
+
self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if is_exec_prg else (j + 1))
|
86
|
+
|
87
|
+
# Collect profile information if profiling is enabled.
|
88
|
+
if PROFILE:
|
89
|
+
prof_ji_desc = ji.prg.clprg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
|
90
|
+
|
91
|
+
sig_st, sig_en = (j * 2, True), (j * 2 + 1, True)
|
92
|
+
if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None: sig_st = (prev_ji * 2 + 1, False)
|
93
|
+
|
94
|
+
if is_exec_prg: prof_args = None
|
95
|
+
else: prof_args = {"Size": memsize_to_str(ji.bufs[0].nbytes), "GB/S": lambda dur, b=ji.bufs[0].nbytes: f"{b/1e3/dur:.2f}"} # type: ignore
|
96
|
+
|
97
|
+
self.prof_records.append((sig_st, sig_en, enqueue_dev, prof_ji_desc, not is_exec_prg, [d - 1 for _, d in rdeps], prof_args))
|
98
|
+
|
99
|
+
last_j[enqueue_queue] = j
|
83
100
|
|
84
101
|
# Build hardware queues.
|
85
|
-
self.
|
86
|
-
self.copy_to_devs: Dict[
|
87
|
-
self.kickoff_wait_cmds: Dict[
|
102
|
+
self.op_cmd_idx: Dict[int, Tuple[Any, int]] = {}
|
103
|
+
self.copy_to_devs: Dict[HCQCompiled, Set[HCQCompiled]] = {dev: set() for dev in self.devices}
|
104
|
+
self.kickoff_wait_cmds: Dict[HWCommandQueue, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
|
88
105
|
|
89
106
|
for dev in self.devices:
|
90
107
|
self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
91
|
-
.wait(self.
|
108
|
+
.wait(self.signals['CPU'], self.kickoff_value).signal(self.signals[dev], self.kickoff_value)
|
92
109
|
|
93
110
|
for j,ji in enumerate(self.jit_cache):
|
94
|
-
deps,
|
95
|
-
|
111
|
+
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
|
112
|
+
|
113
|
+
for i in range(len(sync_signals)): self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) + i)
|
114
|
+
for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
|
96
115
|
|
97
116
|
# Encode waits and start profile timestamp (if needed).
|
98
|
-
|
99
|
-
enqueue_queue.wait(sig, val)
|
100
|
-
if id(sig) in [id(x) for x in self.dev_kickoff_signal.values()]: self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) - 1)
|
101
|
-
if prof_info: enqueue_queue.timestamp(prof_info[0])
|
117
|
+
if PROFILE and self.prof_records[j][0][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][0][0]])
|
102
118
|
|
103
119
|
# Encode main commands based on ji type.
|
104
120
|
if isinstance(ji.prg, CompiledRunner):
|
105
|
-
enqueue_queue.exec(ji.prg.clprg, self.
|
106
|
-
signal=self.signals[enqueue_queue] if signal_value is not None else None, signal_value=signal_value)
|
107
|
-
self.exec_ptrs[j] = (enqueue_queue, len(enqueue_queue) - 1)
|
121
|
+
cast(HWComputeQueue, enqueue_queue).exec(ji.prg.clprg, self.ji_args[j], *ji.prg.p.launch_dims(var_vals))
|
108
122
|
elif isinstance(ji.prg, BufferXfer):
|
109
123
|
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
110
|
-
Device[src.device].
|
111
|
-
enqueue_queue.copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes)
|
112
|
-
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
|
124
|
+
cast(HCQAllocator, Device[src.device].allocator).map(dest._buf)
|
125
|
+
cast(HWCopyQueue, enqueue_queue).copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes)
|
126
|
+
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
|
127
|
+
self.op_cmd_idx[j] = (enqueue_queue, len(enqueue_queue) - 1)
|
113
128
|
|
114
129
|
# Encode finish profile timestamp (if needed).
|
115
|
-
if
|
130
|
+
if PROFILE and self.prof_records[j][1][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][1][0]])
|
131
|
+
|
132
|
+
if signal_val is not None: enqueue_queue.signal(signal, signal_val)
|
116
133
|
|
117
134
|
for dev in self.devices:
|
118
135
|
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
|
119
|
-
if (
|
120
|
-
|
136
|
+
if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1)
|
137
|
+
|
138
|
+
self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value).bind(dev)
|
139
|
+
if dev in self.copy_queues: self.copy_queues[dev].bind(dev)
|
121
140
|
|
122
|
-
|
123
|
-
|
124
|
-
if hasattr(self.copy_queues[dev], 'bind') and self.last_ji[self.copy_queues[dev]] is not None: self.copy_queues[dev].bind(dev)
|
141
|
+
self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
|
142
|
+
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
|
125
143
|
|
126
144
|
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
127
145
|
# Wait and restore signals
|
128
146
|
self.kickoff_value += 1
|
129
|
-
for dev in self.devices: dev.
|
130
|
-
for
|
131
|
-
|
132
|
-
self.devices[0]._set_signal(self.dev_kickoff_signal['CPU'], self.kickoff_value)
|
147
|
+
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
148
|
+
for sig in self.queue_signals_to_reset: sig.value = 0
|
149
|
+
self.signals['CPU'].value = self.kickoff_value
|
133
150
|
|
134
|
-
if PROFILE and self.kickoff_value > 1:
|
135
|
-
for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): #type: ignore
|
136
|
-
dev.raw_prof_records += [(dev._read_timestamp(st), dev._read_timestamp(en), desc, is_cp)]
|
151
|
+
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
|
137
152
|
|
138
153
|
# Update rawbuffers
|
139
|
-
for (j,i),input_idx in self.input_replace.items():
|
154
|
+
for (j,i),input_idx in self.input_replace.items():
|
155
|
+
if j in self.ji_args: self.ji_args[j].update_buffer(i, input_rawbuffers[input_idx]._buf)
|
156
|
+
else: self.op_cmd_idx[j][0].update_copy(self.op_cmd_idx[j][1], **{('dest' if i == 0 else 'src'): input_rawbuffers[input_idx]._buf.va_addr})
|
140
157
|
|
141
158
|
# Update var_vals
|
142
|
-
for j in self.
|
143
|
-
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars): self.ji_args_vars[j][i] = var_vals[v]
|
159
|
+
for j, i, v in self.updated_vars(var_vals): self.ji_args[j].update_var(i, v)
|
144
160
|
|
145
|
-
|
146
|
-
|
147
|
-
queue
|
161
|
+
# Update launch dims
|
162
|
+
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
163
|
+
queue, cmd_ptr = self.op_cmd_idx[j]
|
164
|
+
queue.update_exec(cmd_ptr, global_dims, local_dims)
|
148
165
|
|
149
166
|
for dev in self.devices:
|
150
|
-
self.comp_queues[dev].
|
151
|
-
|
152
|
-
|
167
|
+
comp_queue, copy_queue, need_sig_upd = self.comp_queues[dev], self.copy_queues.get(dev, None), dev.timeline_signal != self.last_timeline[dev][0]
|
168
|
+
comp_queue.update_wait(1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value - 1) \
|
169
|
+
.update_wait(2, value=self.kickoff_value).update_signal(3, value=self.kickoff_value) \
|
170
|
+
.update_signal(len(comp_queue)-1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value).submit(dev)
|
153
171
|
|
154
|
-
if
|
155
|
-
for cmd_idx in self.kickoff_wait_cmds[
|
156
|
-
|
172
|
+
if copy_queue is not None:
|
173
|
+
for cmd_idx in self.kickoff_wait_cmds[copy_queue]: copy_queue.update_wait(cmd_idx, value=self.kickoff_value)
|
174
|
+
copy_queue.submit(dev)
|
157
175
|
|
158
|
-
self.
|
176
|
+
self.last_timeline[dev] = (dev.timeline_signal, dev.timeline_value)
|
159
177
|
dev.timeline_value += 1
|
160
178
|
|
161
179
|
if wait:
|
162
180
|
st = time.perf_counter()
|
163
|
-
for dev in self.devices: dev.
|
181
|
+
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
164
182
|
return time.perf_counter() - st
|
165
183
|
return None
|
166
184
|
|
167
|
-
def
|
168
|
-
|
185
|
+
def collect_timestamps(self):
|
186
|
+
timestamps = [s.timestamp for s in self.prof_signals]
|
169
187
|
|
170
|
-
|
171
|
-
|
172
|
-
for buf in read+write:
|
173
|
-
if buf.device not in self.save_devs[queue]:
|
174
|
-
self.save_devs[queue].add(buf.device)
|
175
|
-
sync_signals += [(self.dev_kickoff_signal[Device[buf.device]], self.kickoff_value)]
|
188
|
+
for (st,_), (en,_), dev, desc, is_cp, deps, args in self.prof_records:
|
189
|
+
dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp, args)]
|
176
190
|
|
177
|
-
|
191
|
+
for x in deps:
|
192
|
+
(b_st,_), (b_en,_), b_dev, _, b_is_cp, _, _ = self.prof_records[x]
|
193
|
+
dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)]
|
178
194
|
|
179
195
|
def __del__(self):
|
180
|
-
for dev in self.devices: dev.
|
196
|
+
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
181
197
|
|
182
|
-
|
183
|
-
if PROFILE and self.kickoff_value > 1:
|
184
|
-
for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): dev.sig_prof_records += [(st, en, desc, is_cp)] #type: ignore
|
198
|
+
if PROFILE and self.kickoff_value >= 1: self.collect_timestamps()
|
185
199
|
|
186
|
-
|
187
|
-
for dev, buf in self.kernargs_bufs.items(): dev.allocator._free(buf, BufferOptions(cpu_access=True))
|
200
|
+
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferOptions(cpu_access=True))
|
tinygrad/runtime/graph/metal.py
CHANGED
@@ -1,12 +1,20 @@
|
|
1
1
|
from typing import List, Any, Dict, cast, Optional
|
2
|
-
import
|
2
|
+
import ctypes
|
3
3
|
from tinygrad.dtype import dtypes
|
4
|
-
from tinygrad.helpers import dedup,
|
4
|
+
from tinygrad.helpers import dedup, getenv
|
5
5
|
from tinygrad.device import Buffer
|
6
6
|
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
7
|
-
from tinygrad.engine.jit import GraphRunner
|
8
|
-
from tinygrad.
|
9
|
-
from tinygrad.runtime.ops_metal import wait_check
|
7
|
+
from tinygrad.engine.jit import GraphRunner, GraphException
|
8
|
+
from tinygrad.ops import Variable
|
9
|
+
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
|
10
|
+
MTLResourceOptions, elapsed_time, objc_id
|
11
|
+
|
12
|
+
class MTLIndirectCommandType:
|
13
|
+
MTLIndirectCommandTypeConcurrentDispatch = (1 << 5)
|
14
|
+
|
15
|
+
class MTLResourceUsage:
|
16
|
+
MTLResourceUsageRead = 0b01
|
17
|
+
MTLResourceUsageWrite = 0b10
|
10
18
|
|
11
19
|
class MetalGraph(GraphRunner):
|
12
20
|
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
@@ -14,62 +22,82 @@ class MetalGraph(GraphRunner):
|
|
14
22
|
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
15
23
|
|
16
24
|
# create metal batch exec
|
17
|
-
icb_descriptor =
|
18
|
-
icb_descriptor
|
19
|
-
icb_descriptor
|
20
|
-
icb_descriptor
|
21
|
-
icb_descriptor
|
22
|
-
self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache),
|
23
|
-
Metal.MTLResourceOptions(0))
|
24
|
-
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
25
|
+
icb_descriptor = msg(libobjc.objc_getClass(b"MTLIndirectCommandBufferDescriptor"), "new", restype=objc_instance)
|
26
|
+
msg(icb_descriptor, "setCommandTypes:", MTLIndirectCommandType.MTLIndirectCommandTypeConcurrentDispatch)
|
27
|
+
msg(icb_descriptor, "setInheritBuffers:", False)
|
28
|
+
msg(icb_descriptor, "setInheritPipelineState:", False)
|
29
|
+
msg(icb_descriptor, "setMaxKernelBufferBindCount:", 31)
|
25
30
|
|
26
|
-
|
27
|
-
|
31
|
+
self.icb = msg(self.device.device, "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:",
|
32
|
+
icb_descriptor, len(self.jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache, restype=objc_instance)
|
33
|
+
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
34
|
+
icb_label = bytes(msg(msg(self.icb, "description", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode()
|
35
|
+
self.needs_icb_fix = int("AGXG15XFamilyIndirectCommandBuffer" not in icb_label) # not required on M3
|
28
36
|
|
37
|
+
if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
|
38
|
+
all_resources = [self.int_buf.buf] if len(self.vars) else []
|
39
|
+
all_pipelines = []
|
29
40
|
for j,ji in enumerate(self.jit_cache):
|
30
41
|
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
|
35
|
-
icb_command.setComputePipelineState_(unwrap2(
|
36
|
-
self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)))
|
42
|
+
icb_command = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_instance)
|
43
|
+
all_pipelines.append(prg.clprg.pipeline_state)
|
44
|
+
msg(icb_command, "setComputePipelineState:", prg.clprg.pipeline_state)
|
37
45
|
for i,b in enumerate(ji.bufs):
|
38
|
-
if b is not None:
|
39
|
-
icb_command
|
40
|
-
all_resources.append(b._buf)
|
41
|
-
for i,v in enumerate(prg.p.vars): icb_command
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
icb_command
|
46
|
+
if b is not None and b not in input_rawbuffers:
|
47
|
+
msg(icb_command, "setKernelBuffer:offset:atIndex:", b._buf.buf, b._buf.offset, i)
|
48
|
+
all_resources.append(b._buf.buf)
|
49
|
+
for i,v in enumerate(prg.p.vars): msg(icb_command, "setKernelBuffer:offset:atIndex:", self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
|
50
|
+
|
51
|
+
global_size, local_size = prg.p.launch_dims(var_vals)
|
52
|
+
msg(icb_command, "concurrentDispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
|
53
|
+
msg(icb_command, "setBarrier")
|
46
54
|
|
47
55
|
self.all_resources = dedup(all_resources)
|
56
|
+
self.all_pipelines = dedup(all_pipelines)
|
48
57
|
self.command_buffer: Any = None
|
49
|
-
if len(self.vars): self.int_buf_view = self.
|
58
|
+
if len(self.vars): self.int_buf_view = self.device.allocator.as_buffer(self.int_buf).cast('i')
|
59
|
+
self.range = to_struct(0, len(self.jit_cache))
|
50
60
|
|
51
61
|
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
62
|
+
|
52
63
|
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
53
|
-
all_resources = dedup(self.all_resources + [x._buf for x in input_rawbuffers])
|
64
|
+
all_resources = dedup(self.all_resources + [x._buf.buf for x in input_rawbuffers])
|
54
65
|
|
55
66
|
for (j,i),input_idx in self.input_replace.items():
|
56
|
-
self.icb
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
67
|
+
computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_id)
|
68
|
+
msg(computeCommand, "setKernelBuffer:offset:atIndex:", input_rawbuffers[input_idx]._buf.buf,
|
69
|
+
input_rawbuffers[input_idx]._buf.offset, i)
|
70
|
+
|
71
|
+
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
72
|
+
prg = cast(CompiledRunner, self.jit_cache[j].prg)
|
73
|
+
global_size, local_size = global_dims or prg.p.global_size, local_dims or prg.p.local_size
|
74
|
+
computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j)
|
75
|
+
msg(computeCommand, "concurrentDispatchThreadgroups:threadsPerThreadgroup:",
|
76
|
+
to_struct(*cast(tuple, global_size)), to_struct(*cast(tuple, local_size)))
|
61
77
|
for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
|
62
78
|
|
63
|
-
command_buffer = self.device.mtl_queue
|
64
|
-
encoder = command_buffer
|
65
|
-
encoder
|
66
|
-
|
67
|
-
|
68
|
-
|
79
|
+
command_buffer = msg(self.device.mtl_queue, "commandBuffer", restype=objc_instance)
|
80
|
+
encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
|
81
|
+
msg(encoder, "useResources:count:usage:", (objc_id * len(all_resources))(*all_resources), len(all_resources),
|
82
|
+
MTLResourceUsage.MTLResourceUsageRead | MTLResourceUsage.MTLResourceUsageWrite)
|
83
|
+
|
84
|
+
# NOTE: the pipelines likely need to be added to the used resources to fix the crash on M1/M2, but I haven't figured out how
|
85
|
+
# this is a O(n) hack to get them used. what should work is:
|
86
|
+
#encoder.useResources_count_usage_(self.all_pipelines, len(self.all_pipelines), Metal.MTLResourceUsageRead)
|
87
|
+
# but it fails with "Invalid Resource (00000009:kIOGPUCommandBufferCallbackErrorInvalidResource)"
|
88
|
+
# to repro the crash (which can also crash other running GPU apps), run with FIX_METAL_ICB=0
|
89
|
+
if getenv("FIX_METAL_ICB", self.needs_icb_fix):
|
90
|
+
for ps in self.all_pipelines:
|
91
|
+
msg(encoder, "setComputePipelineState:", ps)
|
92
|
+
msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(0,0,0), to_struct(0,0,0))
|
93
|
+
|
94
|
+
msg(encoder, "executeCommandsInBuffer:withRange:", self.icb, self.range)
|
95
|
+
msg(encoder, "endEncoding")
|
96
|
+
msg(command_buffer, "commit")
|
69
97
|
self.command_buffer = command_buffer
|
70
98
|
|
71
99
|
if wait:
|
72
100
|
wait_check(command_buffer)
|
73
|
-
return
|
101
|
+
return elapsed_time(command_buffer)
|
74
102
|
self.device.mtl_buffers_in_flight.append(command_buffer)
|
75
103
|
return None
|