tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
 - tinygrad/codegen/kernel.py +572 -83
 - tinygrad/codegen/linearizer.py +415 -395
 - tinygrad/codegen/uops.py +415 -0
 - tinygrad/device.py +183 -0
 - tinygrad/dtype.py +113 -0
 - tinygrad/engine/__init__.py +0 -0
 - tinygrad/engine/graph.py +100 -0
 - tinygrad/engine/jit.py +195 -0
 - tinygrad/engine/realize.py +191 -0
 - tinygrad/engine/schedule.py +362 -0
 - tinygrad/engine/search.py +196 -0
 - tinygrad/{mlops.py → function.py} +76 -55
 - tinygrad/helpers.py +196 -89
 - tinygrad/lazy.py +210 -371
 - tinygrad/multi.py +169 -0
 - tinygrad/nn/__init__.py +202 -22
 - tinygrad/nn/datasets.py +7 -0
 - tinygrad/nn/optim.py +112 -32
 - tinygrad/nn/state.py +136 -39
 - tinygrad/ops.py +119 -202
 - tinygrad/renderer/__init__.py +61 -0
 - tinygrad/renderer/assembly.py +276 -0
 - tinygrad/renderer/cstyle.py +353 -166
 - tinygrad/renderer/llvmir.py +150 -138
 - tinygrad/runtime/autogen/amd_gpu.py +1900 -0
 - tinygrad/runtime/autogen/comgr.py +865 -0
 - tinygrad/runtime/autogen/cuda.py +5923 -0
 - tinygrad/runtime/autogen/hip.py +5909 -0
 - tinygrad/runtime/autogen/hsa.py +5761 -0
 - tinygrad/runtime/autogen/kfd.py +812 -0
 - tinygrad/runtime/autogen/nv_gpu.py +33328 -0
 - tinygrad/runtime/autogen/opencl.py +1795 -0
 - tinygrad/runtime/driver/hip_comgr.py +47 -0
 - tinygrad/runtime/driver/hsa.py +143 -0
 - tinygrad/runtime/graph/clang.py +38 -0
 - tinygrad/runtime/graph/cuda.py +81 -0
 - tinygrad/runtime/graph/hcq.py +143 -0
 - tinygrad/runtime/graph/hsa.py +171 -0
 - tinygrad/runtime/graph/metal.py +75 -0
 - tinygrad/runtime/ops_amd.py +564 -0
 - tinygrad/runtime/ops_clang.py +24 -77
 - tinygrad/runtime/ops_cuda.py +175 -89
 - tinygrad/runtime/ops_disk.py +56 -33
 - tinygrad/runtime/ops_gpu.py +92 -95
 - tinygrad/runtime/ops_hsa.py +278 -0
 - tinygrad/runtime/ops_llvm.py +39 -60
 - tinygrad/runtime/ops_metal.py +92 -74
 - tinygrad/runtime/ops_npy.py +9 -0
 - tinygrad/runtime/ops_nv.py +630 -0
 - tinygrad/runtime/ops_python.py +204 -0
 - tinygrad/shape/shapetracker.py +86 -254
 - tinygrad/shape/symbolic.py +166 -141
 - tinygrad/shape/view.py +296 -0
 - tinygrad/tensor.py +2619 -448
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
 - tinygrad-0.9.0.dist-info/METADATA +227 -0
 - tinygrad-0.9.0.dist-info/RECORD +60 -0
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
 - tinygrad/codegen/assembly.py +0 -190
 - tinygrad/codegen/optimizer.py +0 -379
 - tinygrad/codegen/search.py +0 -72
 - tinygrad/graph.py +0 -83
 - tinygrad/jit.py +0 -57
 - tinygrad/nn/image.py +0 -100
 - tinygrad/renderer/assembly_arm64.py +0 -169
 - tinygrad/renderer/assembly_ptx.py +0 -98
 - tinygrad/renderer/wgsl.py +0 -53
 - tinygrad/runtime/lib.py +0 -113
 - tinygrad/runtime/ops_cpu.py +0 -51
 - tinygrad/runtime/ops_hip.py +0 -82
 - tinygrad/runtime/ops_shm.py +0 -29
 - tinygrad/runtime/ops_torch.py +0 -30
 - tinygrad/runtime/ops_webgpu.py +0 -45
 - tinygrad-0.7.0.dist-info/METADATA +0 -212
 - tinygrad-0.7.0.dist-info/RECORD +0 -40
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
 
    
        tinygrad/nn/image.py
    DELETED
    
    | 
         @@ -1,100 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            import numpy as np
         
     | 
| 
       2 
     | 
    
         
            -
            from tinygrad.helpers import prod, IMAGE, ImageDType, getenv, dtypes
         
     | 
| 
       3 
     | 
    
         
            -
            from tinygrad.lazy import get_single_root
         
     | 
| 
       4 
     | 
    
         
            -
             
     | 
| 
       5 
     | 
    
         
            -
            FLOAT16 = getenv("FLOAT16", 0)
         
     | 
| 
       6 
     | 
    
         
            -
            base_image_type = (100, 2, "imageh", np.float16) if FLOAT16 else (100, 4, "imagef", np.float32)
         
     | 
| 
       7 
     | 
    
         
            -
             
     | 
| 
       8 
     | 
    
         
            -
            def image_dot(self, w):
         
     | 
| 
       9 
     | 
    
         
            -
              # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
         
     | 
| 
       10 
     | 
    
         
            -
              n1, n2 = len(self.shape), len(w.shape)
         
     | 
| 
       11 
     | 
    
         
            -
              assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
         
     | 
| 
       12 
     | 
    
         
            -
              assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})"
         
     | 
| 
       13 
     | 
    
         
            -
              bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
         
     | 
| 
       14 
     | 
    
         
            -
              cin, cout = w.shape[-2], w.shape[-1]
         
     | 
| 
       15 
     | 
    
         
            -
              out_shape_t = self.shape[0:-2] + (cout,-1)
         
     | 
| 
       16 
     | 
    
         
            -
              if len(self.shape) > 1:
         
     | 
| 
       17 
     | 
    
         
            -
                order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
         
     | 
| 
       18 
     | 
    
         
            -
              else:
         
     | 
| 
       19 
     | 
    
         
            -
                order, out_shape_t = (0,), (cout, )
         
     | 
| 
       20 
     | 
    
         
            -
              worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
         
     | 
| 
       21 
     | 
    
         
            -
             
     | 
| 
       22 
     | 
    
         
            -
              # NOTE: with NHWC we can remove the transposes
         
     | 
| 
       23 
     | 
    
         
            -
              # bs x groups*cin x H x W
         
     | 
| 
       24 
     | 
    
         
            -
              cx = self.permute(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
         
     | 
| 
       25 
     | 
    
         
            -
              # groups*cout x cin x H, W
         
     | 
| 
       26 
     | 
    
         
            -
              cw = w.permute(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
         
     | 
| 
       27 
     | 
    
         
            -
              return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).permute(order=order)
         
     | 
| 
       28 
     | 
    
         
            -
             
     | 
| 
       29 
     | 
    
         
            -
            def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0):
         
     | 
| 
       30 
     | 
    
         
            -
              (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
         
     | 
| 
       31 
     | 
    
         
            -
              rcout = cout//groups
         
     | 
| 
       32 
     | 
    
         
            -
              x, w = self, weight.reshape(groups, rcout, cin, H, W)
         
     | 
| 
       33 
     | 
    
         
            -
             
     | 
| 
       34 
     | 
    
         
            -
              # hack for non multiples of 4 on cin
         
     | 
| 
       35 
     | 
    
         
            -
              if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
         
     | 
| 
       36 
     | 
    
         
            -
                x = x.reshape(bs, groups, cin, iy, ix)   # do this always?
         
     | 
| 
       37 
     | 
    
         
            -
                added_input_channels = 4 - (cin % 4)
         
     | 
| 
       38 
     | 
    
         
            -
                w = w.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(w.shape))))
         
     | 
| 
       39 
     | 
    
         
            -
                x = x.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(x.shape))))
         
     | 
| 
       40 
     | 
    
         
            -
                cin = cin + added_input_channels
         
     | 
| 
       41 
     | 
    
         
            -
                x = x.reshape(bs, groups*cin, iy, ix)
         
     | 
| 
       42 
     | 
    
         
            -
             
     | 
| 
       43 
     | 
    
         
            -
              # hack for non multiples of 4 on rcout
         
     | 
| 
       44 
     | 
    
         
            -
              added_output_channels = 0
         
     | 
| 
       45 
     | 
    
         
            -
              if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
         
     | 
| 
       46 
     | 
    
         
            -
                added_output_channels = 4 - (rcout % 4)
         
     | 
| 
       47 
     | 
    
         
            -
                rcout += added_output_channels
         
     | 
| 
       48 
     | 
    
         
            -
                cout = groups * rcout
         
     | 
| 
       49 
     | 
    
         
            -
                w = w.slice(tuple((0, rcout) if i == 1 else (0, w.shape[i]) for i in range(len(w.shape))))
         
     | 
| 
       50 
     | 
    
         
            -
             
     | 
| 
       51 
     | 
    
         
            -
              # packed (note: flipping bs and iy would make the auto-padding work)
         
     | 
| 
       52 
     | 
    
         
            -
              x = x.permute(0,2,3,1).reshape(bs * iy, ix * groups * cin//4, 4)
         
     | 
| 
       53 
     | 
    
         
            -
              cin_last = iy == 1 and ix == 1
         
     | 
| 
       54 
     | 
    
         
            -
              if cin == 1: w = w.reshape(cout//4,4,H*W).permute(0,2,1)
         
     | 
| 
       55 
     | 
    
         
            -
              elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3).reshape(cout//4, H*cin//4*W*4, 4)
         
     | 
| 
       56 
     | 
    
         
            -
              else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1).reshape(cout//4, H*cin//4*W*4, 4)
         
     | 
| 
       57 
     | 
    
         
            -
             
     | 
| 
       58 
     | 
    
         
            -
              # contiguous creates the image, and early realize static weights (TODO: test for the static weight)
         
     | 
| 
       59 
     | 
    
         
            -
              if IMAGE >= 2: x,w = x.cast(ImageDType(*base_image_type, shape=x.shape)), w.cast(ImageDType(*base_image_type, shape=w.shape))
         
     | 
| 
       60 
     | 
    
         
            -
              x, w = x.contiguous(), w.contiguous()
         
     | 
| 
       61 
     | 
    
         
            -
              if get_single_root(w.lazydata).realized: w.realize()
         
     | 
| 
       62 
     | 
    
         
            -
             
     | 
| 
       63 
     | 
    
         
            -
              # expand out
         
     | 
| 
       64 
     | 
    
         
            -
              rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
         
     | 
| 
       65 
     | 
    
         
            -
              cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
         
     | 
| 
       66 
     | 
    
         
            -
              x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
         
     | 
| 
       67 
     | 
    
         
            -
              if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
         
     | 
| 
       68 
     | 
    
         
            -
              else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
         
     | 
| 
       69 
     | 
    
         
            -
             
     | 
| 
       70 
     | 
    
         
            -
              # padding
         
     | 
| 
       71 
     | 
    
         
            -
              padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
         
     | 
| 
       72 
     | 
    
         
            -
              x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
         
     | 
| 
       73 
     | 
    
         
            -
             
     | 
| 
       74 
     | 
    
         
            -
              # prepare input
         
     | 
| 
       75 
     | 
    
         
            -
              x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
         
     | 
| 
       76 
     | 
    
         
            -
              oy, ox = x.shape[4:6]
         
     | 
| 
       77 
     | 
    
         
            -
              x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
         
     | 
| 
       78 
     | 
    
         
            -
              x = x.expand(bs, oy, ox, *cout_expand, rcin_hi, rcin_lo, H, W)
         
     | 
| 
       79 
     | 
    
         
            -
             
     | 
| 
       80 
     | 
    
         
            -
              # prepare weights
         
     | 
| 
       81 
     | 
    
         
            -
              w = w.permute(0,4,2,5,1,3)
         
     | 
| 
       82 
     | 
    
         
            -
              w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)).expand(x.shape)
         
     | 
| 
       83 
     | 
    
         
            -
             
     | 
| 
       84 
     | 
    
         
            -
              # the conv! (+ the bias)
         
     | 
| 
       85 
     | 
    
         
            -
              ret = (x*w).cast(dtypes.float32).sum((-4, -3, -2, -1))
         
     | 
| 
       86 
     | 
    
         
            -
             
     | 
| 
       87 
     | 
    
         
            -
              # reshape to image and cast back to image
         
     | 
| 
       88 
     | 
    
         
            -
              ret = ret.reshape(bs*oy, ox*cout//4, 4)
         
     | 
| 
       89 
     | 
    
         
            -
              if IMAGE >= 2: ret = ret.cast(ImageDType(*base_image_type, shape=ret.shape))
         
     | 
| 
       90 
     | 
    
         
            -
              if IMAGE >= 3: ret = ret.contiguous()
         
     | 
| 
       91 
     | 
    
         
            -
             
     | 
| 
       92 
     | 
    
         
            -
              # undo hack for non multiples of 4 on C.rcout
         
     | 
| 
       93 
     | 
    
         
            -
              if added_output_channels != 0:
         
     | 
| 
       94 
     | 
    
         
            -
                ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
         
     | 
| 
       95 
     | 
    
         
            -
                rcout -= added_output_channels
         
     | 
| 
       96 
     | 
    
         
            -
                cout = groups * rcout
         
     | 
| 
       97 
     | 
    
         
            -
             
     | 
| 
       98 
     | 
    
         
            -
              # NCHW output
         
     | 
| 
       99 
     | 
    
         
            -
              ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
         
     | 
| 
       100 
     | 
    
         
            -
              return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
         
     | 
| 
         @@ -1,169 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            import struct
         
     | 
| 
       2 
     | 
    
         
            -
            from platform import system
         
     | 
| 
       3 
     | 
    
         
            -
            from typing import Tuple, Dict, List, Optional
         
     | 
| 
       4 
     | 
    
         
            -
            from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
         
     | 
| 
       5 
     | 
    
         
            -
            from tinygrad.codegen.linearizer import UOps, UOp
         
     | 
| 
       6 
     | 
    
         
            -
            from tinygrad.helpers import dtypes, CI
         
     | 
| 
       7 
     | 
    
         
            -
            from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
         
     | 
| 
       8 
     | 
    
         
            -
             
     | 
| 
       9 
     | 
    
         
            -
            def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
         
     | 
| 
       10 
     | 
    
         
            -
            def compute_offsets(total):
         
     | 
| 
       11 
     | 
    
         
            -
              quotient, remainder = divmod(total, 4096)
         
     | 
| 
       12 
     | 
    
         
            -
              return [4096]*quotient + [remainder] if remainder else [4096]*quotient
         
     | 
| 
       13 
     | 
    
         
            -
             
     | 
| 
       14 
     | 
    
         
            -
            #NOTE: Darwin needs names to start with a "_"
         
     | 
| 
       15 
     | 
    
         
            -
            def get_name(name): return ('_' if system() == 'Darwin' else '') + name
         
     | 
| 
       16 
     | 
    
         
            -
             
     | 
| 
       17 
     | 
    
         
            -
            class ARM64Language(AssemblyLanguage): pass
         
     | 
| 
       18 
     | 
    
         
            -
             
     | 
| 
       19 
     | 
    
         
            -
            def specialize_to_arm64(fn_nm, asm):
         
     | 
| 
       20 
     | 
    
         
            -
              var_size = 16
         
     | 
| 
       21 
     | 
    
         
            -
              prev_uop:Optional[UOps] = None
         
     | 
| 
       22 
     | 
    
         
            -
              ins = []
         
     | 
| 
       23 
     | 
    
         
            -
              x_regs = ['x' + str(i) for i in reversed(range(12))]
         
     | 
| 
       24 
     | 
    
         
            -
              s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
         
     | 
| 
       25 
     | 
    
         
            -
              type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
         
     | 
| 
       26 
     | 
    
         
            -
              alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
         
     | 
| 
       27 
     | 
    
         
            -
                      BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
         
     | 
| 
       28 
     | 
    
         
            -
                      UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
         
     | 
| 
       29 
     | 
    
         
            -
                      TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
         
     | 
| 
       30 
     | 
    
         
            -
             
     | 
| 
       31 
     | 
    
         
            -
              def mov_imm(value, reg):
         
     | 
| 
       32 
     | 
    
         
            -
                # Manually move value into reg if value can't fit
         
     | 
| 
       33 
     | 
    
         
            -
                if value.__class__ is not float and abs(value) > abs(65535):
         
     | 
| 
       34 
     | 
    
         
            -
                  ins.append(f"movz w15, #{value & 0xffff}")
         
     | 
| 
       35 
     | 
    
         
            -
                  ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
         
     | 
| 
       36 
     | 
    
         
            -
                  ins.append(f"sxtw {reg}, w15")
         
     | 
| 
       37 
     | 
    
         
            -
                elif reg[0] == 's':
         
     | 
| 
       38 
     | 
    
         
            -
                  ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
         
     | 
| 
       39 
     | 
    
         
            -
                  ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
         
     | 
| 
       40 
     | 
    
         
            -
                  ins.append("str x15, [sp, 16]")
         
     | 
| 
       41 
     | 
    
         
            -
                  ins.append(f"ldr {reg}, [sp, 16]")
         
     | 
| 
       42 
     | 
    
         
            -
                else:
         
     | 
| 
       43 
     | 
    
         
            -
                  ins.append(f"mov {reg}, #{value}")
         
     | 
| 
       44 
     | 
    
         
            -
             
     | 
| 
       45 
     | 
    
         
            -
              # Get variables intervals
         
     | 
| 
       46 
     | 
    
         
            -
              live_range:Dict[str, List[int]] = {}
         
     | 
| 
       47 
     | 
    
         
            -
              for i, (uop, out, vin, arg) in enumerate(asm):
         
     | 
| 
       48 
     | 
    
         
            -
                for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
         
     | 
| 
       49 
     | 
    
         
            -
                  live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
         
     | 
| 
       50 
     | 
    
         
            -
             
     | 
| 
       51 
     | 
    
         
            -
              mem_vars:Dict[str, int] = {}
         
     | 
| 
       52 
     | 
    
         
            -
              rtor:Dict[str, str] = {}
         
     | 
| 
       53 
     | 
    
         
            -
              def allocate_regs(mvars):
         
     | 
| 
       54 
     | 
    
         
            -
                nonlocal var_size
         
     | 
| 
       55 
     | 
    
         
            -
                for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]:
         
     | 
| 
       56 
     | 
    
         
            -
                  available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
         
     | 
| 
       57 
     | 
    
         
            -
                  #NOTE: Very simple spill, everything that don't fit in regs goes to mem
         
     | 
| 
       58 
     | 
    
         
            -
                  if not available_regs:
         
     | 
| 
       59 
     | 
    
         
            -
                    # ARM needs the stack 16-byte aligned
         
     | 
| 
       60 
     | 
    
         
            -
                    var_size += 16
         
     | 
| 
       61 
     | 
    
         
            -
                    available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12')
         
     | 
| 
       62 
     | 
    
         
            -
                    mem_vars[v.nm] = var_size
         
     | 
| 
       63 
     | 
    
         
            -
                  rtor[v.nm] = available_regs.pop()
         
     | 
| 
       64 
     | 
    
         
            -
             
     | 
| 
       65 
     | 
    
         
            -
              temp_floats = ['s0', 's1', 's2']
         
     | 
| 
       66 
     | 
    
         
            -
              temp_ints = ['x12', 'x13', 'x16']
         
     | 
| 
       67 
     | 
    
         
            -
              for i, (uop, out, vin, arg) in enumerate(asm):
         
     | 
| 
       68 
     | 
    
         
            -
                # Clear regs out of interval
         
     | 
| 
       69 
     | 
    
         
            -
                for var, reg in list(rtor.items()):
         
     | 
| 
       70 
     | 
    
         
            -
                  available_regs = s_regs if reg[0] == 's' else x_regs
         
     | 
| 
       71 
     | 
    
         
            -
                  if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
         
     | 
| 
       72 
     | 
    
         
            -
                    available_regs.append(rtor.pop(var))
         
     | 
| 
       73 
     | 
    
         
            -
                # Assign a registers to the variables using live ranges.
         
     | 
| 
       74 
     | 
    
         
            -
                allocate_regs([out] + vin)
         
     | 
| 
       75 
     | 
    
         
            -
                # Assign temp regs to vin and load them before direct use
         
     | 
| 
       76 
     | 
    
         
            -
                for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
         
     | 
| 
       77 
     | 
    
         
            -
                  rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
         
     | 
| 
       78 
     | 
    
         
            -
                  # ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
         
     | 
| 
       79 
     | 
    
         
            -
                  ins.append(f"mov x15, {mem_vars[v.nm]}")
         
     | 
| 
       80 
     | 
    
         
            -
                  ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
         
     | 
| 
       81 
     | 
    
         
            -
             
     | 
| 
       82 
     | 
    
         
            -
                if uop == UOps.SPECIAL:
         
     | 
| 
       83 
     | 
    
         
            -
                  if arg.startswith('data'):
         
     | 
| 
       84 
     | 
    
         
            -
                    # data 8 to n into the stack
         
     | 
| 
       85 
     | 
    
         
            -
                    if int(arg[4:]) >= 8:
         
     | 
| 
       86 
     | 
    
         
            -
                      ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
         
     | 
| 
       87 
     | 
    
         
            -
                      ins.append(f"mov {rtor[out.nm]}, x15")
         
     | 
| 
       88 
     | 
    
         
            -
                  else:
         
     | 
| 
       89 
     | 
    
         
            -
                    ins.append(f"mov {rtor[out.nm]}, #0")
         
     | 
| 
       90 
     | 
    
         
            -
                    ins.append(f"loop_{arg}:")
         
     | 
| 
       91 
     | 
    
         
            -
                elif uop == UOps.CAST:
         
     | 
| 
       92 
     | 
    
         
            -
                  if arg == BinaryOps.CMPLT:
         
     | 
| 
       93 
     | 
    
         
            -
                    mov_imm(0.0, 's0')
         
     | 
| 
       94 
     | 
    
         
            -
                    mov_imm(1.0, 's1')
         
     | 
| 
       95 
     | 
    
         
            -
                    ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
         
     | 
| 
       96 
     | 
    
         
            -
                  else:
         
     | 
| 
       97 
     | 
    
         
            -
                    ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
         
     | 
| 
       98 
     | 
    
         
            -
                elif uop == UOps.ALU:
         
     | 
| 
       99 
     | 
    
         
            -
                  if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
         
     | 
| 
       100 
     | 
    
         
            -
                  if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
         
     | 
| 
       101 
     | 
    
         
            -
                    ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
         
     | 
| 
       102 
     | 
    
         
            -
                  elif arg == TernaryOps.WHERE:
         
     | 
| 
       103 
     | 
    
         
            -
                    ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0")
         
     | 
| 
       104 
     | 
    
         
            -
                    ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
         
     | 
| 
       105 
     | 
    
         
            -
                  elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
         
     | 
| 
       106 
     | 
    
         
            -
                    #NOTE: Not a real instruction, use to emulate a ext call in unicorn
         
     | 
| 
       107 
     | 
    
         
            -
                    if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
         
     | 
| 
       108 
     | 
    
         
            -
                    else:
         
     | 
| 
       109 
     | 
    
         
            -
                      save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
         
     | 
| 
       110 
     | 
    
         
            -
                      ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
         
     | 
| 
       111 
     | 
    
         
            -
                      # Save the registers before they are cleared by func call
         
     | 
| 
       112 
     | 
    
         
            -
                      for i,k in enumerate(save_regs,1):
         
     | 
| 
       113 
     | 
    
         
            -
                        ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
         
     | 
| 
       114 
     | 
    
         
            -
                      ins.append("stp x29, x30, [sp, #0]!")
         
     | 
| 
       115 
     | 
    
         
            -
                      ins.append("mov x29, sp")
         
     | 
| 
       116 
     | 
    
         
            -
                      ins.append(f"fmov s0, {rtor[vin[0].nm]}")
         
     | 
| 
       117 
     | 
    
         
            -
                      ins.append(alu[arg])
         
     | 
| 
       118 
     | 
    
         
            -
                      ins.append(f"fmov {rtor[out.nm]}, s0")
         
     | 
| 
       119 
     | 
    
         
            -
                      ins.append("mov sp, x29")
         
     | 
| 
       120 
     | 
    
         
            -
                      ins.append("ldp x29, x30, [sp], #0")
         
     | 
| 
       121 
     | 
    
         
            -
                      for i,k in enumerate(save_regs,1):
         
     | 
| 
       122 
     | 
    
         
            -
                        ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
         
     | 
| 
       123 
     | 
    
         
            -
                      ins.append(f"add sp, sp, #{len(save_regs)*16}")
         
     | 
| 
       124 
     | 
    
         
            -
                  elif arg == BinaryOps.CMPLT:
         
     | 
| 
       125 
     | 
    
         
            -
                    ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
         
     | 
| 
       126 
     | 
    
         
            -
                  elif arg == BinaryOps.MOD:
         
     | 
| 
       127 
     | 
    
         
            -
                    ins.append(f"udiv x14, {rtor[vin[0].nm]}, x15")
         
     | 
| 
       128 
     | 
    
         
            -
                    ins.append(f"msub {rtor[out.nm]}, x14, x15, {rtor[vin[0].nm]}")
         
     | 
| 
       129 
     | 
    
         
            -
                  else:
         
     | 
| 
       130 
     | 
    
         
            -
                    ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
         
     | 
| 
       131 
     | 
    
         
            -
                elif uop == UOps.LOAD:
         
     | 
| 
       132 
     | 
    
         
            -
                  if arg.__class__ in (int, float):
         
     | 
| 
       133 
     | 
    
         
            -
                    mov_imm(arg, rtor[out.nm])
         
     | 
| 
       134 
     | 
    
         
            -
                  else:
         
     | 
| 
       135 
     | 
    
         
            -
                    #NOTE: if need casting load var in s/h0 or x/w12 temp regs
         
     | 
| 
       136 
     | 
    
         
            -
                    reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
         
     | 
| 
       137 
     | 
    
         
            -
                    mov_imm(arg[0], "x15")
         
     | 
| 
       138 
     | 
    
         
            -
                    ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
         
     | 
| 
       139 
     | 
    
         
            -
                    ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
         
     | 
| 
       140 
     | 
    
         
            -
                    if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
         
     | 
| 
       141 
     | 
    
         
            -
                elif uop == UOps.STORE:
         
     | 
| 
       142 
     | 
    
         
            -
                  #NOTE: if need casting load var in s/h0 or x/w12 temp regs
         
     | 
| 
       143 
     | 
    
         
            -
                  reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
         
     | 
| 
       144 
     | 
    
         
            -
                  if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
         
     | 
| 
       145 
     | 
    
         
            -
                  ins.append(f"mov x15, #{arg[0]}")
         
     | 
| 
       146 
     | 
    
         
            -
                  ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
         
     | 
| 
       147 
     | 
    
         
            -
                elif uop == UOps.COND_BRANCH:
         
     | 
| 
       148 
     | 
    
         
            -
                  #TODO: this is a hack it shouldn't always be a cmp before a cond branch?
         
     | 
| 
       149 
     | 
    
         
            -
                  if prev_uop == UOps.LOAD:
         
     | 
| 
       150 
     | 
    
         
            -
                    ins.append(f"cmp {rtor[vin[0].nm]}, #0")
         
     | 
| 
       151 
     | 
    
         
            -
                  ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
         
     | 
| 
       152 
     | 
    
         
            -
                elif uop == UOps.LABEL:
         
     | 
| 
       153 
     | 
    
         
            -
                  ins.append(f"{arg[1:]}:")
         
     | 
| 
       154 
     | 
    
         
            -
                elif uop == UOps.ENDLOOP:
         
     | 
| 
       155 
     | 
    
         
            -
                  mov_imm(arg[0], "x15")
         
     | 
| 
       156 
     | 
    
         
            -
                  ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
         
     | 
| 
       157 
     | 
    
         
            -
                  ins.append(f"cmp {rtor[vin[0].nm]}, x15")
         
     | 
| 
       158 
     | 
    
         
            -
                  ins.append(f"b.lt loop_{arg[1]}")
         
     | 
| 
       159 
     | 
    
         
            -
                prev_uop = uop
         
     | 
| 
       160 
     | 
    
         
            -
                # store regs into memory if needed
         
     | 
| 
       161 
     | 
    
         
            -
                if out is not None and out.nm in mem_vars:
         
     | 
| 
       162 
     | 
    
         
            -
                  ins.append(f"mov x15, {mem_vars[out.nm]}")
         
     | 
| 
       163 
     | 
    
         
            -
                  ins.append(f"str {rtor[out.nm]}, [sp, x15]")
         
     | 
| 
       164 
     | 
    
         
            -
              return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
         
     | 
| 
       165 
     | 
    
         
            -
             
     | 
| 
       166 
     | 
    
         
            -
            def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
         
     | 
| 
       167 
     | 
    
         
            -
              lang = ARM64Language()
         
     | 
| 
       168 
     | 
    
         
            -
              global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
         
     | 
| 
       169 
     | 
    
         
            -
              return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True
         
     | 
| 
         @@ -1,98 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            from typing import List
         
     | 
| 
       2 
     | 
    
         
            -
            import struct
         
     | 
| 
       3 
     | 
    
         
            -
            from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
         
     | 
| 
       4 
     | 
    
         
            -
            from tinygrad.codegen.linearizer import UOps, UOp
         
     | 
| 
       5 
     | 
    
         
            -
            from tinygrad.helpers import dtypes
         
     | 
| 
       6 
     | 
    
         
            -
            from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
         
     | 
| 
       7 
     | 
    
         
            -
            from tinygrad.runtime.ops_cuda import arch
         
     | 
| 
       8 
     | 
    
         
            -
             
     | 
| 
       9 
     | 
    
         
            -
            dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"}
         
     | 
| 
       10 
     | 
    
         
            -
            def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
         
     | 
| 
       11 
     | 
    
         
            -
             
     | 
| 
       12 
     | 
    
         
            -
            def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize)
         
     | 
| 
       13 
     | 
    
         
            -
             
     | 
| 
       14 
     | 
    
         
            -
            def render_cast(ins, inp, out):
         
     | 
| 
       15 
     | 
    
         
            -
              if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)):
         
     | 
| 
       16 
     | 
    
         
            -
                ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};")
         
     | 
| 
       17 
     | 
    
         
            -
              elif out.dtype == dtypes.bool:
         
     | 
| 
       18 
     | 
    
         
            -
                ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};")
         
     | 
| 
       19 
     | 
    
         
            -
              else:
         
     | 
| 
       20 
     | 
    
         
            -
                round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else ''
         
     | 
| 
       21 
     | 
    
         
            -
                ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};")
         
     | 
| 
       22 
     | 
    
         
            -
             
     | 
| 
       23 
     | 
    
         
            -
            # https://docs.nvidia.com/cuda/parallel-thread-execution/#
         
     | 
| 
       24 
     | 
    
         
            -
             
     | 
| 
       25 
     | 
    
         
            -
            class PTXLanguage(AssemblyLanguage):
         
     | 
| 
       26 
     | 
    
         
            -
              supports_constant_folding: bool = True
         
     | 
| 
       27 
     | 
    
         
            -
             
     | 
| 
       28 
     | 
    
         
            -
            def specialize_to_ptx(lang, function_name):
         
     | 
| 
       29 
     | 
    
         
            -
              param_cnt = 0
         
     | 
| 
       30 
     | 
    
         
            -
              ins = []
         
     | 
| 
       31 
     | 
    
         
            -
              alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
         
     | 
| 
       32 
     | 
    
         
            -
                     BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx",
         
     | 
| 
       33 
     | 
    
         
            -
                     UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
         
     | 
| 
       34 
     | 
    
         
            -
                     TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"}
         
     | 
| 
       35 
     | 
    
         
            -
              for uop, out, vin, arg in lang.ins:
         
     | 
| 
       36 
     | 
    
         
            -
                if uop == UOps.ENDLOOP:
         
     | 
| 
       37 
     | 
    
         
            -
                  ins.append("bar.sync 0;")
         
     | 
| 
       38 
     | 
    
         
            -
                elif uop == UOps.DEFINE_LOCAL:
         
     | 
| 
       39 
     | 
    
         
            -
                  ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
         
     | 
| 
       40 
     | 
    
         
            -
                elif uop == UOps.SPECIAL:
         
     | 
| 
       41 
     | 
    
         
            -
                  if arg.startswith('data'):
         
     | 
| 
       42 
     | 
    
         
            -
                    param_cnt += 1
         
     | 
| 
       43 
     | 
    
         
            -
                    ins.append(f"ld.param.u64 {out}, [{arg}];")
         
     | 
| 
       44 
     | 
    
         
            -
                    # TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
         
     | 
| 
       45 
     | 
    
         
            -
                    # ins.append(f"cvta.to.global.u64 {out}, {out};")
         
     | 
| 
       46 
     | 
    
         
            -
                  elif arg.startswith('gid'):
         
     | 
| 
       47 
     | 
    
         
            -
                    ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
         
     | 
| 
       48 
     | 
    
         
            -
                  elif arg.startswith('lid'):
         
     | 
| 
       49 
     | 
    
         
            -
                    ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
         
     | 
| 
       50 
     | 
    
         
            -
                elif uop == UOps.ALU:
         
     | 
| 
       51 
     | 
    
         
            -
                  if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
         
     | 
| 
       52 
     | 
    
         
            -
                    ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
         
     | 
| 
       53 
     | 
    
         
            -
                  else:
         
     | 
| 
       54 
     | 
    
         
            -
                    otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype
         
     | 
| 
       55 
     | 
    
         
            -
                    if arg == TernaryOps.WHERE:
         
     | 
| 
       56 
     | 
    
         
            -
                      reg = lang.newreg((vin[0], 'bool'), dtypes.bool)
         
     | 
| 
       57 
     | 
    
         
            -
                      ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};")
         
     | 
| 
       58 
     | 
    
         
            -
                      vin = vin[1:] + [reg]
         
     | 
| 
       59 
     | 
    
         
            -
                    ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
         
     | 
| 
       60 
     | 
    
         
            -
                elif uop == UOps.LOAD:
         
     | 
| 
       61 
     | 
    
         
            -
                  if arg.__class__ in (int, float):
         
     | 
| 
       62 
     | 
    
         
            -
                    ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};")
         
     | 
| 
       63 
     | 
    
         
            -
                  elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
         
     | 
| 
       64 
     | 
    
         
            -
                    dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2])
         
     | 
| 
       65 
     | 
    
         
            -
                    reg = lang.newreg((out, dt[0]), dtype=dt[1])
         
     | 
| 
       66 
     | 
    
         
            -
                    ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
         
     | 
| 
       67 
     | 
    
         
            -
                    render_cast(ins, reg, out)
         
     | 
| 
       68 
     | 
    
         
            -
                  else:
         
     | 
| 
       69 
     | 
    
         
            -
                    ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
         
     | 
| 
       70 
     | 
    
         
            -
                elif uop == UOps.STORE:
         
     | 
| 
       71 
     | 
    
         
            -
                  if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool:
         
     | 
| 
       72 
     | 
    
         
            -
                    if arg[2] == dtypes.bool != vin[1].dtype:
         
     | 
| 
       73 
     | 
    
         
            -
                      prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool)
         
     | 
| 
       74 
     | 
    
         
            -
                      render_cast(ins, vin[1], prereg)
         
     | 
| 
       75 
     | 
    
         
            -
                    else: prereg = vin[1]
         
     | 
| 
       76 
     | 
    
         
            -
                    reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2])
         
     | 
| 
       77 
     | 
    
         
            -
                    render_cast(ins, prereg, reg)
         
     | 
| 
       78 
     | 
    
         
            -
                    ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};")
         
     | 
| 
       79 
     | 
    
         
            -
                  else:
         
     | 
| 
       80 
     | 
    
         
            -
                    ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
         
     | 
| 
       81 
     | 
    
         
            -
                elif uop == UOps.CAST:
         
     | 
| 
       82 
     | 
    
         
            -
                  render_cast(ins, vin[0], out)
         
     | 
| 
       83 
     | 
    
         
            -
                elif uop == UOps.LABEL:
         
     | 
| 
       84 
     | 
    
         
            -
                  ins.append(f"{arg}:")
         
     | 
| 
       85 
     | 
    
         
            -
                elif uop == UOps.COND_BRANCH:
         
     | 
| 
       86 
     | 
    
         
            -
                  ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
         
     | 
| 
       87 
     | 
    
         
            -
             
     | 
| 
       88 
     | 
    
         
            -
              ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64",
         
     | 
| 
       89 
     | 
    
         
            -
                            f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"]
         
     | 
| 
       90 
     | 
    
         
            -
              for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",)
         
     | 
| 
       91 
     | 
    
         
            -
              ins = ins_prefix + ins
         
     | 
| 
       92 
     | 
    
         
            -
              ins += ["ret;", "}"]
         
     | 
| 
       93 
     | 
    
         
            -
              return '\n'.join(ins)
         
     | 
| 
       94 
     | 
    
         
            -
             
     | 
| 
       95 
     | 
    
         
            -
            def uops_to_ptx_asm(function_name:str, uops:List[UOp]):
         
     | 
| 
       96 
     | 
    
         
            -
              lang = PTXLanguage()
         
     | 
| 
       97 
     | 
    
         
            -
              global_size, local_size = uops_to_asmstyle(lang, function_name, uops)
         
     | 
| 
       98 
     | 
    
         
            -
              return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True
         
     | 
    
        tinygrad/renderer/wgsl.py
    DELETED
    
    | 
         @@ -1,53 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            from tinygrad.renderer.cstyle import render_cl
         
     | 
| 
       2 
     | 
    
         
            -
            from tinygrad.helpers import dtypes, DType
         
     | 
| 
       3 
     | 
    
         
            -
            from tinygrad.renderer.cstyle import CStyleLanguage
         
     | 
| 
       4 
     | 
    
         
            -
            from typing import List, Union
         
     | 
| 
       5 
     | 
    
         
            -
            from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
         
     | 
| 
       6 
     | 
    
         
            -
            import math
         
     | 
| 
       7 
     | 
    
         
            -
            from typing import Tuple
         
     | 
| 
       8 
     | 
    
         
            -
             
     | 
| 
       9 
     | 
    
         
            -
            type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool"}
         
     | 
| 
       10 
     | 
    
         
            -
            class WGSLLanguage(CStyleLanguage):
         
     | 
| 
       11 
     | 
    
         
            -
              gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)]
         
     | 
| 
       12 
     | 
    
         
            -
              lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)]
         
     | 
| 
       13 
     | 
    
         
            -
              size_prefix = "let"
         
     | 
| 
       14 
     | 
    
         
            -
              barrier="workgroupBarrier();"
         
     | 
| 
       15 
     | 
    
         
            -
              generic_var_prefix = "var "
         
     | 
| 
       16 
     | 
    
         
            -
              external_local_bufs = True
         
     | 
| 
       17 
     | 
    
         
            -
              code_for_op = {
         
     | 
| 
       18 
     | 
    
         
            -
                UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})", UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})",
         
     | 
| 
       19 
     | 
    
         
            -
                BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", BinaryOps.MUL: lambda x,y: f"({x}*{y})", BinaryOps.DIV: lambda x,y: f"({x}/{y})",
         
     | 
| 
       20 
     | 
    
         
            -
                BinaryOps.MAX: lambda x,y: f"max({x},{y})", BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})",
         
     | 
| 
       21 
     | 
    
         
            -
                TernaryOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c: f"select({c},{b},{a}!=0.)"
         
     | 
| 
       22 
     | 
    
         
            -
              }
         
     | 
| 
       23 
     | 
    
         
            -
             
     | 
| 
       24 
     | 
    
         
            -
              def render_local(self, name: str, size: int):
         
     | 
| 
       25 
     | 
    
         
            -
                return f"var<workgroup> {name}: array<f32,{size}>;"
         
     | 
| 
       26 
     | 
    
         
            -
             
     | 
| 
       27 
     | 
    
         
            -
              def render_const(self, x:Union[float,int], var_dtype) -> str:
         
     | 
| 
       28 
     | 
    
         
            -
                if math.isnan(x): val = "nan()"
         
     | 
| 
       29 
     | 
    
         
            -
                elif math.isinf(x): val = ("-" if x < 0 else "") + "0x1.fffffep+127f"
         
     | 
| 
       30 
     | 
    
         
            -
                else: val = f"{x}" + ("" if dtypes.is_int(var_dtype) else "f")
         
     | 
| 
       31 
     | 
    
         
            -
                return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
         
     | 
| 
       32 
     | 
    
         
            -
             
     | 
| 
       33 
     | 
    
         
            -
              def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str, List[int], List[int]]:
         
     | 
| 
       34 
     | 
    
         
            -
                local_size = local_size[::-1] if local_size else [1]
         
     | 
| 
       35 
     | 
    
         
            -
                bind_it = iter(range(len(bufs)))
         
     | 
| 
       36 
     | 
    
         
            -
                prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
         
     | 
| 
       37 
     | 
    
         
            -
                prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) var<storage,read_write> {name}: array<{type_map[dtype]}>;" for name,dtype in bufs])
         
     | 
| 
       38 
     | 
    
         
            -
                prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
         
     | 
| 
       39 
     | 
    
         
            -
                return prg, global_size[::-1] if global_size else [1], local_size
         
     | 
| 
       40 
     | 
    
         
            -
             
     | 
| 
       41 
     | 
    
         
            -
              def render_for(self, expr:str, _min:int, _max:Union[int,str]) -> str:
         
     | 
| 
       42 
     | 
    
         
            -
                return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"
         
     | 
| 
       43 
     | 
    
         
            -
             
     | 
| 
       44 
     | 
    
         
            -
              def render_conditional(self, cond:str, x:str, y:str) -> str:
         
     | 
| 
       45 
     | 
    
         
            -
                return f"select(f32({y}), {x}, bool({cond}))"
         
     | 
| 
       46 
     | 
    
         
            -
             
     | 
| 
       47 
     | 
    
         
            -
              def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
         
     | 
| 
       48 
     | 
    
         
            -
                return f"f32({super().render_load(output_dtype, buf_name, buf_dtype, idx, local)})"
         
     | 
| 
       49 
     | 
    
         
            -
             
     | 
| 
       50 
     | 
    
         
            -
              def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
         
     | 
| 
       51 
     | 
    
         
            -
                if buf_dtype != var_dtype:
         
     | 
| 
       52 
     | 
    
         
            -
                  var_name = f"{type_map[buf_dtype]}({var_name})"
         
     | 
| 
       53 
     | 
    
         
            -
                return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
         
     | 
    
        tinygrad/runtime/lib.py
    DELETED
    
    | 
         @@ -1,113 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            import ctypes
         
     | 
| 
       2 
     | 
    
         
            -
            import numpy as np
         
     | 
| 
       3 
     | 
    
         
            -
            from collections import defaultdict, deque
         
     | 
| 
       4 
     | 
    
         
            -
            from typing import TypeVar, Type, Any, Dict, Deque, Tuple
         
     | 
| 
       5 
     | 
    
         
            -
            from tinygrad.helpers import DType, dtypes, prod, GlobalCounters, ImageDType
         
     | 
| 
       6 
     | 
    
         
            -
             
     | 
| 
       7 
     | 
    
         
            -
            _T = TypeVar("_T")
         
     | 
| 
       8 
     | 
    
         
            -
            class RawBuffer:  # pylint: disable=abstract-method
         
     | 
| 
       9 
     | 
    
         
            -
              def __init__(self, size:int, dtype:DType, buf:Any=None, allocator:Any=None, **kwargs):
         
     | 
| 
       10 
     | 
    
         
            -
                self.size: int = size
         
     | 
| 
       11 
     | 
    
         
            -
                self.dtype: DType = dtype
         
     | 
| 
       12 
     | 
    
         
            -
                self._buf = buf if buf is not None else (allocator.alloc(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator.
         
     | 
| 
       13 
     | 
    
         
            -
                self._memsz: int = size*dtype.itemsize
         
     | 
| 
       14 
     | 
    
         
            -
                self._allocator = allocator
         
     | 
| 
       15 
     | 
    
         
            -
                GlobalCounters.mem_used += self._memsz
         
     | 
| 
       16 
     | 
    
         
            -
              def __del__(self):  # NOTE: if it fails on init (bad dtype), it won't have a _memsz
         
     | 
| 
       17 
     | 
    
         
            -
                if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz
         
     | 
| 
       18 
     | 
    
         
            -
                if hasattr(self, '_allocator') and self._allocator: self._allocator.free(self._buf)
         
     | 
| 
       19 
     | 
    
         
            -
              def __repr__(self): return f"buffer<{self.size}, {self.dtype}>"
         
     | 
| 
       20 
     | 
    
         
            -
              @property
         
     | 
| 
       21 
     | 
    
         
            -
              def key(self): return (self.size, self.dtype)
         
     | 
| 
       22 
     | 
    
         
            -
             
     | 
| 
       23 
     | 
    
         
            -
              # NOTE: this interface allows for 0 copy
         
     | 
| 
       24 
     | 
    
         
            -
              @classmethod
         
     | 
| 
       25 
     | 
    
         
            -
              def fromCPU(cls:Type[_T], x:np.ndarray) -> _T: raise NotImplementedError("must be implemented")
         
     | 
| 
       26 
     | 
    
         
            -
              def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
         
     | 
| 
       27 
     | 
    
         
            -
             
     | 
| 
       28 
     | 
    
         
            -
            class RawConst(RawBuffer): # pylint: disable=abstract-method
         
     | 
| 
       29 
     | 
    
         
            -
              def __repr__(self): return f"const<{self._buf}, {self.dtype}>"
         
     | 
| 
       30 
     | 
    
         
            -
              @property
         
     | 
| 
       31 
     | 
    
         
            -
              def key(self): return (str(self._buf), self.dtype)
         
     | 
| 
       32 
     | 
    
         
            -
             
     | 
| 
       33 
     | 
    
         
            -
            def buf_is_kernel_arg(x) -> bool:
         
     | 
| 
       34 
     | 
    
         
            -
              return x.realized is not None and x.realized.__class__ is not RawConst
         
     | 
| 
       35 
     | 
    
         
            -
             
     | 
| 
       36 
     | 
    
         
            -
            # --teenygrad--
         
     | 
| 
       37 
     | 
    
         
            -
             
     | 
| 
       38 
     | 
    
         
            -
            class RawBufferCopyIn(RawBuffer):
         
     | 
| 
       39 
     | 
    
         
            -
              def _copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
         
     | 
| 
       40 
     | 
    
         
            -
             
     | 
| 
       41 
     | 
    
         
            -
              @classmethod
         
     | 
| 
       42 
     | 
    
         
            -
              def fromCPU(cls, x:np.ndarray, **kwargs):
         
     | 
| 
       43 
     | 
    
         
            -
                ret = cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
         
     | 
| 
       44 
     | 
    
         
            -
                ret._copyin(x)
         
     | 
| 
       45 
     | 
    
         
            -
                return ret
         
     | 
| 
       46 
     | 
    
         
            -
             
     | 
| 
       47 
     | 
    
         
            -
            class RawBufferMapped(RawBufferCopyIn):
         
     | 
| 
       48 
     | 
    
         
            -
              def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented")
         
     | 
| 
       49 
     | 
    
         
            -
              # NOTE: this metadata prevents the backing buffer from being freed. hack can be removed with PEP688
         
     | 
| 
       50 
     | 
    
         
            -
              def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self}))  # type: ignore
         
     | 
| 
       51 
     | 
    
         
            -
              def _copyin(self, x:np.ndarray) -> None: np.copyto(self.toCPU(), x.reshape(-1))
         
     | 
| 
       52 
     | 
    
         
            -
             
     | 
| 
       53 
     | 
    
         
            -
            # this one is simple enough that i moved it out of the runtimes
         
     | 
| 
       54 
     | 
    
         
            -
            class RawMallocBuffer(RawBufferMapped):
         
     | 
| 
       55 
     | 
    
         
            -
              def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float64:ctypes.c_double, dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64}[dtype] * size)())
         
     | 
| 
       56 
     | 
    
         
            -
              def _buffer(self): return memoryview(self._buf)
         
     | 
| 
       57 
     | 
    
         
            -
             
     | 
| 
       58 
     | 
    
         
            -
            class RawBufferCopyInOut(RawBufferCopyIn):
         
     | 
| 
       59 
     | 
    
         
            -
              def _copyout(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
         
     | 
| 
       60 
     | 
    
         
            -
             
     | 
| 
       61 
     | 
    
         
            -
              def toCPU(self) -> np.ndarray:
         
     | 
| 
       62 
     | 
    
         
            -
                x: np.ndarray = np.empty(self.size, dtype=self.dtype.np)
         
     | 
| 
       63 
     | 
    
         
            -
                self._copyout(x)
         
     | 
| 
       64 
     | 
    
         
            -
                return x
         
     | 
| 
       65 
     | 
    
         
            -
             
     | 
| 
       66 
     | 
    
         
            -
            class RawBufferTransfer(RawBuffer):
         
     | 
| 
       67 
     | 
    
         
            -
              def _transfer(self, x) -> None: raise NotImplementedError("must be implemented")
         
     | 
| 
       68 
     | 
    
         
            -
             
     | 
| 
       69 
     | 
    
         
            -
              @classmethod
         
     | 
| 
       70 
     | 
    
         
            -
              def transfer(cls, x, shape, dtype, **kwargs):
         
     | 
| 
       71 
     | 
    
         
            -
                ret = cls(prod(shape), dtype, **kwargs)
         
     | 
| 
       72 
     | 
    
         
            -
                ret._transfer(x)
         
     | 
| 
       73 
     | 
    
         
            -
                return ret
         
     | 
| 
       74 
     | 
    
         
            -
             
     | 
| 
       75 
     | 
    
         
            -
            class LRUAllocator:
         
     | 
| 
       76 
     | 
    
         
            -
              def __init__(self, dev_memsz=(4<<30)):
         
     | 
| 
       77 
     | 
    
         
            -
                self.epoch = 0
         
     | 
| 
       78 
     | 
    
         
            -
                self.free_space: Dict[Any, int] = defaultdict(lambda: dev_memsz)
         
     | 
| 
       79 
     | 
    
         
            -
                self.buffer_info: Dict[Any, Tuple[int, DType, str]] = dict()
         
     | 
| 
       80 
     | 
    
         
            -
                self.cached_buffers: Dict[Tuple[int, ...], Deque[Tuple[Any, int]]] = defaultdict(deque) # Cached buffer storage, splitted by type and size, newest first.
         
     | 
| 
       81 
     | 
    
         
            -
                self.aging_order: Dict[Any, Deque[Tuple[Tuple[int, ...], int]]] = defaultdict(deque) # Keys of cached_buffers, ordered from oldest to newest updates.
         
     | 
| 
       82 
     | 
    
         
            -
              def __del__(self):
         
     | 
| 
       83 
     | 
    
         
            -
                for v in self.cached_buffers.values():
         
     | 
| 
       84 
     | 
    
         
            -
                  for buf, _ in v: self._free_buffer(buf)
         
     | 
| 
       85 
     | 
    
         
            -
              def _cache_reuse_buffer(self, rawbufs: Deque[Tuple[Any, int]]): # The newest cached buffer is reused.
         
     | 
| 
       86 
     | 
    
         
            -
                GlobalCounters.mem_cached -= self._underlying_buf_memsz(rawbufs[0][0])
         
     | 
| 
       87 
     | 
    
         
            -
                return rawbufs.popleft()[0]
         
     | 
| 
       88 
     | 
    
         
            -
              def _alloc_buffer(self, size, dtype, device, **kwargs):
         
     | 
| 
       89 
     | 
    
         
            -
                self.free_space[device] -= size*dtype.itemsize
         
     | 
| 
       90 
     | 
    
         
            -
                while len(self.aging_order[device]) and self.free_space[device] < 0: # When OOM removing lru buffers.
         
     | 
| 
       91 
     | 
    
         
            -
                  bucket, epoch = self.aging_order[device].popleft()
         
     | 
| 
       92 
     | 
    
         
            -
                  if self.cached_buffers[bucket] and self.cached_buffers[bucket][-1][1] == epoch: self._free_buffer(self.cached_buffers[bucket].pop()[0]) # Free cached buffer if it is still in cache.
         
     | 
| 
       93 
     | 
    
         
            -
                newbuf = self._do_alloc(size, dtype, device, **kwargs)
         
     | 
| 
       94 
     | 
    
         
            -
                self.buffer_info[newbuf] = (size, dtype, device)
         
     | 
| 
       95 
     | 
    
         
            -
                return newbuf
         
     | 
| 
       96 
     | 
    
         
            -
              def _free_buffer(self, buf_to_free):
         
     | 
| 
       97 
     | 
    
         
            -
                self.free_space[self.buffer_info[buf_to_free][2]] += self._underlying_buf_memsz(buf_to_free)
         
     | 
| 
       98 
     | 
    
         
            -
                GlobalCounters.mem_cached -= self._underlying_buf_memsz(buf_to_free)
         
     | 
| 
       99 
     | 
    
         
            -
                self.buffer_info.pop(buf_to_free)
         
     | 
| 
       100 
     | 
    
         
            -
                self._do_free(buf_to_free)
         
     | 
| 
       101 
     | 
    
         
            -
              def alloc(self, size, dtype, device='0', **kwargs):
         
     | 
| 
       102 
     | 
    
         
            -
                rawbufs = self.cached_buffers.get(self._cached_bufkey(size, dtype, device), None)
         
     | 
| 
       103 
     | 
    
         
            -
                return self._cache_reuse_buffer(rawbufs) if rawbufs else self._alloc_buffer(size, dtype, device, **kwargs)
         
     | 
| 
       104 
     | 
    
         
            -
              def free(self, buf): # free() just caches buffer. It might be freed later when OOM during allocation.
         
     | 
| 
       105 
     | 
    
         
            -
                self.epoch += 1
         
     | 
| 
       106 
     | 
    
         
            -
                size, dtype, device = self.buffer_info[buf]
         
     | 
| 
       107 
     | 
    
         
            -
                self.cached_buffers[self._cached_bufkey(size, dtype, device)].appendleft((buf, self.epoch))
         
     | 
| 
       108 
     | 
    
         
            -
                self.aging_order[device].append((self._cached_bufkey(size, dtype, device), self.epoch))
         
     | 
| 
       109 
     | 
    
         
            -
                GlobalCounters.mem_cached += self._underlying_buf_memsz(buf)
         
     | 
| 
       110 
     | 
    
         
            -
              def _underlying_buf_memsz(self, buf): return self.buffer_info[buf][0] * self.buffer_info[buf][1].itemsize
         
     | 
| 
       111 
     | 
    
         
            -
              def _cached_bufkey(self, size, dtype, device) -> Tuple[int, ...]: return (device, size, dtype, dtype.shape) if isinstance(dtype, ImageDType) else (device, size, dtype) # Provides a key for reusing device buffers with identical keys.
         
     | 
| 
       112 
     | 
    
         
            -
              def _do_alloc(self, size, dtype, device, **kwargs): raise NotImplementedError("must be implemented")
         
     | 
| 
       113 
     | 
    
         
            -
              def _do_free(self, buf): pass
         
     | 
    
        tinygrad/runtime/ops_cpu.py
    DELETED
    
    | 
         @@ -1,51 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            import numpy as np
         
     | 
| 
       2 
     | 
    
         
            -
            import operator
         
     | 
| 
       3 
     | 
    
         
            -
            from typing import Callable, Dict, Tuple, Optional
         
     | 
| 
       4 
     | 
    
         
            -
            from tinygrad.helpers import dtypes, DType
         
     | 
| 
       5 
     | 
    
         
            -
            from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted
         
     | 
| 
       6 
     | 
    
         
            -
            from tinygrad.runtime.lib import RawBuffer
         
     | 
| 
       7 
     | 
    
         
            -
             
     | 
| 
       8 
     | 
    
         
            -
            def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
         
     | 
| 
       9 
     | 
    
         
            -
              assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
         
     | 
| 
       10 
     | 
    
         
            -
              return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)
         
     | 
| 
       11 
     | 
    
         
            -
             
     | 
| 
       12 
     | 
    
         
            -
            base_fxn_for_op: Dict[Op, Callable] = {
         
     | 
| 
       13 
     | 
    
         
            -
              BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv,
         
     | 
| 
       14 
     | 
    
         
            -
              ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
         
     | 
| 
       15 
     | 
    
         
            -
              ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
         
     | 
| 
       16 
     | 
    
         
            -
              MovementOps.RESHAPE: lambda x, arg: x.reshape(arg), MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
         
     | 
| 
       17 
     | 
    
         
            -
            }
         
     | 
| 
       18 
     | 
    
         
            -
             
     | 
| 
       19 
     | 
    
         
            -
            def promote_types(x, y): return ret if (ret := np.promote_types(x.dtype, y.dtype)) != np.float64 else np.float32
         
     | 
| 
       20 
     | 
    
         
            -
            def match_types(x, y):
         
     | 
| 
       21 
     | 
    
         
            -
              up = x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
         
     | 
| 
       22 
     | 
    
         
            -
              return x.astype(up, copy=False), y.astype(up, copy=False)
         
     | 
| 
       23 
     | 
    
         
            -
             
     | 
| 
       24 
     | 
    
         
            -
            def einsum_mulacc(einsum, get_strides, expand):
         
     | 
| 
       25 
     | 
    
         
            -
              def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
         
     | 
| 
       26 
     | 
    
         
            -
              def axes_slice(strides): return [i for i in range(len(strides)) if strides[i] != 0], tuple(slice(None) if strides[i] != 0 else 0 for i in range(len(strides)))
         
     | 
| 
       27 
     | 
    
         
            -
              def mulacc(a, b, new_shape):
         
     | 
| 
       28 
     | 
    
         
            -
                (a_axes, a_slices), (b_axes, b_slices) = axes_slice(get_strides(a)), axes_slice(get_strides(b))
         
     | 
| 
       29 
     | 
    
         
            -
                out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)]
         
     | 
| 
       30 
     | 
    
         
            -
                ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}", a[a_slices], b[b_slices])
         
     | 
| 
       31 
     | 
    
         
            -
                return expand(ret.reshape([(1 if i not in a_axes and i not in b_axes else s) for i,s in enumerate(new_shape)]), new_shape)
         
     | 
| 
       32 
     | 
    
         
            -
              return mulacc
         
     | 
| 
       33 
     | 
    
         
            -
             
     | 
| 
       34 
     | 
    
         
            -
            numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
         
     | 
| 
       35 
     | 
    
         
            -
              UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
         
     | 
| 
       36 
     | 
    
         
            -
              UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
         
     | 
| 
       37 
     | 
    
         
            -
              BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(promote_types(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
         
     | 
| 
       38 
     | 
    
         
            -
              BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
         
     | 
| 
       39 
     | 
    
         
            -
              BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
         
     | 
| 
       40 
     | 
    
         
            -
              MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
         
     | 
| 
       41 
     | 
    
         
            -
              MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
         
     | 
| 
       42 
     | 
    
         
            -
              TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to),
         
     | 
| 
       43 
     | 
    
         
            -
              TernaryOps.WHERE: np.where,
         
     | 
| 
       44 
     | 
    
         
            -
            }}
         
     | 
| 
       45 
     | 
    
         
            -
             
     | 
| 
       46 
     | 
    
         
            -
            class RawNumpyBuffer(RawBuffer):
         
     | 
| 
       47 
     | 
    
         
            -
              def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None): super().__init__(size, dtype, buf if buf is not None else np.empty([size], dtype.np))
         
     | 
| 
       48 
     | 
    
         
            -
              @classmethod
         
     | 
| 
       49 
     | 
    
         
            -
              def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
         
     | 
| 
       50 
     | 
    
         
            -
              def toCPU(self): return self._buf
         
     | 
| 
       51 
     | 
    
         
            -
            CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op, from_underlying=RawNumpyBuffer.fromCPU)
         
     |