tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +35 -37
- tinygrad/codegen/linearize.py +19 -10
- tinygrad/codegen/lowerer.py +31 -8
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +10 -0
- tinygrad/device.py +28 -11
- tinygrad/dtype.py +12 -3
- tinygrad/engine/jit.py +3 -2
- tinygrad/engine/multi.py +0 -1
- tinygrad/engine/realize.py +7 -4
- tinygrad/engine/schedule.py +227 -255
- tinygrad/engine/search.py +20 -27
- tinygrad/gradient.py +3 -0
- tinygrad/helpers.py +7 -4
- tinygrad/nn/state.py +2 -2
- tinygrad/ops.py +64 -329
- tinygrad/renderer/__init__.py +19 -3
- tinygrad/renderer/cstyle.py +39 -18
- tinygrad/renderer/llvmir.py +55 -18
- tinygrad/renderer/ptx.py +6 -2
- tinygrad/renderer/wgsl.py +20 -12
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/metal.py +28 -29
- tinygrad/runtime/ops_amd.py +37 -34
- tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
- tinygrad/runtime/ops_disk.py +1 -1
- tinygrad/runtime/ops_dsp.py +59 -33
- tinygrad/runtime/ops_llvm.py +14 -12
- tinygrad/runtime/ops_metal.py +78 -62
- tinygrad/runtime/ops_nv.py +9 -6
- tinygrad/runtime/ops_python.py +5 -5
- tinygrad/runtime/ops_webgpu.py +200 -38
- tinygrad/runtime/support/am/amdev.py +23 -11
- tinygrad/runtime/support/am/ip.py +10 -10
- tinygrad/runtime/support/elf.py +2 -0
- tinygrad/runtime/support/hcq.py +7 -5
- tinygrad/runtime/support/llvm.py +8 -14
- tinygrad/shape/shapetracker.py +3 -2
- tinygrad/shape/view.py +2 -3
- tinygrad/spec.py +21 -20
- tinygrad/tensor.py +150 -90
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- tinygrad/codegen/rewriter.py +0 -516
- tinygrad-0.10.1.dist-info/RECORD +0 -86
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import ctypes, collections, time, dataclasses, pathlib, fcntl, os
|
3
3
|
from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
|
4
|
-
from tinygrad.runtime.autogen.am import am, mp_11_0
|
4
|
+
from tinygrad.runtime.autogen.am import am, mp_11_0
|
5
5
|
from tinygrad.runtime.support.allocator import TLSFAllocator
|
6
6
|
from tinygrad.runtime.support.am.ip import AM_SOC21, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
|
7
7
|
|
@@ -32,11 +32,13 @@ class AMRegister:
|
|
32
32
|
def read(self, **kwargs): return self.adev.rreg(self.reg_off) & self._parse_kwargs(**kwargs)[0]
|
33
33
|
|
34
34
|
class AMFirmware:
|
35
|
-
def __init__(self):
|
35
|
+
def __init__(self, adev):
|
36
|
+
def fmt_ver(hwip): return f"{adev.ip_versions[hwip]//10000}_{(adev.ip_versions[hwip]//100)%100}_{adev.ip_versions[hwip]%100}"
|
37
|
+
|
36
38
|
# Load SOS firmware
|
37
39
|
self.sos_fw = {}
|
38
40
|
|
39
|
-
blob, sos_hdr = self.load_fw("
|
41
|
+
blob, sos_hdr = self.load_fw(f"psp_{fmt_ver(am.MP0_HWIP)}_sos.bin", am.struct_psp_firmware_header_v2_0)
|
40
42
|
fw_bin = sos_hdr.psp_fw_bin
|
41
43
|
|
42
44
|
for fw_i in range(sos_hdr.psp_fw_bin_count):
|
@@ -48,17 +50,17 @@ class AMFirmware:
|
|
48
50
|
self.ucode_start: dict[str, int] = {}
|
49
51
|
self.descs: list[tuple[int, memoryview]] = []
|
50
52
|
|
51
|
-
blob, hdr = self.load_fw("
|
53
|
+
blob, hdr = self.load_fw(f"smu_{fmt_ver(am.MP1_HWIP)}.bin", am.struct_smc_firmware_header_v1_0)
|
52
54
|
self.smu_psp_desc = self.desc(am.GFX_FW_TYPE_SMU, blob, hdr.header.ucode_array_offset_bytes, hdr.header.ucode_size_bytes)
|
53
55
|
|
54
56
|
# SDMA firmware
|
55
|
-
blob, hdr = self.load_fw("
|
57
|
+
blob, hdr = self.load_fw(f"sdma_{fmt_ver(am.SDMA0_HWIP)}.bin", am.struct_sdma_firmware_header_v2_0)
|
56
58
|
self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH0, blob, hdr.header.ucode_array_offset_bytes, hdr.ctx_ucode_size_bytes)]
|
57
59
|
self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH1, blob, hdr.ctl_ucode_offset, hdr.ctl_ucode_size_bytes)]
|
58
60
|
|
59
61
|
# PFP, ME, MEC firmware
|
60
62
|
for (fw_name, fw_cnt) in [('PFP', 2), ('ME', 2), ('MEC', 4)]:
|
61
|
-
blob, hdr = self.load_fw(f"
|
63
|
+
blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_{fw_name.lower()}.bin", am.struct_gfx_firmware_header_v2_0)
|
62
64
|
|
63
65
|
# Code part
|
64
66
|
self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}'), blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes)]
|
@@ -69,12 +71,12 @@ class AMFirmware:
|
|
69
71
|
self.ucode_start[fw_name] = hdr.ucode_start_addr_lo | (hdr.ucode_start_addr_hi << 32)
|
70
72
|
|
71
73
|
# IMU firmware
|
72
|
-
blob, hdr = self.load_fw("
|
74
|
+
blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_imu.bin", am.struct_imu_firmware_header_v1_0)
|
73
75
|
imu_i_off, imu_i_sz, imu_d_sz = hdr.header.ucode_array_offset_bytes, hdr.imu_iram_ucode_size_bytes, hdr.imu_dram_ucode_size_bytes
|
74
76
|
self.descs += [self.desc(am.GFX_FW_TYPE_IMU_I, blob, imu_i_off, imu_i_sz), self.desc(am.GFX_FW_TYPE_IMU_D, blob, imu_i_off + imu_i_sz, imu_d_sz)]
|
75
77
|
|
76
78
|
# RLC firmware
|
77
|
-
blob, hdr0, hdr1, hdr2, hdr3 = self.load_fw("
|
79
|
+
blob, hdr0, hdr1, hdr2, hdr3 = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_rlc.bin", am.struct_rlc_firmware_header_v2_0,
|
78
80
|
am.struct_rlc_firmware_header_v2_1, am.struct_rlc_firmware_header_v2_2, am.struct_rlc_firmware_header_v2_3)
|
79
81
|
|
80
82
|
for mem in ['GPM', 'SRM']:
|
@@ -263,7 +265,7 @@ class AMDev:
|
|
263
265
|
|
264
266
|
# Memory manager & firmware
|
265
267
|
self.mm = AMMemoryManager(self, self.vram_size)
|
266
|
-
self.fw = AMFirmware()
|
268
|
+
self.fw = AMFirmware(self)
|
267
269
|
|
268
270
|
# Initialize IP blocks
|
269
271
|
self.soc21:AM_SOC21 = AM_SOC21(self)
|
@@ -274,7 +276,7 @@ class AMDev:
|
|
274
276
|
self.gfx:AM_GFX = AM_GFX(self)
|
275
277
|
self.sdma:AM_SDMA = AM_SDMA(self)
|
276
278
|
|
277
|
-
if self.partial_boot and (self.reg("
|
279
|
+
if self.partial_boot and (self.reg("regGCVM_CONTEXT0_CNTL").read() != 0):
|
278
280
|
if DEBUG >= 2: print(f"am {self.devfmt}: MEC is active. Issue a full reset.")
|
279
281
|
self.partial_boot = False
|
280
282
|
|
@@ -298,8 +300,10 @@ class AMDev:
|
|
298
300
|
if DEBUG >= 2: print(f"am {self.devfmt}: boot done")
|
299
301
|
|
300
302
|
def fini(self):
|
303
|
+
if DEBUG >= 2: print(f"am {self.devfmt}: Finalizing")
|
301
304
|
for ip in [self.sdma, self.gfx]: ip.fini()
|
302
305
|
self.smu.set_clocks(level=0)
|
306
|
+
self.ih.interrupt_handler()
|
303
307
|
|
304
308
|
def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr
|
305
309
|
def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr
|
@@ -369,8 +373,16 @@ class AMDev:
|
|
369
373
|
gc_info = am.struct_gc_info_v1_0.from_address(gc_addr:=ctypes.addressof(bhdr) + bhdr.table_list[am.GC].offset)
|
370
374
|
self.gc_info = getattr(am, f"struct_gc_info_v{gc_info.header.version_major}_{gc_info.header.version_minor}").from_address(gc_addr)
|
371
375
|
|
376
|
+
def _ip_module(self, prefix:str, hwip):
|
377
|
+
version = [self.ip_versions[hwip]//10000, (self.ip_versions[hwip]//100)%100, self.ip_versions[hwip]%100]
|
378
|
+
for ver in [version, version[:2]+[0], version[:1]+[0, 0]]:
|
379
|
+
try: return __import__(f"tinygrad.runtime.autogen.am.{prefix}_{ver[0]}_{ver[1]}_{ver[2]}", fromlist=[f"{prefix}_{ver[0]}_{ver[1]}_{ver[2]}"])
|
380
|
+
except ImportError: pass
|
381
|
+
raise ImportError(f"am {self.devfmt}: failed to load {prefix} module with version {version}")
|
382
|
+
|
372
383
|
def _build_regs(self):
|
373
|
-
mods = [("MP0",
|
384
|
+
mods = [("MP0", self._ip_module("mp", am.MP0_HWIP)), ("NBIO", self._ip_module("nbio", am.NBIO_HWIP)), ("GC", self._ip_module("gc", am.GC_HWIP)),
|
385
|
+
("MP1", mp_11_0), ("MMHUB", self._ip_module("mmhub", am.MMHUB_HWIP)), ("OSSSYS", self._ip_module("osssys", am.OSSSYS_HWIP))]
|
374
386
|
for base, module in mods:
|
375
387
|
rpref = "mm" if base == "MP1" else "reg" # MP1 regs starts with mm
|
376
388
|
reg_names: set[str] = set(k[len(rpref):] for k in module.__dict__.keys() if k.startswith(rpref) and not k.endswith("_BASE_IDX"))
|
@@ -18,7 +18,7 @@ class AM_GMC(AM_IP):
|
|
18
18
|
super().__init__(adev)
|
19
19
|
|
20
20
|
# Memory controller aperture
|
21
|
-
self.mc_base = self.adev.regMMMC_VM_FB_LOCATION_BASE.read() << 24
|
21
|
+
self.mc_base = (self.adev.regMMMC_VM_FB_LOCATION_BASE.read() & 0xFFFFFF) << 24
|
22
22
|
self.mc_end = self.mc_base + self.adev.mm.vram_size - 1
|
23
23
|
|
24
24
|
# VM aperture
|
@@ -189,8 +189,6 @@ class AM_GFX(AM_IP):
|
|
189
189
|
self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2) # 1 - DRAIN_PIPE; 2 - RESET_WAVES
|
190
190
|
self.adev.regSPI_COMPUTE_QUEUE_RESET.write(1)
|
191
191
|
self._grbm_select()
|
192
|
-
self.adev.regCP_MEC_RS64_CNTL.update(mec_invalidate_icache=1, mec_pipe0_reset=1, mec_pipe1_reset=1, mec_pipe2_reset=1, mec_pipe3_reset=1,
|
193
|
-
mec_pipe0_active=0, mec_pipe1_active=0, mec_pipe2_active=0, mec_pipe3_active=0, mec_halt=1)
|
194
192
|
self.adev.regGCVM_CONTEXT0_CNTL.write(0)
|
195
193
|
|
196
194
|
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, doorbell:int, pipe:int, queue:int):
|
@@ -225,6 +223,8 @@ class AM_GFX(AM_IP):
|
|
225
223
|
self.adev.reg(f"regCP_ME1_PIPE{pipe}_INT_CNTL").update(time_stamp_int_enable=1, generic0_int_enable=1)
|
226
224
|
|
227
225
|
def set_clockgating_state(self):
|
226
|
+
if hasattr(self.adev, 'regMM_ATC_L2_MISC_CG'): self.adev.regMM_ATC_L2_MISC_CG.write(enable=1, mem_ls_enable=1)
|
227
|
+
|
228
228
|
self.adev.regRLC_SAFE_MODE.write(message=1, cmd=1)
|
229
229
|
self.adev.wait_reg(self.adev.regRLC_SAFE_MODE, mask=0x1, value=0x0)
|
230
230
|
|
@@ -233,6 +233,7 @@ class AM_GFX(AM_IP):
|
|
233
233
|
self.adev.regCP_RB_WPTR_POLL_CNTL.update(poll_frequency=0x100, idle_poll_count=0x90)
|
234
234
|
self.adev.regCP_INT_CNTL.update(cntx_busy_int_enable=1, cntx_empty_int_enable=1, cmp_busy_int_enable=1, gfx_idle_int_enable=1)
|
235
235
|
self.adev.regSDMA0_RLC_CGCG_CTRL.update(cgcg_int_enable=1)
|
236
|
+
self.adev.regSDMA1_RLC_CGCG_CTRL.update(cgcg_int_enable=1)
|
236
237
|
|
237
238
|
self.adev.regRLC_CGTT_MGCG_OVERRIDE.update(perfmon_clock_state=0, gfxip_fgcg_override=0, gfxip_repeater_fgcg_override=0,
|
238
239
|
grbm_cgtt_sclk_override=0, rlc_cgtt_sclk_override=0, gfxip_mgcg_override=0, gfxip_cgls_override=0, gfxip_cgcg_override=0)
|
@@ -311,17 +312,16 @@ class AM_SDMA(AM_IP):
|
|
311
312
|
self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_IB_CNTL").update(ib_enable=1)
|
312
313
|
|
313
314
|
def init(self):
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
315
|
+
for pipe in range(2):
|
316
|
+
self.adev.reg(f"regSDMA{pipe}_WATCHDOG_CNTL").update(queue_hang_count=100) # 10s, 100ms per unit
|
317
|
+
self.adev.reg(f"regSDMA{pipe}_UTCL1_CNTL").update(resp_mode=3, redo_delay=9)
|
318
|
+
self.adev.reg(f"regSDMA{pipe}_UTCL1_PAGE").update(rd_l2_policy=0x2, wr_l2_policy=0x3, llc_noalloc=1) # rd=noa, wr=bypass
|
319
|
+
self.adev.reg(f"regSDMA{pipe}_F32_CNTL").update(halt=0, th1_reset=0)
|
320
|
+
self.adev.reg(f"regSDMA{pipe}_CNTL").update(ctxempty_int_enable=1, trap_enable=1)
|
320
321
|
|
321
322
|
def fini(self):
|
322
323
|
self.adev.regSDMA0_QUEUE0_RB_CNTL.update(rb_enable=0)
|
323
324
|
self.adev.regSDMA0_QUEUE0_IB_CNTL.update(ib_enable=0)
|
324
|
-
self.adev.regSDMA0_F32_CNTL.update(halt=1, th1_reset=1)
|
325
325
|
self.adev.regGRBM_SOFT_RESET.write(soft_reset_sdma0=1)
|
326
326
|
time.sleep(0.01)
|
327
327
|
self.adev.regGRBM_SOFT_RESET.write(0x0)
|
tinygrad/runtime/support/elf.py
CHANGED
@@ -32,6 +32,8 @@ def elf_loader(blob:bytes, force_section_align:int=1) -> tuple[memoryview, list[
|
|
32
32
|
for sh, trgt_sh_name, c_rels in rel + rela:
|
33
33
|
target_image_off = next(tsh for tsh in sections if tsh.name == trgt_sh_name).header.sh_addr
|
34
34
|
rels = [(r.r_offset, symtab[libc.ELF64_R_SYM(r.r_info)], libc.ELF64_R_TYPE(r.r_info), getattr(r, "r_addend", 0)) for r in c_rels]
|
35
|
+
for roff, sym, r_type_, r_addend in rels:
|
36
|
+
if sym.st_shndx == 0: raise RuntimeError(f'Attempting to relocate against an undefined symbol {repr(_strtab(sh_strtab, sym.st_name))}')
|
35
37
|
relocs += [(target_image_off + roff, sections[sym.st_shndx].header.sh_addr + sym.st_value, rtype, raddend) for roff, sym, rtype, raddend in rels]
|
36
38
|
|
37
39
|
return memoryview(image), sections, relocs
|
tinygrad/runtime/support/hcq.py
CHANGED
@@ -4,7 +4,7 @@ import contextlib, decimal, statistics, time, ctypes, array, os, fcntl
|
|
4
4
|
from tinygrad.helpers import PROFILE, from_mv, getenv, to_mv, round_up
|
5
5
|
from tinygrad.renderer import Renderer
|
6
6
|
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileRangeEvent, ProfileDeviceEvent
|
7
|
-
from tinygrad.ops import sym_infer, sint, Variable
|
7
|
+
from tinygrad.ops import sym_infer, sint, Variable, UOp
|
8
8
|
from tinygrad.runtime.autogen import libc
|
9
9
|
|
10
10
|
class HWInterface:
|
@@ -19,9 +19,11 @@ class HWInterface:
|
|
19
19
|
if hasattr(self, 'fd'): os.close(self.fd)
|
20
20
|
def ioctl(self, request, arg): return fcntl.ioctl(self.fd, request, arg)
|
21
21
|
def mmap(self, start, sz, prot, flags, offset): return libc.mmap(start, sz, prot, flags, self.fd, offset)
|
22
|
-
def read(self, size=None, binary=False):
|
22
|
+
def read(self, size=None, binary=False, offset=None):
|
23
|
+
if offset is not None: self.seek(offset)
|
23
24
|
with open(self.fd, "rb" if binary else "r", closefd=False) as file: return file.read(size)
|
24
|
-
def write(self, content, binary=False):
|
25
|
+
def write(self, content, binary=False, offset=None):
|
26
|
+
if offset is not None: self.seek(offset)
|
25
27
|
with open(self.fd, "wb" if binary else "w", closefd=False) as file: file.write(content)
|
26
28
|
def listdir(self): return os.listdir(self.path)
|
27
29
|
def seek(self, offset): os.lseek(self.fd, offset, os.SEEK_SET)
|
@@ -83,10 +85,10 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
|
83
85
|
"""
|
84
86
|
|
85
87
|
for v in values:
|
86
|
-
if isinstance(v,
|
87
|
-
else:
|
88
|
+
if isinstance(v, UOp):
|
88
89
|
self.q_sints.append((len(self._q), self._new_sym(v)))
|
89
90
|
self._q.append(0xbadc0ded)
|
91
|
+
else: self._q.append(v)
|
90
92
|
|
91
93
|
# *** common commands ***
|
92
94
|
|
tinygrad/runtime/support/llvm.py
CHANGED
@@ -6,27 +6,21 @@ if sys.platform == 'win32':
|
|
6
6
|
# winget also doesn't have something like `brew --prefix llvm` so just hardcode default installation path with an option to override
|
7
7
|
LLVM_PATH = getenv('LLVM_PATH', 'C:\\Program Files\\LLVM\\bin\\LLVM-C.dll')
|
8
8
|
if not os.path.exists(LLVM_PATH):
|
9
|
-
raise
|
10
|
-
elif OSX and 'tinygrad.runtime.ops_metal' in sys.modules:
|
11
|
-
# Opening METAL after LLVM doesn't fail because ctypes.CDLL opens with RTLD_LOCAL but MTLCompiler opens it's own llvm with RTLD_GLOBAL
|
12
|
-
# This means that MTLCompiler's llvm will create it's own instances of global state because RTLD_LOCAL doesn't export symbols, but if RTLD_GLOBAL
|
13
|
-
# library is loaded first then RTLD_LOCAL library will just use it's symbols. On linux there is RTLD_DEEPBIND to prevent that, but on macos there
|
14
|
-
# doesn't seem to be anything we can do.
|
15
|
-
LLVM_PATH = ctypes.util.find_library('tinyllvm')
|
16
|
-
if LLVM_PATH is None:
|
17
|
-
raise RuntimeError("LLVM can't be opened in the same process with metal. You can install llvm distribution which supports that via `brew install uuuvn/tinygrad/tinyllvm`") # noqa: E501
|
9
|
+
raise FileNotFoundError('LLVM not found, you can install it with `winget install LLVM.LLVM` or point at a custom dll with LLVM_PATH')
|
18
10
|
elif OSX:
|
11
|
+
# Will raise FileNotFoundError if brew is not installed
|
19
12
|
brew_prefix = subprocess.check_output(['brew', '--prefix', 'llvm']).decode().strip()
|
20
13
|
# `brew --prefix` will return even if formula is not installed
|
21
14
|
if not os.path.exists(brew_prefix):
|
22
|
-
raise
|
23
|
-
LLVM_PATH = os.path.join(brew_prefix, 'lib', 'libLLVM.dylib')
|
15
|
+
raise FileNotFoundError('LLVM not found, you can install it with `brew install llvm`')
|
16
|
+
LLVM_PATH: str|None = os.path.join(brew_prefix, 'lib', 'libLLVM.dylib')
|
24
17
|
else:
|
25
18
|
LLVM_PATH = ctypes.util.find_library('LLVM')
|
26
|
-
|
19
|
+
# use newer LLVM if possible
|
20
|
+
for ver in reversed(range(14, 19+1)):
|
27
21
|
if LLVM_PATH is not None: break
|
28
22
|
LLVM_PATH = ctypes.util.find_library(f'LLVM-{ver}')
|
29
23
|
if LLVM_PATH is None:
|
30
|
-
raise
|
24
|
+
raise FileNotFoundError("No LLVM library found on the system. Install it via your distro's package manager and ensure it's findable as 'LLVM'")
|
31
25
|
|
32
|
-
if DEBUG>=
|
26
|
+
if DEBUG>=3: print(f'Using LLVM at {repr(LLVM_PATH)}')
|
tinygrad/shape/shapetracker.py
CHANGED
@@ -6,8 +6,8 @@ from typing import Optional, Callable
|
|
6
6
|
from tinygrad.helpers import merge_dicts, getenv
|
7
7
|
from tinygrad.shape.view import View, strides_for_shape, unravel
|
8
8
|
from tinygrad.dtype import dtypes
|
9
|
-
from tinygrad.ops import UOp, Ops, graph_rewrite,
|
10
|
-
from tinygrad.codegen.
|
9
|
+
from tinygrad.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context
|
10
|
+
from tinygrad.codegen.symbolic import sym, split_uop, symbolic_flat, uop_given_valid, simplify_valid
|
11
11
|
|
12
12
|
def overflow(u: UOp): return u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int)
|
13
13
|
|
@@ -109,6 +109,7 @@ class ShapeTracker:
|
|
109
109
|
|
110
110
|
def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]:
|
111
111
|
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
|
112
|
+
if all(len(x) == 0 for x in var_vals): return self, {}
|
112
113
|
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
|
113
114
|
|
114
115
|
def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]: return views_to_real_strides(self.views, ignore_valid)
|
tinygrad/shape/view.py
CHANGED
@@ -107,8 +107,7 @@ class View:
|
|
107
107
|
@staticmethod
|
108
108
|
@functools.lru_cache(maxsize=None)
|
109
109
|
def create(shape:tuple[sint, ...], strides:Optional[tuple[sint, ...]]=None, offset:sint=0, mask:Optional[tuple[tuple[sint, sint], ...]]=None):
|
110
|
-
|
111
|
-
if not all(resolve(s >= 0) for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
|
110
|
+
if not all(s >= 0 for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
|
112
111
|
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
|
113
112
|
# canonicalize 0 in shape
|
114
113
|
if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
|
@@ -274,7 +273,7 @@ class View:
|
|
274
273
|
def reshape(self, new_shape: tuple[sint, ...]) -> Optional[View]:
|
275
274
|
if self.shape == new_shape: return self
|
276
275
|
|
277
|
-
|
276
|
+
if not all(x >= 0 for x in new_shape): raise ValueError(f"shape can't contain negative numbers {new_shape}")
|
278
277
|
# check for the same size
|
279
278
|
if (self_all_int := all_int(self.shape)):
|
280
279
|
assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
tinygrad/spec.py
CHANGED
@@ -1,21 +1,25 @@
|
|
1
1
|
from typing import cast
|
2
2
|
from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops
|
3
3
|
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
|
4
|
-
from tinygrad.helpers import
|
4
|
+
from tinygrad.helpers import all_same, dedup, prod
|
5
5
|
|
6
|
-
|
7
|
-
|
8
|
-
tensor_uop_spec = PatternMatcher([
|
6
|
+
buffer_spec = PatternMatcher([
|
7
|
+
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
|
9
8
|
(UPat(Ops.DEVICE, dtypes.void, (), name="device"), lambda device: isinstance(device.arg, str)),
|
10
|
-
(UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE),), name="buf"),
|
11
|
-
lambda buf: isinstance(buf.arg,
|
9
|
+
(UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="buf"),
|
10
|
+
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
|
11
|
+
])
|
12
12
|
|
13
|
+
# *** this is the spec of a Tensor in UOp ***
|
14
|
+
|
15
|
+
tensor_uop_spec = buffer_spec+PatternMatcher([
|
13
16
|
(UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)),
|
14
17
|
# naturally correct
|
15
18
|
lambda mv,x: (isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
|
16
19
|
# "make things that can't be images not images" can change the buffer dtype
|
17
20
|
# this is fine as long as it's a realized buffer and base dtypes match.
|
18
21
|
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.is_realized)),
|
22
|
+
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.CONST, Ops.DEVICE}),)), lambda: False),
|
19
23
|
|
20
24
|
# Tensor variable bindings
|
21
25
|
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
|
@@ -32,11 +36,6 @@ tensor_uop_spec = PatternMatcher([
|
|
32
36
|
# NOTE: the arg here specifies clone=True, which prevents folding same device copy
|
33
37
|
(UPat(Ops.COPY, name="copy", src=(UPat(Ops.DEVICE), UPat.var("x"))), lambda copy,x: isinstance(copy.arg, bool) and copy.dtype == x.dtype),
|
34
38
|
|
35
|
-
# VIEW(BUFFER) applies a ShapeTracker on top of the underlying device buffer
|
36
|
-
# NOTE: VIEW size exactly matches the underlying BUFFER, tensor doesn't apply movement ops to the VIEW
|
37
|
-
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),
|
38
|
-
lambda view,buf: view.dtype == buf.dtype and view.size == buf.size and view.st.contiguous),
|
39
|
-
|
40
39
|
# ASSIGN changes the value of a realized buffer
|
41
40
|
(UPat(Ops.ASSIGN, name="assign", src=(UPat.var("target"), UPat.var("new_val"))),
|
42
41
|
lambda assign,target,new_val: target.is_realized and (assign.dtype == target.dtype == new_val.dtype)),
|
@@ -58,7 +57,7 @@ spec = PatternMatcher([
|
|
58
57
|
|
59
58
|
# TODO: confirm the args of both of these are shapetrackers
|
60
59
|
(UPat(Ops.VIEW, dtypes.void, src=()), lambda: True),
|
61
|
-
(UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype),
|
60
|
+
(UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype.base == src.dtype.base),
|
62
61
|
|
63
62
|
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
|
64
63
|
(UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
@@ -113,9 +112,9 @@ spec = PatternMatcher([
|
|
113
112
|
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
|
114
113
|
|
115
114
|
# NOTE: for testing, we let sinks be anything
|
116
|
-
#(UPat(
|
117
|
-
(UPat(Ops.SINK, dtypes.void), lambda: True),
|
118
|
-
(UPat(Ops.NOOP), lambda: True),
|
115
|
+
#(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
|
116
|
+
(UPat((Ops.NAME, Ops.SINK), dtypes.void), lambda: True),
|
117
|
+
(UPat((Ops.NOOP, Ops.CUSTOM)), lambda: True),
|
119
118
|
|
120
119
|
# PTX LOAD/STORE
|
121
120
|
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
@@ -123,11 +122,13 @@ spec = PatternMatcher([
|
|
123
122
|
|
124
123
|
# *** this is the spec of a Kernel in UOp ***
|
125
124
|
|
126
|
-
kernel_spec = PatternMatcher([
|
127
|
-
(UPat(Ops.
|
128
|
-
|
129
|
-
|
130
|
-
|
125
|
+
kernel_spec = buffer_spec+PatternMatcher([
|
126
|
+
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True),
|
127
|
+
# assign has a buffer view and kernel source, it can optionally depend on other assigns
|
128
|
+
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
|
129
|
+
# view/sink/const can also exist in the kernel graph
|
130
|
+
(UPat((Ops.VIEW, Ops.SINK, Ops.CONST)), lambda: True),
|
131
|
+
(UPat(GroupOp.All), lambda: False),
|
131
132
|
])
|
132
133
|
|
133
134
|
# *** this is the UOp shape spec ***
|