tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,172 @@
1
+ from __future__ import annotations
2
+ import resource, ctypes, weakref, functools, itertools, tinygrad.runtime.autogen.ib as ib
3
+ from typing import Iterator
4
+ from dataclasses import dataclass
5
+ from weakref import WeakKeyDictionary
6
+ from tinygrad.device import Buffer, DMACPURef, DMAFdRef
7
+ from tinygrad.helpers import getenv, round_up, DEBUG
8
+
9
+ DEFAULT_PORT, DEFAULT_GID = getenv("DEFAULT_PORT", 1), getenv("DEFAULT_GID", 3) # DEFAULT_GID=0 for RXE
10
+ IOVA_ALIGN = resource.getpagesize()
11
+
12
+ def checkz(x, ret=None):
13
+ assert x == 0, f'{x} != 0 (errno {ctypes.get_errno()})'
14
+ return ret
15
+
16
+ @dataclass(frozen=True)
17
+ class SGE:
18
+ dst_iova: int
19
+ dst_key: int
20
+ src_iova: int
21
+ src_key: int
22
+ size: int
23
+
24
+ class IBCtx:
25
+ def __init__(self, idx:int):
26
+ # Open the device (aka Host Channel Adapter in ib-speak)
27
+ devs = ib.ibv_get_device_list(ctypes.byref(ndevs:=ctypes.c_int32()))
28
+ if idx >= ndevs.value: raise IndexError(f"{idx} > {ndevs.value}")
29
+ self.ctx = ib.ibv_open_device(devs[idx])
30
+ ib.ibv_free_device_list(devs)
31
+
32
+ # HACK: remove this (and all usage of `ctx.contents.ops`) when clang2py can deal with `static inline` wrapper-functions
33
+ self.vctx = ctypes.cast(ctypes.addressof(self.ctx.contents) - ib.struct_verbs_context.context.offset, ctypes.POINTER(ib.struct_verbs_context))
34
+
35
+ # Get attributes. Something like port_attr.max_msg_sz sound like it might requre taking the min of host's and remote's attributes if they differ
36
+ self.device_attr = checkz(ib.ibv_query_device(self.ctx, ctypes.byref(da:=ib.struct_ibv_device_attr())), da)
37
+ self.port_attr = checkz(self.vctx.contents.query_port(self.ctx, DEFAULT_PORT, ctypes.byref(pa:=ib.struct_ibv_port_attr()), ctypes.sizeof(pa)), pa)
38
+ self.gid_attr = checkz(ib.ibv_query_gid(self.ctx, DEFAULT_PORT, DEFAULT_GID, ctypes.byref(ga:=ib.union_ibv_gid())), ga)
39
+
40
+ # Allocate protection domain
41
+ self.pd = ib.ibv_alloc_pd(self.ctx)
42
+ self.next_iova: int = IOVA_ALIGN # don't start at zero (nullptr)
43
+
44
+ # weakref(buf) => (iova, mr, mr_dealloc). mr_dealloc is kept here to avoid double freeing mrs that are deallocated in __del__
45
+ self.mrs: WeakKeyDictionary[Buffer, tuple[int, ctypes._Pointer[ib.struct_ibv_mr], weakref.finalize]] = WeakKeyDictionary()
46
+
47
+ # Default soft fd limit is 1024, which is not enough, set soft to hard (maximum allowed by the os)
48
+ IBCtx.rlimit_fix()
49
+
50
+ def __del__(self):
51
+ # must deallocate all mrs in protection domain before deallocating the protection domain
52
+ if hasattr(self, "mrs"): [fin() for _,_,fin in self.mrs.values()]
53
+ if hasattr(self, "pd"): ib.ibv_dealloc_pd(self.pd)
54
+ if hasattr(self, "ctx"): ib.ibv_close_device(self.ctx)
55
+
56
+ @functools.cache # run once
57
+ @staticmethod
58
+ def rlimit_fix():
59
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
60
+ resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
61
+ if DEBUG>=2: print(f"IB: Increased fd limit from {soft} to {hard}")
62
+
63
+ def alloc_iova(self, size:int, required_offset:int):
64
+ iova = round_up(self.next_iova - required_offset, IOVA_ALIGN) + required_offset
65
+ self.next_iova = iova + size
66
+ return iova
67
+
68
+ def reg(self, buf:Buffer) -> tuple[int, ctypes._Pointer[ib.struct_ibv_mr]]:
69
+ buf = buf.base
70
+ if buf not in self.mrs:
71
+ if buf.nbytes > self.device_attr.max_mr_size: raise RuntimeError(f"Buffer too big: {buf.nbytes:#x} > {self.device_attr.max_mr_size:#x}")
72
+ if len(self.mrs) >= self.device_attr.max_mr: raise RuntimeError(f"Out of memory region cap: {len(self.mrs)} >= {self.device_attr.max_mr}")
73
+ # Local read is implied (but still have to create the memory region, except for short sends/writes with IBV_SEND_INLINE that are inlined by cpu)
74
+ mr_flags = ib.IBV_ACCESS_LOCAL_WRITE | ib.IBV_ACCESS_REMOTE_READ | ib.IBV_ACCESS_REMOTE_WRITE
75
+ match (dmaref:=buf.as_dmaref()):
76
+ case DMACPURef():
77
+ iova = self.alloc_iova(dmaref.size, dmaref.addr % IOVA_ALIGN)
78
+ mr = ib.ibv_reg_mr_iova2(self.pd, ctypes.c_void_p(dmaref.addr), dmaref.size, iova, mr_flags)
79
+ case DMAFdRef():
80
+ iova = self.alloc_iova(dmaref.size, dmaref.offset % IOVA_ALIGN)
81
+ mr = ib.ibv_reg_dmabuf_mr(self.pd, dmaref.offset, dmaref.size, iova, dmaref.fd, mr_flags)
82
+ case _: raise RuntimeError(f"Unknown type of dma ref: {dmaref}")
83
+ if not mr: raise RuntimeError(f"Couldn't register memory region for {buf} {dmaref} (errno={ctypes.get_errno()})")
84
+ self.mrs[buf] = (iova, mr, weakref.finalize(buf, ib.ibv_dereg_mr, mr))
85
+ return self.mrs[buf][0:2]
86
+
87
+ class IBConn:
88
+ def __init__(self, ctx:IBCtx):
89
+ self.ctx = ctx
90
+
91
+ # Create Completion Channel. It is a file descriptor that kernel sends notifications through, not a thing in infiniband spec, just linux-ism
92
+ self.comp_channel = ib.ibv_create_comp_channel(self.ctx.ctx)
93
+ # Create Completion Queue. When a Work Request with signaled flag is completed a Completion Queue Entry is pushed onto this queue
94
+ self.cq = ib.ibv_create_cq(self.ctx.ctx, _capacity:=256, _cq_context:=None, self.comp_channel, _comp_vector:=0)
95
+ self.pending_wrids: set[int] = set()
96
+ self.wrid_num: Iterator[int] = itertools.count(0) # wc_id is uint64, this will never overflow
97
+
98
+ # Create Queue Pair. It's the closest thing to a socket in infiniband with QP num being the closest thing to a port, except it's allocated by hca
99
+ qp_init_attrs_cap = ib.struct_ibv_qp_cap(max_send_wr=1024, max_recv_wr=64, max_send_sge=8, max_recv_sge=8, max_inline_data=64)
100
+ qp_init_attrs = ib.struct_ibv_qp_init_attr(send_cq=self.cq, recv_cq=self.cq, cap=qp_init_attrs_cap, qp_type=ib.IBV_QPT_RC) # Reliable Connection
101
+ self.qp = ib.ibv_create_qp(self.ctx.pd, ctypes.byref(qp_init_attrs))
102
+ self.qp_cap = qp_init_attrs.cap
103
+
104
+ # The most important thing about QPs is their state, when a new QP is created it's in the RESET state, before it can be properly used it has to go
105
+ # through Init, Ready To Receive, Ready To Send. A good docs on QP state machine: https://www.rdmamojo.com/2012/05/05/qp-state-machine/
106
+
107
+ # INIT
108
+ qp_access_flags = ib.IBV_ACCESS_REMOTE_WRITE | ib.IBV_ACCESS_REMOTE_READ
109
+ qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_INIT, port_num=DEFAULT_PORT, qp_access_flags=qp_access_flags)
110
+ checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_PORT | ib.IBV_QP_ACCESS_FLAGS | ib.IBV_QP_PKEY_INDEX))
111
+
112
+ self.gid, self.qp_num = bytes(self.ctx.gid_attr.raw), self.qp.contents.qp_num
113
+
114
+ # Exchange GID and QP num with remote. At least in RoCEv2 gid can be guessed from remote's ip, QP num can't.
115
+
116
+ def connect(self, remote_gid:bytes, remote_qp_num:int):
117
+ # RTR
118
+ qp_ah_attr_grh = ib.struct_ibv_global_route(hop_limit=1, dgid=ib.union_ibv_gid(raw=(ctypes.c_ubyte * 16)(*remote_gid)), sgid_index=DEFAULT_GID)
119
+ qp_ah_attr = ib.struct_ibv_ah_attr(is_global=1, port_num=DEFAULT_PORT, grh=qp_ah_attr_grh)
120
+ qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_RTR, path_mtu=ib.IBV_MTU_4096, dest_qp_num=remote_qp_num, rq_psn=0, max_dest_rd_atomic=1,
121
+ min_rnr_timer=12, ah_attr=qp_ah_attr)
122
+ checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_PATH_MTU | ib.IBV_QP_DEST_QPN | ib.IBV_QP_RQ_PSN | \
123
+ ib.IBV_QP_MAX_DEST_RD_ATOMIC | ib.IBV_QP_MIN_RNR_TIMER | ib.IBV_QP_AV))
124
+
125
+ # RTS
126
+ qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_RTS, timeout=14, retry_cnt=7, rnr_retry=7, sq_psn=0, max_rd_atomic=1)
127
+ checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_TIMEOUT | ib.IBV_QP_RETRY_CNT | ib.IBV_QP_RNR_RETRY | ib.IBV_QP_SQ_PSN | \
128
+ ib.IBV_QP_MAX_QP_RD_ATOMIC))
129
+
130
+ def __del__(self):
131
+ self.wait_cq() # need to wait for **everything** to complete before it's safe to dealloc queues and stuff
132
+ ib.ibv_destroy_qp(self.qp)
133
+ ib.ibv_destroy_cq(self.cq)
134
+ ib.ibv_destroy_comp_channel(self.comp_channel)
135
+
136
+ def next_wrid(self):
137
+ self.pending_wrids.add(wrid:=next(self.wrid_num))
138
+ return wrid
139
+
140
+ def wait_cq(self, wr_id: int|None=None):
141
+ while (wr_id in self.pending_wrids) if wr_id is not None else self.pending_wrids:
142
+ if self.ctx.ctx.contents.ops.poll_cq(self.cq, _num_entries:=1, ctypes.byref(wc:=ib.struct_ibv_wc())):
143
+ if wc.status != ib.IBV_WC_SUCCESS:
144
+ raise RuntimeError(f'Work Request completed with error: wr_id={wc.wr_id} status={ib.ibv_wc_status__enumvalues.get(wc.status, wc.status)}')
145
+ self.pending_wrids.remove(wc.wr_id)
146
+
147
+ def rdma_write(self, sgl:list[SGE]):
148
+ swr: ctypes._Pointer[ib.struct_ibv_send_wr]|None = None
149
+ swr_cnt, wr_id = 0, self.next_wrid()
150
+ def _post():
151
+ nonlocal swr, swr_cnt, wr_id
152
+ if swr is not None:
153
+ # The swr can be freed when this returns, the memory that sge points to can be unmapped after work completion is retrieved from cq
154
+ checkz(self.ctx.ctx.contents.ops.post_send(self.qp, swr, ctypes.byref(_bad_wr:=ctypes.POINTER(ib.struct_ibv_send_wr)())))
155
+ # TODO: async
156
+ self.wait_cq(wr_id)
157
+ swr, swr_cnt, wr_id = None, 0, self.next_wrid()
158
+ # Everything is in reverse for elegant chaining
159
+ for sg in reversed(sgl):
160
+ # Message size limit (max 2GB per ib spec, 1GB on tinybox mellanoxes) applies to both scatter-gather entries and entire wrs
161
+ for off in reversed(range(0, sg.size, self.ctx.port_attr.max_msg_sz)):
162
+ # Scatter-Gather Entry for local memory
163
+ sge = ctypes.pointer(ib.struct_ibv_sge(addr=sg.src_iova+off, length=min(sg.size-off, self.ctx.port_attr.max_msg_sz), lkey=sg.src_key))
164
+ # RDMA struct for remote memory
165
+ wr = ib.union_ibv_send_wr_wr(rdma=ib.struct_ibv_send_wr_1_rdma(remote_addr=sg.dst_iova+off, rkey=sg.dst_key))
166
+ # Signal (with chosen work request id) if it's the last wr (first in the loop since it's reversed)
167
+ wid, flags = (wr_id, ib.IBV_SEND_SIGNALED) if swr is None else (0, 0)
168
+ # Create Send Request
169
+ swr = ctypes.pointer(ib.struct_ibv_send_wr(opcode=ib.IBV_WR_RDMA_WRITE, sg_list=sge, num_sge=1, wr=wr, wr_id=wid, send_flags=flags, next=swr))
170
+ # Flush if queue is being overrun
171
+ if (swr_cnt:=swr_cnt + 1) >= self.qp_cap.max_send_wr: _post()
172
+ _post()
@@ -9,15 +9,14 @@ if sys.platform == 'win32':
9
9
  raise FileNotFoundError('LLVM not found, you can install it with `winget install LLVM.LLVM` or point at a custom dll with LLVM_PATH')
10
10
  elif OSX:
11
11
  # Will raise FileNotFoundError if brew is not installed
12
- brew_prefix = subprocess.check_output(['brew', '--prefix', 'llvm']).decode().strip()
13
12
  # `brew --prefix` will return even if formula is not installed
14
- if not os.path.exists(brew_prefix):
15
- raise FileNotFoundError('LLVM not found, you can install it with `brew install llvm`')
13
+ if not os.path.exists(brew_prefix:=subprocess.check_output(['brew', '--prefix', 'llvm@20']).decode().strip()):
14
+ raise FileNotFoundError('LLVM not found, you can install it with `brew install llvm@20`')
16
15
  LLVM_PATH: str|None = os.path.join(brew_prefix, 'lib', 'libLLVM.dylib')
17
16
  else:
18
17
  LLVM_PATH = ctypes.util.find_library('LLVM')
19
18
  # use newer LLVM if possible
20
- for ver in reversed(range(14, 19+1)):
19
+ for ver in reversed(range(14, 20+1)):
21
20
  if LLVM_PATH is not None: break
22
21
  LLVM_PATH = ctypes.util.find_library(f'LLVM-{ver}')
23
22
  if LLVM_PATH is None:
@@ -0,0 +1,251 @@
1
+ import collections, functools, dataclasses
2
+ from typing import Any, ClassVar
3
+ from tinygrad.helpers import round_up, getenv
4
+
5
+ class TLSFAllocator:
6
+ """
7
+ The allocator is based on the Two-Level Segregated Fit (TLSF) algorithm. The allocator maintains 2 level of buckets:
8
+ * 1st level is determined by the most significant bit of the size.
9
+ * 2nd level splits the covered memory of 1st level into @lv2_cnt entries.
10
+
11
+ For each allocation request, the allocator searches for the smallest block that can fit the requested size.
12
+ For each deallocation request, the allocator merges the block with its neighbors if they are free.
13
+ """
14
+
15
+ def __init__(self, size:int, base:int=0, block_size:int=16, lv2_cnt:int=16):
16
+ self.size, self.base, self.block_size, self.l2_cnt = size, base, block_size, lv2_cnt.bit_length()
17
+ self.storage:list = [collections.defaultdict(list) for _ in range(size.bit_length() + 1)]
18
+ self.lv1_entries:list[int] = [0] * len(self.storage)
19
+
20
+ # self.blocks is more like a linked list, where each entry is a contiguous block.
21
+ self.blocks:dict[int, tuple[int, int|None, int|None, bool]] = {0: (size, None, None, True)} # size, next, prev, is_free
22
+ self._insert_block(0, size)
23
+
24
+ @functools.cache
25
+ def lv1(self, size): return size.bit_length()
26
+
27
+ @functools.cache
28
+ def lv2(self, size): return (size - (1 << (size.bit_length() - 1))) // (1 << max(0, size.bit_length() - self.l2_cnt))
29
+
30
+ def _insert_block(self, start:int, size:int, prev:int|None=None):
31
+ if prev is None: prev = self.blocks[start][2]
32
+ self.storage[self.lv1(size)][self.lv2(size)].append(start)
33
+ self.lv1_entries[self.lv1(size)] += 1
34
+ self.blocks[start] = (size, start + size, prev, True)
35
+ return self
36
+
37
+ def _remove_block(self, start:int, size:int, prev:int|None=None):
38
+ if prev is None: prev = self.blocks[start][2]
39
+ self.storage[self.lv1(size)][self.lv2(size)].remove(start)
40
+ self.lv1_entries[self.lv1(size)] -= 1
41
+ self.blocks[start] = (size, start + size, prev, False)
42
+ return self
43
+
44
+ def _split_block(self, start:int, size:int, new_size:int):
45
+ nxt = self.blocks[start][1]
46
+ assert self.blocks[start][3], "block must be free"
47
+ self._remove_block(start, size)._insert_block(start, new_size)._insert_block(start + new_size, size - new_size, prev=start)
48
+ if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start + new_size, self.blocks[nxt][3])
49
+ return self
50
+
51
+ def _merge_right(self, start:int):
52
+ size, nxt, _, is_free = self.blocks[start]
53
+ assert is_free, "block must be free"
54
+
55
+ while is_free and nxt in self.blocks:
56
+ if (blk:=self.blocks[nxt])[3] is False: break
57
+ self._remove_block(start, size)._remove_block(nxt, blk[0])._insert_block(start, size:=size + blk[0])
58
+ assert self.blocks[start][1] == blk[1]
59
+ _, nxt, _, _ = self.blocks.pop(nxt)
60
+
61
+ if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start, self.blocks[nxt][3])
62
+
63
+ def _merge_block(self, start:int):
64
+ # Go left while blocks are free. Then merge all them right.
65
+ while (x:=self.blocks[start][2]) is not None and self.blocks[x][3] is True: start = x
66
+ self._merge_right(start)
67
+
68
+ def alloc(self, req_size:int, align:int=1) -> int:
69
+ req_size = max(self.block_size, req_size) # at least block size.
70
+ size = max(self.block_size, req_size + align - 1)
71
+
72
+ # Round up the allocation size to the next bucket, so any entry there can fit the requested size.
73
+ size = round_up(size, (1 << size.bit_length() - self.l2_cnt))
74
+
75
+ # Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found.
76
+ for l1 in range(self.lv1(size), len(self.storage)):
77
+ if self.lv1_entries[l1] == 0: continue
78
+ for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)):
79
+ if len(self.storage[l1][l2]) > 0:
80
+ nsize = self.blocks[self.storage[l1][l2][0]][0]
81
+ assert nsize >= size, "block must be larger"
82
+
83
+ # Block start address.
84
+ start = self.storage[l1][l2][0]
85
+
86
+ # If request contains alignment, split the block into two parts.
87
+ if (new_start:=round_up(start, align)) != start:
88
+ self._split_block(start, nsize, new_start - start)
89
+ start, nsize = new_start, self.blocks[new_start][0]
90
+
91
+ # If the block is larger than the requested size, split it into two parts.
92
+ if nsize > req_size: self._split_block(start, nsize, req_size)
93
+ self._remove_block(start, req_size) # Mark the block as allocated.
94
+ return start + self.base
95
+ raise MemoryError(f"Can't allocate {req_size} bytes")
96
+
97
+ def free(self, start:int):
98
+ self._insert_block(start - self.base, self.blocks[start - self.base][0])._merge_block(start - self.base)
99
+
100
+ # Memory Managment
101
+
102
+ @dataclasses.dataclass(frozen=True)
103
+ class VirtMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702
104
+
105
+ class PageTableTraverseContext:
106
+ def __init__(self, dev, pt, vaddr, create_pts=False, free_pts=False, boot=False):
107
+ self.dev, self.vaddr, self.create_pts, self.free_pts, self.boot = dev, vaddr - dev.mm.va_base, create_pts, free_pts, boot
108
+ self.pt_stack:list[tuple[Any, int, int]] = [(pt, self._pt_pte_idx(pt, self.vaddr), self._pt_pte_size(pt))]
109
+
110
+ def _pt_pte_cnt(self, lv): return self.dev.mm.pte_cnt[lv]
111
+ def _pt_pte_size(self, pt): return self.dev.mm.pte_covers[pt.lv]
112
+ def _pt_pte_idx(self, pt, va): return (va // self._pt_pte_size(pt)) % self._pt_pte_cnt(pt.lv)
113
+
114
+ def level_down(self):
115
+ pt, pte_idx, _ = self.pt_stack[-1]
116
+
117
+ if not pt.valid(pte_idx):
118
+ assert self.create_pts, "Not allowed to create new page table"
119
+ pt.set_entry(pte_idx, self.dev.mm.palloc(0x1000, zero=True, boot=self.boot), table=True, valid=True)
120
+
121
+ assert not pt.is_huge_page(pte_idx), f"Must be table pt={pt.paddr:#x}, {pt.lv=} {pte_idx=} {pt.read_fields(pte_idx)}"
122
+ child_page_table = self.dev.mm.pt_t(self.dev, pt.address(pte_idx), lv=pt.lv+1)
123
+
124
+ self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table)))
125
+ return self.pt_stack[-1]
126
+
127
+ def _try_free_pt(self) -> bool:
128
+ pt, _, _ = self.pt_stack[-1]
129
+ if self.free_pts and pt != self.dev.mm.root_page_table and all(not pt.valid(i) for i in range(self._pt_pte_cnt(self.pt_stack[-1][0].lv))):
130
+ self.dev.mm.pfree(pt.paddr)
131
+ parent_pt, parent_pte_idx, _ = self.pt_stack[-2]
132
+ parent_pt.set_entry(parent_pte_idx, 0x0, valid=False)
133
+ return True
134
+ return False
135
+
136
+ def level_up(self):
137
+ while self._try_free_pt() or self.pt_stack[-1][1] == self._pt_pte_cnt(self.pt_stack[-1][0].lv):
138
+ pt, pt_cnt, _ = self.pt_stack.pop()
139
+ if pt_cnt == self._pt_pte_cnt(pt.lv): self.pt_stack[-1] = (self.pt_stack[-1][0], self.pt_stack[-1][1] + 1, self.pt_stack[-1][2])
140
+
141
+ def next(self, size:int, paddr:int|None=None, off:int=0):
142
+ while size > 0:
143
+ pt, pte_idx, pte_covers = self.pt_stack[-1]
144
+ if self.create_pts:
145
+ assert paddr is not None, "paddr must be provided when allocating new page tables"
146
+ while pte_covers > size or not pt.supports_huge_page(paddr+off) or self.vaddr&(pte_covers-1) != 0: pt, pte_idx, pte_covers = self.level_down()
147
+ else:
148
+ while not pt.is_huge_page(pte_idx): pt, pte_idx, pte_covers = self.level_down()
149
+
150
+ entries = min(size // pte_covers, self._pt_pte_cnt(pt.lv) - pte_idx)
151
+ assert entries > 0, f"Invalid entries {size=:#x}, {pte_covers=:#x}"
152
+ yield off, pt, pte_idx, entries, pte_covers
153
+
154
+ size, off, self.vaddr = size - entries * pte_covers, off + entries * pte_covers, self.vaddr + entries * pte_covers
155
+ self.pt_stack[-1] = (pt, pte_idx + entries, pte_covers)
156
+ self.level_up()
157
+
158
+ class MemoryManager:
159
+ va_allocator: ClassVar[TLSFAllocator|None] = None
160
+
161
+ def __init__(self, dev, vram_size:int, boot_size:int, pt_t, va_bits:int, va_shifts:list[int], va_base:int,
162
+ palloc_ranges:list[tuple[int, int]], first_lv:int=0):
163
+ self.dev, self.vram_size, self.va_shifts, self.va_base, lvl_msb = dev, vram_size, va_shifts, va_base, va_shifts + [va_bits + 1]
164
+ self.pte_covers, self.pte_cnt = [1 << x for x in va_shifts][::-1], [1 << (lvl_msb[i+1] - lvl_msb[i]) for i in range(len(lvl_msb) - 1)][::-1]
165
+ self.pt_t, self.palloc_ranges, self.level_cnt, self.va_bits = pt_t, palloc_ranges, len(va_shifts), va_bits
166
+
167
+ self.boot_allocator = TLSFAllocator(boot_size, base=0) # per device
168
+ self.pa_allocator = TLSFAllocator(vram_size - (64 << 20), base=self.boot_allocator.size) # per device
169
+ self.root_page_table = pt_t(self.dev, self.palloc(0x1000, zero=not self.dev.smi_dev, boot=True), lv=first_lv)
170
+
171
+ def _frag_size(self, va, sz, must_cover=True):
172
+ """
173
+ Calculate the tlb fragment size for a given virtual address and size.
174
+ If must_cover is True, the fragment size must cover the size, otherwise the biggest fragment size that fits the size is returned.
175
+ Fragment 0 is 4KB, 1 is 8KB and so on.
176
+ """
177
+ va_pwr2_div, sz_pwr2_div, sz_pwr2_max = va & -(va) if va > 0 else (1 << 63), sz & -(sz), (1 << (sz.bit_length() - 1))
178
+ return (min(va_pwr2_div, sz_pwr2_div) if must_cover else min(va_pwr2_div, sz_pwr2_max)).bit_length() - 1 - 12
179
+
180
+ def page_tables(self, vaddr:int, size:int):
181
+ ctx = PageTableTraverseContext(self.dev, self.root_page_table, vaddr, create_pts=True)
182
+ for _ in ctx.next(size, paddr=0): return [pt for pt, _, _ in ctx.pt_stack]
183
+
184
+ def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False, boot=False) -> VirtMapping:
185
+ if getenv("MM_DEBUG", 0): print(f"mm {self.dev.devfmt}: mapping {vaddr=:#x} ({size=:#x})")
186
+
187
+ assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
188
+
189
+ ctx = PageTableTraverseContext(self.dev, self.root_page_table, vaddr, create_pts=True, boot=boot)
190
+ for paddr, psize in paddrs:
191
+ for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize, paddr=paddr):
192
+ for pte_off in range(pte_cnt):
193
+ assert not pt.valid(pte_idx + pte_off), f"PTE already mapped: {pt.entry(pte_idx + pte_off):#x}"
194
+ pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, system=system, snooped=snooped,
195
+ frag=self._frag_size(ctx.vaddr+off, pte_cnt * pte_covers), valid=True)
196
+
197
+ self.on_range_mapped()
198
+ return VirtMapping(vaddr, size, paddrs, uncached=uncached, system=system, snooped=snooped)
199
+
200
+ def unmap_range(self, vaddr:int, size:int):
201
+ if getenv("MM_DEBUG", 0): print(f"mm {self.dev.devfmt}: unmapping {vaddr=:#x} ({size=:#x})")
202
+
203
+ ctx = PageTableTraverseContext(self.dev, self.root_page_table, vaddr, free_pts=True)
204
+ for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size):
205
+ for pte_id in range(pte_idx, pte_idx + pte_cnt):
206
+ assert pt.valid(pte_id), f"PTE not mapped: {pt.entry(pte_id):#x}"
207
+ pt.set_entry(pte_id, paddr=0x0, valid=False)
208
+
209
+ def on_range_mapped(self): pass
210
+
211
+ @classmethod
212
+ def alloc_vaddr(cls, size:int, align=0x1000) -> int:
213
+ assert cls.va_allocator is not None, "must be set it"
214
+ return cls.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align))
215
+
216
+ def valloc(self, size:int, align=0x1000, uncached=False, contiguous=False) -> VirtMapping:
217
+ # Alloc physical memory and map it to the virtual address
218
+ va = self.alloc_vaddr(size:=round_up(size, 0x1000), align)
219
+
220
+ if contiguous: paddrs = [(self.palloc(size, zero=True), size)]
221
+ else:
222
+ # Traverse the PT to find the largest contiguous sizes we need to allocate. Try to allocate the longest segment to reduce TLB pressure.
223
+ nxt_range, rem_size, paddrs = 0, size, []
224
+ while rem_size > 0:
225
+ while self.palloc_ranges[nxt_range][0] > rem_size: nxt_range += 1
226
+
227
+ try: paddrs += [(self.palloc(try_sz:=self.palloc_ranges[nxt_range][0], self.palloc_ranges[nxt_range][1], zero=False), try_sz)]
228
+ except MemoryError:
229
+ # Move to a smaller size and try again.
230
+ nxt_range += 1
231
+ if nxt_range == len(self.palloc_ranges):
232
+ for paddr, _ in paddrs: self.pa_allocator.free(paddr)
233
+ raise MemoryError(f"Failed to allocate memory. (total allocation size={size:#x}, current try={self.palloc_ranges[nxt_range-1]})")
234
+ continue
235
+ rem_size -= self.palloc_ranges[nxt_range][0]
236
+
237
+ return self.map_range(va, size, paddrs, uncached=uncached)
238
+
239
+ def vfree(self, vm:VirtMapping):
240
+ assert self.va_allocator is not None, "must be set it"
241
+ self.unmap_range(vm.va_addr, vm.size)
242
+ self.va_allocator.free(vm.va_addr)
243
+ for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr)
244
+
245
+ def palloc(self, size:int, align:int=0x1000, zero=True, boot=False) -> int:
246
+ assert self.dev.is_booting == boot, "During booting, only boot memory can be allocated"
247
+ paddr = (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align)
248
+ if zero: self.dev.vram[paddr:paddr+size] = bytes(size)
249
+ return paddr
250
+
251
+ def pfree(self, paddr:int): self.pa_allocator.free(paddr)
File without changes