tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
 - tinygrad/codegen/kernel.py +572 -83
 - tinygrad/codegen/linearizer.py +415 -395
 - tinygrad/codegen/uops.py +415 -0
 - tinygrad/device.py +183 -0
 - tinygrad/dtype.py +113 -0
 - tinygrad/engine/__init__.py +0 -0
 - tinygrad/engine/graph.py +100 -0
 - tinygrad/engine/jit.py +195 -0
 - tinygrad/engine/realize.py +191 -0
 - tinygrad/engine/schedule.py +362 -0
 - tinygrad/engine/search.py +196 -0
 - tinygrad/{mlops.py → function.py} +76 -55
 - tinygrad/helpers.py +196 -89
 - tinygrad/lazy.py +210 -371
 - tinygrad/multi.py +169 -0
 - tinygrad/nn/__init__.py +202 -22
 - tinygrad/nn/datasets.py +7 -0
 - tinygrad/nn/optim.py +112 -32
 - tinygrad/nn/state.py +136 -39
 - tinygrad/ops.py +119 -202
 - tinygrad/renderer/__init__.py +61 -0
 - tinygrad/renderer/assembly.py +276 -0
 - tinygrad/renderer/cstyle.py +353 -166
 - tinygrad/renderer/llvmir.py +150 -138
 - tinygrad/runtime/autogen/amd_gpu.py +1900 -0
 - tinygrad/runtime/autogen/comgr.py +865 -0
 - tinygrad/runtime/autogen/cuda.py +5923 -0
 - tinygrad/runtime/autogen/hip.py +5909 -0
 - tinygrad/runtime/autogen/hsa.py +5761 -0
 - tinygrad/runtime/autogen/kfd.py +812 -0
 - tinygrad/runtime/autogen/nv_gpu.py +33328 -0
 - tinygrad/runtime/autogen/opencl.py +1795 -0
 - tinygrad/runtime/driver/hip_comgr.py +47 -0
 - tinygrad/runtime/driver/hsa.py +143 -0
 - tinygrad/runtime/graph/clang.py +38 -0
 - tinygrad/runtime/graph/cuda.py +81 -0
 - tinygrad/runtime/graph/hcq.py +143 -0
 - tinygrad/runtime/graph/hsa.py +171 -0
 - tinygrad/runtime/graph/metal.py +75 -0
 - tinygrad/runtime/ops_amd.py +564 -0
 - tinygrad/runtime/ops_clang.py +24 -77
 - tinygrad/runtime/ops_cuda.py +175 -89
 - tinygrad/runtime/ops_disk.py +56 -33
 - tinygrad/runtime/ops_gpu.py +92 -95
 - tinygrad/runtime/ops_hsa.py +278 -0
 - tinygrad/runtime/ops_llvm.py +39 -60
 - tinygrad/runtime/ops_metal.py +92 -74
 - tinygrad/runtime/ops_npy.py +9 -0
 - tinygrad/runtime/ops_nv.py +630 -0
 - tinygrad/runtime/ops_python.py +204 -0
 - tinygrad/shape/shapetracker.py +86 -254
 - tinygrad/shape/symbolic.py +166 -141
 - tinygrad/shape/view.py +296 -0
 - tinygrad/tensor.py +2619 -448
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
 - tinygrad-0.9.0.dist-info/METADATA +227 -0
 - tinygrad-0.9.0.dist-info/RECORD +60 -0
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
 - tinygrad/codegen/assembly.py +0 -190
 - tinygrad/codegen/optimizer.py +0 -379
 - tinygrad/codegen/search.py +0 -72
 - tinygrad/graph.py +0 -83
 - tinygrad/jit.py +0 -57
 - tinygrad/nn/image.py +0 -100
 - tinygrad/renderer/assembly_arm64.py +0 -169
 - tinygrad/renderer/assembly_ptx.py +0 -98
 - tinygrad/renderer/wgsl.py +0 -53
 - tinygrad/runtime/lib.py +0 -113
 - tinygrad/runtime/ops_cpu.py +0 -51
 - tinygrad/runtime/ops_hip.py +0 -82
 - tinygrad/runtime/ops_shm.py +0 -29
 - tinygrad/runtime/ops_torch.py +0 -30
 - tinygrad/runtime/ops_webgpu.py +0 -45
 - tinygrad-0.7.0.dist-info/METADATA +0 -212
 - tinygrad-0.7.0.dist-info/RECORD +0 -40
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
 
    
        tinygrad/runtime/ops_clang.py
    CHANGED
    
    | 
         @@ -1,81 +1,28 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            import  
     | 
| 
       2 
     | 
    
         
            -
            from  
     | 
| 
       3 
     | 
    
         
            -
            from tinygrad. 
     | 
| 
       4 
     | 
    
         
            -
            from tinygrad. 
     | 
| 
       5 
     | 
    
         
            -
            from tinygrad.runtime.lib import RawMallocBuffer
         
     | 
| 
       6 
     | 
    
         
            -
            from tinygrad.codegen.linearizer import LinearizerOptions
         
     | 
| 
       7 
     | 
    
         
            -
            from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
         
     | 
| 
       8 
     | 
    
         
            -
            import struct
         
     | 
| 
       9 
     | 
    
         
            -
            import numpy as np
         
     | 
| 
      
 1 
     | 
    
         
            +
            import ctypes, subprocess, pathlib, tempfile
         
     | 
| 
      
 2 
     | 
    
         
            +
            from tinygrad.device import Compiled, Compiler, MallocAllocator
         
     | 
| 
      
 3 
     | 
    
         
            +
            from tinygrad.helpers import cpu_time_execution, DEBUG, cpu_objdump
         
     | 
| 
      
 4 
     | 
    
         
            +
            from tinygrad.renderer.cstyle import ClangRenderer
         
     | 
| 
       10 
5 
     | 
    
         | 
| 
       11 
     | 
    
         
            -
             
     | 
| 
       12 
     | 
    
         
            -
             
     | 
| 
       13 
     | 
    
         
            -
             
     | 
| 
       14 
     | 
    
         
            -
             
     | 
| 
       15 
     | 
    
         
            -
             
     | 
| 
       16 
     | 
    
         
            -
             
     | 
| 
       17 
     | 
    
         
            -
             
     | 
| 
       18 
     | 
    
         
            -
            }[platform.system()]
         
     | 
| 
       19 
     | 
    
         
            -
             
     | 
| 
       20 
     | 
    
         
            -
            CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#define bool uchar\n'
         
     | 
| 
       21 
     | 
    
         
            -
            ADDRESS = 0x10000
         
     | 
| 
       22 
     | 
    
         
            -
             
     | 
| 
       23 
     | 
    
         
            -
            # Unicorn doesn't support external calls
         
     | 
| 
       24 
     | 
    
         
            -
            def align(addr): return (addr+4095) & ~(4095)
         
     | 
| 
       25 
     | 
    
         
            -
            mock_lm = {"sinf": np.sin, "sqrtf": np.sqrt, "exp2f": np.exp2, "log2f": np.log2}
         
     | 
| 
       26 
     | 
    
         
            -
            def emulate_ext_calls(fn, uc, address, size, user_data):
         
     | 
| 
       27 
     | 
    
         
            -
              s_in = struct.unpack('f', struct.pack('I', uc.reg_read(getattr(arm64_const, f'UC_ARM64_REG_S{fn[2][1:]}'))))[0]
         
     | 
| 
       28 
     | 
    
         
            -
              uc.reg_write(getattr(arm64_const, f'UC_ARM64_REG_S{fn[1][1:]}'), struct.unpack('I', struct.pack('f', mock_lm[fn[0]](s_in)))[0])  # type: ignore
         
     | 
| 
      
 6 
     | 
    
         
            +
            class ClangCompiler(Compiler):
         
     | 
| 
      
 7 
     | 
    
         
            +
              def compile(self, src:str) -> bytes:
         
     | 
| 
      
 8 
     | 
    
         
            +
                # TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here
         
     | 
| 
      
 9 
     | 
    
         
            +
                with tempfile.NamedTemporaryFile(delete=True) as output_file:
         
     | 
| 
      
 10 
     | 
    
         
            +
                  subprocess.check_output(['clang', '-include', 'tgmath.h', '-shared', '-march=native', '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-',
         
     | 
| 
      
 11 
     | 
    
         
            +
                                           '-o', str(output_file.name)], input=src.encode('utf-8'))
         
     | 
| 
      
 12 
     | 
    
         
            +
                  return pathlib.Path(output_file.name).read_bytes()
         
     | 
| 
       29 
13 
     | 
    
         | 
| 
       30 
14 
     | 
    
         
             
            class ClangProgram:
         
     | 
| 
       31 
     | 
    
         
            -
              def __init__(self, name:str,  
     | 
| 
       32 
     | 
    
         
            -
                 
     | 
| 
       33 
     | 
    
         
            -
                 
     | 
| 
       34 
     | 
    
         
            -
                # 
     | 
| 
       35 
     | 
    
         
            -
                 
     | 
| 
       36 
     | 
    
         
            -
             
     | 
| 
       37 
     | 
    
         
            -
             
     | 
| 
       38 
     | 
    
         
            -
             
     | 
| 
       39 
     | 
    
         
            -
             
     | 
| 
       40 
     | 
    
         
            -
                    prg = CLANG_PROGRAM_HEADER + prg
         
     | 
| 
       41 
     | 
    
         
            -
                    subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8'))
         
     | 
| 
       42 
     | 
    
         
            -
                    os.rename(tmp, fn)
         
     | 
| 
       43 
     | 
    
         
            -
                  else:
         
     | 
| 
       44 
     | 
    
         
            -
                    if CI and ARM64:
         
     | 
| 
       45 
     | 
    
         
            -
                      prg = prg.split('\n') # type: ignore
         
     | 
| 
       46 
     | 
    
         
            -
                      self.varsize = align(int(prg[0].split(" ")[1]))
         
     | 
| 
       47 
     | 
    
         
            -
                      self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'}
         
     | 
| 
       48 
     | 
    
         
            -
                      prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n'])
         
     | 
| 
       49 
     | 
    
         
            -
                      subprocess.check_output(args=('aarch64-linux-gnu-as -o '+tmp).split(), input=prg.encode('utf-8'))
         
     | 
| 
       50 
     | 
    
         
            -
                      subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+tmp+' '+fn+'.bin').split())
         
     | 
| 
       51 
     | 
    
         
            -
                      self.prg = open(fn + '.bin', 'rb').read()
         
     | 
| 
       52 
     | 
    
         
            -
                      return
         
     | 
| 
       53 
     | 
    
         
            -
                    subprocess.check_output(args=('as -o' + tmp).split(), input=prg.encode('utf-8'))
         
     | 
| 
       54 
     | 
    
         
            -
                    subprocess.check_output(args=('clang -lm -shared '+tmp+' -o'+fn).split())
         
     | 
| 
       55 
     | 
    
         
            -
                self.lib = ctypes.CDLL(fn)
         
     | 
| 
       56 
     | 
    
         
            -
                self.fxn = self.lib[name]
         
     | 
| 
       57 
     | 
    
         
            -
              def __call__(self, global_size, local_size, *args, wait=False):
         
     | 
| 
       58 
     | 
    
         
            -
                if wait: st = time.monotonic()
         
     | 
| 
       59 
     | 
    
         
            -
                if CI and ARM64:
         
     | 
| 
       60 
     | 
    
         
            -
                  mu = Uc(UC_ARCH_ARM64, UC_MODE_ARM)
         
     | 
| 
       61 
     | 
    
         
            -
                  total_mem = align(reduce(lambda total, arg: total + arg.size * arg.dtype.itemsize, args, len(self.prg)+self.varsize))
         
     | 
| 
       62 
     | 
    
         
            -
                  mu.mem_map(ADDRESS, total_mem)
         
     | 
| 
       63 
     | 
    
         
            -
                  for k, fn in self.ext_calls.items(): mu.hook_add(UC_HOOK_CODE, partial(emulate_ext_calls, fn), begin=k, end=k)
         
     | 
| 
       64 
     | 
    
         
            -
                  mu.mem_write(ADDRESS, self.prg + b''.join(bytes(arg._buf) for arg in args))
         
     | 
| 
       65 
     | 
    
         
            -
                  addr = ADDRESS + len(self.prg)
         
     | 
| 
       66 
     | 
    
         
            -
                  for i, arg in enumerate(args):
         
     | 
| 
       67 
     | 
    
         
            -
                    if i<=7:
         
     | 
| 
       68 
     | 
    
         
            -
                      mu.reg_write(getattr(arm64_const, f'UC_ARM64_REG_X{i}'), addr)
         
     | 
| 
       69 
     | 
    
         
            -
                    else:
         
     | 
| 
       70 
     | 
    
         
            -
                      # NOTE: In ARM, args beyond the first 8 are placed on the stack it also account for the stack red zone.
         
     | 
| 
       71 
     | 
    
         
            -
                      mu.mem_write(ADDRESS + total_mem - (len(args[8:])+2)*8 + 8*(i-8), addr.to_bytes(8, 'little'))
         
     | 
| 
       72 
     | 
    
         
            -
                    addr += arg.size * arg.dtype.itemsize
         
     | 
| 
       73 
     | 
    
         
            -
                  mu.reg_write(arm64_const.UC_ARM64_REG_SP, ADDRESS + total_mem - (len(args[8:])+2)*8)
         
     | 
| 
       74 
     | 
    
         
            -
                  mu.emu_start(ADDRESS, ADDRESS + len(self.prg))
         
     | 
| 
       75 
     | 
    
         
            -
                  args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize)
         
     | 
| 
       76 
     | 
    
         
            -
                else:
         
     | 
| 
       77 
     | 
    
         
            -
                  self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args])
         
     | 
| 
       78 
     | 
    
         
            -
                if wait: return time.monotonic()-st
         
     | 
| 
      
 15 
     | 
    
         
            +
              def __init__(self, name:str, lib:bytes):
         
     | 
| 
      
 16 
     | 
    
         
            +
                if DEBUG >= 6: cpu_objdump(lib)
         
     | 
| 
      
 17 
     | 
    
         
            +
                self.name, self.lib = name, lib
         
     | 
| 
      
 18 
     | 
    
         
            +
                # write to disk so we can load it
         
     | 
| 
      
 19 
     | 
    
         
            +
                with tempfile.NamedTemporaryFile(delete=True) as cached_file_path:
         
     | 
| 
      
 20 
     | 
    
         
            +
                  pathlib.Path(cached_file_path.name).write_bytes(lib)
         
     | 
| 
      
 21 
     | 
    
         
            +
                  self.fxn = ctypes.CDLL(str(cached_file_path.name))[name]
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
              def __call__(self, *bufs, vals=(), wait=False): return cpu_time_execution(lambda: self.fxn(*bufs, *vals), enable=wait)
         
     | 
| 
       79 
24 
     | 
    
         | 
| 
       80 
     | 
    
         
            -
             
     | 
| 
       81 
     | 
    
         
            -
             
     | 
| 
      
 25 
     | 
    
         
            +
            class ClangDevice(Compiled):
         
     | 
| 
      
 26 
     | 
    
         
            +
              def __init__(self, device:str):
         
     | 
| 
      
 27 
     | 
    
         
            +
                from tinygrad.runtime.graph.clang import ClangGraph
         
     | 
| 
      
 28 
     | 
    
         
            +
                super().__init__(device, MallocAllocator, ClangRenderer(), ClangCompiler("compile_clang"), ClangProgram, ClangGraph)
         
     | 
    
        tinygrad/runtime/ops_cuda.py
    CHANGED
    
    | 
         @@ -1,99 +1,185 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
             
     | 
| 
       2 
     | 
    
         
            -
             
     | 
| 
       3 
     | 
    
         
            -
             
     | 
| 
       4 
     | 
    
         
            -
            from  
     | 
| 
       5 
     | 
    
         
            -
             
     | 
| 
       6 
     | 
    
         
            -
            from tinygrad. 
     | 
| 
       7 
     | 
    
         
            -
            from tinygrad. 
     | 
| 
       8 
     | 
    
         
            -
            from tinygrad. 
     | 
| 
       9 
     | 
    
         
            -
            from tinygrad.renderer. 
     | 
| 
      
 1 
     | 
    
         
            +
            from __future__ import annotations
         
     | 
| 
      
 2 
     | 
    
         
            +
            import subprocess, hashlib, tempfile, ctypes, ctypes.util, functools, re
         
     | 
| 
      
 3 
     | 
    
         
            +
            from pathlib import Path
         
     | 
| 
      
 4 
     | 
    
         
            +
            from typing import Tuple, Optional, List
         
     | 
| 
      
 5 
     | 
    
         
            +
            import tinygrad.runtime.autogen.cuda as cuda
         
     | 
| 
      
 6 
     | 
    
         
            +
            from tinygrad.helpers import DEBUG, getenv, from_mv, to_char_p_p, init_c_var, init_c_struct_t, colored, cpu_time_execution
         
     | 
| 
      
 7 
     | 
    
         
            +
            from tinygrad.device import Compiled, Compiler, CompileError, BufferOptions, LRUAllocator, MallocAllocator
         
     | 
| 
      
 8 
     | 
    
         
            +
            from tinygrad.renderer.cstyle import CUDARenderer
         
     | 
| 
      
 9 
     | 
    
         
            +
            from tinygrad.renderer.assembly import PTXRenderer
         
     | 
| 
      
 10 
     | 
    
         
            +
            if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl  # noqa: F401
         
     | 
| 
       10 
11 
     | 
    
         | 
| 
       11 
12 
     | 
    
         
             
            def pretty_ptx(s):
         
     | 
| 
       12 
13 
     | 
    
         
             
              # all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
         
     | 
| 
       13 
     | 
    
         
            -
              s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers
         
     | 
| 
      
 14 
     | 
    
         
            +
              s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers  # noqa: E501
         
     | 
| 
       14 
15 
     | 
    
         
             
              s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
         
     | 
| 
       15 
16 
     | 
    
         
             
              s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
         
     | 
| 
       16 
     | 
    
         
            -
              s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers
         
     | 
| 
      
 17 
     | 
    
         
            +
              s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers  # noqa: E501
         
     | 
| 
       17 
18 
     | 
    
         
             
              s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
         
     | 
| 
       18 
19 
     | 
    
         
             
              s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
         
     | 
| 
       19 
20 
     | 
    
         
             
              return s
         
     | 
| 
       20 
     | 
    
         
            -
             
     | 
| 
       21 
     | 
    
         
            -
             
     | 
| 
       22 
     | 
    
         
            -
            if  
     | 
| 
       23 
     | 
    
         
            -
               
     | 
| 
       24 
     | 
    
         
            -
               
     | 
| 
       25 
     | 
    
         
            -
               
     | 
| 
       26 
     | 
    
         
            -
             
     | 
| 
       27 
     | 
    
         
            -
             
     | 
| 
       28 
     | 
    
         
            -
             
     | 
| 
       29 
     | 
    
         
            -
             
     | 
| 
       30 
     | 
    
         
            -
             
     | 
| 
       31 
     | 
    
         
            -
             
     | 
| 
       32 
     | 
    
         
            -
             
     | 
| 
       33 
     | 
    
         
            -
             
     | 
| 
       34 
     | 
    
         
            -
             
     | 
| 
       35 
     | 
    
         
            -
             
     | 
| 
       36 
     | 
    
         
            -
             
     | 
| 
       37 
     | 
    
         
            -
             
     | 
| 
       38 
     | 
    
         
            -
             
     | 
| 
       39 
     | 
    
         
            -
             
     | 
| 
       40 
     | 
    
         
            -
               
     | 
| 
       41 
     | 
    
         
            -
             
     | 
| 
       42 
     | 
    
         
            -
             
     | 
| 
       43 
     | 
    
         
            -
             
     | 
| 
       44 
     | 
    
         
            -
               
     | 
| 
       45 
     | 
    
         
            -
               
     | 
| 
       46 
     | 
    
         
            -
               
     | 
| 
       47 
     | 
    
         
            -
             
     | 
| 
       48 
     | 
    
         
            -
             
     | 
| 
       49 
     | 
    
         
            -
             
     | 
| 
       50 
     | 
    
         
            -
               
     | 
| 
       51 
     | 
    
         
            -
             
     | 
| 
       52 
     | 
    
         
            -
             
     | 
| 
       53 
     | 
    
         
            -
             
     | 
| 
       54 
     | 
    
         
            -
               
     | 
| 
       55 
     | 
    
         
            -
                 
     | 
| 
       56 
     | 
    
         
            -
                 
     | 
| 
       57 
     | 
    
         
            -
                 
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
            CUDACPU = getenv("CUDACPU") == 1
         
     | 
| 
      
 23 
     | 
    
         
            +
            if CUDACPU:
         
     | 
| 
      
 24 
     | 
    
         
            +
              gpuocelot_lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))
         
     | 
| 
      
 25 
     | 
    
         
            +
              gpuocelot_lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int]  # noqa: E501
         
     | 
| 
      
 26 
     | 
    
         
            +
              cuda.cuLaunchKernel = lambda src, gx, gy, gz, lx, ly, lz, shared, stream, unused_extra, args: gpuocelot_lib.ptx_run(src, len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]), lx, ly, lz, gx, gy, gz, shared)  # type: ignore  # noqa: E501
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
            def check(status):
         
     | 
| 
      
 29 
     | 
    
         
            +
              if status != 0: raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()}")  # noqa: E501
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
            def encode_args(args, vals) -> Tuple[ctypes.Structure, ctypes.Array]:
         
     | 
| 
      
 32 
     | 
    
         
            +
              c_args = init_c_struct_t(tuple([(f'f{i}', cuda.CUdeviceptr_v2) for i in range(len(args))] +
         
     | 
| 
      
 33 
     | 
    
         
            +
                                             [(f'v{i}', ctypes.c_int) for i in range(len(vals))]))(*args, *vals)
         
     | 
| 
      
 34 
     | 
    
         
            +
              vargs = (ctypes.c_void_p * 5)(ctypes.c_void_p(1), ctypes.cast(ctypes.byref(c_args), ctypes.c_void_p), ctypes.c_void_p(2),
         
     | 
| 
      
 35 
     | 
    
         
            +
                                            ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(0))
         
     | 
| 
      
 36 
     | 
    
         
            +
              return c_args, vargs
         
     | 
| 
      
 37 
     | 
    
         
            +
             
     | 
| 
      
 38 
     | 
    
         
            +
            def cu_time_execution(cb, enable=False) -> Optional[float]:
         
     | 
| 
      
 39 
     | 
    
         
            +
              if CUDACPU: return cpu_time_execution(cb, enable=enable)
         
     | 
| 
      
 40 
     | 
    
         
            +
              if not enable: return cb()
         
     | 
| 
      
 41 
     | 
    
         
            +
              evs = [init_c_var(cuda.CUevent(), lambda x: cuda.cuEventCreate(ctypes.byref(x), 0)) for _ in range(2)]
         
     | 
| 
      
 42 
     | 
    
         
            +
              cuda.cuEventRecord(evs[0], None)
         
     | 
| 
      
 43 
     | 
    
         
            +
              cb()
         
     | 
| 
      
 44 
     | 
    
         
            +
              cuda.cuEventRecord(evs[1], None)
         
     | 
| 
      
 45 
     | 
    
         
            +
              check(cuda.cuEventSynchronize(evs[1]))
         
     | 
| 
      
 46 
     | 
    
         
            +
              cuda.cuEventElapsedTime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1])
         
     | 
| 
      
 47 
     | 
    
         
            +
              for ev in evs: cuda.cuEventDestroy_v2(ev)
         
     | 
| 
      
 48 
     | 
    
         
            +
              return ret.value * 1e-3
         
     | 
| 
      
 49 
     | 
    
         
            +
             
     | 
| 
      
 50 
     | 
    
         
            +
            def _get_bytes(arg, get_str, get_sz, check) -> bytes:
         
     | 
| 
      
 51 
     | 
    
         
            +
              sz = init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x))))
         
     | 
| 
      
 52 
     | 
    
         
            +
              return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value)
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
            class PTXCompiler(Compiler):
         
     | 
| 
      
 55 
     | 
    
         
            +
              def __init__(self, arch:str):
         
     | 
| 
      
 56 
     | 
    
         
            +
                self.arch = arch
         
     | 
| 
      
 57 
     | 
    
         
            +
                self.version = "7.8" if arch >= "sm_89" else "7.5"
         
     | 
| 
      
 58 
     | 
    
         
            +
                super().__init__(f"compile_ptx_{self.arch}")
         
     | 
| 
      
 59 
     | 
    
         
            +
              def compile(self, src:str) -> bytes: return src.replace("TARGET", self.arch).replace("VERSION", self.version).encode()
         
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
      
 61 
     | 
    
         
            +
            class CUDACompiler(Compiler):
         
     | 
| 
      
 62 
     | 
    
         
            +
              def __init__(self, arch:str):
         
     | 
| 
      
 63 
     | 
    
         
            +
                self.arch = arch
         
     | 
| 
      
 64 
     | 
    
         
            +
                check(cuda.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int())))
         
     | 
| 
      
 65 
     | 
    
         
            +
                self.compile_options = [f'--gpu-architecture={arch}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"]
         
     | 
| 
      
 66 
     | 
    
         
            +
                if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal")
         
     | 
| 
      
 67 
     | 
    
         
            +
                super().__init__(f"compile_cuda_{self.arch}")
         
     | 
| 
      
 68 
     | 
    
         
            +
              def compile(self, src:str) -> bytes:
         
     | 
| 
      
 69 
     | 
    
         
            +
                check(cuda.nvrtcCreateProgram(ctypes.byref(prog := cuda.nvrtcProgram()), src.encode(), "<null>".encode(), 0, None, None))
         
     | 
| 
      
 70 
     | 
    
         
            +
                status = cuda.nvrtcCompileProgram(prog, len(self.compile_options), to_char_p_p([o.encode() for o in self.compile_options]))
         
     | 
| 
      
 71 
     | 
    
         
            +
             
     | 
| 
      
 72 
     | 
    
         
            +
                if status != 0: raise CompileError(f"compile failed: {_get_bytes(prog, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check).decode()}")
         
     | 
| 
      
 73 
     | 
    
         
            +
                return _get_bytes(prog, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, check)
         
     | 
| 
      
 74 
     | 
    
         
            +
             
     | 
| 
      
 75 
     | 
    
         
            +
            def cuda_disassemble(lib, arch):
         
     | 
| 
      
 76 
     | 
    
         
            +
              try:
         
     | 
| 
      
 77 
     | 
    
         
            +
                fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
         
     | 
| 
      
 78 
     | 
    
         
            +
                with open(fn + ".ptx", "wb") as f: f.write(lib)
         
     | 
| 
      
 79 
     | 
    
         
            +
                subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn+".ptx"], check=True)
         
     | 
| 
      
 80 
     | 
    
         
            +
                print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
         
     | 
| 
      
 81 
     | 
    
         
            +
              except Exception as e: print("failed to generate SASS", str(e))
         
     | 
| 
       58 
82 
     | 
    
         | 
| 
       59 
83 
     | 
    
         
             
            class CUDAProgram:
         
     | 
| 
       60 
     | 
    
         
            -
              def __init__(self,  
     | 
| 
       61 
     | 
    
         
            -
                 
     | 
| 
       62 
     | 
    
         
            -
             
     | 
| 
       63 
     | 
    
         
            -
             
     | 
| 
       64 
     | 
    
         
            -
             
     | 
| 
       65 
     | 
    
         
            -
             
     | 
| 
       66 
     | 
    
         
            -
                 
     | 
| 
       67 
     | 
    
         
            -
             
     | 
| 
       68 
     | 
    
         
            -
                   
     | 
| 
       69 
     | 
    
         
            -
             
     | 
| 
       70 
     | 
    
         
            -
             
     | 
| 
       71 
     | 
    
         
            -
                     
     | 
| 
       72 
     | 
    
         
            -
                     
     | 
| 
       73 
     | 
    
         
            -
             
     | 
| 
       74 
     | 
    
         
            -
             
     | 
| 
       75 
     | 
    
         
            -
             
     | 
| 
       76 
     | 
    
         
            -
             
     | 
| 
       77 
     | 
    
         
            -
              def  
     | 
| 
       78 
     | 
    
         
            -
                if  
     | 
| 
       79 
     | 
    
         
            -
             
     | 
| 
       80 
     | 
    
         
            -
             
     | 
| 
       81 
     | 
    
         
            -
                 
     | 
| 
       82 
     | 
    
         
            -
                 
     | 
| 
       83 
     | 
    
         
            -
                   
     | 
| 
       84 
     | 
    
         
            -
                   
     | 
| 
       85 
     | 
    
         
            -
             
     | 
| 
       86 
     | 
    
         
            -
             
     | 
| 
       87 
     | 
    
         
            -
             
     | 
| 
       88 
     | 
    
         
            -
             
     | 
| 
       89 
     | 
    
         
            -
             
     | 
| 
       90 
     | 
    
         
            -
             
     | 
| 
       91 
     | 
    
         
            -
             
     | 
| 
       92 
     | 
    
         
            -
             
     | 
| 
       93 
     | 
    
         
            -
                 
     | 
| 
       94 
     | 
    
         
            -
             
     | 
| 
       95 
     | 
    
         
            -
             
     | 
| 
       96 
     | 
    
         
            -
             
     | 
| 
       97 
     | 
    
         
            -
                 
     | 
| 
       98 
     | 
    
         
            -
             
     | 
| 
       99 
     | 
    
         
            -
             
     | 
| 
      
 84 
     | 
    
         
            +
              def __init__(self, device:CUDADevice, name:str, lib:bytes):
         
     | 
| 
      
 85 
     | 
    
         
            +
                self.device, self.name, self.lib = device, name, lib
         
     | 
| 
      
 86 
     | 
    
         
            +
                if DEBUG >= 5: print("\n".join([f"{i+1:>3} {line}" for i, line in enumerate(pretty_ptx(lib.decode('utf-8')).split("\n"))]))
         
     | 
| 
      
 87 
     | 
    
         
            +
                if DEBUG >= 6: cuda_disassemble(lib, device.arch)
         
     | 
| 
      
 88 
     | 
    
         
            +
             
     | 
| 
      
 89 
     | 
    
         
            +
                if CUDACPU: self.prg = lib
         
     | 
| 
      
 90 
     | 
    
         
            +
                else:
         
     | 
| 
      
 91 
     | 
    
         
            +
                  check(cuda.cuCtxSetCurrent(self.device.context))
         
     | 
| 
      
 92 
     | 
    
         
            +
                  self.module = cuda.CUmodule()
         
     | 
| 
      
 93 
     | 
    
         
            +
                  status = cuda.cuModuleLoadData(ctypes.byref(self.module), lib)
         
     | 
| 
      
 94 
     | 
    
         
            +
                  if status != 0:
         
     | 
| 
      
 95 
     | 
    
         
            +
                    del self.module
         
     | 
| 
      
 96 
     | 
    
         
            +
                    cuda_disassemble(lib, device.arch)
         
     | 
| 
      
 97 
     | 
    
         
            +
                    raise RuntimeError(f"module load failed with status code {status}: {cuda.cudaError_enum__enumvalues[status]}")
         
     | 
| 
      
 98 
     | 
    
         
            +
                  check(cuda.cuModuleGetFunction(ctypes.byref(prg := cuda.CUfunction()), self.module, name.encode("utf-8")))
         
     | 
| 
      
 99 
     | 
    
         
            +
                  self.prg = prg #type: ignore
         
     | 
| 
      
 100 
     | 
    
         
            +
             
     | 
| 
      
 101 
     | 
    
         
            +
              def __del__(self):
         
     | 
| 
      
 102 
     | 
    
         
            +
                if hasattr(self, 'module'): check(cuda.cuModuleUnload(self.module))
         
     | 
| 
      
 103 
     | 
    
         
            +
             
     | 
| 
      
 104 
     | 
    
         
            +
              def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
         
     | 
| 
      
 105 
     | 
    
         
            +
                if CUDACPU: self.vargs = args+tuple(vals)
         
     | 
| 
      
 106 
     | 
    
         
            +
                else:
         
     | 
| 
      
 107 
     | 
    
         
            +
                  check(cuda.cuCtxSetCurrent(self.device.context))
         
     | 
| 
      
 108 
     | 
    
         
            +
                  if not hasattr(self, "vargs"):
         
     | 
| 
      
 109 
     | 
    
         
            +
                    self.c_args, self.vargs = encode_args(args, vals) #type: ignore
         
     | 
| 
      
 110 
     | 
    
         
            +
                  else:
         
     | 
| 
      
 111 
     | 
    
         
            +
                    for i in range(len(args)): self.c_args.__setattr__(f'f{i}', args[i])
         
     | 
| 
      
 112 
     | 
    
         
            +
                    for i in range(len(vals)): self.c_args.__setattr__(f'v{i}', vals[i])
         
     | 
| 
      
 113 
     | 
    
         
            +
                return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, self.vargs)), enable=wait)
         
     | 
| 
      
 114 
     | 
    
         
            +
             
     | 
| 
      
 115 
     | 
    
         
            +
            class CUDAAllocator(LRUAllocator):
         
     | 
| 
      
 116 
     | 
    
         
            +
              def __init__(self, device:CUDADevice):
         
     | 
| 
      
 117 
     | 
    
         
            +
                self.device = device
         
     | 
| 
      
 118 
     | 
    
         
            +
                super().__init__()
         
     | 
| 
      
 119 
     | 
    
         
            +
              def _alloc(self, size, options:BufferOptions):
         
     | 
| 
      
 120 
     | 
    
         
            +
                check(cuda.cuCtxSetCurrent(self.device.context))
         
     | 
| 
      
 121 
     | 
    
         
            +
                if options.host: return init_c_var(ctypes.c_void_p(), lambda x: check(cuda.cuMemHostAlloc(ctypes.byref(x), size, 0x01)))
         
     | 
| 
      
 122 
     | 
    
         
            +
                else: return init_c_var(cuda.CUdeviceptr(), lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size)))
         
     | 
| 
      
 123 
     | 
    
         
            +
              def _free(self, opaque, options:BufferOptions):
         
     | 
| 
      
 124 
     | 
    
         
            +
                if options.host: return check(cuda.cuMemFreeHost(opaque))
         
     | 
| 
      
 125 
     | 
    
         
            +
                else: check(cuda.cuMemFree_v2(opaque))
         
     | 
| 
      
 126 
     | 
    
         
            +
              def copyin(self, dest, src:memoryview):
         
     | 
| 
      
 127 
     | 
    
         
            +
                check(cuda.cuCtxSetCurrent(self.device.context))
         
     | 
| 
      
 128 
     | 
    
         
            +
                host_mem = self.alloc(len(src), BufferOptions(host=True))
         
     | 
| 
      
 129 
     | 
    
         
            +
                self.device.pending_copyin.append((host_mem, len(src), BufferOptions(host=True)))
         
     | 
| 
      
 130 
     | 
    
         
            +
                ctypes.memmove(host_mem, from_mv(src), len(src))
         
     | 
| 
      
 131 
     | 
    
         
            +
                check(cuda.cuMemcpyHtoDAsync_v2(dest, host_mem, len(src), None))
         
     | 
| 
      
 132 
     | 
    
         
            +
              def copyout(self, dest:memoryview, src):
         
     | 
| 
      
 133 
     | 
    
         
            +
                CUDADevice.synchronize_system()
         
     | 
| 
      
 134 
     | 
    
         
            +
                check(cuda.cuCtxSetCurrent(self.device.context))
         
     | 
| 
      
 135 
     | 
    
         
            +
                check(cuda.cuMemcpyDtoH_v2(from_mv(dest), src, len(dest)))
         
     | 
| 
      
 136 
     | 
    
         
            +
              def transfer(self, dest, src, sz:int, src_dev, dest_dev):
         
     | 
| 
      
 137 
     | 
    
         
            +
                check(cuda.cuCtxSetCurrent(src_dev.context))
         
     | 
| 
      
 138 
     | 
    
         
            +
                check(cuda.cuEventCreate(ctypes.byref(sync_event := cuda.CUevent()), 0))
         
     | 
| 
      
 139 
     | 
    
         
            +
                check(cuda.cuMemcpyDtoDAsync_v2(dest, src, sz, None))
         
     | 
| 
      
 140 
     | 
    
         
            +
                check(cuda.cuEventRecord(sync_event, None))
         
     | 
| 
      
 141 
     | 
    
         
            +
                check(cuda.cuCtxSetCurrent(dest_dev.context))
         
     | 
| 
      
 142 
     | 
    
         
            +
                check(cuda.cuStreamWaitEvent(None, sync_event, 0)) # sync the default stream on the dest dev
         
     | 
| 
      
 143 
     | 
    
         
            +
              def offset(self, buf, size:int, offset:int): return ctypes.c_ulong(buf.value + offset)
         
     | 
| 
      
 144 
     | 
    
         
            +
             
     | 
| 
      
 145 
     | 
    
         
            +
            class CUDADevice(Compiled):
         
     | 
| 
      
 146 
     | 
    
         
            +
              devices: List[CUDADevice] = []
         
     | 
| 
      
 147 
     | 
    
         
            +
              peer_access = False
         
     | 
| 
      
 148 
     | 
    
         
            +
             
     | 
| 
      
 149 
     | 
    
         
            +
              def __init__(self, device:str):
         
     | 
| 
      
 150 
     | 
    
         
            +
                device_id = int(device.split(":")[1]) if ":" in device else 0
         
     | 
| 
      
 151 
     | 
    
         
            +
                if not CUDACPU:
         
     | 
| 
      
 152 
     | 
    
         
            +
                  check(cuda.cuInit(0))
         
     | 
| 
      
 153 
     | 
    
         
            +
                  self.cu_device = init_c_var(cuda.CUdevice(), lambda x: check(cuda.cuDeviceGet(ctypes.byref(x), device_id)))
         
     | 
| 
      
 154 
     | 
    
         
            +
                  self.context = init_c_var(cuda.CUcontext(), lambda x: check(cuda.cuCtxCreate_v2(ctypes.byref(x), 0, self.cu_device)))
         
     | 
| 
      
 155 
     | 
    
         
            +
                  check(cuda.cuDeviceComputeCapability(ctypes.byref(major := ctypes.c_int()), ctypes.byref(minor := ctypes.c_int()), device_id))
         
     | 
| 
      
 156 
     | 
    
         
            +
             
     | 
| 
      
 157 
     | 
    
         
            +
                  for dev in CUDADevice.devices:
         
     | 
| 
      
 158 
     | 
    
         
            +
                    check(cuda.cuDeviceCanAccessPeer(ctypes.byref(val := ctypes.c_int()), self.cu_device, dev.cu_device))
         
     | 
| 
      
 159 
     | 
    
         
            +
                    if val.value != 1: continue
         
     | 
| 
      
 160 
     | 
    
         
            +
                    check(cuda.cuCtxSetCurrent(dev.context))
         
     | 
| 
      
 161 
     | 
    
         
            +
                    check(cuda.cuCtxEnablePeerAccess(self.context, 0))
         
     | 
| 
      
 162 
     | 
    
         
            +
                    check(cuda.cuCtxSetCurrent(self.context))
         
     | 
| 
      
 163 
     | 
    
         
            +
                    check(cuda.cuCtxEnablePeerAccess(dev.context, 0))
         
     | 
| 
      
 164 
     | 
    
         
            +
                    CUDADevice.peer_access = True
         
     | 
| 
      
 165 
     | 
    
         
            +
             
     | 
| 
      
 166 
     | 
    
         
            +
                self.arch = f"sm_{major.value}{minor.value}" if not CUDACPU else "sm_35"
         
     | 
| 
      
 167 
     | 
    
         
            +
                self.pending_copyin: List[Tuple[int, int, Optional[BufferOptions]]] = []
         
     | 
| 
      
 168 
     | 
    
         
            +
                CUDADevice.devices.append(self)
         
     | 
| 
      
 169 
     | 
    
         
            +
             
     | 
| 
      
 170 
     | 
    
         
            +
                from tinygrad.runtime.graph.cuda import CUDAGraph
         
     | 
| 
      
 171 
     | 
    
         
            +
                super().__init__(device, CUDAAllocator(self) if not CUDACPU else MallocAllocator,
         
     | 
| 
      
 172 
     | 
    
         
            +
                                 PTXRenderer(self.arch) if getenv("PTX") else CUDARenderer(self.arch),
         
     | 
| 
      
 173 
     | 
    
         
            +
                                 PTXCompiler(self.arch) if getenv("PTX") else CUDACompiler(self.arch),
         
     | 
| 
      
 174 
     | 
    
         
            +
                                 functools.partial(CUDAProgram, self), graph=CUDAGraph if not CUDACPU else None)
         
     | 
| 
      
 175 
     | 
    
         
            +
             
     | 
| 
      
 176 
     | 
    
         
            +
              def synchronize(self):
         
     | 
| 
      
 177 
     | 
    
         
            +
                if CUDACPU: return
         
     | 
| 
      
 178 
     | 
    
         
            +
                check(cuda.cuCtxSetCurrent(self.context))
         
     | 
| 
      
 179 
     | 
    
         
            +
                check(cuda.cuCtxSynchronize())
         
     | 
| 
      
 180 
     | 
    
         
            +
                for opaque,sz,options in self.pending_copyin: self.allocator.free(opaque, sz, options)
         
     | 
| 
      
 181 
     | 
    
         
            +
                self.pending_copyin.clear()
         
     | 
| 
      
 182 
     | 
    
         
            +
             
     | 
| 
      
 183 
     | 
    
         
            +
              @staticmethod
         
     | 
| 
      
 184 
     | 
    
         
            +
              def synchronize_system():
         
     | 
| 
      
 185 
     | 
    
         
            +
                for d in CUDADevice.devices: d.synchronize()
         
     | 
    
        tinygrad/runtime/ops_disk.py
    CHANGED
    
    | 
         @@ -1,37 +1,60 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            import  
     | 
| 
      
 1 
     | 
    
         
            +
            from __future__ import annotations
         
     | 
| 
      
 2 
     | 
    
         
            +
            import os, mmap, _posixshmem, io
         
     | 
| 
       2 
3 
     | 
    
         
             
            from typing import Optional
         
     | 
| 
       3 
     | 
    
         
            -
            from  
     | 
| 
       4 
     | 
    
         
            -
            from tinygrad. 
     | 
| 
       5 
     | 
    
         
            -
            from tinygrad.runtime.lib import RawBufferMapped
         
     | 
| 
       6 
     | 
    
         
            -
            from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps
         
     | 
| 
      
 4 
     | 
    
         
            +
            from tinygrad.helpers import OSX
         
     | 
| 
      
 5 
     | 
    
         
            +
            from tinygrad.device import Compiled, Allocator
         
     | 
| 
       7 
6 
     | 
    
         | 
| 
       8 
     | 
    
         
            -
            class  
     | 
| 
       9 
     | 
    
         
            -
              def __init__(self,  
     | 
| 
       10 
     | 
    
         
            -
                self. 
     | 
| 
       11 
     | 
    
         
            -
             
     | 
| 
       12 
     | 
    
         
            -
             
     | 
| 
       13 
     | 
    
         
            -
                 
     | 
| 
       14 
     | 
    
         
            -
             
     | 
| 
       15 
     | 
    
         
            -
             
     | 
| 
       16 
     | 
    
         
            -
             
     | 
| 
      
 7 
     | 
    
         
            +
            class DiskBuffer:
         
     | 
| 
      
 8 
     | 
    
         
            +
              def __init__(self, device:DiskDevice, size:int, offset=0):
         
     | 
| 
      
 9 
     | 
    
         
            +
                self.device, self.size, self.offset = device, size, offset
         
     | 
| 
      
 10 
     | 
    
         
            +
              def __repr__(self): return f"<DiskBuffer size={self.size} offset={self.offset}>"
         
     | 
| 
      
 11 
     | 
    
         
            +
              def _buf(self) -> memoryview:
         
     | 
| 
      
 12 
     | 
    
         
            +
                assert self.device.mem is not None, "DiskBuffer wasn't opened"
         
     | 
| 
      
 13 
     | 
    
         
            +
                return memoryview(self.device.mem)[self.offset:self.offset+self.size]
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
            MAP_LOCKED, MAP_POPULATE = 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000)
         
     | 
| 
      
 16 
     | 
    
         
            +
            class DiskAllocator(Allocator):
         
     | 
| 
      
 17 
     | 
    
         
            +
              def __init__(self, device:DiskDevice): self.device = device
         
     | 
| 
      
 18 
     | 
    
         
            +
              def _alloc(self, size:int, options):
         
     | 
| 
      
 19 
     | 
    
         
            +
                self.device._might_open(size)
         
     | 
| 
      
 20 
     | 
    
         
            +
                return DiskBuffer(self.device, size)
         
     | 
| 
      
 21 
     | 
    
         
            +
              def _free(self, buf, options): self.device._might_close()
         
     | 
| 
      
 22 
     | 
    
         
            +
              def as_buffer(self, src:DiskBuffer): return src._buf()
         
     | 
| 
      
 23 
     | 
    
         
            +
              def copyin(self, dest:DiskBuffer, src:memoryview): dest._buf()[:] = src
         
     | 
| 
      
 24 
     | 
    
         
            +
              def copyout(self, dest:memoryview, src:DiskBuffer):
         
     | 
| 
      
 25 
     | 
    
         
            +
                if OSX and hasattr(self.device, 'fd'):
         
     | 
| 
      
 26 
     | 
    
         
            +
                  # OSX doesn't seem great at mmap, this is faster
         
     | 
| 
      
 27 
     | 
    
         
            +
                  with io.FileIO(self.device.fd, "a+b", closefd=False) as fo:
         
     | 
| 
      
 28 
     | 
    
         
            +
                    fo.seek(src.offset)
         
     | 
| 
      
 29 
     | 
    
         
            +
                    fo.readinto(dest)
         
     | 
| 
       17 
30 
     | 
    
         
             
                else:
         
     | 
| 
       18 
     | 
    
         
            -
                   
     | 
| 
       19 
     | 
    
         
            -
             
     | 
| 
       20 
     | 
    
         
            -
             
     | 
| 
       21 
     | 
    
         
            -
             
     | 
| 
       22 
     | 
    
         
            -
             
     | 
| 
       23 
     | 
    
         
            -
                 
     | 
| 
       24 
     | 
    
         
            -
             
     | 
| 
       25 
     | 
    
         
            -
             
     | 
| 
       26 
     | 
    
         
            -
              def  
     | 
| 
       27 
     | 
    
         
            -
                 
     | 
| 
       28 
     | 
    
         
            -
                 
     | 
| 
       29 
     | 
    
         
            -
                size  
     | 
| 
       30 
     | 
    
         
            -
                 
     | 
| 
       31 
     | 
    
         
            -
             
     | 
| 
       32 
     | 
    
         
            -
              def readinto(self, buf):
         
     | 
| 
       33 
     | 
    
         
            -
                self._buf[0].seek(self.offset)
         
     | 
| 
       34 
     | 
    
         
            -
                self._buf[0].readinto(buf)
         
     | 
| 
      
 31 
     | 
    
         
            +
                  dest[:] = src._buf()
         
     | 
| 
      
 32 
     | 
    
         
            +
              def offset(self, buf:DiskBuffer, size:int, offset:int): return DiskBuffer(buf.device, size, offset)
         
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
            class DiskDevice(Compiled):
         
     | 
| 
      
 35 
     | 
    
         
            +
              def __init__(self, device:str):
         
     | 
| 
      
 36 
     | 
    
         
            +
                self.size: Optional[int] = None
         
     | 
| 
      
 37 
     | 
    
         
            +
                self.count = 0
         
     | 
| 
      
 38 
     | 
    
         
            +
                super().__init__(device, DiskAllocator(self), None, None, None)
         
     | 
| 
      
 39 
     | 
    
         
            +
              def _might_open(self, size):
         
     | 
| 
      
 40 
     | 
    
         
            +
                self.count += 1
         
     | 
| 
      
 41 
     | 
    
         
            +
                assert self.size is None or size <= self.size, f"can't reopen Disk tensor with larger size, opened with {self.size}, tried to open with {size}"
         
     | 
| 
      
 42 
     | 
    
         
            +
                if self.size is not None: return
         
     | 
| 
      
 43 
     | 
    
         
            +
                filename = self.dname[len("disk:"):]
         
     | 
| 
      
 44 
     | 
    
         
            +
                self.size = size
         
     | 
| 
       35 
45 
     | 
    
         | 
| 
       36 
     | 
    
         
            -
             
     | 
| 
       37 
     | 
    
         
            -
             
     | 
| 
      
 46 
     | 
    
         
            +
                if filename.startswith("shm:"):
         
     | 
| 
      
 47 
     | 
    
         
            +
                  fd = _posixshmem.shm_open("/"+filename[4:].lstrip("/"), os.O_RDWR, 0o600)
         
     | 
| 
      
 48 
     | 
    
         
            +
                  self.mem = mmap.mmap(fd, self.size, mmap.MAP_SHARED | MAP_POPULATE | MAP_LOCKED)
         
     | 
| 
      
 49 
     | 
    
         
            +
                  os.close(fd)
         
     | 
| 
      
 50 
     | 
    
         
            +
                else:
         
     | 
| 
      
 51 
     | 
    
         
            +
                  try: self.fd = os.open(filename, os.O_RDWR|os.O_CREAT|(0 if OSX else os.O_DIRECT))
         
     | 
| 
      
 52 
     | 
    
         
            +
                  except OSError: self.fd = os.open(filename, os.O_RDWR|os.O_CREAT)
         
     | 
| 
      
 53 
     | 
    
         
            +
                  if os.fstat(self.fd).st_size < self.size: os.ftruncate(self.fd, self.size)
         
     | 
| 
      
 54 
     | 
    
         
            +
                  self.mem = mmap.mmap(self.fd, self.size)
         
     | 
| 
      
 55 
     | 
    
         
            +
                if (hp := getattr(mmap, "MADV_HUGEPAGE", None)) is not None: self.mem.madvise(hp) # type: ignore
         
     | 
| 
      
 56 
     | 
    
         
            +
              def _might_close(self):
         
     | 
| 
      
 57 
     | 
    
         
            +
                self.count -= 1
         
     | 
| 
      
 58 
     | 
    
         
            +
                if self.count == 0:
         
     | 
| 
      
 59 
     | 
    
         
            +
                  if hasattr(self, 'fd'): os.close(self.fd)
         
     | 
| 
      
 60 
     | 
    
         
            +
                  self.size = None
         
     |