tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
 - tinygrad/codegen/expander.py +121 -0
 - tinygrad/codegen/kernel.py +35 -37
 - tinygrad/codegen/linearize.py +19 -10
 - tinygrad/codegen/lowerer.py +31 -8
 - tinygrad/codegen/symbolic.py +476 -0
 - tinygrad/codegen/transcendental.py +10 -0
 - tinygrad/device.py +28 -11
 - tinygrad/dtype.py +12 -3
 - tinygrad/engine/jit.py +3 -2
 - tinygrad/engine/multi.py +0 -1
 - tinygrad/engine/realize.py +7 -4
 - tinygrad/engine/schedule.py +227 -255
 - tinygrad/engine/search.py +20 -27
 - tinygrad/gradient.py +3 -0
 - tinygrad/helpers.py +7 -4
 - tinygrad/nn/state.py +2 -2
 - tinygrad/ops.py +64 -329
 - tinygrad/renderer/__init__.py +19 -3
 - tinygrad/renderer/cstyle.py +39 -18
 - tinygrad/renderer/llvmir.py +55 -18
 - tinygrad/renderer/ptx.py +6 -2
 - tinygrad/renderer/wgsl.py +20 -12
 - tinygrad/runtime/autogen/libc.py +404 -71
 - tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
 - tinygrad/runtime/autogen/webgpu.py +6985 -0
 - tinygrad/runtime/graph/metal.py +28 -29
 - tinygrad/runtime/ops_amd.py +37 -34
 - tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
 - tinygrad/runtime/ops_disk.py +1 -1
 - tinygrad/runtime/ops_dsp.py +59 -33
 - tinygrad/runtime/ops_llvm.py +14 -12
 - tinygrad/runtime/ops_metal.py +78 -62
 - tinygrad/runtime/ops_nv.py +9 -6
 - tinygrad/runtime/ops_python.py +5 -5
 - tinygrad/runtime/ops_webgpu.py +200 -38
 - tinygrad/runtime/support/am/amdev.py +23 -11
 - tinygrad/runtime/support/am/ip.py +10 -10
 - tinygrad/runtime/support/elf.py +2 -0
 - tinygrad/runtime/support/hcq.py +7 -5
 - tinygrad/runtime/support/llvm.py +8 -14
 - tinygrad/shape/shapetracker.py +3 -2
 - tinygrad/shape/view.py +2 -3
 - tinygrad/spec.py +21 -20
 - tinygrad/tensor.py +150 -90
 - tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
 - tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
 - tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
 - tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
 - tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
 - tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
 - tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
 - tinygrad/viz/index.html +544 -0
 - tinygrad/viz/perfetto.html +178 -0
 - tinygrad/viz/serve.py +205 -0
 - {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
 - tinygrad-0.10.2.dist-info/RECORD +99 -0
 - tinygrad/codegen/rewriter.py +0 -516
 - tinygrad-0.10.1.dist-info/RECORD +0 -86
 - {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
 - {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
 - {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -1,7 +1,7 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            from __future__ import annotations
         
     | 
| 
       2 
2 
     | 
    
         
             
            import ctypes, collections, time, dataclasses, pathlib, fcntl, os
         
     | 
| 
       3 
3 
     | 
    
         
             
            from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
         
     | 
| 
       4 
     | 
    
         
            -
            from tinygrad.runtime.autogen.am import am, mp_11_0 
     | 
| 
      
 4 
     | 
    
         
            +
            from tinygrad.runtime.autogen.am import am, mp_11_0
         
     | 
| 
       5 
5 
     | 
    
         
             
            from tinygrad.runtime.support.allocator import TLSFAllocator
         
     | 
| 
       6 
6 
     | 
    
         
             
            from tinygrad.runtime.support.am.ip import AM_SOC21, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
         
     | 
| 
       7 
7 
     | 
    
         | 
| 
         @@ -32,11 +32,13 @@ class AMRegister: 
     | 
|
| 
       32 
32 
     | 
    
         
             
              def read(self, **kwargs): return self.adev.rreg(self.reg_off) & self._parse_kwargs(**kwargs)[0]
         
     | 
| 
       33 
33 
     | 
    
         | 
| 
       34 
34 
     | 
    
         
             
            class AMFirmware:
         
     | 
| 
       35 
     | 
    
         
            -
              def __init__(self):
         
     | 
| 
      
 35 
     | 
    
         
            +
              def __init__(self, adev):
         
     | 
| 
      
 36 
     | 
    
         
            +
                def fmt_ver(hwip): return f"{adev.ip_versions[hwip]//10000}_{(adev.ip_versions[hwip]//100)%100}_{adev.ip_versions[hwip]%100}"
         
     | 
| 
      
 37 
     | 
    
         
            +
             
     | 
| 
       36 
38 
     | 
    
         
             
                # Load SOS firmware
         
     | 
| 
       37 
39 
     | 
    
         
             
                self.sos_fw = {}
         
     | 
| 
       38 
40 
     | 
    
         | 
| 
       39 
     | 
    
         
            -
                blob, sos_hdr = self.load_fw(" 
     | 
| 
      
 41 
     | 
    
         
            +
                blob, sos_hdr = self.load_fw(f"psp_{fmt_ver(am.MP0_HWIP)}_sos.bin", am.struct_psp_firmware_header_v2_0)
         
     | 
| 
       40 
42 
     | 
    
         
             
                fw_bin = sos_hdr.psp_fw_bin
         
     | 
| 
       41 
43 
     | 
    
         | 
| 
       42 
44 
     | 
    
         
             
                for fw_i in range(sos_hdr.psp_fw_bin_count):
         
     | 
| 
         @@ -48,17 +50,17 @@ class AMFirmware: 
     | 
|
| 
       48 
50 
     | 
    
         
             
                self.ucode_start: dict[str, int] = {}
         
     | 
| 
       49 
51 
     | 
    
         
             
                self.descs: list[tuple[int, memoryview]] = []
         
     | 
| 
       50 
52 
     | 
    
         | 
| 
       51 
     | 
    
         
            -
                blob, hdr = self.load_fw(" 
     | 
| 
      
 53 
     | 
    
         
            +
                blob, hdr = self.load_fw(f"smu_{fmt_ver(am.MP1_HWIP)}.bin", am.struct_smc_firmware_header_v1_0)
         
     | 
| 
       52 
54 
     | 
    
         
             
                self.smu_psp_desc = self.desc(am.GFX_FW_TYPE_SMU, blob, hdr.header.ucode_array_offset_bytes, hdr.header.ucode_size_bytes)
         
     | 
| 
       53 
55 
     | 
    
         | 
| 
       54 
56 
     | 
    
         
             
                # SDMA firmware
         
     | 
| 
       55 
     | 
    
         
            -
                blob, hdr = self.load_fw(" 
     | 
| 
      
 57 
     | 
    
         
            +
                blob, hdr = self.load_fw(f"sdma_{fmt_ver(am.SDMA0_HWIP)}.bin", am.struct_sdma_firmware_header_v2_0)
         
     | 
| 
       56 
58 
     | 
    
         
             
                self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH0, blob, hdr.header.ucode_array_offset_bytes, hdr.ctx_ucode_size_bytes)]
         
     | 
| 
       57 
59 
     | 
    
         
             
                self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH1, blob, hdr.ctl_ucode_offset, hdr.ctl_ucode_size_bytes)]
         
     | 
| 
       58 
60 
     | 
    
         | 
| 
       59 
61 
     | 
    
         
             
                # PFP, ME, MEC firmware
         
     | 
| 
       60 
62 
     | 
    
         
             
                for (fw_name, fw_cnt) in [('PFP', 2), ('ME', 2), ('MEC', 4)]:
         
     | 
| 
       61 
     | 
    
         
            -
                  blob, hdr = self.load_fw(f" 
     | 
| 
      
 63 
     | 
    
         
            +
                  blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_{fw_name.lower()}.bin", am.struct_gfx_firmware_header_v2_0)
         
     | 
| 
       62 
64 
     | 
    
         | 
| 
       63 
65 
     | 
    
         
             
                  # Code part
         
     | 
| 
       64 
66 
     | 
    
         
             
                  self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}'), blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes)]
         
     | 
| 
         @@ -69,12 +71,12 @@ class AMFirmware: 
     | 
|
| 
       69 
71 
     | 
    
         
             
                  self.ucode_start[fw_name] = hdr.ucode_start_addr_lo | (hdr.ucode_start_addr_hi << 32)
         
     | 
| 
       70 
72 
     | 
    
         | 
| 
       71 
73 
     | 
    
         
             
                # IMU firmware
         
     | 
| 
       72 
     | 
    
         
            -
                blob, hdr = self.load_fw(" 
     | 
| 
      
 74 
     | 
    
         
            +
                blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_imu.bin", am.struct_imu_firmware_header_v1_0)
         
     | 
| 
       73 
75 
     | 
    
         
             
                imu_i_off, imu_i_sz, imu_d_sz = hdr.header.ucode_array_offset_bytes, hdr.imu_iram_ucode_size_bytes, hdr.imu_dram_ucode_size_bytes
         
     | 
| 
       74 
76 
     | 
    
         
             
                self.descs += [self.desc(am.GFX_FW_TYPE_IMU_I, blob, imu_i_off, imu_i_sz), self.desc(am.GFX_FW_TYPE_IMU_D, blob, imu_i_off + imu_i_sz, imu_d_sz)]
         
     | 
| 
       75 
77 
     | 
    
         | 
| 
       76 
78 
     | 
    
         
             
                # RLC firmware
         
     | 
| 
       77 
     | 
    
         
            -
                blob, hdr0, hdr1, hdr2, hdr3 = self.load_fw(" 
     | 
| 
      
 79 
     | 
    
         
            +
                blob, hdr0, hdr1, hdr2, hdr3 = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_rlc.bin", am.struct_rlc_firmware_header_v2_0,
         
     | 
| 
       78 
80 
     | 
    
         
             
                  am.struct_rlc_firmware_header_v2_1, am.struct_rlc_firmware_header_v2_2, am.struct_rlc_firmware_header_v2_3)
         
     | 
| 
       79 
81 
     | 
    
         | 
| 
       80 
82 
     | 
    
         
             
                for mem in ['GPM', 'SRM']:
         
     | 
| 
         @@ -263,7 +265,7 @@ class AMDev: 
     | 
|
| 
       263 
265 
     | 
    
         | 
| 
       264 
266 
     | 
    
         
             
                # Memory manager & firmware
         
     | 
| 
       265 
267 
     | 
    
         
             
                self.mm = AMMemoryManager(self, self.vram_size)
         
     | 
| 
       266 
     | 
    
         
            -
                self.fw = AMFirmware()
         
     | 
| 
      
 268 
     | 
    
         
            +
                self.fw = AMFirmware(self)
         
     | 
| 
       267 
269 
     | 
    
         | 
| 
       268 
270 
     | 
    
         
             
                # Initialize IP blocks
         
     | 
| 
       269 
271 
     | 
    
         
             
                self.soc21:AM_SOC21 = AM_SOC21(self)
         
     | 
| 
         @@ -274,7 +276,7 @@ class AMDev: 
     | 
|
| 
       274 
276 
     | 
    
         
             
                self.gfx:AM_GFX = AM_GFX(self)
         
     | 
| 
       275 
277 
     | 
    
         
             
                self.sdma:AM_SDMA = AM_SDMA(self)
         
     | 
| 
       276 
278 
     | 
    
         | 
| 
       277 
     | 
    
         
            -
                if self.partial_boot and (self.reg(" 
     | 
| 
      
 279 
     | 
    
         
            +
                if self.partial_boot and (self.reg("regGCVM_CONTEXT0_CNTL").read() != 0):
         
     | 
| 
       278 
280 
     | 
    
         
             
                  if DEBUG >= 2: print(f"am {self.devfmt}: MEC is active. Issue a full reset.")
         
     | 
| 
       279 
281 
     | 
    
         
             
                  self.partial_boot = False
         
     | 
| 
       280 
282 
     | 
    
         | 
| 
         @@ -298,8 +300,10 @@ class AMDev: 
     | 
|
| 
       298 
300 
     | 
    
         
             
                if DEBUG >= 2: print(f"am {self.devfmt}: boot done")
         
     | 
| 
       299 
301 
     | 
    
         | 
| 
       300 
302 
     | 
    
         
             
              def fini(self):
         
     | 
| 
      
 303 
     | 
    
         
            +
                if DEBUG >= 2: print(f"am {self.devfmt}: Finalizing")
         
     | 
| 
       301 
304 
     | 
    
         
             
                for ip in [self.sdma, self.gfx]: ip.fini()
         
     | 
| 
       302 
305 
     | 
    
         
             
                self.smu.set_clocks(level=0)
         
     | 
| 
      
 306 
     | 
    
         
            +
                self.ih.interrupt_handler()
         
     | 
| 
       303 
307 
     | 
    
         | 
| 
       304 
308 
     | 
    
         
             
              def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr
         
     | 
| 
       305 
309 
     | 
    
         
             
              def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr
         
     | 
| 
         @@ -369,8 +373,16 @@ class AMDev: 
     | 
|
| 
       369 
373 
     | 
    
         
             
                gc_info = am.struct_gc_info_v1_0.from_address(gc_addr:=ctypes.addressof(bhdr) + bhdr.table_list[am.GC].offset)
         
     | 
| 
       370 
374 
     | 
    
         
             
                self.gc_info = getattr(am, f"struct_gc_info_v{gc_info.header.version_major}_{gc_info.header.version_minor}").from_address(gc_addr)
         
     | 
| 
       371 
375 
     | 
    
         | 
| 
      
 376 
     | 
    
         
            +
              def _ip_module(self, prefix:str, hwip):
         
     | 
| 
      
 377 
     | 
    
         
            +
                version = [self.ip_versions[hwip]//10000, (self.ip_versions[hwip]//100)%100, self.ip_versions[hwip]%100]
         
     | 
| 
      
 378 
     | 
    
         
            +
                for ver in [version, version[:2]+[0], version[:1]+[0, 0]]:
         
     | 
| 
      
 379 
     | 
    
         
            +
                  try: return __import__(f"tinygrad.runtime.autogen.am.{prefix}_{ver[0]}_{ver[1]}_{ver[2]}", fromlist=[f"{prefix}_{ver[0]}_{ver[1]}_{ver[2]}"])
         
     | 
| 
      
 380 
     | 
    
         
            +
                  except ImportError: pass
         
     | 
| 
      
 381 
     | 
    
         
            +
                raise ImportError(f"am {self.devfmt}: failed to load {prefix} module with version {version}")
         
     | 
| 
      
 382 
     | 
    
         
            +
             
     | 
| 
       372 
383 
     | 
    
         
             
              def _build_regs(self):
         
     | 
| 
       373 
     | 
    
         
            -
                mods = [("MP0",  
     | 
| 
      
 384 
     | 
    
         
            +
                mods = [("MP0", self._ip_module("mp", am.MP0_HWIP)), ("NBIO", self._ip_module("nbio", am.NBIO_HWIP)), ("GC", self._ip_module("gc", am.GC_HWIP)),
         
     | 
| 
      
 385 
     | 
    
         
            +
                  ("MP1", mp_11_0), ("MMHUB", self._ip_module("mmhub", am.MMHUB_HWIP)), ("OSSSYS", self._ip_module("osssys", am.OSSSYS_HWIP))]
         
     | 
| 
       374 
386 
     | 
    
         
             
                for base, module in mods:
         
     | 
| 
       375 
387 
     | 
    
         
             
                  rpref = "mm" if base == "MP1" else "reg" # MP1 regs starts with mm
         
     | 
| 
       376 
388 
     | 
    
         
             
                  reg_names: set[str] = set(k[len(rpref):] for k in module.__dict__.keys() if k.startswith(rpref) and not k.endswith("_BASE_IDX"))
         
     | 
| 
         @@ -18,7 +18,7 @@ class AM_GMC(AM_IP): 
     | 
|
| 
       18 
18 
     | 
    
         
             
                super().__init__(adev)
         
     | 
| 
       19 
19 
     | 
    
         | 
| 
       20 
20 
     | 
    
         
             
                # Memory controller aperture
         
     | 
| 
       21 
     | 
    
         
            -
                self.mc_base = self.adev.regMMMC_VM_FB_LOCATION_BASE.read() << 24
         
     | 
| 
      
 21 
     | 
    
         
            +
                self.mc_base = (self.adev.regMMMC_VM_FB_LOCATION_BASE.read() & 0xFFFFFF) << 24
         
     | 
| 
       22 
22 
     | 
    
         
             
                self.mc_end = self.mc_base + self.adev.mm.vram_size - 1
         
     | 
| 
       23 
23 
     | 
    
         | 
| 
       24 
24 
     | 
    
         
             
                # VM aperture
         
     | 
| 
         @@ -189,8 +189,6 @@ class AM_GFX(AM_IP): 
     | 
|
| 
       189 
189 
     | 
    
         
             
                self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2) # 1 - DRAIN_PIPE; 2 - RESET_WAVES
         
     | 
| 
       190 
190 
     | 
    
         
             
                self.adev.regSPI_COMPUTE_QUEUE_RESET.write(1)
         
     | 
| 
       191 
191 
     | 
    
         
             
                self._grbm_select()
         
     | 
| 
       192 
     | 
    
         
            -
                self.adev.regCP_MEC_RS64_CNTL.update(mec_invalidate_icache=1, mec_pipe0_reset=1, mec_pipe1_reset=1, mec_pipe2_reset=1, mec_pipe3_reset=1,
         
     | 
| 
       193 
     | 
    
         
            -
                                                     mec_pipe0_active=0, mec_pipe1_active=0, mec_pipe2_active=0, mec_pipe3_active=0, mec_halt=1)
         
     | 
| 
       194 
192 
     | 
    
         
             
                self.adev.regGCVM_CONTEXT0_CNTL.write(0)
         
     | 
| 
       195 
193 
     | 
    
         | 
| 
       196 
194 
     | 
    
         
             
              def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, doorbell:int, pipe:int, queue:int):
         
     | 
| 
         @@ -225,6 +223,8 @@ class AM_GFX(AM_IP): 
     | 
|
| 
       225 
223 
     | 
    
         
             
                self.adev.reg(f"regCP_ME1_PIPE{pipe}_INT_CNTL").update(time_stamp_int_enable=1, generic0_int_enable=1)
         
     | 
| 
       226 
224 
     | 
    
         | 
| 
       227 
225 
     | 
    
         
             
              def set_clockgating_state(self):
         
     | 
| 
      
 226 
     | 
    
         
            +
                if hasattr(self.adev, 'regMM_ATC_L2_MISC_CG'): self.adev.regMM_ATC_L2_MISC_CG.write(enable=1, mem_ls_enable=1)
         
     | 
| 
      
 227 
     | 
    
         
            +
             
     | 
| 
       228 
228 
     | 
    
         
             
                self.adev.regRLC_SAFE_MODE.write(message=1, cmd=1)
         
     | 
| 
       229 
229 
     | 
    
         
             
                self.adev.wait_reg(self.adev.regRLC_SAFE_MODE, mask=0x1, value=0x0)
         
     | 
| 
       230 
230 
     | 
    
         | 
| 
         @@ -233,6 +233,7 @@ class AM_GFX(AM_IP): 
     | 
|
| 
       233 
233 
     | 
    
         
             
                self.adev.regCP_RB_WPTR_POLL_CNTL.update(poll_frequency=0x100, idle_poll_count=0x90)
         
     | 
| 
       234 
234 
     | 
    
         
             
                self.adev.regCP_INT_CNTL.update(cntx_busy_int_enable=1, cntx_empty_int_enable=1, cmp_busy_int_enable=1, gfx_idle_int_enable=1)
         
     | 
| 
       235 
235 
     | 
    
         
             
                self.adev.regSDMA0_RLC_CGCG_CTRL.update(cgcg_int_enable=1)
         
     | 
| 
      
 236 
     | 
    
         
            +
                self.adev.regSDMA1_RLC_CGCG_CTRL.update(cgcg_int_enable=1)
         
     | 
| 
       236 
237 
     | 
    
         | 
| 
       237 
238 
     | 
    
         
             
                self.adev.regRLC_CGTT_MGCG_OVERRIDE.update(perfmon_clock_state=0, gfxip_fgcg_override=0, gfxip_repeater_fgcg_override=0,
         
     | 
| 
       238 
239 
     | 
    
         
             
                  grbm_cgtt_sclk_override=0, rlc_cgtt_sclk_override=0, gfxip_mgcg_override=0, gfxip_cgls_override=0, gfxip_cgcg_override=0)
         
     | 
| 
         @@ -311,17 +312,16 @@ class AM_SDMA(AM_IP): 
     | 
|
| 
       311 
312 
     | 
    
         
             
                self.adev.reg(f"regSDMA{pipe}_QUEUE{queue}_IB_CNTL").update(ib_enable=1)
         
     | 
| 
       312 
313 
     | 
    
         | 
| 
       313 
314 
     | 
    
         
             
              def init(self):
         
     | 
| 
       314 
     | 
    
         
            -
                 
     | 
| 
       315 
     | 
    
         
            -
             
     | 
| 
       316 
     | 
    
         
            -
             
     | 
| 
       317 
     | 
    
         
            -
             
     | 
| 
       318 
     | 
    
         
            -
             
     | 
| 
       319 
     | 
    
         
            -
             
     | 
| 
      
 315 
     | 
    
         
            +
                for pipe in range(2):
         
     | 
| 
      
 316 
     | 
    
         
            +
                  self.adev.reg(f"regSDMA{pipe}_WATCHDOG_CNTL").update(queue_hang_count=100) # 10s, 100ms per unit
         
     | 
| 
      
 317 
     | 
    
         
            +
                  self.adev.reg(f"regSDMA{pipe}_UTCL1_CNTL").update(resp_mode=3, redo_delay=9)
         
     | 
| 
      
 318 
     | 
    
         
            +
                  self.adev.reg(f"regSDMA{pipe}_UTCL1_PAGE").update(rd_l2_policy=0x2, wr_l2_policy=0x3, llc_noalloc=1) # rd=noa, wr=bypass
         
     | 
| 
      
 319 
     | 
    
         
            +
                  self.adev.reg(f"regSDMA{pipe}_F32_CNTL").update(halt=0, th1_reset=0)
         
     | 
| 
      
 320 
     | 
    
         
            +
                  self.adev.reg(f"regSDMA{pipe}_CNTL").update(ctxempty_int_enable=1, trap_enable=1)
         
     | 
| 
       320 
321 
     | 
    
         | 
| 
       321 
322 
     | 
    
         
             
              def fini(self):
         
     | 
| 
       322 
323 
     | 
    
         
             
                self.adev.regSDMA0_QUEUE0_RB_CNTL.update(rb_enable=0)
         
     | 
| 
       323 
324 
     | 
    
         
             
                self.adev.regSDMA0_QUEUE0_IB_CNTL.update(ib_enable=0)
         
     | 
| 
       324 
     | 
    
         
            -
                self.adev.regSDMA0_F32_CNTL.update(halt=1, th1_reset=1)
         
     | 
| 
       325 
325 
     | 
    
         
             
                self.adev.regGRBM_SOFT_RESET.write(soft_reset_sdma0=1)
         
     | 
| 
       326 
326 
     | 
    
         
             
                time.sleep(0.01)
         
     | 
| 
       327 
327 
     | 
    
         
             
                self.adev.regGRBM_SOFT_RESET.write(0x0)
         
     | 
    
        tinygrad/runtime/support/elf.py
    CHANGED
    
    | 
         @@ -32,6 +32,8 @@ def elf_loader(blob:bytes, force_section_align:int=1) -> tuple[memoryview, list[ 
     | 
|
| 
       32 
32 
     | 
    
         
             
              for sh, trgt_sh_name, c_rels in rel + rela:
         
     | 
| 
       33 
33 
     | 
    
         
             
                target_image_off = next(tsh for tsh in sections if tsh.name == trgt_sh_name).header.sh_addr
         
     | 
| 
       34 
34 
     | 
    
         
             
                rels = [(r.r_offset, symtab[libc.ELF64_R_SYM(r.r_info)], libc.ELF64_R_TYPE(r.r_info), getattr(r, "r_addend", 0)) for r in c_rels]
         
     | 
| 
      
 35 
     | 
    
         
            +
                for roff, sym, r_type_, r_addend in rels:
         
     | 
| 
      
 36 
     | 
    
         
            +
                  if sym.st_shndx == 0: raise RuntimeError(f'Attempting to relocate against an undefined symbol {repr(_strtab(sh_strtab, sym.st_name))}')
         
     | 
| 
       35 
37 
     | 
    
         
             
                relocs += [(target_image_off + roff, sections[sym.st_shndx].header.sh_addr + sym.st_value, rtype, raddend) for roff, sym, rtype, raddend in rels]
         
     | 
| 
       36 
38 
     | 
    
         | 
| 
       37 
39 
     | 
    
         
             
              return memoryview(image), sections, relocs
         
     | 
    
        tinygrad/runtime/support/hcq.py
    CHANGED
    
    | 
         @@ -4,7 +4,7 @@ import contextlib, decimal, statistics, time, ctypes, array, os, fcntl 
     | 
|
| 
       4 
4 
     | 
    
         
             
            from tinygrad.helpers import PROFILE, from_mv, getenv, to_mv, round_up
         
     | 
| 
       5 
5 
     | 
    
         
             
            from tinygrad.renderer import Renderer
         
     | 
| 
       6 
6 
     | 
    
         
             
            from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileRangeEvent, ProfileDeviceEvent
         
     | 
| 
       7 
     | 
    
         
            -
            from tinygrad.ops import sym_infer, sint, Variable
         
     | 
| 
      
 7 
     | 
    
         
            +
            from tinygrad.ops import sym_infer, sint, Variable, UOp
         
     | 
| 
       8 
8 
     | 
    
         
             
            from tinygrad.runtime.autogen import libc
         
     | 
| 
       9 
9 
     | 
    
         | 
| 
       10 
10 
     | 
    
         
             
            class HWInterface:
         
     | 
| 
         @@ -19,9 +19,11 @@ class HWInterface: 
     | 
|
| 
       19 
19 
     | 
    
         
             
                if hasattr(self, 'fd'): os.close(self.fd)
         
     | 
| 
       20 
20 
     | 
    
         
             
              def ioctl(self, request, arg): return fcntl.ioctl(self.fd, request, arg)
         
     | 
| 
       21 
21 
     | 
    
         
             
              def mmap(self, start, sz, prot, flags, offset): return libc.mmap(start, sz, prot, flags, self.fd, offset)
         
     | 
| 
       22 
     | 
    
         
            -
              def read(self, size=None, binary=False):
         
     | 
| 
      
 22 
     | 
    
         
            +
              def read(self, size=None, binary=False, offset=None):
         
     | 
| 
      
 23 
     | 
    
         
            +
                if offset is not None: self.seek(offset)
         
     | 
| 
       23 
24 
     | 
    
         
             
                with open(self.fd, "rb" if binary else "r", closefd=False) as file: return file.read(size)
         
     | 
| 
       24 
     | 
    
         
            -
              def write(self, content, binary=False):
         
     | 
| 
      
 25 
     | 
    
         
            +
              def write(self, content, binary=False, offset=None):
         
     | 
| 
      
 26 
     | 
    
         
            +
                if offset is not None: self.seek(offset)
         
     | 
| 
       25 
27 
     | 
    
         
             
                with open(self.fd, "wb" if binary else "w", closefd=False) as file: file.write(content)
         
     | 
| 
       26 
28 
     | 
    
         
             
              def listdir(self): return os.listdir(self.path)
         
     | 
| 
       27 
29 
     | 
    
         
             
              def seek(self, offset): os.lseek(self.fd, offset, os.SEEK_SET)
         
     | 
| 
         @@ -83,10 +85,10 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]): 
     | 
|
| 
       83 
85 
     | 
    
         
             
                """
         
     | 
| 
       84 
86 
     | 
    
         | 
| 
       85 
87 
     | 
    
         
             
                for v in values:
         
     | 
| 
       86 
     | 
    
         
            -
                  if isinstance(v,  
     | 
| 
       87 
     | 
    
         
            -
                  else:
         
     | 
| 
      
 88 
     | 
    
         
            +
                  if isinstance(v, UOp):
         
     | 
| 
       88 
89 
     | 
    
         
             
                    self.q_sints.append((len(self._q), self._new_sym(v)))
         
     | 
| 
       89 
90 
     | 
    
         
             
                    self._q.append(0xbadc0ded)
         
     | 
| 
      
 91 
     | 
    
         
            +
                  else: self._q.append(v)
         
     | 
| 
       90 
92 
     | 
    
         | 
| 
       91 
93 
     | 
    
         
             
              # *** common commands  ***
         
     | 
| 
       92 
94 
     | 
    
         | 
    
        tinygrad/runtime/support/llvm.py
    CHANGED
    
    | 
         @@ -6,27 +6,21 @@ if sys.platform == 'win32': 
     | 
|
| 
       6 
6 
     | 
    
         
             
              # winget also doesn't have something like `brew --prefix llvm` so just hardcode default installation path with an option to override
         
     | 
| 
       7 
7 
     | 
    
         
             
              LLVM_PATH = getenv('LLVM_PATH', 'C:\\Program Files\\LLVM\\bin\\LLVM-C.dll')
         
     | 
| 
       8 
8 
     | 
    
         
             
              if not os.path.exists(LLVM_PATH):
         
     | 
| 
       9 
     | 
    
         
            -
                raise  
     | 
| 
       10 
     | 
    
         
            -
            elif OSX and 'tinygrad.runtime.ops_metal' in sys.modules:
         
     | 
| 
       11 
     | 
    
         
            -
              # Opening METAL after LLVM doesn't fail because ctypes.CDLL opens with RTLD_LOCAL but MTLCompiler opens it's own llvm with RTLD_GLOBAL
         
     | 
| 
       12 
     | 
    
         
            -
              # This means that MTLCompiler's llvm will create it's own instances of global state because RTLD_LOCAL doesn't export symbols, but if RTLD_GLOBAL
         
     | 
| 
       13 
     | 
    
         
            -
              # library is loaded first then RTLD_LOCAL library will just use it's symbols. On linux there is RTLD_DEEPBIND to prevent that, but on macos there
         
     | 
| 
       14 
     | 
    
         
            -
              # doesn't seem to be anything we can do.
         
     | 
| 
       15 
     | 
    
         
            -
              LLVM_PATH = ctypes.util.find_library('tinyllvm')
         
     | 
| 
       16 
     | 
    
         
            -
              if LLVM_PATH is None:
         
     | 
| 
       17 
     | 
    
         
            -
                raise RuntimeError("LLVM can't be opened in the same process with metal. You can install llvm distribution which supports that via `brew install uuuvn/tinygrad/tinyllvm`") # noqa: E501
         
     | 
| 
      
 9 
     | 
    
         
            +
                raise FileNotFoundError('LLVM not found, you can install it with `winget install LLVM.LLVM` or point at a custom dll with LLVM_PATH')
         
     | 
| 
       18 
10 
     | 
    
         
             
            elif OSX:
         
     | 
| 
      
 11 
     | 
    
         
            +
              # Will raise FileNotFoundError if brew is not installed
         
     | 
| 
       19 
12 
     | 
    
         
             
              brew_prefix = subprocess.check_output(['brew', '--prefix', 'llvm']).decode().strip()
         
     | 
| 
       20 
13 
     | 
    
         
             
              # `brew --prefix` will return even if formula is not installed
         
     | 
| 
       21 
14 
     | 
    
         
             
              if not os.path.exists(brew_prefix):
         
     | 
| 
       22 
     | 
    
         
            -
                raise  
     | 
| 
       23 
     | 
    
         
            -
              LLVM_PATH = os.path.join(brew_prefix, 'lib', 'libLLVM.dylib')
         
     | 
| 
      
 15 
     | 
    
         
            +
                raise FileNotFoundError('LLVM not found, you can install it with `brew install llvm`')
         
     | 
| 
      
 16 
     | 
    
         
            +
              LLVM_PATH: str|None = os.path.join(brew_prefix, 'lib', 'libLLVM.dylib')
         
     | 
| 
       24 
17 
     | 
    
         
             
            else:
         
     | 
| 
       25 
18 
     | 
    
         
             
              LLVM_PATH = ctypes.util.find_library('LLVM')
         
     | 
| 
       26 
     | 
    
         
            -
               
     | 
| 
      
 19 
     | 
    
         
            +
              # use newer LLVM if possible
         
     | 
| 
      
 20 
     | 
    
         
            +
              for ver in reversed(range(14, 19+1)):
         
     | 
| 
       27 
21 
     | 
    
         
             
                if LLVM_PATH is not None: break
         
     | 
| 
       28 
22 
     | 
    
         
             
                LLVM_PATH = ctypes.util.find_library(f'LLVM-{ver}')
         
     | 
| 
       29 
23 
     | 
    
         
             
              if LLVM_PATH is None:
         
     | 
| 
       30 
     | 
    
         
            -
                raise  
     | 
| 
      
 24 
     | 
    
         
            +
                raise FileNotFoundError("No LLVM library found on the system. Install it via your distro's package manager and ensure it's findable as 'LLVM'")
         
     | 
| 
       31 
25 
     | 
    
         | 
| 
       32 
     | 
    
         
            -
            if DEBUG>= 
     | 
| 
      
 26 
     | 
    
         
            +
            if DEBUG>=3: print(f'Using LLVM at {repr(LLVM_PATH)}')
         
     | 
    
        tinygrad/shape/shapetracker.py
    CHANGED
    
    | 
         @@ -6,8 +6,8 @@ from typing import Optional, Callable 
     | 
|
| 
       6 
6 
     | 
    
         
             
            from tinygrad.helpers import merge_dicts, getenv
         
     | 
| 
       7 
7 
     | 
    
         
             
            from tinygrad.shape.view import View, strides_for_shape, unravel
         
     | 
| 
       8 
8 
     | 
    
         
             
            from tinygrad.dtype import dtypes
         
     | 
| 
       9 
     | 
    
         
            -
            from tinygrad.ops import UOp, Ops, graph_rewrite,  
     | 
| 
       10 
     | 
    
         
            -
            from tinygrad.codegen. 
     | 
| 
      
 9 
     | 
    
         
            +
            from tinygrad.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context
         
     | 
| 
      
 10 
     | 
    
         
            +
            from tinygrad.codegen.symbolic import sym, split_uop, symbolic_flat, uop_given_valid, simplify_valid
         
     | 
| 
       11 
11 
     | 
    
         | 
| 
       12 
12 
     | 
    
         
             
            def overflow(u: UOp): return u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int)
         
     | 
| 
       13 
13 
     | 
    
         | 
| 
         @@ -109,6 +109,7 @@ class ShapeTracker: 
     | 
|
| 
       109 
109 
     | 
    
         | 
| 
       110 
110 
     | 
    
         
             
              def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]:
         
     | 
| 
       111 
111 
     | 
    
         
             
                unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
         
     | 
| 
      
 112 
     | 
    
         
            +
                if all(len(x) == 0 for x in var_vals): return self, {}
         
     | 
| 
       112 
113 
     | 
    
         
             
                return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
         
     | 
| 
       113 
114 
     | 
    
         | 
| 
       114 
115 
     | 
    
         
             
              def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]: return views_to_real_strides(self.views, ignore_valid)
         
     | 
    
        tinygrad/shape/view.py
    CHANGED
    
    | 
         @@ -107,8 +107,7 @@ class View: 
     | 
|
| 
       107 
107 
     | 
    
         
             
              @staticmethod
         
     | 
| 
       108 
108 
     | 
    
         
             
              @functools.lru_cache(maxsize=None)
         
     | 
| 
       109 
109 
     | 
    
         
             
              def create(shape:tuple[sint, ...], strides:Optional[tuple[sint, ...]]=None, offset:sint=0, mask:Optional[tuple[tuple[sint, sint], ...]]=None):
         
     | 
| 
       110 
     | 
    
         
            -
                 
     | 
| 
       111 
     | 
    
         
            -
                if not all(resolve(s >= 0) for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
         
     | 
| 
      
 110 
     | 
    
         
            +
                if not all(s >= 0 for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
         
     | 
| 
       112 
111 
     | 
    
         
             
                strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
         
     | 
| 
       113 
112 
     | 
    
         
             
                # canonicalize 0 in shape
         
     | 
| 
       114 
113 
     | 
    
         
             
                if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
         
     | 
| 
         @@ -274,7 +273,7 @@ class View: 
     | 
|
| 
       274 
273 
     | 
    
         
             
              def reshape(self, new_shape: tuple[sint, ...]) -> Optional[View]:
         
     | 
| 
       275 
274 
     | 
    
         
             
                if self.shape == new_shape: return self
         
     | 
| 
       276 
275 
     | 
    
         | 
| 
       277 
     | 
    
         
            -
                 
     | 
| 
      
 276 
     | 
    
         
            +
                if not all(x >= 0 for x in new_shape): raise ValueError(f"shape can't contain negative numbers {new_shape}")
         
     | 
| 
       278 
277 
     | 
    
         
             
                # check for the same size
         
     | 
| 
       279 
278 
     | 
    
         
             
                if (self_all_int := all_int(self.shape)):
         
     | 
| 
       280 
279 
     | 
    
         
             
                  assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
         
     | 
    
        tinygrad/spec.py
    CHANGED
    
    | 
         @@ -1,21 +1,25 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            from typing import cast
         
     | 
| 
       2 
2 
     | 
    
         
             
            from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops
         
     | 
| 
       3 
3 
     | 
    
         
             
            from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
         
     | 
| 
       4 
     | 
    
         
            -
            from tinygrad.helpers import  
     | 
| 
      
 4 
     | 
    
         
            +
            from tinygrad.helpers import all_same, dedup, prod
         
     | 
| 
       5 
5 
     | 
    
         | 
| 
       6 
     | 
    
         
            -
             
     | 
| 
       7 
     | 
    
         
            -
             
     | 
| 
       8 
     | 
    
         
            -
            tensor_uop_spec = PatternMatcher([
         
     | 
| 
      
 6 
     | 
    
         
            +
            buffer_spec = PatternMatcher([
         
     | 
| 
      
 7 
     | 
    
         
            +
              (UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
         
     | 
| 
       9 
8 
     | 
    
         
             
              (UPat(Ops.DEVICE, dtypes.void, (), name="device"), lambda device: isinstance(device.arg, str)),
         
     | 
| 
       10 
     | 
    
         
            -
              (UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE),), name="buf"),
         
     | 
| 
       11 
     | 
    
         
            -
               lambda buf: isinstance(buf.arg,  
     | 
| 
      
 9 
     | 
    
         
            +
              (UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="buf"),
         
     | 
| 
      
 10 
     | 
    
         
            +
               lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
         
     | 
| 
      
 11 
     | 
    
         
            +
            ])
         
     | 
| 
       12 
12 
     | 
    
         | 
| 
      
 13 
     | 
    
         
            +
            # *** this is the spec of a Tensor in UOp ***
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
            tensor_uop_spec = buffer_spec+PatternMatcher([
         
     | 
| 
       13 
16 
     | 
    
         
             
              (UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)),
         
     | 
| 
       14 
17 
     | 
    
         
             
               # naturally correct
         
     | 
| 
       15 
18 
     | 
    
         
             
               lambda mv,x: (isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
         
     | 
| 
       16 
19 
     | 
    
         
             
               # "make things that can't be images not images" can change the buffer dtype
         
     | 
| 
       17 
20 
     | 
    
         
             
               # this is fine as long as it's a realized buffer and base dtypes match.
         
     | 
| 
       18 
21 
     | 
    
         
             
               ((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.is_realized)),
         
     | 
| 
      
 22 
     | 
    
         
            +
              (UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.CONST, Ops.DEVICE}),)), lambda: False),
         
     | 
| 
       19 
23 
     | 
    
         | 
| 
       20 
24 
     | 
    
         
             
              # Tensor variable bindings
         
     | 
| 
       21 
25 
     | 
    
         
             
              (UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
         
     | 
| 
         @@ -32,11 +36,6 @@ tensor_uop_spec = PatternMatcher([ 
     | 
|
| 
       32 
36 
     | 
    
         
             
              # NOTE: the arg here specifies clone=True, which prevents folding same device copy
         
     | 
| 
       33 
37 
     | 
    
         
             
              (UPat(Ops.COPY, name="copy", src=(UPat(Ops.DEVICE), UPat.var("x"))), lambda copy,x: isinstance(copy.arg, bool) and copy.dtype == x.dtype),
         
     | 
| 
       34 
38 
     | 
    
         | 
| 
       35 
     | 
    
         
            -
              # VIEW(BUFFER) applies a ShapeTracker on top of the underlying device buffer
         
     | 
| 
       36 
     | 
    
         
            -
              # NOTE: VIEW size exactly matches the underlying BUFFER, tensor doesn't apply movement ops to the VIEW
         
     | 
| 
       37 
     | 
    
         
            -
              (UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),
         
     | 
| 
       38 
     | 
    
         
            -
               lambda view,buf: view.dtype == buf.dtype and view.size == buf.size and view.st.contiguous),
         
     | 
| 
       39 
     | 
    
         
            -
             
     | 
| 
       40 
39 
     | 
    
         
             
              # ASSIGN changes the value of a realized buffer
         
     | 
| 
       41 
40 
     | 
    
         
             
              (UPat(Ops.ASSIGN, name="assign", src=(UPat.var("target"), UPat.var("new_val"))),
         
     | 
| 
       42 
41 
     | 
    
         
             
               lambda assign,target,new_val: target.is_realized and (assign.dtype == target.dtype == new_val.dtype)),
         
     | 
| 
         @@ -58,7 +57,7 @@ spec = PatternMatcher([ 
     | 
|
| 
       58 
57 
     | 
    
         | 
| 
       59 
58 
     | 
    
         
             
              # TODO: confirm the args of both of these are shapetrackers
         
     | 
| 
       60 
59 
     | 
    
         
             
              (UPat(Ops.VIEW, dtypes.void, src=()), lambda: True),
         
     | 
| 
       61 
     | 
    
         
            -
              (UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype),
         
     | 
| 
      
 60 
     | 
    
         
            +
              (UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype.base == src.dtype.base),
         
     | 
| 
       62 
61 
     | 
    
         | 
| 
       63 
62 
     | 
    
         
             
              (UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
         
     | 
| 
       64 
63 
     | 
    
         
             
              (UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
         
     | 
| 
         @@ -113,9 +112,9 @@ spec = PatternMatcher([ 
     | 
|
| 
       113 
112 
     | 
    
         
             
              (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
         
     | 
| 
       114 
113 
     | 
    
         | 
| 
       115 
114 
     | 
    
         
             
              # NOTE: for testing, we let sinks be anything
         
     | 
| 
       116 
     | 
    
         
            -
              #(UPat( 
     | 
| 
       117 
     | 
    
         
            -
              (UPat(Ops.SINK, dtypes.void), lambda: True),
         
     | 
| 
       118 
     | 
    
         
            -
              (UPat(Ops.NOOP), lambda: True),
         
     | 
| 
      
 115 
     | 
    
         
            +
              #(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
         
     | 
| 
      
 116 
     | 
    
         
            +
              (UPat((Ops.NAME, Ops.SINK), dtypes.void), lambda: True),
         
     | 
| 
      
 117 
     | 
    
         
            +
              (UPat((Ops.NOOP, Ops.CUSTOM)), lambda: True),
         
     | 
| 
       119 
118 
     | 
    
         | 
| 
       120 
119 
     | 
    
         
             
              # PTX LOAD/STORE
         
     | 
| 
       121 
120 
     | 
    
         
             
              (UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
         
     | 
| 
         @@ -123,11 +122,13 @@ spec = PatternMatcher([ 
     | 
|
| 
       123 
122 
     | 
    
         | 
| 
       124 
123 
     | 
    
         
             
            # *** this is the spec of a Kernel in UOp ***
         
     | 
| 
       125 
124 
     | 
    
         | 
| 
       126 
     | 
    
         
            -
            kernel_spec = PatternMatcher([
         
     | 
| 
       127 
     | 
    
         
            -
              (UPat(Ops. 
     | 
| 
       128 
     | 
    
         
            -
               
     | 
| 
       129 
     | 
    
         
            -
               
     | 
| 
       130 
     | 
    
         
            -
               
     | 
| 
      
 125 
     | 
    
         
            +
            kernel_spec = buffer_spec+PatternMatcher([
         
     | 
| 
      
 126 
     | 
    
         
            +
              (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True),
         
     | 
| 
      
 127 
     | 
    
         
            +
              # assign has a buffer view and kernel source, it can optionally depend on other assigns
         
     | 
| 
      
 128 
     | 
    
         
            +
              (UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
         
     | 
| 
      
 129 
     | 
    
         
            +
              # view/sink/const can also exist in the kernel graph
         
     | 
| 
      
 130 
     | 
    
         
            +
              (UPat((Ops.VIEW, Ops.SINK, Ops.CONST)), lambda: True),
         
     | 
| 
      
 131 
     | 
    
         
            +
              (UPat(GroupOp.All), lambda: False),
         
     | 
| 
       131 
132 
     | 
    
         
             
            ])
         
     | 
| 
       132 
133 
     | 
    
         | 
| 
       133 
134 
     | 
    
         
             
            # *** this is the UOp shape spec ***
         
     |