tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,396 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import ctypes, collections, time, dataclasses, pathlib, fcntl, os
|
3
|
+
from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
|
4
|
+
from tinygrad.runtime.autogen.am import am, mp_11_0
|
5
|
+
from tinygrad.runtime.support.allocator import TLSFAllocator
|
6
|
+
from tinygrad.runtime.support.am.ip import AM_SOC21, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
|
7
|
+
|
8
|
+
AM_DEBUG = getenv("AM_DEBUG", 0)
|
9
|
+
|
10
|
+
@dataclasses.dataclass(frozen=True)
|
11
|
+
class AMRegister:
|
12
|
+
adev:AMDev; reg_off:int; reg_fields:dict[str, tuple[int, int]] # noqa: E702
|
13
|
+
|
14
|
+
def _parse_kwargs(self, **kwargs):
|
15
|
+
mask, values = 0xffffffff, 0
|
16
|
+
for k, v in kwargs.items():
|
17
|
+
if k not in self.reg_fields: raise ValueError(f"Unknown register field: {k}. {self.reg_fields.keys()}")
|
18
|
+
m, s = self.reg_fields[k]
|
19
|
+
if v & (m>>s) != v: raise ValueError(f"Value {v} for {k} is out of range {m=} {s=}")
|
20
|
+
mask &= ~m
|
21
|
+
values |= v << s
|
22
|
+
return mask, values
|
23
|
+
|
24
|
+
def build(self, **kwargs) -> int: return self._parse_kwargs(**kwargs)[1]
|
25
|
+
|
26
|
+
def update(self, **kwargs): self.write(value=self.read(), **kwargs)
|
27
|
+
|
28
|
+
def write(self, value=0, **kwargs):
|
29
|
+
mask, values = self._parse_kwargs(**kwargs)
|
30
|
+
self.adev.wreg(self.reg_off, (value & mask) | values)
|
31
|
+
|
32
|
+
def read(self, **kwargs): return self.adev.rreg(self.reg_off) & self._parse_kwargs(**kwargs)[0]
|
33
|
+
|
34
|
+
class AMFirmware:
|
35
|
+
def __init__(self, adev):
|
36
|
+
def fmt_ver(hwip): return f"{adev.ip_versions[hwip]//10000}_{(adev.ip_versions[hwip]//100)%100}_{adev.ip_versions[hwip]%100}"
|
37
|
+
|
38
|
+
# Load SOS firmware
|
39
|
+
self.sos_fw = {}
|
40
|
+
|
41
|
+
blob, sos_hdr = self.load_fw(f"psp_{fmt_ver(am.MP0_HWIP)}_sos.bin", am.struct_psp_firmware_header_v2_0)
|
42
|
+
fw_bin = sos_hdr.psp_fw_bin
|
43
|
+
|
44
|
+
for fw_i in range(sos_hdr.psp_fw_bin_count):
|
45
|
+
fw_bin_desc = am.struct_psp_fw_bin_desc.from_address(ctypes.addressof(fw_bin) + fw_i * ctypes.sizeof(am.struct_psp_fw_bin_desc))
|
46
|
+
ucode_start_offset = fw_bin_desc.offset_bytes + sos_hdr.header.ucode_array_offset_bytes
|
47
|
+
self.sos_fw[fw_bin_desc.fw_type] = blob[ucode_start_offset:ucode_start_offset+fw_bin_desc.size_bytes]
|
48
|
+
|
49
|
+
# Load other fw
|
50
|
+
self.ucode_start: dict[str, int] = {}
|
51
|
+
self.descs: list[tuple[int, memoryview]] = []
|
52
|
+
|
53
|
+
blob, hdr = self.load_fw(f"smu_{fmt_ver(am.MP1_HWIP)}.bin", am.struct_smc_firmware_header_v1_0)
|
54
|
+
self.smu_psp_desc = self.desc(am.GFX_FW_TYPE_SMU, blob, hdr.header.ucode_array_offset_bytes, hdr.header.ucode_size_bytes)
|
55
|
+
|
56
|
+
# SDMA firmware
|
57
|
+
blob, hdr = self.load_fw(f"sdma_{fmt_ver(am.SDMA0_HWIP)}.bin", am.struct_sdma_firmware_header_v2_0)
|
58
|
+
self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH0, blob, hdr.header.ucode_array_offset_bytes, hdr.ctx_ucode_size_bytes)]
|
59
|
+
self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH1, blob, hdr.ctl_ucode_offset, hdr.ctl_ucode_size_bytes)]
|
60
|
+
|
61
|
+
# PFP, ME, MEC firmware
|
62
|
+
for (fw_name, fw_cnt) in [('PFP', 2), ('ME', 2), ('MEC', 4)]:
|
63
|
+
blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_{fw_name.lower()}.bin", am.struct_gfx_firmware_header_v2_0)
|
64
|
+
|
65
|
+
# Code part
|
66
|
+
self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}'), blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes)]
|
67
|
+
|
68
|
+
# Stack
|
69
|
+
fw_types = [getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}_P{fwnun}_STACK') for fwnun in range(fw_cnt)]
|
70
|
+
self.descs += [self.desc(typ, blob, hdr.data_offset_bytes, hdr.data_size_bytes) for typ in fw_types]
|
71
|
+
self.ucode_start[fw_name] = hdr.ucode_start_addr_lo | (hdr.ucode_start_addr_hi << 32)
|
72
|
+
|
73
|
+
# IMU firmware
|
74
|
+
blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_imu.bin", am.struct_imu_firmware_header_v1_0)
|
75
|
+
imu_i_off, imu_i_sz, imu_d_sz = hdr.header.ucode_array_offset_bytes, hdr.imu_iram_ucode_size_bytes, hdr.imu_dram_ucode_size_bytes
|
76
|
+
self.descs += [self.desc(am.GFX_FW_TYPE_IMU_I, blob, imu_i_off, imu_i_sz), self.desc(am.GFX_FW_TYPE_IMU_D, blob, imu_i_off + imu_i_sz, imu_d_sz)]
|
77
|
+
|
78
|
+
# RLC firmware
|
79
|
+
blob, hdr0, hdr1, hdr2, hdr3 = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_rlc.bin", am.struct_rlc_firmware_header_v2_0,
|
80
|
+
am.struct_rlc_firmware_header_v2_1, am.struct_rlc_firmware_header_v2_2, am.struct_rlc_firmware_header_v2_3)
|
81
|
+
|
82
|
+
for mem in ['GPM', 'SRM']:
|
83
|
+
off, sz = getattr(hdr1, f'save_restore_list_{mem.lower()}_offset_bytes'), getattr(hdr1, f'save_restore_list_{mem.lower()}_size_bytes')
|
84
|
+
self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_RESTORE_LIST_{mem}_MEM'), blob, off, sz)]
|
85
|
+
|
86
|
+
for mem,fmem in [('IRAM', 'iram'), ('DRAM_BOOT', 'dram')]:
|
87
|
+
off, sz = getattr(hdr2, f'rlc_{fmem}_ucode_offset_bytes'), getattr(hdr2, f'rlc_{fmem}_ucode_size_bytes')
|
88
|
+
self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_{mem}'), blob, off, sz)]
|
89
|
+
|
90
|
+
for mem in ['P', 'V']:
|
91
|
+
off, sz = getattr(hdr3, f'rlc{mem.lower()}_ucode_offset_bytes'), getattr(hdr3, f'rlc{mem.lower()}_ucode_size_bytes')
|
92
|
+
self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_{mem}'), blob, off, sz)]
|
93
|
+
|
94
|
+
self.descs += [self.desc(am.GFX_FW_TYPE_RLC_G, blob, hdr0.header.ucode_array_offset_bytes, hdr0.header.ucode_size_bytes)]
|
95
|
+
|
96
|
+
def load_fw(self, fname:str, *headers):
|
97
|
+
fpath = next(f for loc in ["/lib/firmware/updates/amdgpu/", "/lib/firmware/amdgpu/"] if (f:=pathlib.Path(loc + fname)).exists())
|
98
|
+
blob = memoryview(bytearray(fpath.read_bytes()))
|
99
|
+
return tuple([blob] + [hdr.from_address(mv_address(blob)) for hdr in headers])
|
100
|
+
|
101
|
+
def desc(self, typ:int, blob:memoryview, offset:int, size:int) -> tuple[int, memoryview]: return (typ, blob[offset:offset+size])
|
102
|
+
|
103
|
+
@dataclasses.dataclass(frozen=True)
|
104
|
+
class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702
|
105
|
+
|
106
|
+
class AMPageTableEntry:
|
107
|
+
def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.entries, self.lv = adev, paddr, to_mv(adev.paddr2cpu(paddr), 0x1000).cast('Q'), lv
|
108
|
+
|
109
|
+
def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, system=False, snooped=False, frag=0, valid=True):
|
110
|
+
assert paddr & self.adev.gmc.address_space_mask == paddr, f"Invalid physical address {paddr:#x}"
|
111
|
+
|
112
|
+
f = (am.AMDGPU_PTE_VALID if valid else 0) | ((am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE) if not table else 0) \
|
113
|
+
| am.AMDGPU_PTE_FRAG(frag) | (am.AMDGPU_PDE_PTE if not table and self.lv != am.AMDGPU_VM_PTB else 0) \
|
114
|
+
| ((am.AMDGPU_PTE_SYSTEM) if system else 0) | ((am.AMDGPU_PTE_SNOOPED) if snooped else 0) \
|
115
|
+
| (am.AMDGPU_PTE_MTYPE_NV10(0, am.MTYPE_UC) if uncached else 0)
|
116
|
+
self.entries[entry_id] = (paddr & 0x0000FFFFFFFFF000) | f
|
117
|
+
|
118
|
+
class AMPageTableTraverseContext:
|
119
|
+
def __init__(self, adev, pt, vaddr, create_pts=False, free_pts=False):
|
120
|
+
self.adev, self.vaddr, self.create_pts, self.free_pts = adev, vaddr - adev.gmc.vm_base, create_pts, free_pts
|
121
|
+
self.pt_stack:list[tuple[AMPageTableEntry, int, int]] = [(pt, self._pt_pte_idx(pt, vaddr), self._pt_pte_size(pt))]
|
122
|
+
|
123
|
+
def _pt_pte_size(self, pt): return (1 << ((9 * (3-pt.lv)) + 12))
|
124
|
+
def _pt_pte_idx(self, pt, va): return (va // self._pt_pte_size(pt)) % 512
|
125
|
+
|
126
|
+
def level_down(self):
|
127
|
+
pt, pte_idx, _ = self.pt_stack[-1]
|
128
|
+
if (entry:=pt.entries[pte_idx]) & am.AMDGPU_PTE_VALID == 0:
|
129
|
+
assert self.create_pts, "Not allowed to create new page table"
|
130
|
+
pt.set_entry(pte_idx, self.adev.mm.palloc(0x1000, zero=True), table=True, valid=True)
|
131
|
+
entry = pt.entries[pte_idx]
|
132
|
+
|
133
|
+
assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}"
|
134
|
+
child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1)
|
135
|
+
|
136
|
+
self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table)))
|
137
|
+
return self.pt_stack[-1]
|
138
|
+
|
139
|
+
def _try_free_pt(self) -> bool:
|
140
|
+
pt, _, _ = self.pt_stack[-1]
|
141
|
+
if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.entries[i] & am.AMDGPU_PTE_VALID == 0 for i in range(512)):
|
142
|
+
self.adev.mm.pfree(pt.paddr)
|
143
|
+
parent_pt, parent_pte_idx, _ = self.pt_stack[-2]
|
144
|
+
parent_pt.set_entry(parent_pte_idx, 0x0, valid=False)
|
145
|
+
return True
|
146
|
+
return False
|
147
|
+
|
148
|
+
def level_up(self):
|
149
|
+
while self._try_free_pt() or self.pt_stack[-1][1] == 512:
|
150
|
+
_, pt_cnt, _ = self.pt_stack.pop()
|
151
|
+
if pt_cnt == 512: self.pt_stack[-1] = (self.pt_stack[-1][0], self.pt_stack[-1][1] + 1, self.pt_stack[-1][2])
|
152
|
+
|
153
|
+
def next(self, size:int, off=0):
|
154
|
+
while size > 0:
|
155
|
+
pt, pte_idx, pte_covers = self.pt_stack[-1]
|
156
|
+
if self.create_pts:
|
157
|
+
while pte_covers > size: pt, pte_idx, pte_covers = self.level_down()
|
158
|
+
else:
|
159
|
+
while pt.lv!=am.AMDGPU_VM_PTB and (pt.entries[pte_idx] & am.AMDGPU_PDE_PTE != am.AMDGPU_PDE_PTE): pt, pte_idx, pte_covers = self.level_down()
|
160
|
+
|
161
|
+
entries = min(size // pte_covers, 512 - pte_idx)
|
162
|
+
assert entries > 0, "Invalid entries"
|
163
|
+
yield off, pt, pte_idx, entries, pte_covers
|
164
|
+
|
165
|
+
size, off, self.vaddr = size - entries * pte_covers, off + entries * pte_covers, self.vaddr + entries * pte_covers
|
166
|
+
self.pt_stack[-1] = (pt, pte_idx + entries, pte_covers)
|
167
|
+
self.level_up()
|
168
|
+
|
169
|
+
class AMMemoryManager:
|
170
|
+
va_allocator = TLSFAllocator(512 * (1 << 30), base=0x7F0000000000) # global for all devices.
|
171
|
+
|
172
|
+
def __init__(self, adev:AMDev, vram_size:int):
|
173
|
+
self.adev, self.vram_size = adev, vram_size
|
174
|
+
self.boot_allocator = TLSFAllocator(32 << 20, base=vram_size - (64 << 20)) # per device
|
175
|
+
self.pa_allocator = TLSFAllocator(vram_size - (64 << 20)) # per device
|
176
|
+
self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=not self.adev.smi_dev, boot=True), lv=am.AMDGPU_VM_PDB1)
|
177
|
+
|
178
|
+
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False) -> AMMapping:
|
179
|
+
if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: mapping {vaddr=:#x} ({size=:#x})")
|
180
|
+
|
181
|
+
assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
|
182
|
+
|
183
|
+
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, create_pts=True)
|
184
|
+
for paddr, psize in paddrs:
|
185
|
+
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize):
|
186
|
+
for pte_off in range(pte_cnt):
|
187
|
+
assert pt.entries[pte_idx + pte_off] & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.entries[pte_idx + pte_off]:#x}"
|
188
|
+
pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers,
|
189
|
+
uncached=uncached, system=system, snooped=snooped, frag=0 if pte_covers == 0x1000 else 0x9, valid=True)
|
190
|
+
|
191
|
+
# Invalidate TLB after mappings.
|
192
|
+
self.adev.gmc.flush_tlb(ip='GC', vmid=0)
|
193
|
+
self.adev.gmc.flush_tlb(ip='MM', vmid=0)
|
194
|
+
return AMMapping(vaddr, size, paddrs, uncached=uncached, system=system, snooped=snooped)
|
195
|
+
|
196
|
+
def unmap_range(self, vaddr:int, size:int):
|
197
|
+
if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: unmapping {vaddr=:#x} ({size=:#x})")
|
198
|
+
|
199
|
+
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, free_pts=True)
|
200
|
+
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size):
|
201
|
+
for pte_id in range(pte_idx, pte_idx + pte_cnt):
|
202
|
+
assert pt.entries[pte_id] & am.AMDGPU_PTE_VALID == am.AMDGPU_PTE_VALID, f"PTE not mapped: {pt.entries[pte_id]:#x}"
|
203
|
+
pt.set_entry(pte_id, paddr=0x0, valid=False)
|
204
|
+
|
205
|
+
@staticmethod
|
206
|
+
def alloc_vaddr(size:int, align=0x1000) -> int: return AMMemoryManager.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align))
|
207
|
+
|
208
|
+
def valloc(self, size:int, align=0x1000, uncached=False, contigous=False) -> AMMapping:
|
209
|
+
# Alloc physical memory and map it to the virtual address
|
210
|
+
va = self.alloc_vaddr(size, align)
|
211
|
+
|
212
|
+
if contigous: paddrs = [(self.palloc(size, zero=True), size)]
|
213
|
+
else:
|
214
|
+
paddrs = []
|
215
|
+
try:
|
216
|
+
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, va, create_pts=True)
|
217
|
+
for _, _, _, seg_cnt, seg_size in ctx.next(size): paddrs += [(self.palloc(seg_size, zero=False), seg_size) for _ in range(seg_cnt)]
|
218
|
+
except MemoryError:
|
219
|
+
for paddr, _ in paddrs: self.pa_allocator.free(paddr)
|
220
|
+
raise
|
221
|
+
|
222
|
+
return self.map_range(va, size, paddrs, uncached=uncached)
|
223
|
+
|
224
|
+
def vfree(self, vm:AMMapping):
|
225
|
+
self.unmap_range(vm.va_addr, vm.size)
|
226
|
+
self.va_allocator.free(vm.va_addr)
|
227
|
+
for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr)
|
228
|
+
|
229
|
+
def palloc(self, size:int, align:int=0x1000, zero=True, boot=False) -> int:
|
230
|
+
assert self.adev.is_booting == boot, "During booting, only boot memory can be allocated"
|
231
|
+
paddr = (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align)
|
232
|
+
if zero: ctypes.memset(self.adev.paddr2cpu(paddr), 0, size)
|
233
|
+
return paddr
|
234
|
+
|
235
|
+
def pfree(self, paddr:int): self.pa_allocator.free(paddr)
|
236
|
+
|
237
|
+
class AMDev:
|
238
|
+
def __init__(self, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview):
|
239
|
+
self.devfmt = devfmt
|
240
|
+
self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar
|
241
|
+
|
242
|
+
os.umask(0) # Set umask to 0 to allow creating files with 0666 permissions
|
243
|
+
|
244
|
+
# Avoid O_CREAT because we don’t want to re-create/replace an existing file (triggers extra perms checks) when opening as non-owner.
|
245
|
+
if os.path.exists(lock_name:=temp(f"am_{self.devfmt}.lock")): self.lock_fd = os.open(lock_name, os.O_RDWR)
|
246
|
+
else: self.lock_fd = os.open(lock_name, os.O_RDWR | os.O_CREAT, 0o666)
|
247
|
+
|
248
|
+
try: fcntl.flock(self.lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
249
|
+
except OSError: raise RuntimeError(f"Failed to open AM device {self.devfmt}. It's already in use.")
|
250
|
+
|
251
|
+
self._run_discovery()
|
252
|
+
self._build_regs()
|
253
|
+
|
254
|
+
# AM boot Process:
|
255
|
+
# The GPU being passed can be in one of several states: 1. Not initialized. 2. Initialized by amdgpu. 3. Initialized by AM.
|
256
|
+
# The 1st and 2nd states require a full GPU setup since their states are unknown. The 2nd state also requires a mode1 reset to
|
257
|
+
# reinitialize all components.
|
258
|
+
#
|
259
|
+
# The 3rd state can be set up partially to optimize boot time. In this case, only the GFX and SDMA IPs need to be initialized.
|
260
|
+
# To enable this, AM uses a separate boot memory that is guaranteed not to be overwritten. This physical memory is utilized for
|
261
|
+
# all blocks that are initialized only during the initial AM boot.
|
262
|
+
# To determine if the GPU is in the third state, AM uses regSCRATCH_REG7 as a flag.
|
263
|
+
self.is_booting, self.smi_dev = True, False # During boot only boot memory can be allocated. This flag is to validate this.
|
264
|
+
self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000002)) and (getenv("AM_RESET", 0) != 1)
|
265
|
+
|
266
|
+
# Memory manager & firmware
|
267
|
+
self.mm = AMMemoryManager(self, self.vram_size)
|
268
|
+
self.fw = AMFirmware(self)
|
269
|
+
|
270
|
+
# Initialize IP blocks
|
271
|
+
self.soc21:AM_SOC21 = AM_SOC21(self)
|
272
|
+
self.gmc:AM_GMC = AM_GMC(self)
|
273
|
+
self.ih:AM_IH = AM_IH(self)
|
274
|
+
self.psp:AM_PSP = AM_PSP(self)
|
275
|
+
self.smu:AM_SMU = AM_SMU(self)
|
276
|
+
self.gfx:AM_GFX = AM_GFX(self)
|
277
|
+
self.sdma:AM_SDMA = AM_SDMA(self)
|
278
|
+
|
279
|
+
if self.partial_boot and (self.reg("regGCVM_CONTEXT0_CNTL").read() != 0):
|
280
|
+
if DEBUG >= 2: print(f"am {self.devfmt}: MEC is active. Issue a full reset.")
|
281
|
+
self.partial_boot = False
|
282
|
+
|
283
|
+
if not self.partial_boot:
|
284
|
+
if self.psp.is_sos_alive() and self.smu.is_smu_alive(): self.smu.mode1_reset()
|
285
|
+
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]:
|
286
|
+
ip.init()
|
287
|
+
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
|
288
|
+
|
289
|
+
# Booting done
|
290
|
+
self.is_booting = False
|
291
|
+
|
292
|
+
# Re-initialize main blocks
|
293
|
+
for ip in [self.gfx, self.sdma]:
|
294
|
+
ip.init()
|
295
|
+
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
|
296
|
+
|
297
|
+
self.smu.set_clocks(level=-1) # last level, max perf.
|
298
|
+
self.gfx.set_clockgating_state()
|
299
|
+
self.reg("regSCRATCH_REG7").write(am_version)
|
300
|
+
if DEBUG >= 2: print(f"am {self.devfmt}: boot done")
|
301
|
+
|
302
|
+
def fini(self):
|
303
|
+
if DEBUG >= 2: print(f"am {self.devfmt}: Finalizing")
|
304
|
+
for ip in [self.sdma, self.gfx]: ip.fini()
|
305
|
+
self.smu.set_clocks(level=0)
|
306
|
+
self.ih.interrupt_handler()
|
307
|
+
|
308
|
+
def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr
|
309
|
+
def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr
|
310
|
+
|
311
|
+
def ip_base(self, ip:str, inst:int, seg:int) -> int: return self.regs_offset[am.__dict__[f"{ip}_HWIP"]][inst][seg]
|
312
|
+
|
313
|
+
def reg(self, reg:str) -> AMRegister: return self.__dict__[reg]
|
314
|
+
|
315
|
+
def rreg(self, reg:int) -> int:
|
316
|
+
val = self.indirect_rreg(reg * 4) if reg > len(self.mmio) else self.mmio[reg]
|
317
|
+
if AM_DEBUG >= 4 and getattr(self, '_prev_rreg', None) != (reg, val): print(f"am {self.devfmt}: Reading register {reg:#x} with value {val:#x}")
|
318
|
+
self._prev_rreg = (reg, val)
|
319
|
+
return val
|
320
|
+
|
321
|
+
def wreg(self, reg:int, val:int):
|
322
|
+
if AM_DEBUG >= 4: print(f"am {self.devfmt}: Writing register {reg:#x} with value {val:#x}")
|
323
|
+
if reg > len(self.mmio): self.indirect_wreg(reg * 4, val)
|
324
|
+
else: self.mmio[reg] = val
|
325
|
+
|
326
|
+
def wreg_pair(self, reg_base:str, lo_suffix:str, hi_suffix:str, val:int):
|
327
|
+
self.reg(f"{reg_base}{lo_suffix}").write(val & 0xffffffff)
|
328
|
+
self.reg(f"{reg_base}{hi_suffix}").write(val >> 32)
|
329
|
+
|
330
|
+
def indirect_rreg(self, reg:int) -> int:
|
331
|
+
self.reg("regBIF_BX_PF0_RSMU_INDEX").write(reg)
|
332
|
+
return self.reg("regBIF_BX_PF0_RSMU_DATA").read()
|
333
|
+
|
334
|
+
def indirect_wreg(self, reg:int, val:int):
|
335
|
+
self.reg("regBIF_BX_PF0_RSMU_INDEX").write(reg)
|
336
|
+
self.reg("regBIF_BX_PF0_RSMU_DATA").write(val)
|
337
|
+
|
338
|
+
def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff, timeout=10000) -> int:
|
339
|
+
for _ in range(timeout):
|
340
|
+
if ((rval:=reg.read()) & mask) == value: return rval
|
341
|
+
time.sleep(0.001)
|
342
|
+
raise RuntimeError(f'wait_reg timeout reg=0x{reg.reg_off:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}')
|
343
|
+
|
344
|
+
def _run_discovery(self):
|
345
|
+
# NOTE: Fixed register to query memory size without known ip bases to find the discovery table.
|
346
|
+
# The table is located at the end of VRAM - 64KB and is 10KB in size.
|
347
|
+
mmRCC_CONFIG_MEMSIZE = 0xde3
|
348
|
+
self.vram_size = self.rreg(mmRCC_CONFIG_MEMSIZE) << 20
|
349
|
+
|
350
|
+
bhdr = am.struct_binary_header.from_address(self.paddr2cpu(self.vram_size - (64 << 10)))
|
351
|
+
ihdr = am.struct_ip_discovery_header.from_address(ctypes.addressof(bhdr) + bhdr.table_list[am.IP_DISCOVERY].offset)
|
352
|
+
assert ihdr.signature == am.DISCOVERY_TABLE_SIGNATURE and not ihdr.base_addr_64_bit, f"0x{ihdr.signature:X} != 0x{am.DISCOVERY_TABLE_SIGNATURE:X}"
|
353
|
+
|
354
|
+
# Mapping of HW IP to Discovery HW IP
|
355
|
+
hw_id_map = {am.__dict__[x]: int(y) for x,y in am.hw_id_map}
|
356
|
+
self.regs_offset:dict[int, dict[int, list]] = collections.defaultdict(dict)
|
357
|
+
self.ip_versions:dict[int, int] = {}
|
358
|
+
|
359
|
+
for num_die in range(ihdr.num_dies):
|
360
|
+
dhdr = am.struct_die_header.from_address(ctypes.addressof(bhdr) + ihdr.die_info[num_die].die_offset)
|
361
|
+
|
362
|
+
ip_offset = ctypes.addressof(bhdr) + ctypes.sizeof(dhdr) + ihdr.die_info[num_die].die_offset
|
363
|
+
for _ in range(dhdr.num_ips):
|
364
|
+
ip = am.struct_ip_v4.from_address(ip_offset)
|
365
|
+
ba = (ctypes.c_uint32 * ip.num_base_address).from_address(ip_offset + 8)
|
366
|
+
for hw_ip in range(1, am.MAX_HWIP):
|
367
|
+
if hw_ip in hw_id_map and hw_id_map[hw_ip] == ip.hw_id:
|
368
|
+
self.regs_offset[hw_ip][ip.instance_number] = list(ba)
|
369
|
+
self.ip_versions[hw_ip] = int(f"{ip.major:02d}{ip.minor:02d}{ip.revision:02d}")
|
370
|
+
|
371
|
+
ip_offset += 8 + (8 if ihdr.base_addr_64_bit else 4) * ip.num_base_address
|
372
|
+
|
373
|
+
gc_info = am.struct_gc_info_v1_0.from_address(gc_addr:=ctypes.addressof(bhdr) + bhdr.table_list[am.GC].offset)
|
374
|
+
self.gc_info = getattr(am, f"struct_gc_info_v{gc_info.header.version_major}_{gc_info.header.version_minor}").from_address(gc_addr)
|
375
|
+
|
376
|
+
def _ip_module(self, prefix:str, hwip):
|
377
|
+
version = [self.ip_versions[hwip]//10000, (self.ip_versions[hwip]//100)%100, self.ip_versions[hwip]%100]
|
378
|
+
for ver in [version, version[:2]+[0], version[:1]+[0, 0]]:
|
379
|
+
try: return __import__(f"tinygrad.runtime.autogen.am.{prefix}_{ver[0]}_{ver[1]}_{ver[2]}", fromlist=[f"{prefix}_{ver[0]}_{ver[1]}_{ver[2]}"])
|
380
|
+
except ImportError: pass
|
381
|
+
raise ImportError(f"am {self.devfmt}: failed to load {prefix} module with version {version}")
|
382
|
+
|
383
|
+
def _build_regs(self):
|
384
|
+
mods = [("MP0", self._ip_module("mp", am.MP0_HWIP)), ("NBIO", self._ip_module("nbio", am.NBIO_HWIP)), ("GC", self._ip_module("gc", am.GC_HWIP)),
|
385
|
+
("MP1", mp_11_0), ("MMHUB", self._ip_module("mmhub", am.MMHUB_HWIP)), ("OSSSYS", self._ip_module("osssys", am.OSSSYS_HWIP))]
|
386
|
+
for base, module in mods:
|
387
|
+
rpref = "mm" if base == "MP1" else "reg" # MP1 regs starts with mm
|
388
|
+
reg_names: set[str] = set(k[len(rpref):] for k in module.__dict__.keys() if k.startswith(rpref) and not k.endswith("_BASE_IDX"))
|
389
|
+
reg_fields: dict[str, dict[str, tuple]] = collections.defaultdict(dict)
|
390
|
+
for k, val in module.__dict__.items():
|
391
|
+
if k.endswith("_MASK") and ((rname:=k.split("__")[0]) in reg_names):
|
392
|
+
reg_fields[rname][k[2+len(rname):-5].lower()] = (val, module.__dict__.get(f"{k[:-5]}__SHIFT", val.bit_length() - 1))
|
393
|
+
|
394
|
+
for k, regval in module.__dict__.items():
|
395
|
+
if k.startswith(rpref) and not k.endswith("_BASE_IDX") and (base_idx:=getattr(module, f"{k}_BASE_IDX", None)) is not None:
|
396
|
+
setattr(self, k, AMRegister(self, self.ip_base(base, 0, base_idx) + regval, reg_fields.get(k[len(rpref):], {})))
|