tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +35 -37
  4. tinygrad/codegen/linearize.py +19 -10
  5. tinygrad/codegen/lowerer.py +31 -8
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +10 -0
  8. tinygrad/device.py +28 -11
  9. tinygrad/dtype.py +12 -3
  10. tinygrad/engine/jit.py +3 -2
  11. tinygrad/engine/multi.py +0 -1
  12. tinygrad/engine/realize.py +7 -4
  13. tinygrad/engine/schedule.py +227 -255
  14. tinygrad/engine/search.py +20 -27
  15. tinygrad/gradient.py +3 -0
  16. tinygrad/helpers.py +7 -4
  17. tinygrad/nn/state.py +2 -2
  18. tinygrad/ops.py +64 -329
  19. tinygrad/renderer/__init__.py +19 -3
  20. tinygrad/renderer/cstyle.py +39 -18
  21. tinygrad/renderer/llvmir.py +55 -18
  22. tinygrad/renderer/ptx.py +6 -2
  23. tinygrad/renderer/wgsl.py +20 -12
  24. tinygrad/runtime/autogen/libc.py +404 -71
  25. tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
  26. tinygrad/runtime/autogen/webgpu.py +6985 -0
  27. tinygrad/runtime/graph/metal.py +28 -29
  28. tinygrad/runtime/ops_amd.py +37 -34
  29. tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
  30. tinygrad/runtime/ops_disk.py +1 -1
  31. tinygrad/runtime/ops_dsp.py +59 -33
  32. tinygrad/runtime/ops_llvm.py +14 -12
  33. tinygrad/runtime/ops_metal.py +78 -62
  34. tinygrad/runtime/ops_nv.py +9 -6
  35. tinygrad/runtime/ops_python.py +5 -5
  36. tinygrad/runtime/ops_webgpu.py +200 -38
  37. tinygrad/runtime/support/am/amdev.py +23 -11
  38. tinygrad/runtime/support/am/ip.py +10 -10
  39. tinygrad/runtime/support/elf.py +2 -0
  40. tinygrad/runtime/support/hcq.py +7 -5
  41. tinygrad/runtime/support/llvm.py +8 -14
  42. tinygrad/shape/shapetracker.py +3 -2
  43. tinygrad/shape/view.py +2 -3
  44. tinygrad/spec.py +21 -20
  45. tinygrad/tensor.py +150 -90
  46. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  47. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  48. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  49. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  50. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  51. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  52. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  53. tinygrad/viz/index.html +544 -0
  54. tinygrad/viz/perfetto.html +178 -0
  55. tinygrad/viz/serve.py +205 -0
  56. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
  57. tinygrad-0.10.2.dist-info/RECORD +99 -0
  58. tinygrad/codegen/rewriter.py +0 -516
  59. tinygrad-0.10.1.dist-info/RECORD +0 -86
  60. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  61. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
  62. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
  import ctypes, collections, time, dataclasses, pathlib, fcntl, os
3
3
  from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
4
- from tinygrad.runtime.autogen.am import am, mp_11_0, mp_13_0_0, nbio_4_3_0, mmhub_3_0_0, gc_11_0_0, osssys_6_0_0
4
+ from tinygrad.runtime.autogen.am import am, mp_11_0
5
5
  from tinygrad.runtime.support.allocator import TLSFAllocator
6
6
  from tinygrad.runtime.support.am.ip import AM_SOC21, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
7
7
 
@@ -32,11 +32,13 @@ class AMRegister:
32
32
  def read(self, **kwargs): return self.adev.rreg(self.reg_off) & self._parse_kwargs(**kwargs)[0]
33
33
 
34
34
  class AMFirmware:
35
- def __init__(self):
35
+ def __init__(self, adev):
36
+ def fmt_ver(hwip): return f"{adev.ip_versions[hwip]//10000}_{(adev.ip_versions[hwip]//100)%100}_{adev.ip_versions[hwip]%100}"
37
+
36
38
  # Load SOS firmware
37
39
  self.sos_fw = {}
38
40
 
39
- blob, sos_hdr = self.load_fw("psp_13_0_0_sos.bin", am.struct_psp_firmware_header_v2_0)
41
+ blob, sos_hdr = self.load_fw(f"psp_{fmt_ver(am.MP0_HWIP)}_sos.bin", am.struct_psp_firmware_header_v2_0)
40
42
  fw_bin = sos_hdr.psp_fw_bin
41
43
 
42
44
  for fw_i in range(sos_hdr.psp_fw_bin_count):
@@ -48,17 +50,17 @@ class AMFirmware:
48
50
  self.ucode_start: dict[str, int] = {}
49
51
  self.descs: list[tuple[int, memoryview]] = []
50
52
 
51
- blob, hdr = self.load_fw("smu_13_0_0.bin", am.struct_smc_firmware_header_v1_0)
53
+ blob, hdr = self.load_fw(f"smu_{fmt_ver(am.MP1_HWIP)}.bin", am.struct_smc_firmware_header_v1_0)
52
54
  self.smu_psp_desc = self.desc(am.GFX_FW_TYPE_SMU, blob, hdr.header.ucode_array_offset_bytes, hdr.header.ucode_size_bytes)
53
55
 
54
56
  # SDMA firmware
55
- blob, hdr = self.load_fw("sdma_6_0_0.bin", am.struct_sdma_firmware_header_v2_0)
57
+ blob, hdr = self.load_fw(f"sdma_{fmt_ver(am.SDMA0_HWIP)}.bin", am.struct_sdma_firmware_header_v2_0)
56
58
  self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH0, blob, hdr.header.ucode_array_offset_bytes, hdr.ctx_ucode_size_bytes)]
57
59
  self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH1, blob, hdr.ctl_ucode_offset, hdr.ctl_ucode_size_bytes)]
58
60
 
59
61
  # PFP, ME, MEC firmware
60
62
  for (fw_name, fw_cnt) in [('PFP', 2), ('ME', 2), ('MEC', 4)]:
61
- blob, hdr = self.load_fw(f"gc_11_0_0_{fw_name.lower()}.bin", am.struct_gfx_firmware_header_v2_0)
63
+ blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_{fw_name.lower()}.bin", am.struct_gfx_firmware_header_v2_0)
62
64
 
63
65
  # Code part
64
66
  self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}'), blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes)]
@@ -69,12 +71,12 @@ class AMFirmware:
69
71
  self.ucode_start[fw_name] = hdr.ucode_start_addr_lo | (hdr.ucode_start_addr_hi << 32)
70
72
 
71
73
  # IMU firmware
72
- blob, hdr = self.load_fw("gc_11_0_0_imu.bin", am.struct_imu_firmware_header_v1_0)
74
+ blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_imu.bin", am.struct_imu_firmware_header_v1_0)
73
75
  imu_i_off, imu_i_sz, imu_d_sz = hdr.header.ucode_array_offset_bytes, hdr.imu_iram_ucode_size_bytes, hdr.imu_dram_ucode_size_bytes
74
76
  self.descs += [self.desc(am.GFX_FW_TYPE_IMU_I, blob, imu_i_off, imu_i_sz), self.desc(am.GFX_FW_TYPE_IMU_D, blob, imu_i_off + imu_i_sz, imu_d_sz)]
75
77
 
76
78
  # RLC firmware
77
- blob, hdr0, hdr1, hdr2, hdr3 = self.load_fw("gc_11_0_0_rlc.bin", am.struct_rlc_firmware_header_v2_0,
79
+ blob, hdr0, hdr1, hdr2, hdr3 = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_rlc.bin", am.struct_rlc_firmware_header_v2_0,
78
80
  am.struct_rlc_firmware_header_v2_1, am.struct_rlc_firmware_header_v2_2, am.struct_rlc_firmware_header_v2_3)
79
81
 
80
82
  for mem in ['GPM', 'SRM']:
@@ -263,7 +265,7 @@ class AMDev:
263
265
 
264
266
  # Memory manager & firmware
265
267
  self.mm = AMMemoryManager(self, self.vram_size)
266
- self.fw = AMFirmware()
268
+ self.fw = AMFirmware(self)
267
269
 
268
270
  # Initialize IP blocks
269
271
  self.soc21:AM_SOC21 = AM_SOC21(self)
@@ -274,7 +276,7 @@ class AMDev:
274
276
  self.gfx:AM_GFX = AM_GFX(self)
275
277
  self.sdma:AM_SDMA = AM_SDMA(self)
276
278
 
277
- if self.partial_boot and (self.reg("regCP_MEC_RS64_CNTL").read() & gc_11_0_0.CP_MEC_RS64_CNTL__MEC_HALT_MASK == 0):
279
+ if self.partial_boot and (self.reg("regGCVM_CONTEXT0_CNTL").read() != 0):
278
280
  if DEBUG >= 2: print(f"am {self.devfmt}: MEC is active. Issue a full reset.")
279
281
  self.partial_boot = False
280
282
 
@@ -298,8 +300,10 @@ class AMDev:
298
300
  if DEBUG >= 2: print(f"am {self.devfmt}: boot done")
299
301
 
300
302
  def fini(self):
303
+ if DEBUG >= 2: print(f"am {self.devfmt}: Finalizing")
301
304
  for ip in [self.sdma, self.gfx]: ip.fini()
302
305
  self.smu.set_clocks(level=0)
306
+ self.ih.interrupt_handler()
303
307
 
304
308
  def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr
305
309
  def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr
@@ -369,8 +373,16 @@ class AMDev:
369
373
  gc_info = am.struct_gc_info_v1_0.from_address(gc_addr:=ctypes.addressof(bhdr) + bhdr.table_list[am.GC].offset)
370
374
  self.gc_info = getattr(am, f"struct_gc_info_v{gc_info.header.version_major}_{gc_info.header.version_minor}").from_address(gc_addr)
371
375
 
376
+ def _ip_module(self, prefix:str, hwip):
377
+ version = [self.ip_versions[hwip]//10000, (self.ip_versions[hwip]//100)%100, self.ip_versions[hwip]%100]
378
+ for ver in [version, version[:2]+[0], version[:1]+[0, 0]]:
379
+ try: return __import__(f"tinygrad.runtime.autogen.am.{prefix}_{ver[0]}_{ver[1]}_{ver[2]}", fromlist=[f"{prefix}_{ver[0]}_{ver[1]}_{ver[2]}"])
380
+ except ImportError: pass
381
+ raise ImportError(f"am {self.devfmt}: failed to load {prefix} module with version {version}")
382
+
372
383
  def _build_regs(self):
373
- mods = [("MP0", mp_13_0_0), ("MP1", mp_11_0), ("NBIO", nbio_4_3_0), ("MMHUB", mmhub_3_0_0), ("GC", gc_11_0_0), ("OSSSYS", osssys_6_0_0)]
384
+ mods = [("MP0", self._ip_module("mp", am.MP0_HWIP)), ("NBIO", self._ip_module("nbio", am.NBIO_HWIP)), ("GC", self._ip_module("gc", am.GC_HWIP)),
385
+ ("MP1", mp_11_0), ("MMHUB", self._ip_module("mmhub", am.MMHUB_HWIP)), ("OSSSYS", self._ip_module("osssys", am.OSSSYS_HWIP))]
374
386
  for base, module in mods:
375
387
  rpref = "mm" if base == "MP1" else "reg" # MP1 regs starts with mm
376
388
  reg_names: set[str] = set(k[len(rpref):] for k in module.__dict__.keys() if k.startswith(rpref) and not k.endswith("_BASE_IDX"))
@@ -18,7 +18,7 @@ class AM_GMC(AM_IP):
18
18
  super().__init__(adev)
19
19
 
20
20
  # Memory controller aperture
21
- self.mc_base = self.adev.regMMMC_VM_FB_LOCATION_BASE.read() << 24
21
+ self.mc_base = (self.adev.regMMMC_VM_FB_LOCATION_BASE.read() & 0xFFFFFF) << 24
22
22
  self.mc_end = self.mc_base + self.adev.mm.vram_size - 1
23
23
 
24
24
  # VM aperture
@@ -189,8 +189,6 @@ class AM_GFX(AM_IP):
189
189
  self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2) # 1 - DRAIN_PIPE; 2 - RESET_WAVES
190
190
  self.adev.regSPI_COMPUTE_QUEUE_RESET.write(1)
191
191
  self._grbm_select()
192
- self.adev.regCP_MEC_RS64_CNTL.update(mec_invalidate_icache=1, mec_pipe0_reset=1, mec_pipe1_reset=1, mec_pipe2_reset=1, mec_pipe3_reset=1,
193
- mec_pipe0_active=0, mec_pipe1_active=0, mec_pipe2_active=0, mec_pipe3_active=0, mec_halt=1)
194
192
  self.adev.regGCVM_CONTEXT0_CNTL.write(0)
195
193
 
196
194
  def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, doorbell:int, pipe:int, queue:int):
@@ -225,6 +223,8 @@ class AM_GFX(AM_IP):
225
223
  self.adev.reg(f"regCP_ME1_PIPE{pipe}_INT_CNTL").update(time_stamp_int_enable=1, generic0_int_enable=1)
226
224
 
227
225
  def set_clockgating_state(self):
226
+ if hasattr(self.adev, 'regMM_ATC_L2_MISC_CG'): self.adev.regMM_ATC_L2_MISC_CG.write(enable=1, mem_ls_enable=1)
227
+
228
228
  self.adev.regRLC_SAFE_MODE.write(message=1, cmd=1)
229
229
  self.adev.wait_reg(self.adev.regRLC_SAFE_MODE, mask=0x1, value=0x0)
230
230
 
@@ -233,6 +233,7 @@ class AM_GFX(AM_IP):
233
233
  self.adev.regCP_RB_WPTR_POLL_CNTL.update(poll_frequency=0x100, idle_poll_count=0x90)
234
234
  self.adev.regCP_INT_CNTL.update(cntx_busy_int_enable=1, cntx_empty_int_enable=1, cmp_busy_int_enable=1, gfx_idle_int_enable=1)
235
235
  self.adev.regSDMA0_RLC_CGCG_CTRL.update(cgcg_int_enable=1)
236
+ self.adev.regSDMA1_RLC_CGCG_CTRL.update(cgcg_int_enable=1)
236
237
 
237
238
  self.adev.regRLC_CGTT_MGCG_OVERRIDE.update(perfmon_clock_state=0, gfxip_fgcg_override=0, gfxip_repeater_fgcg_override=0,
238
239
  grbm_cgtt_sclk_override=0, rlc_cgtt_sclk_override=0, gfxip_mgcg_override=0, gfxip_cgls_override=0, gfxip_cgcg_override=0)
@@ -311,17 +312,16 @@ class AM_SDMA(AM_IP):
311
312
  self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_IB_CNTL").update(ib_enable=1)
312
313
 
313
314
  def init(self):
314
- self.adev.regSDMA0_SEM_WAIT_FAIL_TIMER_CNTL.write(0x0)
315
- self.adev.regSDMA0_WATCHDOG_CNTL.update(queue_hang_count=100) # 10s, 100ms per unit
316
- self.adev.regSDMA0_UTCL1_CNTL.update(resp_mode=3, redo_delay=9)
317
- self.adev.regSDMA0_UTCL1_PAGE.update(rd_l2_policy=0x2, wr_l2_policy=0x3, llc_noalloc=1) # rd=noa, wr=bypass
318
- self.adev.regSDMA0_F32_CNTL.update(halt=0, th1_reset=0)
319
- self.adev.regSDMA0_CNTL.update(ctxempty_int_enable=1, trap_enable=1)
315
+ for pipe in range(2):
316
+ self.adev.reg(f"regSDMA{pipe}_WATCHDOG_CNTL").update(queue_hang_count=100) # 10s, 100ms per unit
317
+ self.adev.reg(f"regSDMA{pipe}_UTCL1_CNTL").update(resp_mode=3, redo_delay=9)
318
+ self.adev.reg(f"regSDMA{pipe}_UTCL1_PAGE").update(rd_l2_policy=0x2, wr_l2_policy=0x3, llc_noalloc=1) # rd=noa, wr=bypass
319
+ self.adev.reg(f"regSDMA{pipe}_F32_CNTL").update(halt=0, th1_reset=0)
320
+ self.adev.reg(f"regSDMA{pipe}_CNTL").update(ctxempty_int_enable=1, trap_enable=1)
320
321
 
321
322
  def fini(self):
322
323
  self.adev.regSDMA0_QUEUE0_RB_CNTL.update(rb_enable=0)
323
324
  self.adev.regSDMA0_QUEUE0_IB_CNTL.update(ib_enable=0)
324
- self.adev.regSDMA0_F32_CNTL.update(halt=1, th1_reset=1)
325
325
  self.adev.regGRBM_SOFT_RESET.write(soft_reset_sdma0=1)
326
326
  time.sleep(0.01)
327
327
  self.adev.regGRBM_SOFT_RESET.write(0x0)
@@ -32,6 +32,8 @@ def elf_loader(blob:bytes, force_section_align:int=1) -> tuple[memoryview, list[
32
32
  for sh, trgt_sh_name, c_rels in rel + rela:
33
33
  target_image_off = next(tsh for tsh in sections if tsh.name == trgt_sh_name).header.sh_addr
34
34
  rels = [(r.r_offset, symtab[libc.ELF64_R_SYM(r.r_info)], libc.ELF64_R_TYPE(r.r_info), getattr(r, "r_addend", 0)) for r in c_rels]
35
+ for roff, sym, r_type_, r_addend in rels:
36
+ if sym.st_shndx == 0: raise RuntimeError(f'Attempting to relocate against an undefined symbol {repr(_strtab(sh_strtab, sym.st_name))}')
35
37
  relocs += [(target_image_off + roff, sections[sym.st_shndx].header.sh_addr + sym.st_value, rtype, raddend) for roff, sym, rtype, raddend in rels]
36
38
 
37
39
  return memoryview(image), sections, relocs
@@ -4,7 +4,7 @@ import contextlib, decimal, statistics, time, ctypes, array, os, fcntl
4
4
  from tinygrad.helpers import PROFILE, from_mv, getenv, to_mv, round_up
5
5
  from tinygrad.renderer import Renderer
6
6
  from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileRangeEvent, ProfileDeviceEvent
7
- from tinygrad.ops import sym_infer, sint, Variable
7
+ from tinygrad.ops import sym_infer, sint, Variable, UOp
8
8
  from tinygrad.runtime.autogen import libc
9
9
 
10
10
  class HWInterface:
@@ -19,9 +19,11 @@ class HWInterface:
19
19
  if hasattr(self, 'fd'): os.close(self.fd)
20
20
  def ioctl(self, request, arg): return fcntl.ioctl(self.fd, request, arg)
21
21
  def mmap(self, start, sz, prot, flags, offset): return libc.mmap(start, sz, prot, flags, self.fd, offset)
22
- def read(self, size=None, binary=False):
22
+ def read(self, size=None, binary=False, offset=None):
23
+ if offset is not None: self.seek(offset)
23
24
  with open(self.fd, "rb" if binary else "r", closefd=False) as file: return file.read(size)
24
- def write(self, content, binary=False):
25
+ def write(self, content, binary=False, offset=None):
26
+ if offset is not None: self.seek(offset)
25
27
  with open(self.fd, "wb" if binary else "w", closefd=False) as file: file.write(content)
26
28
  def listdir(self): return os.listdir(self.path)
27
29
  def seek(self, offset): os.lseek(self.fd, offset, os.SEEK_SET)
@@ -83,10 +85,10 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
83
85
  """
84
86
 
85
87
  for v in values:
86
- if isinstance(v, int): self._q.append(v)
87
- else:
88
+ if isinstance(v, UOp):
88
89
  self.q_sints.append((len(self._q), self._new_sym(v)))
89
90
  self._q.append(0xbadc0ded)
91
+ else: self._q.append(v)
90
92
 
91
93
  # *** common commands ***
92
94
 
@@ -6,27 +6,21 @@ if sys.platform == 'win32':
6
6
  # winget also doesn't have something like `brew --prefix llvm` so just hardcode default installation path with an option to override
7
7
  LLVM_PATH = getenv('LLVM_PATH', 'C:\\Program Files\\LLVM\\bin\\LLVM-C.dll')
8
8
  if not os.path.exists(LLVM_PATH):
9
- raise RuntimeError('LLVM not found, you can install it with `winget install LLVM.LLVM` or point at a custom dll with LLVM_PATH')
10
- elif OSX and 'tinygrad.runtime.ops_metal' in sys.modules:
11
- # Opening METAL after LLVM doesn't fail because ctypes.CDLL opens with RTLD_LOCAL but MTLCompiler opens it's own llvm with RTLD_GLOBAL
12
- # This means that MTLCompiler's llvm will create it's own instances of global state because RTLD_LOCAL doesn't export symbols, but if RTLD_GLOBAL
13
- # library is loaded first then RTLD_LOCAL library will just use it's symbols. On linux there is RTLD_DEEPBIND to prevent that, but on macos there
14
- # doesn't seem to be anything we can do.
15
- LLVM_PATH = ctypes.util.find_library('tinyllvm')
16
- if LLVM_PATH is None:
17
- raise RuntimeError("LLVM can't be opened in the same process with metal. You can install llvm distribution which supports that via `brew install uuuvn/tinygrad/tinyllvm`") # noqa: E501
9
+ raise FileNotFoundError('LLVM not found, you can install it with `winget install LLVM.LLVM` or point at a custom dll with LLVM_PATH')
18
10
  elif OSX:
11
+ # Will raise FileNotFoundError if brew is not installed
19
12
  brew_prefix = subprocess.check_output(['brew', '--prefix', 'llvm']).decode().strip()
20
13
  # `brew --prefix` will return even if formula is not installed
21
14
  if not os.path.exists(brew_prefix):
22
- raise RuntimeError('LLVM not found, you can install it with `brew install llvm`')
23
- LLVM_PATH = os.path.join(brew_prefix, 'lib', 'libLLVM.dylib')
15
+ raise FileNotFoundError('LLVM not found, you can install it with `brew install llvm`')
16
+ LLVM_PATH: str|None = os.path.join(brew_prefix, 'lib', 'libLLVM.dylib')
24
17
  else:
25
18
  LLVM_PATH = ctypes.util.find_library('LLVM')
26
- for ver in range(14, 19+1):
19
+ # use newer LLVM if possible
20
+ for ver in reversed(range(14, 19+1)):
27
21
  if LLVM_PATH is not None: break
28
22
  LLVM_PATH = ctypes.util.find_library(f'LLVM-{ver}')
29
23
  if LLVM_PATH is None:
30
- raise RuntimeError("No LLVM library found on the system. Install it via your distro's package manager and ensure it's findable as 'LLVM'")
24
+ raise FileNotFoundError("No LLVM library found on the system. Install it via your distro's package manager and ensure it's findable as 'LLVM'")
31
25
 
32
- if DEBUG>=2: print(f'Using LLVM at {repr(LLVM_PATH)}')
26
+ if DEBUG>=3: print(f'Using LLVM at {repr(LLVM_PATH)}')
@@ -6,8 +6,8 @@ from typing import Optional, Callable
6
6
  from tinygrad.helpers import merge_dicts, getenv
7
7
  from tinygrad.shape.view import View, strides_for_shape, unravel
8
8
  from tinygrad.dtype import dtypes
9
- from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid, sint_to_uop, Context
10
- from tinygrad.codegen.rewriter import sym
9
+ from tinygrad.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context
10
+ from tinygrad.codegen.symbolic import sym, split_uop, symbolic_flat, uop_given_valid, simplify_valid
11
11
 
12
12
  def overflow(u: UOp): return u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int)
13
13
 
@@ -109,6 +109,7 @@ class ShapeTracker:
109
109
 
110
110
  def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]:
111
111
  unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
112
+ if all(len(x) == 0 for x in var_vals): return self, {}
112
113
  return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
113
114
 
114
115
  def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]: return views_to_real_strides(self.views, ignore_valid)
tinygrad/shape/view.py CHANGED
@@ -107,8 +107,7 @@ class View:
107
107
  @staticmethod
108
108
  @functools.lru_cache(maxsize=None)
109
109
  def create(shape:tuple[sint, ...], strides:Optional[tuple[sint, ...]]=None, offset:sint=0, mask:Optional[tuple[tuple[sint, sint], ...]]=None):
110
- # TODO: this resolve shouldn't be needed
111
- if not all(resolve(s >= 0) for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
110
+ if not all(s >= 0 for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
112
111
  strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
113
112
  # canonicalize 0 in shape
114
113
  if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
@@ -274,7 +273,7 @@ class View:
274
273
  def reshape(self, new_shape: tuple[sint, ...]) -> Optional[View]:
275
274
  if self.shape == new_shape: return self
276
275
 
277
- assert all(x >= 0 for x in new_shape), f"shape can't contain negative numbers {new_shape}"
276
+ if not all(x >= 0 for x in new_shape): raise ValueError(f"shape can't contain negative numbers {new_shape}")
278
277
  # check for the same size
279
278
  if (self_all_int := all_int(self.shape)):
280
279
  assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
tinygrad/spec.py CHANGED
@@ -1,21 +1,25 @@
1
1
  from typing import cast
2
2
  from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops
3
3
  from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
4
- from tinygrad.helpers import all_int, all_same, dedup, prod
4
+ from tinygrad.helpers import all_same, dedup, prod
5
5
 
6
- # *** this is the spec of a Tensor in UOp ***
7
-
8
- tensor_uop_spec = PatternMatcher([
6
+ buffer_spec = PatternMatcher([
7
+ (UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
9
8
  (UPat(Ops.DEVICE, dtypes.void, (), name="device"), lambda device: isinstance(device.arg, str)),
10
- (UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE),), name="buf"),
11
- lambda buf: isinstance(buf.arg, tuple) and len(buf.arg) == 2 and all_int(buf.arg) and isinstance(buf.dtype, (DType, ImageDType))),
9
+ (UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="buf"),
10
+ lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
11
+ ])
12
12
 
13
+ # *** this is the spec of a Tensor in UOp ***
14
+
15
+ tensor_uop_spec = buffer_spec+PatternMatcher([
13
16
  (UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)),
14
17
  # naturally correct
15
18
  lambda mv,x: (isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
16
19
  # "make things that can't be images not images" can change the buffer dtype
17
20
  # this is fine as long as it's a realized buffer and base dtypes match.
18
21
  ((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.is_realized)),
22
+ (UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.CONST, Ops.DEVICE}),)), lambda: False),
19
23
 
20
24
  # Tensor variable bindings
21
25
  (UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
@@ -32,11 +36,6 @@ tensor_uop_spec = PatternMatcher([
32
36
  # NOTE: the arg here specifies clone=True, which prevents folding same device copy
33
37
  (UPat(Ops.COPY, name="copy", src=(UPat(Ops.DEVICE), UPat.var("x"))), lambda copy,x: isinstance(copy.arg, bool) and copy.dtype == x.dtype),
34
38
 
35
- # VIEW(BUFFER) applies a ShapeTracker on top of the underlying device buffer
36
- # NOTE: VIEW size exactly matches the underlying BUFFER, tensor doesn't apply movement ops to the VIEW
37
- (UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),
38
- lambda view,buf: view.dtype == buf.dtype and view.size == buf.size and view.st.contiguous),
39
-
40
39
  # ASSIGN changes the value of a realized buffer
41
40
  (UPat(Ops.ASSIGN, name="assign", src=(UPat.var("target"), UPat.var("new_val"))),
42
41
  lambda assign,target,new_val: target.is_realized and (assign.dtype == target.dtype == new_val.dtype)),
@@ -58,7 +57,7 @@ spec = PatternMatcher([
58
57
 
59
58
  # TODO: confirm the args of both of these are shapetrackers
60
59
  (UPat(Ops.VIEW, dtypes.void, src=()), lambda: True),
61
- (UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype),
60
+ (UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype.base == src.dtype.base),
62
61
 
63
62
  (UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
64
63
  (UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
@@ -113,9 +112,9 @@ spec = PatternMatcher([
113
112
  (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
114
113
 
115
114
  # NOTE: for testing, we let sinks be anything
116
- #(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
117
- (UPat(Ops.SINK, dtypes.void), lambda: True),
118
- (UPat(Ops.NOOP), lambda: True),
115
+ #(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
116
+ (UPat((Ops.NAME, Ops.SINK), dtypes.void), lambda: True),
117
+ (UPat((Ops.NOOP, Ops.CUSTOM)), lambda: True),
119
118
 
120
119
  # PTX LOAD/STORE
121
120
  (UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
@@ -123,11 +122,13 @@ spec = PatternMatcher([
123
122
 
124
123
  # *** this is the spec of a Kernel in UOp ***
125
124
 
126
- kernel_spec = PatternMatcher([
127
- (UPat(Ops.DEVICE, src=()), lambda: True),
128
- (UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE),)), lambda: True),
129
- # TODO: currently kernel only has buffer parents, this is incomplete. it should be BUFFER and ASSIGN
130
- (UPat(Ops.KERNEL, src=UPat(Ops.BUFFER)), lambda: True),
125
+ kernel_spec = buffer_spec+PatternMatcher([
126
+ (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True),
127
+ # assign has a buffer view and kernel source, it can optionally depend on other assigns
128
+ (UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
129
+ # view/sink/const can also exist in the kernel graph
130
+ (UPat((Ops.VIEW, Ops.SINK, Ops.CONST)), lambda: True),
131
+ (UPat(GroupOp.All), lambda: False),
131
132
  ])
132
133
 
133
134
  # *** this is the UOp shape spec ***