tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,405 @@
1
+ from __future__ import annotations
2
+ import os, ctypes, functools, mmap, struct, array, decimal, math, sys
3
+ assert sys.platform != 'win32'
4
+ from types import SimpleNamespace
5
+ from typing import Tuple, List, Any, cast
6
+ from tinygrad.device import BufferOptions
7
+ from tinygrad.runtime.support.hcq import HCQBuffer, HWComputeQueue, HCQProgram, HCQCompiled, HCQSignal, HCQAllocator, HCQArgsState
8
+ from tinygrad.runtime.autogen import kgsl, adreno, libc
9
+ from tinygrad.runtime.ops_gpu import CLCompiler, CLDevice
10
+ from tinygrad.renderer.cstyle import QCOMRenderer
11
+ from tinygrad.helpers import getenv, from_mv, mv_address, to_mv, round_up, data64_le, prod, fromimport
12
+ if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
13
+
14
+ BUFTYPE_BUF, BUFTYPE_TEX, BUFTYPE_IBO = 0, 1, 2
15
+
16
+ #Parse C-style defines: <regname>_<field_x>__SHIFT and <regname>_<field_y>__MASK from the adreno module into the following format:
17
+ # qreg.<regname>(<field_x>=..., <field_y>=..., ..., <field_n>=...)
18
+ def _qreg_exec(reg, __val=0, **kwargs):
19
+ for k, v in kwargs.items():
20
+ __val |= (getattr(adreno, f'{reg[4:]}_{k.upper()}') if v else 0) if type(v) is bool else (v << getattr(adreno, f'{reg[4:]}_{k.upper()}__SHIFT'))
21
+ return __val
22
+ qreg: Any = type("QREG", (object,), {name[4:].lower(): functools.partial(_qreg_exec, name) for name in adreno.__dict__.keys() if name[:4] == 'REG_'})
23
+
24
+ def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length()
25
+
26
+ def parity(val: int):
27
+ for i in range(4,1,-1): val ^= val >> (1 << i)
28
+ return (~0x6996 >> (val & 0xf)) & 1
29
+
30
+ def pkt7_hdr(opcode: int, cnt: int): return adreno.CP_TYPE7_PKT | cnt & 0x3FFF | parity(cnt) << 15 | (opcode & 0x7F) << 16 | parity(opcode) << 23
31
+
32
+ def pkt4_hdr(reg: int, cnt: int): return adreno.CP_TYPE4_PKT | cnt & 0x7F | parity(cnt) << 7 | (reg & 0x3FFFF) << 8 | parity(reg) << 27
33
+
34
+ class QCOMCompiler(CLCompiler):
35
+ def __init__(self, device:str=""): super().__init__(CLDevice(device), 'compile_qcom')
36
+ def disassemble(self, lib:bytes): fromimport('extra.disassemblers.adreno', 'disasm')(lib)
37
+
38
+ class QCOMSignal(HCQSignal):
39
+ def __init__(self, value=0, is_timeline=False):
40
+ self._signal = QCOMDevice.signals_pool.pop()
41
+ super().__init__(value)
42
+ def __del__(self): QCOMDevice.signals_pool.append(self._signal)
43
+ def _get_value(self) -> int: return self._signal[0]
44
+ def _get_timestamp(self) -> decimal.Decimal: return decimal.Decimal(self._signal[1]) / decimal.Decimal(19.2) # based on the 19.2MHz always-on timer
45
+ def _set_value(self, new_value:int): self._signal[0] = new_value
46
+
47
+ class QCOMComputeQueue(HWComputeQueue):
48
+ def __init__(self):
49
+ self.cmd_idx_to_dims = {}
50
+ super().__init__()
51
+
52
+ def __del__(self):
53
+ if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferOptions(cpu_access=True, nolru=True))
54
+
55
+ def cmd(self, opcode: int, *vals: int): self.q += [pkt7_hdr(opcode, len(vals)), *vals]
56
+
57
+ def reg(self, reg: int, *vals: int): self.q += [pkt4_hdr(reg, len(vals)), *vals]
58
+
59
+ def _cache_flush(self, write_back=True, invalidate=False, sync=True, memsync=False):
60
+ # TODO: 7xx support.
61
+ if write_back: self.cmd(adreno.CP_EVENT_WRITE, adreno.CACHE_FLUSH_TS, *data64_le(QCOMDevice.dummy_addr), 0) # dirty cache write-back.
62
+ if invalidate: self.cmd(adreno.CP_EVENT_WRITE, adreno.CACHE_INVALIDATE) # invalidate cache lines (following reads from RAM).
63
+ if memsync: self.cmd(adreno.CP_WAIT_MEM_WRITES)
64
+ if sync: self.cmd(adreno.CP_WAIT_FOR_IDLE)
65
+
66
+ def _memory_barrier(self): self._cache_flush(write_back=True, invalidate=True, sync=True, memsync=True)
67
+
68
+ def _signal(self, signal, value=0, ts=False):
69
+ self.cmd(adreno.CP_WAIT_FOR_IDLE)
70
+ if QCOMDevice.gpu_id < 700:
71
+ self.cmd(adreno.CP_EVENT_WRITE, qreg.cp_event_write_0(event=adreno.CACHE_FLUSH_TS, timestamp=ts),
72
+ *data64_le(mv_address(signal._signal) + (0 if not ts else 8)), qreg.cp_event_write_3(value & 0xFFFFFFFF))
73
+ self._cache_flush(write_back=True, invalidate=False, sync=False, memsync=False)
74
+ else:
75
+ # TODO: support devices starting with 8 Gen 1. Also, 700th series have convenient CP_GLOBAL_TIMESTAMP and CP_LOCAL_TIMESTAMP
76
+ raise RuntimeError('CP_EVENT_WRITE7 is not supported')
77
+
78
+ def _timestamp(self, signal): return self._signal(signal, 0, ts=True)
79
+
80
+ def _wait(self, signal, value=0):
81
+ self.cmd(adreno.CP_WAIT_REG_MEM, qreg.cp_wait_reg_mem_0(function=adreno.WRITE_GE, poll=adreno.POLL_MEMORY),*data64_le(mv_address(signal._signal)),
82
+ qreg.cp_wait_reg_mem_3(ref=value&0xFFFFFFFF), qreg.cp_wait_reg_mem_4(mask=0xFFFFFFFF), qreg.cp_wait_reg_mem_5(delay_loop_cycles=32))
83
+
84
+ def _update_signal(self, cmd_idx, signal, value):
85
+ if signal is not None: self._patch(cmd_idx, offset=3, data=data64_le(mv_address(signal._signal)))
86
+ if value is not None: self._patch(cmd_idx, offset=5, data=[value & 0xFFFFFFFF])
87
+
88
+ def _update_wait(self, cmd_idx, signal, value):
89
+ if signal is not None: self._patch(cmd_idx, offset=2, data=data64_le(mv_address(signal._signal)))
90
+ if value is not None: self._patch(cmd_idx, offset=4, data=[value & 0xFFFFFFFF])
91
+
92
+ def _build_gpu_command(self, device, hw_addr=None):
93
+ to_mv((hw_page_addr:=hw_addr or device._alloc_cmd_buf(len(self.q) * 4)), len(self.q) * 4).cast('I')[:] = array.array('I', self.q)
94
+ obj = kgsl.struct_kgsl_command_object(gpuaddr=hw_page_addr, size=len(self.q) * 4, flags=kgsl.KGSL_CMDLIST_IB)
95
+ submit_req = kgsl.struct_kgsl_gpu_command(cmdlist=ctypes.addressof(obj), numcmds=1, context_id=device.ctx,
96
+ cmdsize=ctypes.sizeof(kgsl.struct_kgsl_command_object))
97
+ return submit_req, obj
98
+
99
+ def bind(self, device):
100
+ self.binded_device = device
101
+ self.hw_page = device.allocator.alloc(len(self.q) * 4, BufferOptions(cpu_access=True, nolru=True))
102
+ self.submit_req, self.obj = self._build_gpu_command(self.binded_device, self.hw_page.va_addr)
103
+ # From now on, the queue is on the device for faster submission.
104
+ self.q = to_mv(self.obj.gpuaddr, len(self.q) * 4).cast("I") # type: ignore
105
+
106
+ def _submit(self, device):
107
+ if self.binded_device == device: submit_req = self.submit_req
108
+ else: submit_req, _ = self._build_gpu_command(device)
109
+ device.last_cmd = kgsl.IOCTL_KGSL_GPU_COMMAND(device.fd, __payload=submit_req).timestamp
110
+
111
+ def _exec(self, prg, args_state, global_size, local_size):
112
+ global_size_mp = [int(g*l) for g,l in zip(global_size, local_size)]
113
+ self.cmd_idx_to_dims[self._cur_cmd_idx()] = [global_size, local_size]
114
+
115
+ self.cmd(adreno.CP_SET_MARKER, qreg.a6xx_cp_set_marker_0(mode=adreno.RM6_COMPUTE))
116
+ self.reg(adreno.REG_A6XX_HLSQ_INVALIDATE_CMD, qreg.a6xx_hlsq_invalidate_cmd(cs_state=True, cs_ibo=True))
117
+ self.reg(adreno.REG_A6XX_HLSQ_INVALIDATE_CMD, 0x0)
118
+ self.reg(adreno.REG_A6XX_SP_CS_TEX_COUNT, qreg.a6xx_sp_cs_tex_count(0x80))
119
+ self.reg(adreno.REG_A6XX_SP_CS_IBO_COUNT, qreg.a6xx_sp_cs_ibo_count(0x40))
120
+ self.reg(adreno.REG_A6XX_SP_MODE_CONTROL, qreg.a6xx_sp_mode_control(isammode=adreno.ISAMMODE_CL))
121
+ self.reg(adreno.REG_A6XX_SP_PERFCTR_ENABLE, qreg.a6xx_sp_perfctr_enable(cs=True))
122
+ self.reg(adreno.REG_A6XX_SP_TP_MODE_CNTL, qreg.a6xx_sp_tp_mode_cntl(isammode=adreno.ISAMMODE_CL, unk3=2))
123
+ self.reg(adreno.REG_A6XX_TPL1_DBG_ECO_CNTL, 0)
124
+ self.cmd(adreno.CP_WAIT_FOR_IDLE)
125
+
126
+ self.reg(adreno.REG_A6XX_HLSQ_CS_NDRANGE_0,
127
+ qreg.a6xx_hlsq_cs_ndrange_0(kerneldim=3, localsizex=local_size[0] - 1, localsizey=local_size[1] - 1, localsizez=local_size[2] - 1),
128
+ global_size_mp[0], 0, global_size_mp[1], 0, global_size_mp[2], 0, 0xccc0cf, 0xfc | qreg.a6xx_hlsq_cs_cntl_1(threadsize=adreno.THREAD64),
129
+ int(math.ceil(global_size[0])), int(math.ceil(global_size[1])), int(math.ceil(global_size[2])))
130
+
131
+ self.reg(adreno.REG_A6XX_SP_CS_CTRL_REG0,
132
+ qreg.a6xx_sp_cs_ctrl_reg0(threadsize=adreno.THREAD64, halfregfootprint=prg.hregs, fullregfootprint=prg.fregs, branchstack=prg.brnchstck),
133
+ qreg.a6xx_sp_cs_unknown_a9b1(unk6=True, shared_size=prg.shared_size), 0, prg.prg_offset, *data64_le(prg.lib_gpu.va_addr),
134
+ qreg.a6xx_sp_cs_pvt_mem_param(memsizeperitem=prg.pvtmem_size_per_item), *data64_le(prg.device._stack.va_addr),
135
+ qreg.a6xx_sp_cs_pvt_mem_size(totalpvtmemsize=prg.pvtmem_size_total))
136
+
137
+ self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT,
138
+ state_block=adreno.SB6_CS_SHADER, num_unit=1024 // 4),
139
+ *data64_le(args_state.ptr))
140
+ self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_SHADER, state_src=adreno.SS6_INDIRECT,
141
+ state_block=adreno.SB6_CS_SHADER, num_unit=round_up(prg.image_size, 128) // 128),
142
+ *data64_le(prg.lib_gpu.va_addr))
143
+
144
+ self.reg(adreno.REG_A6XX_HLSQ_CONTROL_2_REG, 0xfcfcfcfc, 0xfcfcfcfc, 0xfcfcfcfc, 0xfc, qreg.a6xx_hlsq_cs_cntl(constlen=1024 // 4, enabled=True))
145
+
146
+ self.reg(adreno.REG_A6XX_SP_CS_PVT_MEM_HW_STACK_OFFSET, qreg.a6xx_sp_cs_pvt_mem_hw_stack_offset(prg.hw_stack_offset))
147
+ self.reg(adreno.REG_A6XX_SP_CS_INSTRLEN, qreg.a6xx_sp_cs_instrlen(prg.image_size // 4))
148
+
149
+ if args_state.prg.samp_cnt > 0:
150
+ self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_SHADER, state_src=adreno.SS6_INDIRECT,
151
+ state_block=adreno.SB6_CS_TEX, num_unit=args_state.prg.samp_cnt),
152
+ *data64_le(args_state.ptr + args_state.prg.samp_off))
153
+ self.reg(adreno.REG_A6XX_SP_CS_TEX_SAMP, *data64_le(args_state.ptr + args_state.prg.samp_off))
154
+ self.reg(adreno.REG_A6XX_SP_PS_TP_BORDER_COLOR_BASE_ADDR, *data64_le(prg.device._border_color_base()))
155
+
156
+ if args_state.prg.tex_cnt > 0:
157
+ self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT,
158
+ state_block=adreno.SB6_CS_TEX, num_unit=min(16, args_state.prg.tex_cnt)),
159
+ *data64_le(args_state.ptr + args_state.prg.tex_off))
160
+ self.reg(adreno.REG_A6XX_SP_CS_TEX_CONST, *data64_le(args_state.ptr + args_state.prg.tex_off))
161
+
162
+ if args_state.prg.ibo_cnt > 0:
163
+ self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST6_IBO, state_src=adreno.SS6_INDIRECT,
164
+ state_block=adreno.SB6_CS_SHADER, num_unit=args_state.prg.ibo_cnt),
165
+ *data64_le(args_state.ptr + args_state.prg.ibo_off))
166
+ self.reg(adreno.REG_A6XX_SP_CS_IBO, *data64_le(args_state.ptr + args_state.prg.ibo_off))
167
+
168
+ self.reg(adreno.REG_A6XX_SP_CS_CONFIG,
169
+ qreg.a6xx_sp_cs_config(enabled=True, nsamp=args_state.prg.samp_cnt, ntex=args_state.prg.tex_cnt, nibo=args_state.prg.ibo_cnt))
170
+ self.cmd(adreno.CP_RUN_OPENCL, 0)
171
+ self._cache_flush(write_back=True, invalidate=False, sync=False, memsync=False)
172
+
173
+ def _update_exec(self, cmd_idx, global_size, local_size):
174
+ if global_size is not None:
175
+ self._patch(cmd_idx, offset=29, data=[int(math.ceil(global_size[0])), int(math.ceil(global_size[1])), int(math.ceil(global_size[2]))])
176
+ self.cmd_idx_to_dims[cmd_idx][0] = global_size
177
+
178
+ if local_size is not None:
179
+ payload = qreg.a6xx_hlsq_cs_ndrange_0(kerneldim=3, localsizex=local_size[0] - 1, localsizey=local_size[1] - 1, localsizez=local_size[2] - 1)
180
+ self._patch(cmd_idx, offset=20, data=[payload])
181
+ self.cmd_idx_to_dims[cmd_idx][1] = local_size
182
+
183
+ global_size_mp = [int(g*l) for g,l in zip(self.cmd_idx_to_dims[cmd_idx][0], self.cmd_idx_to_dims[cmd_idx][1])]
184
+ self._patch(cmd_idx, offset=21, data=[global_size_mp[0], 0, global_size_mp[1], 0, global_size_mp[2], 0])
185
+
186
+ class QCOMArgsState(HCQArgsState):
187
+ def __init__(self, ptr:int, prg:QCOMProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()):
188
+ super().__init__(ptr, prg, bufs, vals=vals)
189
+
190
+ if len(bufs) + len(vals) != len(prg.buf_info): raise RuntimeError(f'incorrect args size given={len(bufs)+len(vals)} != want={len(prg.buf_info)}')
191
+
192
+ self.buf_info, self.args_info, self.args_view = prg.buf_info[:len(bufs)], prg.buf_info[len(bufs):], to_mv(ptr, prg.kernargs_alloc_size).cast('Q')
193
+
194
+ ctypes.memset(self.ptr, 0, prg.kernargs_alloc_size)
195
+ for cnst_val, cnst_off, cnst_sz in prg.consts_info: to_mv(self.ptr + cnst_off, cnst_sz)[:] = cnst_val.to_bytes(cnst_sz, byteorder='little')
196
+
197
+ if prg.samp_cnt > 0: to_mv(self.ptr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
198
+ for i, b in enumerate(cast(List[QCOMBuffer], bufs)):
199
+ if prg.buf_info[i].type is BUFTYPE_TEX: to_mv(self.ptr + prg.buf_info[i].offset, len(b.desc) * 4).cast('I')[:] = array.array('I', b.desc)
200
+ elif prg.buf_info[i].type is BUFTYPE_IBO: to_mv(self.ptr + prg.buf_info[i].offset, len(b.ibo) * 4).cast('I')[:] = array.array('I', b.ibo)
201
+ else: self.update_buffer(i, b)
202
+ for i, v in enumerate(vals): self.update_var(i, v)
203
+
204
+ def update_buffer(self, index:int, buf:HCQBuffer):
205
+ if self.buf_info[index].type is not BUFTYPE_BUF: self.args_view[self.buf_info[index].offset//8 + 2] = buf.va_addr
206
+ else: self.args_view[self.buf_info[index].offset//8] = buf.va_addr
207
+
208
+ def update_var(self, index:int, val:int): self.args_view[self.args_info[index].offset//8] = val
209
+
210
+ class QCOMProgram(HCQProgram):
211
+ def __init__(self, device: QCOMDevice, name: str, lib: bytes):
212
+ self.device, self.name, self.lib = device, name, lib
213
+ self._parse_lib()
214
+
215
+ self.lib_gpu = self.device.allocator.alloc(self.image_size, options=BufferOptions(cpu_access=True, nolru=True))
216
+ to_mv(self.lib_gpu.va_addr, self.image_size)[:] = self.image
217
+
218
+ self.pvtmem_size_per_item = round_up(self.pvtmem, 512) >> 9
219
+ self.pvtmem_size_total = self.pvtmem_size_per_item * 128 * 2
220
+ self.hw_stack_offset = round_up(next_power2(round_up(self.pvtmem, 512)) * 128 * 16, 0x1000)
221
+ self.shared_size = max(1, (self.shmem - 1) // 1024)
222
+ self.max_threads = min(1024, ((384 * 32) // (max(1, (self.fregs + round_up(self.hregs, 2) // 2)) * 128)) * 128)
223
+ device._ensure_stack_size(self.hw_stack_offset * 4)
224
+
225
+ kernargs_alloc_size = round_up(2048 + (self.tex_cnt + self.ibo_cnt) * 0x40 + self.samp_cnt * 0x10, 0x100)
226
+ super().__init__(QCOMArgsState, self.device, self.name, kernargs_alloc_size=kernargs_alloc_size)
227
+
228
+ def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
229
+ if self.max_threads < prod(local_size): raise RuntimeError("Too many resources requested for launch")
230
+ if any(g*l>mx for g,l,mx in zip(global_size, local_size, [65536, 65536, 65536])) and any(l>mx for l,mx in zip(local_size, [1024, 1024, 1024])):
231
+ raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}")
232
+ return super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait)
233
+
234
+ def _parse_lib(self):
235
+ def _read_lib(off): return struct.unpack("I", self.lib[off:off+4])[0]
236
+
237
+ # Extract image binary
238
+ self.image_size = _read_lib(0x100)
239
+ self.image = bytearray(self.lib[(image_offset:=_read_lib(0xc0)):image_offset+self.image_size])
240
+
241
+ # Parse image descriptors
242
+ image_desc_off = _read_lib(0x110)
243
+ self.prg_offset, self.brnchstck = _read_lib(image_desc_off+0xc4), _read_lib(image_desc_off+0x108) // 2
244
+ self.pvtmem, self.shmem = _read_lib(image_desc_off+0xc8), _read_lib(image_desc_off+0xd8)
245
+
246
+ # Fill up constants and buffers info
247
+ self.buf_info, self.consts_info = [], []
248
+
249
+ # Collect sampler info.
250
+ self.samp_cnt = samp_cnt_in_file = _read_lib(image_desc_off + 0xdc)
251
+ assert self.samp_cnt <= 1, "Up to one sampler supported"
252
+ if self.samp_cnt:
253
+ self.samp_cnt += 1
254
+ self.samplers = [qreg.a6xx_tex_samp_0(wrap_s=(clamp_mode:=adreno.A6XX_TEX_CLAMP_TO_BORDER), wrap_t=clamp_mode, wrap_r=clamp_mode),
255
+ qreg.a6xx_tex_samp_1(unnorm_coords=True, cubemapseamlessfiltoff=True), 0, 0, 0, 0, 0, 0]
256
+
257
+ # Collect kernel arguments (buffers) info.
258
+ bdoff = round_up(image_desc_off + 0x158 + len(self.name), 4) + 8 * samp_cnt_in_file
259
+ while bdoff + 32 <= len(self.lib):
260
+ length, _, _, offset_words, _, _, _, typ = struct.unpack("IIIIIIII", self.lib[bdoff:bdoff+32])
261
+ if length == 0: break
262
+ self.buf_info.append(SimpleNamespace(offset=offset_words * 4, type=typ))
263
+ bdoff += length
264
+
265
+ # Setting correct offsets to textures/ibos.
266
+ self.tex_cnt, self.ibo_cnt = sum(x.type is BUFTYPE_TEX for x in self.buf_info), sum(x.type is BUFTYPE_IBO for x in self.buf_info)
267
+ self.ibo_off, self.tex_off, self.samp_off = 2048, 2048 + 0x40 * self.ibo_cnt, 2048 + 0x40 * self.tex_cnt + 0x40 * self.ibo_cnt
268
+ cur_ibo_off, cur_tex_off = self.ibo_off, self.tex_off
269
+ for x in self.buf_info:
270
+ if x.type is BUFTYPE_IBO: x.offset, cur_ibo_off = cur_ibo_off, cur_ibo_off + 0x40
271
+ elif x.type is BUFTYPE_TEX: x.offset, cur_tex_off = cur_tex_off, cur_tex_off + 0x40
272
+
273
+ if _read_lib(0xb0) != 0: # check if we have constants.
274
+ cdoff = _read_lib(0xac)
275
+ while cdoff + 40 <= image_offset:
276
+ cnst, offset_words, _, is32 = struct.unpack("I", self.lib[cdoff:cdoff+4])[0], *struct.unpack("III", self.lib[cdoff+16:cdoff+28])
277
+ self.consts_info.append((cnst, offset_words * (sz_bytes:=(2 << is32)), sz_bytes))
278
+ cdoff += 40
279
+
280
+ # Registers info
281
+ reg_desc_off = _read_lib(0x34)
282
+ self.fregs, self.hregs = _read_lib(reg_desc_off + 0x14), _read_lib(reg_desc_off + 0x18)
283
+
284
+ def __del__(self):
285
+ if hasattr(self, 'lib_gpu'): self.device.allocator.free(self.lib_gpu, self.lib_gpu.size, options=BufferOptions(cpu_access=True, nolru=True))
286
+
287
+ class QCOMBuffer(HCQBuffer):
288
+ def __init__(self, va_addr:int, size:int, info=None, mapped=False, desc=None, ibo=None, pitch=None, real_stride=None, **kwargs):
289
+ self.va_addr, self.size, self.info, self.mapped = va_addr, size, info, mapped
290
+
291
+ # Texture specific definitions
292
+ self.desc, self.ibo, self.pitch, self.real_stride = [0] * 16, [0] * 16, pitch, real_stride
293
+
294
+ class QCOMAllocator(HCQAllocator):
295
+ def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer:
296
+ if options.image is not None:
297
+ imgw, imgh, itemsize_log = options.image.shape[1], options.image.shape[0], int(math.log2(options.image.itemsize))
298
+ pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
299
+ align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2))
300
+
301
+ granularity = 128 if options.image.itemsize == 4 else 256
302
+ pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
303
+ pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add
304
+
305
+ if options.external_ptr: texture = QCOMBuffer(options.external_ptr, size)
306
+ else: texture = self.device._gpu_alloc(pitch * imgh, kgsl.KGSL_MEMTYPE_TEXTURE)
307
+
308
+ texture.pitch, texture.real_stride = pitch, real_stride
309
+
310
+ tex_fmt = adreno.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else adreno.FMT6_16_16_16_16_FLOAT
311
+ texture.desc[0] = qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt)
312
+ texture.desc[1] = qreg.a6xx_tex_const_1(width=imgw, height=imgh)
313
+ texture.desc[2] = qreg.a6xx_tex_const_2(type=adreno.A6XX_TEX_2D, pitch=texture.pitch, pitchalign=pitchalign-6)
314
+ texture.desc[4:8] = [*data64_le(texture.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13)]
315
+ texture.ibo = [texture.desc[0] & (~0xffff), *texture.desc[1:len(texture.desc)]]
316
+
317
+ return texture
318
+
319
+ return QCOMBuffer(options.external_ptr, size) if options.external_ptr else self.device._gpu_alloc(size)
320
+
321
+ def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, dest_off=0, src_off=0):
322
+ while src_off < src_size:
323
+ ctypes.memmove(dest_addr+dest_off, src_addr+src_off, real_size)
324
+ src_off, dest_off = src_off+src_stride, dest_off+dest_stride
325
+
326
+ def copyin(self, dest:HCQBuffer, src:memoryview):
327
+ if (qd:=cast(QCOMBuffer, dest)).pitch is not None: self._do_copy(mv_address(src), qd.va_addr, len(src), qd.real_stride, qd.real_stride, qd.pitch)
328
+ else: ctypes.memmove(dest.va_addr, mv_address(src), src.nbytes)
329
+
330
+ def copyout(self, dest:memoryview, src:HCQBuffer):
331
+ self.device.synchronize()
332
+ if (qs:=cast(QCOMBuffer, src)).pitch is not None: self._do_copy(qs.va_addr, mv_address(dest), qs.size, qs.real_stride, qs.pitch, qs.real_stride)
333
+ else: ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
334
+
335
+ def as_buffer(self, src:HCQBuffer) -> memoryview:
336
+ self.device.synchronize()
337
+ return to_mv(src.va_addr, src.size)
338
+
339
+ def _free(self, opaque, options:BufferOptions):
340
+ self.device.synchronize()
341
+ self.device._gpu_free(opaque)
342
+
343
+ class QCOMDevice(HCQCompiled):
344
+ signals_page: Any = None
345
+ signals_pool: List[Any] = []
346
+ gpu_id: int = 0
347
+ dummy_addr: int = 0
348
+
349
+ def __init__(self, device:str=""):
350
+ self.fd = os.open('/dev/kgsl-3d0', os.O_RDWR)
351
+ QCOMDevice.dummy_addr = self._gpu_alloc(0x1000).va_addr
352
+ QCOMDevice.signals_page = self._gpu_alloc(16 * 65536, uncached=True)
353
+ QCOMDevice.signals_pool = [to_mv(self.signals_page.va_addr + off, 16).cast("Q") for off in range(0, self.signals_page.size, 16)]
354
+ info, self.ctx, self.cmd_buf, self.cmd_buf_ptr, self.last_cmd = self._info(), self._ctx_create(), self._gpu_alloc(16 << 20), 0,0
355
+ QCOMDevice.gpu_id = ((info.chip_id >> 24) & 0xFF) * 100 + ((info.chip_id >> 16) & 0xFF) * 10 + ((info.chip_id >> 8) & 0xFF)
356
+ if QCOMDevice.gpu_id >= 700: raise RuntimeError(f"Unsupported GPU: {QCOMDevice.gpu_id}")
357
+
358
+ super().__init__(device, QCOMAllocator(self), QCOMRenderer(), QCOMCompiler(device), functools.partial(QCOMProgram, self),
359
+ QCOMSignal, QCOMComputeQueue, None)
360
+
361
+ def _ctx_create(self):
362
+ cr = kgsl.IOCTL_KGSL_DRAWCTXT_CREATE(self.fd, flags=(kgsl.KGSL_CONTEXT_PREAMBLE | kgsl.KGSL_CONTEXT_PWR_CONSTRAINT |
363
+ kgsl.KGSL_CONTEXT_NO_FAULT_TOLERANCE | kgsl.KGSL_CONTEXT_NO_GMEM_ALLOC | kgsl.KGSL_CONTEXT_PRIORITY(8) |
364
+ kgsl.KGSL_CONTEXT_PREEMPT_STYLE(kgsl.KGSL_CONTEXT_PREEMPT_STYLE_FINEGRAIN)))
365
+
366
+ # Set power to maximum.
367
+ struct.pack_into('IIQQ', pwr:=memoryview(bytearray(0x18)), 0, 1, cr.drawctxt_id, mv_address(_:=memoryview(array.array('I', [1]))), 4)
368
+ kgsl.IOCTL_KGSL_SETPROPERTY(self.fd, type=kgsl.KGSL_PROP_PWR_CONSTRAINT, value=mv_address(pwr), sizebytes=pwr.nbytes)
369
+ return cr.drawctxt_id
370
+
371
+ def _info(self):
372
+ info = kgsl.struct_kgsl_devinfo()
373
+ kgsl.IOCTL_KGSL_DEVICE_GETPROPERTY(self.fd, type=kgsl.KGSL_PROP_DEVICE_INFO, value=ctypes.addressof(info), sizebytes=ctypes.sizeof(info))
374
+ return info
375
+
376
+ def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False):
377
+ flags |= kgsl.KGSL_MEMALIGN(alignment_hint:=12) | kgsl.KGSL_MEMFLAGS_USE_CPU_MAP
378
+ if uncached: flags |= kgsl.KGSL_CACHEMODE(kgsl.KGSL_CACHEMODE_UNCACHED)
379
+
380
+ alloc = kgsl.IOCTL_KGSL_GPUOBJ_ALLOC(self.fd, size=(bosz:=round_up(size, 1<<alignment_hint)), flags=flags, mmapsize=bosz)
381
+ va_addr = libc.mmap(0, bosz, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, self.fd, alloc.id * 0x1000)
382
+
383
+ if fill_zeroes: ctypes.memset(va_addr, 0, size)
384
+ return QCOMBuffer(va_addr=va_addr, size=size, info=alloc)
385
+
386
+ def _gpu_free(self, mem):
387
+ kgsl.IOCTL_KGSL_GPUOBJ_FREE(self.fd, id=mem.info.id)
388
+ libc.munmap(mem.va_addr, mem.info.mmapsize)
389
+
390
+ def _alloc_cmd_buf(self, sz: int):
391
+ self.cmd_buf_ptr = (cur_ptr:=self.cmd_buf_ptr if self.cmd_buf_ptr + sz < self.cmd_buf.size else 0) + sz
392
+ return self.cmd_buf.va_addr + cur_ptr
393
+
394
+ def _border_color_base(self):
395
+ if not hasattr(self, '_border_color_gpu'): self._border_color_gpu = self._gpu_alloc(0x1000, fill_zeroes=True)
396
+ return self._border_color_gpu.va_addr
397
+
398
+ def _ensure_stack_size(self, sz):
399
+ if not hasattr(self, '_stack'): self._stack = self._gpu_alloc(sz)
400
+ elif self._stack.size < sz:
401
+ self.synchronize()
402
+ self._gpu_free(self._stack)
403
+ self._stack = self._gpu_alloc(sz)
404
+
405
+ def _syncdev(self): kgsl.IOCTL_KGSL_DEVICE_WAITTIMESTAMP_CTXTID(self.fd, context_id=self.ctx, timestamp=self.last_cmd, timeout=0xffffffff)
File without changes
@@ -0,0 +1,77 @@
1
+ import subprocess, hashlib, tempfile, ctypes, ctypes.util, re, pathlib
2
+ from typing import Callable
3
+ from tinygrad.helpers import to_char_p_p, colored, init_c_var, getenv
4
+ import tinygrad.runtime.autogen.nvrtc as nvrtc
5
+ from tinygrad.device import Compiler, CompileError
6
+
7
+ PTX = getenv("PTX") # this shouldn't be here, in fact, it shouldn't exist
8
+
9
+ def _get_bytes(arg, get_str, get_sz, check) -> bytes:
10
+ sz = init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x))))
11
+ return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value)
12
+
13
+ def nvrtc_check(status, ctx=None):
14
+ if status != 0:
15
+ err_log = _get_bytes(ctx, nvrtc.nvrtcGetProgramLog, nvrtc.nvrtcGetProgramLogSize, lambda _: None).decode() if ctx else ""
16
+ raise CompileError(f"Nvrtc Error {status}, {ctypes.string_at(nvrtc.nvrtcGetErrorString(status)).decode()}\n{err_log}")
17
+
18
+ def jitlink_check(status, ctx=None):
19
+ if status != 0:
20
+ err_log = _get_bytes(ctx, nvrtc.nvJitLinkGetErrorLog, nvrtc.nvJitLinkGetErrorLogSize, lambda _: None).decode() if ctx else ""
21
+ raise CompileError(f"NvJitLink Error {status}, {nvrtc.nvJitLinkResult__enumvalues.get(status, 'Unknown')}\n{err_log}")
22
+
23
+ def pretty_ptx(s):
24
+ # all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
25
+ s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers # noqa: E501
26
+ s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
27
+ s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
28
+ s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers # noqa: E501
29
+ s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
30
+ s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
31
+ return s
32
+
33
+ def cuda_disassemble(lib, arch):
34
+ try:
35
+ fn = (pathlib.Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
36
+ with open(fn + ".ptx", "wb") as f: f.write(lib)
37
+ subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn+".ptx"], check=True)
38
+ print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
39
+ except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains ptxas/nvdisasm binary of compatible version.")
40
+
41
+ class CUDACompiler(Compiler):
42
+ def __init__(self, arch:str, cache_key:str="cuda"):
43
+ self.arch, self.compile_options = arch, [f'--gpu-architecture={arch}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"]
44
+ nvrtc_check(nvrtc.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int())))
45
+ if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal")
46
+ super().__init__(f"compile_{cache_key}_{self.arch}")
47
+ def _compile_program(self, src:str, nvrtc_get_content:Callable, nvrtc_get_size:Callable) -> bytes:
48
+ nvrtc_check(nvrtc.nvrtcCreateProgram(ctypes.byref(prog := nvrtc.nvrtcProgram()), src.encode(), "<null>".encode(), 0, None, None))
49
+ nvrtc_check(nvrtc.nvrtcCompileProgram(prog, len(self.compile_options), to_char_p_p([o.encode() for o in self.compile_options])), prog)
50
+ data = _get_bytes(prog, nvrtc_get_content, nvrtc_get_size, nvrtc_check)
51
+ nvrtc_check(nvrtc.nvrtcDestroyProgram(ctypes.byref(prog)))
52
+ return data
53
+ def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetPTX, nvrtc.nvrtcGetPTXSize)
54
+ def disassemble(self, lib:bytes):
55
+ try:
56
+ fn = (pathlib.Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
57
+ with open(fn + ".cubin", "wb") as f: f.write(lib)
58
+ print(subprocess.check_output(["nvdisasm", fn+".cubin"]).decode('utf-8'))
59
+ except Exception as e: print("Failed to disasm cubin:", str(e), "Make sure your PATH contains nvdisasm binary of compatible version.")
60
+
61
+ class NVCompiler(CUDACompiler):
62
+ def __init__(self, arch:str): super().__init__(arch, cache_key="nv")
63
+ def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetCUBIN, nvrtc.nvrtcGetCUBINSize)
64
+
65
+ class PTXCompiler(CUDACompiler):
66
+ def __init__(self, arch:str, cache_key="ptx"): super().__init__(arch, cache_key=cache_key)
67
+ def compile(self, src:str) -> bytes: return src.replace("TARGET", self.arch).replace("VERSION", "7.8" if self.arch >= "sm_89" else "7.5").encode()
68
+
69
+ class NVPTXCompiler(PTXCompiler):
70
+ def __init__(self, arch:str): super().__init__(arch, cache_key="nv_ptx")
71
+ def compile(self, src:str) -> bytes:
72
+ jitlink_check(nvrtc.nvJitLinkCreate(handle := nvrtc.nvJitLinkHandle(), 1, to_char_p_p([f'-arch={self.arch}'.encode()])), handle)
73
+ jitlink_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc:=super().compile(src), len(ptxsrc), "<null>".encode()), handle)
74
+ jitlink_check(nvrtc.nvJitLinkComplete(handle), handle)
75
+ data = _get_bytes(handle, nvrtc.nvJitLinkGetLinkedCubin, nvrtc.nvJitLinkGetLinkedCubinSize, jitlink_check)
76
+ jitlink_check(nvrtc.nvJitLinkDestroy(handle))
77
+ return data
@@ -1,5 +1,6 @@
1
- import ctypes
1
+ import ctypes, subprocess
2
2
  import tinygrad.runtime.autogen.comgr as comgr
3
+ from tinygrad.device import Compiler, CompileError
3
4
 
4
5
  def check(status):
5
6
  if status != 0:
@@ -54,3 +55,14 @@ def compile_hip(prg:str, arch="gfx1100", asm=False) -> bytes:
54
55
  for x in [data_set_src, data_set_bc, data_set_reloc, data_set_exec]: check(comgr.amd_comgr_destroy_data_set(x))
55
56
  check(comgr.amd_comgr_destroy_action_info(action_info))
56
57
  return ret
58
+
59
+ class AMDCompiler(Compiler):
60
+ def __init__(self, arch:str):
61
+ self.arch = arch
62
+ super().__init__(f"compile_hip_{self.arch}")
63
+ def compile(self, src:str) -> bytes:
64
+ try: return compile_hip(src, self.arch)
65
+ except RuntimeError as e: raise CompileError(e) from e
66
+ def disassemble(self, lib:bytes):
67
+ asm = subprocess.check_output(["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib)
68
+ print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
@@ -0,0 +1,38 @@
1
+ from __future__ import annotations
2
+ from typing import Tuple, List, Any
3
+ from dataclasses import dataclass
4
+ import tinygrad.runtime.autogen.libc as libc
5
+
6
+ @dataclass(frozen=True)
7
+ class ElfSection: name:str; header:libc.Elf64_Shdr; content:bytes # noqa: E702
8
+
9
+ def elf_loader(blob:bytes, force_section_align:int=1) -> Tuple[memoryview, List[ElfSection], Any]:
10
+ def _strtab(blob: bytes, idx: int) -> str: return blob[idx:blob.find(b'\x00', idx)].decode('utf-8')
11
+
12
+ header = libc.Elf64_Ehdr.from_buffer_copy(blob)
13
+ section_headers = (libc.Elf64_Shdr * header.e_shnum).from_buffer_copy(blob[header.e_shoff:])
14
+ sh_strtab = blob[(shstrst:=section_headers[header.e_shstrndx].sh_offset):shstrst+section_headers[header.e_shstrndx].sh_size]
15
+ sections = [ElfSection(_strtab(sh_strtab, sh.sh_name), sh, blob[sh.sh_offset:sh.sh_offset+sh.sh_size]) for sh in section_headers]
16
+
17
+ def _to_carray(sh, ctype): return (ctype * (sh.header.sh_size // sh.header.sh_entsize)).from_buffer_copy(sh.content)
18
+ rel = [(sh, sh.name[4:], _to_carray(sh, libc.Elf64_Rel)) for sh in sections if sh.header.sh_type == libc.SHT_REL]
19
+ rela = [(sh, sh.name[5:], _to_carray(sh, libc.Elf64_Rela)) for sh in sections if sh.header.sh_type == libc.SHT_RELA]
20
+ symtab = [_to_carray(sh, libc.Elf64_Sym) for sh in sections if sh.header.sh_type == libc.SHT_SYMTAB][0]
21
+ progbits = [sh for sh in sections if sh.header.sh_type == libc.SHT_PROGBITS]
22
+
23
+ # Prealloc image for all fixed addresses.
24
+ image = bytearray(max([sh.header.sh_addr + sh.header.sh_size for sh in progbits if sh.header.sh_addr != 0] + [0]))
25
+ for sh in progbits:
26
+ if sh.header.sh_addr != 0: image[sh.header.sh_addr:sh.header.sh_addr+sh.header.sh_size] = sh.content
27
+ else:
28
+ image += b'\0' * (((align:=max(sh.header.sh_addralign, force_section_align)) - len(image) % align) % align) + sh.content
29
+ sh.header.sh_addr = len(image) - len(sh.content)
30
+
31
+ # Relocations
32
+ relocs = []
33
+ for sh, trgt_sh_name, c_rels in rel + rela:
34
+ target_image_off = next(tsh for tsh in sections if tsh.name == trgt_sh_name).header.sh_addr
35
+ rels = [(r.r_offset, symtab[libc.ELF64_R_SYM(r.r_info)], libc.ELF64_R_TYPE(r.r_info), getattr(r, "r_addend", 0)) for r in c_rels]
36
+ relocs += [(target_image_off + roff, sections[sym.st_shndx].header.sh_addr + sym.st_value, rtype, raddend) for roff, sym, rtype, raddend in rels]
37
+
38
+ return memoryview(image), sections, relocs