tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -1,44 +1,35 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import ctypes, collections,
|
3
|
-
from tinygrad.helpers import
|
4
|
-
from tinygrad.runtime.autogen.am import am
|
5
|
-
from tinygrad.runtime.support.
|
6
|
-
from tinygrad.runtime.support.
|
2
|
+
import ctypes, collections, dataclasses, functools, os, hashlib
|
3
|
+
from tinygrad.helpers import mv_address, getenv, DEBUG, fetch
|
4
|
+
from tinygrad.runtime.autogen.am import am
|
5
|
+
from tinygrad.runtime.support.hcq import MMIOInterface
|
6
|
+
from tinygrad.runtime.support.amd import AMDReg, import_module, import_asic_regs
|
7
|
+
from tinygrad.runtime.support.memory import TLSFAllocator, MemoryManager
|
8
|
+
from tinygrad.runtime.support.system import System, PCIDevImplBase
|
9
|
+
from tinygrad.runtime.support.am.ip import AM_SOC, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
|
7
10
|
|
8
11
|
AM_DEBUG = getenv("AM_DEBUG", 0)
|
9
12
|
|
10
|
-
@dataclasses.dataclass
|
11
|
-
class AMRegister:
|
12
|
-
adev:AMDev
|
13
|
+
@dataclasses.dataclass
|
14
|
+
class AMRegister(AMDReg):
|
15
|
+
adev:AMDev
|
13
16
|
|
14
|
-
def
|
15
|
-
|
16
|
-
for k, v in kwargs.items():
|
17
|
-
if k not in self.reg_fields: raise ValueError(f"Unknown register field: {k}. {self.reg_fields.keys()}")
|
18
|
-
m, s = self.reg_fields[k]
|
19
|
-
if v & (m>>s) != v: raise ValueError(f"Value {v} for {k} is out of range {m=} {s=}")
|
20
|
-
mask &= ~m
|
21
|
-
values |= v << s
|
22
|
-
return mask, values
|
17
|
+
def read(self, inst=0): return self.adev.rreg(self.addr[inst])
|
18
|
+
def read_bitfields(self, inst=0) -> dict[str, int]: return self.decode(self.read(inst=inst))
|
23
19
|
|
24
|
-
def
|
20
|
+
def write(self, _am_val:int=0, inst=0, **kwargs): self.adev.wreg(self.addr[inst], _am_val | self.encode(**kwargs))
|
25
21
|
|
26
|
-
def update(self, **kwargs): self.write(
|
27
|
-
|
28
|
-
def write(self, value=0, **kwargs):
|
29
|
-
mask, values = self._parse_kwargs(**kwargs)
|
30
|
-
self.adev.wreg(self.reg_off, (value & mask) | values)
|
31
|
-
|
32
|
-
def read(self, **kwargs): return self.adev.rreg(self.reg_off) & self._parse_kwargs(**kwargs)[0]
|
22
|
+
def update(self, inst=0, **kwargs): self.write(self.read(inst=inst) & ~self.fields_mask(*kwargs.keys()), inst=inst, **kwargs)
|
33
23
|
|
34
24
|
class AMFirmware:
|
35
25
|
def __init__(self, adev):
|
36
|
-
|
26
|
+
self.adev = adev
|
27
|
+
def fmt_ver(hwip): return '_'.join(map(str, adev.ip_ver[hwip]))
|
37
28
|
|
38
29
|
# Load SOS firmware
|
39
30
|
self.sos_fw = {}
|
40
31
|
|
41
|
-
blob, sos_hdr = self.load_fw(f"psp_{fmt_ver(am.MP0_HWIP)}_sos.bin",
|
32
|
+
blob, sos_hdr = self.load_fw(f"psp_{fmt_ver(am.MP0_HWIP)}_sos.bin", versioned_header='struct_psp_firmware_header')
|
42
33
|
fw_bin = sos_hdr.psp_fw_bin
|
43
34
|
|
44
35
|
for fw_i in range(sos_hdr.psp_fw_bin_count):
|
@@ -48,205 +39,88 @@ class AMFirmware:
|
|
48
39
|
|
49
40
|
# Load other fw
|
50
41
|
self.ucode_start: dict[str, int] = {}
|
51
|
-
self.descs: list[tuple[int, memoryview]] = []
|
42
|
+
self.descs: list[tuple[list[int], memoryview]] = []
|
52
43
|
|
53
44
|
blob, hdr = self.load_fw(f"smu_{fmt_ver(am.MP1_HWIP)}.bin", am.struct_smc_firmware_header_v1_0)
|
54
|
-
self.smu_psp_desc = self.desc(
|
45
|
+
self.smu_psp_desc = self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.header.ucode_size_bytes, am.GFX_FW_TYPE_SMU)
|
55
46
|
|
56
47
|
# SDMA firmware
|
57
|
-
blob, hdr = self.load_fw(f"sdma_{fmt_ver(am.SDMA0_HWIP)}.bin",
|
58
|
-
|
59
|
-
|
48
|
+
blob, hdr = self.load_fw(f"sdma_{fmt_ver(am.SDMA0_HWIP)}.bin", versioned_header='struct_sdma_firmware_header')
|
49
|
+
if hdr.header.header_version_major < 3:
|
50
|
+
self.descs += [self.desc(blob, hdr.ctl_ucode_offset, hdr.ctl_ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH1)]
|
51
|
+
self.descs += [self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.ctx_ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH0)]
|
52
|
+
else: self.descs += [self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH0)]
|
60
53
|
|
61
54
|
# PFP, ME, MEC firmware
|
62
|
-
for (fw_name, fw_cnt) in [('PFP',
|
55
|
+
for (fw_name, fw_cnt) in ([('PFP', 1), ('ME', 1)] if self.adev.ip_ver[am.GC_HWIP] >= (12,0,0) else []) + [('MEC', 1)]:
|
63
56
|
blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_{fw_name.lower()}.bin", am.struct_gfx_firmware_header_v2_0)
|
64
57
|
|
65
58
|
# Code part
|
66
|
-
self.descs += [self.desc(
|
59
|
+
self.descs += [self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes, getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}'))]
|
67
60
|
|
68
61
|
# Stack
|
69
|
-
|
70
|
-
self.descs += [self.desc(
|
62
|
+
stack_fws = [getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}_P{fwnum}_STACK') for fwnum in range(fw_cnt)]
|
63
|
+
self.descs += [self.desc(blob, hdr.data_offset_bytes, hdr.data_size_bytes, *stack_fws)]
|
71
64
|
self.ucode_start[fw_name] = hdr.ucode_start_addr_lo | (hdr.ucode_start_addr_hi << 32)
|
72
65
|
|
73
66
|
# IMU firmware
|
74
67
|
blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_imu.bin", am.struct_imu_firmware_header_v1_0)
|
75
68
|
imu_i_off, imu_i_sz, imu_d_sz = hdr.header.ucode_array_offset_bytes, hdr.imu_iram_ucode_size_bytes, hdr.imu_dram_ucode_size_bytes
|
76
|
-
self.descs += [self.desc(
|
69
|
+
self.descs += [self.desc(blob, imu_i_off, imu_i_sz, am.GFX_FW_TYPE_IMU_I), self.desc(blob, imu_i_off + imu_i_sz, imu_d_sz, am.GFX_FW_TYPE_IMU_D)]
|
77
70
|
|
78
71
|
# RLC firmware
|
79
|
-
blob, hdr0,
|
72
|
+
blob, hdr0, _hdr1, hdr2, hdr3 = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_rlc.bin", am.struct_rlc_firmware_header_v2_0,
|
80
73
|
am.struct_rlc_firmware_header_v2_1, am.struct_rlc_firmware_header_v2_2, am.struct_rlc_firmware_header_v2_3)
|
81
74
|
|
82
|
-
for mem in ['GPM', 'SRM']:
|
83
|
-
off, sz = getattr(hdr1, f'save_restore_list_{mem.lower()}_offset_bytes'), getattr(hdr1, f'save_restore_list_{mem.lower()}_size_bytes')
|
84
|
-
self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_RESTORE_LIST_{mem}_MEM'), blob, off, sz)]
|
85
|
-
|
86
75
|
for mem,fmem in [('IRAM', 'iram'), ('DRAM_BOOT', 'dram')]:
|
87
76
|
off, sz = getattr(hdr2, f'rlc_{fmem}_ucode_offset_bytes'), getattr(hdr2, f'rlc_{fmem}_ucode_size_bytes')
|
88
|
-
self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_{mem}')
|
77
|
+
self.descs += [self.desc(blob, off, sz, getattr(am, f'GFX_FW_TYPE_RLC_{mem}'))]
|
89
78
|
|
90
|
-
|
91
|
-
|
92
|
-
|
79
|
+
if hdr0.header.header_version_minor == 3:
|
80
|
+
for mem in ['P', 'V']:
|
81
|
+
off, sz = getattr(hdr3, f'rlc{mem.lower()}_ucode_offset_bytes'), getattr(hdr3, f'rlc{mem.lower()}_ucode_size_bytes')
|
82
|
+
self.descs += [self.desc(blob, off, sz, getattr(am, f'GFX_FW_TYPE_RLC_{mem}'))]
|
93
83
|
|
94
|
-
self.descs += [self.desc(
|
84
|
+
self.descs += [self.desc(blob, hdr0.header.ucode_array_offset_bytes, hdr0.header.ucode_size_bytes, am.GFX_FW_TYPE_RLC_G)]
|
95
85
|
|
96
|
-
def load_fw(self, fname:str, *headers):
|
97
|
-
fpath =
|
86
|
+
def load_fw(self, fname:str, *headers, versioned_header:str|None=None):
|
87
|
+
fpath = fetch(f"https://gitlab.com/kernel-firmware/linux-firmware/-/raw/45f59212aebd226c7630aff4b58598967c0c8c91/amdgpu/{fname}", subdir="fw")
|
98
88
|
blob = memoryview(bytearray(fpath.read_bytes()))
|
89
|
+
if AM_DEBUG >= 1: print(f"am {self.adev.devfmt}: loading firmware {fname}: {hashlib.sha256(blob).hexdigest()}")
|
90
|
+
if versioned_header:
|
91
|
+
chdr = am.struct_common_firmware_header.from_address(mv_address(blob))
|
92
|
+
headers += (getattr(am, versioned_header + f"_v{chdr.header_version_major}_{chdr.header_version_minor}"),)
|
99
93
|
return tuple([blob] + [hdr.from_address(mv_address(blob)) for hdr in headers])
|
100
94
|
|
101
|
-
def desc(self,
|
102
|
-
|
103
|
-
@dataclasses.dataclass(frozen=True)
|
104
|
-
class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702
|
95
|
+
def desc(self, blob:memoryview, offset:int, size:int, *types:int) -> tuple[list[int], memoryview]: return (list(types), blob[offset:offset+size])
|
105
96
|
|
106
97
|
class AMPageTableEntry:
|
107
|
-
def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.
|
98
|
+
def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.lv, self.entries = adev, paddr, lv, adev.vram.view(paddr, 0x1000, fmt='Q')
|
108
99
|
|
109
100
|
def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, system=False, snooped=False, frag=0, valid=True):
|
110
101
|
assert paddr & self.adev.gmc.address_space_mask == paddr, f"Invalid physical address {paddr:#x}"
|
102
|
+
self.entries[entry_id] = self.adev.gmc.get_pte_flags(self.lv, table, frag, uncached, system, snooped, valid) | (paddr & 0x0000FFFFFFFFF000)
|
111
103
|
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
class AMPageTableTraverseContext:
|
119
|
-
def __init__(self, adev, pt, vaddr, create_pts=False, free_pts=False):
|
120
|
-
self.adev, self.vaddr, self.create_pts, self.free_pts = adev, vaddr - adev.gmc.vm_base, create_pts, free_pts
|
121
|
-
self.pt_stack:list[tuple[AMPageTableEntry, int, int]] = [(pt, self._pt_pte_idx(pt, vaddr), self._pt_pte_size(pt))]
|
122
|
-
|
123
|
-
def _pt_pte_size(self, pt): return (1 << ((9 * (3-pt.lv)) + 12))
|
124
|
-
def _pt_pte_idx(self, pt, va): return (va // self._pt_pte_size(pt)) % 512
|
125
|
-
|
126
|
-
def level_down(self):
|
127
|
-
pt, pte_idx, _ = self.pt_stack[-1]
|
128
|
-
if (entry:=pt.entries[pte_idx]) & am.AMDGPU_PTE_VALID == 0:
|
129
|
-
assert self.create_pts, "Not allowed to create new page table"
|
130
|
-
pt.set_entry(pte_idx, self.adev.mm.palloc(0x1000, zero=True), table=True, valid=True)
|
131
|
-
entry = pt.entries[pte_idx]
|
132
|
-
|
133
|
-
assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}"
|
134
|
-
child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1)
|
135
|
-
|
136
|
-
self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table)))
|
137
|
-
return self.pt_stack[-1]
|
138
|
-
|
139
|
-
def _try_free_pt(self) -> bool:
|
140
|
-
pt, _, _ = self.pt_stack[-1]
|
141
|
-
if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.entries[i] & am.AMDGPU_PTE_VALID == 0 for i in range(512)):
|
142
|
-
self.adev.mm.pfree(pt.paddr)
|
143
|
-
parent_pt, parent_pte_idx, _ = self.pt_stack[-2]
|
144
|
-
parent_pt.set_entry(parent_pte_idx, 0x0, valid=False)
|
145
|
-
return True
|
146
|
-
return False
|
147
|
-
|
148
|
-
def level_up(self):
|
149
|
-
while self._try_free_pt() or self.pt_stack[-1][1] == 512:
|
150
|
-
_, pt_cnt, _ = self.pt_stack.pop()
|
151
|
-
if pt_cnt == 512: self.pt_stack[-1] = (self.pt_stack[-1][0], self.pt_stack[-1][1] + 1, self.pt_stack[-1][2])
|
152
|
-
|
153
|
-
def next(self, size:int, off=0):
|
154
|
-
while size > 0:
|
155
|
-
pt, pte_idx, pte_covers = self.pt_stack[-1]
|
156
|
-
if self.create_pts:
|
157
|
-
while pte_covers > size: pt, pte_idx, pte_covers = self.level_down()
|
158
|
-
else:
|
159
|
-
while pt.lv!=am.AMDGPU_VM_PTB and (pt.entries[pte_idx] & am.AMDGPU_PDE_PTE != am.AMDGPU_PDE_PTE): pt, pte_idx, pte_covers = self.level_down()
|
160
|
-
|
161
|
-
entries = min(size // pte_covers, 512 - pte_idx)
|
162
|
-
assert entries > 0, "Invalid entries"
|
163
|
-
yield off, pt, pte_idx, entries, pte_covers
|
164
|
-
|
165
|
-
size, off, self.vaddr = size - entries * pte_covers, off + entries * pte_covers, self.vaddr + entries * pte_covers
|
166
|
-
self.pt_stack[-1] = (pt, pte_idx + entries, pte_covers)
|
167
|
-
self.level_up()
|
168
|
-
|
169
|
-
class AMMemoryManager:
|
170
|
-
va_allocator = TLSFAllocator(512 * (1 << 30), base=0x7F0000000000) # global for all devices.
|
171
|
-
|
172
|
-
def __init__(self, adev:AMDev, vram_size:int):
|
173
|
-
self.adev, self.vram_size = adev, vram_size
|
174
|
-
self.boot_allocator = TLSFAllocator(32 << 20, base=vram_size - (64 << 20)) # per device
|
175
|
-
self.pa_allocator = TLSFAllocator(vram_size - (64 << 20)) # per device
|
176
|
-
self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=not self.adev.smi_dev, boot=True), lv=am.AMDGPU_VM_PDB1)
|
177
|
-
|
178
|
-
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False) -> AMMapping:
|
179
|
-
if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: mapping {vaddr=:#x} ({size=:#x})")
|
180
|
-
|
181
|
-
assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
|
182
|
-
|
183
|
-
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, create_pts=True)
|
184
|
-
for paddr, psize in paddrs:
|
185
|
-
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize):
|
186
|
-
for pte_off in range(pte_cnt):
|
187
|
-
assert pt.entries[pte_idx + pte_off] & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.entries[pte_idx + pte_off]:#x}"
|
188
|
-
pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers,
|
189
|
-
uncached=uncached, system=system, snooped=snooped, frag=0 if pte_covers == 0x1000 else 0x9, valid=True)
|
190
|
-
|
191
|
-
# Invalidate TLB after mappings.
|
192
|
-
self.adev.gmc.flush_tlb(ip='GC', vmid=0)
|
193
|
-
self.adev.gmc.flush_tlb(ip='MM', vmid=0)
|
194
|
-
return AMMapping(vaddr, size, paddrs, uncached=uncached, system=system, snooped=snooped)
|
195
|
-
|
196
|
-
def unmap_range(self, vaddr:int, size:int):
|
197
|
-
if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: unmapping {vaddr=:#x} ({size=:#x})")
|
198
|
-
|
199
|
-
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, free_pts=True)
|
200
|
-
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size):
|
201
|
-
for pte_id in range(pte_idx, pte_idx + pte_cnt):
|
202
|
-
assert pt.entries[pte_id] & am.AMDGPU_PTE_VALID == am.AMDGPU_PTE_VALID, f"PTE not mapped: {pt.entries[pte_id]:#x}"
|
203
|
-
pt.set_entry(pte_id, paddr=0x0, valid=False)
|
204
|
-
|
205
|
-
@staticmethod
|
206
|
-
def alloc_vaddr(size:int, align=0x1000) -> int: return AMMemoryManager.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align))
|
207
|
-
|
208
|
-
def valloc(self, size:int, align=0x1000, uncached=False, contigous=False) -> AMMapping:
|
209
|
-
# Alloc physical memory and map it to the virtual address
|
210
|
-
va = self.alloc_vaddr(size, align)
|
211
|
-
|
212
|
-
if contigous: paddrs = [(self.palloc(size, zero=True), size)]
|
213
|
-
else:
|
214
|
-
paddrs = []
|
215
|
-
try:
|
216
|
-
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, va, create_pts=True)
|
217
|
-
for _, _, _, seg_cnt, seg_size in ctx.next(size): paddrs += [(self.palloc(seg_size, zero=False), seg_size) for _ in range(seg_cnt)]
|
218
|
-
except MemoryError:
|
219
|
-
for paddr, _ in paddrs: self.pa_allocator.free(paddr)
|
220
|
-
raise
|
104
|
+
def entry(self, entry_id:int) -> int: return self.entries[entry_id]
|
105
|
+
def valid(self, entry_id:int) -> bool: return (self.entries[entry_id] & am.AMDGPU_PTE_VALID) != 0
|
106
|
+
def address(self, entry_id:int) -> int: return self.entries[entry_id] & 0x0000FFFFFFFFF000
|
107
|
+
def is_huge_page(self, entry_id:int) -> bool: return self.lv == am.AMDGPU_VM_PTB or self.adev.gmc.is_pte_huge_page(self.entries[entry_id])
|
108
|
+
def supports_huge_page(self, paddr:int): return self.lv >= am.AMDGPU_VM_PDB2
|
221
109
|
|
222
|
-
|
110
|
+
class AMMemoryManager(MemoryManager):
|
111
|
+
va_allocator = TLSFAllocator(512 * (1 << 30), base=0x200000000000) # global for all devices.
|
223
112
|
|
224
|
-
def
|
225
|
-
|
226
|
-
self.
|
227
|
-
|
228
|
-
|
229
|
-
def palloc(self, size:int, align:int=0x1000, zero=True, boot=False) -> int:
|
230
|
-
assert self.adev.is_booting == boot, "During booting, only boot memory can be allocated"
|
231
|
-
paddr = (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align)
|
232
|
-
if zero: ctypes.memset(self.adev.paddr2cpu(paddr), 0, size)
|
233
|
-
return paddr
|
234
|
-
|
235
|
-
def pfree(self, paddr:int): self.pa_allocator.free(paddr)
|
236
|
-
|
237
|
-
class AMDev:
|
238
|
-
def __init__(self, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview):
|
239
|
-
self.devfmt = devfmt
|
240
|
-
self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar
|
241
|
-
|
242
|
-
os.umask(0) # Set umask to 0 to allow creating files with 0666 permissions
|
113
|
+
def on_range_mapped(self):
|
114
|
+
# Invalidate TLB after mappings.
|
115
|
+
self.dev.gmc.flush_tlb(ip='GC', vmid=0)
|
116
|
+
self.dev.gmc.flush_tlb(ip='MM', vmid=0)
|
243
117
|
|
244
|
-
|
245
|
-
|
246
|
-
else: self.lock_fd = os.open(lock_name, os.O_RDWR | os.O_CREAT, 0o666)
|
118
|
+
class AMDev(PCIDevImplBase):
|
119
|
+
Version = 0xA0000006
|
247
120
|
|
248
|
-
|
249
|
-
|
121
|
+
def __init__(self, devfmt, vram:MMIOInterface, doorbell:MMIOInterface, mmio:MMIOInterface, dma_regions:list[tuple[int, MMIOInterface]]|None=None):
|
122
|
+
self.devfmt, self.vram, self.doorbell64, self.mmio, self.dma_regions = devfmt, vram, doorbell, mmio, dma_regions
|
123
|
+
self.lock_fd = System.flock_acquire(f"am_{self.devfmt}.lock")
|
250
124
|
|
251
125
|
self._run_discovery()
|
252
126
|
self._build_regs()
|
@@ -260,30 +134,19 @@ class AMDev:
|
|
260
134
|
# To enable this, AM uses a separate boot memory that is guaranteed not to be overwritten. This physical memory is utilized for
|
261
135
|
# all blocks that are initialized only during the initial AM boot.
|
262
136
|
# To determine if the GPU is in the third state, AM uses regSCRATCH_REG7 as a flag.
|
263
|
-
self.is_booting
|
264
|
-
self.
|
137
|
+
self.is_booting = True
|
138
|
+
self.init_sw(smi_dev=False)
|
265
139
|
|
266
|
-
|
267
|
-
self.
|
268
|
-
|
269
|
-
|
270
|
-
# Initialize IP blocks
|
271
|
-
self.soc21:AM_SOC21 = AM_SOC21(self)
|
272
|
-
self.gmc:AM_GMC = AM_GMC(self)
|
273
|
-
self.ih:AM_IH = AM_IH(self)
|
274
|
-
self.psp:AM_PSP = AM_PSP(self)
|
275
|
-
self.smu:AM_SMU = AM_SMU(self)
|
276
|
-
self.gfx:AM_GFX = AM_GFX(self)
|
277
|
-
self.sdma:AM_SDMA = AM_SDMA(self)
|
278
|
-
|
279
|
-
if self.partial_boot and (self.reg("regGCVM_CONTEXT0_CNTL").read() != 0):
|
280
|
-
if DEBUG >= 2: print(f"am {self.devfmt}: MEC is active. Issue a full reset.")
|
140
|
+
self.partial_boot = (self.reg("regSCRATCH_REG7").read() == AMDev.Version) and (getenv("AM_RESET", 0) != 1)
|
141
|
+
if self.partial_boot and (self.reg("regGCVM_CONTEXT0_CNTL").read() != 0 or self.reg(self.gmc.pf_status_reg("GC")).read() != 0):
|
142
|
+
if DEBUG >= 2: print(f"am {self.devfmt}: Malformed state. Issuing a full reset.")
|
281
143
|
self.partial_boot = False
|
282
144
|
|
145
|
+
# Init hw for IP blocks where it is needed
|
283
146
|
if not self.partial_boot:
|
284
147
|
if self.psp.is_sos_alive() and self.smu.is_smu_alive(): self.smu.mode1_reset()
|
285
|
-
for ip in [self.
|
286
|
-
ip.
|
148
|
+
for ip in [self.soc, self.gmc, self.ih, self.psp, self.smu]:
|
149
|
+
ip.init_hw()
|
287
150
|
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
|
288
151
|
|
289
152
|
# Booting done
|
@@ -291,25 +154,44 @@ class AMDev:
|
|
291
154
|
|
292
155
|
# Re-initialize main blocks
|
293
156
|
for ip in [self.gfx, self.sdma]:
|
294
|
-
ip.
|
157
|
+
ip.init_hw()
|
295
158
|
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
|
296
159
|
|
297
160
|
self.smu.set_clocks(level=-1) # last level, max perf.
|
298
|
-
self.gfx.set_clockgating_state()
|
299
|
-
self.reg("regSCRATCH_REG7").write(
|
161
|
+
for ip in [self.soc, self.gfx]: ip.set_clockgating_state()
|
162
|
+
self.reg("regSCRATCH_REG7").write(AMDev.Version)
|
300
163
|
if DEBUG >= 2: print(f"am {self.devfmt}: boot done")
|
301
164
|
|
165
|
+
def init_sw(self, smi_dev=False):
|
166
|
+
self.smi_dev = smi_dev # During boot only boot memory can be allocated. This flag is to validate this.
|
167
|
+
|
168
|
+
# Memory manager & firmware
|
169
|
+
self.mm = AMMemoryManager(self, self.vram_size, boot_size=(32 << 20), pt_t=AMPageTableEntry, va_shifts=[12, 21, 30, 39], va_bits=48,
|
170
|
+
first_lv=am.AMDGPU_VM_PDB2, va_base=AMMemoryManager.va_allocator.base,
|
171
|
+
palloc_ranges=[(1 << (i + 12), 0x1000) for i in range(9 * (3 - am.AMDGPU_VM_PDB2), -1, -1)])
|
172
|
+
self.fw = AMFirmware(self)
|
173
|
+
|
174
|
+
# Initialize IP blocks
|
175
|
+
self.soc:AM_SOC = AM_SOC(self)
|
176
|
+
self.gmc:AM_GMC = AM_GMC(self)
|
177
|
+
self.ih:AM_IH = AM_IH(self)
|
178
|
+
self.psp:AM_PSP = AM_PSP(self)
|
179
|
+
self.smu:AM_SMU = AM_SMU(self)
|
180
|
+
self.gfx:AM_GFX = AM_GFX(self)
|
181
|
+
self.sdma:AM_SDMA = AM_SDMA(self)
|
182
|
+
|
183
|
+
# Init sw for all IP blocks
|
184
|
+
for ip in [self.soc, self.gmc, self.ih, self.psp, self.smu, self.gfx, self.sdma]: ip.init_sw()
|
185
|
+
|
302
186
|
def fini(self):
|
303
187
|
if DEBUG >= 2: print(f"am {self.devfmt}: Finalizing")
|
304
|
-
for ip in [self.sdma, self.gfx]: ip.
|
188
|
+
for ip in [self.sdma, self.gfx]: ip.fini_hw()
|
305
189
|
self.smu.set_clocks(level=0)
|
306
190
|
self.ih.interrupt_handler()
|
191
|
+
os.close(self.lock_fd)
|
307
192
|
|
308
|
-
def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr
|
309
193
|
def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr
|
310
194
|
|
311
|
-
def ip_base(self, ip:str, inst:int, seg:int) -> int: return self.regs_offset[am.__dict__[f"{ip}_HWIP"]][inst][seg]
|
312
|
-
|
313
195
|
def reg(self, reg:str) -> AMRegister: return self.__dict__[reg]
|
314
196
|
|
315
197
|
def rreg(self, reg:int) -> int:
|
@@ -335,62 +217,45 @@ class AMDev:
|
|
335
217
|
self.reg("regBIF_BX_PF0_RSMU_INDEX").write(reg)
|
336
218
|
self.reg("regBIF_BX_PF0_RSMU_DATA").write(val)
|
337
219
|
|
338
|
-
def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff, timeout=10000) -> int:
|
339
|
-
for _ in range(timeout):
|
340
|
-
if ((rval:=reg.read()) & mask) == value: return rval
|
341
|
-
time.sleep(0.001)
|
342
|
-
raise RuntimeError(f'wait_reg timeout reg=0x{reg.reg_off:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}')
|
343
|
-
|
344
220
|
def _run_discovery(self):
|
345
221
|
# NOTE: Fixed register to query memory size without known ip bases to find the discovery table.
|
346
222
|
# The table is located at the end of VRAM - 64KB and is 10KB in size.
|
347
223
|
mmRCC_CONFIG_MEMSIZE = 0xde3
|
348
224
|
self.vram_size = self.rreg(mmRCC_CONFIG_MEMSIZE) << 20
|
225
|
+
tmr_offset, tmr_size = self.vram_size - (64 << 10), (10 << 10)
|
349
226
|
|
350
|
-
bhdr = am.struct_binary_header.
|
351
|
-
ihdr = am.struct_ip_discovery_header.from_address(ctypes.addressof(bhdr) + bhdr.table_list[am.IP_DISCOVERY].offset)
|
352
|
-
assert
|
227
|
+
self.bhdr = am.struct_binary_header.from_buffer(bytearray(self.vram.view(tmr_offset, tmr_size)[:]))
|
228
|
+
ihdr = am.struct_ip_discovery_header.from_address(ctypes.addressof(self.bhdr) + self.bhdr.table_list[am.IP_DISCOVERY].offset)
|
229
|
+
assert self.bhdr.binary_signature == am.BINARY_SIGNATURE and ihdr.signature == am.DISCOVERY_TABLE_SIGNATURE, "discovery signatures mismatch"
|
353
230
|
|
354
231
|
# Mapping of HW IP to Discovery HW IP
|
355
232
|
hw_id_map = {am.__dict__[x]: int(y) for x,y in am.hw_id_map}
|
356
|
-
self.regs_offset:dict[int, dict[int,
|
357
|
-
self.
|
233
|
+
self.regs_offset:dict[int, dict[int, tuple]] = collections.defaultdict(dict)
|
234
|
+
self.ip_ver:dict[int, tuple[int, int, int]] = {}
|
358
235
|
|
359
236
|
for num_die in range(ihdr.num_dies):
|
360
|
-
dhdr = am.struct_die_header.from_address(ctypes.addressof(bhdr) + ihdr.die_info[num_die].die_offset)
|
237
|
+
dhdr = am.struct_die_header.from_address(ctypes.addressof(self.bhdr) + ihdr.die_info[num_die].die_offset)
|
361
238
|
|
362
|
-
ip_offset = ctypes.addressof(bhdr) + ctypes.sizeof(dhdr) + ihdr.die_info[num_die].die_offset
|
239
|
+
ip_offset = ctypes.addressof(self.bhdr) + ctypes.sizeof(dhdr) + ihdr.die_info[num_die].die_offset
|
363
240
|
for _ in range(dhdr.num_ips):
|
364
241
|
ip = am.struct_ip_v4.from_address(ip_offset)
|
365
242
|
ba = (ctypes.c_uint32 * ip.num_base_address).from_address(ip_offset + 8)
|
366
243
|
for hw_ip in range(1, am.MAX_HWIP):
|
367
244
|
if hw_ip in hw_id_map and hw_id_map[hw_ip] == ip.hw_id:
|
368
|
-
self.regs_offset[hw_ip][ip.instance_number] = list(ba)
|
369
|
-
self.
|
245
|
+
self.regs_offset[hw_ip][ip.instance_number] = tuple(list(ba))
|
246
|
+
self.ip_ver[hw_ip] = (ip.major, ip.minor, ip.revision)
|
370
247
|
|
371
248
|
ip_offset += 8 + (8 if ihdr.base_addr_64_bit else 4) * ip.num_base_address
|
372
249
|
|
373
|
-
gc_info = am.struct_gc_info_v1_0.from_address(gc_addr:=ctypes.addressof(bhdr) + bhdr.table_list[am.GC].offset)
|
250
|
+
gc_info = am.struct_gc_info_v1_0.from_address(gc_addr:=ctypes.addressof(self.bhdr) + self.bhdr.table_list[am.GC].offset)
|
374
251
|
self.gc_info = getattr(am, f"struct_gc_info_v{gc_info.header.version_major}_{gc_info.header.version_minor}").from_address(gc_addr)
|
375
252
|
|
376
|
-
def _ip_module(self, prefix:str, hwip):
|
377
|
-
version = [self.ip_versions[hwip]//10000, (self.ip_versions[hwip]//100)%100, self.ip_versions[hwip]%100]
|
378
|
-
for ver in [version, version[:2]+[0], version[:1]+[0, 0]]:
|
379
|
-
try: return __import__(f"tinygrad.runtime.autogen.am.{prefix}_{ver[0]}_{ver[1]}_{ver[2]}", fromlist=[f"{prefix}_{ver[0]}_{ver[1]}_{ver[2]}"])
|
380
|
-
except ImportError: pass
|
381
|
-
raise ImportError(f"am {self.devfmt}: failed to load {prefix} module with version {version}")
|
253
|
+
def _ip_module(self, prefix:str, hwip, prever_prefix:str=""): return import_module(prefix, self.ip_ver[hwip], prever_prefix)
|
382
254
|
|
383
255
|
def _build_regs(self):
|
384
|
-
mods = [("
|
385
|
-
("
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
for k, val in module.__dict__.items():
|
391
|
-
if k.endswith("_MASK") and ((rname:=k.split("__")[0]) in reg_names):
|
392
|
-
reg_fields[rname][k[2+len(rname):-5].lower()] = (val, module.__dict__.get(f"{k[:-5]}__SHIFT", val.bit_length() - 1))
|
393
|
-
|
394
|
-
for k, regval in module.__dict__.items():
|
395
|
-
if k.startswith(rpref) and not k.endswith("_BASE_IDX") and (base_idx:=getattr(module, f"{k}_BASE_IDX", None)) is not None:
|
396
|
-
setattr(self, k, AMRegister(self, self.ip_base(base, 0, base_idx) + regval, reg_fields.get(k[len(rpref):], {})))
|
256
|
+
mods = [("mp", am.MP0_HWIP), ("hdp", am.HDP_HWIP), ("gc", am.GC_HWIP), ("mmhub", am.MMHUB_HWIP), ("osssys", am.OSSSYS_HWIP),
|
257
|
+
("nbio" if self.ip_ver[am.GC_HWIP] < (12,0,0) else "nbif", am.NBIO_HWIP)]
|
258
|
+
|
259
|
+
for prefix, hwip in mods:
|
260
|
+
self.__dict__.update(import_asic_regs(prefix, self.ip_ver[hwip], cls=functools.partial(AMRegister, adev=self, bases=self.regs_offset[hwip])))
|
261
|
+
self.__dict__.update(import_asic_regs('mp', (11, 0), cls=functools.partial(AMRegister, adev=self, bases=self.regs_offset[am.MP1_HWIP])))
|