tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/runtime/ops_cloud.py
DELETED
@@ -1,220 +0,0 @@
|
|
1
|
-
# the CLOUD=1 device is a process boundary between the frontend/runtime
|
2
|
-
# normally tinygrad is frontend <-> middleware <-> runtime <-> hardware
|
3
|
-
# with CLOUD tinygrad is frontend <-> middleware <-> CloudDevice ///HTTP/// cloud_server <-> runtime <-> hardware
|
4
|
-
# this client and server can be on the same machine, same network, or just same internet
|
5
|
-
# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
|
6
|
-
|
7
|
-
from __future__ import annotations
|
8
|
-
from typing import Optional, Any
|
9
|
-
from collections import defaultdict
|
10
|
-
from dataclasses import dataclass, field
|
11
|
-
import multiprocessing, functools, http.client, hashlib, json, time, os, binascii, struct, ast, contextlib
|
12
|
-
from http.server import HTTPServer, BaseHTTPRequestHandler
|
13
|
-
from tinygrad.renderer import Renderer
|
14
|
-
from tinygrad.dtype import dtypes
|
15
|
-
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing
|
16
|
-
from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferSpec
|
17
|
-
|
18
|
-
# ***** API *****
|
19
|
-
|
20
|
-
class CloudRequest: pass
|
21
|
-
|
22
|
-
@dataclass(frozen=True)
|
23
|
-
class BufferAlloc(CloudRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
|
24
|
-
|
25
|
-
@dataclass(frozen=True)
|
26
|
-
class BufferFree(CloudRequest): buffer_num: int # noqa: E702
|
27
|
-
|
28
|
-
@dataclass(frozen=True)
|
29
|
-
class CopyIn(CloudRequest): buffer_num: int; datahash: str # noqa: E702
|
30
|
-
|
31
|
-
@dataclass(frozen=True)
|
32
|
-
class CopyOut(CloudRequest): buffer_num: int
|
33
|
-
|
34
|
-
@dataclass(frozen=True)
|
35
|
-
class ProgramAlloc(CloudRequest): name: str; datahash: str # noqa: E702
|
36
|
-
|
37
|
-
@dataclass(frozen=True)
|
38
|
-
class ProgramFree(CloudRequest): name: str; datahash: str # noqa: E702
|
39
|
-
|
40
|
-
@dataclass(frozen=True)
|
41
|
-
class ProgramExec(CloudRequest):
|
42
|
-
name: str; datahash: str; bufs: tuple[int, ...]; vals: tuple[int, ...] # noqa: E702
|
43
|
-
global_size: Optional[tuple[int, ...]]; local_size: Optional[tuple[int, ...]]; wait: bool # noqa: E702
|
44
|
-
|
45
|
-
# for safe deserialization
|
46
|
-
whitelist = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, BufferSpec]}
|
47
|
-
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
|
48
|
-
ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}),
|
49
|
-
ast.Name: lambda x: whitelist[x.id], ast.Attribute: lambda x: {"imagef": dtypes.imagef, "imageh": dtypes.imageh}[x.attr]}
|
50
|
-
def safe_eval(node): return eval_fxns[node.__class__](node)
|
51
|
-
|
52
|
-
class BatchRequest:
|
53
|
-
def __init__(self):
|
54
|
-
self._q: list[CloudRequest] = []
|
55
|
-
self._h: dict[str, bytes] = {}
|
56
|
-
def h(self, d:bytes) -> str:
|
57
|
-
binhash = hashlib.sha256(d).digest()
|
58
|
-
self._h[datahash:=binascii.hexlify(binhash).decode()] = binhash+struct.pack("<Q", len(d))+d
|
59
|
-
return datahash
|
60
|
-
def q(self, x:CloudRequest): self._q.append(x)
|
61
|
-
def serialize(self) -> bytes:
|
62
|
-
self.h(repr(self._q).encode())
|
63
|
-
return b''.join(self._h.values())
|
64
|
-
def deserialize(self, dat:bytes) -> BatchRequest:
|
65
|
-
ptr = 0
|
66
|
-
while ptr < len(dat):
|
67
|
-
datahash, datalen = binascii.hexlify(dat[ptr:ptr+0x20]).decode(), struct.unpack("<Q", dat[ptr+0x20:ptr+0x28])[0]
|
68
|
-
self._h[datahash] = dat[ptr+0x28:ptr+0x28+datalen]
|
69
|
-
ptr += 0x28+datalen
|
70
|
-
self._q = safe_eval(ast.parse(self._h[datahash], mode="eval").body)
|
71
|
-
return self
|
72
|
-
|
73
|
-
# ***** backend *****
|
74
|
-
|
75
|
-
@dataclass
|
76
|
-
class CloudSession:
|
77
|
-
programs: dict[tuple[str, str], Any] = field(default_factory=dict)
|
78
|
-
# TODO: the buffer should track this internally
|
79
|
-
buffers: dict[int, tuple[Any, int, Optional[BufferSpec]]] = field(default_factory=dict)
|
80
|
-
|
81
|
-
class CloudHandler(BaseHTTPRequestHandler):
|
82
|
-
protocol_version = 'HTTP/1.1'
|
83
|
-
device: str
|
84
|
-
sessions: defaultdict[str, CloudSession] = defaultdict(CloudSession)
|
85
|
-
|
86
|
-
def setup(self):
|
87
|
-
super().setup()
|
88
|
-
print(f"connection established with {self.client_address}, socket: {self.connection.fileno()}")
|
89
|
-
|
90
|
-
def _do(self, method):
|
91
|
-
session = CloudHandler.sessions[unwrap(self.headers.get("Cookie")).split("session=")[1]]
|
92
|
-
ret, status_code = b"", 200
|
93
|
-
if self.path == "/batch" and method == "POST":
|
94
|
-
# TODO: streaming deserialize?
|
95
|
-
req = BatchRequest().deserialize(self.rfile.read(int(unwrap(self.headers.get('Content-Length')))))
|
96
|
-
# the cmds are always last (currently in datahash)
|
97
|
-
for c in req._q:
|
98
|
-
if DEBUG >= 1: print(c)
|
99
|
-
match c:
|
100
|
-
case BufferAlloc():
|
101
|
-
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
|
102
|
-
session.buffers[c.buffer_num] = (Device[CloudHandler.device].allocator.alloc(c.size, c.options), c.size, c.options)
|
103
|
-
case BufferFree():
|
104
|
-
buf,sz,buffer_options = session.buffers[c.buffer_num]
|
105
|
-
Device[CloudHandler.device].allocator.free(buf,sz,buffer_options)
|
106
|
-
del session.buffers[c.buffer_num]
|
107
|
-
case CopyIn(): Device[CloudHandler.device].allocator._copyin(session.buffers[c.buffer_num][0], memoryview(bytearray(req._h[c.datahash])))
|
108
|
-
case CopyOut():
|
109
|
-
buf,sz,_ = session.buffers[c.buffer_num]
|
110
|
-
Device[CloudHandler.device].allocator._copyout(memoryview(ret:=bytearray(sz)), buf)
|
111
|
-
case ProgramAlloc():
|
112
|
-
lib = Device[CloudHandler.device].compiler.compile_cached(req._h[c.datahash].decode())
|
113
|
-
session.programs[(c.name, c.datahash)] = Device[CloudHandler.device].runtime(c.name, lib)
|
114
|
-
case ProgramFree(): del session.programs[(c.name, c.datahash)]
|
115
|
-
case ProgramExec():
|
116
|
-
bufs = [session.buffers[x][0] for x in c.bufs]
|
117
|
-
extra_args = {k:v for k,v in [("global_size", c.global_size), ("local_size", c.local_size)] if v is not None}
|
118
|
-
r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
|
119
|
-
if r is not None: ret = str(r).encode()
|
120
|
-
elif self.path == "/renderer" and method == "GET":
|
121
|
-
cls, args = Device[CloudHandler.device].renderer.__reduce__()
|
122
|
-
ret = json.dumps((cls.__module__, cls.__name__, args)).encode()
|
123
|
-
else: status_code = 404
|
124
|
-
self.send_response(status_code)
|
125
|
-
self.send_header('Content-Length', str(len(ret)))
|
126
|
-
self.end_headers()
|
127
|
-
return self.wfile.write(ret)
|
128
|
-
|
129
|
-
def do_GET(self): return self._do("GET")
|
130
|
-
def do_POST(self): return self._do("POST")
|
131
|
-
|
132
|
-
def cloud_server(port:int):
|
133
|
-
multiprocessing.current_process().name = "MainProcess"
|
134
|
-
CloudHandler.device = getenv("CLOUDDEV", "METAL") if Device.DEFAULT == "CLOUD" else Device.DEFAULT
|
135
|
-
print(f"start cloud server on {port} with device {CloudHandler.device}")
|
136
|
-
server = HTTPServer(('', port), CloudHandler)
|
137
|
-
server.serve_forever()
|
138
|
-
|
139
|
-
# ***** frontend *****
|
140
|
-
|
141
|
-
class CloudAllocator(Allocator):
|
142
|
-
def __init__(self, dev:CloudDevice):
|
143
|
-
self.device = dev
|
144
|
-
super().__init__()
|
145
|
-
# TODO: ideally we shouldn't have to deal with images here
|
146
|
-
def _alloc(self, size:int, options:BufferSpec) -> int:
|
147
|
-
self.device.buffer_num += 1
|
148
|
-
self.device.req.q(BufferAlloc(self.device.buffer_num, size, options))
|
149
|
-
return self.device.buffer_num
|
150
|
-
# TODO: options should not be here in any Allocator
|
151
|
-
def _free(self, opaque:int, options): self.device.req.q(BufferFree(opaque))
|
152
|
-
def _copyin(self, dest:int, src:memoryview): self.device.req.q(CopyIn(dest, self.device.req.h(bytes(src))))
|
153
|
-
def _copyout(self, dest:memoryview, src:int):
|
154
|
-
self.device.req.q(CopyOut(src))
|
155
|
-
resp = self.device.batch_submit()
|
156
|
-
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
|
157
|
-
dest[:] = resp
|
158
|
-
|
159
|
-
class CloudProgram:
|
160
|
-
def __init__(self, dev:CloudDevice, name:str, lib:bytes):
|
161
|
-
self.dev, self.name = dev, name
|
162
|
-
self.datahash = self.dev.req.h(lib)
|
163
|
-
self.dev.req.q(ProgramAlloc(self.name, self.datahash))
|
164
|
-
super().__init__()
|
165
|
-
def __del__(self): self.dev.req.q(ProgramFree(self.name, self.datahash))
|
166
|
-
|
167
|
-
def __call__(self, *bufs, global_size=None, local_size=None, vals:tuple[int, ...]=(), wait=False):
|
168
|
-
self.dev.req.q(ProgramExec(self.name, self.datahash, bufs, vals, global_size, local_size, wait))
|
169
|
-
if wait: return float(self.dev.batch_submit())
|
170
|
-
|
171
|
-
class CloudDevice(Compiled):
|
172
|
-
def __init__(self, device:str):
|
173
|
-
if (host:=getenv("HOST", "")) != "": self.host = host
|
174
|
-
else:
|
175
|
-
p = multiprocessing.Process(target=cloud_server, args=(6667,))
|
176
|
-
p.daemon = True
|
177
|
-
p.start()
|
178
|
-
self.host = "127.0.0.1:6667"
|
179
|
-
|
180
|
-
# state for the connection
|
181
|
-
self.session = binascii.hexlify(os.urandom(0x10)).decode()
|
182
|
-
self.buffer_num = 0
|
183
|
-
self.req: BatchRequest = BatchRequest()
|
184
|
-
|
185
|
-
if DEBUG >= 1: print(f"cloud with host {self.host}")
|
186
|
-
while 1:
|
187
|
-
try:
|
188
|
-
self.conn = http.client.HTTPConnection(self.host, timeout=60.0)
|
189
|
-
clouddev = json.loads(self.send("GET", "renderer").decode())
|
190
|
-
break
|
191
|
-
except Exception as e:
|
192
|
-
print(e)
|
193
|
-
time.sleep(0.1)
|
194
|
-
if DEBUG >= 1: print(f"remote has device {clouddev}")
|
195
|
-
# TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
|
196
|
-
if not clouddev[0].startswith("tinygrad.renderer.") or not clouddev[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {clouddev}")
|
197
|
-
renderer_class = fromimport(clouddev[0], clouddev[1]) # TODO: is this secure?
|
198
|
-
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {clouddev}")
|
199
|
-
super().__init__(device, CloudAllocator(self), renderer_class(*clouddev[2]), Compiler(), functools.partial(CloudProgram, self))
|
200
|
-
|
201
|
-
def __del__(self):
|
202
|
-
# TODO: this is never being called
|
203
|
-
# TODO: should close the whole session
|
204
|
-
with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected): self.batch_submit()
|
205
|
-
|
206
|
-
def batch_submit(self):
|
207
|
-
data = self.req.serialize()
|
208
|
-
with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=1):
|
209
|
-
ret = self.send("POST", "batch", data)
|
210
|
-
self.req = BatchRequest()
|
211
|
-
return ret
|
212
|
-
|
213
|
-
def send(self, method, path, data:Optional[bytes]=None) -> bytes:
|
214
|
-
# TODO: retry logic
|
215
|
-
self.conn.request(method, "/"+path, data, headers={"Cookie": f"session={self.session}"})
|
216
|
-
response = self.conn.getresponse()
|
217
|
-
assert response.status == 200, f"failed on {method} {path}"
|
218
|
-
return response.read()
|
219
|
-
|
220
|
-
if __name__ == "__main__": cloud_server(getenv("PORT", 6667))
|
@@ -1,94 +0,0 @@
|
|
1
|
-
import collections
|
2
|
-
from tinygrad.helpers import round_up
|
3
|
-
|
4
|
-
class TLSFAllocator:
|
5
|
-
"""
|
6
|
-
The allocator is based on the Two-Level Segregated Fit (TLSF) algorithm. The allocator maintains 2 level of buckets:
|
7
|
-
* 1st level is determined by the most significant bit of the size.
|
8
|
-
* 2nd level splits the covered memory of 1st level into @lv2_cnt entries.
|
9
|
-
|
10
|
-
For each allocation request, the allocator searches for the smallest block that can fit the requested size.
|
11
|
-
For each deallocation request, the allocator merges the block with its neighbors if they are free.
|
12
|
-
"""
|
13
|
-
|
14
|
-
def __init__(self, size:int, base:int=0, block_size:int=16, lv2_cnt:int=16):
|
15
|
-
self.size, self.base, self.block_size, self.l2_cnt = size, base, block_size, lv2_cnt.bit_length()
|
16
|
-
self.storage:list = [collections.defaultdict(list) for _ in range(size.bit_length() + 1)]
|
17
|
-
self.lv1_entries:list[int] = [0] * len(self.storage)
|
18
|
-
|
19
|
-
# self.blocks is more like a linked list, where each entry is a contigous block.
|
20
|
-
self.blocks:dict[int, tuple[int, int|None, int|None, bool]] = {0: (size, None, None, True)} # size, next, prev, is_free
|
21
|
-
self._insert_block(0, size)
|
22
|
-
|
23
|
-
def lv1(self, size): return size.bit_length()
|
24
|
-
def lv2(self, size): return (size - (1 << (size.bit_length() - 1))) // (1 << max(0, size.bit_length() - self.l2_cnt))
|
25
|
-
|
26
|
-
def _insert_block(self, start:int, size:int, prev:int|None=None):
|
27
|
-
if prev is None: prev = self.blocks[start][2]
|
28
|
-
self.storage[self.lv1(size)][self.lv2(size)].append(start)
|
29
|
-
self.lv1_entries[self.lv1(size)] += 1
|
30
|
-
self.blocks[start] = (size, start + size, prev, True)
|
31
|
-
return self
|
32
|
-
|
33
|
-
def _remove_block(self, start:int, size:int, prev:int|None=None):
|
34
|
-
if prev is None: prev = self.blocks[start][2]
|
35
|
-
self.storage[self.lv1(size)][self.lv2(size)].remove(start)
|
36
|
-
self.lv1_entries[self.lv1(size)] -= 1
|
37
|
-
self.blocks[start] = (size, start + size, prev, False)
|
38
|
-
return self
|
39
|
-
|
40
|
-
def _split_block(self, start:int, size:int, new_size:int):
|
41
|
-
nxt = self.blocks[start][1]
|
42
|
-
assert self.blocks[start][3], "block must be free"
|
43
|
-
self._remove_block(start, size)._insert_block(start, new_size)._insert_block(start + new_size, size - new_size, prev=start)
|
44
|
-
if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start + new_size, self.blocks[nxt][3])
|
45
|
-
return self
|
46
|
-
|
47
|
-
def _merge_right(self, start:int):
|
48
|
-
size, nxt, _, is_free = self.blocks[start]
|
49
|
-
assert is_free, "block must be free"
|
50
|
-
|
51
|
-
while is_free and nxt in self.blocks:
|
52
|
-
if (blk:=self.blocks[nxt])[3] is False: break
|
53
|
-
self._remove_block(start, size)._remove_block(nxt, blk[0])._insert_block(start, size:=size + blk[0])
|
54
|
-
assert self.blocks[start][1] == blk[1]
|
55
|
-
_, nxt, _, _ = self.blocks.pop(nxt)
|
56
|
-
|
57
|
-
if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start, self.blocks[nxt][3])
|
58
|
-
|
59
|
-
def _merge_block(self, start:int):
|
60
|
-
# Go left while blocks are free. Then merge all them right.
|
61
|
-
while (x:=self.blocks[start][2]) is not None and self.blocks[x][3] is True: start = x
|
62
|
-
self._merge_right(start)
|
63
|
-
|
64
|
-
def alloc(self, req_size:int, align:int=1) -> int:
|
65
|
-
req_size = max(self.block_size, req_size) # at least block size.
|
66
|
-
size = max(self.block_size, req_size + align - 1)
|
67
|
-
|
68
|
-
# Round up the allocation size to the next bucket, so any entry there can fit the requested size.
|
69
|
-
size = round_up(size, (1 << size.bit_length() - self.l2_cnt))
|
70
|
-
|
71
|
-
# Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found.
|
72
|
-
for l1 in range(self.lv1(size), len(self.storage)):
|
73
|
-
if self.lv1_entries[l1] == 0: continue
|
74
|
-
for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)):
|
75
|
-
if len(self.storage[l1][l2]) > 0:
|
76
|
-
nsize = self.blocks[self.storage[l1][l2][0]][0]
|
77
|
-
assert nsize >= size, "block must be larger"
|
78
|
-
|
79
|
-
# Block start address.
|
80
|
-
start = self.storage[l1][l2][0]
|
81
|
-
|
82
|
-
# If request contains alignment, split the block into two parts.
|
83
|
-
if (new_start:=round_up(start, align)) != start:
|
84
|
-
self._split_block(start, nsize, new_start - start)
|
85
|
-
start, nsize = new_start, self.blocks[new_start][0]
|
86
|
-
|
87
|
-
# If the block is larger than the requested size, split it into two parts.
|
88
|
-
if nsize > req_size: self._split_block(start, nsize, req_size)
|
89
|
-
self._remove_block(start, req_size) # Mark the block as allocated.
|
90
|
-
return start + self.base
|
91
|
-
raise MemoryError(f"Can't allocate {req_size} bytes")
|
92
|
-
|
93
|
-
def free(self, start:int):
|
94
|
-
self._insert_block(start - self.base, self.blocks[start - self.base][0])._merge_block(start - self.base)
|
tinygrad/spec.py
DELETED
@@ -1,155 +0,0 @@
|
|
1
|
-
from typing import cast
|
2
|
-
from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops
|
3
|
-
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
|
4
|
-
from tinygrad.helpers import all_same, dedup, prod
|
5
|
-
|
6
|
-
buffer_spec = PatternMatcher([
|
7
|
-
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
|
8
|
-
(UPat(Ops.DEVICE, dtypes.void, (), name="device"), lambda device: isinstance(device.arg, str)),
|
9
|
-
(UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="buf"),
|
10
|
-
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
|
11
|
-
])
|
12
|
-
|
13
|
-
# *** this is the spec of a Tensor in UOp ***
|
14
|
-
|
15
|
-
tensor_uop_spec = buffer_spec+PatternMatcher([
|
16
|
-
(UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)),
|
17
|
-
# naturally correct
|
18
|
-
lambda mv,x: (isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
|
19
|
-
# "make things that can't be images not images" can change the buffer dtype
|
20
|
-
# this is fine as long as it's a realized buffer and base dtypes match.
|
21
|
-
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.is_realized)),
|
22
|
-
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.CONST, Ops.DEVICE}),)), lambda: False),
|
23
|
-
|
24
|
-
# Tensor variable bindings
|
25
|
-
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
|
26
|
-
|
27
|
-
# Tensor const has a device and an unmasked ShapeTracker of stride 0
|
28
|
-
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)),
|
29
|
-
lambda st: st.st.views[0].mask is None and len(st.st.views) == 1 and all(s == 0 for s in st.st.views[0].strides)),
|
30
|
-
|
31
|
-
# DETACH and CONTIGUOUS change how we interpret the source UOp
|
32
|
-
# CONTIGUOUS ensures the source UOp realizes
|
33
|
-
(UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype),
|
34
|
-
|
35
|
-
# COPY
|
36
|
-
# NOTE: the arg here specifies clone=True, which prevents folding same device copy
|
37
|
-
(UPat(Ops.COPY, name="copy", src=(UPat(Ops.DEVICE), UPat.var("x"))), lambda copy,x: isinstance(copy.arg, bool) and copy.dtype == x.dtype),
|
38
|
-
|
39
|
-
# ASSIGN changes the value of a realized buffer
|
40
|
-
(UPat(Ops.ASSIGN, name="assign", src=(UPat.var("target"), UPat.var("new_val"))),
|
41
|
-
lambda assign,target,new_val: target.is_realized and (assign.dtype == target.dtype == new_val.dtype)),
|
42
|
-
])
|
43
|
-
|
44
|
-
# ***** uop type spec *****
|
45
|
-
|
46
|
-
# this is the matcher for the final rendered UOps
|
47
|
-
# matcher functions returns True or False (or None to not match)
|
48
|
-
spec = PatternMatcher([
|
49
|
-
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
|
50
|
-
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
|
51
|
-
(UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True),
|
52
|
-
lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
|
53
|
-
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
54
|
-
|
55
|
-
(UPat(Ops.RANGE, src=(UPat.var("x"), UPat.var("y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype and isinstance(rng.arg, int)),
|
56
|
-
(UPat(Ops.SPECIAL, src=()), lambda: True),
|
57
|
-
|
58
|
-
# TODO: confirm the args of both of these are shapetrackers
|
59
|
-
(UPat(Ops.VIEW, dtypes.void, src=()), lambda: True),
|
60
|
-
(UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype.base == src.dtype.base),
|
61
|
-
|
62
|
-
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
|
63
|
-
(UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
64
|
-
|
65
|
-
# early LOAD has a <buf, shapetracker, store?>
|
66
|
-
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
|
67
|
-
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
|
68
|
-
|
69
|
-
# early STORE has a <buf, shapetracker, val>
|
70
|
-
(UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
|
71
|
-
|
72
|
-
# **** new style load/store ****
|
73
|
-
|
74
|
-
# INDEX is used in new style load/store
|
75
|
-
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
|
76
|
-
|
77
|
-
# LOAD takes a <bufidx, alt?, gate?, barrier?>
|
78
|
-
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
|
79
|
-
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
|
80
|
-
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat.var("alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
|
81
|
-
|
82
|
-
# STORE takes a <bufidx, val, gate?>
|
83
|
-
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True),
|
84
|
-
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
85
|
-
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True),
|
86
|
-
|
87
|
-
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
88
|
-
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
89
|
-
(UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base),
|
90
|
-
# and SHL/SHR, the shift distance can be an int
|
91
|
-
(UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
|
92
|
-
(UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
93
|
-
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
|
94
|
-
|
95
|
-
(UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
|
96
|
-
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
|
97
|
-
|
98
|
-
# WMMA has a <a, b, acc>
|
99
|
-
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
|
100
|
-
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
101
|
-
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
102
|
-
|
103
|
-
# if has a <gate, barrier?>
|
104
|
-
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
|
105
|
-
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
|
106
|
-
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
107
|
-
|
108
|
-
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
109
|
-
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
110
|
-
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
111
|
-
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
|
112
|
-
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
|
113
|
-
|
114
|
-
# NOTE: for testing, we let sinks be anything
|
115
|
-
#(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
|
116
|
-
(UPat((Ops.NAME, Ops.SINK), dtypes.void), lambda: True),
|
117
|
-
(UPat((Ops.NOOP, Ops.CUSTOM)), lambda: True),
|
118
|
-
|
119
|
-
# PTX LOAD/STORE
|
120
|
-
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
121
|
-
])
|
122
|
-
|
123
|
-
# *** this is the spec of a Kernel in UOp ***
|
124
|
-
|
125
|
-
kernel_spec = buffer_spec+PatternMatcher([
|
126
|
-
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True),
|
127
|
-
# assign has a buffer view and kernel source, it can optionally depend on other assigns
|
128
|
-
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
|
129
|
-
# view/sink/const can also exist in the kernel graph
|
130
|
-
(UPat((Ops.VIEW, Ops.SINK, Ops.CONST)), lambda: True),
|
131
|
-
(UPat(GroupOp.All), lambda: False),
|
132
|
-
])
|
133
|
-
|
134
|
-
# *** this is the UOp shape spec ***
|
135
|
-
|
136
|
-
def verify_sink_dims(sink:UOp):
|
137
|
-
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sink.toposort if x.op is not Ops.SINK and x.st is not None])]
|
138
|
-
return all_same([x.st_arg.size for x in sink.src]) and all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims)
|
139
|
-
|
140
|
-
shape_spec = PatternMatcher([
|
141
|
-
# shapes must have either 1 or n in each dimension
|
142
|
-
(UPat(Ops.SINK, src=UPat(Ops.STORE), allow_any_len=True, name="sink"), verify_sink_dims),
|
143
|
-
# all parent UOps must have the same shape
|
144
|
-
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])),
|
145
|
-
])
|
146
|
-
|
147
|
-
# ***** uop helpers *****
|
148
|
-
|
149
|
-
def type_verify(uops:list[UOp], *extra_specs:PatternMatcher):
|
150
|
-
specs = [spec, *extra_specs]
|
151
|
-
for i,u in enumerate(uops):
|
152
|
-
spec_ret = [cast(bool|None, s.rewrite(u)) for s in specs]
|
153
|
-
if any(ret is False for ret in spec_ret) or all(ret is None for ret in spec_ret):
|
154
|
-
print_uops(uops)
|
155
|
-
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}")
|