tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,581 @@
1
+ from __future__ import annotations
2
+ import ctypes, time, array, struct, itertools, dataclasses
3
+ from typing import cast, Any
4
+ from tinygrad.runtime.autogen.nv import nv
5
+ from tinygrad.helpers import to_mv, lo32, hi32, DEBUG, round_up, round_down, mv_address, fetch, wait_cond
6
+ from tinygrad.runtime.support.system import System
7
+ from tinygrad.runtime.support.elf import elf_loader
8
+ from tinygrad.runtime.autogen import nv_gpu
9
+
10
+ @dataclasses.dataclass(frozen=True)
11
+ class GRBufDesc: size:int; virt:bool; phys:bool; local:bool=False # noqa: E702
12
+
13
+ class NV_IP:
14
+ def __init__(self, nvdev): self.nvdev = nvdev
15
+ def init_sw(self): pass # Prepare sw/allocations for this IP
16
+ def init_hw(self): pass # Initialize hw for this IP
17
+ def fini_hw(self): pass # Finalize hw for this IP
18
+
19
+ class NVRpcQueue:
20
+ def __init__(self, gsp:NV_GSP, va:int, completion_q_va:int|None=None):
21
+ self.tx = nv.msgqTxHeader.from_address(va)
22
+ wait_cond(lambda: self.tx.entryOff, value=0x1000, msg="RPC queue not initialized")
23
+
24
+ if completion_q_va is not None: self.rx = nv.msgqRxHeader.from_address(completion_q_va + nv.msgqTxHeader.from_address(completion_q_va).rxHdrOff)
25
+
26
+ self.gsp, self.va, self.queue_va, self.seq = gsp, va, va + self.tx.entryOff, 0
27
+ self.queue_mv = to_mv(self.queue_va, self.tx.msgSize * self.tx.msgCount)
28
+
29
+ def _checksum(self, data:bytes):
30
+ if (pad_len:=(-len(data)) % 8): data += b'\x00' * pad_len
31
+ checksum = 0
32
+ for offset in range(0, len(data), 8): checksum ^= struct.unpack_from('Q', data, offset)[0]
33
+ return hi32(checksum) ^ lo32(checksum)
34
+
35
+ def send_rpc(self, func:int, msg:bytes, wait=False):
36
+ header = nv.rpc_message_header_v(signature=nv.NV_VGPU_MSG_SIGNATURE_VALID, rpc_result=nv.NV_VGPU_MSG_RESULT_RPC_PENDING,
37
+ rpc_result_private=nv.NV_VGPU_MSG_RESULT_RPC_PENDING, header_version=(3<<24), function=func, length=len(msg) + 0x20)
38
+
39
+ msg = bytes(header) + msg
40
+ phdr = nv.GSP_MSG_QUEUE_ELEMENT(elemCount=round_up(len(msg), self.tx.msgSize) // self.tx.msgSize, seqNum=self.seq)
41
+ phdr.checkSum = self._checksum(bytes(phdr) + msg)
42
+ msg = bytes(phdr) + msg
43
+
44
+ off = self.tx.writePtr * self.tx.msgSize
45
+ self.queue_mv[off:off+len(msg)] = msg
46
+ self.tx.writePtr = (self.tx.writePtr + round_up(len(msg), self.tx.msgSize) // self.tx.msgSize) % self.tx.msgCount
47
+ System.memory_barrier()
48
+
49
+ self.seq += 1
50
+ self.gsp.nvdev.NV_PGSP_QUEUE_HEAD[0].write(0x0)
51
+
52
+ def wait_resp(self, cmd:int) -> memoryview:
53
+ while True:
54
+ System.memory_barrier()
55
+ if self.rx.readPtr == self.tx.writePtr: continue
56
+
57
+ off = self.rx.readPtr * self.tx.msgSize
58
+ hdr = nv.rpc_message_header_v.from_address(self.queue_va + off + 0x30)
59
+ msg = self.queue_mv[off + 0x50 : off + 0x50 + hdr.length]
60
+
61
+ # Handling special functions
62
+ if hdr.function == nv.NV_VGPU_MSG_EVENT_GSP_RUN_CPU_SEQUENCER: self.gsp.run_cpu_seq(msg)
63
+ elif hdr.function == nv.NV_VGPU_MSG_EVENT_OS_ERROR_LOG:
64
+ print(f"nv {self.gsp.nvdev.devfmt}: GSP LOG: {msg[12:].tobytes().rstrip(bytes([0])).decode('utf-8')}")
65
+
66
+ # Update the read pointer
67
+ self.rx.readPtr = (self.rx.readPtr + round_up(hdr.length, self.tx.msgSize) // self.tx.msgSize) % self.tx.msgCount
68
+ System.memory_barrier()
69
+
70
+ if DEBUG >= 3:
71
+ rpc_names = {**nv.c__Ea_NV_VGPU_MSG_FUNCTION_NOP__enumvalues, **nv.c__Ea_NV_VGPU_MSG_EVENT_FIRST_EVENT__enumvalues}
72
+ print(f"nv {self.gsp.nvdev.devfmt}: in RPC: {rpc_names.get(hdr.function, f'ev:{hdr.function:x}')}, res:{hdr.rpc_result:#x}")
73
+
74
+ if hdr.rpc_result != 0: raise RuntimeError(f"RPC call {hdr.function} failed with result {hdr.rpc_result}")
75
+ if hdr.function == cmd: return msg
76
+
77
+ class NV_FLCN(NV_IP):
78
+ def init_sw(self):
79
+ self.nvdev.include("src/common/inc/swref/published/ampere/ga102/dev_gsp.h")
80
+ self.nvdev.include("src/common/inc/swref/published/ampere/ga102/dev_falcon_v4.h")
81
+ self.nvdev.include("src/common/inc/swref/published/ampere/ga102/dev_falcon_v4_addendum.h")
82
+ self.nvdev.include("src/common/inc/swref/published/ampere/ga102/dev_riscv_pri.h")
83
+ self.nvdev.include("src/common/inc/swref/published/ampere/ga102/dev_fbif_v4.h")
84
+ self.nvdev.include("src/common/inc/swref/published/ampere/ga102/dev_falcon_second_pri.h")
85
+ self.nvdev.include("src/common/inc/swref/published/ampere/ga102/dev_sec_pri.h")
86
+ self.nvdev.include("src/common/inc/swref/published/turing/tu102/dev_bus.h")
87
+
88
+ self.prep_ucode()
89
+ self.prep_booter()
90
+
91
+ def prep_ucode(self):
92
+ expansion_rom_off, bit_addr = {"GA": 0x16600, "AD": 0x14e00}[self.nvdev.chip_name[:2]], 0x1b0
93
+ vbios_bytes = bytes(array.array('I', self.nvdev.mmio[0x00300000//4:(0x00300000+0x98e00)//4]))
94
+
95
+ bit_header = nv.BIT_HEADER_V1_00.from_buffer_copy(vbios_bytes[bit_addr:bit_addr + ctypes.sizeof(nv.BIT_HEADER_V1_00)])
96
+ assert bit_header.Signature == 0x00544942, f"Invalid BIT header signature {hex(bit_header.Signature)}"
97
+
98
+ for i in range(bit_header.TokenEntries):
99
+ bit = nv.BIT_TOKEN_V1_00.from_buffer_copy(vbios_bytes[bit_addr + bit_header.HeaderSize + i * bit_header.TokenSize:])
100
+ if bit.TokenId != nv.BIT_TOKEN_FALCON_DATA or bit.DataVersion != 2 or bit.DataSize < nv.BIT_DATA_FALCON_DATA_V2_SIZE_4: continue
101
+
102
+ falcon_data = nv.BIT_DATA_FALCON_DATA_V2.from_buffer_copy(vbios_bytes[bit.DataPtr & 0xffff:])
103
+ ucode_hdr = nv.FALCON_UCODE_TABLE_HDR_V1.from_buffer_copy(vbios_bytes[(table_ptr:=expansion_rom_off + falcon_data.FalconUcodeTablePtr):])
104
+ for j in range(ucode_hdr.EntryCount):
105
+ ucode_entry = nv.FALCON_UCODE_TABLE_ENTRY_V1.from_buffer_copy(vbios_bytes[table_ptr + ucode_hdr.HeaderSize + j * ucode_hdr.EntrySize:])
106
+ if ucode_entry.ApplicationID != nv.FALCON_UCODE_ENTRY_APPID_FWSEC_PROD: continue
107
+
108
+ ucode_desc_hdr = nv.FALCON_UCODE_DESC_HEADER.from_buffer_copy(vbios_bytes[expansion_rom_off + ucode_entry.DescPtr:])
109
+ ucode_desc_off = expansion_rom_off + ucode_entry.DescPtr
110
+ ucode_desc_size = ucode_desc_hdr.vDesc >> 16
111
+
112
+ self.desc_v3 = nv.FALCON_UCODE_DESC_V3.from_buffer_copy(vbios_bytes[ucode_desc_off:ucode_desc_off + ucode_desc_size])
113
+
114
+ sig_total_size = ucode_desc_size - nv.FALCON_UCODE_DESC_V3_SIZE_44
115
+ signature = vbios_bytes[ucode_desc_off + nv.FALCON_UCODE_DESC_V3_SIZE_44:][:sig_total_size]
116
+ image = vbios_bytes[ucode_desc_off + ucode_desc_size:][:round_up(self.desc_v3.StoredSize, 256)]
117
+
118
+ self.frts_offset = self.nvdev.vram_size - 0x100000 - 0x100000
119
+ read_vbios_desc = nv.FWSECLIC_READ_VBIOS_DESC(version=0x1, size=ctypes.sizeof(nv.FWSECLIC_READ_VBIOS_DESC), flags=2)
120
+ frst_reg_desc = nv.FWSECLIC_FRTS_REGION_DESC(version=0x1, size=ctypes.sizeof(nv.FWSECLIC_FRTS_REGION_DESC),
121
+ frtsRegionOffset4K=self.frts_offset >> 12, frtsRegionSize=0x100, frtsRegionMediaType=2)
122
+ frts_cmd = nv.FWSECLIC_FRTS_CMD(readVbiosDesc=read_vbios_desc, frtsRegionDesc=frst_reg_desc)
123
+
124
+ def __patch(cmd_id, cmd):
125
+ patched_image = bytearray(image)
126
+
127
+ hdr = nv.FALCON_APPLICATION_INTERFACE_HEADER_V1.from_buffer_copy(image[(app_hdr_off:=self.desc_v3.IMEMLoadSize+self.desc_v3.InterfaceOffset):])
128
+ ents = (nv.FALCON_APPLICATION_INTERFACE_ENTRY_V1 * hdr.entryCount).from_buffer_copy(image[app_hdr_off + ctypes.sizeof(hdr):])
129
+ for i in range(hdr.entryCount):
130
+ if ents[i].id == nv.FALCON_APPLICATION_INTERFACE_ENTRY_ID_DMEMMAPPER: dmem_offset = ents[i].dmemOffset
131
+
132
+ # Patch image
133
+ dmem = nv.FALCON_APPLICATION_INTERFACE_DMEM_MAPPER_V3.from_buffer_copy(image[(dmem_mapper_offset:=self.desc_v3.IMEMLoadSize+dmem_offset):])
134
+ dmem.init_cmd = cmd_id
135
+ patched_image[dmem_mapper_offset : dmem_mapper_offset+len(bytes(dmem))] = bytes(dmem)
136
+ patched_image[(cmd_off:=self.desc_v3.IMEMLoadSize+dmem.cmd_in_buffer_offset) : cmd_off+len(cmd)] = cmd
137
+ patched_image[(sig_off:=self.desc_v3.IMEMLoadSize+self.desc_v3.PKCDataOffset) : sig_off+0x180] = signature[-0x180:]
138
+
139
+ return System.alloc_sysmem(len(patched_image), contiguous=True, data=patched_image)
140
+
141
+ self.frts_image_va, self.frts_image_sysmem = __patch(0x15, bytes(frts_cmd))
142
+
143
+ def prep_booter(self):
144
+ image = self.nvdev.extract_fw("kgspBinArchiveBooterLoadUcode", "image_prod_data")
145
+ sig = self.nvdev.extract_fw("kgspBinArchiveBooterLoadUcode", "sig_prod_data")
146
+ header = self.nvdev.extract_fw("kgspBinArchiveBooterLoadUcode", "header_prod_data")
147
+ patch_loc = int.from_bytes(self.nvdev.extract_fw("kgspBinArchiveBooterLoadUcode", "patch_loc_data"), 'little')
148
+ sig_len = len(sig) // int.from_bytes(self.nvdev.extract_fw("kgspBinArchiveBooterLoadUcode", "num_sigs_data"), 'little')
149
+
150
+ patched_image = bytearray(image)
151
+ patched_image[patch_loc:patch_loc+sig_len] = sig[:sig_len]
152
+ self.booter_image_va, self.booter_image_sysmem = System.alloc_sysmem(len(patched_image), contiguous=True, data=patched_image)
153
+ _, _, self.booter_data_off, self.booter_data_sz, _, self.booter_code_off, self.booter_code_sz, _, _ = struct.unpack("9I", header)
154
+
155
+ def init_hw(self):
156
+ self.falcon, self.sec2 = 0x00110000, 0x00840000
157
+
158
+ self.reset(self.falcon)
159
+ self.execute_hs(self.falcon, self.frts_image_sysmem[0], code_off=0x0, data_off=self.desc_v3.IMEMLoadSize,
160
+ imemPa=self.desc_v3.IMEMPhysBase, imemVa=self.desc_v3.IMEMVirtBase, imemSz=self.desc_v3.IMEMLoadSize,
161
+ dmemPa=self.desc_v3.DMEMPhysBase, dmemVa=0x0, dmemSz=self.desc_v3.DMEMLoadSize,
162
+ pkc_off=self.desc_v3.PKCDataOffset, engid=self.desc_v3.EngineIdMask, ucodeid=self.desc_v3.UcodeId)
163
+ assert self.nvdev.NV_PFB_PRI_MMU_WPR2_ADDR_HI.read() != 0, "WPR2 is not initialized"
164
+
165
+ self.reset(self.falcon, riscv=True)
166
+
167
+ # set up the mailbox
168
+ self.nvdev.NV_PGSP_FALCON_MAILBOX0.write(lo32(self.nvdev.gsp.libos_args_sysmem[0]))
169
+ self.nvdev.NV_PGSP_FALCON_MAILBOX1.write(hi32(self.nvdev.gsp.libos_args_sysmem[0]))
170
+
171
+ # booter
172
+ self.reset(self.sec2)
173
+ mbx = self.execute_hs(self.sec2, self.booter_image_sysmem[0], code_off=self.booter_code_off, data_off=self.booter_data_off,
174
+ imemPa=0x0, imemVa=self.booter_code_off, imemSz=self.booter_code_sz, dmemPa=0x0, dmemVa=0x0, dmemSz=self.booter_data_sz,
175
+ pkc_off=0x10, engid=1, ucodeid=3, mailbox=self.nvdev.gsp.wpr_meta_sysmem)
176
+ assert mbx[0] == 0x0, f"Booter failed to execute, mailbox is {mbx[0]:08x}, {mbx[1]:08x}"
177
+
178
+ self.nvdev.NV_PFALCON_FALCON_OS.with_base(self.falcon).write(0x0)
179
+ assert self.nvdev.NV_PRISCV_RISCV_CPUCTL.with_base(self.falcon).read_bitfields()['active_stat'] == 1, "GSP Core is not active"
180
+
181
+ def execute_dma(self, base:int, cmd:int, dest:int, mem_off:int, sysmem:int, size:int):
182
+ wait_cond(lambda: self.nvdev.NV_PFALCON_FALCON_DMATRFCMD.with_base(base).read_bitfields()['full'], value=0, msg="DMA does not progress")
183
+
184
+ self.nvdev.NV_PFALCON_FALCON_DMATRFBASE.with_base(base).write(lo32(sysmem >> 8))
185
+ self.nvdev.NV_PFALCON_FALCON_DMATRFBASE1.with_base(base).write(hi32(sysmem >> 8) & 0x1ff)
186
+
187
+ xfered = 0
188
+ while xfered < size:
189
+ wait_cond(lambda: self.nvdev.NV_PFALCON_FALCON_DMATRFCMD.with_base(base).read_bitfields()['full'], value=0, msg="DMA does not progress")
190
+
191
+ self.nvdev.NV_PFALCON_FALCON_DMATRFMOFFS.with_base(base).write(dest + xfered)
192
+ self.nvdev.NV_PFALCON_FALCON_DMATRFFBOFFS.with_base(base).write(mem_off + xfered)
193
+ self.nvdev.NV_PFALCON_FALCON_DMATRFCMD.with_base(base).write(cmd)
194
+ xfered += 256
195
+
196
+ wait_cond(lambda: self.nvdev.NV_PFALCON_FALCON_DMATRFCMD.with_base(base).read_bitfields()['idle'], msg="DMA does not complete")
197
+
198
+ def start_cpu(self, base:int):
199
+ if self.nvdev.NV_PFALCON_FALCON_CPUCTL.with_base(base).read_bitfields()['alias_en'] == 1:
200
+ self.nvdev.wreg(base + self.nvdev.NV_PFALCON_FALCON_CPUCTL_ALIAS, 0x2)
201
+ else: self.nvdev.NV_PFALCON_FALCON_CPUCTL.with_base(base).write(startcpu=1)
202
+
203
+ def wait_cpu_halted(self, base): wait_cond(lambda: self.nvdev.NV_PFALCON_FALCON_CPUCTL.with_base(base).read_bitfields()['halted'], msg="not halted")
204
+
205
+ def execute_hs(self, base, img_sysmem, code_off, data_off, imemPa, imemVa, imemSz, dmemPa, dmemVa, dmemSz, pkc_off, engid, ucodeid, mailbox=None):
206
+ self.disable_ctx_req(base)
207
+
208
+ self.nvdev.NV_PFALCON_FBIF_TRANSCFG.with_base(base)[ctx_dma:=0].update(target=self.nvdev.NV_PFALCON_FBIF_TRANSCFG_TARGET_COHERENT_SYSMEM,
209
+ mem_type=self.nvdev.NV_PFALCON_FBIF_TRANSCFG_MEM_TYPE_PHYSICAL)
210
+
211
+ cmd = self.nvdev.NV_PFALCON_FALCON_DMATRFCMD.with_base(base).encode(write=0, size=self.nvdev.NV_PFALCON_FALCON_DMATRFCMD_SIZE_256B,
212
+ ctxdma=ctx_dma, imem=1, sec=1)
213
+ self.execute_dma(base, cmd, dest=imemPa, mem_off=imemVa, sysmem=img_sysmem+code_off-imemVa, size=imemSz)
214
+
215
+ cmd = self.nvdev.NV_PFALCON_FALCON_DMATRFCMD.with_base(base).encode(write=0, size=self.nvdev.NV_PFALCON_FALCON_DMATRFCMD_SIZE_256B,
216
+ ctxdma=ctx_dma, imem=0, sec=0)
217
+ self.execute_dma(base, cmd, dest=dmemPa, mem_off=dmemVa, sysmem=img_sysmem+data_off-dmemVa, size=dmemSz)
218
+
219
+ self.nvdev.NV_PFALCON2_FALCON_BROM_PARAADDR.with_base(base)[0].write(pkc_off)
220
+ self.nvdev.NV_PFALCON2_FALCON_BROM_ENGIDMASK.with_base(base).write(engid)
221
+ self.nvdev.NV_PFALCON2_FALCON_BROM_CURR_UCODE_ID.with_base(base).write(val=ucodeid)
222
+ self.nvdev.NV_PFALCON2_FALCON_MOD_SEL.with_base(base).write(algo=self.nvdev.NV_PFALCON2_FALCON_MOD_SEL_ALGO_RSA3K)
223
+
224
+ self.nvdev.NV_PFALCON_FALCON_BOOTVEC.with_base(base).write(imemVa)
225
+
226
+ if mailbox is not None:
227
+ self.nvdev.NV_PFALCON_FALCON_MAILBOX0.with_base(base).write(lo32(mailbox))
228
+ self.nvdev.NV_PFALCON_FALCON_MAILBOX1.with_base(base).write(hi32(mailbox))
229
+
230
+ self.start_cpu(base)
231
+ self.wait_cpu_halted(base)
232
+
233
+ if mailbox is not None:
234
+ return self.nvdev.NV_PFALCON_FALCON_MAILBOX0.with_base(base).read(), self.nvdev.NV_PFALCON_FALCON_MAILBOX1.with_base(base).read()
235
+
236
+ def disable_ctx_req(self, base:int):
237
+ self.nvdev.NV_PFALCON_FBIF_CTL.with_base(base).update(allow_phys_no_ctx=1)
238
+ self.nvdev.NV_PFALCON_FALCON_DMACTL.with_base(base).write(0x0)
239
+
240
+ def reset(self, base:int, riscv=False):
241
+ engine_reg = self.nvdev.NV_PGSP_FALCON_ENGINE if base == self.falcon else self.nvdev.NV_PSEC_FALCON_ENGINE
242
+ engine_reg.write(reset=1)
243
+ time.sleep(0.1)
244
+ engine_reg.write(reset=0)
245
+
246
+ wait_cond(lambda: self.nvdev.NV_PFALCON_FALCON_HWCFG2.with_base(base).read_bitfields()['mem_scrubbing'], value=0, msg="Scrubbing not completed")
247
+
248
+ if riscv: self.nvdev.NV_PRISCV_RISCV_BCR_CTRL.with_base(base).write(core_select=1, valid=0, brfetch=1)
249
+ elif self.nvdev.NV_PFALCON_FALCON_HWCFG2.with_base(base).read_bitfields()['riscv'] == 1:
250
+ self.nvdev.NV_PRISCV_RISCV_BCR_CTRL.with_base(base).write(core_select=0)
251
+ wait_cond(lambda: self.nvdev.NV_PRISCV_RISCV_BCR_CTRL.with_base(base).read_bitfields()['valid'], msg="RISCV core not booted")
252
+ self.nvdev.NV_PFALCON_FALCON_RM.with_base(base).write(self.nvdev.chip_id)
253
+
254
+ class NV_FLCN_COT(NV_IP):
255
+ def init_sw(self):
256
+ self.nvdev.include("src/common/inc/swref/published/ampere/ga102/dev_gsp.h")
257
+ self.nvdev.include("src/common/inc/swref/published/hopper/gh100/dev_falcon_v4.h")
258
+ self.nvdev.include("src/common/inc/swref/published/hopper/gh100/dev_vm.h")
259
+ self.nvdev.include("src/common/inc/swref/published/hopper/gh100/dev_fsp_pri.h")
260
+ self.nvdev.include("src/common/inc/swref/published/turing/tu102/dev_bus.h")
261
+ self.nvdev.include("src/nvidia/arch/nvalloc/common/inc/fsp/fsp_mctp_format.h")
262
+ self.nvdev.include("src/nvidia/arch/nvalloc/common/inc/fsp/fsp_emem_channels.h")
263
+
264
+ self.fmc_boot_args, self.fmc_boot_args_sysmem = self.nvdev._alloc_boot_struct(nv.GSP_FMC_BOOT_PARAMS())
265
+ self.init_fmc_image()
266
+
267
+ def init_fmc_image(self):
268
+ self.fmc_booter_image = self.nvdev.extract_fw("kgspBinArchiveGspRmFmcGfwProdSigned", "ucode_image_data")
269
+ self.fmc_booter_hash = memoryview(self.nvdev.extract_fw("kgspBinArchiveGspRmFmcGfwProdSigned", "ucode_hash_data")).cast('I')
270
+ self.fmc_booter_sig = memoryview(self.nvdev.extract_fw("kgspBinArchiveGspRmFmcGfwProdSigned", "ucode_sig_data")).cast('I')
271
+ self.fmc_booter_pkey = memoryview(self.nvdev.extract_fw("kgspBinArchiveGspRmFmcGfwProdSigned", "ucode_pkey_data") + b'\x00\x00\x00').cast('I')
272
+ _, self.fmc_booter_sysmem = System.alloc_sysmem(len(self.fmc_booter_image), contiguous=True, data=self.fmc_booter_image)
273
+
274
+ def init_hw(self):
275
+ self.falcon = 0x00110000
276
+
277
+ self.fmc_boot_args.bootGspRmParams = nv.GSP_ACR_BOOT_GSP_RM_PARAMS(gspRmDescOffset=self.nvdev.gsp.wpr_meta_sysmem,
278
+ gspRmDescSize=ctypes.sizeof(nv.GspFwWprMeta), target=nv.GSP_DMA_TARGET_COHERENT_SYSTEM, bIsGspRmBoot=True)
279
+ self.fmc_boot_args.gspRmParams = nv.GSP_RM_PARAMS(bootArgsOffset=self.nvdev.gsp.libos_args_sysmem[0], target=nv.GSP_DMA_TARGET_COHERENT_SYSTEM)
280
+
281
+ cot_payload = nv.NVDM_PAYLOAD_COT(version=0x2, size=ctypes.sizeof(nv.NVDM_PAYLOAD_COT), frtsVidmemOffset=0x1c00000, frtsVidmemSize=0x100000,
282
+ gspBootArgsSysmemOffset=self.fmc_boot_args_sysmem, gspFmcSysmemOffset=self.fmc_booter_sysmem[0])
283
+ for i,x in enumerate(self.fmc_booter_hash): cot_payload.hash384[i] = x
284
+ for i,x in enumerate(self.fmc_booter_sig): cot_payload.signature[i] = x
285
+ for i,x in enumerate(self.fmc_booter_pkey): cot_payload.publicKey[i] = x
286
+
287
+ self.kfsp_send_msg(nv.NVDM_TYPE_COT, bytes(cot_payload))
288
+ wait_cond(lambda: self.nvdev.NV_PFALCON_FALCON_HWCFG2.with_base(self.falcon).read_bitfields()['riscv_br_priv_lockdown'], value=0)
289
+
290
+ def kfsp_send_msg(self, nvmd:int, buf:bytes):
291
+ # All single-packets go to seid 0
292
+ headers = int.to_bytes((1 << 31) | (1 << 30), 4, 'little') + int.to_bytes((0x7e << 0) | (0x10de << 8) | (nvmd << 24), 4, 'little')
293
+ buf = headers + buf + (4 - (len(buf) % 4)) * b'\x00'
294
+ assert len(buf) < 0x400, f"FSP message too long: {len(buf)} bytes, max 1024 bytes"
295
+
296
+ self.nvdev.NV_PFSP_EMEMC[0].write(offs=0, blk=0, aincw=1, aincr=0)
297
+ for i in range(0, len(buf), 4): self.nvdev.NV_PFSP_EMEMD[0].write(int.from_bytes(buf[i:i+4], 'little'))
298
+
299
+ self.nvdev.NV_PFSP_QUEUE_TAIL[0].write(len(buf) - 4)
300
+ self.nvdev.NV_PFSP_QUEUE_HEAD[0].write(0)
301
+
302
+ # Waiting for a response
303
+ wait_cond(lambda: self.nvdev.NV_PFSP_MSGQ_HEAD[0].read() != self.nvdev.NV_PFSP_MSGQ_TAIL[0].read(), msg="FSP didn't respond to message")
304
+
305
+ self.nvdev.NV_PFSP_EMEMC[0].write(offs=0, blk=0, aincw=0, aincr=1)
306
+ self.nvdev.NV_PFSP_MSGQ_TAIL[0].write(self.nvdev.NV_PFSP_MSGQ_HEAD[0].read())
307
+
308
+ class NV_GSP(NV_IP):
309
+ def init_sw(self):
310
+ self.handle_gen = itertools.count(0xcf000000)
311
+ self.init_rm_args()
312
+ self.init_libos_args()
313
+ self.init_wpr_meta()
314
+
315
+ # Prefill cmd queue with info for gsp to start.
316
+ self.rpc_set_gsp_system_info()
317
+ self.rpc_set_registry_table()
318
+
319
+ self.gpfifo_class, self.compute_class, self.dma_class = nv_gpu.AMPERE_CHANNEL_GPFIFO_A, nv_gpu.AMPERE_COMPUTE_B, nv_gpu.AMPERE_DMA_COPY_B
320
+ match self.nvdev.chip_name[:2]:
321
+ case "AD": self.compute_class = nv_gpu.ADA_COMPUTE_A
322
+ case "GB":
323
+ self.gpfifo_class,self.compute_class,self.dma_class=nv_gpu.BLACKWELL_CHANNEL_GPFIFO_A,nv_gpu.BLACKWELL_COMPUTE_B,nv_gpu.BLACKWELL_DMA_COPY_B
324
+
325
+ def init_rm_args(self, queue_size=0x40000):
326
+ # Alloc queues
327
+ pte_cnt = ((queue_pte_cnt:=(queue_size * 2) // 0x1000)) + round_up(queue_pte_cnt * 8, 0x1000) // 0x1000
328
+ pt_size = round_up(pte_cnt * 8, 0x1000)
329
+ queues_va, queues_sysmem = System.alloc_sysmem(pt_size + queue_size * 2, contiguous=False)
330
+
331
+ # Fill up ptes
332
+ for i, sysmem in enumerate(queues_sysmem): to_mv(queues_va + i * 0x8, 0x8).cast('Q')[0] = sysmem
333
+
334
+ # Fill up arguments
335
+ queue_args = nv.MESSAGE_QUEUE_INIT_ARGUMENTS(sharedMemPhysAddr=queues_sysmem[0], pageTableEntryCount=pte_cnt, cmdQueueOffset=pt_size,
336
+ statQueueOffset=pt_size + queue_size)
337
+ rm_args, self.rm_args_sysmem = self.nvdev._alloc_boot_struct(nv.GSP_ARGUMENTS_CACHED(bDmemStack=True, messageQueueInitArguments=queue_args))
338
+
339
+ # Build command queue header
340
+ self.cmd_q_va, self.stat_q_va = queues_va + pt_size, queues_va + pt_size + queue_size
341
+
342
+ cmd_q_tx = nv.msgqTxHeader(version=0, size=queue_size, entryOff=0x1000, msgSize=0x1000, msgCount=(queue_size - 0x1000) // 0x1000,
343
+ writePtr=0, flags=1, rxHdrOff=ctypes.sizeof(nv.msgqTxHeader))
344
+ to_mv(self.cmd_q_va, ctypes.sizeof(nv.msgqTxHeader))[:] = bytes(cmd_q_tx)
345
+
346
+ self.cmd_q = NVRpcQueue(self, self.cmd_q_va, None)
347
+
348
+ def init_libos_args(self):
349
+ _, logbuf_sysmem = System.alloc_sysmem((2 << 20), contiguous=True)
350
+ libos_args_va, self.libos_args_sysmem = System.alloc_sysmem(0x1000, contiguous=True)
351
+
352
+ libos_structs = (nv.LibosMemoryRegionInitArgument * 6).from_address(libos_args_va)
353
+ for i, name in enumerate(["INIT", "INTR", "RM", "MNOC", "KRNL"]):
354
+ libos_structs[i] = nv.LibosMemoryRegionInitArgument(kind=nv.LIBOS_MEMORY_REGION_CONTIGUOUS, loc=nv.LIBOS_MEMORY_REGION_LOC_SYSMEM, size=0x10000,
355
+ id8=int.from_bytes(bytes(f"LOG{name}", 'utf-8'), 'big'), pa=logbuf_sysmem[0] + 0x10000 * i)
356
+
357
+ libos_structs[5] = nv.LibosMemoryRegionInitArgument(kind=nv.LIBOS_MEMORY_REGION_CONTIGUOUS, loc=nv.LIBOS_MEMORY_REGION_LOC_SYSMEM, size=0x1000,
358
+ id8=int.from_bytes(bytes("RMARGS", 'utf-8'), 'big'), pa=self.rm_args_sysmem)
359
+
360
+ def init_gsp_image(self):
361
+ fw = fetch("https://github.com/NVIDIA/linux-firmware/raw/refs/heads/nvidia-staging/nvidia/ga102/gsp/gsp-570.144.bin", subdir="fw").read_bytes()
362
+
363
+ _, sections, _ = elf_loader(fw)
364
+ self.gsp_image = next((sh.content for sh in sections if sh.name == ".fwimage"))
365
+ signature = next((sh.content for sh in sections if sh.name == (f".fwsignature_{self.nvdev.chip_name[:4].lower()}x")))
366
+
367
+ # Build radix3
368
+ npages = [0, 0, 0, round_up(len(self.gsp_image), 0x1000) // 0x1000]
369
+ for i in range(3, 0, -1): npages[i-1] = ((npages[i] - 1) >> (nv.LIBOS_MEMORY_REGION_RADIX_PAGE_LOG2 - 3)) + 1
370
+
371
+ offsets = [sum(npages[:i]) * 0x1000 for i in range(4)]
372
+ radix_va, self.gsp_radix3_sysmem = System.alloc_sysmem(offsets[-1] + len(self.gsp_image), contiguous=False)
373
+
374
+ # Copy image
375
+ to_mv(radix_va + offsets[-1], len(self.gsp_image))[:] = self.gsp_image
376
+
377
+ # Copy level and image pages.
378
+ for i in range(0, 3):
379
+ cur_offset = sum(npages[:i+1])
380
+ to_mv(radix_va + offsets[i], npages[i+1] * 8).cast('Q')[:] = array.array('Q', self.gsp_radix3_sysmem[cur_offset:cur_offset+npages[i+1]])
381
+
382
+ # Copy signature
383
+ self.gsp_signature_va, self.gsp_signature_sysmem = System.alloc_sysmem(len(signature), contiguous=True, data=signature)
384
+
385
+ def init_boot_binary_image(self):
386
+ self.booter_image = self.nvdev.extract_fw("kgspBinArchiveGspRmBoot", "ucode_image_prod_data")
387
+ self.booter_desc = nv.RM_RISCV_UCODE_DESC.from_buffer_copy(self.nvdev.extract_fw("kgspBinArchiveGspRmBoot", "ucode_desc_prod_data"))
388
+ _, self.booter_sysmem = System.alloc_sysmem(len(self.booter_image), contiguous=True, data=self.booter_image)
389
+
390
+ def init_wpr_meta(self):
391
+ self.init_gsp_image()
392
+ self.init_boot_binary_image()
393
+
394
+ common = {'sizeOfBootloader':(boot_sz:=len(self.booter_image)), 'sysmemAddrOfBootloader':self.booter_sysmem[0],
395
+ 'sizeOfRadix3Elf':(radix3_sz:=len(self.gsp_image)), 'sysmemAddrOfRadix3Elf': self.gsp_radix3_sysmem[0],
396
+ 'sizeOfSignature': 0x1000, 'sysmemAddrOfSignature': self.gsp_signature_sysmem[0],
397
+ 'bootloaderCodeOffset': self.booter_desc.monitorCodeOffset, 'bootloaderDataOffset': self.booter_desc.monitorDataOffset,
398
+ 'bootloaderManifestOffset': self.booter_desc.manifestOffset, 'revision':nv.GSP_FW_WPR_META_REVISION, 'magic':nv.GSP_FW_WPR_META_MAGIC}
399
+
400
+ if self.nvdev.fmc_boot:
401
+ m = nv.GspFwWprMeta(**common, vgaWorkspaceSize=0x20000, pmuReservedSize=0x1820000, nonWprHeapSize=0x220000, gspFwHeapSize=0x8700000,
402
+ frtsSize=0x100000)
403
+ else:
404
+ m = nv.GspFwWprMeta(**common, vgaWorkspaceSize=(vga_sz:=0x100000), vgaWorkspaceOffset=(vga_off:=self.nvdev.vram_size-vga_sz),
405
+ gspFwWprEnd=vga_off, frtsSize=(frts_sz:=0x100000), frtsOffset=(frts_off:=vga_off-frts_sz), bootBinOffset=(boot_off:=frts_off-boot_sz),
406
+ gspFwOffset=(gsp_off:=round_down(boot_off-radix3_sz, 0x10000)), gspFwHeapSize=(gsp_heap_sz:=0x8100000), fbSize=self.nvdev.vram_size,
407
+ gspFwHeapOffset=(gsp_heap_off:=round_down(gsp_off-gsp_heap_sz, 0x100000)), gspFwWprStart=(wpr_st:=round_down(gsp_heap_off-0x1000, 0x100000)),
408
+ nonWprHeapSize=(non_wpr_sz:=0x100000), nonWprHeapOffset=(non_wpr_off:=round_down(wpr_st-non_wpr_sz, 0x100000)), gspFwRsvdStart=non_wpr_off)
409
+ assert self.nvdev.flcn.frts_offset == m.frtsOffset, f"FRTS mismatch: {self.nvdev.flcn.frts_offset} != {m.frtsOffset}"
410
+ self.wpr_meta, self.wpr_meta_sysmem = self.nvdev._alloc_boot_struct(m)
411
+
412
+ def promote_ctx(self, client:int, subdevice:int, obj:int, ctxbufs:dict[int, GRBufDesc], bufs=None, virt=None, phys=None):
413
+ res, prom = {}, nv_gpu.NV2080_CTRL_GPU_PROMOTE_CTX_PARAMS(entryCount=len(ctxbufs), engineType=0x1, hChanClient=client, hObject=obj)
414
+ for i,(buf,desc) in enumerate(ctxbufs.items()):
415
+ use_v, use_p = (desc.virt if virt is None else virt), (desc.phys if phys is None else phys)
416
+ x = (bufs or {}).get(buf, self.nvdev.mm.valloc(desc.size, contiguous=True)) # allocate buffers
417
+ prom.promoteEntry[i] = nv_gpu.NV2080_CTRL_GPU_PROMOTE_CTX_BUFFER_ENTRY(bufferId=buf, gpuVirtAddr=x.va_addr if use_v else 0, bInitialize=use_p,
418
+ gpuPhysAddr=x.paddrs[0][0] if use_p else 0, size=desc.size if use_p else 0, physAttr=0x4 if use_p else 0, bNonmapped=(use_p and not use_v))
419
+ res[buf] = x
420
+ self.rpc_rm_control(hObject=subdevice, cmd=nv_gpu.NV2080_CTRL_CMD_GPU_PROMOTE_CTX, params=prom, client=client)
421
+ return res
422
+
423
+ def init_golden_image(self):
424
+ self.rpc_rm_alloc(hParent=0x0, hClass=0x0, params=nv_gpu.NV0000_ALLOC_PARAMETERS())
425
+ dev = self.rpc_rm_alloc(hParent=self.priv_root, hClass=nv_gpu.NV01_DEVICE_0, params=nv_gpu.NV0080_ALLOC_PARAMETERS(hClientShare=self.priv_root))
426
+ subdev = self.rpc_rm_alloc(hParent=dev, hClass=nv_gpu.NV20_SUBDEVICE_0, params=nv_gpu.NV2080_ALLOC_PARAMETERS())
427
+ vaspace = self.rpc_rm_alloc(hParent=dev, hClass=nv_gpu.FERMI_VASPACE_A, params=nv_gpu.NV_VASPACE_ALLOCATION_PARAMETERS())
428
+
429
+ # reserve 512MB for the reserved PDES
430
+ res_va = self.nvdev.mm.alloc_vaddr(res_sz:=(512 << 20))
431
+
432
+ bufs_p = nv_gpu.struct_NV90F1_CTRL_VASPACE_COPY_SERVER_RESERVED_PDES_PARAMS(pageSize=res_sz, numLevelsToCopy=3,
433
+ virtAddrLo=res_va, virtAddrHi=res_va + res_sz - 1)
434
+ for i,pt in enumerate(self.nvdev.mm.page_tables(res_va, size=res_sz)):
435
+ bufs_p.levels[i] = nv_gpu.struct_NV90F1_CTRL_VASPACE_COPY_SERVER_RESERVED_PDES_PARAMS_0(physAddress=pt.paddr,
436
+ size=self.nvdev.mm.pte_cnt[0] * 8 if i == 0 else 0x1000, pageShift=self.nvdev.mm.pte_covers[i].bit_length() - 1, aperture=1)
437
+ self.rpc_rm_control(hObject=vaspace, cmd=nv_gpu.NV90F1_CTRL_CMD_VASPACE_COPY_SERVER_RESERVED_PDES, params=bufs_p)
438
+
439
+ gpfifo_area = self.nvdev.mm.valloc(4 << 10, contiguous=True)
440
+ userd = nv_gpu.NV_MEMORY_DESC_PARAMS(base=gpfifo_area.paddrs[0][0] + 0x20 * 8, size=0x20, addressSpace=2, cacheAttrib=0)
441
+ gg_params = nv_gpu.NV_CHANNELGPFIFO_ALLOCATION_PARAMETERS(gpFifoOffset=gpfifo_area.va_addr, gpFifoEntries=32, engineType=0x1, cid=3,
442
+ hVASpace=vaspace, userdOffset=(ctypes.c_uint64*8)(0x20 * 8), userdMem=userd, internalFlags=0x1a, flags=0x200320)
443
+ ch_gpfifo = self.rpc_rm_alloc(hParent=dev, hClass=self.gpfifo_class, params=gg_params)
444
+
445
+ gr_ctx_bufs_info = self.rpc_rm_control(hObject=subdev, cmd=nv_gpu.NV2080_CTRL_CMD_INTERNAL_STATIC_KGR_GET_CONTEXT_BUFFERS_INFO,
446
+ params=nv_gpu.NV2080_CTRL_INTERNAL_STATIC_KGR_GET_CONTEXT_BUFFERS_INFO_PARAMS()).engineContextBuffersInfo[0]
447
+ def _ctx_info(idx, add=0, align=None): return round_up(gr_ctx_bufs_info.engine[idx].size + add, align or gr_ctx_bufs_info.engine[idx].alignment)
448
+
449
+ # Setup graphics context
450
+ gr_size = _ctx_info(nv_gpu.NV0080_CTRL_FIFO_GET_ENGINE_CONTEXT_PROPERTIES_ENGINE_ID_GRAPHICS, add=0x40000)
451
+ patch_size = _ctx_info(nv_gpu.NV0080_CTRL_FIFO_GET_ENGINE_CONTEXT_PROPERTIES_ENGINE_ID_GRAPHICS_PATCH)
452
+ cfgs_sizes = {x: _ctx_info(x + 14, align=(2 << 20) if x == 5 else None) for x in range(3, 11)} # indices 3–10 are mapped to 17–24
453
+ self.grctx_bufs = {0: GRBufDesc(gr_size, phys=True, virt=True), 1: GRBufDesc(patch_size, phys=True, virt=True, local=True),
454
+ 2: GRBufDesc(patch_size, phys=True, virt=True), **{x: GRBufDesc(cfgs_sizes[x], phys=False, virt=True) for x in range(3, 7)},
455
+ 9: GRBufDesc(cfgs_sizes[9], phys=True, virt=True), 10: GRBufDesc(cfgs_sizes[10], phys=True, virt=False),
456
+ 11: GRBufDesc(cfgs_sizes[10], phys=True, virt=True)} # NOTE: 11 reuses cfgs_sizes[10]
457
+ self.promote_ctx(self.priv_root, subdev, ch_gpfifo, {k:v for k, v in self.grctx_bufs.items() if not v.local})
458
+
459
+ self.rpc_rm_alloc(hParent=ch_gpfifo, hClass=self.compute_class, params=None)
460
+ self.rpc_rm_alloc(hParent=ch_gpfifo, hClass=self.dma_class, params=None)
461
+
462
+ def init_hw(self):
463
+ self.stat_q = NVRpcQueue(self, self.stat_q_va, self.cmd_q_va)
464
+ self.cmd_q.rx = nv.msgqRxHeader.from_address(self.stat_q.va + self.stat_q.tx.rxHdrOff)
465
+
466
+ self.stat_q.wait_resp(nv.NV_VGPU_MSG_EVENT_GSP_INIT_DONE)
467
+
468
+ self.nvdev.NV_PBUS_BAR1_BLOCK.write(mode=0, target=0, ptr=0)
469
+ if self.nvdev.fmc_boot: self.nvdev.NV_VIRTUAL_FUNCTION_PRIV_FUNC_BAR1_BLOCK_LOW_ADDR.write(mode=0, target=0, ptr=0)
470
+
471
+ self.priv_root = 0xc1e00004
472
+ self.init_golden_image()
473
+
474
+ def fini_hw(self): self.rpc_unloading_guest_driver()
475
+
476
+ ### RPCs
477
+
478
+ def rpc_rm_alloc(self, hParent:int, hClass:int, params:Any, client=None) -> int:
479
+ if hClass == self.gpfifo_class:
480
+ ramfc_alloc = self.nvdev.mm.valloc(0x1000, contiguous=True)
481
+ params.ramfcMem = nv_gpu.NV_MEMORY_DESC_PARAMS(base=ramfc_alloc.paddrs[0][0], size=0x200, addressSpace=2, cacheAttrib=0)
482
+ params.instanceMem = nv_gpu.NV_MEMORY_DESC_PARAMS(base=ramfc_alloc.paddrs[0][0], size=0x1000, addressSpace=2, cacheAttrib=0)
483
+
484
+ method_va, method_sysmem = System.alloc_sysmem(0x5000, contiguous=True)
485
+ params.mthdbufMem = nv_gpu.NV_MEMORY_DESC_PARAMS(base=method_sysmem[0], size=0x5000, addressSpace=1, cacheAttrib=0)
486
+
487
+ if client is not None and client != self.priv_root and params.hObjectError != 0:
488
+ params.errorNotifierMem = nv_gpu.NV_MEMORY_DESC_PARAMS(base=0, size=0xecc, addressSpace=0, cacheAttrib=0)
489
+ params.userdMem = nv_gpu.NV_MEMORY_DESC_PARAMS(base=params.hUserdMemory[0] + params.userdOffset[0], size=0x400, addressSpace=2, cacheAttrib=0)
490
+
491
+ alloc_args = nv.rpc_gsp_rm_alloc_v(hClient=(client:=client or self.priv_root), hParent=hParent, hObject=(obj:=next(self.handle_gen)),
492
+ hClass=hClass, flags=0x0, paramsSize=ctypes.sizeof(params) if params is not None else 0x0)
493
+ self.cmd_q.send_rpc(nv.NV_VGPU_MSG_FUNCTION_GSP_RM_ALLOC, bytes(alloc_args) + (bytes(params) if params is not None else b''))
494
+ self.stat_q.wait_resp(nv.NV_VGPU_MSG_FUNCTION_GSP_RM_ALLOC)
495
+
496
+ if hClass == nv_gpu.FERMI_VASPACE_A and client != self.priv_root:
497
+ self.rpc_set_page_directory(device=hParent, hVASpace=obj, pdir_paddr=self.nvdev.mm.root_page_table.paddr, client=client)
498
+ if hClass == nv_gpu.NV20_SUBDEVICE_0: self.subdevice = obj # save subdevice handle
499
+ if hClass == self.compute_class and client != self.priv_root:
500
+ phys_gr_ctx = self.promote_ctx(client, self.subdevice, hParent, {k:v for k,v in self.grctx_bufs.items() if k in [0, 1, 2]}, virt=False)
501
+ self.promote_ctx(client, self.subdevice, hParent, {k:v for k,v in self.grctx_bufs.items() if k in [0, 1, 2]}, phys_gr_ctx, phys=False)
502
+ return obj if hClass != nv_gpu.NV1_ROOT else client
503
+
504
+ def rpc_rm_control(self, hObject:int, cmd:int, params:Any, client=None):
505
+ control_args = nv.rpc_gsp_rm_control_v(hClient=(client:=client or self.priv_root), hObject=hObject, cmd=cmd, flags=0x0,
506
+ paramsSize=ctypes.sizeof(params) if params is not None else 0x0)
507
+ self.cmd_q.send_rpc(nv.NV_VGPU_MSG_FUNCTION_GSP_RM_CONTROL, bytes(control_args) + (bytes(params) if params is not None else b''))
508
+ res = self.stat_q.wait_resp(nv.NV_VGPU_MSG_FUNCTION_GSP_RM_CONTROL)
509
+ st = type(params).from_buffer_copy(res[len(bytes(control_args)):]) if params is not None else None
510
+
511
+ # NOTE: gb20x requires the enable bit for token submission. Patch workSubmitToken here to maintain userspace compatibility.
512
+ if self.nvdev.chip_name.startswith("GB2") and cmd == nv_gpu.NVC36F_CTRL_CMD_GPFIFO_GET_WORK_SUBMIT_TOKEN:
513
+ cast(nv_gpu.NVC36F_CTRL_CMD_GPFIFO_GET_WORK_SUBMIT_TOKEN_PARAMS, st).workSubmitToken |= (1 << 30)
514
+ return st
515
+
516
+ def rpc_set_page_directory(self, device:int, hVASpace:int, pdir_paddr:int, client=None, pasid=0xffffffff):
517
+ params = nv.struct_NV0080_CTRL_DMA_SET_PAGE_DIRECTORY_PARAMS_v1E_05(physAddress=pdir_paddr,
518
+ numEntries=self.nvdev.mm.pte_cnt[0], flags=0x8, hVASpace=hVASpace, pasid=pasid, subDeviceId=1, chId=0) # flags field is all channels.
519
+ alloc_args = nv.rpc_set_page_directory_v(hClient=client or self.priv_root, hDevice=device, pasid=pasid, params=params)
520
+ self.cmd_q.send_rpc(nv.NV_VGPU_MSG_FUNCTION_SET_PAGE_DIRECTORY, bytes(alloc_args))
521
+ self.stat_q.wait_resp(nv.NV_VGPU_MSG_FUNCTION_SET_PAGE_DIRECTORY)
522
+
523
+ def rpc_set_gsp_system_info(self):
524
+ def bdf_as_int(s): return (int(s[5:7],16)<<8) | (int(s[8:10],16)<<3) | int(s[-1],16)
525
+
526
+ data = nv.GspSystemInfo(gpuPhysAddr=self.nvdev.bars[0][0], gpuPhysFbAddr=self.nvdev.bars[1][0], gpuPhysInstAddr=self.nvdev.bars[3][0],
527
+ pciConfigMirrorBase=[0x88000, 0x92000][self.nvdev.fmc_boot], pciConfigMirrorSize=0x1000, nvDomainBusDeviceFunc=bdf_as_int(self.nvdev.devfmt),
528
+ bIsPassthru=1, PCIDeviceID=self.nvdev.venid, PCISubDeviceID=self.nvdev.subvenid, PCIRevisionID=self.nvdev.rev, maxUserVa=0x7ffffffff000)
529
+ self.cmd_q.send_rpc(nv.NV_VGPU_MSG_FUNCTION_GSP_SET_SYSTEM_INFO, bytes(data))
530
+
531
+ def rpc_unloading_guest_driver(self):
532
+ data = nv.rpc_unloading_guest_driver_v(bInPMTransition=0, bGc6Entering=0, newLevel=(__GPU_STATE_FLAGS_FAST_UNLOAD:=1 << 6))
533
+ self.cmd_q.send_rpc(nv.NV_VGPU_MSG_FUNCTION_UNLOADING_GUEST_DRIVER, bytes(data))
534
+ self.stat_q.wait_resp(nv.NV_VGPU_MSG_FUNCTION_UNLOADING_GUEST_DRIVER)
535
+
536
+ def rpc_set_registry_table(self):
537
+ table = {'RMForcePcieConfigSave': 0x1, 'RMSecBusResetEnable': 0x1}
538
+ entries_bytes, data_bytes = bytes(), bytes()
539
+ hdr_size, entries_size = ctypes.sizeof(nv.PACKED_REGISTRY_TABLE), ctypes.sizeof(nv.PACKED_REGISTRY_ENTRY) * len(table)
540
+
541
+ for k,v in table.items():
542
+ entries_bytes += bytes(nv.PACKED_REGISTRY_ENTRY(nameOffset=hdr_size + entries_size + len(data_bytes),
543
+ type=nv.REGISTRY_TABLE_ENTRY_TYPE_DWORD, data=v, length=4))
544
+ data_bytes += k.encode('utf-8') + b'\x00'
545
+
546
+ header = nv.PACKED_REGISTRY_TABLE(size=hdr_size + len(entries_bytes) + len(data_bytes), numEntries=len(table))
547
+ self.cmd_q.send_rpc(nv.NV_VGPU_MSG_FUNCTION_SET_REGISTRY, bytes(header) + entries_bytes + data_bytes)
548
+
549
+ def run_cpu_seq(self, seq_buf:memoryview):
550
+ hdr = nv.rpc_run_cpu_sequencer_v17_00.from_address(mv_address(seq_buf))
551
+ cmd_iter = iter(seq_buf[ctypes.sizeof(nv.rpc_run_cpu_sequencer_v17_00):].cast('I')[:hdr.cmdIndex])
552
+
553
+ for op in cmd_iter:
554
+ if op == 0x0: self.nvdev.wreg(next(cmd_iter), next(cmd_iter)) # reg write
555
+ elif op == 0x1: # reg modify
556
+ addr, val, mask = next(cmd_iter), next(cmd_iter), next(cmd_iter)
557
+ self.nvdev.wreg(addr, (self.nvdev.rreg(addr) & ~mask) | (val & mask))
558
+ elif op == 0x2: # reg poll
559
+ addr, mask, val, _, _ = next(cmd_iter), next(cmd_iter), next(cmd_iter), next(cmd_iter), next(cmd_iter)
560
+ wait_cond(lambda: (self.nvdev.rreg(addr) & mask), value=val, msg=f"Register {addr:#x} not equal to {val:#x} after polling")
561
+ elif op == 0x3: time.sleep(next(cmd_iter) / 1e6) # delay us
562
+ elif op == 0x4: # save reg
563
+ addr, index = next(cmd_iter), next(cmd_iter)
564
+ hdr.regSaveArea[index] = self.nvdev.rreg(addr)
565
+ elif op == 0x5: # core reset
566
+ self.nvdev.flcn.reset(self.nvdev.flcn.falcon)
567
+ self.nvdev.flcn.disable_ctx_req(self.nvdev.flcn.falcon)
568
+ elif op == 0x6: self.nvdev.flcn.start_cpu(self.nvdev.flcn.falcon)
569
+ elif op == 0x7: self.nvdev.flcn.wait_cpu_halted(self.nvdev.flcn.falcon)
570
+ elif op == 0x8: # core resume
571
+ self.nvdev.flcn.reset(self.nvdev.flcn.falcon, riscv=True)
572
+
573
+ self.nvdev.NV_PGSP_FALCON_MAILBOX0.write(lo32(self.libos_args_sysmem[0]))
574
+ self.nvdev.NV_PGSP_FALCON_MAILBOX1.write(hi32(self.libos_args_sysmem[0]))
575
+
576
+ self.nvdev.flcn.start_cpu(self.nvdev.flcn.sec2)
577
+ wait_cond(lambda: self.nvdev.NV_PGC6_BSI_SECURE_SCRATCH_14.read_bitfields()['boot_stage_3_handoff'], msg="SEC2 didn't hand off")
578
+
579
+ mailbox = self.nvdev.NV_PFALCON_FALCON_MAILBOX0.with_base(self.nvdev.flcn.sec2).read()
580
+ assert mailbox == 0x0, f"Falcon SEC2 failed to execute, mailbox is {mailbox:08x}"
581
+ else: raise ValueError(f"Unknown op code {op} in run_cpu_seq")