tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,183 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import ctypes, time, functools, re, gzip, struct
|
3
|
+
from tinygrad.helpers import getenv, DEBUG, fetch, getbits, to_mv
|
4
|
+
from tinygrad.runtime.support.hcq import MMIOInterface
|
5
|
+
from tinygrad.runtime.support.memory import TLSFAllocator, MemoryManager
|
6
|
+
from tinygrad.runtime.support.nv.ip import NV_FLCN, NV_FLCN_COT, NV_GSP
|
7
|
+
from tinygrad.runtime.support.system import System, PCIDevImplBase
|
8
|
+
|
9
|
+
NV_DEBUG = getenv("NV_DEBUG", 0)
|
10
|
+
|
11
|
+
class NVReg:
|
12
|
+
def __init__(self, nvdev, base, off, fields=None): self.nvdev, self.base, self.off, self.fields = nvdev, base, off, fields
|
13
|
+
|
14
|
+
def __getitem__(self, idx:int): return NVReg(self.nvdev, self.base, self.off(idx), fields=self.fields)
|
15
|
+
|
16
|
+
def add_field(self, name:str, start:int, end:int): self.fields[name] = (start, end)
|
17
|
+
def with_base(self, base:int): return NVReg(self.nvdev, base + self.base, self.off, self.fields)
|
18
|
+
|
19
|
+
def read(self): return self.nvdev.rreg(self.base + self.off)
|
20
|
+
def read_bitfields(self) -> dict[str, int]: return self.decode(self.read())
|
21
|
+
|
22
|
+
def write(self, _ini_val:int=0, **kwargs): self.nvdev.wreg(self.base + self.off, _ini_val | self.encode(**kwargs))
|
23
|
+
|
24
|
+
def update(self, **kwargs): self.write(self.read() & ~self.mask(*kwargs.keys()), **kwargs)
|
25
|
+
|
26
|
+
def mask(self, *names):
|
27
|
+
return functools.reduce(int.__or__, ((((1 << (self.fields[nm][1]-self.fields[nm][0] + 1)) - 1) << self.fields[nm][0]) for nm in names), 0)
|
28
|
+
|
29
|
+
def encode(self, **kwargs) -> int: return functools.reduce(int.__or__, (value << self.fields[name][0] for name,value in kwargs.items()), 0)
|
30
|
+
def decode(self, val: int) -> dict: return {name:getbits(val, start, end) for name,(start,end) in self.fields.items()}
|
31
|
+
|
32
|
+
class NVPageTableEntry:
|
33
|
+
def __init__(self, nvdev, paddr, lv): self.nvdev, self.paddr, self.lv, self.entries = nvdev, paddr, lv, nvdev.vram.view(paddr, 0x1000, fmt='Q')
|
34
|
+
|
35
|
+
def _is_dual_pde(self) -> bool: return self.lv == self.nvdev.mm.level_cnt - 2
|
36
|
+
|
37
|
+
def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, system=False, snooped=False, frag=0, valid=True):
|
38
|
+
if not table:
|
39
|
+
x = self.nvdev.pte_t.encode(valid=valid, address_sys=paddr >> 12, aperture=2 if system else 0, kind=6,
|
40
|
+
**({'pcf': int(uncached)} if self.nvdev.mmu_ver == 3 else {'vol': uncached}))
|
41
|
+
else:
|
42
|
+
pde = self.nvdev.dual_pde_t if self._is_dual_pde() else self.nvdev.pde_t
|
43
|
+
small, sys = ("_small" if self._is_dual_pde() else ""), "" if self.nvdev.mmu_ver == 3 else "_sys"
|
44
|
+
x = pde.encode(is_pte=False, **{f'aperture{small}': 1 if valid else 0, f'address{small}{sys}': paddr >> 12},
|
45
|
+
**({f'pcf{small}': 0b10} if self.nvdev.mmu_ver == 3 else {'no_ats': 1}))
|
46
|
+
|
47
|
+
if self._is_dual_pde(): self.entries[2*entry_id], self.entries[2*entry_id+1] = x & 0xffffffffffffffff, x >> 64
|
48
|
+
else: self.entries[entry_id] = x
|
49
|
+
|
50
|
+
def entry(self, entry_id:int) -> int:
|
51
|
+
return (self.entries[2*entry_id+1]<<64) | self.entries[2*entry_id] if self._is_dual_pde() else self.entries[entry_id]
|
52
|
+
|
53
|
+
def read_fields(self, entry_id:int) -> dict:
|
54
|
+
if self.is_huge_page(entry_id): return self.nvdev.pte_t.decode(self.entry(entry_id))
|
55
|
+
return (self.nvdev.dual_pde_t if self._is_dual_pde() else self.nvdev.pde_t).decode(self.entry(entry_id))
|
56
|
+
|
57
|
+
def is_huge_page(self, entry_id) -> bool: return (self.entry(entry_id) & 1 == 1) if self.lv < self.nvdev.mm.level_cnt - 1 else True
|
58
|
+
def supports_huge_page(self, paddr:int): return self.lv >= self.nvdev.mm.level_cnt - 3 and paddr % self.nvdev.mm.pte_covers[self.lv] == 0
|
59
|
+
|
60
|
+
def valid(self, entry_id):
|
61
|
+
if self.is_huge_page(entry_id): return self.read_fields(entry_id)['valid']
|
62
|
+
return self.read_fields(entry_id)['aperture_small' if self._is_dual_pde() else 'aperture'] != 0
|
63
|
+
|
64
|
+
def address(self, entry_id:int) -> int:
|
65
|
+
small, sys = ("_small" if self._is_dual_pde() else ""), "_sys" if self.nvdev.mmu_ver == 2 or self.lv == self.nvdev.mm.level_cnt - 1 else ""
|
66
|
+
return self.read_fields(entry_id)[f'address{small}{sys}'] << 12
|
67
|
+
|
68
|
+
class NVMemoryManager(MemoryManager):
|
69
|
+
va_allocator = TLSFAllocator((1 << 44), base=0x1000000000) # global for all devices.
|
70
|
+
|
71
|
+
def on_range_mapped(self): self.dev.NV_VIRTUAL_FUNCTION_PRIV_MMU_INVALIDATE.write((1 << 0) | (1 << 1) | (1 << 6) | (1 << 31))
|
72
|
+
|
73
|
+
class NVDev(PCIDevImplBase):
|
74
|
+
def __init__(self, devfmt:str, mmio:MMIOInterface, vram:MMIOInterface, venid:int, subvenid:int, rev:int, bars:dict):
|
75
|
+
self.devfmt, self.mmio, self.vram, self.venid, self.subvenid, self.rev, self.bars = devfmt, mmio, vram, venid, subvenid, rev, bars
|
76
|
+
self.lock_fd = System.flock_acquire(f"nv_{self.devfmt}.lock")
|
77
|
+
|
78
|
+
self.smi_dev, self.is_booting = False, True
|
79
|
+
self._early_init()
|
80
|
+
|
81
|
+
# UVM depth HW level VA bits
|
82
|
+
# 0 PDE4 56:56 (hopper+)
|
83
|
+
# 1 PDE3 55:47
|
84
|
+
# 2 PDE2 46:38
|
85
|
+
# 3 PDE1 (or 512M PTE) 37:29
|
86
|
+
# 4 PDE0 (dual 64k/4k PDE, or 2M PTE) 28:21
|
87
|
+
# 5 PTE_64K / PTE_4K 20:16 / 20:12
|
88
|
+
bits, shifts = (56, [12, 21, 29, 38, 47, 56]) if self.mmu_ver == 3 else (48, [12, 21, 29, 38, 47])
|
89
|
+
self.mm = NVMemoryManager(self, self.vram_size, boot_size=(2 << 20), pt_t=NVPageTableEntry, va_bits=bits, va_shifts=shifts, va_base=0,
|
90
|
+
palloc_ranges=[(x, x) for x in [512 << 20, 2 << 20, 4 << 10]])
|
91
|
+
self.flcn:NV_FLCN|NV_FLCN_COT = NV_FLCN_COT(self) if self.fmc_boot else NV_FLCN(self)
|
92
|
+
self.gsp:NV_GSP = NV_GSP(self)
|
93
|
+
|
94
|
+
# Turn the booting early, gsp client is loaded from the clean.
|
95
|
+
self.is_booting = False
|
96
|
+
|
97
|
+
for ip in [self.flcn, self.gsp]: ip.init_sw()
|
98
|
+
for ip in [self.flcn, self.gsp]: ip.init_hw()
|
99
|
+
|
100
|
+
def fini(self):
|
101
|
+
for ip in [self.gsp, self.flcn]: ip.fini_hw()
|
102
|
+
|
103
|
+
def reg(self, reg:str) -> NVReg: return self.__dict__[reg]
|
104
|
+
def wreg(self, addr:int, value:int):
|
105
|
+
self.mmio[addr // 4] = value
|
106
|
+
if NV_DEBUG >= 4: print(f"wreg: {hex(addr)} = {hex(value)}")
|
107
|
+
def rreg(self, addr:int) -> int: return self.mmio[addr // 4]
|
108
|
+
|
109
|
+
def _early_init(self):
|
110
|
+
self.reg_names:set[str] = set()
|
111
|
+
self.reg_offsets:dict[str, tuple[int, int]] = {}
|
112
|
+
|
113
|
+
self.include("src/common/inc/swref/published/nv_ref.h")
|
114
|
+
self.chip_id = self.reg("NV_PMC_BOOT_0").read()
|
115
|
+
self.chip_details = self.reg("NV_PMC_BOOT_42").read_bitfields()
|
116
|
+
self.chip_name = {0x17: "GA1", 0x19: "AD1", 0x1b: "GB2"}[self.chip_details['architecture']] + f"{self.chip_details['implementation']:02d}"
|
117
|
+
self.mmu_ver, self.fmc_boot = (3, True) if self.chip_details['architecture'] >= 0x1a else (2, False)
|
118
|
+
|
119
|
+
self.include("src/common/inc/swref/published/turing/tu102/dev_fb.h")
|
120
|
+
if self.reg("NV_PFB_PRI_MMU_WPR2_ADDR_HI").read() != 0:
|
121
|
+
if DEBUG >= 2: print(f"nv {self.devfmt}: WPR2 is up. Issuing a full reset.")
|
122
|
+
System.pci_reset(self.devfmt)
|
123
|
+
time.sleep(0.5)
|
124
|
+
|
125
|
+
self.include("src/common/inc/swref/published/turing/tu102/dev_vm.h")
|
126
|
+
self.include("src/common/inc/swref/published/ampere/ga102/dev_gc6_island.h")
|
127
|
+
self.include("src/common/inc/swref/published/ampere/ga102/dev_gc6_island_addendum.h")
|
128
|
+
|
129
|
+
# MMU Init
|
130
|
+
self.reg_names.update(mmu_pd_names:=[f'NV_MMU_VER{self.mmu_ver}_PTE', f'NV_MMU_VER{self.mmu_ver}_PDE', f'NV_MMU_VER{self.mmu_ver}_DUAL_PDE'])
|
131
|
+
for name in mmu_pd_names: self.__dict__[name] = NVReg(self, None, None, fields={})
|
132
|
+
self.include(f"kernel-open/nvidia-uvm/hwref/{'hopper/gh100' if self.mmu_ver == 3 else 'turing/tu102'}/dev_mmu.h")
|
133
|
+
self.pte_t, self.pde_t, self.dual_pde_t = tuple([self.__dict__[name] for name in mmu_pd_names])
|
134
|
+
|
135
|
+
self.vram_size = self.reg("NV_PGC6_AON_SECURE_SCRATCH_GROUP_42").read() << 20
|
136
|
+
|
137
|
+
def _alloc_boot_struct(self, struct:ctypes.Structure) -> tuple[ctypes.Structure, int]:
|
138
|
+
va, paddrs = System.alloc_sysmem(sz:=ctypes.sizeof(type(struct)), contiguous=True)
|
139
|
+
to_mv(va, sz)[:] = bytes(struct)
|
140
|
+
return type(struct).from_address(va), paddrs[0]
|
141
|
+
|
142
|
+
def _download(self, file:str) -> str:
|
143
|
+
url = f"https://raw.githubusercontent.com/NVIDIA/open-gpu-kernel-modules/8ec351aeb96a93a4bb69ccc12a542bf8a8df2b6f/{file}"
|
144
|
+
return fetch(url, subdir="defines").read_text()
|
145
|
+
|
146
|
+
def extract_fw(self, file:str, dname:str) -> bytes:
|
147
|
+
# Extracts the firmware binary from the given header
|
148
|
+
tname = file.replace("kgsp", "kgspGet")
|
149
|
+
text = self._download(f"src/nvidia/generated/g_bindata_{tname}_{self.chip_name}.c")
|
150
|
+
info, sl = text[text[:text.index(dnm:=f'{file}_{self.chip_name}_{dname}')].rindex("COMPRESSION:"):][:16], text[text.index(dnm) + len(dnm) + 7:]
|
151
|
+
image = bytes.fromhex(sl[:sl.find("};")].strip().replace("0x", "").replace(",", "").replace(" ", "").replace("\n", ""))
|
152
|
+
return gzip.decompress(struct.pack("<4BL2B", 0x1f, 0x8b, 8, 0, 0, 0, 3) + image) if "COMPRESSION: YES" in info else image
|
153
|
+
|
154
|
+
def include(self, file:str):
|
155
|
+
regs_off = {'NV_PFALCON_FALCON': 0x0, 'NV_PGSP_FALCON': 0x0, 'NV_PSEC_FALCON': 0x0, 'NV_PRISCV_RISCV': 0x1000, 'NV_PGC6_AON': 0x0, 'NV_PFSP': 0x0,
|
156
|
+
'NV_PGC6_BSI': 0x0, 'NV_PFALCON_FBIF': 0x600, 'NV_PFALCON2_FALCON': 0x1000, 'NV_PBUS': 0x0, 'NV_PFB': 0x0, 'NV_PMC': 0x0, 'NV_PGSP_QUEUE': 0x0,
|
157
|
+
'NV_VIRTUAL_FUNCTION':0xb80000}
|
158
|
+
|
159
|
+
for raw in self._download(file).splitlines():
|
160
|
+
if not raw.startswith("#define "): continue
|
161
|
+
|
162
|
+
if m:=re.match(r'#define\s+(\w+)\s+([0-9\+\-\*\(\)]+):([0-9\+\-\*\(\)]+)', raw): # bitfields
|
163
|
+
name, hi, lo = m.groups()
|
164
|
+
|
165
|
+
reg = next((r for r in self.reg_names if name.startswith(r+"_")), None)
|
166
|
+
if reg is not None: self.__dict__[reg].add_field(name[len(reg)+1:].lower(), eval(lo), eval(hi))
|
167
|
+
else: self.reg_offsets[name] = (eval(lo), eval(hi))
|
168
|
+
continue
|
169
|
+
|
170
|
+
if m:=re.match(r'#define\s+(\w+)\s*\(\s*(\w+)\s*\)\s*(.+)', raw): # reg set
|
171
|
+
fn = m.groups()[2].strip().rstrip('\\').split('/*')[0].rstrip()
|
172
|
+
name, value = m.groups()[0], eval(f"lambda {m.groups()[1]}: {fn}")
|
173
|
+
elif m:=re.match(r'#define\s+(\w+)\s+([0-9A-Fa-fx]+)(?![^\n]*:)', raw): name, value = m.groups()[0], int(m.groups()[1], 0) # reg value
|
174
|
+
else: continue
|
175
|
+
|
176
|
+
reg_pref = next((prefix for prefix in regs_off.keys() if name.startswith(prefix)), None)
|
177
|
+
not_already_reg = not any(name.startswith(r+"_") for r in self.reg_names)
|
178
|
+
|
179
|
+
if reg_pref is not None and not_already_reg:
|
180
|
+
fields = {k[len(name)+1:]: v for k, v in self.reg_offsets.items() if k.startswith(name+'_')}
|
181
|
+
self.__dict__[name] = NVReg(self, regs_off[reg_pref], value, fields=fields)
|
182
|
+
self.reg_names.add(name)
|
183
|
+
else: self.__dict__[name] = value
|
@@ -0,0 +1,170 @@
|
|
1
|
+
import os, mmap, array, functools, ctypes, select, contextlib, dataclasses, sys
|
2
|
+
from typing import cast, ClassVar
|
3
|
+
from tinygrad.helpers import round_up, to_mv, getenv, OSX, temp
|
4
|
+
from tinygrad.runtime.autogen import libc, vfio
|
5
|
+
from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface, HCQBuffer
|
6
|
+
from tinygrad.runtime.support.memory import MemoryManager, VirtMapping
|
7
|
+
|
8
|
+
MAP_FIXED, MAP_LOCKED, MAP_POPULATE, MAP_NORESERVE = 0x10, 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000), 0x400
|
9
|
+
|
10
|
+
class _System:
|
11
|
+
def reserve_hugepages(self, cnt): os.system(f"sudo sh -c 'echo {cnt} > /proc/sys/vm/nr_hugepages'")
|
12
|
+
|
13
|
+
def memory_barrier(self): lib.atomic_thread_fence(__ATOMIC_SEQ_CST:=5) if (lib:=self.atomic_lib()) is not None else None
|
14
|
+
|
15
|
+
def lock_memory(self, addr:int, size:int):
|
16
|
+
if libc.mlock(ctypes.c_void_p(addr), size): raise RuntimeError(f"Failed to lock memory at {addr:#x} with size {size:#x}")
|
17
|
+
|
18
|
+
def system_paddrs(self, vaddr:int, size:int) -> list[int]:
|
19
|
+
self.pagemap().seek(vaddr // mmap.PAGESIZE * 8)
|
20
|
+
return [(x & ((1<<55) - 1)) * mmap.PAGESIZE for x in array.array('Q', self.pagemap().read(size//mmap.PAGESIZE*8, binary=True))]
|
21
|
+
|
22
|
+
def alloc_sysmem(self, size:int, vaddr:int=0, contiguous:bool=False, data:bytes|None=None) -> tuple[int, list[int]]:
|
23
|
+
assert not contiguous or size <= (2 << 20), "Contiguous allocation is only supported for sizes up to 2MB"
|
24
|
+
flags = (libc.MAP_HUGETLB if contiguous and (size:=round_up(size, mmap.PAGESIZE)) > 0x1000 else 0) | (MAP_FIXED if vaddr else 0)
|
25
|
+
va = FileIOInterface.anon_mmap(vaddr, size, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED|mmap.MAP_ANONYMOUS|MAP_POPULATE|MAP_LOCKED|flags, 0)
|
26
|
+
|
27
|
+
if data is not None: to_mv(va, len(data))[:] = data
|
28
|
+
return va, self.system_paddrs(va, size)
|
29
|
+
|
30
|
+
def pci_reset(self, gpu): os.system(f"sudo sh -c 'echo 1 > /sys/bus/pci/devices/{gpu}/reset'")
|
31
|
+
def pci_scan_bus(self, target_vendor:int, target_devices:list[int]) -> list[str]:
|
32
|
+
result = []
|
33
|
+
for pcibus in FileIOInterface("/sys/bus/pci/devices").listdir():
|
34
|
+
vendor = int(FileIOInterface(f"/sys/bus/pci/devices/{pcibus}/vendor").read(), 16)
|
35
|
+
device = int(FileIOInterface(f"/sys/bus/pci/devices/{pcibus}/device").read(), 16)
|
36
|
+
if vendor == target_vendor and device in target_devices: result.append(pcibus)
|
37
|
+
return sorted(result)
|
38
|
+
|
39
|
+
@functools.cache
|
40
|
+
def atomic_lib(self): return ctypes.CDLL(ctypes.util.find_library('atomic')) if sys.platform == "linux" else None
|
41
|
+
|
42
|
+
@functools.cache
|
43
|
+
def pagemap(self) -> FileIOInterface:
|
44
|
+
if FileIOInterface(reloc_sysfs:="/proc/sys/vm/compact_unevictable_allowed", os.O_RDONLY).read()[0] != "0":
|
45
|
+
os.system(cmd:=f"sudo sh -c 'echo 0 > {reloc_sysfs}'")
|
46
|
+
assert FileIOInterface(reloc_sysfs, os.O_RDONLY).read()[0] == "0", f"Failed to disable migration of locked pages. Please run {cmd} manually."
|
47
|
+
return FileIOInterface("/proc/self/pagemap", os.O_RDONLY)
|
48
|
+
|
49
|
+
@functools.cache
|
50
|
+
def vfio(self) -> FileIOInterface|None:
|
51
|
+
try:
|
52
|
+
if not FileIOInterface.exists("/sys/module/vfio"): os.system("sudo modprobe vfio-pci disable_idle_d3=1")
|
53
|
+
|
54
|
+
FileIOInterface("/sys/module/vfio/parameters/enable_unsafe_noiommu_mode", os.O_RDWR).write("1")
|
55
|
+
vfio_fd = FileIOInterface("/dev/vfio/vfio", os.O_RDWR)
|
56
|
+
vfio.VFIO_CHECK_EXTENSION(vfio_fd, vfio.VFIO_NOIOMMU_IOMMU)
|
57
|
+
|
58
|
+
return vfio_fd
|
59
|
+
except OSError: return None
|
60
|
+
|
61
|
+
def flock_acquire(self, name:str) -> int:
|
62
|
+
import fcntl # to support windows
|
63
|
+
|
64
|
+
os.umask(0) # Set umask to 0 to allow creating files with 0666 permissions
|
65
|
+
|
66
|
+
# Avoid O_CREAT because we don’t want to re-create/replace an existing file (triggers extra perms checks) when opening as non-owner.
|
67
|
+
if os.path.exists(lock_name:=temp(name)): self.lock_fd = os.open(lock_name, os.O_RDWR)
|
68
|
+
else: self.lock_fd = os.open(lock_name, os.O_RDWR | os.O_CREAT | os.O_CLOEXEC, 0o666)
|
69
|
+
|
70
|
+
try: fcntl.flock(self.lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
71
|
+
except OSError: raise RuntimeError(f"Failed to take lock file {name}. It's already in use.")
|
72
|
+
|
73
|
+
return self.lock_fd
|
74
|
+
|
75
|
+
System = _System()
|
76
|
+
|
77
|
+
class PCIDevice:
|
78
|
+
def __init__(self, pcibus:str, bars:list[int], resize_bars:list[int]|None=None):
|
79
|
+
self.pcibus, self.irq_poller = pcibus, None
|
80
|
+
|
81
|
+
if FileIOInterface.exists(f"/sys/bus/pci/devices/{self.pcibus}/driver"):
|
82
|
+
FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/driver/unbind", os.O_WRONLY).write(self.pcibus)
|
83
|
+
|
84
|
+
for i in resize_bars or []:
|
85
|
+
if FileIOInterface.exists(rpath:=f"/sys/bus/pci/devices/{self.pcibus}/resource{i}_resize"):
|
86
|
+
try: FileIOInterface(rpath, os.O_RDWR).write(str(int(FileIOInterface(rpath, os.O_RDONLY).read(), 16).bit_length() - 1))
|
87
|
+
except OSError as e: raise RuntimeError(f"Cannot resize BAR {i}: {e}. Ensure the resizable BAR option is enabled on your system.") from e
|
88
|
+
|
89
|
+
if getenv("VFIO", 0) and (vfio_fd:=System.vfio()) is not None:
|
90
|
+
FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/driver_override", os.O_WRONLY).write("vfio-pci")
|
91
|
+
FileIOInterface("/sys/bus/pci/drivers_probe", os.O_WRONLY).write(self.pcibus)
|
92
|
+
iommu_group = FileIOInterface.readlink(f"/sys/bus/pci/devices/{self.pcibus}/iommu_group").split('/')[-1]
|
93
|
+
|
94
|
+
self.vfio_group = FileIOInterface(f"/dev/vfio/noiommu-{iommu_group}", os.O_RDWR)
|
95
|
+
vfio.VFIO_GROUP_SET_CONTAINER(self.vfio_group, ctypes.c_int(vfio_fd.fd))
|
96
|
+
|
97
|
+
with contextlib.suppress(OSError): vfio.VFIO_SET_IOMMU(vfio_fd, vfio.VFIO_NOIOMMU_IOMMU) # set iommu works only once for the fd.
|
98
|
+
self.vfio_dev = FileIOInterface(fd=vfio.VFIO_GROUP_GET_DEVICE_FD(self.vfio_group, ctypes.create_string_buffer(self.pcibus.encode())))
|
99
|
+
|
100
|
+
self.irq_fd = FileIOInterface.eventfd(0, 0)
|
101
|
+
self.irq_poller = select.poll()
|
102
|
+
self.irq_poller.register(self.irq_fd.fd, select.POLLIN)
|
103
|
+
|
104
|
+
irqs = vfio.struct_vfio_irq_set(index=vfio.VFIO_PCI_MSI_IRQ_INDEX, flags=vfio.VFIO_IRQ_SET_DATA_EVENTFD|vfio.VFIO_IRQ_SET_ACTION_TRIGGER,
|
105
|
+
argsz=ctypes.sizeof(vfio.struct_vfio_irq_set), count=1, data=(ctypes.c_int * 1)(self.irq_fd.fd))
|
106
|
+
vfio.VFIO_DEVICE_SET_IRQS(self.vfio_dev, irqs)
|
107
|
+
else: FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/enable", os.O_RDWR).write("1")
|
108
|
+
|
109
|
+
self.cfg_fd = FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/config", os.O_RDWR | os.O_SYNC | os.O_CLOEXEC)
|
110
|
+
self.bar_fds = {b: FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource{b}", os.O_RDWR | os.O_SYNC | os.O_CLOEXEC) for b in bars}
|
111
|
+
|
112
|
+
bar_info = FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource", os.O_RDONLY).read().splitlines()
|
113
|
+
self.bar_info = {j:(int(start,16), int(end,16), int(flgs,16)) for j,(start,end,flgs) in enumerate(l.split() for l in bar_info)}
|
114
|
+
|
115
|
+
def read_config(self, offset:int, size:int): return int.from_bytes(self.cfg_fd.read(size, binary=True, offset=offset), byteorder='little')
|
116
|
+
def write_config(self, offset:int, value:int, size:int): self.cfg_fd.write(value.to_bytes(size, byteorder='little'), binary=True, offset=offset)
|
117
|
+
def map_bar(self, bar:int, off:int=0, addr:int=0, size:int|None=None, fmt='B') -> MMIOInterface:
|
118
|
+
fd, sz = self.bar_fds[bar], size or (self.bar_info[bar][1] - self.bar_info[bar][0] + 1)
|
119
|
+
libc.madvise(loc:=fd.mmap(addr, sz, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | (MAP_FIXED if addr else 0), off), sz, libc.MADV_DONTFORK)
|
120
|
+
return MMIOInterface(loc, sz, fmt=fmt)
|
121
|
+
|
122
|
+
class PCIDevImplBase:
|
123
|
+
mm: MemoryManager
|
124
|
+
|
125
|
+
@dataclasses.dataclass
|
126
|
+
class PCIAllocationMeta: mapping:VirtMapping; has_cpu_mapping:bool; hMemory:int=0 # noqa: E702
|
127
|
+
|
128
|
+
class PCIIfaceBase:
|
129
|
+
dev_impl:PCIDevImplBase
|
130
|
+
gpus:ClassVar[list[str]] = []
|
131
|
+
|
132
|
+
def __init__(self, dev, dev_id, vendor, devices, bars, vram_bar, va_start, va_size):
|
133
|
+
if len((cls:=type(self)).gpus) == 0:
|
134
|
+
cls.gpus = System.pci_scan_bus(vendor, devices)
|
135
|
+
visible_devices = [int(x) for x in (getenv('VISIBLE_DEVICES', '')).split(',') if x.strip()]
|
136
|
+
cls.gpus = [cls.gpus[x] for x in visible_devices] if visible_devices else cls.gpus
|
137
|
+
|
138
|
+
# Acquire va range to avoid collisions.
|
139
|
+
FileIOInterface.anon_mmap(va_start, va_size, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS | MAP_NORESERVE | MAP_FIXED, 0)
|
140
|
+
self.pci_dev, self.dev, self.vram_bar = PCIDevice(cls.gpus[dev_id], bars=bars, resize_bars=[vram_bar]), dev, vram_bar
|
141
|
+
self.p2p_base_addr = self.pci_dev.bar_info[vram_bar][0]
|
142
|
+
|
143
|
+
def alloc(self, size:int, host=False, uncached=False, cpu_access=False, contiguous=False, **kwargs) -> HCQBuffer:
|
144
|
+
if host or (uncached and cpu_access): # host or gtt-like memory.
|
145
|
+
vaddr = self.dev_impl.mm.alloc_vaddr(size:=round_up(size, mmap.PAGESIZE), align=mmap.PAGESIZE)
|
146
|
+
paddrs = [(paddr, mmap.PAGESIZE) for paddr in System.alloc_sysmem(size, vaddr=vaddr, contiguous=contiguous)[1]]
|
147
|
+
mapping = self.dev_impl.mm.map_range(vaddr, size, paddrs, system=True, snooped=True, uncached=True)
|
148
|
+
return HCQBuffer(vaddr, size, meta=PCIAllocationMeta(mapping, has_cpu_mapping=True, hMemory=paddrs[0][0]),
|
149
|
+
view=MMIOInterface(mapping.va_addr, size, fmt='B'), owner=self.dev)
|
150
|
+
|
151
|
+
mapping = self.dev_impl.mm.valloc(size:=round_up(size, 4 << 10), uncached=uncached, contiguous=cpu_access)
|
152
|
+
if cpu_access: self.pci_dev.map_bar(bar=self.vram_bar, off=mapping.paddrs[0][0], addr=mapping.va_addr, size=mapping.size)
|
153
|
+
return HCQBuffer(mapping.va_addr, size, view=MMIOInterface(mapping.va_addr, size, fmt='B') if cpu_access else None,
|
154
|
+
meta=PCIAllocationMeta(mapping, has_cpu_mapping=cpu_access, hMemory=mapping.paddrs[0][0]), owner=self.dev)
|
155
|
+
|
156
|
+
def free(self, b:HCQBuffer):
|
157
|
+
for dev in b.mapped_devs[1:]: dev.iface.dev_impl.mm.unmap_range(b.va_addr, b.size)
|
158
|
+
if not b.meta.mapping.system: self.dev_impl.mm.vfree(b.meta.mapping)
|
159
|
+
if b.owner == self.dev and b.meta.has_cpu_mapping: FileIOInterface.munmap(b.va_addr, b.size)
|
160
|
+
|
161
|
+
def map(self, b:HCQBuffer):
|
162
|
+
if b.owner is not None and b.owner._is_cpu():
|
163
|
+
System.lock_memory(cast(int, b.va_addr), b.size)
|
164
|
+
paddrs, snooped, uncached = [(x, 0x1000) for x in System.system_paddrs(cast(int, b.va_addr), round_up(b.size, 0x1000))], True, False
|
165
|
+
elif (ifa:=getattr(b.owner, "iface", None)) is not None and isinstance(ifa, PCIIfaceBase):
|
166
|
+
paddrs = [(paddr if b.meta.mapping.system else (paddr + ifa.p2p_base_addr), size) for paddr,size in b.meta.mapping.paddrs]
|
167
|
+
snooped, uncached = b.meta.mapping.snooped, b.meta.mapping.uncached
|
168
|
+
else: raise RuntimeError(f"map failed: {b.owner} -> {self.dev}")
|
169
|
+
|
170
|
+
self.dev_impl.mm.map_range(cast(int, b.va_addr), round_up(b.size, 0x1000), paddrs, system=True, snooped=snooped, uncached=uncached)
|
@@ -0,0 +1,268 @@
|
|
1
|
+
import ctypes, struct, dataclasses, array, itertools
|
2
|
+
from typing import Sequence
|
3
|
+
from tinygrad.runtime.autogen import libusb
|
4
|
+
from tinygrad.helpers import DEBUG, to_mv, round_up, OSX
|
5
|
+
from tinygrad.runtime.support.hcq import MMIOInterface
|
6
|
+
|
7
|
+
class USB3:
|
8
|
+
def __init__(self, vendor:int, dev:int, ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int, max_streams:int=31):
|
9
|
+
self.vendor, self.dev = vendor, dev
|
10
|
+
self.ep_data_in, self.ep_stat_in, self.ep_data_out, self.ep_cmd_out = ep_data_in, ep_stat_in, ep_data_out, ep_cmd_out
|
11
|
+
self.max_streams = max_streams
|
12
|
+
self.ctx = ctypes.POINTER(libusb.struct_libusb_context)()
|
13
|
+
|
14
|
+
if libusb.libusb_init(ctypes.byref(self.ctx)): raise RuntimeError("libusb_init failed")
|
15
|
+
if DEBUG >= 6: libusb.libusb_set_option(self.ctx, libusb.LIBUSB_OPTION_LOG_LEVEL, 4)
|
16
|
+
|
17
|
+
self.handle = libusb.libusb_open_device_with_vid_pid(self.ctx, self.vendor, self.dev)
|
18
|
+
if not self.handle: raise RuntimeError(f"device {self.vendor:04x}:{self.dev:04x} not found. sudo required?")
|
19
|
+
|
20
|
+
# Detach kernel driver if needed
|
21
|
+
if libusb.libusb_kernel_driver_active(self.handle, 0):
|
22
|
+
libusb.libusb_detach_kernel_driver(self.handle, 0)
|
23
|
+
libusb.libusb_reset_device(self.handle)
|
24
|
+
|
25
|
+
# Set configuration and claim interface
|
26
|
+
if libusb.libusb_set_configuration(self.handle, 1): raise RuntimeError("set_configuration failed")
|
27
|
+
if libusb.libusb_claim_interface(self.handle, 0): raise RuntimeError("claim_interface failed. sudo required?")
|
28
|
+
if libusb.libusb_set_interface_alt_setting(self.handle, 0, 1): raise RuntimeError("alt_setting failed")
|
29
|
+
|
30
|
+
# Clear any stalled endpoints
|
31
|
+
all_eps = (self.ep_data_out, self.ep_data_in, self.ep_stat_in, self.ep_cmd_out)
|
32
|
+
for ep in all_eps: libusb.libusb_clear_halt(self.handle, ep)
|
33
|
+
|
34
|
+
# Allocate streams
|
35
|
+
stream_eps = (ctypes.c_uint8 * 3)(self.ep_data_out, self.ep_data_in, self.ep_stat_in)
|
36
|
+
if (rc:=libusb.libusb_alloc_streams(self.handle, self.max_streams * len(stream_eps), stream_eps, len(stream_eps))) < 0:
|
37
|
+
raise RuntimeError(f"alloc_streams failed: {rc}")
|
38
|
+
|
39
|
+
# Base cmd
|
40
|
+
cmd_template = bytes([0x01, 0x00, 0x00, 0x01, *([0] * 12), 0xE4, 0x24, 0x00, 0xB2, 0x1A, 0x00, 0x00, 0x00, *([0] * 8)])
|
41
|
+
|
42
|
+
# Init pools
|
43
|
+
self.tr = {ep: [libusb.libusb_alloc_transfer(0) for _ in range(self.max_streams)] for ep in all_eps}
|
44
|
+
|
45
|
+
self.buf_cmd = [(ctypes.c_uint8 * len(cmd_template))(*cmd_template) for _ in range(self.max_streams)]
|
46
|
+
self.buf_stat = [(ctypes.c_uint8 * 64)() for _ in range(self.max_streams)]
|
47
|
+
self.buf_data_in = [(ctypes.c_uint8 * 0x1000)() for _ in range(self.max_streams)]
|
48
|
+
self.buf_data_out = [(ctypes.c_uint8 * 0x80000)() for _ in range(self.max_streams)]
|
49
|
+
self.buf_data_out_mvs = [to_mv(ctypes.addressof(self.buf_data_out[i]), 0x80000) for i in range(self.max_streams)]
|
50
|
+
|
51
|
+
for slot in range(self.max_streams): struct.pack_into(">B", self.buf_cmd[slot], 3, slot + 1)
|
52
|
+
|
53
|
+
def _prep_transfer(self, tr, ep, stream_id, buf, length):
|
54
|
+
tr.contents.dev_handle, tr.contents.endpoint, tr.contents.length, tr.contents.buffer = self.handle, ep, length, buf
|
55
|
+
tr.contents.status, tr.contents.flags, tr.contents.timeout, tr.contents.num_iso_packets = 0xff, 0, 1000, 0
|
56
|
+
tr.contents.type = (libusb.LIBUSB_TRANSFER_TYPE_BULK_STREAM if stream_id is not None else libusb.LIBUSB_TRANSFER_TYPE_BULK)
|
57
|
+
if stream_id is not None: libusb.libusb_transfer_set_stream_id(tr, stream_id)
|
58
|
+
return tr
|
59
|
+
|
60
|
+
def _submit_and_wait(self, cmds):
|
61
|
+
for tr in cmds: libusb.libusb_submit_transfer(tr)
|
62
|
+
|
63
|
+
running = len(cmds)
|
64
|
+
while running:
|
65
|
+
libusb.libusb_handle_events(self.ctx)
|
66
|
+
running = len(cmds)
|
67
|
+
for tr in cmds:
|
68
|
+
if tr.contents.status == libusb.LIBUSB_TRANSFER_COMPLETED: running -= 1
|
69
|
+
elif tr.contents.status != 0xFF: raise RuntimeError(f"EP 0x{tr.contents.endpoint:02X} error: {tr.contents.status}")
|
70
|
+
|
71
|
+
def send_batch(self, cdbs:list[bytes], idata:list[int]|None=None, odata:list[bytes|None]|None=None) -> list[bytes|None]:
|
72
|
+
idata, odata = idata or [0] * len(cdbs), odata or [None] * len(cdbs)
|
73
|
+
results, tr_window, op_window = [], [], []
|
74
|
+
|
75
|
+
for idx, (cdb, rlen, send_data) in enumerate(zip(cdbs, idata, odata)):
|
76
|
+
# allocate slot and stream. stream is 1-based
|
77
|
+
slot, stream = idx % self.max_streams, (idx % self.max_streams) + 1
|
78
|
+
|
79
|
+
# build cmd packet
|
80
|
+
self.buf_cmd[slot][16:16+len(cdb)] = list(cdb)
|
81
|
+
|
82
|
+
# cmd + stat transfers
|
83
|
+
tr_window.append(self._prep_transfer(self.tr[self.ep_cmd_out][slot], self.ep_cmd_out, None, self.buf_cmd[slot], len(self.buf_cmd[slot])))
|
84
|
+
tr_window.append(self._prep_transfer(self.tr[self.ep_stat_in][slot], self.ep_stat_in, stream, self.buf_stat[slot], 64))
|
85
|
+
|
86
|
+
if rlen:
|
87
|
+
if rlen > len(self.buf_data_in[slot]): self.buf_data_in[slot] = (ctypes.c_uint8 * round_up(rlen, 0x1000))()
|
88
|
+
tr_window.append(self._prep_transfer(self.tr[self.ep_data_in][slot], self.ep_data_in, stream, self.buf_data_in[slot], rlen))
|
89
|
+
|
90
|
+
if send_data is not None:
|
91
|
+
if len(send_data) > len(self.buf_data_out[slot]):
|
92
|
+
self.buf_data_out[slot] = (ctypes.c_uint8 * len(send_data))()
|
93
|
+
self.buf_data_out_mvs[slot] = to_mv(ctypes.addressof(self.buf_data_out[slot]), len(send_data))
|
94
|
+
|
95
|
+
self.buf_data_out_mvs[slot][:len(send_data)] = bytes(send_data)
|
96
|
+
tr_window.append(self._prep_transfer(self.tr[self.ep_data_out][slot], self.ep_data_out, stream, self.buf_data_out[slot], len(send_data)))
|
97
|
+
|
98
|
+
op_window.append((idx, slot, rlen))
|
99
|
+
if (idx + 1 == len(cdbs)) or len(op_window) >= self.max_streams:
|
100
|
+
self._submit_and_wait(tr_window)
|
101
|
+
for idx, slot, rlen in op_window: results.append(bytes(self.buf_data_in[slot][:rlen]) if rlen else None)
|
102
|
+
tr_window = []
|
103
|
+
|
104
|
+
return results
|
105
|
+
|
106
|
+
@dataclasses.dataclass(frozen=True)
|
107
|
+
class WriteOp: addr:int; data:bytes; ignore_cache:bool=True # noqa: E702
|
108
|
+
|
109
|
+
@dataclasses.dataclass(frozen=True)
|
110
|
+
class ReadOp: addr:int; size:int # noqa: E702
|
111
|
+
|
112
|
+
@dataclasses.dataclass(frozen=True)
|
113
|
+
class ScsiWriteOp: data:bytes; lba:int=0 # noqa: E702
|
114
|
+
|
115
|
+
class ASM24Controller:
|
116
|
+
def __init__(self):
|
117
|
+
self.usb = USB3(0xADD1, 0x0001, 0x81, 0x83, 0x02, 0x04)
|
118
|
+
self._cache: dict[int, int|None] = {}
|
119
|
+
self._pci_cacheable: list[tuple[int, int]] = []
|
120
|
+
self._pci_cache: dict[int, int|None] = {}
|
121
|
+
|
122
|
+
# Init controller.
|
123
|
+
self.exec_ops([WriteOp(0x54b, b' '), WriteOp(0x54e, b'\x04'), WriteOp(0x5a8, b'\x02'), WriteOp(0x5f8, b'\x04'),
|
124
|
+
WriteOp(0x7ec, b'\x01\x00\x00\x00'), WriteOp(0xc422, b'\x02'), WriteOp(0x0, b'\x33')])
|
125
|
+
|
126
|
+
def exec_ops(self, ops:Sequence[WriteOp|ReadOp|ScsiWriteOp]):
|
127
|
+
cdbs:list[bytes] = []
|
128
|
+
idata:list[int] = []
|
129
|
+
odata:list[bytes|None] = []
|
130
|
+
|
131
|
+
def _add_req(cdb:bytes, i:int, o:bytes|None):
|
132
|
+
nonlocal cdbs, idata, odata
|
133
|
+
cdbs, idata, odata = cdbs + [cdb], idata + [i], odata + [o]
|
134
|
+
|
135
|
+
for op in ops:
|
136
|
+
if isinstance(op, WriteOp):
|
137
|
+
for off, value in enumerate(op.data):
|
138
|
+
addr = ((op.addr + off) & 0x1FFFF) | 0x500000
|
139
|
+
if not op.ignore_cache and self._cache.get(addr) == value: continue
|
140
|
+
_add_req(struct.pack('>BBBHB', 0xE5, value, addr >> 16, addr & 0xFFFF, 0), 0, None)
|
141
|
+
self._cache[addr] = value
|
142
|
+
elif isinstance(op, ReadOp):
|
143
|
+
assert op.size <= 0xff
|
144
|
+
addr = (op.addr & 0x1FFFF) | 0x500000
|
145
|
+
_add_req(struct.pack('>BBBHB', 0xE4, op.size, addr >> 16, addr & 0xFFFF, 0), op.size, None)
|
146
|
+
for i in range(op.size): self._cache[addr + i] = None
|
147
|
+
elif isinstance(op, ScsiWriteOp):
|
148
|
+
sectors = round_up(len(op.data), 512) // 512
|
149
|
+
_add_req(struct.pack('>BBQIBB', 0x8A, 0, op.lba, sectors, 0, 0), 0, op.data+b'\x00'*((sectors*512)-len(op.data)))
|
150
|
+
|
151
|
+
return self.usb.send_batch(cdbs, idata, odata)
|
152
|
+
|
153
|
+
def write(self, base_addr:int, data:bytes, ignore_cache:bool=True): return self.exec_ops([WriteOp(base_addr, data, ignore_cache)])
|
154
|
+
|
155
|
+
def scsi_write(self, buf:bytes, lba:int=0):
|
156
|
+
if len(buf) > 0x4000: buf += b'\x00' * (round_up(len(buf), 0x10000) - len(buf))
|
157
|
+
|
158
|
+
for i in range(0, len(buf), 0x10000):
|
159
|
+
self.exec_ops([ScsiWriteOp(buf[i:i+0x10000], lba), WriteOp(0x171, b'\xff\xff\xff', ignore_cache=True)])
|
160
|
+
self.exec_ops([WriteOp(0xce6e, b'\x00\x00', ignore_cache=True)])
|
161
|
+
|
162
|
+
if len(buf) > 0x4000:
|
163
|
+
for i in range(4): self.exec_ops([WriteOp(0xce40 + i, b'\x00', ignore_cache=True)])
|
164
|
+
|
165
|
+
def read(self, base_addr:int, length:int, stride:int=0xff) -> bytes:
|
166
|
+
parts = self.exec_ops([ReadOp(base_addr + off, min(stride, length - off)) for off in range(0, length, stride)])
|
167
|
+
return b''.join(p or b'' for p in parts)[:length]
|
168
|
+
|
169
|
+
def _is_pci_cacheable(self, addr:int) -> bool: return any(x <= addr <= x + sz for x, sz in self._pci_cacheable)
|
170
|
+
def pcie_prep_request(self, fmt_type:int, address:int, value:int|None=None, size:int=4) -> list[WriteOp]:
|
171
|
+
if fmt_type == 0x60 and size == 4 and self._is_pci_cacheable(address) and self._pci_cache.get(address) == value: return []
|
172
|
+
|
173
|
+
assert fmt_type >> 8 == 0 and size > 0 and size <= 4, f"Invalid fmt_type {fmt_type} or size {size}"
|
174
|
+
if DEBUG >= 5: print("pcie_request", hex(fmt_type), hex(address), value, size)
|
175
|
+
|
176
|
+
masked_address, offset = address & 0xFFFFFFFC, address & 0x3
|
177
|
+
assert size + offset <= 4 and (value is None or value >> (8 * size) == 0)
|
178
|
+
self._pci_cache[address] = value if size == 4 and fmt_type == 0x60 else None
|
179
|
+
|
180
|
+
return ([WriteOp(0xB220, struct.pack('>I', value << (8 * offset)), ignore_cache=False)] if value is not None else []) + \
|
181
|
+
[WriteOp(0xB218, struct.pack('>I', masked_address), ignore_cache=False), WriteOp(0xB21c, struct.pack('>I', address>>32), ignore_cache=False),
|
182
|
+
WriteOp(0xB217, bytes([((1 << size) - 1) << offset]), ignore_cache=False), WriteOp(0xB210, bytes([fmt_type]), ignore_cache=False),
|
183
|
+
WriteOp(0xB254, b"\x0f", ignore_cache=True), WriteOp(0xB296, b"\x04", ignore_cache=True)]
|
184
|
+
|
185
|
+
def pcie_request(self, fmt_type, address, value=None, size=4, cnt=10):
|
186
|
+
self.exec_ops(self.pcie_prep_request(fmt_type, address, value, size))
|
187
|
+
|
188
|
+
# Fast path for write requests
|
189
|
+
if ((fmt_type & 0b11011111) == 0b01000000) or ((fmt_type & 0b10111000) == 0b00110000): return
|
190
|
+
|
191
|
+
while (stat:=self.read(0xB296, 1)[0]) & 2 == 0:
|
192
|
+
if stat & 1:
|
193
|
+
self.write(0xB296, bytes([0x01]))
|
194
|
+
if cnt > 0: return self.pcie_request(fmt_type, address, value, size, cnt=cnt-1)
|
195
|
+
assert stat == 2, f"stat read 2 was {stat}"
|
196
|
+
|
197
|
+
# Retrieve completion data from Link Status (0xB22A, 0xB22B)
|
198
|
+
b284 = self.read(0xB284, 1)[0]
|
199
|
+
completion = struct.unpack('>H', self.read(0xB22A, 2))
|
200
|
+
|
201
|
+
# Validate completion status based on PCIe request typ
|
202
|
+
# Completion TLPs for configuration requests always have a byte count of 4.
|
203
|
+
assert completion[0] & 0xfff == (4 if (fmt_type & 0xbe == 0x04) else size)
|
204
|
+
|
205
|
+
# Extract completion status field
|
206
|
+
status = (completion[0] >> 13) & 0x7
|
207
|
+
|
208
|
+
# Handle completion errors or inconsistencies
|
209
|
+
if status or ((fmt_type & 0xbe == 0x04) and (((value is None) and (not (b284 & 0x01))) or ((value is not None) and (b284 & 0x01)))):
|
210
|
+
status_map = {0b001: f"Unsupported Request: invalid address/function (target might not be reachable): {address:#x}",
|
211
|
+
0b100: "Completer Abort: abort due to internal error", 0b010: "Configuration Request Retry Status: configuration space busy"}
|
212
|
+
raise RuntimeError(f"TLP status: {status_map.get(status, 'Reserved (0b{:03b})'.format(status))}")
|
213
|
+
|
214
|
+
if value is None: return (struct.unpack('>I', self.read(0xB220, 4))[0] >> (8 * (address & 0x3))) & ((1 << (8 * size)) - 1)
|
215
|
+
|
216
|
+
def pcie_cfg_req(self, byte_addr, bus=1, dev=0, fn=0, value=None, size=4):
|
217
|
+
assert byte_addr >> 12 == 0 and bus >> 8 == 0 and dev >> 5 == 0 and fn >> 3 == 0, f"Invalid byte_addr {byte_addr}, bus {bus}, dev {dev}, fn {fn}"
|
218
|
+
|
219
|
+
fmt_type = (0x44 if value is not None else 0x4) | int(bus > 0)
|
220
|
+
address = (bus << 24) | (dev << 19) | (fn << 16) | (byte_addr & 0xfff)
|
221
|
+
return self.pcie_request(fmt_type, address, value, size)
|
222
|
+
|
223
|
+
def pcie_mem_req(self, address, value=None, size=4): return self.pcie_request(0x60 if value is not None else 0x20, address, value, size)
|
224
|
+
|
225
|
+
def pcie_mem_write(self, address, values, size):
|
226
|
+
ops = [self.pcie_prep_request(0x60, address + i * size, value, size) for i, value in enumerate(values)]
|
227
|
+
|
228
|
+
# Send in batches of 4 for OSX and 16 for Linux (benchmarked values)
|
229
|
+
for i in range(0, len(ops), bs:=(4 if OSX else 16)): self.exec_ops(list(itertools.chain.from_iterable(ops[i:i+bs])))
|
230
|
+
|
231
|
+
class USBMMIOInterface(MMIOInterface):
|
232
|
+
def __init__(self, usb, addr, size, fmt, pcimem=True):
|
233
|
+
self.usb, self.addr, self.nbytes, self.fmt, self.pcimem, self.el_sz = usb, addr, size, fmt, pcimem, struct.calcsize(fmt)
|
234
|
+
|
235
|
+
def __getitem__(self, index): return self._access_items(index)
|
236
|
+
def __setitem__(self, index, val): self._access_items(index, val)
|
237
|
+
|
238
|
+
def _access_items(self, index, val=None):
|
239
|
+
if isinstance(index, slice): return self._acc((index.start or 0) * self.el_sz, ((index.stop or len(self))-(index.start or 0)) * self.el_sz, val)
|
240
|
+
return self._acc_one(index * self.el_sz, self.el_sz, val) if self.pcimem else self._acc(index * self.el_sz, self.el_sz, val)
|
241
|
+
|
242
|
+
def view(self, offset:int=0, size:int|None=None, fmt=None):
|
243
|
+
return USBMMIOInterface(self.usb, self.addr+offset, size or (self.nbytes - offset), fmt=fmt or self.fmt, pcimem=self.pcimem)
|
244
|
+
|
245
|
+
def _acc_size(self, sz): return next(x for x in [('I', 4), ('H', 2), ('B', 1)] if sz % x[1] == 0)
|
246
|
+
|
247
|
+
def _acc_one(self, off, sz, val=None):
|
248
|
+
upper = 0 if sz < 8 else self.usb.pcie_mem_req(self.addr + off + 4, val if val is None else (val >> 32), 4)
|
249
|
+
lower = self.usb.pcie_mem_req(self.addr + off, val if val is None else val & 0xffffffff, min(sz, 4))
|
250
|
+
if val is None: return lower | (upper << 32)
|
251
|
+
|
252
|
+
def _acc(self, off, sz, data=None):
|
253
|
+
if data is None: # read op
|
254
|
+
if not self.pcimem:
|
255
|
+
return int.from_bytes(self.usb.read(self.addr + off, sz), "little") if sz == self.el_sz else self.usb.read(self.addr + off, sz)
|
256
|
+
|
257
|
+
acc, acc_size = self._acc_size(sz)
|
258
|
+
return bytes(array.array(acc, [self._acc_one(off + i * acc_size, acc_size) for i in range(sz // acc_size)]))
|
259
|
+
else: # write op
|
260
|
+
data = struct.pack(self.fmt, data) if isinstance(data, int) else bytes(data)
|
261
|
+
|
262
|
+
if not self.pcimem:
|
263
|
+
# Fast path for writing into buffer 0xf000
|
264
|
+
use_cache = 0xa800 <= self.addr <= 0xb000
|
265
|
+
return self.usb.scsi_write(bytes(data)) if self.addr == 0xf000 else self.usb.write(self.addr + off, bytes(data), ignore_cache=not use_cache)
|
266
|
+
|
267
|
+
_, acc_sz = self._acc_size(len(data) * struct.calcsize(self.fmt))
|
268
|
+
self.usb.pcie_mem_write(self.addr+off, [int.from_bytes(data[i:i+acc_sz], "little") for i in range(0, len(data), acc_sz)], acc_sz)
|
@@ -0,0 +1,18 @@
|
|
1
|
+
import ctypes, ctypes.util, os, subprocess, platform, sysconfig
|
2
|
+
from tinygrad.helpers import OSX
|
3
|
+
|
4
|
+
WEBGPU_PATH: str | None
|
5
|
+
|
6
|
+
if OSX:
|
7
|
+
if not os.path.exists(brew_prefix:=subprocess.check_output(['brew', '--prefix', 'dawn']).decode().strip()):
|
8
|
+
raise FileNotFoundError('dawn library not found. Install it with `brew tap wpmed92/dawn && brew install dawn`')
|
9
|
+
WEBGPU_PATH = os.path.join(brew_prefix, 'lib', 'libwebgpu_dawn.dylib')
|
10
|
+
elif platform.system() == "Windows":
|
11
|
+
if not os.path.exists(pydawn_path:=os.path.join(sysconfig.get_paths()["purelib"], "pydawn")):
|
12
|
+
raise FileNotFoundError("dawn library not found. Install it with `pip install dawn-python`")
|
13
|
+
WEBGPU_PATH = os.path.join(pydawn_path, "lib", "libwebgpu_dawn.dll")
|
14
|
+
else:
|
15
|
+
if (WEBGPU_PATH:=ctypes.util.find_library('webgpu_dawn')) is None:
|
16
|
+
raise FileNotFoundError("dawn library not found. " +
|
17
|
+
"Install it with `sudo curl -L https://github.com/wpmed92/pydawn/releases/download/v0.3.0/" +
|
18
|
+
f"libwebgpu_dawn_{platform.machine()}.so -o /usr/lib/libwebgpu_dawn.so`")
|
File without changes
|