tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,463 @@
|
|
1
|
+
import ctypes, time, contextlib
|
2
|
+
from typing import Literal
|
3
|
+
from tinygrad.runtime.autogen.am import am, smu_v13_0_0
|
4
|
+
from tinygrad.helpers import to_mv, data64, lo32, hi32, DEBUG
|
5
|
+
|
6
|
+
class AM_IP:
|
7
|
+
def __init__(self, adev): self.adev = adev
|
8
|
+
def init(self): raise NotImplementedError("IP block init must be implemeted")
|
9
|
+
def fini(self): pass
|
10
|
+
|
11
|
+
class AM_SOC21(AM_IP):
|
12
|
+
def init(self):
|
13
|
+
self.adev.regRCC_DEV0_EPF2_STRAP2.update(strap_no_soft_reset_dev0_f2=0x0)
|
14
|
+
self.adev.regRCC_DEV0_EPF0_RCC_DOORBELL_APER_EN.write(0x1)
|
15
|
+
|
16
|
+
class AM_GMC(AM_IP):
|
17
|
+
def __init__(self, adev):
|
18
|
+
super().__init__(adev)
|
19
|
+
|
20
|
+
# Memory controller aperture
|
21
|
+
self.mc_base = (self.adev.regMMMC_VM_FB_LOCATION_BASE.read() & 0xFFFFFF) << 24
|
22
|
+
self.mc_end = self.mc_base + self.adev.mm.vram_size - 1
|
23
|
+
|
24
|
+
# VM aperture
|
25
|
+
self.vm_base = self.adev.mm.va_allocator.base
|
26
|
+
self.vm_end = self.vm_base + self.adev.mm.va_allocator.size - 1
|
27
|
+
|
28
|
+
# GFX11 has 44-bit address space
|
29
|
+
self.address_space_mask = (1 << 44) - 1
|
30
|
+
|
31
|
+
self.memscratch_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
|
32
|
+
self.dummy_page_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
|
33
|
+
self.hub_initted = {"MM": False, "GC": False}
|
34
|
+
|
35
|
+
def init(self): self.init_hub("MM")
|
36
|
+
|
37
|
+
def flush_hdp(self): self.adev.regBIF_BX_PF0_GPU_HDP_FLUSH_REQ.write(0xffffffff)
|
38
|
+
def flush_tlb(self, ip:Literal["MM", "GC"], vmid, flush_type=0):
|
39
|
+
self.flush_hdp()
|
40
|
+
|
41
|
+
# Can't issue TLB invalidation if the hub isn't initialized.
|
42
|
+
if not self.hub_initted[ip]: return
|
43
|
+
|
44
|
+
if ip == "MM": self.adev.wait_reg(self.adev.regMMVM_INVALIDATE_ENG17_SEM, mask=0x1, value=0x1)
|
45
|
+
|
46
|
+
self.adev.reg(f"reg{ip}VM_INVALIDATE_ENG17_REQ").write(flush_type=flush_type, per_vmid_invalidate_req=(1 << vmid), invalidate_l2_ptes=1,
|
47
|
+
invalidate_l2_pde0=1, invalidate_l2_pde1=1, invalidate_l2_pde2=1, invalidate_l1_ptes=1, clear_protection_fault_status_addr=0)
|
48
|
+
|
49
|
+
self.adev.wait_reg(self.adev.reg(f"reg{ip}VM_INVALIDATE_ENG17_ACK"), mask=(1 << vmid), value=(1 << vmid))
|
50
|
+
|
51
|
+
if ip == "MM":
|
52
|
+
self.adev.regMMVM_INVALIDATE_ENG17_SEM.write(0x0)
|
53
|
+
self.adev.regMMVM_L2_BANK_SELECT_RESERVED_CID2.update(reserved_cache_private_invalidation=1)
|
54
|
+
|
55
|
+
# Read back the register to ensure the invalidation is complete
|
56
|
+
self.adev.regMMVM_L2_BANK_SELECT_RESERVED_CID2.read()
|
57
|
+
|
58
|
+
def enable_vm_addressing(self, page_table, ip:Literal["MM", "GC"], vmid):
|
59
|
+
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_START_ADDR", "_LO32", "_HI32", self.vm_base >> 12)
|
60
|
+
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_END_ADDR", "_LO32", "_HI32", self.vm_end >> 12)
|
61
|
+
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", page_table.paddr | 1)
|
62
|
+
self.adev.reg(f"reg{ip}VM_CONTEXT{vmid}_CNTL").write(0x1fffe00, enable_context=1, page_table_depth=(3 - page_table.lv))
|
63
|
+
|
64
|
+
def init_hub(self, ip:Literal["MM", "GC"]):
|
65
|
+
# Init system apertures
|
66
|
+
self.adev.reg(f"reg{ip}MC_VM_AGP_BASE").write(0)
|
67
|
+
self.adev.reg(f"reg{ip}MC_VM_AGP_BOT").write(0xffffffffffff >> 24) # disable AGP
|
68
|
+
self.adev.reg(f"reg{ip}MC_VM_AGP_TOP").write(0)
|
69
|
+
|
70
|
+
self.adev.reg(f"reg{ip}MC_VM_SYSTEM_APERTURE_LOW_ADDR").write(self.mc_base >> 18)
|
71
|
+
self.adev.reg(f"reg{ip}MC_VM_SYSTEM_APERTURE_HIGH_ADDR").write(self.mc_end >> 18)
|
72
|
+
self.adev.wreg_pair(f"reg{ip}MC_VM_SYSTEM_APERTURE_DEFAULT_ADDR", "_LSB", "_MSB", self.memscratch_paddr >> 12)
|
73
|
+
self.adev.wreg_pair(f"reg{ip}VM_L2_PROTECTION_FAULT_DEFAULT_ADDR", "_LO32", "_HI32", self.dummy_page_paddr >> 12)
|
74
|
+
|
75
|
+
self.adev.reg(f"reg{ip}VM_L2_PROTECTION_FAULT_CNTL2").update(active_page_migration_pte_read_retry=1)
|
76
|
+
|
77
|
+
# Init TLB and cache
|
78
|
+
self.adev.reg(f"reg{ip}MC_VM_MX_L1_TLB_CNTL").update(enable_l1_tlb=1, system_access_mode=3, enable_advanced_driver_model=1,
|
79
|
+
system_aperture_unmapped_access=0, eco_bits=0, mtype=am.MTYPE_UC)
|
80
|
+
|
81
|
+
self.adev.reg(f"reg{ip}VM_L2_CNTL").update(enable_l2_cache=1, enable_l2_fragment_processing=0, enable_default_page_out_to_system_memory=1,
|
82
|
+
l2_pde0_cache_tag_generation_mode=0, pde_fault_classification=0, context1_identity_access_mode=1, identity_mode_fragment_size=0)
|
83
|
+
self.adev.reg(f"reg{ip}VM_L2_CNTL2").update(invalidate_all_l1_tlbs=1, invalidate_l2_cache=1)
|
84
|
+
self.adev.reg(f"reg{ip}VM_L2_CNTL3").write(bank_select=9, l2_cache_bigk_fragment_size=6,l2_cache_4k_associativity=1,l2_cache_bigk_associativity=1)
|
85
|
+
self.adev.reg(f"reg{ip}VM_L2_CNTL4").write(l2_cache_4k_partition_count=1)
|
86
|
+
self.adev.reg(f"reg{ip}VM_L2_CNTL5").write(walker_priority_client_id=0x1ff)
|
87
|
+
|
88
|
+
self.enable_vm_addressing(self.adev.mm.root_page_table, ip, vmid=0)
|
89
|
+
|
90
|
+
# Disable identity aperture
|
91
|
+
self.adev.wreg_pair(f"reg{ip}VM_L2_CONTEXT1_IDENTITY_APERTURE_LOW_ADDR", "_LO32", "_HI32", 0xfffffffff)
|
92
|
+
self.adev.wreg_pair(f"reg{ip}VM_L2_CONTEXT1_IDENTITY_APERTURE_HIGH_ADDR", "_LO32", "_HI32", 0x0)
|
93
|
+
self.adev.wreg_pair(f"reg{ip}VM_L2_CONTEXT_IDENTITY_PHYSICAL_OFFSET", "_LO32", "_HI32", 0x0)
|
94
|
+
|
95
|
+
for eng_i in range(18): self.adev.wreg_pair(f"reg{ip}VM_INVALIDATE_ENG{eng_i}_ADDR_RANGE", "_LO32", "_HI32", 0x1fffffffff)
|
96
|
+
self.hub_initted[ip] = True
|
97
|
+
|
98
|
+
def on_interrupt(self):
|
99
|
+
for ip in ["MM", "GC"]:
|
100
|
+
st, va = self.adev.reg(f'reg{ip}VM_L2_PROTECTION_FAULT_STATUS').read(), self.adev.reg(f'reg{ip}VM_L2_PROTECTION_FAULT_ADDR_LO32').read()
|
101
|
+
va = (va | (self.adev.reg(f'reg{ip}VM_L2_PROTECTION_FAULT_ADDR_HI32').read()) << 32) << 12
|
102
|
+
if self.adev.reg(f"reg{ip}VM_L2_PROTECTION_FAULT_STATUS").read(): raise RuntimeError(f"{ip}VM_L2_PROTECTION_FAULT_STATUS: {st:#x} {va:#x}")
|
103
|
+
|
104
|
+
class AM_SMU(AM_IP):
|
105
|
+
def __init__(self, adev):
|
106
|
+
super().__init__(adev)
|
107
|
+
self.driver_table_paddr = self.adev.mm.palloc(0x4000, zero=not self.adev.partial_boot, boot=True)
|
108
|
+
|
109
|
+
def init(self):
|
110
|
+
self._send_msg(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrHigh, hi32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
|
111
|
+
self._send_msg(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrLow, lo32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
|
112
|
+
self._send_msg(smu_v13_0_0.PPSMC_MSG_EnableAllSmuFeatures, 0, poll=True)
|
113
|
+
|
114
|
+
def is_smu_alive(self):
|
115
|
+
with contextlib.suppress(RuntimeError): self._send_msg(smu_v13_0_0.PPSMC_MSG_GetSmuVersion, 0, timeout=100)
|
116
|
+
return self.adev.mmMP1_SMN_C2PMSG_90.read() != 0
|
117
|
+
|
118
|
+
def mode1_reset(self):
|
119
|
+
if DEBUG >= 2: print(f"am {self.adev.devfmt}: mode1 reset")
|
120
|
+
self._send_msg(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True)
|
121
|
+
time.sleep(0.5) # 500ms
|
122
|
+
|
123
|
+
def read_table(self, table_t, cmd):
|
124
|
+
self._send_msg(smu_v13_0_0.PPSMC_MSG_TransferTableSmu2Dram, cmd, poll=True)
|
125
|
+
return table_t.from_buffer(to_mv(self.adev.paddr2cpu(self.driver_table_paddr), ctypes.sizeof(table_t)))
|
126
|
+
def read_metrics(self): return self.read_table(smu_v13_0_0.SmuMetricsExternal_t, smu_v13_0_0.TABLE_SMU_METRICS)
|
127
|
+
|
128
|
+
def set_clocks(self, level):
|
129
|
+
if not hasattr(self, 'clcks'):
|
130
|
+
self.clcks = {}
|
131
|
+
for clck in [smu_v13_0_0.PPCLK_GFXCLK, smu_v13_0_0.PPCLK_UCLK, smu_v13_0_0.PPCLK_FCLK, smu_v13_0_0.PPCLK_SOCCLK]:
|
132
|
+
cnt = self._send_msg(smu_v13_0_0.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|0xff, read_back_arg=True)&0x7fffffff
|
133
|
+
self.clcks[clck] = [self._send_msg(smu_v13_0_0.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|i, read_back_arg=True)&0x7fffffff for i in range(cnt)]
|
134
|
+
|
135
|
+
for clck, vals in self.clcks.items():
|
136
|
+
self._send_msg(smu_v13_0_0.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]), poll=True)
|
137
|
+
self._send_msg(smu_v13_0_0.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level]), poll=True)
|
138
|
+
|
139
|
+
def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout)
|
140
|
+
def _smu_cmn_send_msg(self, msg, param=0):
|
141
|
+
self.adev.mmMP1_SMN_C2PMSG_90.write(0) # resp reg
|
142
|
+
self.adev.mmMP1_SMN_C2PMSG_82.write(param)
|
143
|
+
self.adev.mmMP1_SMN_C2PMSG_66.write(msg)
|
144
|
+
|
145
|
+
def _send_msg(self, msg, param, poll=True, read_back_arg=False, timeout=10000): # 10s
|
146
|
+
if poll: self._smu_cmn_poll_stat(timeout=timeout)
|
147
|
+
|
148
|
+
self._smu_cmn_send_msg(msg, param)
|
149
|
+
self._smu_cmn_poll_stat(timeout=timeout)
|
150
|
+
return self.adev.mmMP1_SMN_C2PMSG_82.read() if read_back_arg else None
|
151
|
+
|
152
|
+
class AM_GFX(AM_IP):
|
153
|
+
def init(self):
|
154
|
+
# Wait for RLC autoload to complete
|
155
|
+
while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read(bootload_complete=1) != 0: pass
|
156
|
+
|
157
|
+
self._config_gfx_rs64()
|
158
|
+
self.adev.gmc.init_hub("GC")
|
159
|
+
|
160
|
+
# NOTE: Golden reg for gfx11. No values for this reg provided. The kernel just ors 0x20000000 to this reg.
|
161
|
+
self.adev.regTCP_CNTL.write(self.adev.regTCP_CNTL.read() | 0x20000000)
|
162
|
+
self.adev.regRLC_SRM_CNTL.update(srm_enable=1, auto_incr_addr=1)
|
163
|
+
|
164
|
+
self.adev.regGRBM_CNTL.update(read_timeout=0xff)
|
165
|
+
for i in range(0, 16):
|
166
|
+
self._grbm_select(vmid=i)
|
167
|
+
self.adev.regSH_MEM_CONFIG.write(address_mode=am.SH_MEM_ADDRESS_MODE_64, alignment_mode=am.SH_MEM_ALIGNMENT_MODE_UNALIGNED,
|
168
|
+
initial_inst_prefetch=3)
|
169
|
+
|
170
|
+
# Configure apertures:
|
171
|
+
# LDS: 0x10000000'00000000 - 0x10000001'00000000 (4GB)
|
172
|
+
# Scratch: 0x20000000'00000000 - 0x20000001'00000000 (4GB)
|
173
|
+
self.adev.regSH_MEM_BASES.write(shared_base=0x1, private_base=0x2)
|
174
|
+
self._grbm_select()
|
175
|
+
|
176
|
+
# Configure MEC doorbell range
|
177
|
+
self.adev.regCP_MEC_DOORBELL_RANGE_LOWER.write(0x0)
|
178
|
+
self.adev.regCP_MEC_DOORBELL_RANGE_UPPER.write(0x450)
|
179
|
+
|
180
|
+
# Enable MEC
|
181
|
+
self.adev.regCP_MEC_RS64_CNTL.update(mec_invalidate_icache=0, mec_pipe0_reset=0, mec_pipe1_reset=0, mec_pipe2_reset=0, mec_pipe3_reset=0,
|
182
|
+
mec_pipe0_active=1, mec_pipe1_active=1, mec_pipe2_active=1, mec_pipe3_active=1, mec_halt=0)
|
183
|
+
|
184
|
+
# NOTE: Wait for MEC to be ready. The kernel does udelay here as well.
|
185
|
+
time.sleep(0.05)
|
186
|
+
|
187
|
+
def fini(self):
|
188
|
+
self._grbm_select(me=1, pipe=0, queue=0)
|
189
|
+
self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2) # 1 - DRAIN_PIPE; 2 - RESET_WAVES
|
190
|
+
self.adev.regSPI_COMPUTE_QUEUE_RESET.write(1)
|
191
|
+
self._grbm_select()
|
192
|
+
self.adev.regGCVM_CONTEXT0_CNTL.write(0)
|
193
|
+
|
194
|
+
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, doorbell:int, pipe:int, queue:int):
|
195
|
+
mqd = self.adev.mm.valloc(0x1000, uncached=True, contigous=True)
|
196
|
+
|
197
|
+
mqd_struct = am.struct_v11_compute_mqd(header=0xC0310800, cp_mqd_base_addr_lo=lo32(mqd.va_addr), cp_mqd_base_addr_hi=hi32(mqd.va_addr),
|
198
|
+
cp_hqd_persistent_state=self.adev.regCP_HQD_PERSISTENT_STATE.build(preload_size=0x55, preload_req=1),
|
199
|
+
cp_hqd_pipe_priority=0x2, cp_hqd_queue_priority=0xf, cp_hqd_quantum=0x111,
|
200
|
+
cp_hqd_pq_base_lo=lo32(ring_addr>>8), cp_hqd_pq_base_hi=hi32(ring_addr>>8),
|
201
|
+
cp_hqd_pq_rptr_report_addr_lo=lo32(rptr_addr), cp_hqd_pq_rptr_report_addr_hi=hi32(rptr_addr),
|
202
|
+
cp_hqd_pq_wptr_poll_addr_lo=lo32(wptr_addr), cp_hqd_pq_wptr_poll_addr_hi=hi32(wptr_addr),
|
203
|
+
cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.build(doorbell_offset=doorbell*2, doorbell_en=1),
|
204
|
+
cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.build(rptr_block_size=5, unord_dispatch=1, queue_size=(ring_size//4).bit_length()-2),
|
205
|
+
cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.build(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000,
|
206
|
+
cp_mqd_control=self.adev.regCP_MQD_CONTROL.build(priv_state=1), cp_hqd_vmid=0,
|
207
|
+
cp_hqd_eop_base_addr_lo=lo32(eop_addr>>8), cp_hqd_eop_base_addr_hi=hi32(eop_addr>>8),
|
208
|
+
cp_hqd_eop_control=self.adev.regCP_HQD_EOP_CONTROL.build(eop_size=(eop_size//4).bit_length()-2))
|
209
|
+
|
210
|
+
# Copy mqd into memory
|
211
|
+
ctypes.memmove(self.adev.paddr2cpu(mqd.paddrs[0][0]), ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct))
|
212
|
+
self.adev.gmc.flush_hdp()
|
213
|
+
|
214
|
+
self._grbm_select(me=1, pipe=pipe, queue=queue)
|
215
|
+
|
216
|
+
mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I')
|
217
|
+
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.reg_off, self.adev.regCP_HQD_PQ_WPTR_HI.reg_off + 1)):
|
218
|
+
self.adev.wreg(reg, mqd_st_mv[0x80 + i])
|
219
|
+
self.adev.regCP_HQD_ACTIVE.write(0x1)
|
220
|
+
|
221
|
+
self._grbm_select()
|
222
|
+
|
223
|
+
self.adev.reg(f"regCP_ME1_PIPE{pipe}_INT_CNTL").update(time_stamp_int_enable=1, generic0_int_enable=1)
|
224
|
+
|
225
|
+
def set_clockgating_state(self):
|
226
|
+
if hasattr(self.adev, 'regMM_ATC_L2_MISC_CG'): self.adev.regMM_ATC_L2_MISC_CG.write(enable=1, mem_ls_enable=1)
|
227
|
+
|
228
|
+
self.adev.regRLC_SAFE_MODE.write(message=1, cmd=1)
|
229
|
+
self.adev.wait_reg(self.adev.regRLC_SAFE_MODE, mask=0x1, value=0x0)
|
230
|
+
|
231
|
+
self.adev.regRLC_CGCG_CGLS_CTRL.update(cgcg_gfx_idle_threshold=0x36, cgcg_en=1, cgls_rep_compansat_delay=0xf, cgls_en=1)
|
232
|
+
|
233
|
+
self.adev.regCP_RB_WPTR_POLL_CNTL.update(poll_frequency=0x100, idle_poll_count=0x90)
|
234
|
+
self.adev.regCP_INT_CNTL.update(cntx_busy_int_enable=1, cntx_empty_int_enable=1, cmp_busy_int_enable=1, gfx_idle_int_enable=1)
|
235
|
+
self.adev.regSDMA0_RLC_CGCG_CTRL.update(cgcg_int_enable=1)
|
236
|
+
self.adev.regSDMA1_RLC_CGCG_CTRL.update(cgcg_int_enable=1)
|
237
|
+
|
238
|
+
self.adev.regRLC_CGTT_MGCG_OVERRIDE.update(perfmon_clock_state=0, gfxip_fgcg_override=0, gfxip_repeater_fgcg_override=0,
|
239
|
+
grbm_cgtt_sclk_override=0, rlc_cgtt_sclk_override=0, gfxip_mgcg_override=0, gfxip_cgls_override=0, gfxip_cgcg_override=0)
|
240
|
+
|
241
|
+
self.adev.regRLC_SAFE_MODE.write(message=0, cmd=1)
|
242
|
+
|
243
|
+
def _grbm_select(self, me=0, pipe=0, queue=0, vmid=0): self.adev.regGRBM_GFX_CNTL.write(meid=me, pipeid=pipe, vmid=vmid, queueid=queue)
|
244
|
+
|
245
|
+
def _config_gfx_rs64(self):
|
246
|
+
def _config_helper(eng_name, cntl_reg, eng_reg, pipe_cnt, me=0):
|
247
|
+
for pipe in range(pipe_cnt):
|
248
|
+
self._grbm_select(me=me, pipe=pipe)
|
249
|
+
self.adev.wreg_pair(f"regCP_{eng_reg}_PRGRM_CNTR_START", "", "_HI", self.adev.fw.ucode_start[eng_name] >> 2)
|
250
|
+
self._grbm_select()
|
251
|
+
self.adev.reg(f"regCP_{cntl_reg}_CNTL").update(**{f"{eng_name.lower()}_pipe{pipe}_reset": 1 for pipe in range(pipe_cnt)})
|
252
|
+
self.adev.reg(f"regCP_{cntl_reg}_CNTL").update(**{f"{eng_name.lower()}_pipe{pipe}_reset": 0 for pipe in range(pipe_cnt)})
|
253
|
+
|
254
|
+
_config_helper(eng_name="PFP", cntl_reg="ME", eng_reg="PFP", pipe_cnt=2)
|
255
|
+
_config_helper(eng_name="ME", cntl_reg="ME", eng_reg="ME", pipe_cnt=2)
|
256
|
+
_config_helper(eng_name="MEC", cntl_reg="MEC_RS64", eng_reg="MEC_RS64", pipe_cnt=4, me=1)
|
257
|
+
|
258
|
+
class AM_IH(AM_IP):
|
259
|
+
def __init__(self, adev):
|
260
|
+
super().__init__(adev)
|
261
|
+
self.ring_size = 512 << 10
|
262
|
+
def _alloc_ring(size): return (self.adev.mm.palloc(size, zero=not self.adev.partial_boot, boot=True),
|
263
|
+
self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True))
|
264
|
+
self.rings = [(*_alloc_ring(self.ring_size), "", 0), (*_alloc_ring(self.ring_size), "_RING1", 1)]
|
265
|
+
|
266
|
+
def interrupt_handler(self):
|
267
|
+
_, rwptr_vm, suf, _ = self.rings[0]
|
268
|
+
wptr = to_mv(self.adev.paddr2cpu(rwptr_vm), 8).cast('Q')[0]
|
269
|
+
|
270
|
+
if self.adev.reg(f"regIH_RB_WPTR{suf}").read(rb_overflow=1):
|
271
|
+
self.adev.reg(f"regIH_RB_WPTR{suf}").update(rb_overflow=0)
|
272
|
+
self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=1)
|
273
|
+
self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=0)
|
274
|
+
self.adev.regIH_RB_RPTR.write(wptr % self.ring_size)
|
275
|
+
|
276
|
+
def init(self):
|
277
|
+
for ring_vm, rwptr_vm, suf, ring_id in self.rings:
|
278
|
+
self.adev.wreg_pair("regIH_RB_BASE", suf, f"_HI{suf}", self.adev.paddr2mc(ring_vm) >> 8)
|
279
|
+
|
280
|
+
self.adev.reg(f"regIH_RB_CNTL{suf}").write(mc_space=4, wptr_overflow_clear=1, rb_size=(self.ring_size//4).bit_length(),
|
281
|
+
mc_snoop=1, mc_ro=0, mc_vmid=0, **({'wptr_overflow_enable': 1, 'rptr_rearm': 1} if ring_id == 0 else {'rb_full_drain_enable': 1}))
|
282
|
+
|
283
|
+
if ring_id == 0: self.adev.wreg_pair("regIH_RB_WPTR_ADDR", "_LO", "_HI", self.adev.paddr2mc(rwptr_vm))
|
284
|
+
|
285
|
+
self.adev.reg(f"regIH_RB_WPTR{suf}").write(0)
|
286
|
+
self.adev.reg(f"regIH_RB_RPTR{suf}").write(0)
|
287
|
+
|
288
|
+
self.adev.reg(f"regIH_DOORBELL_RPTR{suf}").write(((am.AMDGPU_NAVI10_DOORBELL_IH + ring_id) * 2), enable=1)
|
289
|
+
|
290
|
+
self.adev.regIH_STORM_CLIENT_LIST_CNTL.update(client18_is_storm_client=1)
|
291
|
+
self.adev.regIH_INT_FLOOD_CNTL.update(flood_cntl_enable=1)
|
292
|
+
self.adev.regIH_MSI_STORM_CTRL.update(delay=3)
|
293
|
+
|
294
|
+
# toggle interrupts
|
295
|
+
for _, rwptr_vm, suf, ring_id in self.rings:
|
296
|
+
self.adev.reg(f"regIH_RB_CNTL{suf}").update(rb_enable=1, **({'enable_intr': 1} if ring_id == 0 else {}))
|
297
|
+
|
298
|
+
class AM_SDMA(AM_IP):
|
299
|
+
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, doorbell:int, pipe:int, queue:int):
|
300
|
+
# Setup the ring
|
301
|
+
self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_MINOR_PTR_UPDATE").write(0x1)
|
302
|
+
self.adev.wreg_pair(f"regSDMA{pipe}_QUEUE{queue}_RB_RPTR", "", "_HI", 0)
|
303
|
+
self.adev.wreg_pair(f"regSDMA{pipe}_QUEUE{queue}_RB_WPTR", "", "_HI", 0)
|
304
|
+
self.adev.wreg_pair(f"regSDMA{pipe}_QUEUE{queue}_RB_BASE", "", "_HI", ring_addr >> 8)
|
305
|
+
self.adev.wreg_pair(f"regSDMA{pipe}_QUEUE{queue}_RB_RPTR_ADDR", "_LO", "_HI", rptr_addr)
|
306
|
+
self.adev.wreg_pair(f"regSDMA{pipe}_QUEUE{queue}_RB_WPTR_POLL_ADDR", "_LO", "_HI", wptr_addr)
|
307
|
+
self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_DOORBELL_OFFSET").update(offset=doorbell * 2)
|
308
|
+
self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_DOORBELL").update(enable=1)
|
309
|
+
self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_MINOR_PTR_UPDATE").write(0x0)
|
310
|
+
self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_RB_CNTL").write(rb_vmid=0, rptr_writeback_enable=1, rptr_writeback_timer=4,
|
311
|
+
f32_wptr_poll_enable=1, rb_size=(ring_size//4).bit_length()-1, rb_enable=1, rb_priv=1)
|
312
|
+
self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_IB_CNTL").update(ib_enable=1)
|
313
|
+
|
314
|
+
def init(self):
|
315
|
+
for pipe in range(2):
|
316
|
+
self.adev.reg(f"regSDMA{pipe}_WATCHDOG_CNTL").update(queue_hang_count=100) # 10s, 100ms per unit
|
317
|
+
self.adev.reg(f"regSDMA{pipe}_UTCL1_CNTL").update(resp_mode=3, redo_delay=9)
|
318
|
+
self.adev.reg(f"regSDMA{pipe}_UTCL1_PAGE").update(rd_l2_policy=0x2, wr_l2_policy=0x3, llc_noalloc=1) # rd=noa, wr=bypass
|
319
|
+
self.adev.reg(f"regSDMA{pipe}_F32_CNTL").update(halt=0, th1_reset=0)
|
320
|
+
self.adev.reg(f"regSDMA{pipe}_CNTL").update(ctxempty_int_enable=1, trap_enable=1)
|
321
|
+
|
322
|
+
def fini(self):
|
323
|
+
self.adev.regSDMA0_QUEUE0_RB_CNTL.update(rb_enable=0)
|
324
|
+
self.adev.regSDMA0_QUEUE0_IB_CNTL.update(ib_enable=0)
|
325
|
+
self.adev.regGRBM_SOFT_RESET.write(soft_reset_sdma0=1)
|
326
|
+
time.sleep(0.01)
|
327
|
+
self.adev.regGRBM_SOFT_RESET.write(0x0)
|
328
|
+
|
329
|
+
class AM_PSP(AM_IP):
|
330
|
+
def __init__(self, adev):
|
331
|
+
super().__init__(adev)
|
332
|
+
|
333
|
+
self.msg1_paddr = self.adev.mm.palloc(am.PSP_1_MEG, align=am.PSP_1_MEG, zero=not self.adev.partial_boot, boot=True)
|
334
|
+
self.cmd_paddr = self.adev.mm.palloc(am.PSP_CMD_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True)
|
335
|
+
self.fence_paddr = self.adev.mm.palloc(am.PSP_FENCE_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True)
|
336
|
+
|
337
|
+
self.ring_size = 0x10000
|
338
|
+
self.ring_paddr = self.adev.mm.palloc(self.ring_size, zero=not self.adev.partial_boot, boot=True)
|
339
|
+
|
340
|
+
self.max_tmr_size = 0x1300000
|
341
|
+
self.tmr_paddr = self.adev.mm.palloc(self.max_tmr_size, align=am.PSP_TMR_ALIGNMENT, zero=not self.adev.partial_boot, boot=True)
|
342
|
+
|
343
|
+
def is_sos_alive(self): return self.adev.regMP0_SMN_C2PMSG_81.read() != 0x0
|
344
|
+
def init(self):
|
345
|
+
sos_components_load_order = [
|
346
|
+
(am.PSP_FW_TYPE_PSP_KDB, am.PSP_BL__LOAD_KEY_DATABASE), (am.PSP_FW_TYPE_PSP_KDB, am.PSP_BL__LOAD_TOS_SPL_TABLE),
|
347
|
+
(am.PSP_FW_TYPE_PSP_SYS_DRV, am.PSP_BL__LOAD_SYSDRV), (am.PSP_FW_TYPE_PSP_SOC_DRV, am.PSP_BL__LOAD_SOCDRV),
|
348
|
+
(am.PSP_FW_TYPE_PSP_INTF_DRV, am.PSP_BL__LOAD_INTFDRV), (am.PSP_FW_TYPE_PSP_DBG_DRV, am.PSP_BL__LOAD_DBGDRV),
|
349
|
+
(am.PSP_FW_TYPE_PSP_RAS_DRV, am.PSP_BL__LOAD_RASDRV), (am.PSP_FW_TYPE_PSP_SOS, am.PSP_BL__LOAD_SOSDRV)]
|
350
|
+
|
351
|
+
if not self.is_sos_alive():
|
352
|
+
for fw, compid in sos_components_load_order: self._bootloader_load_component(fw, compid)
|
353
|
+
while not self.is_sos_alive(): time.sleep(0.01)
|
354
|
+
|
355
|
+
self._ring_create()
|
356
|
+
self._tmr_init()
|
357
|
+
|
358
|
+
# SMU fw should be loaded before TMR.
|
359
|
+
self._load_ip_fw_cmd(self.adev.fw.smu_psp_desc)
|
360
|
+
self._tmr_load_cmd()
|
361
|
+
|
362
|
+
for psp_desc in self.adev.fw.descs: self._load_ip_fw_cmd(psp_desc)
|
363
|
+
self._rlc_autoload_cmd()
|
364
|
+
|
365
|
+
def _wait_for_bootloader(self): self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_35, mask=0xFFFFFFFF, value=0x80000000)
|
366
|
+
|
367
|
+
def _prep_msg1(self, data):
|
368
|
+
ctypes.memset(cpu_addr:=self.adev.paddr2cpu(self.msg1_paddr), 0, am.PSP_1_MEG)
|
369
|
+
to_mv(cpu_addr, len(data))[:] = data
|
370
|
+
self.adev.gmc.flush_hdp()
|
371
|
+
|
372
|
+
def _bootloader_load_component(self, fw, compid):
|
373
|
+
if fw not in self.adev.fw.sos_fw: return 0
|
374
|
+
|
375
|
+
self._wait_for_bootloader()
|
376
|
+
|
377
|
+
self._prep_msg1(self.adev.fw.sos_fw[fw])
|
378
|
+
self.adev.regMP0_SMN_C2PMSG_36.write(self.adev.paddr2mc(self.msg1_paddr) >> 20)
|
379
|
+
self.adev.regMP0_SMN_C2PMSG_35.write(compid)
|
380
|
+
|
381
|
+
return self._wait_for_bootloader()
|
382
|
+
|
383
|
+
def _tmr_init(self):
|
384
|
+
# Load TOC and calculate TMR size
|
385
|
+
self._prep_msg1(fwm:=self.adev.fw.sos_fw[am.PSP_FW_TYPE_PSP_TOC])
|
386
|
+
self.tmr_size = self._load_toc_cmd(len(fwm)).resp.tmr_size
|
387
|
+
assert self.tmr_size <= self.max_tmr_size
|
388
|
+
|
389
|
+
def _ring_create(self):
|
390
|
+
# If the ring is already created, destroy it
|
391
|
+
if self.adev.regMP0_SMN_C2PMSG_71.read() != 0:
|
392
|
+
self.adev.regMP0_SMN_C2PMSG_64.write(am.GFX_CTRL_CMD_ID_DESTROY_RINGS)
|
393
|
+
|
394
|
+
# There might be handshake issue with hardware which needs delay
|
395
|
+
time.sleep(0.02)
|
396
|
+
|
397
|
+
# Wait until the sOS is ready
|
398
|
+
self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_64, mask=0x80000000, value=0x80000000)
|
399
|
+
|
400
|
+
self.adev.wreg_pair("regMP0_SMN_C2PMSG", "_69", "_70", self.adev.paddr2mc(self.ring_paddr))
|
401
|
+
self.adev.regMP0_SMN_C2PMSG_71.write(self.ring_size)
|
402
|
+
self.adev.regMP0_SMN_C2PMSG_64.write(am.PSP_RING_TYPE__KM << 16)
|
403
|
+
|
404
|
+
# There might be handshake issue with hardware which needs delay
|
405
|
+
time.sleep(0.02)
|
406
|
+
|
407
|
+
self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_64, mask=0x8000FFFF, value=0x80000000)
|
408
|
+
|
409
|
+
def _ring_submit(self):
|
410
|
+
prev_wptr = self.adev.regMP0_SMN_C2PMSG_67.read()
|
411
|
+
ring_entry_addr = self.adev.paddr2cpu(self.ring_paddr) + prev_wptr * 4
|
412
|
+
|
413
|
+
ctypes.memset(ring_entry_addr, 0, ctypes.sizeof(am.struct_psp_gfx_rb_frame))
|
414
|
+
write_loc = am.struct_psp_gfx_rb_frame.from_address(ring_entry_addr)
|
415
|
+
write_loc.cmd_buf_addr_hi, write_loc.cmd_buf_addr_lo = data64(self.adev.paddr2mc(self.cmd_paddr))
|
416
|
+
write_loc.fence_addr_hi, write_loc.fence_addr_lo = data64(self.adev.paddr2mc(self.fence_paddr))
|
417
|
+
write_loc.fence_value = prev_wptr
|
418
|
+
|
419
|
+
# Move the wptr
|
420
|
+
self.adev.regMP0_SMN_C2PMSG_67.write(prev_wptr + ctypes.sizeof(am.struct_psp_gfx_rb_frame) // 4)
|
421
|
+
|
422
|
+
while to_mv(self.adev.paddr2cpu(self.fence_paddr), 4).cast('I')[0] != prev_wptr: pass
|
423
|
+
time.sleep(0.005)
|
424
|
+
|
425
|
+
resp = am.struct_psp_gfx_cmd_resp.from_address(self.adev.paddr2cpu(self.cmd_paddr))
|
426
|
+
if resp.resp.status != 0: raise RuntimeError(f"PSP command failed {resp.cmd_id} {resp.resp.status}")
|
427
|
+
|
428
|
+
return resp
|
429
|
+
|
430
|
+
def _prep_ring_cmd(self, hdr):
|
431
|
+
ctypes.memset(self.adev.paddr2cpu(self.cmd_paddr), 0, 0x1000)
|
432
|
+
cmd = am.struct_psp_gfx_cmd_resp.from_address(self.adev.paddr2cpu(self.cmd_paddr))
|
433
|
+
cmd.cmd_id = hdr
|
434
|
+
return cmd
|
435
|
+
|
436
|
+
def _load_ip_fw_cmd(self, psp_desc):
|
437
|
+
if DEBUG >= 2: print(f"am {self.adev.devfmt}: loading fw: {am.psp_gfx_fw_type__enumvalues[psp_desc[0]]}")
|
438
|
+
fw_type, fw_bytes = psp_desc
|
439
|
+
|
440
|
+
self._prep_msg1(fw_bytes)
|
441
|
+
cmd = self._prep_ring_cmd(am.GFX_CMD_ID_LOAD_IP_FW)
|
442
|
+
cmd.cmd.cmd_load_ip_fw.fw_phy_addr_hi, cmd.cmd.cmd_load_ip_fw.fw_phy_addr_lo = data64(self.adev.paddr2mc(self.msg1_paddr))
|
443
|
+
cmd.cmd.cmd_load_ip_fw.fw_size = len(fw_bytes)
|
444
|
+
cmd.cmd.cmd_load_ip_fw.fw_type = fw_type
|
445
|
+
return self._ring_submit()
|
446
|
+
|
447
|
+
def _tmr_load_cmd(self):
|
448
|
+
cmd = self._prep_ring_cmd(am.GFX_CMD_ID_SETUP_TMR)
|
449
|
+
cmd.cmd.cmd_setup_tmr.buf_phy_addr_hi, cmd.cmd.cmd_setup_tmr.buf_phy_addr_lo = data64(self.adev.paddr2mc(self.tmr_paddr))
|
450
|
+
cmd.cmd.cmd_setup_tmr.system_phy_addr_hi, cmd.cmd.cmd_setup_tmr.system_phy_addr_lo = data64(self.tmr_paddr)
|
451
|
+
cmd.cmd.cmd_setup_tmr.bitfield.virt_phy_addr = 1
|
452
|
+
cmd.cmd.cmd_setup_tmr.buf_size = self.tmr_size
|
453
|
+
return self._ring_submit()
|
454
|
+
|
455
|
+
def _load_toc_cmd(self, toc_size):
|
456
|
+
cmd = self._prep_ring_cmd(am.GFX_CMD_ID_LOAD_TOC)
|
457
|
+
cmd.cmd.cmd_load_toc.toc_phy_addr_hi, cmd.cmd.cmd_load_toc.toc_phy_addr_lo = data64(self.adev.paddr2mc(self.msg1_paddr))
|
458
|
+
cmd.cmd.cmd_load_toc.toc_size = toc_size
|
459
|
+
return self._ring_submit()
|
460
|
+
|
461
|
+
def _rlc_autoload_cmd(self):
|
462
|
+
self._prep_ring_cmd(am.GFX_CMD_ID_AUTOLOAD_RLC)
|
463
|
+
return self._ring_submit()
|
@@ -62,8 +62,10 @@ class NVCompiler(CUDACompiler):
|
|
62
62
|
def __init__(self, arch:str): super().__init__(arch, cache_key="nv")
|
63
63
|
def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetCUBIN, nvrtc.nvrtcGetCUBINSize)
|
64
64
|
|
65
|
-
class PTXCompiler(
|
66
|
-
def __init__(self, arch:str, cache_key="ptx"):
|
65
|
+
class PTXCompiler(Compiler):
|
66
|
+
def __init__(self, arch:str, cache_key="ptx"):
|
67
|
+
self.arch = arch
|
68
|
+
super().__init__(f"compile_{cache_key}_{self.arch}")
|
67
69
|
def compile(self, src:str) -> bytes: return src.replace("TARGET", self.arch).replace("VERSION", "7.8" if self.arch >= "sm_89" else "7.5").encode()
|
68
70
|
|
69
71
|
class NVPTXCompiler(PTXCompiler):
|
tinygrad/runtime/support/elf.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1
|
-
|
2
|
-
from typing import Tuple, List, Any
|
1
|
+
import struct, tinygrad.runtime.autogen.libc as libc
|
3
2
|
from dataclasses import dataclass
|
4
|
-
|
3
|
+
from tinygrad.helpers import getbits, i2u
|
5
4
|
|
6
5
|
@dataclass(frozen=True)
|
7
6
|
class ElfSection: name:str; header:libc.Elf64_Shdr; content:bytes # noqa: E702
|
8
7
|
|
9
|
-
def elf_loader(blob:bytes, force_section_align:int=1) ->
|
8
|
+
def elf_loader(blob:bytes, force_section_align:int=1) -> tuple[memoryview, list[ElfSection], list[tuple]]:
|
10
9
|
def _strtab(blob: bytes, idx: int) -> str: return blob[idx:blob.find(b'\x00', idx)].decode('utf-8')
|
11
10
|
|
12
11
|
header = libc.Elf64_Ehdr.from_buffer_copy(blob)
|
@@ -33,6 +32,31 @@ def elf_loader(blob:bytes, force_section_align:int=1) -> Tuple[memoryview, List[
|
|
33
32
|
for sh, trgt_sh_name, c_rels in rel + rela:
|
34
33
|
target_image_off = next(tsh for tsh in sections if tsh.name == trgt_sh_name).header.sh_addr
|
35
34
|
rels = [(r.r_offset, symtab[libc.ELF64_R_SYM(r.r_info)], libc.ELF64_R_TYPE(r.r_info), getattr(r, "r_addend", 0)) for r in c_rels]
|
35
|
+
for roff, sym, r_type_, r_addend in rels:
|
36
|
+
if sym.st_shndx == 0: raise RuntimeError(f'Attempting to relocate against an undefined symbol {repr(_strtab(sh_strtab, sym.st_name))}')
|
36
37
|
relocs += [(target_image_off + roff, sections[sym.st_shndx].header.sh_addr + sym.st_value, rtype, raddend) for roff, sym, rtype, raddend in rels]
|
37
38
|
|
38
39
|
return memoryview(image), sections, relocs
|
40
|
+
|
41
|
+
def relocate(instr: int, ploc: int, tgt: int, r_type: int):
|
42
|
+
match r_type:
|
43
|
+
# https://refspecs.linuxfoundation.org/elf/x86_64-abi-0.95.pdf
|
44
|
+
case libc.R_X86_64_PC32: return i2u(32, tgt-ploc)
|
45
|
+
# https://github.com/ARM-software/abi-aa/blob/main/aaelf64/aaelf64.rst for definitions of relocations
|
46
|
+
# https://www.scs.stanford.edu/~zyedidia/arm64/index.html for instruction encodings
|
47
|
+
case libc.R_AARCH64_ADR_PREL_PG_HI21:
|
48
|
+
rel_pg = (tgt & ~0xFFF) - (ploc & ~0xFFF)
|
49
|
+
return instr | (getbits(rel_pg, 12, 13) << 29) | (getbits(rel_pg, 14, 32) << 5)
|
50
|
+
case libc.R_AARCH64_ADD_ABS_LO12_NC: return instr | (getbits(tgt, 0, 11) << 10)
|
51
|
+
case libc.R_AARCH64_LDST16_ABS_LO12_NC: return instr | (getbits(tgt, 1, 11) << 10)
|
52
|
+
case libc.R_AARCH64_LDST32_ABS_LO12_NC: return instr | (getbits(tgt, 2, 11) << 10)
|
53
|
+
case libc.R_AARCH64_LDST64_ABS_LO12_NC: return instr | (getbits(tgt, 3, 11) << 10)
|
54
|
+
case libc.R_AARCH64_LDST128_ABS_LO12_NC: return instr | (getbits(tgt, 4, 11) << 10)
|
55
|
+
raise NotImplementedError(f"Encountered unknown relocation type {r_type}")
|
56
|
+
|
57
|
+
def jit_loader(obj: bytes) -> bytes:
|
58
|
+
image, _, relocs = elf_loader(obj)
|
59
|
+
# This is needed because we have an object file, not a .so that has all internal references (like loads of constants from .rodata) resolved.
|
60
|
+
for ploc,tgt,r_type,r_addend in relocs:
|
61
|
+
image[ploc:ploc+4] = struct.pack("<I", relocate(struct.unpack("<I", image[ploc:ploc+4])[0], ploc, tgt+r_addend, r_type))
|
62
|
+
return bytes(image)
|