tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -1,52 +1,63 @@
1
- import ctypes, time, contextlib
1
+ import ctypes, time, contextlib, functools
2
2
  from typing import Literal
3
- from tinygrad.runtime.autogen.am import am, smu_v13_0_0
4
- from tinygrad.helpers import to_mv, data64, lo32, hi32, DEBUG
3
+ from tinygrad.helpers import to_mv, data64, lo32, hi32, DEBUG, wait_cond
4
+ from tinygrad.runtime.autogen.am import am
5
+ from tinygrad.runtime.support.amd import import_soc
5
6
 
6
7
  class AM_IP:
7
8
  def __init__(self, adev): self.adev = adev
8
- def init(self): raise NotImplementedError("IP block init must be implemeted")
9
- def fini(self): pass
9
+ def init_sw(self): pass # Prepare sw/allocations for this IP
10
+ def init_hw(self): pass # Initialize hw for this IP
11
+ def fini_hw(self): pass # Finalize hw for this IP
12
+ def set_clockgating_state(self): pass # Set clockgating state for this IP
10
13
 
11
- class AM_SOC21(AM_IP):
12
- def init(self):
14
+ class AM_SOC(AM_IP):
15
+ def init_sw(self): self.module = import_soc(self.adev.ip_ver[am.GC_HWIP])
16
+
17
+ def init_hw(self):
13
18
  self.adev.regRCC_DEV0_EPF2_STRAP2.update(strap_no_soft_reset_dev0_f2=0x0)
14
19
  self.adev.regRCC_DEV0_EPF0_RCC_DOORBELL_APER_EN.write(0x1)
20
+ def set_clockgating_state(self): self.adev.regHDP_MEM_POWER_CTRL.update(atomic_mem_power_ctrl_en=1, atomic_mem_power_ds_en=1)
15
21
 
16
- class AM_GMC(AM_IP):
17
- def __init__(self, adev):
18
- super().__init__(adev)
22
+ def doorbell_enable(self, port, awid=0, awaddr_31_28_value=0, offset=0, size=0):
23
+ self.adev.reg(f"{'regGDC_S2A0_S2A' if self.adev.ip_ver[am.GC_HWIP] >= (12,0,0) else 'regS2A'}_DOORBELL_ENTRY_{port}_CTRL").update(
24
+ **{f"s2a_doorbell_port{port}_enable":1, f"s2a_doorbell_port{port}_awid":awid, f"s2a_doorbell_port{port}_awaddr_31_28_value":awaddr_31_28_value,
25
+ f"s2a_doorbell_port{port}_range_offset":offset, f"s2a_doorbell_port{port}_range_size":size})
19
26
 
27
+ class AM_GMC(AM_IP):
28
+ def init_sw(self):
20
29
  # Memory controller aperture
21
30
  self.mc_base = (self.adev.regMMMC_VM_FB_LOCATION_BASE.read() & 0xFFFFFF) << 24
22
31
  self.mc_end = self.mc_base + self.adev.mm.vram_size - 1
23
32
 
24
33
  # VM aperture
25
34
  self.vm_base = self.adev.mm.va_allocator.base
26
- self.vm_end = self.vm_base + self.adev.mm.va_allocator.size - 1
35
+ self.vm_end = min(self.vm_base + (1 << self.adev.mm.va_bits) - 1, 0x7fffffffffff)
27
36
 
28
- # GFX11 has 44-bit address space
37
+ # GFX11/GFX12 has 44-bit address space
29
38
  self.address_space_mask = (1 << 44) - 1
30
39
 
31
- self.memscratch_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
32
- self.dummy_page_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
40
+ self.memscratch_paddr = self.adev.mm.palloc(0x1000, zero=False, boot=True)
41
+ self.dummy_page_paddr = self.adev.mm.palloc(0x1000, zero=False, boot=True)
33
42
  self.hub_initted = {"MM": False, "GC": False}
34
43
 
35
- def init(self): self.init_hub("MM")
44
+ self.pf_status_reg = lambda ip: f"reg{ip}VM_L2_PROTECTION_FAULT_STATUS{'_LO32' if self.adev.ip_ver[am.GC_HWIP] >= (12,0,0) else ''}"
45
+
46
+ def init_hw(self): self.init_hub("MM")
36
47
 
37
- def flush_hdp(self): self.adev.regBIF_BX_PF0_GPU_HDP_FLUSH_REQ.write(0xffffffff)
48
+ def flush_hdp(self): self.adev.wreg(self.adev.reg("regBIF_BX0_REMAP_HDP_MEM_FLUSH_CNTL").read() // 4, 0x0)
38
49
  def flush_tlb(self, ip:Literal["MM", "GC"], vmid, flush_type=0):
39
50
  self.flush_hdp()
40
51
 
41
52
  # Can't issue TLB invalidation if the hub isn't initialized.
42
53
  if not self.hub_initted[ip]: return
43
54
 
44
- if ip == "MM": self.adev.wait_reg(self.adev.regMMVM_INVALIDATE_ENG17_SEM, mask=0x1, value=0x1)
55
+ if ip == "MM": wait_cond(lambda: self.adev.regMMVM_INVALIDATE_ENG17_SEM.read() & 0x1, value=1, msg="mm flush_tlb timeout")
45
56
 
46
57
  self.adev.reg(f"reg{ip}VM_INVALIDATE_ENG17_REQ").write(flush_type=flush_type, per_vmid_invalidate_req=(1 << vmid), invalidate_l2_ptes=1,
47
58
  invalidate_l2_pde0=1, invalidate_l2_pde1=1, invalidate_l2_pde2=1, invalidate_l1_ptes=1, clear_protection_fault_status_addr=0)
48
59
 
49
- self.adev.wait_reg(self.adev.reg(f"reg{ip}VM_INVALIDATE_ENG17_ACK"), mask=(1 << vmid), value=(1 << vmid))
60
+ wait_cond(lambda: self.adev.reg(f"reg{ip}VM_INVALIDATE_ENG17_ACK").read() & (1 << vmid), value=(1 << vmid), msg="flush_tlb timeout")
50
61
 
51
62
  if ip == "MM":
52
63
  self.adev.regMMVM_INVALIDATE_ENG17_SEM.write(0x0)
@@ -59,7 +70,14 @@ class AM_GMC(AM_IP):
59
70
  self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_START_ADDR", "_LO32", "_HI32", self.vm_base >> 12)
60
71
  self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_END_ADDR", "_LO32", "_HI32", self.vm_end >> 12)
61
72
  self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", page_table.paddr | 1)
62
- self.adev.reg(f"reg{ip}VM_CONTEXT{vmid}_CNTL").write(0x1fffe00, enable_context=1, page_table_depth=(3 - page_table.lv))
73
+ self.adev.reg(f"reg{ip}VM_CONTEXT{vmid}_CNTL").write(0x1800000, pde0_protection_fault_enable_interrupt=1, pde0_protection_fault_enable_default=1,
74
+ dummy_page_protection_fault_enable_interrupt=1, dummy_page_protection_fault_enable_default=1,
75
+ range_protection_fault_enable_interrupt=1, range_protection_fault_enable_default=1,
76
+ valid_protection_fault_enable_interrupt=1, valid_protection_fault_enable_default=1,
77
+ read_protection_fault_enable_interrupt=1, read_protection_fault_enable_default=1,
78
+ write_protection_fault_enable_interrupt=1, write_protection_fault_enable_default=1,
79
+ execute_protection_fault_enable_interrupt=1, execute_protection_fault_enable_default=1,
80
+ enable_context=1, page_table_depth=(3 - page_table.lv))
63
81
 
64
82
  def init_hub(self, ip:Literal["MM", "GC"]):
65
83
  # Init system apertures
@@ -76,7 +94,7 @@ class AM_GMC(AM_IP):
76
94
 
77
95
  # Init TLB and cache
78
96
  self.adev.reg(f"reg{ip}MC_VM_MX_L1_TLB_CNTL").update(enable_l1_tlb=1, system_access_mode=3, enable_advanced_driver_model=1,
79
- system_aperture_unmapped_access=0, eco_bits=0, mtype=am.MTYPE_UC)
97
+ system_aperture_unmapped_access=0, eco_bits=0, mtype=self.adev.soc.module.MTYPE_UC)
80
98
 
81
99
  self.adev.reg(f"reg{ip}VM_L2_CNTL").update(enable_l2_cache=1, enable_l2_fragment_processing=0, enable_default_page_out_to_system_memory=1,
82
100
  l2_pde0_cache_tag_generation_mode=0, pde_fault_classification=0, context1_identity_access_mode=1, identity_mode_fragment_size=0)
@@ -95,77 +113,93 @@ class AM_GMC(AM_IP):
95
113
  for eng_i in range(18): self.adev.wreg_pair(f"reg{ip}VM_INVALIDATE_ENG{eng_i}_ADDR_RANGE", "_LO32", "_HI32", 0x1fffffffff)
96
114
  self.hub_initted[ip] = True
97
115
 
116
+ @functools.cache
117
+ def get_pte_flags(self, pte_lv, is_table, frag, uncached, system, snooped, valid, extra=0):
118
+ extra |= (am.AMDGPU_PTE_SYSTEM * system) | (am.AMDGPU_PTE_SNOOPED * snooped) | (am.AMDGPU_PTE_VALID * valid) | am.AMDGPU_PTE_FRAG(frag)
119
+ if not is_table: extra |= (am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE)
120
+ if self.adev.ip_ver[am.GC_HWIP] >= (12,0,0):
121
+ extra |= am.AMDGPU_PTE_MTYPE_GFX12(0, self.adev.soc.module.MTYPE_UC if uncached else 0)
122
+ extra |= (am.AMDGPU_PDE_PTE_GFX12 if not is_table and pte_lv != am.AMDGPU_VM_PTB else (am.AMDGPU_PTE_IS_PTE if not is_table else 0))
123
+ else:
124
+ extra |= am.AMDGPU_PTE_MTYPE_NV10(0, self.adev.soc.module.MTYPE_UC if uncached else 0)
125
+ extra |= (am.AMDGPU_PDE_PTE if not is_table and pte_lv != am.AMDGPU_VM_PTB else 0)
126
+ return extra
127
+ def is_pte_huge_page(self, pte): return pte & (am.AMDGPU_PDE_PTE_GFX12 if self.adev.ip_ver[am.GC_HWIP] >= (12,0,0) else am.AMDGPU_PDE_PTE)
128
+
98
129
  def on_interrupt(self):
99
130
  for ip in ["MM", "GC"]:
100
- st, va = self.adev.reg(f'reg{ip}VM_L2_PROTECTION_FAULT_STATUS').read(), self.adev.reg(f'reg{ip}VM_L2_PROTECTION_FAULT_ADDR_LO32').read()
101
- va = (va | (self.adev.reg(f'reg{ip}VM_L2_PROTECTION_FAULT_ADDR_HI32').read()) << 32) << 12
102
- if self.adev.reg(f"reg{ip}VM_L2_PROTECTION_FAULT_STATUS").read(): raise RuntimeError(f"{ip}VM_L2_PROTECTION_FAULT_STATUS: {st:#x} {va:#x}")
131
+ va = (self.adev.reg(f'reg{ip}VM_L2_PROTECTION_FAULT_ADDR_HI32').read()<<32) | self.adev.reg(f'reg{ip}VM_L2_PROTECTION_FAULT_ADDR_LO32').read()
132
+ if self.adev.reg(self.pf_status_reg(ip)).read():
133
+ raise RuntimeError(f"{ip}VM_L2_PROTECTION_FAULT_STATUS: {self.adev.reg(self.pf_status_reg(ip)).read_bitfields()} {va<<12:#x}")
103
134
 
104
135
  class AM_SMU(AM_IP):
105
- def __init__(self, adev):
106
- super().__init__(adev)
107
- self.driver_table_paddr = self.adev.mm.palloc(0x4000, zero=not self.adev.partial_boot, boot=True)
136
+ def init_sw(self):
137
+ self.smu_mod = self.adev._ip_module("smu", am.MP1_HWIP, prever_prefix='v')
138
+ self.driver_table_paddr = self.adev.mm.palloc(0x4000, zero=False, boot=True)
108
139
 
109
- def init(self):
110
- self._send_msg(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrHigh, hi32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
111
- self._send_msg(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrLow, lo32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
112
- self._send_msg(smu_v13_0_0.PPSMC_MSG_EnableAllSmuFeatures, 0, poll=True)
140
+ def init_hw(self):
141
+ self._send_msg(self.smu_mod.PPSMC_MSG_SetDriverDramAddrHigh, hi32(self.adev.paddr2mc(self.driver_table_paddr)))
142
+ self._send_msg(self.smu_mod.PPSMC_MSG_SetDriverDramAddrLow, lo32(self.adev.paddr2mc(self.driver_table_paddr)))
143
+ self._send_msg(self.smu_mod.PPSMC_MSG_EnableAllSmuFeatures, 0)
113
144
 
114
145
  def is_smu_alive(self):
115
- with contextlib.suppress(RuntimeError): self._send_msg(smu_v13_0_0.PPSMC_MSG_GetSmuVersion, 0, timeout=100)
146
+ with contextlib.suppress(RuntimeError): self._send_msg(self.smu_mod.PPSMC_MSG_GetSmuVersion, 0, timeout=100)
116
147
  return self.adev.mmMP1_SMN_C2PMSG_90.read() != 0
117
148
 
118
149
  def mode1_reset(self):
119
150
  if DEBUG >= 2: print(f"am {self.adev.devfmt}: mode1 reset")
120
- self._send_msg(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True)
151
+ if self.adev.ip_ver[am.MP0_HWIP] >= (14,0,0): self._send_msg(__DEBUGSMC_MSG_Mode1Reset:=2, 0, debug=True)
152
+ else: self._send_msg(self.smu_mod.PPSMC_MSG_Mode1Reset, 0)
121
153
  time.sleep(0.5) # 500ms
122
154
 
123
155
  def read_table(self, table_t, cmd):
124
- self._send_msg(smu_v13_0_0.PPSMC_MSG_TransferTableSmu2Dram, cmd, poll=True)
125
- return table_t.from_buffer(to_mv(self.adev.paddr2cpu(self.driver_table_paddr), ctypes.sizeof(table_t)))
126
- def read_metrics(self): return self.read_table(smu_v13_0_0.SmuMetricsExternal_t, smu_v13_0_0.TABLE_SMU_METRICS)
156
+ self._send_msg(self.smu_mod.PPSMC_MSG_TransferTableSmu2Dram, cmd)
157
+ return table_t.from_buffer(bytearray(self.adev.vram.view(self.driver_table_paddr, ctypes.sizeof(table_t))[:]))
158
+ def read_metrics(self): return self.read_table(self.smu_mod.SmuMetricsExternal_t, self.smu_mod.TABLE_SMU_METRICS)
127
159
 
128
160
  def set_clocks(self, level):
129
161
  if not hasattr(self, 'clcks'):
130
162
  self.clcks = {}
131
- for clck in [smu_v13_0_0.PPCLK_GFXCLK, smu_v13_0_0.PPCLK_UCLK, smu_v13_0_0.PPCLK_FCLK, smu_v13_0_0.PPCLK_SOCCLK]:
132
- cnt = self._send_msg(smu_v13_0_0.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|0xff, read_back_arg=True)&0x7fffffff
133
- self.clcks[clck] = [self._send_msg(smu_v13_0_0.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|i, read_back_arg=True)&0x7fffffff for i in range(cnt)]
163
+ for clck in [self.smu_mod.PPCLK_GFXCLK, self.smu_mod.PPCLK_UCLK, self.smu_mod.PPCLK_FCLK, self.smu_mod.PPCLK_SOCCLK]:
164
+ cnt = self._send_msg(self.smu_mod.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|0xff, read_back_arg=True)&0x7fffffff
165
+ self.clcks[clck] = [self._send_msg(self.smu_mod.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|i, read_back_arg=True)&0x7fffffff for i in range(cnt)]
134
166
 
135
167
  for clck, vals in self.clcks.items():
136
- self._send_msg(smu_v13_0_0.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]), poll=True)
137
- self._send_msg(smu_v13_0_0.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level]), poll=True)
138
-
139
- def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout)
140
- def _smu_cmn_send_msg(self, msg, param=0):
141
- self.adev.mmMP1_SMN_C2PMSG_90.write(0) # resp reg
142
- self.adev.mmMP1_SMN_C2PMSG_82.write(param)
143
- self.adev.mmMP1_SMN_C2PMSG_66.write(msg)
168
+ self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]))
169
+ self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level]))
144
170
 
145
- def _send_msg(self, msg, param, poll=True, read_back_arg=False, timeout=10000): # 10s
146
- if poll: self._smu_cmn_poll_stat(timeout=timeout)
171
+ def _smu_cmn_send_msg(self, msg:int, param=0, debug=False):
172
+ (self.adev.mmMP1_SMN_C2PMSG_90 if not debug else self.adev.mmMP1_SMN_C2PMSG_54).write(0) # resp reg
173
+ (self.adev.mmMP1_SMN_C2PMSG_82 if not debug else self.adev.mmMP1_SMN_C2PMSG_53).write(param)
174
+ (self.adev.mmMP1_SMN_C2PMSG_66 if not debug else self.adev.mmMP1_SMN_C2PMSG_75).write(msg)
147
175
 
148
- self._smu_cmn_send_msg(msg, param)
149
- self._smu_cmn_poll_stat(timeout=timeout)
150
- return self.adev.mmMP1_SMN_C2PMSG_82.read() if read_back_arg else None
176
+ def _send_msg(self, msg:int, param:int, read_back_arg=False, timeout=10000, debug=False): # default timeout is 10 seconds
177
+ self._smu_cmn_send_msg(msg, param, debug=debug)
178
+ wait_cond(lambda: (self.adev.mmMP1_SMN_C2PMSG_90 if not debug else self.adev.mmMP1_SMN_C2PMSG_54).read(), value=1, timeout_ms=timeout,
179
+ msg=f"SMU msg {msg:#x} timeout")
180
+ return (self.adev.mmMP1_SMN_C2PMSG_82 if not debug else self.adev.mmMP1_SMN_C2PMSG_53).read() if read_back_arg else None
151
181
 
152
182
  class AM_GFX(AM_IP):
153
- def init(self):
183
+ def init_hw(self):
154
184
  # Wait for RLC autoload to complete
155
- while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read(bootload_complete=1) != 0: pass
185
+ while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read_bitfields()['bootload_complete'] != 0: pass
156
186
 
157
187
  self._config_gfx_rs64()
158
188
  self.adev.gmc.init_hub("GC")
159
189
 
160
190
  # NOTE: Golden reg for gfx11. No values for this reg provided. The kernel just ors 0x20000000 to this reg.
161
191
  self.adev.regTCP_CNTL.write(self.adev.regTCP_CNTL.read() | 0x20000000)
192
+
162
193
  self.adev.regRLC_SRM_CNTL.update(srm_enable=1, auto_incr_addr=1)
163
194
 
195
+ self.adev.soc.doorbell_enable(port=0, awid=0x3, awaddr_31_28_value=0x3)
196
+ self.adev.soc.doorbell_enable(port=3, awid=0x6, awaddr_31_28_value=0x3)
197
+
164
198
  self.adev.regGRBM_CNTL.update(read_timeout=0xff)
165
199
  for i in range(0, 16):
166
200
  self._grbm_select(vmid=i)
167
- self.adev.regSH_MEM_CONFIG.write(address_mode=am.SH_MEM_ADDRESS_MODE_64, alignment_mode=am.SH_MEM_ALIGNMENT_MODE_UNALIGNED,
168
- initial_inst_prefetch=3)
201
+ self.adev.regSH_MEM_CONFIG.write(address_mode=self.adev.soc.module.SH_MEM_ADDRESS_MODE_64,
202
+ alignment_mode=self.adev.soc.module.SH_MEM_ALIGNMENT_MODE_UNALIGNED, initial_inst_prefetch=3)
169
203
 
170
204
  # Configure apertures:
171
205
  # LDS: 0x10000000'00000000 - 0x10000001'00000000 (4GB)
@@ -178,13 +212,12 @@ class AM_GFX(AM_IP):
178
212
  self.adev.regCP_MEC_DOORBELL_RANGE_UPPER.write(0x450)
179
213
 
180
214
  # Enable MEC
181
- self.adev.regCP_MEC_RS64_CNTL.update(mec_invalidate_icache=0, mec_pipe0_reset=0, mec_pipe1_reset=0, mec_pipe2_reset=0, mec_pipe3_reset=0,
182
- mec_pipe0_active=1, mec_pipe1_active=1, mec_pipe2_active=1, mec_pipe3_active=1, mec_halt=0)
215
+ self.adev.regCP_MEC_RS64_CNTL.update(mec_invalidate_icache=0, mec_pipe0_reset=0, mec_pipe0_active=1, mec_halt=0)
183
216
 
184
217
  # NOTE: Wait for MEC to be ready. The kernel does udelay here as well.
185
218
  time.sleep(0.05)
186
219
 
187
- def fini(self):
220
+ def fini_hw(self):
188
221
  self._grbm_select(me=1, pipe=0, queue=0)
189
222
  self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2) # 1 - DRAIN_PIPE; 2 - RESET_WAVES
190
223
  self.adev.regSPI_COMPUTE_QUEUE_RESET.write(1)
@@ -192,29 +225,30 @@ class AM_GFX(AM_IP):
192
225
  self.adev.regGCVM_CONTEXT0_CNTL.write(0)
193
226
 
194
227
  def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, doorbell:int, pipe:int, queue:int):
195
- mqd = self.adev.mm.valloc(0x1000, uncached=True, contigous=True)
228
+ mqd = self.adev.mm.valloc(0x1000, uncached=True, contiguous=True)
196
229
 
197
- mqd_struct = am.struct_v11_compute_mqd(header=0xC0310800, cp_mqd_base_addr_lo=lo32(mqd.va_addr), cp_mqd_base_addr_hi=hi32(mqd.va_addr),
198
- cp_hqd_persistent_state=self.adev.regCP_HQD_PERSISTENT_STATE.build(preload_size=0x55, preload_req=1),
230
+ struct_t = getattr(am, f"struct_v{self.adev.ip_ver[am.GC_HWIP][0]}_compute_mqd")
231
+ mqd_struct = struct_t(header=0xC0310800, cp_mqd_base_addr_lo=lo32(mqd.va_addr), cp_mqd_base_addr_hi=hi32(mqd.va_addr),
232
+ cp_hqd_persistent_state=self.adev.regCP_HQD_PERSISTENT_STATE.encode(preload_size=0x55, preload_req=1),
199
233
  cp_hqd_pipe_priority=0x2, cp_hqd_queue_priority=0xf, cp_hqd_quantum=0x111,
200
234
  cp_hqd_pq_base_lo=lo32(ring_addr>>8), cp_hqd_pq_base_hi=hi32(ring_addr>>8),
201
235
  cp_hqd_pq_rptr_report_addr_lo=lo32(rptr_addr), cp_hqd_pq_rptr_report_addr_hi=hi32(rptr_addr),
202
236
  cp_hqd_pq_wptr_poll_addr_lo=lo32(wptr_addr), cp_hqd_pq_wptr_poll_addr_hi=hi32(wptr_addr),
203
- cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.build(doorbell_offset=doorbell*2, doorbell_en=1),
204
- cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.build(rptr_block_size=5, unord_dispatch=1, queue_size=(ring_size//4).bit_length()-2),
205
- cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.build(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000,
206
- cp_mqd_control=self.adev.regCP_MQD_CONTROL.build(priv_state=1), cp_hqd_vmid=0,
237
+ cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.encode(doorbell_offset=doorbell*2, doorbell_en=1),
238
+ cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.encode(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2),
239
+ cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.encode(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000,
240
+ cp_mqd_control=self.adev.regCP_MQD_CONTROL.encode(priv_state=1), cp_hqd_vmid=0,
207
241
  cp_hqd_eop_base_addr_lo=lo32(eop_addr>>8), cp_hqd_eop_base_addr_hi=hi32(eop_addr>>8),
208
- cp_hqd_eop_control=self.adev.regCP_HQD_EOP_CONTROL.build(eop_size=(eop_size//4).bit_length()-2))
242
+ cp_hqd_eop_control=self.adev.regCP_HQD_EOP_CONTROL.encode(eop_size=(eop_size//4).bit_length()-2))
209
243
 
210
244
  # Copy mqd into memory
211
- ctypes.memmove(self.adev.paddr2cpu(mqd.paddrs[0][0]), ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct))
245
+ self.adev.vram.view(mqd.paddrs[0][0], ctypes.sizeof(mqd_struct))[:] = memoryview(mqd_struct).cast('B')
212
246
  self.adev.gmc.flush_hdp()
213
247
 
214
248
  self._grbm_select(me=1, pipe=pipe, queue=queue)
215
249
 
216
250
  mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I')
217
- for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.reg_off, self.adev.regCP_HQD_PQ_WPTR_HI.reg_off + 1)):
251
+ for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr[0], self.adev.regCP_HQD_PQ_WPTR_HI.addr[0] + 1)):
218
252
  self.adev.wreg(reg, mqd_st_mv[0x80 + i])
219
253
  self.adev.regCP_HQD_ACTIVE.write(0x1)
220
254
 
@@ -226,7 +260,7 @@ class AM_GFX(AM_IP):
226
260
  if hasattr(self.adev, 'regMM_ATC_L2_MISC_CG'): self.adev.regMM_ATC_L2_MISC_CG.write(enable=1, mem_ls_enable=1)
227
261
 
228
262
  self.adev.regRLC_SAFE_MODE.write(message=1, cmd=1)
229
- self.adev.wait_reg(self.adev.regRLC_SAFE_MODE, mask=0x1, value=0x0)
263
+ wait_cond(lambda: self.adev.regRLC_SAFE_MODE.read() & 0x1, value=0, msg="RLC safe mode timeout")
230
264
 
231
265
  self.adev.regRLC_CGCG_CGLS_CTRL.update(cgcg_gfx_idle_threshold=0x36, cgcg_en=1, cgls_rep_compansat_delay=0xf, cgls_en=1)
232
266
 
@@ -251,29 +285,18 @@ class AM_GFX(AM_IP):
251
285
  self.adev.reg(f"regCP_{cntl_reg}_CNTL").update(**{f"{eng_name.lower()}_pipe{pipe}_reset": 1 for pipe in range(pipe_cnt)})
252
286
  self.adev.reg(f"regCP_{cntl_reg}_CNTL").update(**{f"{eng_name.lower()}_pipe{pipe}_reset": 0 for pipe in range(pipe_cnt)})
253
287
 
254
- _config_helper(eng_name="PFP", cntl_reg="ME", eng_reg="PFP", pipe_cnt=2)
255
- _config_helper(eng_name="ME", cntl_reg="ME", eng_reg="ME", pipe_cnt=2)
256
- _config_helper(eng_name="MEC", cntl_reg="MEC_RS64", eng_reg="MEC_RS64", pipe_cnt=4, me=1)
288
+ if self.adev.ip_ver[am.GC_HWIP] >= (12,0,0):
289
+ _config_helper(eng_name="PFP", cntl_reg="ME", eng_reg="PFP", pipe_cnt=1)
290
+ _config_helper(eng_name="ME", cntl_reg="ME", eng_reg="ME", pipe_cnt=1)
291
+ _config_helper(eng_name="MEC", cntl_reg="MEC_RS64", eng_reg="MEC_RS64", pipe_cnt=1, me=1)
257
292
 
258
293
  class AM_IH(AM_IP):
259
- def __init__(self, adev):
260
- super().__init__(adev)
294
+ def init_sw(self):
261
295
  self.ring_size = 512 << 10
262
- def _alloc_ring(size): return (self.adev.mm.palloc(size, zero=not self.adev.partial_boot, boot=True),
263
- self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True))
296
+ def _alloc_ring(size): return (self.adev.mm.palloc(size, zero=False, boot=True), self.adev.mm.palloc(0x1000, zero=False, boot=True))
264
297
  self.rings = [(*_alloc_ring(self.ring_size), "", 0), (*_alloc_ring(self.ring_size), "_RING1", 1)]
265
298
 
266
- def interrupt_handler(self):
267
- _, rwptr_vm, suf, _ = self.rings[0]
268
- wptr = to_mv(self.adev.paddr2cpu(rwptr_vm), 8).cast('Q')[0]
269
-
270
- if self.adev.reg(f"regIH_RB_WPTR{suf}").read(rb_overflow=1):
271
- self.adev.reg(f"regIH_RB_WPTR{suf}").update(rb_overflow=0)
272
- self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=1)
273
- self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=0)
274
- self.adev.regIH_RB_RPTR.write(wptr % self.ring_size)
275
-
276
- def init(self):
299
+ def init_hw(self):
277
300
  for ring_vm, rwptr_vm, suf, ring_id in self.rings:
278
301
  self.adev.wreg_pair("regIH_RB_BASE", suf, f"_HI{suf}", self.adev.paddr2mc(ring_vm) >> 8)
279
302
 
@@ -285,7 +308,7 @@ class AM_IH(AM_IP):
285
308
  self.adev.reg(f"regIH_RB_WPTR{suf}").write(0)
286
309
  self.adev.reg(f"regIH_RB_RPTR{suf}").write(0)
287
310
 
288
- self.adev.reg(f"regIH_DOORBELL_RPTR{suf}").write(((am.AMDGPU_NAVI10_DOORBELL_IH + ring_id) * 2), enable=1)
311
+ self.adev.reg(f"regIH_DOORBELL_RPTR{suf}").write(offset=(am.AMDGPU_NAVI10_DOORBELL_IH + ring_id) * 2, enable=1)
289
312
 
290
313
  self.adev.regIH_STORM_CLIENT_LIST_CNTL.update(client18_is_storm_client=1)
291
314
  self.adev.regIH_INT_FLOOD_CNTL.update(flood_cntl_enable=1)
@@ -295,7 +318,38 @@ class AM_IH(AM_IP):
295
318
  for _, rwptr_vm, suf, ring_id in self.rings:
296
319
  self.adev.reg(f"regIH_RB_CNTL{suf}").update(rb_enable=1, **({'enable_intr': 1} if ring_id == 0 else {}))
297
320
 
321
+ self.adev.soc.doorbell_enable(port=1, awid=0x0, awaddr_31_28_value=0x0, offset=am.AMDGPU_NAVI10_DOORBELL_IH*2, size=2)
322
+
323
+ def interrupt_handler(self):
324
+ _, rwptr_vm, suf, _ = self.rings[0]
325
+ wptr = self.adev.vram.view(offset=rwptr_vm, size=8, fmt='Q')[0]
326
+
327
+ if self.adev.reg(f"regIH_RB_WPTR{suf}").read_bitfields()['rb_overflow']:
328
+ self.adev.reg(f"regIH_RB_WPTR{suf}").update(rb_overflow=0)
329
+ self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=1)
330
+ self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=0)
331
+ self.adev.regIH_RB_RPTR.write(wptr % self.ring_size)
332
+
298
333
  class AM_SDMA(AM_IP):
334
+ def init_sw(self): self.sdma_name = "F32" if self.adev.ip_ver[am.SDMA0_HWIP] < (7,0,0) else "MCU"
335
+ def init_hw(self):
336
+ for pipe in range(2):
337
+ self.adev.reg(f"regSDMA{pipe}_WATCHDOG_CNTL").update(queue_hang_count=100) # 10s, 100ms per unit
338
+ self.adev.reg(f"regSDMA{pipe}_UTCL1_CNTL").update(resp_mode=3, redo_delay=9)
339
+
340
+ # rd=noa, wr=bypass
341
+ self.adev.reg(f"regSDMA{pipe}_UTCL1_PAGE").update(rd_l2_policy=0x2, wr_l2_policy=0x3, **({'llc_noalloc':1} if self.sdma_name == "F32" else {}))
342
+ self.adev.reg(f"regSDMA{pipe}_{self.sdma_name}_CNTL").update(halt=0, **{f"{'th1_' if self.sdma_name == 'F32' else ''}reset":0})
343
+ self.adev.reg(f"regSDMA{pipe}_CNTL").update(ctxempty_int_enable=1, trap_enable=1)
344
+ self.adev.soc.doorbell_enable(port=2, awid=0xe, awaddr_31_28_value=0x3, offset=am.AMDGPU_NAVI10_DOORBELL_sDMA_ENGINE0*2, size=4)
345
+
346
+ def fini_hw(self):
347
+ self.adev.regSDMA0_QUEUE0_RB_CNTL.update(rb_enable=0)
348
+ self.adev.regSDMA0_QUEUE0_IB_CNTL.update(ib_enable=0)
349
+ self.adev.regGRBM_SOFT_RESET.write(soft_reset_sdma0=1)
350
+ time.sleep(0.01)
351
+ self.adev.regGRBM_SOFT_RESET.write(0x0)
352
+
299
353
  def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, doorbell:int, pipe:int, queue:int):
300
354
  # Setup the ring
301
355
  self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_MINOR_PTR_UPDATE").write(0x1)
@@ -308,77 +362,74 @@ class AM_SDMA(AM_IP):
308
362
  self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_DOORBELL").update(enable=1)
309
363
  self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_MINOR_PTR_UPDATE").write(0x0)
310
364
  self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_RB_CNTL").write(rb_vmid=0, rptr_writeback_enable=1, rptr_writeback_timer=4,
311
- f32_wptr_poll_enable=1, rb_size=(ring_size//4).bit_length()-1, rb_enable=1, rb_priv=1)
365
+ **{f'{self.sdma_name.lower()}_wptr_poll_enable':1}, rb_size=(ring_size//4).bit_length()-1, rb_enable=1, rb_priv=1)
312
366
  self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_IB_CNTL").update(ib_enable=1)
313
367
 
314
- def init(self):
315
- for pipe in range(2):
316
- self.adev.reg(f"regSDMA{pipe}_WATCHDOG_CNTL").update(queue_hang_count=100) # 10s, 100ms per unit
317
- self.adev.reg(f"regSDMA{pipe}_UTCL1_CNTL").update(resp_mode=3, redo_delay=9)
318
- self.adev.reg(f"regSDMA{pipe}_UTCL1_PAGE").update(rd_l2_policy=0x2, wr_l2_policy=0x3, llc_noalloc=1) # rd=noa, wr=bypass
319
- self.adev.reg(f"regSDMA{pipe}_F32_CNTL").update(halt=0, th1_reset=0)
320
- self.adev.reg(f"regSDMA{pipe}_CNTL").update(ctxempty_int_enable=1, trap_enable=1)
321
-
322
- def fini(self):
323
- self.adev.regSDMA0_QUEUE0_RB_CNTL.update(rb_enable=0)
324
- self.adev.regSDMA0_QUEUE0_IB_CNTL.update(ib_enable=0)
325
- self.adev.regGRBM_SOFT_RESET.write(soft_reset_sdma0=1)
326
- time.sleep(0.01)
327
- self.adev.regGRBM_SOFT_RESET.write(0x0)
328
-
329
368
  class AM_PSP(AM_IP):
330
- def __init__(self, adev):
331
- super().__init__(adev)
369
+ def init_sw(self):
370
+ self.reg_pref = "regMP0_SMN_C2PMSG" if self.adev.ip_ver[am.MP0_HWIP] < (14,0,0) else "regMPASP_SMN_C2PMSG"
371
+
372
+ msg1_region = next((reg for reg in self.adev.dma_regions or [] if reg[1].nbytes >= (512 << 10)), None)
373
+ if msg1_region is not None:
374
+ self.msg1_addr, self.msg1_view = self.adev.mm.alloc_vaddr(size=msg1_region[1].nbytes, align=am.PSP_1_MEG), msg1_region[1]
375
+ self.adev.mm.map_range(self.msg1_addr, msg1_region[1].nbytes, [(msg1_region[0], msg1_region[1].nbytes)], system=True, uncached=True, boot=True)
376
+ else:
377
+ self.msg1_paddr = self.adev.mm.palloc(am.PSP_1_MEG, align=am.PSP_1_MEG, zero=False, boot=True)
378
+ self.msg1_addr, self.msg1_view = self.adev.paddr2mc(self.msg1_paddr), self.adev.vram.view(self.msg1_paddr, am.PSP_1_MEG, 'B')
332
379
 
333
- self.msg1_paddr = self.adev.mm.palloc(am.PSP_1_MEG, align=am.PSP_1_MEG, zero=not self.adev.partial_boot, boot=True)
334
- self.cmd_paddr = self.adev.mm.palloc(am.PSP_CMD_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True)
335
- self.fence_paddr = self.adev.mm.palloc(am.PSP_FENCE_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True)
380
+ self.cmd_paddr = self.adev.mm.palloc(am.PSP_CMD_BUFFER_SIZE, zero=False, boot=True)
381
+ self.fence_paddr = self.adev.mm.palloc(am.PSP_FENCE_BUFFER_SIZE, zero=True, boot=True)
336
382
 
337
383
  self.ring_size = 0x10000
338
- self.ring_paddr = self.adev.mm.palloc(self.ring_size, zero=not self.adev.partial_boot, boot=True)
384
+ self.ring_paddr = self.adev.mm.palloc(self.ring_size, zero=False, boot=True)
339
385
 
340
386
  self.max_tmr_size = 0x1300000
341
- self.tmr_paddr = self.adev.mm.palloc(self.max_tmr_size, align=am.PSP_TMR_ALIGNMENT, zero=not self.adev.partial_boot, boot=True)
387
+ self.boot_time_tmr = self.adev.ip_ver[am.GC_HWIP] >= (12,0,0)
388
+ if not self.boot_time_tmr:
389
+ self.tmr_paddr = self.adev.mm.palloc(self.max_tmr_size, align=am.PSP_TMR_ALIGNMENT, zero=False, boot=True)
342
390
 
343
- def is_sos_alive(self): return self.adev.regMP0_SMN_C2PMSG_81.read() != 0x0
344
- def init(self):
345
- sos_components_load_order = [
346
- (am.PSP_FW_TYPE_PSP_KDB, am.PSP_BL__LOAD_KEY_DATABASE), (am.PSP_FW_TYPE_PSP_KDB, am.PSP_BL__LOAD_TOS_SPL_TABLE),
391
+ def init_hw(self):
392
+ spl_key = am.PSP_FW_TYPE_PSP_SPL if self.adev.ip_ver[am.MP0_HWIP] >= (14,0,0) else am.PSP_FW_TYPE_PSP_KDB
393
+ sos_components = [(am.PSP_FW_TYPE_PSP_KDB, am.PSP_BL__LOAD_KEY_DATABASE), (spl_key, am.PSP_BL__LOAD_TOS_SPL_TABLE),
347
394
  (am.PSP_FW_TYPE_PSP_SYS_DRV, am.PSP_BL__LOAD_SYSDRV), (am.PSP_FW_TYPE_PSP_SOC_DRV, am.PSP_BL__LOAD_SOCDRV),
348
395
  (am.PSP_FW_TYPE_PSP_INTF_DRV, am.PSP_BL__LOAD_INTFDRV), (am.PSP_FW_TYPE_PSP_DBG_DRV, am.PSP_BL__LOAD_DBGDRV),
349
396
  (am.PSP_FW_TYPE_PSP_RAS_DRV, am.PSP_BL__LOAD_RASDRV), (am.PSP_FW_TYPE_PSP_SOS, am.PSP_BL__LOAD_SOSDRV)]
350
397
 
351
398
  if not self.is_sos_alive():
352
- for fw, compid in sos_components_load_order: self._bootloader_load_component(fw, compid)
399
+ for fw, compid in sos_components: self._bootloader_load_component(fw, compid)
353
400
  while not self.is_sos_alive(): time.sleep(0.01)
354
401
 
355
402
  self._ring_create()
356
403
  self._tmr_init()
357
404
 
358
405
  # SMU fw should be loaded before TMR.
359
- self._load_ip_fw_cmd(self.adev.fw.smu_psp_desc)
360
- self._tmr_load_cmd()
406
+ self._load_ip_fw_cmd(*self.adev.fw.smu_psp_desc)
407
+ if not self.boot_time_tmr: self._tmr_load_cmd()
361
408
 
362
- for psp_desc in self.adev.fw.descs: self._load_ip_fw_cmd(psp_desc)
409
+ for psp_desc in self.adev.fw.descs: self._load_ip_fw_cmd(*psp_desc)
363
410
  self._rlc_autoload_cmd()
364
411
 
365
- def _wait_for_bootloader(self): self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_35, mask=0xFFFFFFFF, value=0x80000000)
412
+ def is_sos_alive(self): return self.adev.reg(f"{self.reg_pref}_81").read() != 0x0
413
+
414
+ def _wait_for_bootloader(self): wait_cond(lambda: self.adev.reg(f"{self.reg_pref}_35").read() & 0x80000000, value=0x80000000, msg="BL not ready")
366
415
 
367
- def _prep_msg1(self, data):
368
- ctypes.memset(cpu_addr:=self.adev.paddr2cpu(self.msg1_paddr), 0, am.PSP_1_MEG)
369
- to_mv(cpu_addr, len(data))[:] = data
416
+ def _prep_msg1(self, data:memoryview):
417
+ assert len(data) <= self.msg1_view.nbytes, f"msg1 buffer is too small {len(data):#x} > {self.msg1_view.nbytes:#x}"
418
+ self.msg1_view[:len(data)+4] = bytes(data) + b'\x00' * 4
370
419
  self.adev.gmc.flush_hdp()
371
420
 
372
- def _bootloader_load_component(self, fw, compid):
421
+ def _bootloader_load_component(self, fw:int, compid:int):
373
422
  if fw not in self.adev.fw.sos_fw: return 0
374
423
 
375
424
  self._wait_for_bootloader()
376
425
 
426
+ if DEBUG >= 2: print(f"am {self.adev.devfmt}: loading sos component: {am.psp_fw_type__enumvalues[fw]}")
427
+
377
428
  self._prep_msg1(self.adev.fw.sos_fw[fw])
378
- self.adev.regMP0_SMN_C2PMSG_36.write(self.adev.paddr2mc(self.msg1_paddr) >> 20)
379
- self.adev.regMP0_SMN_C2PMSG_35.write(compid)
429
+ self.adev.reg(f"{self.reg_pref}_36").write(self.msg1_addr >> 20)
430
+ self.adev.reg(f"{self.reg_pref}_35").write(compid)
380
431
 
381
- return self._wait_for_bootloader()
432
+ return self._wait_for_bootloader() if compid != am.PSP_BL__LOAD_SOSDRV else 0
382
433
 
383
434
  def _tmr_init(self):
384
435
  # Load TOC and calculate TMR size
@@ -388,76 +439,64 @@ class AM_PSP(AM_IP):
388
439
 
389
440
  def _ring_create(self):
390
441
  # If the ring is already created, destroy it
391
- if self.adev.regMP0_SMN_C2PMSG_71.read() != 0:
392
- self.adev.regMP0_SMN_C2PMSG_64.write(am.GFX_CTRL_CMD_ID_DESTROY_RINGS)
442
+ if self.adev.reg(f"{self.reg_pref}_71").read() != 0:
443
+ self.adev.reg(f"{self.reg_pref}_64").write(am.GFX_CTRL_CMD_ID_DESTROY_RINGS)
393
444
 
394
445
  # There might be handshake issue with hardware which needs delay
395
446
  time.sleep(0.02)
396
447
 
397
448
  # Wait until the sOS is ready
398
- self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_64, mask=0x80000000, value=0x80000000)
449
+ wait_cond(lambda: self.adev.reg(f"{self.reg_pref}_64").read() & 0x80000000, value=0x80000000, msg="sOS not ready")
399
450
 
400
- self.adev.wreg_pair("regMP0_SMN_C2PMSG", "_69", "_70", self.adev.paddr2mc(self.ring_paddr))
401
- self.adev.regMP0_SMN_C2PMSG_71.write(self.ring_size)
402
- self.adev.regMP0_SMN_C2PMSG_64.write(am.PSP_RING_TYPE__KM << 16)
451
+ self.adev.wreg_pair(self.reg_pref, "_69", "_70", self.adev.paddr2mc(self.ring_paddr))
452
+ self.adev.reg(f"{self.reg_pref}_71").write(self.ring_size)
453
+ self.adev.reg(f"{self.reg_pref}_64").write(am.PSP_RING_TYPE__KM << 16)
403
454
 
404
455
  # There might be handshake issue with hardware which needs delay
405
456
  time.sleep(0.02)
406
457
 
407
- self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_64, mask=0x8000FFFF, value=0x80000000)
458
+ wait_cond(lambda: self.adev.reg(f"{self.reg_pref}_64").read() & 0x8000FFFF, value=0x80000000, msg="sOS ring not created")
408
459
 
409
- def _ring_submit(self):
410
- prev_wptr = self.adev.regMP0_SMN_C2PMSG_67.read()
411
- ring_entry_addr = self.adev.paddr2cpu(self.ring_paddr) + prev_wptr * 4
460
+ def _ring_submit(self, cmd:am.struct_psp_gfx_cmd_resp) -> am.struct_psp_gfx_cmd_resp:
461
+ msg = am.struct_psp_gfx_rb_frame(fence_value=(prev_wptr:=self.adev.reg(f"{self.reg_pref}_67").read()) + 1,
462
+ cmd_buf_addr_lo=lo32(self.adev.paddr2mc(self.cmd_paddr)), cmd_buf_addr_hi=hi32(self.adev.paddr2mc(self.cmd_paddr)),
463
+ fence_addr_lo=lo32(self.adev.paddr2mc(self.fence_paddr)), fence_addr_hi=hi32(self.adev.paddr2mc(self.fence_paddr)))
412
464
 
413
- ctypes.memset(ring_entry_addr, 0, ctypes.sizeof(am.struct_psp_gfx_rb_frame))
414
- write_loc = am.struct_psp_gfx_rb_frame.from_address(ring_entry_addr)
415
- write_loc.cmd_buf_addr_hi, write_loc.cmd_buf_addr_lo = data64(self.adev.paddr2mc(self.cmd_paddr))
416
- write_loc.fence_addr_hi, write_loc.fence_addr_lo = data64(self.adev.paddr2mc(self.fence_paddr))
417
- write_loc.fence_value = prev_wptr
465
+ self.adev.vram.view(self.cmd_paddr, ctypes.sizeof(cmd))[:] = memoryview(cmd).cast('B')
466
+ self.adev.vram.view(self.ring_paddr + prev_wptr * 4, ctypes.sizeof(msg))[:] = memoryview(msg).cast('B')
418
467
 
419
468
  # Move the wptr
420
- self.adev.regMP0_SMN_C2PMSG_67.write(prev_wptr + ctypes.sizeof(am.struct_psp_gfx_rb_frame) // 4)
469
+ self.adev.reg(f"{self.reg_pref}_67").write(prev_wptr + ctypes.sizeof(am.struct_psp_gfx_rb_frame) // 4)
421
470
 
422
- while to_mv(self.adev.paddr2cpu(self.fence_paddr), 4).cast('I')[0] != prev_wptr: pass
423
- time.sleep(0.005)
471
+ wait_cond(lambda: self.adev.vram.view(self.fence_paddr, 4, 'I')[0], value=msg.fence_value, msg="sOS ring not responding")
424
472
 
425
- resp = am.struct_psp_gfx_cmd_resp.from_address(self.adev.paddr2cpu(self.cmd_paddr))
473
+ resp = type(cmd).from_buffer(bytearray(self.adev.vram.view(self.cmd_paddr, ctypes.sizeof(cmd))[:]))
426
474
  if resp.resp.status != 0: raise RuntimeError(f"PSP command failed {resp.cmd_id} {resp.resp.status}")
427
475
 
428
476
  return resp
429
477
 
430
- def _prep_ring_cmd(self, hdr):
431
- ctypes.memset(self.adev.paddr2cpu(self.cmd_paddr), 0, 0x1000)
432
- cmd = am.struct_psp_gfx_cmd_resp.from_address(self.adev.paddr2cpu(self.cmd_paddr))
433
- cmd.cmd_id = hdr
434
- return cmd
435
-
436
- def _load_ip_fw_cmd(self, psp_desc):
437
- if DEBUG >= 2: print(f"am {self.adev.devfmt}: loading fw: {am.psp_gfx_fw_type__enumvalues[psp_desc[0]]}")
438
- fw_type, fw_bytes = psp_desc
439
-
478
+ def _load_ip_fw_cmd(self, fw_types:list[int], fw_bytes:memoryview):
440
479
  self._prep_msg1(fw_bytes)
441
- cmd = self._prep_ring_cmd(am.GFX_CMD_ID_LOAD_IP_FW)
442
- cmd.cmd.cmd_load_ip_fw.fw_phy_addr_hi, cmd.cmd.cmd_load_ip_fw.fw_phy_addr_lo = data64(self.adev.paddr2mc(self.msg1_paddr))
443
- cmd.cmd.cmd_load_ip_fw.fw_size = len(fw_bytes)
444
- cmd.cmd.cmd_load_ip_fw.fw_type = fw_type
445
- return self._ring_submit()
446
-
447
- def _tmr_load_cmd(self):
448
- cmd = self._prep_ring_cmd(am.GFX_CMD_ID_SETUP_TMR)
480
+ for fw_type in fw_types:
481
+ if DEBUG >= 2: print(f"am {self.adev.devfmt}: loading fw: {am.psp_gfx_fw_type__enumvalues[fw_type]}")
482
+ cmd = am.struct_psp_gfx_cmd_resp(cmd_id=am.GFX_CMD_ID_LOAD_IP_FW)
483
+ cmd.cmd.cmd_load_ip_fw.fw_phy_addr_hi, cmd.cmd.cmd_load_ip_fw.fw_phy_addr_lo = data64(self.msg1_addr)
484
+ cmd.cmd.cmd_load_ip_fw.fw_size = len(fw_bytes)
485
+ cmd.cmd.cmd_load_ip_fw.fw_type = fw_type
486
+ self._ring_submit(cmd)
487
+
488
+ def _tmr_load_cmd(self) -> am.struct_psp_gfx_cmd_resp:
489
+ cmd = am.struct_psp_gfx_cmd_resp(cmd_id=am.GFX_CMD_ID_SETUP_TMR)
449
490
  cmd.cmd.cmd_setup_tmr.buf_phy_addr_hi, cmd.cmd.cmd_setup_tmr.buf_phy_addr_lo = data64(self.adev.paddr2mc(self.tmr_paddr))
450
491
  cmd.cmd.cmd_setup_tmr.system_phy_addr_hi, cmd.cmd.cmd_setup_tmr.system_phy_addr_lo = data64(self.tmr_paddr)
451
492
  cmd.cmd.cmd_setup_tmr.bitfield.virt_phy_addr = 1
452
493
  cmd.cmd.cmd_setup_tmr.buf_size = self.tmr_size
453
- return self._ring_submit()
494
+ return self._ring_submit(cmd)
454
495
 
455
- def _load_toc_cmd(self, toc_size):
456
- cmd = self._prep_ring_cmd(am.GFX_CMD_ID_LOAD_TOC)
457
- cmd.cmd.cmd_load_toc.toc_phy_addr_hi, cmd.cmd.cmd_load_toc.toc_phy_addr_lo = data64(self.adev.paddr2mc(self.msg1_paddr))
496
+ def _load_toc_cmd(self, toc_size:int) -> am.struct_psp_gfx_cmd_resp:
497
+ cmd = am.struct_psp_gfx_cmd_resp(cmd_id=am.GFX_CMD_ID_LOAD_TOC)
498
+ cmd.cmd.cmd_load_toc.toc_phy_addr_hi, cmd.cmd.cmd_load_toc.toc_phy_addr_lo = data64(self.msg1_addr)
458
499
  cmd.cmd.cmd_load_toc.toc_size = toc_size
459
- return self._ring_submit()
500
+ return self._ring_submit(cmd)
460
501
 
461
- def _rlc_autoload_cmd(self):
462
- self._prep_ring_cmd(am.GFX_CMD_ID_AUTOLOAD_RLC)
463
- return self._ring_submit()
502
+ def _rlc_autoload_cmd(self): return self._ring_submit(am.struct_psp_gfx_cmd_resp(cmd_id=am.GFX_CMD_ID_AUTOLOAD_RLC))