tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/runtime/graph/cuda.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import ctypes
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, cast
|
3
3
|
import tinygrad.runtime.autogen.cuda as cuda
|
4
4
|
from tinygrad.helpers import init_c_var, dedup
|
5
5
|
from tinygrad.device import Buffer, Device
|
@@ -9,18 +9,18 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
|
9
9
|
from tinygrad.engine.jit import MultiGraphRunner, GraphException
|
10
10
|
|
11
11
|
class CUDAGraph(MultiGraphRunner):
|
12
|
-
def __init__(self, jit_cache:
|
12
|
+
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
13
13
|
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
14
14
|
|
15
15
|
# Check all jit items are compatible.
|
16
16
|
if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
|
17
17
|
|
18
18
|
self.jc_idx_with_updatable_rawbufs = dedup([x[0] for x in self.input_replace.keys()])
|
19
|
-
self.updatable_nodes:
|
19
|
+
self.updatable_nodes: dict[int, tuple[Any, Any, Any, bool]] = {} # dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
|
20
20
|
|
21
21
|
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
22
22
|
|
23
|
-
for j,ji in enumerate(
|
23
|
+
for j,ji in enumerate(jit_cache):
|
24
24
|
if isinstance(ji.prg, CompiledRunner):
|
25
25
|
global_size, local_size = ji.prg.p.launch_dims(var_vals)
|
26
26
|
|
@@ -29,7 +29,7 @@ class CUDAGraph(MultiGraphRunner):
|
|
29
29
|
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
30
30
|
|
31
31
|
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.p.vars])
|
32
|
-
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.
|
32
|
+
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs)
|
33
33
|
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
|
34
34
|
|
35
35
|
if j in self.launch_dims_replace or j in self.var_vals_replace or j in self.jc_idx_with_updatable_rawbufs:
|
@@ -48,7 +48,7 @@ class CUDAGraph(MultiGraphRunner):
|
|
48
48
|
|
49
49
|
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
|
50
50
|
|
51
|
-
def __call__(self, input_rawbuffers:
|
51
|
+
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
|
52
52
|
# Update rawbuffers in the c_args struct.
|
53
53
|
for (j,i),input_idx in self.input_replace.items():
|
54
54
|
if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
|
@@ -61,9 +61,8 @@ class CUDAGraph(MultiGraphRunner):
|
|
61
61
|
|
62
62
|
# Update launch dims in the kern_params struct.
|
63
63
|
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
64
|
-
|
65
|
-
node,
|
66
|
-
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size # type: ignore[misc]
|
64
|
+
node = self.updatable_nodes[j][1]
|
65
|
+
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_dims, *global_dims # type: ignore[misc]
|
67
66
|
|
68
67
|
# Update graph nodes with the updated structs.
|
69
68
|
for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
|
tinygrad/runtime/graph/hcq.py
CHANGED
@@ -1,58 +1,77 @@
|
|
1
1
|
import collections, time
|
2
|
-
from typing import
|
3
|
-
from tinygrad.helpers import round_up, PROFILE
|
4
|
-
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer,
|
5
|
-
from tinygrad.device import Buffer,
|
6
|
-
from tinygrad.
|
2
|
+
from typing import Any, cast
|
3
|
+
from tinygrad.helpers import round_up, PROFILE
|
4
|
+
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator
|
5
|
+
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
|
6
|
+
from tinygrad.dtype import dtypes
|
7
|
+
from tinygrad.ops import UOp, Variable
|
7
8
|
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
8
9
|
from tinygrad.engine.jit import MultiGraphRunner
|
9
10
|
|
10
11
|
class HCQGraph(MultiGraphRunner):
|
11
|
-
def __init__(self, jit_cache:
|
12
|
+
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
12
13
|
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
13
14
|
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
|
14
15
|
|
16
|
+
# Replace input buffers with variables.
|
17
|
+
self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in jit_cache]
|
18
|
+
self.input_replace_to_var: dict[tuple[int, int], Variable] = {}
|
19
|
+
|
20
|
+
for (j,i), input_idx in self.input_replace.items():
|
21
|
+
x = self.input_replace_to_var.setdefault((j,i), UOp.variable(f"input_{input_idx}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
|
22
|
+
self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, texture_info=self.hcq_bufs[j][i].texture_info) # Create fake buffer with variable
|
23
|
+
|
15
24
|
# Allocate kernel args.
|
16
|
-
kernargs_size:
|
17
|
-
for ji in
|
25
|
+
kernargs_size: dict[Compiled, int] = collections.defaultdict(int)
|
26
|
+
for ji in jit_cache:
|
18
27
|
if not isinstance(ji.prg, CompiledRunner): continue
|
19
|
-
kernargs_size[ji.prg.
|
20
|
-
self.kernargs_bufs:
|
28
|
+
kernargs_size[ji.prg.dev] += round_up(ji.prg._prg.kernargs_alloc_size, 16)
|
29
|
+
self.kernargs_bufs: dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferSpec(cpu_access=True)) for dev,sz in kernargs_size.items()}
|
21
30
|
|
22
31
|
# Fill initial arguments.
|
23
|
-
self.ji_args:
|
32
|
+
self.ji_args: dict[int, HCQArgsState] = {}
|
24
33
|
|
25
|
-
|
26
|
-
for j,ji in enumerate(
|
34
|
+
kargs_alloc: dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size, base=cast(int, buf.va_addr)) for dev,buf in self.kernargs_bufs.items()}
|
35
|
+
for j,ji in enumerate(jit_cache):
|
27
36
|
if not isinstance(ji.prg, CompiledRunner): continue
|
28
|
-
|
29
|
-
self.ji_args[j] = ji.prg.
|
37
|
+
|
38
|
+
self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.vars, kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16))
|
30
39
|
|
31
40
|
# Schedule Dependencies.
|
32
41
|
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
|
33
42
|
# graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
|
34
43
|
# global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’s
|
35
44
|
# compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
|
36
|
-
self.ji_schedule:
|
45
|
+
self.ji_schedule: dict[int, tuple[HCQCompiled, HWQueue, list, list, HCQSignal, int|None]] = {}
|
37
46
|
|
38
|
-
self.comp_queues:
|
39
|
-
self.copy_queues:
|
47
|
+
self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
|
48
|
+
self.copy_queues: dict[HCQCompiled, HWQueue] = {} # lazy allocation
|
40
49
|
|
41
|
-
self.signals:
|
50
|
+
self.signals: dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
|
42
51
|
self.kickoff_value: int = 0
|
52
|
+
self.kickoff_var = UOp.variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32)
|
43
53
|
|
44
|
-
|
45
|
-
|
54
|
+
# When profiling allocate 2 signals for each jit item to measure speed. The jth jit item have signals at 2*j and 2*j+1.
|
55
|
+
# TODO: This logic might allocate a few extra signals...
|
56
|
+
self.prof_signals: list[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else []
|
57
|
+
self.prog_graph_deps: list[list[int]] = []
|
58
|
+
self.prof_graph_entries: list[ProfileGraphEntry] = []
|
46
59
|
|
47
|
-
last_j:
|
48
|
-
queue_access:
|
49
|
-
dev_access:
|
60
|
+
last_j: dict[HWQueue, int|None] = collections.defaultdict(lambda: None)
|
61
|
+
queue_access: dict[HWQueue, dict[HWQueue, int|None]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
|
62
|
+
dev_access: dict[HWQueue, set[HCQCompiled]] = collections.defaultdict(set)
|
50
63
|
|
51
64
|
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
|
52
65
|
|
53
|
-
for j,ji in enumerate(
|
54
|
-
enqueue_dev = ji.prg.
|
55
|
-
|
66
|
+
for j,ji in enumerate(jit_cache):
|
67
|
+
enqueue_dev: HCQCompiled = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
|
68
|
+
|
69
|
+
if is_exec_prg:
|
70
|
+
enqueue_queue = self.comp_queues[enqueue_dev]
|
71
|
+
else:
|
72
|
+
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
|
73
|
+
enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
|
74
|
+
|
56
75
|
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
|
57
76
|
|
58
77
|
# Get dependencies based on input and output buffers.
|
@@ -70,13 +89,13 @@ class HCQGraph(MultiGraphRunner):
|
|
70
89
|
|
71
90
|
# Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph.
|
72
91
|
for dep_queue, _ in opt_deps: dev_access[enqueue_queue].update(dev_access[dep_queue])
|
73
|
-
sync_signals = [(self.signals[d], self.
|
92
|
+
sync_signals = [(self.signals[d], self.kickoff_var) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]]
|
74
93
|
dev_access[enqueue_queue].update(cast(HCQCompiled, Device[cast(Buffer, b).device]) for b in ji.bufs)
|
75
94
|
|
76
95
|
# Remove self-dependency for compute and copy queues.
|
77
96
|
# For compute, in case of NV, optimize when only 1 same-queue dependency exists, since NV chains 2+ executions in this case,
|
78
97
|
# eliminating dependency need.
|
79
|
-
dname = enqueue_dev.
|
98
|
+
dname = enqueue_dev.device.split(":", 1)[0]
|
80
99
|
can_opt = dname in {"AMD", "QCOM"} or (dname == "NV" and len(sync_signals) == 0 and len(opt_deps) == 1 and id(opt_deps[0][0]) == id(out_signal))
|
81
100
|
if can_opt or isinstance(ji.prg, BufferXfer): opt_deps = [x for x in opt_deps if id(x[0]) != id(out_signal)]
|
82
101
|
|
@@ -86,48 +105,52 @@ class HCQGraph(MultiGraphRunner):
|
|
86
105
|
|
87
106
|
# Collect profile information if profiling is enabled.
|
88
107
|
if PROFILE:
|
89
|
-
|
90
|
-
|
91
|
-
sig_st, sig_en = (j * 2, True), (j * 2 + 1, True)
|
92
|
-
if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None: sig_st = (prev_ji * 2 + 1, False)
|
108
|
+
# When execution are chained, we can reuse the end timestamp from the previous command as the start timestamp for the current command.
|
109
|
+
sig_st = prev_ji * 2 + 1 if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None else j * 2
|
93
110
|
|
94
|
-
|
95
|
-
|
111
|
+
# Description based on the command.
|
112
|
+
prof_ji_desc = ji.prg._prg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
|
96
113
|
|
97
|
-
self.
|
114
|
+
self.prof_graph_entries.append(ProfileGraphEntry(enqueue_dev.device, prof_ji_desc, sig_st, j * 2 + 1, is_copy=not is_exec_prg))
|
115
|
+
self.prog_graph_deps.append([d - 1 for _, d in rdeps])
|
98
116
|
|
99
117
|
last_j[enqueue_queue] = j
|
100
118
|
|
119
|
+
# Check which signals are used in the profile graph.
|
120
|
+
self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(self.prof_signals))]
|
121
|
+
|
101
122
|
# Build hardware queues.
|
102
|
-
self.
|
103
|
-
|
104
|
-
|
123
|
+
self.copy_to_devs: dict[HCQCompiled, set[HCQCompiled]] = {dev: set() for dev in self.devices}
|
124
|
+
|
125
|
+
# Create variable timeline signals for each device.
|
126
|
+
timeline_sigaddrs = {dev: UOp.variable(f"timeline_sig_{dev.device_id}", 0, 0xffffffffffffffff, dtype=dtypes.uint64) for dev in self.devices}
|
127
|
+
self.virt_timeline_vals = {dev: UOp.variable(f"timeline_var_{dev.device_id}", 0, 0xffffffff, dtype=dtypes.uint32) for dev in self.devices}
|
128
|
+
self.virt_timeline_signals = {dev: dev.signal_t(base_addr=timeline_sigaddrs[dev], timeline_for_device=dev) for dev in self.devices}
|
105
129
|
|
106
130
|
for dev in self.devices:
|
107
|
-
self.comp_queues[dev].memory_barrier().wait(dev
|
108
|
-
.wait(self.signals['CPU'], self.
|
131
|
+
self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \
|
132
|
+
.wait(self.signals['CPU'], self.kickoff_var).signal(self.signals[dev], self.kickoff_var)
|
109
133
|
|
110
|
-
for j,ji in enumerate(
|
134
|
+
for j,ji in enumerate(jit_cache):
|
111
135
|
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
|
112
136
|
|
113
|
-
for i in range(len(sync_signals)): self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) + i)
|
114
137
|
for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
|
115
138
|
|
116
139
|
# Encode waits and start profile timestamp (if needed).
|
117
|
-
if PROFILE and self.
|
140
|
+
if PROFILE and self.prof_signal_is_used[j * 2]: enqueue_queue.timestamp(self.prof_signals[j * 2])
|
118
141
|
|
119
142
|
# Encode main commands based on ji type.
|
120
143
|
if isinstance(ji.prg, CompiledRunner):
|
121
|
-
|
144
|
+
enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(ji.prg.p.global_size or (1,1,1)), tuple(ji.prg.p.local_size or (1,1,1)))
|
122
145
|
elif isinstance(ji.prg, BufferXfer):
|
123
146
|
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
124
147
|
cast(HCQAllocator, Device[src.device].allocator).map(dest._buf)
|
125
|
-
|
148
|
+
|
149
|
+
enqueue_queue.copy(self.hcq_bufs[j][0].va_addr, self.hcq_bufs[j][1].va_addr, dest.nbytes)
|
126
150
|
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
|
127
|
-
self.op_cmd_idx[j] = (enqueue_queue, len(enqueue_queue) - 1)
|
128
151
|
|
129
152
|
# Encode finish profile timestamp (if needed).
|
130
|
-
if PROFILE and self.
|
153
|
+
if PROFILE and self.prof_signal_is_used[j * 2 + 1]: enqueue_queue.timestamp(self.prof_signals[j * 2 + 1])
|
131
154
|
|
132
155
|
if signal_val is not None: enqueue_queue.signal(signal, signal_val)
|
133
156
|
|
@@ -135,13 +158,13 @@ class HCQGraph(MultiGraphRunner):
|
|
135
158
|
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
|
136
159
|
if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1)
|
137
160
|
|
138
|
-
self.comp_queues[dev].signal(dev
|
161
|
+
self.comp_queues[dev].signal(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev] + 1).bind(dev)
|
139
162
|
if dev in self.copy_queues: self.copy_queues[dev].bind(dev)
|
140
163
|
|
141
|
-
self.last_timeline:
|
164
|
+
self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
|
142
165
|
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
|
143
166
|
|
144
|
-
def __call__(self, input_rawbuffers:
|
167
|
+
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
|
145
168
|
# Wait and restore signals
|
146
169
|
self.kickoff_value += 1
|
147
170
|
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
@@ -150,28 +173,16 @@ class HCQGraph(MultiGraphRunner):
|
|
150
173
|
|
151
174
|
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
|
152
175
|
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
else: self.op_cmd_idx[j][0].update_copy(self.op_cmd_idx[j][1], **{('dest' if i == 0 else 'src'): input_rawbuffers[input_idx]._buf.va_addr})
|
157
|
-
|
158
|
-
# Update var_vals
|
159
|
-
for j, i, v in self.updated_vars(var_vals): self.ji_args[j].update_var(i, v)
|
176
|
+
hcq_var_vals = {self.kickoff_var: self.kickoff_value, **var_vals,
|
177
|
+
**{var: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
|
178
|
+
**{sig.base_addr: dev.timeline_signal.base_addr for dev, sig in self.virt_timeline_signals.items()}}
|
160
179
|
|
161
|
-
# Update
|
162
|
-
for j,
|
163
|
-
queue, cmd_ptr = self.op_cmd_idx[j]
|
164
|
-
queue.update_exec(cmd_ptr, global_dims, local_dims)
|
180
|
+
# Update rawbuffers
|
181
|
+
for (j,i),input_idx in self.input_replace.items(): hcq_var_vals[self.input_replace_to_var.get((j,i))] = input_rawbuffers[input_idx]._buf.va_addr
|
165
182
|
|
166
183
|
for dev in self.devices:
|
167
|
-
|
168
|
-
|
169
|
-
.update_wait(2, value=self.kickoff_value).update_signal(3, value=self.kickoff_value) \
|
170
|
-
.update_signal(len(comp_queue)-1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value).submit(dev)
|
171
|
-
|
172
|
-
if copy_queue is not None:
|
173
|
-
for cmd_idx in self.kickoff_wait_cmds[copy_queue]: copy_queue.update_wait(cmd_idx, value=self.kickoff_value)
|
174
|
-
copy_queue.submit(dev)
|
184
|
+
self.comp_queues[dev].submit(dev, hcq_var_vals)
|
185
|
+
if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals)
|
175
186
|
|
176
187
|
self.last_timeline[dev] = (dev.timeline_signal, dev.timeline_value)
|
177
188
|
dev.timeline_value += 1
|
@@ -183,18 +194,12 @@ class HCQGraph(MultiGraphRunner):
|
|
183
194
|
return None
|
184
195
|
|
185
196
|
def collect_timestamps(self):
|
186
|
-
|
187
|
-
|
188
|
-
for (st,_), (en,_), dev, desc, is_cp, deps, args in self.prof_records:
|
189
|
-
dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp, args)]
|
190
|
-
|
191
|
-
for x in deps:
|
192
|
-
(b_st,_), (b_en,_), b_dev, _, b_is_cp, _, _ = self.prof_records[x]
|
193
|
-
dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)]
|
197
|
+
# NOTE: Append to any device is fine...
|
198
|
+
self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prog_graph_deps, [s.timestamp for s in self.prof_signals])]
|
194
199
|
|
195
200
|
def __del__(self):
|
196
201
|
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
197
202
|
|
198
203
|
if PROFILE and self.kickoff_value >= 1: self.collect_timestamps()
|
199
204
|
|
200
|
-
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf,
|
205
|
+
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferSpec(cpu_access=True))
|
tinygrad/runtime/graph/metal.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Any, cast
|
2
2
|
import ctypes
|
3
3
|
from tinygrad.dtype import dtypes
|
4
4
|
from tinygrad.helpers import dedup, getenv
|
@@ -7,7 +7,7 @@ from tinygrad.engine.realize import ExecItem, CompiledRunner
|
|
7
7
|
from tinygrad.engine.jit import GraphRunner, GraphException
|
8
8
|
from tinygrad.ops import Variable
|
9
9
|
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
|
10
|
-
MTLResourceOptions,
|
10
|
+
MTLResourceOptions, cmdbuf_st_time, cmdbuf_en_time, objc_id, to_ns_str
|
11
11
|
|
12
12
|
class MTLIndirectCommandType:
|
13
13
|
MTLIndirectCommandTypeConcurrentDispatch = (1 << 5)
|
@@ -17,68 +17,64 @@ class MTLResourceUsage:
|
|
17
17
|
MTLResourceUsageWrite = 0b10
|
18
18
|
|
19
19
|
class MetalGraph(GraphRunner):
|
20
|
-
def __init__(self, jit_cache:
|
20
|
+
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
21
21
|
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
22
22
|
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
23
23
|
|
24
24
|
# create metal batch exec
|
25
|
-
icb_descriptor = msg(libobjc.objc_getClass(b"MTLIndirectCommandBufferDescriptor")
|
26
|
-
msg(
|
27
|
-
msg(
|
28
|
-
msg(
|
29
|
-
msg(
|
30
|
-
|
31
|
-
self.icb = msg(
|
32
|
-
icb_descriptor, len(
|
25
|
+
icb_descriptor = msg("new", objc_instance)(libobjc.objc_getClass(b"MTLIndirectCommandBufferDescriptor"))
|
26
|
+
msg("setCommandTypes:")(icb_descriptor, MTLIndirectCommandType.MTLIndirectCommandTypeConcurrentDispatch)
|
27
|
+
msg("setInheritBuffers:")(icb_descriptor, False)
|
28
|
+
msg("setInheritPipelineState:")(icb_descriptor, False)
|
29
|
+
msg("setMaxKernelBufferBindCount:")(icb_descriptor, 31)
|
30
|
+
|
31
|
+
self.icb = msg("newIndirectCommandBufferWithDescriptor:maxCommandCount:options:", objc_instance)(self.dev.sysdevice,
|
32
|
+
icb_descriptor, len(jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache)
|
33
33
|
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
34
|
-
icb_label = bytes(msg(msg(
|
34
|
+
icb_label = bytes(msg("UTF8String", ctypes.c_char_p)(msg("description", objc_instance)(self.icb))).decode()
|
35
35
|
self.needs_icb_fix = int("AGXG15XFamilyIndirectCommandBuffer" not in icb_label) # not required on M3
|
36
36
|
|
37
|
-
if len(self.vars): self.int_buf = self.
|
37
|
+
if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
|
38
38
|
all_resources = [self.int_buf.buf] if len(self.vars) else []
|
39
39
|
all_pipelines = []
|
40
|
-
for j,ji in enumerate(
|
40
|
+
for j,ji in enumerate(jit_cache):
|
41
41
|
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
|
42
|
-
icb_command = msg(
|
43
|
-
all_pipelines.append(prg.
|
44
|
-
msg(
|
42
|
+
icb_command = msg("indirectComputeCommandAtIndex:", objc_instance)(self.icb, j)
|
43
|
+
all_pipelines.append(prg._prg.pipeline_state)
|
44
|
+
msg("setComputePipelineState:")(icb_command, prg._prg.pipeline_state)
|
45
45
|
for i,b in enumerate(ji.bufs):
|
46
46
|
if b is not None and b not in input_rawbuffers:
|
47
|
-
msg(
|
47
|
+
msg("setKernelBuffer:offset:atIndex:")(icb_command, b._buf.buf, b._buf.offset, i)
|
48
48
|
all_resources.append(b._buf.buf)
|
49
|
-
for i,v in enumerate(prg.p.vars): msg(
|
49
|
+
for i,v in enumerate(prg.p.vars): msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
|
50
50
|
|
51
51
|
global_size, local_size = prg.p.launch_dims(var_vals)
|
52
|
-
msg(
|
53
|
-
msg(
|
52
|
+
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(icb_command, to_struct(*global_size), to_struct(*local_size))
|
53
|
+
msg("setBarrier")(icb_command)
|
54
54
|
|
55
55
|
self.all_resources = dedup(all_resources)
|
56
56
|
self.all_pipelines = dedup(all_pipelines)
|
57
57
|
self.command_buffer: Any = None
|
58
|
-
if len(self.vars): self.int_buf_view = self.
|
59
|
-
self.range = to_struct(0, len(
|
58
|
+
if len(self.vars): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
|
59
|
+
self.range = to_struct(0, len(jit_cache))
|
60
60
|
|
61
|
-
def __call__(self, input_rawbuffers:
|
61
|
+
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
|
62
62
|
|
63
|
-
if self.command_buffer is not None and self.command_buffer in self.
|
63
|
+
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
64
64
|
all_resources = dedup(self.all_resources + [x._buf.buf for x in input_rawbuffers])
|
65
65
|
|
66
66
|
for (j,i),input_idx in self.input_replace.items():
|
67
|
-
computeCommand = msg(
|
68
|
-
msg(
|
69
|
-
input_rawbuffers[input_idx]._buf.offset, i)
|
67
|
+
computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
|
68
|
+
msg("setKernelBuffer:offset:atIndex:")(computeCommand, input_rawbuffers[input_idx]._buf.buf, input_rawbuffers[input_idx]._buf.offset, i)
|
70
69
|
|
71
70
|
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
72
|
-
|
73
|
-
|
74
|
-
computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j)
|
75
|
-
msg(computeCommand, "concurrentDispatchThreadgroups:threadsPerThreadgroup:",
|
76
|
-
to_struct(*cast(tuple, global_size)), to_struct(*cast(tuple, local_size)))
|
71
|
+
computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
|
72
|
+
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(computeCommand, to_struct(*global_dims), to_struct(*local_dims))
|
77
73
|
for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
|
78
74
|
|
79
|
-
command_buffer = msg(
|
80
|
-
encoder = msg(
|
81
|
-
msg(
|
75
|
+
command_buffer = msg("commandBuffer", objc_instance)(self.dev.mtl_queue)
|
76
|
+
encoder = msg("computeCommandEncoder", objc_instance)(command_buffer)
|
77
|
+
msg("useResources:count:usage:")(encoder, (objc_id * len(all_resources))(*all_resources), len(all_resources),
|
82
78
|
MTLResourceUsage.MTLResourceUsageRead | MTLResourceUsage.MTLResourceUsageWrite)
|
83
79
|
|
84
80
|
# NOTE: the pipelines likely need to be added to the used resources to fix the crash on M1/M2, but I haven't figured out how
|
@@ -88,16 +84,17 @@ class MetalGraph(GraphRunner):
|
|
88
84
|
# to repro the crash (which can also crash other running GPU apps), run with FIX_METAL_ICB=0
|
89
85
|
if getenv("FIX_METAL_ICB", self.needs_icb_fix):
|
90
86
|
for ps in self.all_pipelines:
|
91
|
-
msg(
|
92
|
-
msg(
|
87
|
+
msg("setComputePipelineState:")(encoder, ps)
|
88
|
+
msg("dispatchThreadgroups:threadsPerThreadgroup:")(encoder, to_struct(0,0,0), to_struct(0,0,0))
|
93
89
|
|
94
|
-
msg(
|
95
|
-
msg(
|
96
|
-
msg(command_buffer, "
|
90
|
+
msg("executeCommandsInBuffer:withRange:")(encoder, self.icb, self.range)
|
91
|
+
msg("endEncoding")(encoder)
|
92
|
+
msg("setLabel:")(command_buffer, to_ns_str(f"batched {len(self.jit_cache)}"))
|
93
|
+
msg("commit")(command_buffer)
|
97
94
|
self.command_buffer = command_buffer
|
98
95
|
|
96
|
+
self.dev.mtl_buffers_in_flight.append(command_buffer)
|
99
97
|
if wait:
|
100
98
|
wait_check(command_buffer)
|
101
|
-
return
|
102
|
-
self.device.mtl_buffers_in_flight.append(command_buffer)
|
99
|
+
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
|
103
100
|
return None
|