tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,225 @@
1
+ import functools, struct
2
+ from tinygrad.device import Compiled, Allocator, Compiler
3
+ from tinygrad.renderer.wgsl import WGSLRenderer
4
+ from tinygrad.helpers import round_up, OSX
5
+ from tinygrad.runtime.autogen import webgpu
6
+ from typing import List, Any
7
+ import ctypes
8
+ import os
9
+
10
+ backend_types = {v: k for k, v in webgpu.WGPUBackendType__enumvalues.items() }
11
+
12
+ try:
13
+ instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True)))
14
+ except AttributeError:
15
+ raise RuntimeError("Cannot find dawn library. Install it with: " + ("brew tap wpmed92/dawn && brew install dawn" if OSX else
16
+ "sudo curl -L https://github.com/wpmed92/pydawn/releases/download/v0.1.6/libwebgpu_dawn.so -o /usr/lib/libwebgpu_dawn.so"))
17
+
18
+ def to_c_string(_str): return ctypes.create_string_buffer(_str.encode('utf-8'))
19
+
20
+ def from_wgpu_str(string_view): return ctypes.string_at(string_view.data, string_view.length).decode("utf-8")
21
+
22
+ def to_wgpu_str(_str):
23
+ return webgpu.WGPUStringView(data=ctypes.cast(ctypes.pointer(to_c_string(_str)), ctypes.POINTER(ctypes.c_char)), length=len(_str))
24
+
25
+ def _wait(future):
26
+ assert webgpu.wgpuInstanceWaitAny(instance, 1, webgpu.WGPUFutureWaitInfo(future=future), 2**64-1) == webgpu.WGPUWaitStatus_Success, "Future failed"
27
+
28
+ def write_buffer(device, buf, offset, src):
29
+ src = bytearray(src)
30
+ webgpu.wgpuQueueWriteBuffer(webgpu.wgpuDeviceGetQueue(device), buf, offset, (ctypes.c_uint8 * len(src)).from_buffer(src), len(src))
31
+
32
+ def _run(async_fun, cb_info_type, cb_type, status_enum, res_idx, msg_idx, *params):
33
+ result: List[Any] = []
34
+
35
+ def cb(*params):
36
+ result[:] = params
37
+ if msg_idx: result[msg_idx] = from_wgpu_str(result[msg_idx])
38
+
39
+ cb_info = cb_info_type(nextInChain=None, mode=webgpu.WGPUCallbackMode_WaitAnyOnly, callback=cb_type(cb))
40
+ _wait(async_fun(*params, cb_info))
41
+
42
+ if result[0] != 1: raise RuntimeError(f"[{status_enum[result[0]] if status_enum else 'ERROR'}]{result[msg_idx] if msg_idx else ''}")
43
+ return result[res_idx] if res_idx else None
44
+
45
+ def copy_buffer_to_buffer(dev, src, src_offset, dst, dst_offset, size):
46
+ encoder = webgpu.wgpuDeviceCreateCommandEncoder(dev, webgpu.WGPUCommandEncoderDescriptor())
47
+ webgpu.wgpuCommandEncoderCopyBufferToBuffer(encoder, src, src_offset, dst, dst_offset, size)
48
+ cb = webgpu.wgpuCommandEncoderFinish(encoder, webgpu.WGPUCommandBufferDescriptor())
49
+ webgpu.wgpuQueueSubmit(webgpu.wgpuDeviceGetQueue(dev), 1, (webgpu.WGPUCommandBuffer*1)(cb))
50
+ webgpu.wgpuCommandBufferRelease(cb)
51
+ webgpu.wgpuCommandEncoderRelease(encoder)
52
+
53
+ def read_buffer(dev, buf):
54
+ size = webgpu.wgpuBufferGetSize(buf)
55
+ tmp_buffer = webgpu.wgpuDeviceCreateBuffer(dev, webgpu.WGPUBufferDescriptor(size=size,
56
+ usage=webgpu.WGPUBufferUsage_CopyDst | webgpu.WGPUBufferUsage_MapRead, mappedAtCreation=False))
57
+ copy_buffer_to_buffer(dev, buf, 0, tmp_buffer, 0, size)
58
+ _run(webgpu.wgpuBufferMapAsync2, webgpu.WGPUBufferMapCallbackInfo2, webgpu.WGPUBufferMapCallback2, webgpu.WGPUBufferMapAsyncStatus__enumvalues,
59
+ None, 0, tmp_buffer, webgpu.WGPUMapMode_Read, 0, size)
60
+ void_ptr = ctypes.cast(webgpu.wgpuBufferGetConstMappedRange(tmp_buffer, 0, size), ctypes.c_void_p)
61
+ buf_copy = bytearray((ctypes.c_uint8 * size).from_address(void_ptr.value))
62
+ webgpu.wgpuBufferUnmap(tmp_buffer)
63
+ webgpu.wgpuBufferDestroy(tmp_buffer)
64
+ return memoryview(buf_copy).cast("B")
65
+
66
+ def pop_error(device):
67
+ return _run(webgpu.wgpuDevicePopErrorScopeF, webgpu.WGPUPopErrorScopeCallbackInfo, webgpu.WGPUPopErrorScopeCallback, None, 2, 2, device)
68
+
69
+ def create_uniform(wgpu_device, val):
70
+ buf = webgpu.wgpuDeviceCreateBuffer(wgpu_device,
71
+ webgpu.WGPUBufferDescriptor(size=4, usage=webgpu.WGPUBufferUsage_Uniform | webgpu.WGPUBufferUsage_CopyDst))
72
+ write_buffer(wgpu_device, buf, 0, val.to_bytes(4, "little") if isinstance(val, int) else struct.pack('<f', val))
73
+ return buf
74
+
75
+ class WebGPUProgram:
76
+ def __init__(self, dev, name:str, lib:bytes):
77
+ (self.dev, self.timestamp_supported) = dev
78
+
79
+ # Creating shader module
80
+ shader = webgpu.WGPUShaderModuleWGSLDescriptor(code=to_wgpu_str(lib.decode()),
81
+ chain=webgpu.WGPUChainedStruct(sType=webgpu.WGPUSType_ShaderSourceWGSL))
82
+ module = webgpu.WGPUShaderModuleDescriptor()
83
+ module.nextInChain = ctypes.cast(ctypes.pointer(shader), ctypes.POINTER(webgpu.struct_WGPUChainedStruct))
84
+
85
+ # Check compiler error
86
+ webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
87
+ shader_module = webgpu.wgpuDeviceCreateShaderModule(self.dev, module)
88
+
89
+ if err := pop_error(self.dev): raise RuntimeError(f"Shader compilation failed: {err}")
90
+
91
+ self.name, self.lib, self.prg = name, lib, shader_module
92
+ def __call__(self, *bufs, global_size=(1,1,1), local_size=(1,1,1), vals=(), wait=False):
93
+ wait = wait and self.timestamp_supported
94
+ tmp_bufs = [*bufs]
95
+ buf_patch = False
96
+
97
+ # WebGPU does not allow using the same buffer for input and output
98
+ for i in range(1, len(bufs)):
99
+ if bufs[i] == bufs[0]:
100
+ tmp_bufs[0] = webgpu.wgpuDeviceCreateBuffer(self.dev,
101
+ webgpu.WGPUBufferDescriptor(size=webgpu.wgpuBufferGetSize(bufs[0]), usage=webgpu.wgpuBufferGetUsage(bufs[0])))
102
+ buf_patch = True
103
+
104
+ # Creating bind group layout
105
+ binding_layouts = [webgpu.WGPUBindGroupLayoutEntry(binding=0, visibility= webgpu.WGPUShaderStage_Compute,
106
+ buffer=webgpu.WGPUBufferBindingLayout(type=webgpu.WGPUBufferBindingType_Uniform))]
107
+ binding_layouts += [webgpu.WGPUBindGroupLayoutEntry(binding=i+1, visibility=webgpu.WGPUShaderStage_Compute,
108
+ buffer=webgpu.WGPUBufferBindingLayout(type=webgpu.WGPUBufferBindingType_Uniform if i >= len(tmp_bufs)
109
+ else webgpu.WGPUBufferBindingType_Storage)) for i in range(len(tmp_bufs)+len(vals))]
110
+
111
+ bl_arr_type = webgpu.WGPUBindGroupLayoutEntry * len(binding_layouts)
112
+ webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
113
+ bind_group_layouts = [webgpu.wgpuDeviceCreateBindGroupLayout(self.dev, webgpu.WGPUBindGroupLayoutDescriptor(
114
+ entryCount=len(binding_layouts), entries=ctypes.cast(bl_arr_type(*binding_layouts), ctypes.POINTER(webgpu.WGPUBindGroupLayoutEntry))))]
115
+
116
+ if bg_layout_err := pop_error(self.dev): raise RuntimeError(f"Error creating bind group layout: {bg_layout_err}")
117
+
118
+ # Creating pipeline layout
119
+ pipeline_layout_desc = webgpu.WGPUPipelineLayoutDescriptor(bindGroupLayoutCount=len(bind_group_layouts),
120
+ bindGroupLayouts = (webgpu.WGPUBindGroupLayout * len(bind_group_layouts))(*bind_group_layouts))
121
+
122
+ webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
123
+ pipeline_layout = webgpu.wgpuDeviceCreatePipelineLayout(self.dev, pipeline_layout_desc)
124
+
125
+ if pipe_err := pop_error(self.dev): raise RuntimeError(f"Error creating pipeline layout: {pipe_err}")
126
+
127
+ # Creating bind group
128
+ bindings = [webgpu.WGPUBindGroupEntry(binding=0, buffer=create_uniform(self.dev, float('inf')), offset=0, size=4)]
129
+ bindings += [webgpu.WGPUBindGroupEntry(binding=i+1, buffer=create_uniform(self.dev, x) if i >= len(tmp_bufs) else x, offset=0,
130
+ size=4 if i >= len(tmp_bufs) else webgpu.wgpuBufferGetSize(x)) for i,x in enumerate(tuple(tmp_bufs)+vals)]
131
+
132
+ bg_arr_type = webgpu.WGPUBindGroupEntry * len(bindings)
133
+ bind_group_desc = webgpu.WGPUBindGroupDescriptor(layout=bind_group_layouts[0], entryCount=len(bindings), entries=bg_arr_type(*bindings))
134
+ webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
135
+ bind_group = webgpu.wgpuDeviceCreateBindGroup(self.dev, bind_group_desc)
136
+
137
+ if bind_err := pop_error(self.dev): raise RuntimeError(f"Error creating bind group: {bind_err}")
138
+
139
+ # Creating compute pipeline
140
+ compute_desc = webgpu.WGPUComputePipelineDescriptor(layout=pipeline_layout,
141
+ compute=webgpu.WGPUComputeState(module=self.prg, entryPoint=to_wgpu_str(self.name)))
142
+ pipeline_result = _run(webgpu.wgpuDeviceCreateComputePipelineAsync2, webgpu.WGPUCreateComputePipelineAsyncCallbackInfo2,
143
+ webgpu.WGPUCreateComputePipelineAsyncCallback2, webgpu.WGPUCreatePipelineAsyncStatus__enumvalues, 1, None, self.dev, compute_desc)
144
+
145
+ command_encoder = webgpu.wgpuDeviceCreateCommandEncoder(self.dev, webgpu.WGPUCommandEncoderDescriptor())
146
+ comp_pass_desc = webgpu.WGPUComputePassDescriptor(nextInChain=None)
147
+
148
+ if wait:
149
+ query_set = webgpu.wgpuDeviceCreateQuerySet(self.dev, webgpu.WGPUQuerySetDescriptor(type=webgpu.WGPUQueryType_Timestamp, count=2))
150
+ query_buf = webgpu.wgpuDeviceCreateBuffer(self.dev,
151
+ webgpu.WGPUBufferDescriptor(size=16, usage=webgpu.WGPUBufferUsage_QueryResolve | webgpu.WGPUBufferUsage_CopySrc))
152
+ comp_pass_desc.timestampWrites = ctypes.pointer(webgpu.WGPUComputePassTimestampWrites(
153
+ querySet=query_set, beginningOfPassWriteIndex=0, endOfPassWriteIndex=1))
154
+
155
+ # Begin compute pass
156
+ compute_pass = webgpu.wgpuCommandEncoderBeginComputePass(command_encoder, comp_pass_desc)
157
+ webgpu.wgpuComputePassEncoderSetPipeline(compute_pass, pipeline_result)
158
+ webgpu.wgpuComputePassEncoderSetBindGroup(compute_pass, 0, bind_group, 0, None)
159
+ webgpu.wgpuComputePassEncoderDispatchWorkgroups(compute_pass, *global_size)
160
+ webgpu.wgpuComputePassEncoderEnd(compute_pass)
161
+
162
+ if wait: webgpu.wgpuCommandEncoderResolveQuerySet(command_encoder, query_set, 0, 2, query_buf, 0)
163
+
164
+ cmd_buf = webgpu.wgpuCommandEncoderFinish(command_encoder, webgpu.WGPUCommandBufferDescriptor())
165
+ webgpu.wgpuQueueSubmit(webgpu.wgpuDeviceGetQueue(self.dev), 1, (webgpu.WGPUCommandBuffer*1)(cmd_buf))
166
+
167
+ if buf_patch:
168
+ copy_buffer_to_buffer(self.dev, tmp_bufs[0], 0, bufs[0], 0, webgpu.wgpuBufferGetSize(bufs[0]))
169
+ webgpu.wgpuBufferDestroy(tmp_bufs[0])
170
+
171
+ if wait:
172
+ time = ((timestamps:=read_buffer(self.dev, query_buf).cast("Q").tolist())[1] - timestamps[0]) / 1e9
173
+ webgpu.wgpuBufferDestroy(query_buf)
174
+ webgpu.wgpuQuerySetDestroy(query_set)
175
+ return time
176
+
177
+ class WebGpuAllocator(Allocator):
178
+ def __init__(self, dev): self.dev = dev
179
+ def _alloc(self, size: int, options):
180
+ # WebGPU buffers have to be 4-byte aligned
181
+ return webgpu.wgpuDeviceCreateBuffer(self.dev, webgpu.WGPUBufferDescriptor(size=round_up(size, 4),
182
+ usage=webgpu.WGPUBufferUsage_Storage | webgpu.WGPUBufferUsage_CopyDst | webgpu.WGPUBufferUsage_CopySrc))
183
+ def _copyin(self, dest, src: memoryview):
184
+ if src.nbytes % 4:
185
+ padded_src = bytearray(round_up(src.nbytes, 4))
186
+ padded_src[:src.nbytes] = src
187
+ write_buffer(self.dev, dest, 0, padded_src if src.nbytes % 4 else src)
188
+ def _copyout(self, dest: memoryview, src):
189
+ buffer_data = read_buffer(self.dev, src)
190
+ dest[:] = buffer_data[:dest.nbytes] if webgpu.wgpuBufferGetSize(src) > dest.nbytes else buffer_data
191
+ def _free(self, opaque, options):
192
+ webgpu.wgpuBufferDestroy(opaque)
193
+
194
+ class WebGpuDevice(Compiled):
195
+ def __init__(self, device:str):
196
+ # Requesting an adapter
197
+ adapter_res = _run(webgpu.wgpuInstanceRequestAdapterF, webgpu.WGPURequestAdapterCallbackInfo, webgpu.WGPURequestAdapterCallback,
198
+ webgpu.WGPURequestAdapterStatus__enumvalues, 1, 2, instance,
199
+
200
+ webgpu.WGPURequestAdapterOptions(powerPreference=webgpu.WGPUPowerPreference_HighPerformance,
201
+ backendType=backend_types.get(os.getenv("WEBGPU_BACKEND", ""), 0)))
202
+
203
+ # Get supported features
204
+ supported_features = webgpu.WGPUSupportedFeatures()
205
+ webgpu.wgpuAdapterGetFeatures(adapter_res, supported_features)
206
+ supported = [supported_features.features[i] for i in range(supported_features.featureCount)]
207
+ features = [feat for feat in [webgpu.WGPUFeatureName_TimestampQuery, webgpu.WGPUFeatureName_ShaderF16] if feat in supported]
208
+ dev_desc = webgpu.WGPUDeviceDescriptor(requiredFeatureCount=len(features),requiredFeatures=(webgpu.WGPUFeatureName * len(features))(*features))
209
+
210
+ # Limits
211
+ supported_limits = webgpu.WGPUSupportedLimits()
212
+ webgpu.wgpuAdapterGetLimits(adapter_res, ctypes.cast(ctypes.pointer(supported_limits),ctypes.POINTER(webgpu.struct_WGPUSupportedLimits)))
213
+ limits = webgpu.WGPURequiredLimits(limits=supported_limits.limits)
214
+ dev_desc.requiredLimits = ctypes.cast(ctypes.pointer(limits),ctypes.POINTER(webgpu.struct_WGPURequiredLimits))
215
+
216
+ # Requesting a device
217
+ device_res = _run(webgpu.wgpuAdapterRequestDeviceF, webgpu.WGPURequestDeviceCallbackInfo, webgpu.WGPURequestDeviceCallback,
218
+ webgpu.WGPURequestDeviceStatus__enumvalues, 1, 2, adapter_res, dev_desc)
219
+
220
+ super().__init__(device, WebGpuAllocator(device_res), WGSLRenderer(), Compiler(),
221
+ functools.partial(WebGPUProgram, (device_res, webgpu.WGPUFeatureName_TimestampQuery in supported)))
222
+
223
+ def synchronize(self):
224
+ _run(webgpu.wgpuQueueOnSubmittedWorkDone2, webgpu.WGPUQueueWorkDoneCallbackInfo2, webgpu.WGPUQueueWorkDoneCallback2,
225
+ webgpu.WGPUQueueWorkDoneStatus__enumvalues, None, None, webgpu.wgpuDeviceGetQueue(self.runtime.args[0][0]))
@@ -0,0 +1,94 @@
1
+ import collections
2
+ from tinygrad.helpers import round_up
3
+
4
+ class TLSFAllocator:
5
+ """
6
+ The allocator is based on the Two-Level Segregated Fit (TLSF) algorithm. The allocator maintains 2 level of buckets:
7
+ * 1st level is determined by the most significant bit of the size.
8
+ * 2nd level splits the covered memory of 1st level into @lv2_cnt entries.
9
+
10
+ For each allocation request, the allocator searches for the smallest block that can fit the requested size.
11
+ For each deallocation request, the allocator merges the block with its neighbors if they are free.
12
+ """
13
+
14
+ def __init__(self, size:int, base:int=0, block_size:int=16, lv2_cnt:int=16):
15
+ self.size, self.base, self.block_size, self.l2_cnt = size, base, block_size, lv2_cnt.bit_length()
16
+ self.storage:list = [collections.defaultdict(list) for _ in range(size.bit_length() + 1)]
17
+ self.lv1_entries:list[int] = [0] * len(self.storage)
18
+
19
+ # self.blocks is more like a linked list, where each entry is a contigous block.
20
+ self.blocks:dict[int, tuple[int, int|None, int|None, bool]] = {0: (size, None, None, True)} # size, next, prev, is_free
21
+ self._insert_block(0, size)
22
+
23
+ def lv1(self, size): return size.bit_length()
24
+ def lv2(self, size): return (size - (1 << (size.bit_length() - 1))) // (1 << max(0, size.bit_length() - self.l2_cnt))
25
+
26
+ def _insert_block(self, start:int, size:int, prev:int|None=None):
27
+ if prev is None: prev = self.blocks[start][2]
28
+ self.storage[self.lv1(size)][self.lv2(size)].append(start)
29
+ self.lv1_entries[self.lv1(size)] += 1
30
+ self.blocks[start] = (size, start + size, prev, True)
31
+ return self
32
+
33
+ def _remove_block(self, start:int, size:int, prev:int|None=None):
34
+ if prev is None: prev = self.blocks[start][2]
35
+ self.storage[self.lv1(size)][self.lv2(size)].remove(start)
36
+ self.lv1_entries[self.lv1(size)] -= 1
37
+ self.blocks[start] = (size, start + size, prev, False)
38
+ return self
39
+
40
+ def _split_block(self, start:int, size:int, new_size:int):
41
+ nxt = self.blocks[start][1]
42
+ assert self.blocks[start][3], "block must be free"
43
+ self._remove_block(start, size)._insert_block(start, new_size)._insert_block(start + new_size, size - new_size, prev=start)
44
+ if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start + new_size, self.blocks[nxt][3])
45
+ return self
46
+
47
+ def _merge_right(self, start:int):
48
+ size, nxt, _, is_free = self.blocks[start]
49
+ assert is_free, "block must be free"
50
+
51
+ while is_free and nxt in self.blocks:
52
+ if (blk:=self.blocks[nxt])[3] is False: break
53
+ self._remove_block(start, size)._remove_block(nxt, blk[0])._insert_block(start, size:=size + blk[0])
54
+ assert self.blocks[start][1] == blk[1]
55
+ _, nxt, _, _ = self.blocks.pop(nxt)
56
+
57
+ if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start, self.blocks[nxt][3])
58
+
59
+ def _merge_block(self, start:int):
60
+ # Go left while blocks are free. Then merge all them right.
61
+ while (x:=self.blocks[start][2]) is not None and self.blocks[x][3] is True: start = x
62
+ self._merge_right(start)
63
+
64
+ def alloc(self, req_size:int, align:int=1) -> int:
65
+ req_size = max(self.block_size, req_size) # at least block size.
66
+ size = max(self.block_size, req_size + align - 1)
67
+
68
+ # Round up the allocation size to the next bucket, so any entry there can fit the requested size.
69
+ size = round_up(size, (1 << size.bit_length() - self.l2_cnt))
70
+
71
+ # Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found.
72
+ for l1 in range(self.lv1(size), len(self.storage)):
73
+ if self.lv1_entries[l1] == 0: continue
74
+ for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)):
75
+ if len(self.storage[l1][l2]) > 0:
76
+ nsize = self.blocks[self.storage[l1][l2][0]][0]
77
+ assert nsize >= size, "block must be larger"
78
+
79
+ # Block start address.
80
+ start = self.storage[l1][l2][0]
81
+
82
+ # If request contains alignment, split the block into two parts.
83
+ if (new_start:=round_up(start, align)) != start:
84
+ self._split_block(start, nsize, new_start - start)
85
+ start, nsize = new_start, self.blocks[new_start][0]
86
+
87
+ # If the block is larger than the requested size, split it into two parts.
88
+ if nsize > req_size: self._split_block(start, nsize, req_size)
89
+ self._remove_block(start, req_size) # Mark the block as allocated.
90
+ return start + self.base
91
+ raise MemoryError(f"Can't allocate {req_size} bytes")
92
+
93
+ def free(self, start:int):
94
+ self._insert_block(start - self.base, self.blocks[start - self.base][0])._merge_block(start - self.base)
File without changes