tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/runtime/graph/cuda.py
CHANGED
@@ -4,7 +4,7 @@ import tinygrad.runtime.autogen.cuda as cuda
|
|
4
4
|
from tinygrad.helpers import init_c_var, dedup
|
5
5
|
from tinygrad.device import Buffer, Device
|
6
6
|
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
7
|
-
from tinygrad.ops import Variable
|
7
|
+
from tinygrad.uop.ops import Variable
|
8
8
|
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
9
9
|
from tinygrad.engine.jit import MultiGraphRunner, GraphException
|
10
10
|
|
@@ -28,8 +28,8 @@ class CUDAGraph(MultiGraphRunner):
|
|
28
28
|
deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node)
|
29
29
|
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
30
30
|
|
31
|
-
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals
|
32
|
-
kern_params = cuda.
|
31
|
+
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x, ji.fixedvars.get(x)) for x in ji.prg.p.vars])
|
32
|
+
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs)
|
33
33
|
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
|
34
34
|
|
35
35
|
if j in self.launch_dims_replace or j in self.var_vals_replace or j in self.jc_idx_with_updatable_rawbufs:
|
tinygrad/runtime/graph/hcq.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
import collections, time
|
2
2
|
from typing import Any, cast
|
3
|
-
from tinygrad.helpers import round_up, PROFILE
|
4
|
-
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator
|
3
|
+
from tinygrad.helpers import round_up, PROFILE, merge_dicts, getenv, dedup
|
4
|
+
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface
|
5
5
|
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
|
6
6
|
from tinygrad.dtype import dtypes
|
7
|
-
from tinygrad.ops import UOp, Variable
|
8
|
-
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
7
|
+
from tinygrad.uop.ops import UOp, Variable
|
8
|
+
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, BufferCopy
|
9
9
|
from tinygrad.engine.jit import MultiGraphRunner
|
10
10
|
|
11
11
|
class HCQGraph(MultiGraphRunner):
|
@@ -13,6 +13,9 @@ class HCQGraph(MultiGraphRunner):
|
|
13
13
|
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
14
14
|
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
|
15
15
|
|
16
|
+
# CPU Device is always last
|
17
|
+
self.devices = sorted(self.devices, key=lambda x: 1 if x._is_cpu() else 0)
|
18
|
+
|
16
19
|
# Replace input buffers with variables.
|
17
20
|
self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in jit_cache]
|
18
21
|
self.input_replace_to_var: dict[tuple[int, int], Variable] = {}
|
@@ -26,16 +29,17 @@ class HCQGraph(MultiGraphRunner):
|
|
26
29
|
for ji in jit_cache:
|
27
30
|
if not isinstance(ji.prg, CompiledRunner): continue
|
28
31
|
kernargs_size[ji.prg.dev] += round_up(ji.prg._prg.kernargs_alloc_size, 16)
|
29
|
-
self.kernargs_bufs: dict[Compiled, HCQBuffer] = {
|
32
|
+
self.kernargs_bufs: dict[Compiled, HCQBuffer] = {d:d.allocator._alloc(max(sz, 1), BufferSpec(cpu_access=True)) for d,sz in kernargs_size.items()}
|
30
33
|
|
31
34
|
# Fill initial arguments.
|
32
35
|
self.ji_args: dict[int, HCQArgsState] = {}
|
33
36
|
|
34
|
-
kargs_alloc: dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size
|
37
|
+
kargs_alloc: dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size) for dev,buf in self.kernargs_bufs.items()}
|
35
38
|
for j,ji in enumerate(jit_cache):
|
36
39
|
if not isinstance(ji.prg, CompiledRunner): continue
|
37
40
|
|
38
|
-
|
41
|
+
argsbuf = self.kernargs_bufs[ji.prg.dev].offset(kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16))
|
42
|
+
self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.vars, argsbuf)
|
39
43
|
|
40
44
|
# Schedule Dependencies.
|
41
45
|
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
|
@@ -47,14 +51,15 @@ class HCQGraph(MultiGraphRunner):
|
|
47
51
|
self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
|
48
52
|
self.copy_queues: dict[HCQCompiled, HWQueue] = {} # lazy allocation
|
49
53
|
|
50
|
-
self.signals: dict[Any, HCQSignal] = {**{dev: dev.
|
54
|
+
self.signals: dict[Any, HCQSignal] = {**{dev: dev.new_signal(value=0) for dev in self.devices if not dev._is_cpu()},
|
55
|
+
**{"KICK": self.devices[0].new_signal(value=0)}, **{dev: self.devices[0].new_signal(value=0) for dev in self.devices if dev._is_cpu()}}
|
51
56
|
self.kickoff_value: int = 0
|
52
57
|
self.kickoff_var = UOp.variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32)
|
53
58
|
|
54
59
|
# When profiling allocate 2 signals for each jit item to measure speed. The jth jit item have signals at 2*j and 2*j+1.
|
55
60
|
# TODO: This logic might allocate a few extra signals...
|
56
|
-
self.prof_signals: list[HCQSignal] = [
|
57
|
-
self.
|
61
|
+
self.prof_signals: list[HCQSignal] = []
|
62
|
+
self.prof_graph_deps: list[list[int]] = []
|
58
63
|
self.prof_graph_entries: list[ProfileGraphEntry] = []
|
59
64
|
|
60
65
|
last_j: dict[HWQueue, int|None] = collections.defaultdict(lambda: None)
|
@@ -63,8 +68,18 @@ class HCQGraph(MultiGraphRunner):
|
|
63
68
|
|
64
69
|
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
|
65
70
|
|
71
|
+
self.input_replace_map: dict[HCQCompiled, set[int]] = collections.defaultdict(set)
|
72
|
+
self.fixedvars: dict[HCQCompiled, dict[Variable, int]] = {}
|
73
|
+
|
66
74
|
for j,ji in enumerate(jit_cache):
|
67
|
-
|
75
|
+
if is_exec_prg:=isinstance(ji.prg, CompiledRunner): enqueue_dev: HCQCompiled = ji.prg.dev
|
76
|
+
else:
|
77
|
+
# For copy ops prioritize enqeueuing on the dest device, so reverse the buffers.
|
78
|
+
for b in cast(list[Buffer], ji.bufs[::-1]):
|
79
|
+
if (enqueue_dev:=cast(HCQCompiled, Device[b.device])).hw_copy_queue_t is not None: break
|
80
|
+
|
81
|
+
# set any fixedvars on the device
|
82
|
+
self.fixedvars[enqueue_dev] = merge_dicts([self.fixedvars.get(enqueue_dev, {}), ji.fixedvars])
|
68
83
|
|
69
84
|
if is_exec_prg:
|
70
85
|
enqueue_queue = self.comp_queues[enqueue_dev]
|
@@ -72,7 +87,7 @@ class HCQGraph(MultiGraphRunner):
|
|
72
87
|
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
|
73
88
|
enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
|
74
89
|
|
75
|
-
out_signal = self.signals.setdefault(enqueue_queue,
|
90
|
+
out_signal = self.signals.setdefault(enqueue_queue, self.devices[0].new_signal(value=0))
|
76
91
|
|
77
92
|
# Get dependencies based on input and output buffers.
|
78
93
|
rdeps = self._access_resources(ji.bufs, ji.prg.p.outs if is_exec_prg else [0], (enqueue_queue, j + 1)) #type:ignore
|
@@ -86,9 +101,9 @@ class HCQGraph(MultiGraphRunner):
|
|
86
101
|
if (qa:=queue_access[enqueue_queue][dep_queue]) is None or qa < dep_val:
|
87
102
|
opt_deps.append((self.signals[dep_queue], dep_val))
|
88
103
|
queue_access[enqueue_queue][dep_queue] = dep_val
|
104
|
+
dev_access[enqueue_queue].update(dev_access[dep_queue])
|
89
105
|
|
90
106
|
# Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph.
|
91
|
-
for dep_queue, _ in opt_deps: dev_access[enqueue_queue].update(dev_access[dep_queue])
|
92
107
|
sync_signals = [(self.signals[d], self.kickoff_var) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]]
|
93
108
|
dev_access[enqueue_queue].update(cast(HCQCompiled, Device[cast(Buffer, b).device]) for b in ji.bufs)
|
94
109
|
|
@@ -112,28 +127,31 @@ class HCQGraph(MultiGraphRunner):
|
|
112
127
|
prof_ji_desc = ji.prg._prg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
|
113
128
|
|
114
129
|
self.prof_graph_entries.append(ProfileGraphEntry(enqueue_dev.device, prof_ji_desc, sig_st, j * 2 + 1, is_copy=not is_exec_prg))
|
115
|
-
self.
|
130
|
+
self.prof_graph_deps.append([d - 1 for _, d in rdeps])
|
116
131
|
|
117
132
|
last_j[enqueue_queue] = j
|
118
133
|
|
119
134
|
# Check which signals are used in the profile graph.
|
120
|
-
self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(
|
135
|
+
self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(jit_cache) * 2)]
|
121
136
|
|
122
137
|
# Build hardware queues.
|
123
138
|
self.copy_to_devs: dict[HCQCompiled, set[HCQCompiled]] = {dev: set() for dev in self.devices}
|
124
139
|
|
125
140
|
# Create variable timeline signals for each device.
|
126
|
-
timeline_sigaddrs = {dev: UOp.variable(f"timeline_sig_{dev
|
127
|
-
self.virt_timeline_vals = {dev: UOp.variable(f"timeline_var_{dev
|
128
|
-
self.virt_timeline_signals = {dev: dev.signal_t(
|
141
|
+
timeline_sigaddrs = {dev: UOp.variable(f"timeline_sig_{self.dev_name(dev)}", 0, 0xffffffffffffffff, dtype=dtypes.uint64) for dev in self.devices}
|
142
|
+
self.virt_timeline_vals = {dev: UOp.variable(f"timeline_var_{self.dev_name(dev)}", 0, 0xffffffff, dtype=dtypes.uint32) for dev in self.devices}
|
143
|
+
self.virt_timeline_signals = {dev: dev.signal_t(HCQBuffer(timeline_sigaddrs[dev], 16), owner=dev, is_timeline=True) for dev in self.devices}
|
129
144
|
|
130
145
|
for dev in self.devices:
|
131
146
|
self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \
|
132
|
-
.wait(self.signals['
|
147
|
+
.wait(self.signals['KICK'], self.kickoff_var).signal(self.signals[dev], self.kickoff_var)
|
133
148
|
|
134
149
|
for j,ji in enumerate(jit_cache):
|
135
150
|
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
|
136
151
|
|
152
|
+
# Lazy allocate signals
|
153
|
+
if PROFILE: self.prof_signals += [enqueue_dev.new_signal(value=0) for _ in range(2)]
|
154
|
+
|
137
155
|
for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
|
138
156
|
|
139
157
|
# Encode waits and start profile timestamp (if needed).
|
@@ -142,10 +160,11 @@ class HCQGraph(MultiGraphRunner):
|
|
142
160
|
# Encode main commands based on ji type.
|
143
161
|
if isinstance(ji.prg, CompiledRunner):
|
144
162
|
enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(ji.prg.p.global_size or (1,1,1)), tuple(ji.prg.p.local_size or (1,1,1)))
|
145
|
-
elif isinstance(ji.prg, BufferXfer):
|
163
|
+
elif isinstance(ji.prg, (BufferXfer, BufferCopy)):
|
146
164
|
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
147
|
-
cast(
|
148
|
-
|
165
|
+
for bufid, src in enumerate(cast(list[Buffer], ji.bufs)):
|
166
|
+
if (inprep_idx:=self.input_replace.get((j, bufid))) is not None: self.input_replace_map[enqueue_dev].add(inprep_idx)
|
167
|
+
else: cast(HCQAllocator, enqueue_dev.allocator).map(self.hcq_bufs[j][bufid])
|
149
168
|
enqueue_queue.copy(self.hcq_bufs[j][0].va_addr, self.hcq_bufs[j][1].va_addr, dest.nbytes)
|
150
169
|
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
|
151
170
|
|
@@ -169,23 +188,25 @@ class HCQGraph(MultiGraphRunner):
|
|
169
188
|
self.kickoff_value += 1
|
170
189
|
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
171
190
|
for sig in self.queue_signals_to_reset: sig.value = 0
|
172
|
-
self.signals['
|
191
|
+
self.signals['KICK'].value = self.kickoff_value
|
192
|
+
|
193
|
+
for dev in self.devices:
|
194
|
+
for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_rawbuffers[idx_to_map]._buf)
|
173
195
|
|
174
196
|
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
|
175
197
|
|
176
198
|
hcq_var_vals = {self.kickoff_var: self.kickoff_value, **var_vals,
|
177
199
|
**{var: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
|
178
|
-
**{sig.
|
200
|
+
**{sig.base_buf.va_addr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}}
|
179
201
|
|
180
202
|
# Update rawbuffers
|
181
203
|
for (j,i),input_idx in self.input_replace.items(): hcq_var_vals[self.input_replace_to_var.get((j,i))] = input_rawbuffers[input_idx]._buf.va_addr
|
182
204
|
|
183
205
|
for dev in self.devices:
|
184
|
-
self.comp_queues[dev].submit(dev, hcq_var_vals)
|
185
|
-
if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev,
|
206
|
+
self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {}))
|
207
|
+
if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals_local)
|
186
208
|
|
187
|
-
self.last_timeline[dev] = (dev.timeline_signal, dev.
|
188
|
-
dev.timeline_value += 1
|
209
|
+
self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline())
|
189
210
|
|
190
211
|
if wait:
|
191
212
|
st = time.perf_counter()
|
@@ -195,7 +216,9 @@ class HCQGraph(MultiGraphRunner):
|
|
195
216
|
|
196
217
|
def collect_timestamps(self):
|
197
218
|
# NOTE: Append to any device is fine...
|
198
|
-
self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.
|
219
|
+
self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prof_graph_deps, [s.timestamp for s in self.prof_signals])]
|
220
|
+
|
221
|
+
def dev_name(self, dev) -> str: return dev.device.replace(":", "_")
|
199
222
|
|
200
223
|
def __del__(self):
|
201
224
|
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
@@ -203,3 +226,19 @@ class HCQGraph(MultiGraphRunner):
|
|
203
226
|
if PROFILE and self.kickoff_value >= 1: self.collect_timestamps()
|
204
227
|
|
205
228
|
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferSpec(cpu_access=True))
|
229
|
+
|
230
|
+
@staticmethod
|
231
|
+
def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool:
|
232
|
+
# Check if all devices are HCQ
|
233
|
+
all_devs = cast(list[HCQCompiled], dedup(devs + [Device[b.device] for b in ei.bufs if b]))
|
234
|
+
if not all(issubclass(type(d), HCQCompiled) for d in all_devs): return False
|
235
|
+
|
236
|
+
# If all of devices are mapped into CPU address space, can use CPU inside the peer group.
|
237
|
+
cpu_support = all(isinstance(d.timeline_signal.base_buf.view, MMIOInterface) for d in all_devs)
|
238
|
+
|
239
|
+
# Check if all devices are within the same peer group. If CPU is supported, don't count it as a separate peer group.
|
240
|
+
if len(set(d.peer_group for d in all_devs if cpu_support and not d._is_cpu())) > 1: return False
|
241
|
+
|
242
|
+
# MOCKGPU is not supported, since it can't execute commands in parallel
|
243
|
+
copy = (isinstance(ei.prg, BufferCopy) and cast(HCQCompiled, devs[0]).hw_copy_queue_t is not None) and not getenv("MOCKGPU")
|
244
|
+
return isinstance(ei.prg, (CompiledRunner, BufferXfer)) or copy
|
tinygrad/runtime/graph/metal.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
from typing import Any, cast
|
2
|
-
import ctypes
|
2
|
+
import ctypes, re, decimal
|
3
3
|
from tinygrad.dtype import dtypes
|
4
|
-
from tinygrad.helpers import dedup, getenv
|
5
|
-
from tinygrad.device import Buffer
|
4
|
+
from tinygrad.helpers import dedup, getenv, merge_dicts, PROFILE
|
5
|
+
from tinygrad.device import Buffer, ProfileGraphEntry, ProfileGraphEvent
|
6
6
|
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
7
7
|
from tinygrad.engine.jit import GraphRunner, GraphException
|
8
|
-
from tinygrad.ops import Variable
|
8
|
+
from tinygrad.uop.ops import Variable
|
9
9
|
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
|
10
10
|
MTLResourceOptions, cmdbuf_st_time, cmdbuf_en_time, objc_id, to_ns_str
|
11
11
|
|
@@ -32,11 +32,13 @@ class MetalGraph(GraphRunner):
|
|
32
32
|
icb_descriptor, len(jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache)
|
33
33
|
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
34
34
|
icb_label = bytes(msg("UTF8String", ctypes.c_char_p)(msg("description", objc_instance)(self.icb))).decode()
|
35
|
-
self.needs_icb_fix = int(
|
35
|
+
self.needs_icb_fix = int((m := re.search(r'AGXG(\d+)XFamily', icb_label)) is None or int(m.group(1)) < 15) # not required on M3+
|
36
36
|
|
37
|
-
|
38
|
-
|
39
|
-
|
37
|
+
self.fixedvars = merge_dicts([ji.fixedvars for ji in jit_cache])
|
38
|
+
self.varlist = self.vars + list(self.fixedvars.keys())
|
39
|
+
if len(self.varlist): self.int_buf = self.dev.allocator.alloc(len(self.varlist)*dtypes.int32.itemsize)
|
40
|
+
|
41
|
+
all_pipelines, all_resources = [], [self.int_buf.buf] if len(self.varlist) else []
|
40
42
|
for j,ji in enumerate(jit_cache):
|
41
43
|
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
|
42
44
|
icb_command = msg("indirectComputeCommandAtIndex:", objc_instance)(self.icb, j)
|
@@ -46,7 +48,7 @@ class MetalGraph(GraphRunner):
|
|
46
48
|
if b is not None and b not in input_rawbuffers:
|
47
49
|
msg("setKernelBuffer:offset:atIndex:")(icb_command, b._buf.buf, b._buf.offset, i)
|
48
50
|
all_resources.append(b._buf.buf)
|
49
|
-
for i,v in enumerate(prg.p.vars): msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.
|
51
|
+
for i,v in enumerate(prg.p.vars): msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v)*4, len(ji.bufs)+i)
|
50
52
|
|
51
53
|
global_size, local_size = prg.p.launch_dims(var_vals)
|
52
54
|
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(icb_command, to_struct(*global_size), to_struct(*local_size))
|
@@ -55,14 +57,16 @@ class MetalGraph(GraphRunner):
|
|
55
57
|
self.all_resources = dedup(all_resources)
|
56
58
|
self.all_pipelines = dedup(all_pipelines)
|
57
59
|
self.command_buffer: Any = None
|
58
|
-
if len(self.
|
60
|
+
if len(self.varlist): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
|
61
|
+
for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var]
|
59
62
|
self.range = to_struct(0, len(jit_cache))
|
60
63
|
|
61
64
|
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
|
62
|
-
|
63
65
|
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
64
|
-
|
66
|
+
# NOTE: old command buffer may not be inflight anymore
|
67
|
+
if self.command_buffer is not None and PROFILE: self.collect_timestamps()
|
65
68
|
|
69
|
+
all_resources = dedup(self.all_resources + [input_rawbuffers[input_idx]._buf.buf for input_idx in self.input_replace.values()])
|
66
70
|
for (j,i),input_idx in self.input_replace.items():
|
67
71
|
computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
|
68
72
|
msg("setKernelBuffer:offset:atIndex:")(computeCommand, input_rawbuffers[input_idx]._buf.buf, input_rawbuffers[input_idx]._buf.offset, i)
|
@@ -70,7 +74,7 @@ class MetalGraph(GraphRunner):
|
|
70
74
|
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
71
75
|
computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
|
72
76
|
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(computeCommand, to_struct(*global_dims), to_struct(*local_dims))
|
73
|
-
for
|
77
|
+
for var in self.vars: self.int_buf_view[self.varlist.index(var)] = var_vals[var]
|
74
78
|
|
75
79
|
command_buffer = msg("commandBuffer", objc_instance)(self.dev.mtl_queue)
|
76
80
|
encoder = msg("computeCommandEncoder", objc_instance)(command_buffer)
|
@@ -98,3 +102,15 @@ class MetalGraph(GraphRunner):
|
|
98
102
|
wait_check(command_buffer)
|
99
103
|
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
|
100
104
|
return None
|
105
|
+
|
106
|
+
def collect_timestamps(self):
|
107
|
+
# create a graph event and evenly space each program
|
108
|
+
st, en = decimal.Decimal(cmdbuf_st_time(self.command_buffer)) * 1000000, decimal.Decimal(cmdbuf_en_time(self.command_buffer)) * 1000000
|
109
|
+
ents = [ProfileGraphEntry(self.device, cast(CompiledRunner, ji.prg)._prg.name, i, i+1, is_copy=False) for i,ji in enumerate(self.jit_cache)]
|
110
|
+
step = (en-st)/len(ents)
|
111
|
+
self.dev.profile_events += [ProfileGraphEvent(ents, [], [st+step*i for i in range(len(ents)+1)])]
|
112
|
+
|
113
|
+
def __del__(self):
|
114
|
+
if PROFILE and self.command_buffer is not None:
|
115
|
+
wait_check(self.command_buffer)
|
116
|
+
self.collect_timestamps()
|
@@ -0,0 +1,114 @@
|
|
1
|
+
import time, itertools
|
2
|
+
from tinygrad.uop.ops import Variable
|
3
|
+
from tinygrad.engine.jit import MultiGraphRunner
|
4
|
+
from tinygrad.engine.realize import CompiledRunner, BufferXfer, ExecItem
|
5
|
+
from tinygrad.device import Device, Compiled, Buffer
|
6
|
+
from tinygrad.runtime.ops_remote import RemoteDevice, RemoteConnection, RemoteRequest, GraphComputeItem, Transfer, GraphAlloc, GraphFree, GraphExec
|
7
|
+
from tinygrad.runtime.ops_remote import BatchTransfer, Event, Wait
|
8
|
+
from tinygrad.helpers import unwrap, flatten, dedup
|
9
|
+
from enum import Enum, auto
|
10
|
+
from dataclasses import replace
|
11
|
+
from collections import defaultdict
|
12
|
+
from typing import cast
|
13
|
+
|
14
|
+
class StagingType(Enum): NONE = auto(); GRAPH = auto(); TRANSFER = auto() # noqa: E702
|
15
|
+
|
16
|
+
def rd(dev:Compiled) -> RemoteDevice: return cast(RemoteDevice, dev)
|
17
|
+
def dev_key(dev:RemoteDevice): return dev.conn if dev.properties.graph_supports_multi else dev
|
18
|
+
def map_rawbuf(rawbuf:Buffer): return (cast(RemoteDevice, Device[rawbuf.device]).session, rawbuf._buf)
|
19
|
+
|
20
|
+
class RemoteGraph(MultiGraphRunner):
|
21
|
+
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[Variable, int]):
|
22
|
+
super().__init__(jit_cache, rawbufs, var_vals)
|
23
|
+
devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache]))
|
24
|
+
c2d = {device.conn: device for device in devices}
|
25
|
+
self.handle_indexes = {map_rawbuf(rawbufs[i]): i for i in sorted(dedup(self.input_replace.values()))}
|
26
|
+
|
27
|
+
self.template: list[RemoteRequest] = []
|
28
|
+
|
29
|
+
stagings: dict[RemoteDevice|RemoteConnection, list[GraphComputeItem|Transfer]] = defaultdict(list)
|
30
|
+
clobbered_buffers: set[Buffer] = set()
|
31
|
+
cur_staging_type: StagingType = StagingType.NONE
|
32
|
+
|
33
|
+
def _flush(new_staging_type:StagingType, force_break:bool=False):
|
34
|
+
nonlocal cur_staging_type
|
35
|
+
if cur_staging_type == new_staging_type and not force_break: return
|
36
|
+
# Pre-sync
|
37
|
+
if cur_staging_type == StagingType.TRANSFER:
|
38
|
+
for sdev,ddev in itertools.permutations(c2d.values(), 2):
|
39
|
+
self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
|
40
|
+
self.template.append(Wait(event, session=ddev.session))
|
41
|
+
# Flush
|
42
|
+
for dev in devices:
|
43
|
+
dk = dev_key(dev)
|
44
|
+
staging = stagings[dk]
|
45
|
+
if not staging: continue
|
46
|
+
match cur_staging_type:
|
47
|
+
case StagingType.GRAPH:
|
48
|
+
bufs = tuple(map_rawbuf(rawbufs[i]) for i in sorted(dedup(self.input_replace.values())) if dev_key(rd(Device[rawbufs[i].device])) == dk)
|
49
|
+
dev.q(GraphAlloc(graph_num:=next(dev.graph_num), tuple(staging), tuple(bufs), var_vals))
|
50
|
+
self.template.append(GraphExec(graph_num, bufs, var_vals, wait=False, session=dev.session))
|
51
|
+
case StagingType.TRANSFER:
|
52
|
+
st = cast(list[Transfer], staging)
|
53
|
+
for host in dedup(t.dsession.host for t in st):
|
54
|
+
sbuffer_nums = [(unwrap(t.session), t.buffer_num) for t in st if t.dsession.host == host]
|
55
|
+
dbuffer_nums = [(t.dsession, t.dbuffer_num) for t in st if t.dsession.host == host]
|
56
|
+
self.template.append(BatchTransfer(sbuffer_nums, dbuffer_nums, session=dev.session))
|
57
|
+
staging.clear()
|
58
|
+
# Post-sync
|
59
|
+
if cur_staging_type == StagingType.TRANSFER:
|
60
|
+
for sdev,ddev in itertools.permutations(c2d.values(), 2):
|
61
|
+
self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
|
62
|
+
self.template.append(Wait(event, session=ddev.session))
|
63
|
+
cur_staging_type = new_staging_type
|
64
|
+
clobbered_buffers.clear()
|
65
|
+
|
66
|
+
for ji in jit_cache:
|
67
|
+
match ji.prg:
|
68
|
+
case CompiledRunner():
|
69
|
+
_flush(StagingType.GRAPH)
|
70
|
+
gi = GraphComputeItem(ji.prg.dev.session, ji.prg._prg.name, ji.prg._prg.datahash, tuple(unwrap(buf)._buf for buf in ji.bufs),
|
71
|
+
tuple(ji.prg.p.vars), ji.fixedvars, tuple(ji.prg.p.ins), tuple(ji.prg.p.outs),
|
72
|
+
tuple(ji.prg.p.global_size) if ji.prg.p.global_size is not None else None,
|
73
|
+
tuple(ji.prg.p.local_size) if ji.prg.p.local_size is not None else None)
|
74
|
+
stagings[dev_key(ji.prg.dev)].append(gi)
|
75
|
+
case BufferXfer():
|
76
|
+
dest, src = ji.bufs[0:2]
|
77
|
+
dest_dev, src_dev = cast(RemoteDevice, Device[unwrap(dest).device]), cast(RemoteDevice, Device[unwrap(src).device])
|
78
|
+
assert dest is not None and src is not None, ji
|
79
|
+
ti = Transfer(session=src_dev.session, buffer_num=src._buf, dsession=dest_dev.session, dbuffer_num=dest._buf)
|
80
|
+
if dev_key(dest_dev) == dev_key(src_dev):
|
81
|
+
_flush(StagingType.GRAPH)
|
82
|
+
stagings[dev_key(src_dev)].append(ti)
|
83
|
+
elif dest_dev.conn == src_dev.conn:
|
84
|
+
_flush(StagingType.NONE)
|
85
|
+
self.template.append(ti)
|
86
|
+
else:
|
87
|
+
_flush(StagingType.TRANSFER, force_break=src in clobbered_buffers)
|
88
|
+
clobbered_buffers.add(dest)
|
89
|
+
stagings[dev_key(src_dev)].append(ti)
|
90
|
+
case _: raise NotImplementedError(ji.prg)
|
91
|
+
_flush(StagingType.NONE)
|
92
|
+
def __del__(self):
|
93
|
+
for req in self.template:
|
94
|
+
match req:
|
95
|
+
case GraphExec(): RemoteConnection(unwrap(req.session).host).q(GraphFree(req.graph_num, session=req.session))
|
96
|
+
def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False):
|
97
|
+
if wait: st = time.perf_counter()
|
98
|
+
rmap = {orig: map_rawbuf(rawbufs[replace_idx]) for orig,replace_idx in self.handle_indexes.items()}
|
99
|
+
for req in self.template:
|
100
|
+
match req:
|
101
|
+
case GraphExec():
|
102
|
+
req = replace(req, bufs=tuple(rmap[buf] for buf in req.bufs), var_vals=var_vals, wait=wait)
|
103
|
+
case Transfer():
|
104
|
+
if (req.session, req.buffer_num) in rmap: req = replace(req, buffer_num=rmap[(req.session, req.buffer_num)][1])
|
105
|
+
if (req.dsession, req.dbuffer_num) in rmap: req = replace(req, dbuffer_num=rmap[(req.dsession, req.dbuffer_num)][1])
|
106
|
+
case BatchTransfer():
|
107
|
+
req = replace(req, sbuffer_nums=[rmap.get(b, b) for b in req.sbuffer_nums], dbuffer_nums=[rmap.get(b, b) for b in req.dbuffer_nums])
|
108
|
+
case Event()|Wait():
|
109
|
+
pass # event number can be reused
|
110
|
+
case _: raise NotImplementedError(req)
|
111
|
+
RemoteConnection(unwrap(req.session).host).q(req)
|
112
|
+
if wait:
|
113
|
+
RemoteConnection(unwrap(req.session).host).batch_submit()
|
114
|
+
return time.perf_counter() - st
|