tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,172 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import resource, ctypes, weakref, functools, itertools, tinygrad.runtime.autogen.ib as ib
|
3
|
+
from typing import Iterator
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from weakref import WeakKeyDictionary
|
6
|
+
from tinygrad.device import Buffer, DMACPURef, DMAFdRef
|
7
|
+
from tinygrad.helpers import getenv, round_up, DEBUG
|
8
|
+
|
9
|
+
DEFAULT_PORT, DEFAULT_GID = getenv("DEFAULT_PORT", 1), getenv("DEFAULT_GID", 3) # DEFAULT_GID=0 for RXE
|
10
|
+
IOVA_ALIGN = resource.getpagesize()
|
11
|
+
|
12
|
+
def checkz(x, ret=None):
|
13
|
+
assert x == 0, f'{x} != 0 (errno {ctypes.get_errno()})'
|
14
|
+
return ret
|
15
|
+
|
16
|
+
@dataclass(frozen=True)
|
17
|
+
class SGE:
|
18
|
+
dst_iova: int
|
19
|
+
dst_key: int
|
20
|
+
src_iova: int
|
21
|
+
src_key: int
|
22
|
+
size: int
|
23
|
+
|
24
|
+
class IBCtx:
|
25
|
+
def __init__(self, idx:int):
|
26
|
+
# Open the device (aka Host Channel Adapter in ib-speak)
|
27
|
+
devs = ib.ibv_get_device_list(ctypes.byref(ndevs:=ctypes.c_int32()))
|
28
|
+
if idx >= ndevs.value: raise IndexError(f"{idx} > {ndevs.value}")
|
29
|
+
self.ctx = ib.ibv_open_device(devs[idx])
|
30
|
+
ib.ibv_free_device_list(devs)
|
31
|
+
|
32
|
+
# HACK: remove this (and all usage of `ctx.contents.ops`) when clang2py can deal with `static inline` wrapper-functions
|
33
|
+
self.vctx = ctypes.cast(ctypes.addressof(self.ctx.contents) - ib.struct_verbs_context.context.offset, ctypes.POINTER(ib.struct_verbs_context))
|
34
|
+
|
35
|
+
# Get attributes. Something like port_attr.max_msg_sz sound like it might requre taking the min of host's and remote's attributes if they differ
|
36
|
+
self.device_attr = checkz(ib.ibv_query_device(self.ctx, ctypes.byref(da:=ib.struct_ibv_device_attr())), da)
|
37
|
+
self.port_attr = checkz(self.vctx.contents.query_port(self.ctx, DEFAULT_PORT, ctypes.byref(pa:=ib.struct_ibv_port_attr()), ctypes.sizeof(pa)), pa)
|
38
|
+
self.gid_attr = checkz(ib.ibv_query_gid(self.ctx, DEFAULT_PORT, DEFAULT_GID, ctypes.byref(ga:=ib.union_ibv_gid())), ga)
|
39
|
+
|
40
|
+
# Allocate protection domain
|
41
|
+
self.pd = ib.ibv_alloc_pd(self.ctx)
|
42
|
+
self.next_iova: int = IOVA_ALIGN # don't start at zero (nullptr)
|
43
|
+
|
44
|
+
# weakref(buf) => (iova, mr, mr_dealloc). mr_dealloc is kept here to avoid double freeing mrs that are deallocated in __del__
|
45
|
+
self.mrs: WeakKeyDictionary[Buffer, tuple[int, ctypes._Pointer[ib.struct_ibv_mr], weakref.finalize]] = WeakKeyDictionary()
|
46
|
+
|
47
|
+
# Default soft fd limit is 1024, which is not enough, set soft to hard (maximum allowed by the os)
|
48
|
+
IBCtx.rlimit_fix()
|
49
|
+
|
50
|
+
def __del__(self):
|
51
|
+
# must deallocate all mrs in protection domain before deallocating the protection domain
|
52
|
+
if hasattr(self, "mrs"): [fin() for _,_,fin in self.mrs.values()]
|
53
|
+
if hasattr(self, "pd"): ib.ibv_dealloc_pd(self.pd)
|
54
|
+
if hasattr(self, "ctx"): ib.ibv_close_device(self.ctx)
|
55
|
+
|
56
|
+
@functools.cache # run once
|
57
|
+
@staticmethod
|
58
|
+
def rlimit_fix():
|
59
|
+
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
60
|
+
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
|
61
|
+
if DEBUG>=2: print(f"IB: Increased fd limit from {soft} to {hard}")
|
62
|
+
|
63
|
+
def alloc_iova(self, size:int, required_offset:int):
|
64
|
+
iova = round_up(self.next_iova - required_offset, IOVA_ALIGN) + required_offset
|
65
|
+
self.next_iova = iova + size
|
66
|
+
return iova
|
67
|
+
|
68
|
+
def reg(self, buf:Buffer) -> tuple[int, ctypes._Pointer[ib.struct_ibv_mr]]:
|
69
|
+
buf = buf.base
|
70
|
+
if buf not in self.mrs:
|
71
|
+
if buf.nbytes > self.device_attr.max_mr_size: raise RuntimeError(f"Buffer too big: {buf.nbytes:#x} > {self.device_attr.max_mr_size:#x}")
|
72
|
+
if len(self.mrs) >= self.device_attr.max_mr: raise RuntimeError(f"Out of memory region cap: {len(self.mrs)} >= {self.device_attr.max_mr}")
|
73
|
+
# Local read is implied (but still have to create the memory region, except for short sends/writes with IBV_SEND_INLINE that are inlined by cpu)
|
74
|
+
mr_flags = ib.IBV_ACCESS_LOCAL_WRITE | ib.IBV_ACCESS_REMOTE_READ | ib.IBV_ACCESS_REMOTE_WRITE
|
75
|
+
match (dmaref:=buf.as_dmaref()):
|
76
|
+
case DMACPURef():
|
77
|
+
iova = self.alloc_iova(dmaref.size, dmaref.addr % IOVA_ALIGN)
|
78
|
+
mr = ib.ibv_reg_mr_iova2(self.pd, ctypes.c_void_p(dmaref.addr), dmaref.size, iova, mr_flags)
|
79
|
+
case DMAFdRef():
|
80
|
+
iova = self.alloc_iova(dmaref.size, dmaref.offset % IOVA_ALIGN)
|
81
|
+
mr = ib.ibv_reg_dmabuf_mr(self.pd, dmaref.offset, dmaref.size, iova, dmaref.fd, mr_flags)
|
82
|
+
case _: raise RuntimeError(f"Unknown type of dma ref: {dmaref}")
|
83
|
+
if not mr: raise RuntimeError(f"Couldn't register memory region for {buf} {dmaref} (errno={ctypes.get_errno()})")
|
84
|
+
self.mrs[buf] = (iova, mr, weakref.finalize(buf, ib.ibv_dereg_mr, mr))
|
85
|
+
return self.mrs[buf][0:2]
|
86
|
+
|
87
|
+
class IBConn:
|
88
|
+
def __init__(self, ctx:IBCtx):
|
89
|
+
self.ctx = ctx
|
90
|
+
|
91
|
+
# Create Completion Channel. It is a file descriptor that kernel sends notifications through, not a thing in infiniband spec, just linux-ism
|
92
|
+
self.comp_channel = ib.ibv_create_comp_channel(self.ctx.ctx)
|
93
|
+
# Create Completion Queue. When a Work Request with signaled flag is completed a Completion Queue Entry is pushed onto this queue
|
94
|
+
self.cq = ib.ibv_create_cq(self.ctx.ctx, _capacity:=256, _cq_context:=None, self.comp_channel, _comp_vector:=0)
|
95
|
+
self.pending_wrids: set[int] = set()
|
96
|
+
self.wrid_num: Iterator[int] = itertools.count(0) # wc_id is uint64, this will never overflow
|
97
|
+
|
98
|
+
# Create Queue Pair. It's the closest thing to a socket in infiniband with QP num being the closest thing to a port, except it's allocated by hca
|
99
|
+
qp_init_attrs_cap = ib.struct_ibv_qp_cap(max_send_wr=1024, max_recv_wr=64, max_send_sge=8, max_recv_sge=8, max_inline_data=64)
|
100
|
+
qp_init_attrs = ib.struct_ibv_qp_init_attr(send_cq=self.cq, recv_cq=self.cq, cap=qp_init_attrs_cap, qp_type=ib.IBV_QPT_RC) # Reliable Connection
|
101
|
+
self.qp = ib.ibv_create_qp(self.ctx.pd, ctypes.byref(qp_init_attrs))
|
102
|
+
self.qp_cap = qp_init_attrs.cap
|
103
|
+
|
104
|
+
# The most important thing about QPs is their state, when a new QP is created it's in the RESET state, before it can be properly used it has to go
|
105
|
+
# through Init, Ready To Receive, Ready To Send. A good docs on QP state machine: https://www.rdmamojo.com/2012/05/05/qp-state-machine/
|
106
|
+
|
107
|
+
# INIT
|
108
|
+
qp_access_flags = ib.IBV_ACCESS_REMOTE_WRITE | ib.IBV_ACCESS_REMOTE_READ
|
109
|
+
qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_INIT, port_num=DEFAULT_PORT, qp_access_flags=qp_access_flags)
|
110
|
+
checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_PORT | ib.IBV_QP_ACCESS_FLAGS | ib.IBV_QP_PKEY_INDEX))
|
111
|
+
|
112
|
+
self.gid, self.qp_num = bytes(self.ctx.gid_attr.raw), self.qp.contents.qp_num
|
113
|
+
|
114
|
+
# Exchange GID and QP num with remote. At least in RoCEv2 gid can be guessed from remote's ip, QP num can't.
|
115
|
+
|
116
|
+
def connect(self, remote_gid:bytes, remote_qp_num:int):
|
117
|
+
# RTR
|
118
|
+
qp_ah_attr_grh = ib.struct_ibv_global_route(hop_limit=1, dgid=ib.union_ibv_gid(raw=(ctypes.c_ubyte * 16)(*remote_gid)), sgid_index=DEFAULT_GID)
|
119
|
+
qp_ah_attr = ib.struct_ibv_ah_attr(is_global=1, port_num=DEFAULT_PORT, grh=qp_ah_attr_grh)
|
120
|
+
qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_RTR, path_mtu=ib.IBV_MTU_4096, dest_qp_num=remote_qp_num, rq_psn=0, max_dest_rd_atomic=1,
|
121
|
+
min_rnr_timer=12, ah_attr=qp_ah_attr)
|
122
|
+
checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_PATH_MTU | ib.IBV_QP_DEST_QPN | ib.IBV_QP_RQ_PSN | \
|
123
|
+
ib.IBV_QP_MAX_DEST_RD_ATOMIC | ib.IBV_QP_MIN_RNR_TIMER | ib.IBV_QP_AV))
|
124
|
+
|
125
|
+
# RTS
|
126
|
+
qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_RTS, timeout=14, retry_cnt=7, rnr_retry=7, sq_psn=0, max_rd_atomic=1)
|
127
|
+
checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_TIMEOUT | ib.IBV_QP_RETRY_CNT | ib.IBV_QP_RNR_RETRY | ib.IBV_QP_SQ_PSN | \
|
128
|
+
ib.IBV_QP_MAX_QP_RD_ATOMIC))
|
129
|
+
|
130
|
+
def __del__(self):
|
131
|
+
self.wait_cq() # need to wait for **everything** to complete before it's safe to dealloc queues and stuff
|
132
|
+
ib.ibv_destroy_qp(self.qp)
|
133
|
+
ib.ibv_destroy_cq(self.cq)
|
134
|
+
ib.ibv_destroy_comp_channel(self.comp_channel)
|
135
|
+
|
136
|
+
def next_wrid(self):
|
137
|
+
self.pending_wrids.add(wrid:=next(self.wrid_num))
|
138
|
+
return wrid
|
139
|
+
|
140
|
+
def wait_cq(self, wr_id: int|None=None):
|
141
|
+
while (wr_id in self.pending_wrids) if wr_id is not None else self.pending_wrids:
|
142
|
+
if self.ctx.ctx.contents.ops.poll_cq(self.cq, _num_entries:=1, ctypes.byref(wc:=ib.struct_ibv_wc())):
|
143
|
+
if wc.status != ib.IBV_WC_SUCCESS:
|
144
|
+
raise RuntimeError(f'Work Request completed with error: wr_id={wc.wr_id} status={ib.ibv_wc_status__enumvalues.get(wc.status, wc.status)}')
|
145
|
+
self.pending_wrids.remove(wc.wr_id)
|
146
|
+
|
147
|
+
def rdma_write(self, sgl:list[SGE]):
|
148
|
+
swr: ctypes._Pointer[ib.struct_ibv_send_wr]|None = None
|
149
|
+
swr_cnt, wr_id = 0, self.next_wrid()
|
150
|
+
def _post():
|
151
|
+
nonlocal swr, swr_cnt, wr_id
|
152
|
+
if swr is not None:
|
153
|
+
# The swr can be freed when this returns, the memory that sge points to can be unmapped after work completion is retrieved from cq
|
154
|
+
checkz(self.ctx.ctx.contents.ops.post_send(self.qp, swr, ctypes.byref(_bad_wr:=ctypes.POINTER(ib.struct_ibv_send_wr)())))
|
155
|
+
# TODO: async
|
156
|
+
self.wait_cq(wr_id)
|
157
|
+
swr, swr_cnt, wr_id = None, 0, self.next_wrid()
|
158
|
+
# Everything is in reverse for elegant chaining
|
159
|
+
for sg in reversed(sgl):
|
160
|
+
# Message size limit (max 2GB per ib spec, 1GB on tinybox mellanoxes) applies to both scatter-gather entries and entire wrs
|
161
|
+
for off in reversed(range(0, sg.size, self.ctx.port_attr.max_msg_sz)):
|
162
|
+
# Scatter-Gather Entry for local memory
|
163
|
+
sge = ctypes.pointer(ib.struct_ibv_sge(addr=sg.src_iova+off, length=min(sg.size-off, self.ctx.port_attr.max_msg_sz), lkey=sg.src_key))
|
164
|
+
# RDMA struct for remote memory
|
165
|
+
wr = ib.union_ibv_send_wr_wr(rdma=ib.struct_ibv_send_wr_1_rdma(remote_addr=sg.dst_iova+off, rkey=sg.dst_key))
|
166
|
+
# Signal (with chosen work request id) if it's the last wr (first in the loop since it's reversed)
|
167
|
+
wid, flags = (wr_id, ib.IBV_SEND_SIGNALED) if swr is None else (0, 0)
|
168
|
+
# Create Send Request
|
169
|
+
swr = ctypes.pointer(ib.struct_ibv_send_wr(opcode=ib.IBV_WR_RDMA_WRITE, sg_list=sge, num_sge=1, wr=wr, wr_id=wid, send_flags=flags, next=swr))
|
170
|
+
# Flush if queue is being overrun
|
171
|
+
if (swr_cnt:=swr_cnt + 1) >= self.qp_cap.max_send_wr: _post()
|
172
|
+
_post()
|
tinygrad/runtime/support/llvm.py
CHANGED
@@ -9,15 +9,14 @@ if sys.platform == 'win32':
|
|
9
9
|
raise FileNotFoundError('LLVM not found, you can install it with `winget install LLVM.LLVM` or point at a custom dll with LLVM_PATH')
|
10
10
|
elif OSX:
|
11
11
|
# Will raise FileNotFoundError if brew is not installed
|
12
|
-
brew_prefix = subprocess.check_output(['brew', '--prefix', 'llvm']).decode().strip()
|
13
12
|
# `brew --prefix` will return even if formula is not installed
|
14
|
-
if not os.path.exists(brew_prefix):
|
15
|
-
raise FileNotFoundError('LLVM not found, you can install it with `brew install llvm`')
|
13
|
+
if not os.path.exists(brew_prefix:=subprocess.check_output(['brew', '--prefix', 'llvm@20']).decode().strip()):
|
14
|
+
raise FileNotFoundError('LLVM not found, you can install it with `brew install llvm@20`')
|
16
15
|
LLVM_PATH: str|None = os.path.join(brew_prefix, 'lib', 'libLLVM.dylib')
|
17
16
|
else:
|
18
17
|
LLVM_PATH = ctypes.util.find_library('LLVM')
|
19
18
|
# use newer LLVM if possible
|
20
|
-
for ver in reversed(range(14,
|
19
|
+
for ver in reversed(range(14, 20+1)):
|
21
20
|
if LLVM_PATH is not None: break
|
22
21
|
LLVM_PATH = ctypes.util.find_library(f'LLVM-{ver}')
|
23
22
|
if LLVM_PATH is None:
|
@@ -0,0 +1,251 @@
|
|
1
|
+
import collections, functools, dataclasses
|
2
|
+
from typing import Any, ClassVar
|
3
|
+
from tinygrad.helpers import round_up, getenv
|
4
|
+
|
5
|
+
class TLSFAllocator:
|
6
|
+
"""
|
7
|
+
The allocator is based on the Two-Level Segregated Fit (TLSF) algorithm. The allocator maintains 2 level of buckets:
|
8
|
+
* 1st level is determined by the most significant bit of the size.
|
9
|
+
* 2nd level splits the covered memory of 1st level into @lv2_cnt entries.
|
10
|
+
|
11
|
+
For each allocation request, the allocator searches for the smallest block that can fit the requested size.
|
12
|
+
For each deallocation request, the allocator merges the block with its neighbors if they are free.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, size:int, base:int=0, block_size:int=16, lv2_cnt:int=16):
|
16
|
+
self.size, self.base, self.block_size, self.l2_cnt = size, base, block_size, lv2_cnt.bit_length()
|
17
|
+
self.storage:list = [collections.defaultdict(list) for _ in range(size.bit_length() + 1)]
|
18
|
+
self.lv1_entries:list[int] = [0] * len(self.storage)
|
19
|
+
|
20
|
+
# self.blocks is more like a linked list, where each entry is a contiguous block.
|
21
|
+
self.blocks:dict[int, tuple[int, int|None, int|None, bool]] = {0: (size, None, None, True)} # size, next, prev, is_free
|
22
|
+
self._insert_block(0, size)
|
23
|
+
|
24
|
+
@functools.cache
|
25
|
+
def lv1(self, size): return size.bit_length()
|
26
|
+
|
27
|
+
@functools.cache
|
28
|
+
def lv2(self, size): return (size - (1 << (size.bit_length() - 1))) // (1 << max(0, size.bit_length() - self.l2_cnt))
|
29
|
+
|
30
|
+
def _insert_block(self, start:int, size:int, prev:int|None=None):
|
31
|
+
if prev is None: prev = self.blocks[start][2]
|
32
|
+
self.storage[self.lv1(size)][self.lv2(size)].append(start)
|
33
|
+
self.lv1_entries[self.lv1(size)] += 1
|
34
|
+
self.blocks[start] = (size, start + size, prev, True)
|
35
|
+
return self
|
36
|
+
|
37
|
+
def _remove_block(self, start:int, size:int, prev:int|None=None):
|
38
|
+
if prev is None: prev = self.blocks[start][2]
|
39
|
+
self.storage[self.lv1(size)][self.lv2(size)].remove(start)
|
40
|
+
self.lv1_entries[self.lv1(size)] -= 1
|
41
|
+
self.blocks[start] = (size, start + size, prev, False)
|
42
|
+
return self
|
43
|
+
|
44
|
+
def _split_block(self, start:int, size:int, new_size:int):
|
45
|
+
nxt = self.blocks[start][1]
|
46
|
+
assert self.blocks[start][3], "block must be free"
|
47
|
+
self._remove_block(start, size)._insert_block(start, new_size)._insert_block(start + new_size, size - new_size, prev=start)
|
48
|
+
if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start + new_size, self.blocks[nxt][3])
|
49
|
+
return self
|
50
|
+
|
51
|
+
def _merge_right(self, start:int):
|
52
|
+
size, nxt, _, is_free = self.blocks[start]
|
53
|
+
assert is_free, "block must be free"
|
54
|
+
|
55
|
+
while is_free and nxt in self.blocks:
|
56
|
+
if (blk:=self.blocks[nxt])[3] is False: break
|
57
|
+
self._remove_block(start, size)._remove_block(nxt, blk[0])._insert_block(start, size:=size + blk[0])
|
58
|
+
assert self.blocks[start][1] == blk[1]
|
59
|
+
_, nxt, _, _ = self.blocks.pop(nxt)
|
60
|
+
|
61
|
+
if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start, self.blocks[nxt][3])
|
62
|
+
|
63
|
+
def _merge_block(self, start:int):
|
64
|
+
# Go left while blocks are free. Then merge all them right.
|
65
|
+
while (x:=self.blocks[start][2]) is not None and self.blocks[x][3] is True: start = x
|
66
|
+
self._merge_right(start)
|
67
|
+
|
68
|
+
def alloc(self, req_size:int, align:int=1) -> int:
|
69
|
+
req_size = max(self.block_size, req_size) # at least block size.
|
70
|
+
size = max(self.block_size, req_size + align - 1)
|
71
|
+
|
72
|
+
# Round up the allocation size to the next bucket, so any entry there can fit the requested size.
|
73
|
+
size = round_up(size, (1 << size.bit_length() - self.l2_cnt))
|
74
|
+
|
75
|
+
# Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found.
|
76
|
+
for l1 in range(self.lv1(size), len(self.storage)):
|
77
|
+
if self.lv1_entries[l1] == 0: continue
|
78
|
+
for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)):
|
79
|
+
if len(self.storage[l1][l2]) > 0:
|
80
|
+
nsize = self.blocks[self.storage[l1][l2][0]][0]
|
81
|
+
assert nsize >= size, "block must be larger"
|
82
|
+
|
83
|
+
# Block start address.
|
84
|
+
start = self.storage[l1][l2][0]
|
85
|
+
|
86
|
+
# If request contains alignment, split the block into two parts.
|
87
|
+
if (new_start:=round_up(start, align)) != start:
|
88
|
+
self._split_block(start, nsize, new_start - start)
|
89
|
+
start, nsize = new_start, self.blocks[new_start][0]
|
90
|
+
|
91
|
+
# If the block is larger than the requested size, split it into two parts.
|
92
|
+
if nsize > req_size: self._split_block(start, nsize, req_size)
|
93
|
+
self._remove_block(start, req_size) # Mark the block as allocated.
|
94
|
+
return start + self.base
|
95
|
+
raise MemoryError(f"Can't allocate {req_size} bytes")
|
96
|
+
|
97
|
+
def free(self, start:int):
|
98
|
+
self._insert_block(start - self.base, self.blocks[start - self.base][0])._merge_block(start - self.base)
|
99
|
+
|
100
|
+
# Memory Managment
|
101
|
+
|
102
|
+
@dataclasses.dataclass(frozen=True)
|
103
|
+
class VirtMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702
|
104
|
+
|
105
|
+
class PageTableTraverseContext:
|
106
|
+
def __init__(self, dev, pt, vaddr, create_pts=False, free_pts=False, boot=False):
|
107
|
+
self.dev, self.vaddr, self.create_pts, self.free_pts, self.boot = dev, vaddr - dev.mm.va_base, create_pts, free_pts, boot
|
108
|
+
self.pt_stack:list[tuple[Any, int, int]] = [(pt, self._pt_pte_idx(pt, self.vaddr), self._pt_pte_size(pt))]
|
109
|
+
|
110
|
+
def _pt_pte_cnt(self, lv): return self.dev.mm.pte_cnt[lv]
|
111
|
+
def _pt_pte_size(self, pt): return self.dev.mm.pte_covers[pt.lv]
|
112
|
+
def _pt_pte_idx(self, pt, va): return (va // self._pt_pte_size(pt)) % self._pt_pte_cnt(pt.lv)
|
113
|
+
|
114
|
+
def level_down(self):
|
115
|
+
pt, pte_idx, _ = self.pt_stack[-1]
|
116
|
+
|
117
|
+
if not pt.valid(pte_idx):
|
118
|
+
assert self.create_pts, "Not allowed to create new page table"
|
119
|
+
pt.set_entry(pte_idx, self.dev.mm.palloc(0x1000, zero=True, boot=self.boot), table=True, valid=True)
|
120
|
+
|
121
|
+
assert not pt.is_huge_page(pte_idx), f"Must be table pt={pt.paddr:#x}, {pt.lv=} {pte_idx=} {pt.read_fields(pte_idx)}"
|
122
|
+
child_page_table = self.dev.mm.pt_t(self.dev, pt.address(pte_idx), lv=pt.lv+1)
|
123
|
+
|
124
|
+
self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table)))
|
125
|
+
return self.pt_stack[-1]
|
126
|
+
|
127
|
+
def _try_free_pt(self) -> bool:
|
128
|
+
pt, _, _ = self.pt_stack[-1]
|
129
|
+
if self.free_pts and pt != self.dev.mm.root_page_table and all(not pt.valid(i) for i in range(self._pt_pte_cnt(self.pt_stack[-1][0].lv))):
|
130
|
+
self.dev.mm.pfree(pt.paddr)
|
131
|
+
parent_pt, parent_pte_idx, _ = self.pt_stack[-2]
|
132
|
+
parent_pt.set_entry(parent_pte_idx, 0x0, valid=False)
|
133
|
+
return True
|
134
|
+
return False
|
135
|
+
|
136
|
+
def level_up(self):
|
137
|
+
while self._try_free_pt() or self.pt_stack[-1][1] == self._pt_pte_cnt(self.pt_stack[-1][0].lv):
|
138
|
+
pt, pt_cnt, _ = self.pt_stack.pop()
|
139
|
+
if pt_cnt == self._pt_pte_cnt(pt.lv): self.pt_stack[-1] = (self.pt_stack[-1][0], self.pt_stack[-1][1] + 1, self.pt_stack[-1][2])
|
140
|
+
|
141
|
+
def next(self, size:int, paddr:int|None=None, off:int=0):
|
142
|
+
while size > 0:
|
143
|
+
pt, pte_idx, pte_covers = self.pt_stack[-1]
|
144
|
+
if self.create_pts:
|
145
|
+
assert paddr is not None, "paddr must be provided when allocating new page tables"
|
146
|
+
while pte_covers > size or not pt.supports_huge_page(paddr+off) or self.vaddr&(pte_covers-1) != 0: pt, pte_idx, pte_covers = self.level_down()
|
147
|
+
else:
|
148
|
+
while not pt.is_huge_page(pte_idx): pt, pte_idx, pte_covers = self.level_down()
|
149
|
+
|
150
|
+
entries = min(size // pte_covers, self._pt_pte_cnt(pt.lv) - pte_idx)
|
151
|
+
assert entries > 0, f"Invalid entries {size=:#x}, {pte_covers=:#x}"
|
152
|
+
yield off, pt, pte_idx, entries, pte_covers
|
153
|
+
|
154
|
+
size, off, self.vaddr = size - entries * pte_covers, off + entries * pte_covers, self.vaddr + entries * pte_covers
|
155
|
+
self.pt_stack[-1] = (pt, pte_idx + entries, pte_covers)
|
156
|
+
self.level_up()
|
157
|
+
|
158
|
+
class MemoryManager:
|
159
|
+
va_allocator: ClassVar[TLSFAllocator|None] = None
|
160
|
+
|
161
|
+
def __init__(self, dev, vram_size:int, boot_size:int, pt_t, va_bits:int, va_shifts:list[int], va_base:int,
|
162
|
+
palloc_ranges:list[tuple[int, int]], first_lv:int=0):
|
163
|
+
self.dev, self.vram_size, self.va_shifts, self.va_base, lvl_msb = dev, vram_size, va_shifts, va_base, va_shifts + [va_bits + 1]
|
164
|
+
self.pte_covers, self.pte_cnt = [1 << x for x in va_shifts][::-1], [1 << (lvl_msb[i+1] - lvl_msb[i]) for i in range(len(lvl_msb) - 1)][::-1]
|
165
|
+
self.pt_t, self.palloc_ranges, self.level_cnt, self.va_bits = pt_t, palloc_ranges, len(va_shifts), va_bits
|
166
|
+
|
167
|
+
self.boot_allocator = TLSFAllocator(boot_size, base=0) # per device
|
168
|
+
self.pa_allocator = TLSFAllocator(vram_size - (64 << 20), base=self.boot_allocator.size) # per device
|
169
|
+
self.root_page_table = pt_t(self.dev, self.palloc(0x1000, zero=not self.dev.smi_dev, boot=True), lv=first_lv)
|
170
|
+
|
171
|
+
def _frag_size(self, va, sz, must_cover=True):
|
172
|
+
"""
|
173
|
+
Calculate the tlb fragment size for a given virtual address and size.
|
174
|
+
If must_cover is True, the fragment size must cover the size, otherwise the biggest fragment size that fits the size is returned.
|
175
|
+
Fragment 0 is 4KB, 1 is 8KB and so on.
|
176
|
+
"""
|
177
|
+
va_pwr2_div, sz_pwr2_div, sz_pwr2_max = va & -(va) if va > 0 else (1 << 63), sz & -(sz), (1 << (sz.bit_length() - 1))
|
178
|
+
return (min(va_pwr2_div, sz_pwr2_div) if must_cover else min(va_pwr2_div, sz_pwr2_max)).bit_length() - 1 - 12
|
179
|
+
|
180
|
+
def page_tables(self, vaddr:int, size:int):
|
181
|
+
ctx = PageTableTraverseContext(self.dev, self.root_page_table, vaddr, create_pts=True)
|
182
|
+
for _ in ctx.next(size, paddr=0): return [pt for pt, _, _ in ctx.pt_stack]
|
183
|
+
|
184
|
+
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False, boot=False) -> VirtMapping:
|
185
|
+
if getenv("MM_DEBUG", 0): print(f"mm {self.dev.devfmt}: mapping {vaddr=:#x} ({size=:#x})")
|
186
|
+
|
187
|
+
assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
|
188
|
+
|
189
|
+
ctx = PageTableTraverseContext(self.dev, self.root_page_table, vaddr, create_pts=True, boot=boot)
|
190
|
+
for paddr, psize in paddrs:
|
191
|
+
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize, paddr=paddr):
|
192
|
+
for pte_off in range(pte_cnt):
|
193
|
+
assert not pt.valid(pte_idx + pte_off), f"PTE already mapped: {pt.entry(pte_idx + pte_off):#x}"
|
194
|
+
pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, system=system, snooped=snooped,
|
195
|
+
frag=self._frag_size(ctx.vaddr+off, pte_cnt * pte_covers), valid=True)
|
196
|
+
|
197
|
+
self.on_range_mapped()
|
198
|
+
return VirtMapping(vaddr, size, paddrs, uncached=uncached, system=system, snooped=snooped)
|
199
|
+
|
200
|
+
def unmap_range(self, vaddr:int, size:int):
|
201
|
+
if getenv("MM_DEBUG", 0): print(f"mm {self.dev.devfmt}: unmapping {vaddr=:#x} ({size=:#x})")
|
202
|
+
|
203
|
+
ctx = PageTableTraverseContext(self.dev, self.root_page_table, vaddr, free_pts=True)
|
204
|
+
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size):
|
205
|
+
for pte_id in range(pte_idx, pte_idx + pte_cnt):
|
206
|
+
assert pt.valid(pte_id), f"PTE not mapped: {pt.entry(pte_id):#x}"
|
207
|
+
pt.set_entry(pte_id, paddr=0x0, valid=False)
|
208
|
+
|
209
|
+
def on_range_mapped(self): pass
|
210
|
+
|
211
|
+
@classmethod
|
212
|
+
def alloc_vaddr(cls, size:int, align=0x1000) -> int:
|
213
|
+
assert cls.va_allocator is not None, "must be set it"
|
214
|
+
return cls.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align))
|
215
|
+
|
216
|
+
def valloc(self, size:int, align=0x1000, uncached=False, contiguous=False) -> VirtMapping:
|
217
|
+
# Alloc physical memory and map it to the virtual address
|
218
|
+
va = self.alloc_vaddr(size:=round_up(size, 0x1000), align)
|
219
|
+
|
220
|
+
if contiguous: paddrs = [(self.palloc(size, zero=True), size)]
|
221
|
+
else:
|
222
|
+
# Traverse the PT to find the largest contiguous sizes we need to allocate. Try to allocate the longest segment to reduce TLB pressure.
|
223
|
+
nxt_range, rem_size, paddrs = 0, size, []
|
224
|
+
while rem_size > 0:
|
225
|
+
while self.palloc_ranges[nxt_range][0] > rem_size: nxt_range += 1
|
226
|
+
|
227
|
+
try: paddrs += [(self.palloc(try_sz:=self.palloc_ranges[nxt_range][0], self.palloc_ranges[nxt_range][1], zero=False), try_sz)]
|
228
|
+
except MemoryError:
|
229
|
+
# Move to a smaller size and try again.
|
230
|
+
nxt_range += 1
|
231
|
+
if nxt_range == len(self.palloc_ranges):
|
232
|
+
for paddr, _ in paddrs: self.pa_allocator.free(paddr)
|
233
|
+
raise MemoryError(f"Failed to allocate memory. (total allocation size={size:#x}, current try={self.palloc_ranges[nxt_range-1]})")
|
234
|
+
continue
|
235
|
+
rem_size -= self.palloc_ranges[nxt_range][0]
|
236
|
+
|
237
|
+
return self.map_range(va, size, paddrs, uncached=uncached)
|
238
|
+
|
239
|
+
def vfree(self, vm:VirtMapping):
|
240
|
+
assert self.va_allocator is not None, "must be set it"
|
241
|
+
self.unmap_range(vm.va_addr, vm.size)
|
242
|
+
self.va_allocator.free(vm.va_addr)
|
243
|
+
for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr)
|
244
|
+
|
245
|
+
def palloc(self, size:int, align:int=0x1000, zero=True, boot=False) -> int:
|
246
|
+
assert self.dev.is_booting == boot, "During booting, only boot memory can be allocated"
|
247
|
+
paddr = (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align)
|
248
|
+
if zero: self.dev.vram[paddr:paddr+size] = bytes(size)
|
249
|
+
return paddr
|
250
|
+
|
251
|
+
def pfree(self, paddr:int): self.pa_allocator.free(paddr)
|
File without changes
|