tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +253 -225
- tinygrad/codegen/linearizer.py +398 -436
- tinygrad/codegen/uops.py +451 -0
- tinygrad/device.py +268 -274
- tinygrad/dtype.py +56 -40
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +198 -0
- tinygrad/engine/realize.py +192 -0
- tinygrad/engine/schedule.py +370 -0
- tinygrad/engine/search.py +199 -0
- tinygrad/{mlops.py → function.py} +40 -32
- tinygrad/helpers.py +144 -46
- tinygrad/lazy.py +143 -242
- tinygrad/multi.py +173 -0
- tinygrad/nn/__init__.py +180 -9
- tinygrad/nn/datasets.py +8 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +87 -19
- tinygrad/ops.py +104 -45
- tinygrad/renderer/__init__.py +65 -0
- tinygrad/renderer/assembly.py +269 -0
- tinygrad/renderer/cstyle.py +308 -210
- tinygrad/renderer/llvmir.py +119 -124
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +13403 -0
- tinygrad/runtime/autogen/comgr.py +891 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5893 -0
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33597 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +56 -0
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +39 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +187 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +550 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +129 -37
- tinygrad/runtime/ops_disk.py +111 -43
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +41 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +625 -0
- tinygrad/runtime/ops_python.py +208 -0
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +46 -107
- tinygrad/shape/symbolic.py +99 -98
- tinygrad/shape/view.py +162 -45
- tinygrad/tensor.py +2492 -483
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
    
        tinygrad/device.py
    CHANGED
    
    | @@ -1,326 +1,320 @@ | |
| 1 1 | 
             
            from __future__ import annotations
         | 
| 2 | 
            -
            import  | 
| 2 | 
            +
            import multiprocessing
         | 
| 3 | 
            +
            from dataclasses import dataclass
         | 
| 3 4 | 
             
            from collections import defaultdict
         | 
| 4 | 
            -
            from typing import  | 
| 5 | 
            -
            import importlib, inspect, functools, pathlib,  | 
| 6 | 
            -
            from tinygrad. | 
| 7 | 
            -
            from tinygrad. | 
| 8 | 
            -
            from tinygrad. | 
| 9 | 
            -
            from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, GlobalCounters
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            if TYPE_CHECKING:
         | 
| 12 | 
            -
              from tinygrad.codegen.linearizer import Linearizer
         | 
| 13 | 
            -
              from tinygrad.codegen.kernel import LinearizerOptions
         | 
| 5 | 
            +
            from typing import List, Optional, Dict, Tuple, Any, cast
         | 
| 6 | 
            +
            import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib
         | 
| 7 | 
            +
            from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
         | 
| 8 | 
            +
            from tinygrad.dtype import DType, ImageDType
         | 
| 9 | 
            +
            from tinygrad.renderer import Renderer
         | 
| 14 10 |  | 
| 15 11 | 
             
            # **************** Device ****************
         | 
| 16 12 |  | 
| 17 13 | 
             
            class _Device:
         | 
| 18 14 | 
             
              def __init__(self) -> None: self._devices: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]  # noqa: E501
         | 
| 19 | 
            -
              def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT  # noqa: E501
         | 
| 20 | 
            -
              def __getitem__(self, ix:str) -> Union[Interpreted, Compiled]: return self.__get_canonicalized_item(self.canonicalize(ix))
         | 
| 21 15 | 
             
              @functools.lru_cache(maxsize=None)  # this class is a singleton, pylint: disable=method-cache-max-size-none
         | 
| 22 | 
            -
              def  | 
| 16 | 
            +
              def _canonicalize(self, device:str) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "")   # noqa: E501
         | 
| 17 | 
            +
              # NOTE: you can't cache canonicalize in case Device.DEFAULT changes
         | 
| 18 | 
            +
              def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
         | 
| 19 | 
            +
              def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
         | 
| 20 | 
            +
              @functools.lru_cache(maxsize=None)  # this class is a singleton, pylint: disable=method-cache-max-size-none
         | 
| 21 | 
            +
              def __get_canonicalized_item(self, ix:str) -> Compiled:
         | 
| 22 | 
            +
                assert ((cpn:=multiprocessing.current_process().name) == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], \
         | 
| 23 | 
            +
                  f"can only open device {ix} from parent, not {cpn}"
         | 
| 23 24 | 
             
                x = ix.split(":")[0].upper()
         | 
| 24 | 
            -
                ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0]  # noqa: E501
         | 
| 25 | 
            -
                if  | 
| 25 | 
            +
                ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix)  # noqa: E501
         | 
| 26 | 
            +
                if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
         | 
| 26 27 | 
             
                return ret
         | 
| 27 28 | 
             
              @functools.cached_property
         | 
| 28 29 | 
             
              def DEFAULT(self) -> str:
         | 
| 29 30 | 
             
                device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None)   # type: ignore
         | 
| 30 31 | 
             
                if device_from_env: return device_from_env
         | 
| 31 | 
            -
                for device in ["METAL", " | 
| 32 | 
            +
                for device in ["METAL", "AMD", "CUDA", "GPU", "CLANG", "LLVM"]:
         | 
| 32 33 | 
             
                  try:
         | 
| 33 | 
            -
                    if self[device]: | 
| 34 | 
            +
                    if self[device]:
         | 
| 35 | 
            +
                      os.environ[device] = "1"   # we set this in environment for spawned children
         | 
| 36 | 
            +
                      return device
         | 
| 34 37 | 
             
                  except Exception: pass
         | 
| 35 | 
            -
                 | 
| 38 | 
            +
                raise RuntimeError("no usable devices")
         | 
| 36 39 | 
             
            Device = _Device()
         | 
| 37 40 |  | 
| 38 | 
            -
            # ****************  | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
| 43 | 
            -
               | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
            -
                CacheCollector.add(self, rawbufs, var_vals)
         | 
| 48 | 
            -
                return et
         | 
| 49 | 
            -
              def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
         | 
| 50 | 
            -
                raise NotImplementedError("override this")
         | 
| 51 | 
            -
             | 
| 52 | 
            -
            def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count:int, jit=False, num_kernels=1, lra: Optional[Dict]=None, device:str="", first_run=False):  # noqa: E501
         | 
| 53 | 
            -
              if var_vals is None: var_vals = {}
         | 
| 54 | 
            -
              op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals)
         | 
| 55 | 
            -
              GlobalCounters.kernel_count += num_kernels
         | 
| 56 | 
            -
              GlobalCounters.global_ops += op_estimate
         | 
| 57 | 
            -
              GlobalCounters.global_mem += mem_estimate
         | 
| 58 | 
            -
              if et is not None: GlobalCounters.time_sum_s += et
         | 
| 59 | 
            -
              if DEBUG >= 2:
         | 
| 60 | 
            -
                ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
         | 
| 61 | 
            -
                print(f"{colored(f'*** {device[:7]:7s} {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else ('green' if first_run else None))} {name+' '*(37-ansilen(name))} arg {buf_count:3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " +  # noqa: E501
         | 
| 62 | 
            -
                      (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))  # noqa: E501
         | 
| 63 | 
            -
             | 
| 64 | 
            -
            # **************** Buffer / Allocator ****************
         | 
| 41 | 
            +
            # **************** Buffer + Allocators ****************
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            @dataclass(frozen=True, eq=True)
         | 
| 44 | 
            +
            class BufferOptions:
         | 
| 45 | 
            +
              image: Optional[ImageDType] = None
         | 
| 46 | 
            +
              uncached: bool = False
         | 
| 47 | 
            +
              cpu_access: bool = False
         | 
| 48 | 
            +
              host: bool = False
         | 
| 49 | 
            +
              nolru: bool = False
         | 
| 65 50 |  | 
| 66 51 | 
             
            class Buffer:
         | 
| 67 | 
            -
              def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None | 
| 52 | 
            +
              def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
         | 
| 53 | 
            +
                           initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
         | 
| 68 54 | 
             
                assert isinstance(dtype, DType)
         | 
| 69 | 
            -
                 | 
| 55 | 
            +
                if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
         | 
| 56 | 
            +
                self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
         | 
| 57 | 
            +
                if base is None:
         | 
| 58 | 
            +
                  assert offset == 0, "base buffers can't have offset"
         | 
| 59 | 
            +
                  self._base = None
         | 
| 60 | 
            +
                  self._lb_refcount = lb_refcount
         | 
| 61 | 
            +
                  if opaque is not None: self.allocate(opaque)
         | 
| 62 | 
            +
                  if initial_value is not None:
         | 
| 63 | 
            +
                    self.allocate()
         | 
| 64 | 
            +
                    self.copyin(memoryview(initial_value))
         | 
| 65 | 
            +
                else:
         | 
| 66 | 
            +
                  assert base._base is None, "base can't have a base"
         | 
| 67 | 
            +
                  assert device == base.device, "base must have the same device"
         | 
| 68 | 
            +
                  self._base = base
         | 
| 69 | 
            +
                if preallocate: self.allocate()
         | 
| 70 | 
            +
              @property
         | 
| 71 | 
            +
              def base(self) -> Buffer: return self._base if self._base is not None else self
         | 
| 72 | 
            +
              @property
         | 
| 73 | 
            +
              def lb_refcount(self): return self.base._lb_refcount
         | 
| 74 | 
            +
              def ref(self, cnt): self.base._lb_refcount += cnt
         | 
| 75 | 
            +
              def is_allocated(self) -> bool: return hasattr(self, '_buf')
         | 
| 76 | 
            +
              def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
         | 
| 77 | 
            +
              def allocate(self, opaque=None) -> Buffer:
         | 
| 78 | 
            +
                assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
         | 
| 70 79 | 
             
                self.allocator = Device[self.device].allocator
         | 
| 71 | 
            -
                 | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
            -
             | 
| 80 | 
            +
                if self._base is not None:
         | 
| 81 | 
            +
                  self._base.ensure_allocated()
         | 
| 82 | 
            +
                  assert hasattr(self.allocator, "offset"), "offset function required for view"
         | 
| 83 | 
            +
                  self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset)
         | 
| 84 | 
            +
                else:
         | 
| 85 | 
            +
                  self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
         | 
| 86 | 
            +
                  if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
         | 
| 87 | 
            +
                return self
         | 
| 88 | 
            +
              def __reduce__(self):
         | 
| 89 | 
            +
                buf = None
         | 
| 90 | 
            +
                if self._base is not None:
         | 
| 91 | 
            +
                  return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
         | 
| 92 | 
            +
                if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
         | 
| 93 | 
            +
                if self.is_allocated():
         | 
| 94 | 
            +
                  buf = bytearray(self.nbytes)
         | 
| 95 | 
            +
                  self.copyout(memoryview(buf))
         | 
| 96 | 
            +
                return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
         | 
| 97 | 
            +
              @property
         | 
| 98 | 
            +
              def nbytes(self): return self.size*self.dtype.itemsize
         | 
| 75 99 | 
             
              def __del__(self):
         | 
| 76 | 
            -
                if not hasattr(self, '_buf'): return | 
| 77 | 
            -
                if  | 
| 78 | 
            -
             | 
| 79 | 
            -
             | 
| 80 | 
            -
              def __repr__(self): | 
| 100 | 
            +
                if not hasattr(self, '_buf'): return
         | 
| 101 | 
            +
                if self._base is None:
         | 
| 102 | 
            +
                  if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
         | 
| 103 | 
            +
                  self.allocator.free(self._buf, self.nbytes, self.options)
         | 
| 104 | 
            +
              def __repr__(self):
         | 
| 105 | 
            +
                return f"<buf real:{hasattr(self, '_buf')} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
         | 
| 106 | 
            +
                       (f" offset:{self.offset}" if hasattr(self, "base") else "") + \
         | 
| 107 | 
            +
                       (">" if self.options is None else f" {self.options=}>")
         | 
| 108 | 
            +
              def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
         | 
| 109 | 
            +
                # zero copy with as_buffer (disabled by default due to use after free)
         | 
| 110 | 
            +
                if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
         | 
| 111 | 
            +
                assert not force_zero_copy, "force zero copy was passed, but copy is required"
         | 
| 112 | 
            +
                return self.copyout(memoryview(bytearray(self.nbytes)))
         | 
| 81 113 | 
             
              def copyin(self, mv:memoryview):
         | 
| 82 114 | 
             
                mv = flat_mv(mv)
         | 
| 83 | 
            -
                assert len(mv) == self. | 
| 115 | 
            +
                assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
         | 
| 116 | 
            +
                assert self.is_allocated(), "can't copyin to unallocated buffer"
         | 
| 84 117 | 
             
                self.allocator.copyin(self._buf, mv)
         | 
| 85 118 | 
             
                return self
         | 
| 86 | 
            -
               | 
| 87 | 
            -
             | 
| 88 | 
            -
             | 
| 89 | 
            -
                 | 
| 90 | 
            -
                 | 
| 91 | 
            -
             | 
| 92 | 
            -
             | 
| 93 | 
            -
                 | 
| 94 | 
            -
                return  | 
| 95 | 
            -
             | 
| 96 | 
            -
            def _internal_buffer_copy(dest:Buffer, src:Buffer):
         | 
| 97 | 
            -
              if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator):  # noqa: E721
         | 
| 98 | 
            -
                # fast path, used on HIP between GPUs
         | 
| 99 | 
            -
                # NOTE: it's important we use the dest device here to ensure the transfer is ready
         | 
| 100 | 
            -
                Device[src.device].synchronize()   # TODO: async this
         | 
| 101 | 
            -
                dest.allocator.transfer(dest._buf, src._buf, dest.size*dest.dtype.itemsize)
         | 
| 102 | 
            -
                return
         | 
| 103 | 
            -
              if getenv("FROM_BUFFER") and hasattr(dest.allocator, 'from_buffer') and hasattr(dest.allocator, 'transfer') and hasattr(src.allocator, 'as_buffer'):
         | 
| 104 | 
            -
                # fast path, used on Metal in OS X Sonoma
         | 
| 105 | 
            -
                # NOTE: this is *only* faster if the pages from disk are already loaded into memory
         | 
| 106 | 
            -
                fb = dest.allocator.from_buffer(src.allocator.as_buffer(src._buf))
         | 
| 107 | 
            -
                if fb:
         | 
| 108 | 
            -
                  dest.allocator.transfer(dest._buf, fb, dest.size*dest.dtype.itemsize)
         | 
| 109 | 
            -
                  return
         | 
| 110 | 
            -
              if hasattr(dest.allocator, 'copy_from_fd') and src.device.startswith("DISK") and src.size*src.dtype.itemsize >= 4096 and src._buf.ud.fd is not None:
         | 
| 111 | 
            -
                dest.allocator.copy_from_fd(dest._buf, src._buf.ud.fd, src._buf.offset, src.size*src.dtype.itemsize)
         | 
| 112 | 
            -
              elif hasattr(dest.allocator, 'as_buffer'):
         | 
| 113 | 
            -
                # fast(ish) path, uses readinto in diskbuffers
         | 
| 114 | 
            -
                src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
         | 
| 115 | 
            -
              elif hasattr(src.allocator, 'as_buffer'):
         | 
| 116 | 
            -
                dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf))
         | 
| 117 | 
            -
              else:
         | 
| 118 | 
            -
                # slow path, allocates a CPU buffer
         | 
| 119 | 
            -
                dest.copyin(src.toCPU().data)
         | 
| 120 | 
            -
             | 
| 121 | 
            -
            class _BufferCopy(JITRunner):
         | 
| 122 | 
            -
              # TODO: make wait work
         | 
| 123 | 
            -
              def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False):
         | 
| 124 | 
            -
                dest, src = rawbufs
         | 
| 125 | 
            -
                assert dest.size == src.size, f"buffer copy size mismatch, {dest.size} != {src.size}"
         | 
| 126 | 
            -
                assert dest.dtype == src.dtype, f"buffer copy dtype mismatch, {dest.dtype} != {src.dtype}"
         | 
| 127 | 
            -
                st = time.perf_counter()
         | 
| 128 | 
            -
                _internal_buffer_copy(dest, src)
         | 
| 129 | 
            -
                et = None
         | 
| 130 | 
            -
                if wait or DEBUG >= 2:
         | 
| 131 | 
            -
                  Device[dest.device].synchronize()
         | 
| 132 | 
            -
                  et = time.perf_counter() - st
         | 
| 133 | 
            -
                update_stats(colored(f"copy {dest.size:8d}, {dest.device[:7]:>7s} <- {src.device[:7]:7s}", "yellow"), 0, dest.size*dest.dtype.itemsize, {}, et, 2, jit, device=dest.device)  # noqa: E501
         | 
| 134 | 
            -
            BufferCopy = _BufferCopy()
         | 
| 119 | 
            +
              def copyout(self, mv:memoryview) -> memoryview:
         | 
| 120 | 
            +
                mv = flat_mv(mv)
         | 
| 121 | 
            +
                assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
         | 
| 122 | 
            +
                assert self.is_allocated(), "can't copyout unallocated buffer"
         | 
| 123 | 
            +
                self.allocator.copyout(mv, self._buf)
         | 
| 124 | 
            +
                return mv
         | 
| 125 | 
            +
              def view(self, size:int, dtype:DType, offset:int) -> Buffer:
         | 
| 126 | 
            +
                assert offset < self.nbytes, "offset must be less than nbytes"
         | 
| 127 | 
            +
                if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
         | 
| 128 | 
            +
                return Buffer(self.device, size, dtype, base=self, offset=offset)
         | 
| 135 129 |  | 
| 136 130 | 
             
            # TODO: size, dest, src are the same type. can we enforce this?
         | 
| 137 | 
            -
            sz_type = Union[ImageDType, int]
         | 
| 138 131 | 
             
            class Allocator:
         | 
| 139 | 
            -
              def alloc(self, size: | 
| 132 | 
            +
              def alloc(self, size:int, options:Optional[BufferOptions]=None):
         | 
| 140 133 | 
             
                assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
         | 
| 141 | 
            -
                return self. | 
| 142 | 
            -
              def _alloc(self, size:int): raise NotImplementedError("need alloc")
         | 
| 143 | 
            -
              def  | 
| 144 | 
            -
             | 
| 145 | 
            -
              def _free(self, opaque): pass
         | 
| 134 | 
            +
                return self._alloc(size, options if options is not None else BufferOptions())
         | 
| 135 | 
            +
              def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
         | 
| 136 | 
            +
              def free(self, opaque, size:int, options:Optional[BufferOptions]=None):
         | 
| 137 | 
            +
                self._free(opaque, options if options is not None else BufferOptions())
         | 
| 138 | 
            +
              def _free(self, opaque, options:BufferOptions): pass  # if opaque is a Python object, you don't need a free
         | 
| 146 139 | 
             
              def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
         | 
| 147 140 | 
             
              def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
         | 
| 148 141 |  | 
| 149 142 | 
             
            class LRUAllocator(Allocator):  # pylint: disable=abstract-method
         | 
| 150 | 
            -
              def __init__(self): self.cache: Dict[ | 
| 151 | 
            -
              def alloc(self, size: | 
| 152 | 
            -
                if len(c := self.cache[size]): return c.pop()
         | 
| 153 | 
            -
                try:
         | 
| 154 | 
            -
             | 
| 155 | 
            -
                except MemoryError:
         | 
| 143 | 
            +
              def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
         | 
| 144 | 
            +
              def alloc(self, size:int, options:Optional[BufferOptions]=None):
         | 
| 145 | 
            +
                if len(c := self.cache[(size, options)]): return c.pop()
         | 
| 146 | 
            +
                try: return super().alloc(size, options)
         | 
| 147 | 
            +
                except (RuntimeError, MemoryError):
         | 
| 156 148 | 
             
                  self.free_cache()
         | 
| 157 | 
            -
                  return super().alloc(size)
         | 
| 149 | 
            +
                  return super().alloc(size, options)
         | 
| 158 150 | 
             
              def free_cache(self):
         | 
| 159 | 
            -
                for opaques in self.cache. | 
| 160 | 
            -
                  for opaque in opaques:  | 
| 151 | 
            +
                for (sz,options),opaques in self.cache.items():
         | 
| 152 | 
            +
                  for opaque in opaques: super().free(opaque, sz, options)
         | 
| 161 153 | 
             
                  opaques.clear()
         | 
| 162 | 
            -
              def free(self, opaque:Any, size: | 
| 163 | 
            -
                if getenv("LRU", 1): self.cache[size].append(opaque)
         | 
| 164 | 
            -
                else:  | 
| 154 | 
            +
              def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None):
         | 
| 155 | 
            +
                if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
         | 
| 156 | 
            +
                else: super().free(opaque, size, options)
         | 
| 165 157 |  | 
| 166 158 | 
             
            class _MallocAllocator(LRUAllocator):
         | 
| 167 | 
            -
              def _alloc(self, size:int): return (ctypes.c_uint8 * size)()
         | 
| 159 | 
            +
              def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)()
         | 
| 168 160 | 
             
              def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
         | 
| 169 161 | 
             
              def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
         | 
| 170 162 | 
             
              def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
         | 
| 171 | 
            -
             | 
| 172 | 
            -
             | 
| 173 | 
            -
            # **************** for Interpreted Devices ****************
         | 
| 174 | 
            -
             | 
| 175 | 
            -
            class InterpretedASTRunner(JITRunner):
         | 
| 176 | 
            -
              def __init__(self, ast:LazyOp, fxn:Callable):
         | 
| 177 | 
            -
                super().__init__()
         | 
| 178 | 
            -
                self.fxn = fxn
         | 
| 179 | 
            -
                info = get_lazyop_info(ast)
         | 
| 180 | 
            -
                self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
         | 
| 181 | 
            -
             | 
| 182 | 
            -
              def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> float:
         | 
| 183 | 
            -
                st = time.perf_counter()
         | 
| 184 | 
            -
                rawbufs[0]._buf = self.fxn([x._buf for x in rawbufs[1:]], var_vals)
         | 
| 185 | 
            -
                et = time.perf_counter() - st
         | 
| 186 | 
            -
                update_stats(f"<interpreted {rawbufs[0].size}>", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, device=rawbufs[0].device)
         | 
| 187 | 
            -
                return et
         | 
| 188 | 
            -
             | 
| 189 | 
            -
            class Interpreted:
         | 
| 190 | 
            -
              def __init__(self, allocator: Allocator, fxn_for_op:Dict[Op, Callable]):
         | 
| 191 | 
            -
                self.allocator, self.fxn_for_op = allocator, fxn_for_op
         | 
| 192 | 
            -
                self.synchronize, self.codegen, self.graph = lambda: None, None, None
         | 
| 193 | 
            -
             | 
| 194 | 
            -
              @functools.lru_cache(None)    # pylint: disable=method-cache-max-size-none
         | 
| 195 | 
            -
              def get_runner(self, ast:LazyOp) -> InterpretedASTRunner: return _get_interpreted_fxn(self.fxn_for_op, ast)
         | 
| 196 | 
            -
             | 
| 197 | 
            -
            def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> InterpretedASTRunner:
         | 
| 198 | 
            -
              if DEBUG >= 3:
         | 
| 199 | 
            -
                from tinygrad.graph import print_tree
         | 
| 200 | 
            -
                print_tree(ast)
         | 
| 201 | 
            -
              tglob: Dict[str, Any] = {"Variable": Variable}
         | 
| 202 | 
            -
             | 
| 203 | 
            -
              @functools.lru_cache(None)
         | 
| 204 | 
            -
              def gstr(x:Any, nm=None) -> str:
         | 
| 205 | 
            -
                if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
         | 
| 206 | 
            -
                  str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
         | 
| 207 | 
            -
                  # TODO: (Variable - Variable) might create NumNode. can we remove it?
         | 
| 208 | 
            -
                  return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
         | 
| 209 | 
            -
                ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
         | 
| 210 | 
            -
                tglob[ret] = x
         | 
| 211 | 
            -
                return ret
         | 
| 212 | 
            -
             | 
| 213 | 
            -
              lines: List[str] = []
         | 
| 214 | 
            -
              @functools.lru_cache(None)
         | 
| 215 | 
            -
              def _interpret_ast(ast:LazyOp) -> str:
         | 
| 216 | 
            -
                # TODO: shortcutted store won't work with strides
         | 
| 217 | 
            -
                if ast.op == BufferOps.STORE: return _interpret_ast(ast.src[0])
         | 
| 218 | 
            -
                if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM:
         | 
| 219 | 
            -
                  if ast.src[0].op == BinaryOps.MUL: ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
         | 
| 220 | 
            -
                  if (castop:=ast.src[0]).op == UnaryOps.CAST and (mulop:=castop.src[0]).op == BinaryOps.MUL:
         | 
| 221 | 
            -
                    # MULACC with acc cast rewrite: MUL -> CAST -> SUM => CAST -> MULACC
         | 
| 222 | 
            -
                    ast = LazyOp(TernaryOps.MULACC, tuple(LazyOp(UnaryOps.CAST, (s, ), castop.arg) for s in mulop.src), ast.arg)
         | 
| 223 | 
            -
             | 
| 224 | 
            -
                if ast.op in BufferOps:
         | 
| 225 | 
            -
                  if ast.op == BufferOps.CONST: tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})"
         | 
| 226 | 
            -
                  else: tmp = f"{gstr(fxn_for_op[UnaryOps.CAST], UnaryOps.CAST)}(inputs[{ast.arg.idx-1}], ({gstr(ast.arg.dtype)}, True))"
         | 
| 227 | 
            -
                  for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})"
         | 
| 228 | 
            -
                else:
         | 
| 229 | 
            -
                  tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join([_interpret_ast(src) for src in ast.src] + ([gstr(ast.arg)] if ast.arg else []))})"
         | 
| 230 | 
            -
             | 
| 231 | 
            -
                ret = f"a{len(lines)}"
         | 
| 232 | 
            -
                lines.append(f"  {ret} = {tmp}")
         | 
| 233 | 
            -
                return ret
         | 
| 163 | 
            +
              def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size])
         | 
| 234 164 |  | 
| 235 | 
            -
             | 
| 236 | 
            -
              src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f"  return {ret}"])
         | 
| 237 | 
            -
              if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
         | 
| 238 | 
            -
              exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
         | 
| 239 | 
            -
              return InterpretedASTRunner(ast, tglob['run'])
         | 
| 165 | 
            +
            MallocAllocator = _MallocAllocator()
         | 
| 240 166 |  | 
| 241 167 | 
             
            # **************** for Compiled Devices ****************
         | 
| 242 168 |  | 
| 243 | 
            -
            class  | 
| 244 | 
            -
              def __init__(self, ast:Optional[LazyOp], name:str, prg:str, lib:bytes, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None):  # noqa: E501
         | 
| 245 | 
            -
                super().__init__()
         | 
| 246 | 
            -
                if DEBUG >= 4: print(prg)
         | 
| 247 | 
            -
                if global_size is not None: global_size = global_size + [1]*(3-len(global_size))
         | 
| 248 | 
            -
                if local_size is not None: local_size = local_size + [1]*(3-len(local_size))
         | 
| 249 | 
            -
                self.name, self.display_name, self.prg, self.lib, self.global_size, self.local_size, self.first_run = \
         | 
| 250 | 
            -
                  to_function_name(name), name, prg, lib, global_size, local_size, True
         | 
| 251 | 
            -
                self.vars: List[Variable] = []
         | 
| 252 | 
            -
                if ast:
         | 
| 253 | 
            -
                  info = get_lazyop_info(ast)
         | 
| 254 | 
            -
                  self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
         | 
| 255 | 
            -
                  self.vars = ast.vars()
         | 
| 256 | 
            -
                  assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}"
         | 
| 257 | 
            -
             | 
| 258 | 
            -
              def build(self, runtime):
         | 
| 259 | 
            -
                self.clprg = runtime(self.name, self.lib)
         | 
| 260 | 
            -
                return self
         | 
| 169 | 
            +
            class CompileError(Exception): pass
         | 
| 261 170 |  | 
| 262 | 
            -
             | 
| 263 | 
            -
             | 
| 264 | 
            -
             | 
| 265 | 
            -
             | 
| 266 | 
            -
             | 
| 267 | 
            -
             | 
| 268 | 
            -
             | 
| 269 | 
            -
             | 
| 270 | 
            -
             | 
| 271 | 
            -
                  from tinygrad.features.search import optimize_local_size
         | 
| 272 | 
            -
                  local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs)
         | 
| 273 | 
            -
                  global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
         | 
| 274 | 
            -
                lra = {}
         | 
| 275 | 
            -
                if global_size: lra['global_size'] = global_size
         | 
| 276 | 
            -
                if local_size: lra['local_size'] = local_size
         | 
| 277 | 
            -
                et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2)
         | 
| 278 | 
            -
                if do_update_stats: update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra, device=rawbufs[0].device, first_run=self.first_run)  # noqa: E501
         | 
| 279 | 
            -
                self.first_run = False
         | 
| 280 | 
            -
                return et
         | 
| 171 | 
            +
            class Compiler:
         | 
| 172 | 
            +
              def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
         | 
| 173 | 
            +
              def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
         | 
| 174 | 
            +
              def compile_cached(self, src:str) -> bytes:
         | 
| 175 | 
            +
                if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
         | 
| 176 | 
            +
                  assert not getenv("ASSERT_COMPILE"), f"tried to compile with ASSERT_COMPILE set\n{src}"
         | 
| 177 | 
            +
                  lib = self.compile(src)
         | 
| 178 | 
            +
                  if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
         | 
| 179 | 
            +
                return lib
         | 
| 281 180 |  | 
| 282 181 | 
             
            class Compiled:
         | 
| 283 | 
            -
              def __init__(self,  | 
| 284 | 
            -
                self. | 
| 182 | 
            +
              def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
         | 
| 183 | 
            +
                self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
         | 
| 184 | 
            +
                self.renderer = renderer or Renderer()
         | 
| 285 185 | 
             
              def synchronize(self): pass  # override this in your device
         | 
| 286 186 |  | 
| 287 | 
            -
             | 
| 288 | 
            -
             | 
| 289 | 
            -
             | 
| 290 | 
            -
             | 
| 291 | 
            -
             | 
| 292 | 
            -
             | 
| 293 | 
            -
             | 
| 294 | 
            -
             | 
| 295 | 
            -
             | 
| 296 | 
            -
             | 
| 297 | 
            -
             | 
| 298 | 
            -
             | 
| 299 | 
            -
             | 
| 300 | 
            -
             | 
| 301 | 
            -
                 | 
| 302 | 
            -
             | 
| 303 | 
            -
             | 
| 304 | 
            -
                 | 
| 305 | 
            -
                 | 
| 306 | 
            -
             | 
| 307 | 
            -
                 | 
| 308 | 
            -
             | 
| 309 | 
            -
             | 
| 310 | 
            -
             | 
| 311 | 
            -
             | 
| 312 | 
            -
             | 
| 313 | 
            -
             | 
| 314 | 
            -
             | 
| 315 | 
            -
             | 
| 316 | 
            -
             | 
| 317 | 
            -
             | 
| 318 | 
            -
             | 
| 319 | 
            -
             | 
| 320 | 
            -
             | 
| 321 | 
            -
             | 
| 322 | 
            -
             | 
| 323 | 
            -
             | 
| 324 | 
            -
             | 
| 325 | 
            -
               | 
| 326 | 
            -
             | 
| 187 | 
            +
            # **************** for HCQ Compatible Devices ****************
         | 
| 188 | 
            +
             | 
| 189 | 
            +
            @contextlib.contextmanager
         | 
| 190 | 
            +
            def hcq_profile(dev, queue_type, enabled, desc):
         | 
| 191 | 
            +
              st, en = (dev._get_signal(), dev._get_signal()) if enabled else (None, None)
         | 
| 192 | 
            +
              if enabled: queue_type().timestamp(st).submit(dev)
         | 
| 193 | 
            +
              try: yield (st, en)
         | 
| 194 | 
            +
              finally:
         | 
| 195 | 
            +
                if enabled: queue_type().timestamp(en).submit(dev)
         | 
| 196 | 
            +
                if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
         | 
| 197 | 
            +
             | 
| 198 | 
            +
            class HCQCompatCompiled(Compiled):
         | 
| 199 | 
            +
              def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime, comp_queue_t, copy_queue_t, timeline_signals):
         | 
| 200 | 
            +
                self.hw_compute_queue_t, self.hw_copy_queue_t = comp_queue_t, copy_queue_t
         | 
| 201 | 
            +
                self.timeline_value: int = 1
         | 
| 202 | 
            +
                self.timeline_signal, self._shadow_timeline_signal = timeline_signals
         | 
| 203 | 
            +
                self.sig_prof_records: List[Tuple[Any, Any, str, bool]] = []
         | 
| 204 | 
            +
                self.raw_prof_records: List[Tuple[int, int, str, bool]] = []
         | 
| 205 | 
            +
                if PROFILE: self._prof_setup()
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                from tinygrad.runtime.graph.hcq import HCQGraph
         | 
| 208 | 
            +
                super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
              @classmethod
         | 
| 211 | 
            +
              def _read_signal(self, sig): raise NotImplementedError("need _read_signal") # reads a value for a signal
         | 
| 212 | 
            +
             | 
| 213 | 
            +
              @classmethod
         | 
| 214 | 
            +
              def _read_timestamp(self, sig): raise NotImplementedError("need _read_timestamp") # reads a timestamp for a signal
         | 
| 215 | 
            +
             | 
| 216 | 
            +
              @classmethod
         | 
| 217 | 
            +
              def _set_signal(self, sig, value): raise NotImplementedError("need _set_signal") # sets a value for a signal
         | 
| 218 | 
            +
             | 
| 219 | 
            +
              @classmethod
         | 
| 220 | 
            +
              def _get_signal(self, value=0, **kwargs): raise NotImplementedError("need _get_signal") # allocates a new signal
         | 
| 221 | 
            +
             | 
| 222 | 
            +
              @classmethod
         | 
| 223 | 
            +
              def _wait_signal(self, signal, value=0, timeout=10000): raise NotImplementedError("need _wait_signal") # waits for a signal value
         | 
| 224 | 
            +
             | 
| 225 | 
            +
              def _gpu2cpu_time(self, gpu_time, is_copy): raise NotImplementedError("need _gpu2cpu_time")
         | 
| 226 | 
            +
             | 
| 227 | 
            +
              def _prof_setup(self):
         | 
| 228 | 
            +
                self.profile_logger = ProfileLogger()
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                def _sync_queue(q_t):
         | 
| 231 | 
            +
                  q_t().timestamp(self.timeline_signal).signal(self.timeline_signal, self.timeline_value).submit(self)
         | 
| 232 | 
            +
                  self.timeline_value += 1
         | 
| 233 | 
            +
                  cpu_start_time = time.perf_counter_ns() / 1e3
         | 
| 234 | 
            +
                  self._wait_signal(self.timeline_signal, self.timeline_value - 1)
         | 
| 235 | 
            +
                  return cpu_start_time, self._read_timestamp(self.timeline_signal)
         | 
| 236 | 
            +
                self.cpu_start_time, self.gpu_start_time = _sync_queue(self.hw_compute_queue_t)
         | 
| 237 | 
            +
                self.copy_cpu_start_time, self.copy_gpu_start_time = _sync_queue(self.hw_copy_queue_t)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                atexit.register(self._prof_finalize)
         | 
| 240 | 
            +
             | 
| 241 | 
            +
              def _prof_process_events(self):
         | 
| 242 | 
            +
                self.raw_prof_records += [(self._read_timestamp(st), self._read_timestamp(en), name, is_cp) for st, en, name, is_cp in self.sig_prof_records]
         | 
| 243 | 
            +
                for st, en, _, _ in self.sig_prof_records: self.signals_pool += [st, en] # type: ignore
         | 
| 244 | 
            +
                self.sig_prof_records = []
         | 
| 245 | 
            +
             | 
| 246 | 
            +
              def _prof_finalize(self):
         | 
| 247 | 
            +
                for st, en, name, is_cp in self.raw_prof_records:
         | 
| 248 | 
            +
                  self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, ["COMPUTE", "DMA"][is_cp])]
         | 
| 249 | 
            +
                del self.profile_logger
         | 
| 250 | 
            +
             | 
| 251 | 
            +
              def _wrap_timeline_signal(self):
         | 
| 252 | 
            +
                self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
         | 
| 253 | 
            +
                self._set_signal(self.timeline_signal, 0)
         | 
| 254 | 
            +
                cast(HCQCompatAllocator, self.allocator).b_timeline = [0] * len(cast(HCQCompatAllocator, self.allocator).b)
         | 
| 255 | 
            +
             | 
| 256 | 
            +
            class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
         | 
| 257 | 
            +
              def __init__(self, device, batch_size=(2 << 20), batch_cnt=32):
         | 
| 258 | 
            +
                self.device = device
         | 
| 259 | 
            +
                self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
         | 
| 260 | 
            +
                self.b_timeline, self.b_next = [0] * len(self.b), 0
         | 
| 261 | 
            +
                super().__init__()
         | 
| 262 | 
            +
             | 
| 263 | 
            +
              def copyin(self, dest, src: memoryview):
         | 
| 264 | 
            +
                with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
         | 
| 265 | 
            +
                  for i in range(0, src.nbytes, self.b[0].size):
         | 
| 266 | 
            +
                    self.b_next = (self.b_next + 1) % len(self.b)
         | 
| 267 | 
            +
                    self.device._wait_signal(self.device.timeline_signal, self.b_timeline[self.b_next])
         | 
| 268 | 
            +
                    ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
         | 
| 269 | 
            +
                    self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
         | 
| 270 | 
            +
                                                 .copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
         | 
| 271 | 
            +
                                                 .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
         | 
| 272 | 
            +
                    self.b_timeline[self.b_next] = self.device.timeline_value
         | 
| 273 | 
            +
                    self.device.timeline_value += 1
         | 
| 274 | 
            +
             | 
| 275 | 
            +
              def copy_from_disk(self, dest, src, size):
         | 
| 276 | 
            +
                def _get_temp_buf():
         | 
| 277 | 
            +
                  # Check if the next buffer is safe to be used (its signal has passed) and reserve it.
         | 
| 278 | 
            +
                  if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device._read_signal(self.device.timeline_signal):
         | 
| 279 | 
            +
                    self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
         | 
| 280 | 
            +
                    return (self.b[self.b_next].va_addr, self.b_next)
         | 
| 281 | 
            +
                  return None
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
         | 
| 284 | 
            +
                  for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
         | 
| 285 | 
            +
                    self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
         | 
| 286 | 
            +
                                                 .copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
         | 
| 287 | 
            +
                                                 .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
         | 
| 288 | 
            +
                    self.b_timeline[batch_info[1]] = self.device.timeline_value
         | 
| 289 | 
            +
                    self.device.timeline_value += 1
         | 
| 290 | 
            +
             | 
| 291 | 
            +
              def copyout(self, dest:memoryview, src):
         | 
| 292 | 
            +
                self.device.synchronize()
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
         | 
| 295 | 
            +
                  for i in range(0, dest.nbytes, self.b[0].size):
         | 
| 296 | 
            +
                    self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
         | 
| 297 | 
            +
                                                 .copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
         | 
| 298 | 
            +
                                                 .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
         | 
| 299 | 
            +
                    self.device._wait_signal(self.device.timeline_signal, self.device.timeline_value)
         | 
| 300 | 
            +
                    self.device.timeline_value += 1
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
         | 
| 303 | 
            +
             | 
| 304 | 
            +
              def transfer(self, dest, src, sz: int, src_dev, dest_dev):
         | 
| 305 | 
            +
                src_dev._gpu_map(dest)
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
         | 
| 308 | 
            +
                  src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
         | 
| 309 | 
            +
                                           .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
         | 
| 310 | 
            +
                                           .copy(dest.va_addr, src.va_addr, sz) \
         | 
| 311 | 
            +
                                           .signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
         | 
| 312 | 
            +
                  src_dev.timeline_value += 1
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                if src_dev != dest_dev:
         | 
| 315 | 
            +
                  dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
         | 
| 316 | 
            +
                                               .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
         | 
| 317 | 
            +
                                               .signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
         | 
| 318 | 
            +
                  dest_dev.timeline_value += 1
         | 
| 319 | 
            +
             | 
| 320 | 
            +
              def offset(self, buf, size:int, offset:int): return type(buf)(base=buf.base + offset, va_addr=buf.va_addr + offset, length=size, size=size)
         |