tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,183 @@
1
+ from __future__ import annotations
2
+ import ctypes, time, functools, re, gzip, struct
3
+ from tinygrad.helpers import getenv, DEBUG, fetch, getbits, to_mv
4
+ from tinygrad.runtime.support.hcq import MMIOInterface
5
+ from tinygrad.runtime.support.memory import TLSFAllocator, MemoryManager
6
+ from tinygrad.runtime.support.nv.ip import NV_FLCN, NV_FLCN_COT, NV_GSP
7
+ from tinygrad.runtime.support.system import System, PCIDevImplBase
8
+
9
+ NV_DEBUG = getenv("NV_DEBUG", 0)
10
+
11
+ class NVReg:
12
+ def __init__(self, nvdev, base, off, fields=None): self.nvdev, self.base, self.off, self.fields = nvdev, base, off, fields
13
+
14
+ def __getitem__(self, idx:int): return NVReg(self.nvdev, self.base, self.off(idx), fields=self.fields)
15
+
16
+ def add_field(self, name:str, start:int, end:int): self.fields[name] = (start, end)
17
+ def with_base(self, base:int): return NVReg(self.nvdev, base + self.base, self.off, self.fields)
18
+
19
+ def read(self): return self.nvdev.rreg(self.base + self.off)
20
+ def read_bitfields(self) -> dict[str, int]: return self.decode(self.read())
21
+
22
+ def write(self, _ini_val:int=0, **kwargs): self.nvdev.wreg(self.base + self.off, _ini_val | self.encode(**kwargs))
23
+
24
+ def update(self, **kwargs): self.write(self.read() & ~self.mask(*kwargs.keys()), **kwargs)
25
+
26
+ def mask(self, *names):
27
+ return functools.reduce(int.__or__, ((((1 << (self.fields[nm][1]-self.fields[nm][0] + 1)) - 1) << self.fields[nm][0]) for nm in names), 0)
28
+
29
+ def encode(self, **kwargs) -> int: return functools.reduce(int.__or__, (value << self.fields[name][0] for name,value in kwargs.items()), 0)
30
+ def decode(self, val: int) -> dict: return {name:getbits(val, start, end) for name,(start,end) in self.fields.items()}
31
+
32
+ class NVPageTableEntry:
33
+ def __init__(self, nvdev, paddr, lv): self.nvdev, self.paddr, self.lv, self.entries = nvdev, paddr, lv, nvdev.vram.view(paddr, 0x1000, fmt='Q')
34
+
35
+ def _is_dual_pde(self) -> bool: return self.lv == self.nvdev.mm.level_cnt - 2
36
+
37
+ def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, system=False, snooped=False, frag=0, valid=True):
38
+ if not table:
39
+ x = self.nvdev.pte_t.encode(valid=valid, address_sys=paddr >> 12, aperture=2 if system else 0, kind=6,
40
+ **({'pcf': int(uncached)} if self.nvdev.mmu_ver == 3 else {'vol': uncached}))
41
+ else:
42
+ pde = self.nvdev.dual_pde_t if self._is_dual_pde() else self.nvdev.pde_t
43
+ small, sys = ("_small" if self._is_dual_pde() else ""), "" if self.nvdev.mmu_ver == 3 else "_sys"
44
+ x = pde.encode(is_pte=False, **{f'aperture{small}': 1 if valid else 0, f'address{small}{sys}': paddr >> 12},
45
+ **({f'pcf{small}': 0b10} if self.nvdev.mmu_ver == 3 else {'no_ats': 1}))
46
+
47
+ if self._is_dual_pde(): self.entries[2*entry_id], self.entries[2*entry_id+1] = x & 0xffffffffffffffff, x >> 64
48
+ else: self.entries[entry_id] = x
49
+
50
+ def entry(self, entry_id:int) -> int:
51
+ return (self.entries[2*entry_id+1]<<64) | self.entries[2*entry_id] if self._is_dual_pde() else self.entries[entry_id]
52
+
53
+ def read_fields(self, entry_id:int) -> dict:
54
+ if self.is_huge_page(entry_id): return self.nvdev.pte_t.decode(self.entry(entry_id))
55
+ return (self.nvdev.dual_pde_t if self._is_dual_pde() else self.nvdev.pde_t).decode(self.entry(entry_id))
56
+
57
+ def is_huge_page(self, entry_id) -> bool: return (self.entry(entry_id) & 1 == 1) if self.lv < self.nvdev.mm.level_cnt - 1 else True
58
+ def supports_huge_page(self, paddr:int): return self.lv >= self.nvdev.mm.level_cnt - 3 and paddr % self.nvdev.mm.pte_covers[self.lv] == 0
59
+
60
+ def valid(self, entry_id):
61
+ if self.is_huge_page(entry_id): return self.read_fields(entry_id)['valid']
62
+ return self.read_fields(entry_id)['aperture_small' if self._is_dual_pde() else 'aperture'] != 0
63
+
64
+ def address(self, entry_id:int) -> int:
65
+ small, sys = ("_small" if self._is_dual_pde() else ""), "_sys" if self.nvdev.mmu_ver == 2 or self.lv == self.nvdev.mm.level_cnt - 1 else ""
66
+ return self.read_fields(entry_id)[f'address{small}{sys}'] << 12
67
+
68
+ class NVMemoryManager(MemoryManager):
69
+ va_allocator = TLSFAllocator((1 << 44), base=0x1000000000) # global for all devices.
70
+
71
+ def on_range_mapped(self): self.dev.NV_VIRTUAL_FUNCTION_PRIV_MMU_INVALIDATE.write((1 << 0) | (1 << 1) | (1 << 6) | (1 << 31))
72
+
73
+ class NVDev(PCIDevImplBase):
74
+ def __init__(self, devfmt:str, mmio:MMIOInterface, vram:MMIOInterface, venid:int, subvenid:int, rev:int, bars:dict):
75
+ self.devfmt, self.mmio, self.vram, self.venid, self.subvenid, self.rev, self.bars = devfmt, mmio, vram, venid, subvenid, rev, bars
76
+ self.lock_fd = System.flock_acquire(f"nv_{self.devfmt}.lock")
77
+
78
+ self.smi_dev, self.is_booting = False, True
79
+ self._early_init()
80
+
81
+ # UVM depth HW level VA bits
82
+ # 0 PDE4 56:56 (hopper+)
83
+ # 1 PDE3 55:47
84
+ # 2 PDE2 46:38
85
+ # 3 PDE1 (or 512M PTE) 37:29
86
+ # 4 PDE0 (dual 64k/4k PDE, or 2M PTE) 28:21
87
+ # 5 PTE_64K / PTE_4K 20:16 / 20:12
88
+ bits, shifts = (56, [12, 21, 29, 38, 47, 56]) if self.mmu_ver == 3 else (48, [12, 21, 29, 38, 47])
89
+ self.mm = NVMemoryManager(self, self.vram_size, boot_size=(2 << 20), pt_t=NVPageTableEntry, va_bits=bits, va_shifts=shifts, va_base=0,
90
+ palloc_ranges=[(x, x) for x in [512 << 20, 2 << 20, 4 << 10]])
91
+ self.flcn:NV_FLCN|NV_FLCN_COT = NV_FLCN_COT(self) if self.fmc_boot else NV_FLCN(self)
92
+ self.gsp:NV_GSP = NV_GSP(self)
93
+
94
+ # Turn the booting early, gsp client is loaded from the clean.
95
+ self.is_booting = False
96
+
97
+ for ip in [self.flcn, self.gsp]: ip.init_sw()
98
+ for ip in [self.flcn, self.gsp]: ip.init_hw()
99
+
100
+ def fini(self):
101
+ for ip in [self.gsp, self.flcn]: ip.fini_hw()
102
+
103
+ def reg(self, reg:str) -> NVReg: return self.__dict__[reg]
104
+ def wreg(self, addr:int, value:int):
105
+ self.mmio[addr // 4] = value
106
+ if NV_DEBUG >= 4: print(f"wreg: {hex(addr)} = {hex(value)}")
107
+ def rreg(self, addr:int) -> int: return self.mmio[addr // 4]
108
+
109
+ def _early_init(self):
110
+ self.reg_names:set[str] = set()
111
+ self.reg_offsets:dict[str, tuple[int, int]] = {}
112
+
113
+ self.include("src/common/inc/swref/published/nv_ref.h")
114
+ self.chip_id = self.reg("NV_PMC_BOOT_0").read()
115
+ self.chip_details = self.reg("NV_PMC_BOOT_42").read_bitfields()
116
+ self.chip_name = {0x17: "GA1", 0x19: "AD1", 0x1b: "GB2"}[self.chip_details['architecture']] + f"{self.chip_details['implementation']:02d}"
117
+ self.mmu_ver, self.fmc_boot = (3, True) if self.chip_details['architecture'] >= 0x1a else (2, False)
118
+
119
+ self.include("src/common/inc/swref/published/turing/tu102/dev_fb.h")
120
+ if self.reg("NV_PFB_PRI_MMU_WPR2_ADDR_HI").read() != 0:
121
+ if DEBUG >= 2: print(f"nv {self.devfmt}: WPR2 is up. Issuing a full reset.")
122
+ System.pci_reset(self.devfmt)
123
+ time.sleep(0.5)
124
+
125
+ self.include("src/common/inc/swref/published/turing/tu102/dev_vm.h")
126
+ self.include("src/common/inc/swref/published/ampere/ga102/dev_gc6_island.h")
127
+ self.include("src/common/inc/swref/published/ampere/ga102/dev_gc6_island_addendum.h")
128
+
129
+ # MMU Init
130
+ self.reg_names.update(mmu_pd_names:=[f'NV_MMU_VER{self.mmu_ver}_PTE', f'NV_MMU_VER{self.mmu_ver}_PDE', f'NV_MMU_VER{self.mmu_ver}_DUAL_PDE'])
131
+ for name in mmu_pd_names: self.__dict__[name] = NVReg(self, None, None, fields={})
132
+ self.include(f"kernel-open/nvidia-uvm/hwref/{'hopper/gh100' if self.mmu_ver == 3 else 'turing/tu102'}/dev_mmu.h")
133
+ self.pte_t, self.pde_t, self.dual_pde_t = tuple([self.__dict__[name] for name in mmu_pd_names])
134
+
135
+ self.vram_size = self.reg("NV_PGC6_AON_SECURE_SCRATCH_GROUP_42").read() << 20
136
+
137
+ def _alloc_boot_struct(self, struct:ctypes.Structure) -> tuple[ctypes.Structure, int]:
138
+ va, paddrs = System.alloc_sysmem(sz:=ctypes.sizeof(type(struct)), contiguous=True)
139
+ to_mv(va, sz)[:] = bytes(struct)
140
+ return type(struct).from_address(va), paddrs[0]
141
+
142
+ def _download(self, file:str) -> str:
143
+ url = f"https://raw.githubusercontent.com/NVIDIA/open-gpu-kernel-modules/8ec351aeb96a93a4bb69ccc12a542bf8a8df2b6f/{file}"
144
+ return fetch(url, subdir="defines").read_text()
145
+
146
+ def extract_fw(self, file:str, dname:str) -> bytes:
147
+ # Extracts the firmware binary from the given header
148
+ tname = file.replace("kgsp", "kgspGet")
149
+ text = self._download(f"src/nvidia/generated/g_bindata_{tname}_{self.chip_name}.c")
150
+ info, sl = text[text[:text.index(dnm:=f'{file}_{self.chip_name}_{dname}')].rindex("COMPRESSION:"):][:16], text[text.index(dnm) + len(dnm) + 7:]
151
+ image = bytes.fromhex(sl[:sl.find("};")].strip().replace("0x", "").replace(",", "").replace(" ", "").replace("\n", ""))
152
+ return gzip.decompress(struct.pack("<4BL2B", 0x1f, 0x8b, 8, 0, 0, 0, 3) + image) if "COMPRESSION: YES" in info else image
153
+
154
+ def include(self, file:str):
155
+ regs_off = {'NV_PFALCON_FALCON': 0x0, 'NV_PGSP_FALCON': 0x0, 'NV_PSEC_FALCON': 0x0, 'NV_PRISCV_RISCV': 0x1000, 'NV_PGC6_AON': 0x0, 'NV_PFSP': 0x0,
156
+ 'NV_PGC6_BSI': 0x0, 'NV_PFALCON_FBIF': 0x600, 'NV_PFALCON2_FALCON': 0x1000, 'NV_PBUS': 0x0, 'NV_PFB': 0x0, 'NV_PMC': 0x0, 'NV_PGSP_QUEUE': 0x0,
157
+ 'NV_VIRTUAL_FUNCTION':0xb80000}
158
+
159
+ for raw in self._download(file).splitlines():
160
+ if not raw.startswith("#define "): continue
161
+
162
+ if m:=re.match(r'#define\s+(\w+)\s+([0-9\+\-\*\(\)]+):([0-9\+\-\*\(\)]+)', raw): # bitfields
163
+ name, hi, lo = m.groups()
164
+
165
+ reg = next((r for r in self.reg_names if name.startswith(r+"_")), None)
166
+ if reg is not None: self.__dict__[reg].add_field(name[len(reg)+1:].lower(), eval(lo), eval(hi))
167
+ else: self.reg_offsets[name] = (eval(lo), eval(hi))
168
+ continue
169
+
170
+ if m:=re.match(r'#define\s+(\w+)\s*\(\s*(\w+)\s*\)\s*(.+)', raw): # reg set
171
+ fn = m.groups()[2].strip().rstrip('\\').split('/*')[0].rstrip()
172
+ name, value = m.groups()[0], eval(f"lambda {m.groups()[1]}: {fn}")
173
+ elif m:=re.match(r'#define\s+(\w+)\s+([0-9A-Fa-fx]+)(?![^\n]*:)', raw): name, value = m.groups()[0], int(m.groups()[1], 0) # reg value
174
+ else: continue
175
+
176
+ reg_pref = next((prefix for prefix in regs_off.keys() if name.startswith(prefix)), None)
177
+ not_already_reg = not any(name.startswith(r+"_") for r in self.reg_names)
178
+
179
+ if reg_pref is not None and not_already_reg:
180
+ fields = {k[len(name)+1:]: v for k, v in self.reg_offsets.items() if k.startswith(name+'_')}
181
+ self.__dict__[name] = NVReg(self, regs_off[reg_pref], value, fields=fields)
182
+ self.reg_names.add(name)
183
+ else: self.__dict__[name] = value
@@ -0,0 +1,170 @@
1
+ import os, mmap, array, functools, ctypes, select, contextlib, dataclasses, sys
2
+ from typing import cast, ClassVar
3
+ from tinygrad.helpers import round_up, to_mv, getenv, OSX, temp
4
+ from tinygrad.runtime.autogen import libc, vfio
5
+ from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface, HCQBuffer
6
+ from tinygrad.runtime.support.memory import MemoryManager, VirtMapping
7
+
8
+ MAP_FIXED, MAP_LOCKED, MAP_POPULATE, MAP_NORESERVE = 0x10, 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000), 0x400
9
+
10
+ class _System:
11
+ def reserve_hugepages(self, cnt): os.system(f"sudo sh -c 'echo {cnt} > /proc/sys/vm/nr_hugepages'")
12
+
13
+ def memory_barrier(self): lib.atomic_thread_fence(__ATOMIC_SEQ_CST:=5) if (lib:=self.atomic_lib()) is not None else None
14
+
15
+ def lock_memory(self, addr:int, size:int):
16
+ if libc.mlock(ctypes.c_void_p(addr), size): raise RuntimeError(f"Failed to lock memory at {addr:#x} with size {size:#x}")
17
+
18
+ def system_paddrs(self, vaddr:int, size:int) -> list[int]:
19
+ self.pagemap().seek(vaddr // mmap.PAGESIZE * 8)
20
+ return [(x & ((1<<55) - 1)) * mmap.PAGESIZE for x in array.array('Q', self.pagemap().read(size//mmap.PAGESIZE*8, binary=True))]
21
+
22
+ def alloc_sysmem(self, size:int, vaddr:int=0, contiguous:bool=False, data:bytes|None=None) -> tuple[int, list[int]]:
23
+ assert not contiguous or size <= (2 << 20), "Contiguous allocation is only supported for sizes up to 2MB"
24
+ flags = (libc.MAP_HUGETLB if contiguous and (size:=round_up(size, mmap.PAGESIZE)) > 0x1000 else 0) | (MAP_FIXED if vaddr else 0)
25
+ va = FileIOInterface.anon_mmap(vaddr, size, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED|mmap.MAP_ANONYMOUS|MAP_POPULATE|MAP_LOCKED|flags, 0)
26
+
27
+ if data is not None: to_mv(va, len(data))[:] = data
28
+ return va, self.system_paddrs(va, size)
29
+
30
+ def pci_reset(self, gpu): os.system(f"sudo sh -c 'echo 1 > /sys/bus/pci/devices/{gpu}/reset'")
31
+ def pci_scan_bus(self, target_vendor:int, target_devices:list[int]) -> list[str]:
32
+ result = []
33
+ for pcibus in FileIOInterface("/sys/bus/pci/devices").listdir():
34
+ vendor = int(FileIOInterface(f"/sys/bus/pci/devices/{pcibus}/vendor").read(), 16)
35
+ device = int(FileIOInterface(f"/sys/bus/pci/devices/{pcibus}/device").read(), 16)
36
+ if vendor == target_vendor and device in target_devices: result.append(pcibus)
37
+ return sorted(result)
38
+
39
+ @functools.cache
40
+ def atomic_lib(self): return ctypes.CDLL(ctypes.util.find_library('atomic')) if sys.platform == "linux" else None
41
+
42
+ @functools.cache
43
+ def pagemap(self) -> FileIOInterface:
44
+ if FileIOInterface(reloc_sysfs:="/proc/sys/vm/compact_unevictable_allowed", os.O_RDONLY).read()[0] != "0":
45
+ os.system(cmd:=f"sudo sh -c 'echo 0 > {reloc_sysfs}'")
46
+ assert FileIOInterface(reloc_sysfs, os.O_RDONLY).read()[0] == "0", f"Failed to disable migration of locked pages. Please run {cmd} manually."
47
+ return FileIOInterface("/proc/self/pagemap", os.O_RDONLY)
48
+
49
+ @functools.cache
50
+ def vfio(self) -> FileIOInterface|None:
51
+ try:
52
+ if not FileIOInterface.exists("/sys/module/vfio"): os.system("sudo modprobe vfio-pci disable_idle_d3=1")
53
+
54
+ FileIOInterface("/sys/module/vfio/parameters/enable_unsafe_noiommu_mode", os.O_RDWR).write("1")
55
+ vfio_fd = FileIOInterface("/dev/vfio/vfio", os.O_RDWR)
56
+ vfio.VFIO_CHECK_EXTENSION(vfio_fd, vfio.VFIO_NOIOMMU_IOMMU)
57
+
58
+ return vfio_fd
59
+ except OSError: return None
60
+
61
+ def flock_acquire(self, name:str) -> int:
62
+ import fcntl # to support windows
63
+
64
+ os.umask(0) # Set umask to 0 to allow creating files with 0666 permissions
65
+
66
+ # Avoid O_CREAT because we don’t want to re-create/replace an existing file (triggers extra perms checks) when opening as non-owner.
67
+ if os.path.exists(lock_name:=temp(name)): self.lock_fd = os.open(lock_name, os.O_RDWR)
68
+ else: self.lock_fd = os.open(lock_name, os.O_RDWR | os.O_CREAT | os.O_CLOEXEC, 0o666)
69
+
70
+ try: fcntl.flock(self.lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
71
+ except OSError: raise RuntimeError(f"Failed to take lock file {name}. It's already in use.")
72
+
73
+ return self.lock_fd
74
+
75
+ System = _System()
76
+
77
+ class PCIDevice:
78
+ def __init__(self, pcibus:str, bars:list[int], resize_bars:list[int]|None=None):
79
+ self.pcibus, self.irq_poller = pcibus, None
80
+
81
+ if FileIOInterface.exists(f"/sys/bus/pci/devices/{self.pcibus}/driver"):
82
+ FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/driver/unbind", os.O_WRONLY).write(self.pcibus)
83
+
84
+ for i in resize_bars or []:
85
+ if FileIOInterface.exists(rpath:=f"/sys/bus/pci/devices/{self.pcibus}/resource{i}_resize"):
86
+ try: FileIOInterface(rpath, os.O_RDWR).write(str(int(FileIOInterface(rpath, os.O_RDONLY).read(), 16).bit_length() - 1))
87
+ except OSError as e: raise RuntimeError(f"Cannot resize BAR {i}: {e}. Ensure the resizable BAR option is enabled on your system.") from e
88
+
89
+ if getenv("VFIO", 0) and (vfio_fd:=System.vfio()) is not None:
90
+ FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/driver_override", os.O_WRONLY).write("vfio-pci")
91
+ FileIOInterface("/sys/bus/pci/drivers_probe", os.O_WRONLY).write(self.pcibus)
92
+ iommu_group = FileIOInterface.readlink(f"/sys/bus/pci/devices/{self.pcibus}/iommu_group").split('/')[-1]
93
+
94
+ self.vfio_group = FileIOInterface(f"/dev/vfio/noiommu-{iommu_group}", os.O_RDWR)
95
+ vfio.VFIO_GROUP_SET_CONTAINER(self.vfio_group, ctypes.c_int(vfio_fd.fd))
96
+
97
+ with contextlib.suppress(OSError): vfio.VFIO_SET_IOMMU(vfio_fd, vfio.VFIO_NOIOMMU_IOMMU) # set iommu works only once for the fd.
98
+ self.vfio_dev = FileIOInterface(fd=vfio.VFIO_GROUP_GET_DEVICE_FD(self.vfio_group, ctypes.create_string_buffer(self.pcibus.encode())))
99
+
100
+ self.irq_fd = FileIOInterface.eventfd(0, 0)
101
+ self.irq_poller = select.poll()
102
+ self.irq_poller.register(self.irq_fd.fd, select.POLLIN)
103
+
104
+ irqs = vfio.struct_vfio_irq_set(index=vfio.VFIO_PCI_MSI_IRQ_INDEX, flags=vfio.VFIO_IRQ_SET_DATA_EVENTFD|vfio.VFIO_IRQ_SET_ACTION_TRIGGER,
105
+ argsz=ctypes.sizeof(vfio.struct_vfio_irq_set), count=1, data=(ctypes.c_int * 1)(self.irq_fd.fd))
106
+ vfio.VFIO_DEVICE_SET_IRQS(self.vfio_dev, irqs)
107
+ else: FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/enable", os.O_RDWR).write("1")
108
+
109
+ self.cfg_fd = FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/config", os.O_RDWR | os.O_SYNC | os.O_CLOEXEC)
110
+ self.bar_fds = {b: FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource{b}", os.O_RDWR | os.O_SYNC | os.O_CLOEXEC) for b in bars}
111
+
112
+ bar_info = FileIOInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource", os.O_RDONLY).read().splitlines()
113
+ self.bar_info = {j:(int(start,16), int(end,16), int(flgs,16)) for j,(start,end,flgs) in enumerate(l.split() for l in bar_info)}
114
+
115
+ def read_config(self, offset:int, size:int): return int.from_bytes(self.cfg_fd.read(size, binary=True, offset=offset), byteorder='little')
116
+ def write_config(self, offset:int, value:int, size:int): self.cfg_fd.write(value.to_bytes(size, byteorder='little'), binary=True, offset=offset)
117
+ def map_bar(self, bar:int, off:int=0, addr:int=0, size:int|None=None, fmt='B') -> MMIOInterface:
118
+ fd, sz = self.bar_fds[bar], size or (self.bar_info[bar][1] - self.bar_info[bar][0] + 1)
119
+ libc.madvise(loc:=fd.mmap(addr, sz, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | (MAP_FIXED if addr else 0), off), sz, libc.MADV_DONTFORK)
120
+ return MMIOInterface(loc, sz, fmt=fmt)
121
+
122
+ class PCIDevImplBase:
123
+ mm: MemoryManager
124
+
125
+ @dataclasses.dataclass
126
+ class PCIAllocationMeta: mapping:VirtMapping; has_cpu_mapping:bool; hMemory:int=0 # noqa: E702
127
+
128
+ class PCIIfaceBase:
129
+ dev_impl:PCIDevImplBase
130
+ gpus:ClassVar[list[str]] = []
131
+
132
+ def __init__(self, dev, dev_id, vendor, devices, bars, vram_bar, va_start, va_size):
133
+ if len((cls:=type(self)).gpus) == 0:
134
+ cls.gpus = System.pci_scan_bus(vendor, devices)
135
+ visible_devices = [int(x) for x in (getenv('VISIBLE_DEVICES', '')).split(',') if x.strip()]
136
+ cls.gpus = [cls.gpus[x] for x in visible_devices] if visible_devices else cls.gpus
137
+
138
+ # Acquire va range to avoid collisions.
139
+ FileIOInterface.anon_mmap(va_start, va_size, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS | MAP_NORESERVE | MAP_FIXED, 0)
140
+ self.pci_dev, self.dev, self.vram_bar = PCIDevice(cls.gpus[dev_id], bars=bars, resize_bars=[vram_bar]), dev, vram_bar
141
+ self.p2p_base_addr = self.pci_dev.bar_info[vram_bar][0]
142
+
143
+ def alloc(self, size:int, host=False, uncached=False, cpu_access=False, contiguous=False, **kwargs) -> HCQBuffer:
144
+ if host or (uncached and cpu_access): # host or gtt-like memory.
145
+ vaddr = self.dev_impl.mm.alloc_vaddr(size:=round_up(size, mmap.PAGESIZE), align=mmap.PAGESIZE)
146
+ paddrs = [(paddr, mmap.PAGESIZE) for paddr in System.alloc_sysmem(size, vaddr=vaddr, contiguous=contiguous)[1]]
147
+ mapping = self.dev_impl.mm.map_range(vaddr, size, paddrs, system=True, snooped=True, uncached=True)
148
+ return HCQBuffer(vaddr, size, meta=PCIAllocationMeta(mapping, has_cpu_mapping=True, hMemory=paddrs[0][0]),
149
+ view=MMIOInterface(mapping.va_addr, size, fmt='B'), owner=self.dev)
150
+
151
+ mapping = self.dev_impl.mm.valloc(size:=round_up(size, 4 << 10), uncached=uncached, contiguous=cpu_access)
152
+ if cpu_access: self.pci_dev.map_bar(bar=self.vram_bar, off=mapping.paddrs[0][0], addr=mapping.va_addr, size=mapping.size)
153
+ return HCQBuffer(mapping.va_addr, size, view=MMIOInterface(mapping.va_addr, size, fmt='B') if cpu_access else None,
154
+ meta=PCIAllocationMeta(mapping, has_cpu_mapping=cpu_access, hMemory=mapping.paddrs[0][0]), owner=self.dev)
155
+
156
+ def free(self, b:HCQBuffer):
157
+ for dev in b.mapped_devs[1:]: dev.iface.dev_impl.mm.unmap_range(b.va_addr, b.size)
158
+ if not b.meta.mapping.system: self.dev_impl.mm.vfree(b.meta.mapping)
159
+ if b.owner == self.dev and b.meta.has_cpu_mapping: FileIOInterface.munmap(b.va_addr, b.size)
160
+
161
+ def map(self, b:HCQBuffer):
162
+ if b.owner is not None and b.owner._is_cpu():
163
+ System.lock_memory(cast(int, b.va_addr), b.size)
164
+ paddrs, snooped, uncached = [(x, 0x1000) for x in System.system_paddrs(cast(int, b.va_addr), round_up(b.size, 0x1000))], True, False
165
+ elif (ifa:=getattr(b.owner, "iface", None)) is not None and isinstance(ifa, PCIIfaceBase):
166
+ paddrs = [(paddr if b.meta.mapping.system else (paddr + ifa.p2p_base_addr), size) for paddr,size in b.meta.mapping.paddrs]
167
+ snooped, uncached = b.meta.mapping.snooped, b.meta.mapping.uncached
168
+ else: raise RuntimeError(f"map failed: {b.owner} -> {self.dev}")
169
+
170
+ self.dev_impl.mm.map_range(cast(int, b.va_addr), round_up(b.size, 0x1000), paddrs, system=True, snooped=snooped, uncached=uncached)
@@ -0,0 +1,268 @@
1
+ import ctypes, struct, dataclasses, array, itertools
2
+ from typing import Sequence
3
+ from tinygrad.runtime.autogen import libusb
4
+ from tinygrad.helpers import DEBUG, to_mv, round_up, OSX
5
+ from tinygrad.runtime.support.hcq import MMIOInterface
6
+
7
+ class USB3:
8
+ def __init__(self, vendor:int, dev:int, ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int, max_streams:int=31):
9
+ self.vendor, self.dev = vendor, dev
10
+ self.ep_data_in, self.ep_stat_in, self.ep_data_out, self.ep_cmd_out = ep_data_in, ep_stat_in, ep_data_out, ep_cmd_out
11
+ self.max_streams = max_streams
12
+ self.ctx = ctypes.POINTER(libusb.struct_libusb_context)()
13
+
14
+ if libusb.libusb_init(ctypes.byref(self.ctx)): raise RuntimeError("libusb_init failed")
15
+ if DEBUG >= 6: libusb.libusb_set_option(self.ctx, libusb.LIBUSB_OPTION_LOG_LEVEL, 4)
16
+
17
+ self.handle = libusb.libusb_open_device_with_vid_pid(self.ctx, self.vendor, self.dev)
18
+ if not self.handle: raise RuntimeError(f"device {self.vendor:04x}:{self.dev:04x} not found. sudo required?")
19
+
20
+ # Detach kernel driver if needed
21
+ if libusb.libusb_kernel_driver_active(self.handle, 0):
22
+ libusb.libusb_detach_kernel_driver(self.handle, 0)
23
+ libusb.libusb_reset_device(self.handle)
24
+
25
+ # Set configuration and claim interface
26
+ if libusb.libusb_set_configuration(self.handle, 1): raise RuntimeError("set_configuration failed")
27
+ if libusb.libusb_claim_interface(self.handle, 0): raise RuntimeError("claim_interface failed. sudo required?")
28
+ if libusb.libusb_set_interface_alt_setting(self.handle, 0, 1): raise RuntimeError("alt_setting failed")
29
+
30
+ # Clear any stalled endpoints
31
+ all_eps = (self.ep_data_out, self.ep_data_in, self.ep_stat_in, self.ep_cmd_out)
32
+ for ep in all_eps: libusb.libusb_clear_halt(self.handle, ep)
33
+
34
+ # Allocate streams
35
+ stream_eps = (ctypes.c_uint8 * 3)(self.ep_data_out, self.ep_data_in, self.ep_stat_in)
36
+ if (rc:=libusb.libusb_alloc_streams(self.handle, self.max_streams * len(stream_eps), stream_eps, len(stream_eps))) < 0:
37
+ raise RuntimeError(f"alloc_streams failed: {rc}")
38
+
39
+ # Base cmd
40
+ cmd_template = bytes([0x01, 0x00, 0x00, 0x01, *([0] * 12), 0xE4, 0x24, 0x00, 0xB2, 0x1A, 0x00, 0x00, 0x00, *([0] * 8)])
41
+
42
+ # Init pools
43
+ self.tr = {ep: [libusb.libusb_alloc_transfer(0) for _ in range(self.max_streams)] for ep in all_eps}
44
+
45
+ self.buf_cmd = [(ctypes.c_uint8 * len(cmd_template))(*cmd_template) for _ in range(self.max_streams)]
46
+ self.buf_stat = [(ctypes.c_uint8 * 64)() for _ in range(self.max_streams)]
47
+ self.buf_data_in = [(ctypes.c_uint8 * 0x1000)() for _ in range(self.max_streams)]
48
+ self.buf_data_out = [(ctypes.c_uint8 * 0x80000)() for _ in range(self.max_streams)]
49
+ self.buf_data_out_mvs = [to_mv(ctypes.addressof(self.buf_data_out[i]), 0x80000) for i in range(self.max_streams)]
50
+
51
+ for slot in range(self.max_streams): struct.pack_into(">B", self.buf_cmd[slot], 3, slot + 1)
52
+
53
+ def _prep_transfer(self, tr, ep, stream_id, buf, length):
54
+ tr.contents.dev_handle, tr.contents.endpoint, tr.contents.length, tr.contents.buffer = self.handle, ep, length, buf
55
+ tr.contents.status, tr.contents.flags, tr.contents.timeout, tr.contents.num_iso_packets = 0xff, 0, 1000, 0
56
+ tr.contents.type = (libusb.LIBUSB_TRANSFER_TYPE_BULK_STREAM if stream_id is not None else libusb.LIBUSB_TRANSFER_TYPE_BULK)
57
+ if stream_id is not None: libusb.libusb_transfer_set_stream_id(tr, stream_id)
58
+ return tr
59
+
60
+ def _submit_and_wait(self, cmds):
61
+ for tr in cmds: libusb.libusb_submit_transfer(tr)
62
+
63
+ running = len(cmds)
64
+ while running:
65
+ libusb.libusb_handle_events(self.ctx)
66
+ running = len(cmds)
67
+ for tr in cmds:
68
+ if tr.contents.status == libusb.LIBUSB_TRANSFER_COMPLETED: running -= 1
69
+ elif tr.contents.status != 0xFF: raise RuntimeError(f"EP 0x{tr.contents.endpoint:02X} error: {tr.contents.status}")
70
+
71
+ def send_batch(self, cdbs:list[bytes], idata:list[int]|None=None, odata:list[bytes|None]|None=None) -> list[bytes|None]:
72
+ idata, odata = idata or [0] * len(cdbs), odata or [None] * len(cdbs)
73
+ results, tr_window, op_window = [], [], []
74
+
75
+ for idx, (cdb, rlen, send_data) in enumerate(zip(cdbs, idata, odata)):
76
+ # allocate slot and stream. stream is 1-based
77
+ slot, stream = idx % self.max_streams, (idx % self.max_streams) + 1
78
+
79
+ # build cmd packet
80
+ self.buf_cmd[slot][16:16+len(cdb)] = list(cdb)
81
+
82
+ # cmd + stat transfers
83
+ tr_window.append(self._prep_transfer(self.tr[self.ep_cmd_out][slot], self.ep_cmd_out, None, self.buf_cmd[slot], len(self.buf_cmd[slot])))
84
+ tr_window.append(self._prep_transfer(self.tr[self.ep_stat_in][slot], self.ep_stat_in, stream, self.buf_stat[slot], 64))
85
+
86
+ if rlen:
87
+ if rlen > len(self.buf_data_in[slot]): self.buf_data_in[slot] = (ctypes.c_uint8 * round_up(rlen, 0x1000))()
88
+ tr_window.append(self._prep_transfer(self.tr[self.ep_data_in][slot], self.ep_data_in, stream, self.buf_data_in[slot], rlen))
89
+
90
+ if send_data is not None:
91
+ if len(send_data) > len(self.buf_data_out[slot]):
92
+ self.buf_data_out[slot] = (ctypes.c_uint8 * len(send_data))()
93
+ self.buf_data_out_mvs[slot] = to_mv(ctypes.addressof(self.buf_data_out[slot]), len(send_data))
94
+
95
+ self.buf_data_out_mvs[slot][:len(send_data)] = bytes(send_data)
96
+ tr_window.append(self._prep_transfer(self.tr[self.ep_data_out][slot], self.ep_data_out, stream, self.buf_data_out[slot], len(send_data)))
97
+
98
+ op_window.append((idx, slot, rlen))
99
+ if (idx + 1 == len(cdbs)) or len(op_window) >= self.max_streams:
100
+ self._submit_and_wait(tr_window)
101
+ for idx, slot, rlen in op_window: results.append(bytes(self.buf_data_in[slot][:rlen]) if rlen else None)
102
+ tr_window = []
103
+
104
+ return results
105
+
106
+ @dataclasses.dataclass(frozen=True)
107
+ class WriteOp: addr:int; data:bytes; ignore_cache:bool=True # noqa: E702
108
+
109
+ @dataclasses.dataclass(frozen=True)
110
+ class ReadOp: addr:int; size:int # noqa: E702
111
+
112
+ @dataclasses.dataclass(frozen=True)
113
+ class ScsiWriteOp: data:bytes; lba:int=0 # noqa: E702
114
+
115
+ class ASM24Controller:
116
+ def __init__(self):
117
+ self.usb = USB3(0xADD1, 0x0001, 0x81, 0x83, 0x02, 0x04)
118
+ self._cache: dict[int, int|None] = {}
119
+ self._pci_cacheable: list[tuple[int, int]] = []
120
+ self._pci_cache: dict[int, int|None] = {}
121
+
122
+ # Init controller.
123
+ self.exec_ops([WriteOp(0x54b, b' '), WriteOp(0x54e, b'\x04'), WriteOp(0x5a8, b'\x02'), WriteOp(0x5f8, b'\x04'),
124
+ WriteOp(0x7ec, b'\x01\x00\x00\x00'), WriteOp(0xc422, b'\x02'), WriteOp(0x0, b'\x33')])
125
+
126
+ def exec_ops(self, ops:Sequence[WriteOp|ReadOp|ScsiWriteOp]):
127
+ cdbs:list[bytes] = []
128
+ idata:list[int] = []
129
+ odata:list[bytes|None] = []
130
+
131
+ def _add_req(cdb:bytes, i:int, o:bytes|None):
132
+ nonlocal cdbs, idata, odata
133
+ cdbs, idata, odata = cdbs + [cdb], idata + [i], odata + [o]
134
+
135
+ for op in ops:
136
+ if isinstance(op, WriteOp):
137
+ for off, value in enumerate(op.data):
138
+ addr = ((op.addr + off) & 0x1FFFF) | 0x500000
139
+ if not op.ignore_cache and self._cache.get(addr) == value: continue
140
+ _add_req(struct.pack('>BBBHB', 0xE5, value, addr >> 16, addr & 0xFFFF, 0), 0, None)
141
+ self._cache[addr] = value
142
+ elif isinstance(op, ReadOp):
143
+ assert op.size <= 0xff
144
+ addr = (op.addr & 0x1FFFF) | 0x500000
145
+ _add_req(struct.pack('>BBBHB', 0xE4, op.size, addr >> 16, addr & 0xFFFF, 0), op.size, None)
146
+ for i in range(op.size): self._cache[addr + i] = None
147
+ elif isinstance(op, ScsiWriteOp):
148
+ sectors = round_up(len(op.data), 512) // 512
149
+ _add_req(struct.pack('>BBQIBB', 0x8A, 0, op.lba, sectors, 0, 0), 0, op.data+b'\x00'*((sectors*512)-len(op.data)))
150
+
151
+ return self.usb.send_batch(cdbs, idata, odata)
152
+
153
+ def write(self, base_addr:int, data:bytes, ignore_cache:bool=True): return self.exec_ops([WriteOp(base_addr, data, ignore_cache)])
154
+
155
+ def scsi_write(self, buf:bytes, lba:int=0):
156
+ if len(buf) > 0x4000: buf += b'\x00' * (round_up(len(buf), 0x10000) - len(buf))
157
+
158
+ for i in range(0, len(buf), 0x10000):
159
+ self.exec_ops([ScsiWriteOp(buf[i:i+0x10000], lba), WriteOp(0x171, b'\xff\xff\xff', ignore_cache=True)])
160
+ self.exec_ops([WriteOp(0xce6e, b'\x00\x00', ignore_cache=True)])
161
+
162
+ if len(buf) > 0x4000:
163
+ for i in range(4): self.exec_ops([WriteOp(0xce40 + i, b'\x00', ignore_cache=True)])
164
+
165
+ def read(self, base_addr:int, length:int, stride:int=0xff) -> bytes:
166
+ parts = self.exec_ops([ReadOp(base_addr + off, min(stride, length - off)) for off in range(0, length, stride)])
167
+ return b''.join(p or b'' for p in parts)[:length]
168
+
169
+ def _is_pci_cacheable(self, addr:int) -> bool: return any(x <= addr <= x + sz for x, sz in self._pci_cacheable)
170
+ def pcie_prep_request(self, fmt_type:int, address:int, value:int|None=None, size:int=4) -> list[WriteOp]:
171
+ if fmt_type == 0x60 and size == 4 and self._is_pci_cacheable(address) and self._pci_cache.get(address) == value: return []
172
+
173
+ assert fmt_type >> 8 == 0 and size > 0 and size <= 4, f"Invalid fmt_type {fmt_type} or size {size}"
174
+ if DEBUG >= 5: print("pcie_request", hex(fmt_type), hex(address), value, size)
175
+
176
+ masked_address, offset = address & 0xFFFFFFFC, address & 0x3
177
+ assert size + offset <= 4 and (value is None or value >> (8 * size) == 0)
178
+ self._pci_cache[address] = value if size == 4 and fmt_type == 0x60 else None
179
+
180
+ return ([WriteOp(0xB220, struct.pack('>I', value << (8 * offset)), ignore_cache=False)] if value is not None else []) + \
181
+ [WriteOp(0xB218, struct.pack('>I', masked_address), ignore_cache=False), WriteOp(0xB21c, struct.pack('>I', address>>32), ignore_cache=False),
182
+ WriteOp(0xB217, bytes([((1 << size) - 1) << offset]), ignore_cache=False), WriteOp(0xB210, bytes([fmt_type]), ignore_cache=False),
183
+ WriteOp(0xB254, b"\x0f", ignore_cache=True), WriteOp(0xB296, b"\x04", ignore_cache=True)]
184
+
185
+ def pcie_request(self, fmt_type, address, value=None, size=4, cnt=10):
186
+ self.exec_ops(self.pcie_prep_request(fmt_type, address, value, size))
187
+
188
+ # Fast path for write requests
189
+ if ((fmt_type & 0b11011111) == 0b01000000) or ((fmt_type & 0b10111000) == 0b00110000): return
190
+
191
+ while (stat:=self.read(0xB296, 1)[0]) & 2 == 0:
192
+ if stat & 1:
193
+ self.write(0xB296, bytes([0x01]))
194
+ if cnt > 0: return self.pcie_request(fmt_type, address, value, size, cnt=cnt-1)
195
+ assert stat == 2, f"stat read 2 was {stat}"
196
+
197
+ # Retrieve completion data from Link Status (0xB22A, 0xB22B)
198
+ b284 = self.read(0xB284, 1)[0]
199
+ completion = struct.unpack('>H', self.read(0xB22A, 2))
200
+
201
+ # Validate completion status based on PCIe request typ
202
+ # Completion TLPs for configuration requests always have a byte count of 4.
203
+ assert completion[0] & 0xfff == (4 if (fmt_type & 0xbe == 0x04) else size)
204
+
205
+ # Extract completion status field
206
+ status = (completion[0] >> 13) & 0x7
207
+
208
+ # Handle completion errors or inconsistencies
209
+ if status or ((fmt_type & 0xbe == 0x04) and (((value is None) and (not (b284 & 0x01))) or ((value is not None) and (b284 & 0x01)))):
210
+ status_map = {0b001: f"Unsupported Request: invalid address/function (target might not be reachable): {address:#x}",
211
+ 0b100: "Completer Abort: abort due to internal error", 0b010: "Configuration Request Retry Status: configuration space busy"}
212
+ raise RuntimeError(f"TLP status: {status_map.get(status, 'Reserved (0b{:03b})'.format(status))}")
213
+
214
+ if value is None: return (struct.unpack('>I', self.read(0xB220, 4))[0] >> (8 * (address & 0x3))) & ((1 << (8 * size)) - 1)
215
+
216
+ def pcie_cfg_req(self, byte_addr, bus=1, dev=0, fn=0, value=None, size=4):
217
+ assert byte_addr >> 12 == 0 and bus >> 8 == 0 and dev >> 5 == 0 and fn >> 3 == 0, f"Invalid byte_addr {byte_addr}, bus {bus}, dev {dev}, fn {fn}"
218
+
219
+ fmt_type = (0x44 if value is not None else 0x4) | int(bus > 0)
220
+ address = (bus << 24) | (dev << 19) | (fn << 16) | (byte_addr & 0xfff)
221
+ return self.pcie_request(fmt_type, address, value, size)
222
+
223
+ def pcie_mem_req(self, address, value=None, size=4): return self.pcie_request(0x60 if value is not None else 0x20, address, value, size)
224
+
225
+ def pcie_mem_write(self, address, values, size):
226
+ ops = [self.pcie_prep_request(0x60, address + i * size, value, size) for i, value in enumerate(values)]
227
+
228
+ # Send in batches of 4 for OSX and 16 for Linux (benchmarked values)
229
+ for i in range(0, len(ops), bs:=(4 if OSX else 16)): self.exec_ops(list(itertools.chain.from_iterable(ops[i:i+bs])))
230
+
231
+ class USBMMIOInterface(MMIOInterface):
232
+ def __init__(self, usb, addr, size, fmt, pcimem=True):
233
+ self.usb, self.addr, self.nbytes, self.fmt, self.pcimem, self.el_sz = usb, addr, size, fmt, pcimem, struct.calcsize(fmt)
234
+
235
+ def __getitem__(self, index): return self._access_items(index)
236
+ def __setitem__(self, index, val): self._access_items(index, val)
237
+
238
+ def _access_items(self, index, val=None):
239
+ if isinstance(index, slice): return self._acc((index.start or 0) * self.el_sz, ((index.stop or len(self))-(index.start or 0)) * self.el_sz, val)
240
+ return self._acc_one(index * self.el_sz, self.el_sz, val) if self.pcimem else self._acc(index * self.el_sz, self.el_sz, val)
241
+
242
+ def view(self, offset:int=0, size:int|None=None, fmt=None):
243
+ return USBMMIOInterface(self.usb, self.addr+offset, size or (self.nbytes - offset), fmt=fmt or self.fmt, pcimem=self.pcimem)
244
+
245
+ def _acc_size(self, sz): return next(x for x in [('I', 4), ('H', 2), ('B', 1)] if sz % x[1] == 0)
246
+
247
+ def _acc_one(self, off, sz, val=None):
248
+ upper = 0 if sz < 8 else self.usb.pcie_mem_req(self.addr + off + 4, val if val is None else (val >> 32), 4)
249
+ lower = self.usb.pcie_mem_req(self.addr + off, val if val is None else val & 0xffffffff, min(sz, 4))
250
+ if val is None: return lower | (upper << 32)
251
+
252
+ def _acc(self, off, sz, data=None):
253
+ if data is None: # read op
254
+ if not self.pcimem:
255
+ return int.from_bytes(self.usb.read(self.addr + off, sz), "little") if sz == self.el_sz else self.usb.read(self.addr + off, sz)
256
+
257
+ acc, acc_size = self._acc_size(sz)
258
+ return bytes(array.array(acc, [self._acc_one(off + i * acc_size, acc_size) for i in range(sz // acc_size)]))
259
+ else: # write op
260
+ data = struct.pack(self.fmt, data) if isinstance(data, int) else bytes(data)
261
+
262
+ if not self.pcimem:
263
+ # Fast path for writing into buffer 0xf000
264
+ use_cache = 0xa800 <= self.addr <= 0xb000
265
+ return self.usb.scsi_write(bytes(data)) if self.addr == 0xf000 else self.usb.write(self.addr + off, bytes(data), ignore_cache=not use_cache)
266
+
267
+ _, acc_sz = self._acc_size(len(data) * struct.calcsize(self.fmt))
268
+ self.usb.pcie_mem_write(self.addr+off, [int.from_bytes(data[i:i+acc_sz], "little") for i in range(0, len(data), acc_sz)], acc_sz)
@@ -0,0 +1,18 @@
1
+ import ctypes, ctypes.util, os, subprocess, platform, sysconfig
2
+ from tinygrad.helpers import OSX
3
+
4
+ WEBGPU_PATH: str | None
5
+
6
+ if OSX:
7
+ if not os.path.exists(brew_prefix:=subprocess.check_output(['brew', '--prefix', 'dawn']).decode().strip()):
8
+ raise FileNotFoundError('dawn library not found. Install it with `brew tap wpmed92/dawn && brew install dawn`')
9
+ WEBGPU_PATH = os.path.join(brew_prefix, 'lib', 'libwebgpu_dawn.dylib')
10
+ elif platform.system() == "Windows":
11
+ if not os.path.exists(pydawn_path:=os.path.join(sysconfig.get_paths()["purelib"], "pydawn")):
12
+ raise FileNotFoundError("dawn library not found. Install it with `pip install dawn-python`")
13
+ WEBGPU_PATH = os.path.join(pydawn_path, "lib", "libwebgpu_dawn.dll")
14
+ else:
15
+ if (WEBGPU_PATH:=ctypes.util.find_library('webgpu_dawn')) is None:
16
+ raise FileNotFoundError("dawn library not found. " +
17
+ "Install it with `sudo curl -L https://github.com/wpmed92/pydawn/releases/download/v0.3.0/" +
18
+ f"libwebgpu_dawn_{platform.machine()}.so -o /usr/lib/libwebgpu_dawn.so`")
File without changes