tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
 - tinygrad/codegen/kernel.py +572 -83
 - tinygrad/codegen/linearizer.py +415 -395
 - tinygrad/codegen/uops.py +415 -0
 - tinygrad/device.py +183 -0
 - tinygrad/dtype.py +113 -0
 - tinygrad/engine/__init__.py +0 -0
 - tinygrad/engine/graph.py +100 -0
 - tinygrad/engine/jit.py +195 -0
 - tinygrad/engine/realize.py +191 -0
 - tinygrad/engine/schedule.py +362 -0
 - tinygrad/engine/search.py +196 -0
 - tinygrad/{mlops.py → function.py} +76 -55
 - tinygrad/helpers.py +196 -89
 - tinygrad/lazy.py +210 -371
 - tinygrad/multi.py +169 -0
 - tinygrad/nn/__init__.py +202 -22
 - tinygrad/nn/datasets.py +7 -0
 - tinygrad/nn/optim.py +112 -32
 - tinygrad/nn/state.py +136 -39
 - tinygrad/ops.py +119 -202
 - tinygrad/renderer/__init__.py +61 -0
 - tinygrad/renderer/assembly.py +276 -0
 - tinygrad/renderer/cstyle.py +353 -166
 - tinygrad/renderer/llvmir.py +150 -138
 - tinygrad/runtime/autogen/amd_gpu.py +1900 -0
 - tinygrad/runtime/autogen/comgr.py +865 -0
 - tinygrad/runtime/autogen/cuda.py +5923 -0
 - tinygrad/runtime/autogen/hip.py +5909 -0
 - tinygrad/runtime/autogen/hsa.py +5761 -0
 - tinygrad/runtime/autogen/kfd.py +812 -0
 - tinygrad/runtime/autogen/nv_gpu.py +33328 -0
 - tinygrad/runtime/autogen/opencl.py +1795 -0
 - tinygrad/runtime/driver/hip_comgr.py +47 -0
 - tinygrad/runtime/driver/hsa.py +143 -0
 - tinygrad/runtime/graph/clang.py +38 -0
 - tinygrad/runtime/graph/cuda.py +81 -0
 - tinygrad/runtime/graph/hcq.py +143 -0
 - tinygrad/runtime/graph/hsa.py +171 -0
 - tinygrad/runtime/graph/metal.py +75 -0
 - tinygrad/runtime/ops_amd.py +564 -0
 - tinygrad/runtime/ops_clang.py +24 -77
 - tinygrad/runtime/ops_cuda.py +175 -89
 - tinygrad/runtime/ops_disk.py +56 -33
 - tinygrad/runtime/ops_gpu.py +92 -95
 - tinygrad/runtime/ops_hsa.py +278 -0
 - tinygrad/runtime/ops_llvm.py +39 -60
 - tinygrad/runtime/ops_metal.py +92 -74
 - tinygrad/runtime/ops_npy.py +9 -0
 - tinygrad/runtime/ops_nv.py +630 -0
 - tinygrad/runtime/ops_python.py +204 -0
 - tinygrad/shape/shapetracker.py +86 -254
 - tinygrad/shape/symbolic.py +166 -141
 - tinygrad/shape/view.py +296 -0
 - tinygrad/tensor.py +2619 -448
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
 - tinygrad-0.9.0.dist-info/METADATA +227 -0
 - tinygrad-0.9.0.dist-info/RECORD +60 -0
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
 - tinygrad/codegen/assembly.py +0 -190
 - tinygrad/codegen/optimizer.py +0 -379
 - tinygrad/codegen/search.py +0 -72
 - tinygrad/graph.py +0 -83
 - tinygrad/jit.py +0 -57
 - tinygrad/nn/image.py +0 -100
 - tinygrad/renderer/assembly_arm64.py +0 -169
 - tinygrad/renderer/assembly_ptx.py +0 -98
 - tinygrad/renderer/wgsl.py +0 -53
 - tinygrad/runtime/lib.py +0 -113
 - tinygrad/runtime/ops_cpu.py +0 -51
 - tinygrad/runtime/ops_hip.py +0 -82
 - tinygrad/runtime/ops_shm.py +0 -29
 - tinygrad/runtime/ops_torch.py +0 -30
 - tinygrad/runtime/ops_webgpu.py +0 -45
 - tinygrad-0.7.0.dist-info/METADATA +0 -212
 - tinygrad-0.7.0.dist-info/RECORD +0 -40
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
 
| 
         @@ -1,31 +1,48 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            """This is where the forwards and backwards passes live."""
         
     | 
| 
       1 
2 
     | 
    
         
             
            import math
         
     | 
| 
       2 
3 
     | 
    
         
             
            from typing import Tuple, Optional
         
     | 
| 
       3 
     | 
    
         
            -
            from tinygrad.helpers import argsort 
     | 
| 
      
 4 
     | 
    
         
            +
            from tinygrad.helpers import argsort
         
     | 
| 
      
 5 
     | 
    
         
            +
            from tinygrad.dtype import dtypes, DType, sum_acc_dtype
         
     | 
| 
       4 
6 
     | 
    
         
             
            from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
         
     | 
| 
       5 
7 
     | 
    
         
             
            from tinygrad.tensor import Function
         
     | 
| 
       6 
8 
     | 
    
         
             
            from tinygrad.lazy import LazyBuffer
         
     | 
| 
      
 9 
     | 
    
         
            +
            from tinygrad.shape.symbolic import sint
         
     | 
| 
       7 
10 
     | 
    
         | 
| 
       8 
11 
     | 
    
         
             
            class Contiguous(Function):
         
     | 
| 
       9 
12 
     | 
    
         
             
              def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
         
     | 
| 
       10 
13 
     | 
    
         
             
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output
         
     | 
| 
       11 
14 
     | 
    
         | 
| 
      
 15 
     | 
    
         
            +
            class ContiguousBackward(Function):
         
     | 
| 
      
 16 
     | 
    
         
            +
              def forward(self, x:LazyBuffer) -> LazyBuffer: return x
         
     | 
| 
      
 17 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.contiguous()
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
       12 
19 
     | 
    
         
             
            class Cast(Function):
         
     | 
| 
       13 
20 
     | 
    
         
             
              def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
         
     | 
| 
       14 
21 
     | 
    
         
             
                self.input_dtype, self.bitcast = x.dtype, bitcast
         
     | 
| 
       15 
     | 
    
         
            -
                return x. 
     | 
| 
      
 22 
     | 
    
         
            +
                return x.cast(dtype, bitcast)
         
     | 
| 
       16 
23 
     | 
    
         | 
| 
       17 
     | 
    
         
            -
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       18 
     | 
    
         
            -
                return grad_output.e(UnaryOps.CAST, arg=(self.input_dtype, self.bitcast))
         
     | 
| 
      
 24 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.cast(self.input_dtype, self.bitcast)
         
     | 
| 
       19 
25 
     | 
    
         | 
| 
       20 
26 
     | 
    
         
             
            # ************* unary ops *************
         
     | 
| 
       21 
27 
     | 
    
         | 
| 
      
 28 
     | 
    
         
            +
            class Neg(Function):
         
     | 
| 
      
 29 
     | 
    
         
            +
              def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG)
         
     | 
| 
      
 30 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(UnaryOps.NEG)
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
            class Reciprocal(Function):
         
     | 
| 
      
 33 
     | 
    
         
            +
              def forward(self, x:LazyBuffer) -> LazyBuffer:
         
     | 
| 
      
 34 
     | 
    
         
            +
                self.ret = x.const(1).e(BinaryOps.DIV, x)
         
     | 
| 
      
 35 
     | 
    
         
            +
                return self.ret
         
     | 
| 
      
 36 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
         
     | 
| 
      
 37 
     | 
    
         
            +
                return grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.ret).e(BinaryOps.MUL, self.ret)
         
     | 
| 
      
 38 
     | 
    
         
            +
             
     | 
| 
       22 
39 
     | 
    
         
             
            class Sin(Function):
         
     | 
| 
       23 
40 
     | 
    
         
             
              def forward(self, x:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       24 
41 
     | 
    
         
             
                self.x = x
         
     | 
| 
       25 
42 
     | 
    
         
             
                return x.e(UnaryOps.SIN)
         
     | 
| 
       26 
43 
     | 
    
         | 
| 
       27 
     | 
    
         
            -
              def backward(self,  
     | 
| 
       28 
     | 
    
         
            -
                return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL,  
     | 
| 
      
 44 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
         
     | 
| 
      
 45 
     | 
    
         
            +
                return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
         
     | 
| 
       29 
46 
     | 
    
         | 
| 
       30 
47 
     | 
    
         
             
            # NOTE: maximum(x, 0) behaves differently where x=0
         
     | 
| 
       31 
48 
     | 
    
         
             
            class Relu(Function):
         
     | 
| 
         @@ -34,23 +51,21 @@ class Relu(Function): 
     | 
|
| 
       34 
51 
     | 
    
         
             
                return self.ret
         
     | 
| 
       35 
52 
     | 
    
         | 
| 
       36 
53 
     | 
    
         
             
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       37 
     | 
    
         
            -
                return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).e(BinaryOps.MUL, grad_output)
         
     | 
| 
      
 54 
     | 
    
         
            +
                return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output)
         
     | 
| 
       38 
55 
     | 
    
         | 
| 
       39 
56 
     | 
    
         
             
            class Log(Function):
         
     | 
| 
       40 
57 
     | 
    
         
             
              def forward(self, x:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       41 
58 
     | 
    
         
             
                self.x = x
         
     | 
| 
       42 
59 
     | 
    
         
             
                return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
         
     | 
| 
       43 
60 
     | 
    
         | 
| 
       44 
     | 
    
         
            -
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       45 
     | 
    
         
            -
                return grad_output.e(BinaryOps.DIV, self.x)
         
     | 
| 
      
 61 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.DIV, self.x)
         
     | 
| 
       46 
62 
     | 
    
         | 
| 
       47 
63 
     | 
    
         
             
            class Exp(Function):
         
     | 
| 
       48 
64 
     | 
    
         
             
              def forward(self, x:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       49 
65 
     | 
    
         
             
                self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
         
     | 
| 
       50 
66 
     | 
    
         
             
                return self.ret
         
     | 
| 
       51 
67 
     | 
    
         | 
| 
       52 
     | 
    
         
            -
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       53 
     | 
    
         
            -
                return self.ret.e(BinaryOps.MUL, grad_output)
         
     | 
| 
      
 68 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, grad_output)
         
     | 
| 
       54 
69 
     | 
    
         | 
| 
       55 
70 
     | 
    
         
             
            class Sqrt(Function):
         
     | 
| 
       56 
71 
     | 
    
         
             
              def forward(self, x:LazyBuffer) -> LazyBuffer:
         
     | 
| 
         @@ -71,48 +86,39 @@ class Sigmoid(Function): 
     | 
|
| 
       71 
86 
     | 
    
         
             
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       72 
87 
     | 
    
         
             
                return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.SUB, self.ret)).e(BinaryOps.MUL, grad_output)
         
     | 
| 
       73 
88 
     | 
    
         | 
| 
       74 
     | 
    
         
            -
             
     | 
| 
       75 
     | 
    
         
            -
             
     | 
| 
       76 
     | 
    
         
            -
             
     | 
| 
       77 
     | 
    
         
            -
             
     | 
| 
       78 
     | 
    
         
            -
             
     | 
| 
       79 
     | 
    
         
            -
             
     | 
| 
       80 
     | 
    
         
            -
             
     | 
| 
       81 
     | 
    
         
            -
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       82 
     | 
    
         
            -
                return grad_output.expand(self.input_shape)
         
     | 
| 
       83 
     | 
    
         
            -
             
     | 
| 
       84 
     | 
    
         
            -
            class Max(Function):
         
     | 
| 
       85 
     | 
    
         
            -
              def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
         
     | 
| 
       86 
     | 
    
         
            -
                self.x, self.ret = x, x.reduce_op(ReduceOps.MAX, new_shape)
         
     | 
| 
       87 
     | 
    
         
            -
                return self.ret
         
     | 
| 
       88 
     | 
    
         
            -
             
     | 
| 
       89 
     | 
    
         
            -
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       90 
     | 
    
         
            -
                # 1s in locations where the max was chosen (can be two locations)
         
     | 
| 
       91 
     | 
    
         
            -
                max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
         
     | 
| 
       92 
     | 
    
         
            -
                div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
         
     | 
| 
       93 
     | 
    
         
            -
                return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
         
     | 
| 
      
 89 
     | 
    
         
            +
            class Sign(Function):
         
     | 
| 
      
 90 
     | 
    
         
            +
              def forward(self, x:LazyBuffer) -> LazyBuffer:
         
     | 
| 
      
 91 
     | 
    
         
            +
                return x.e(BinaryOps.CMPEQ, x.const(0)).e(TernaryOps.WHERE, x.const(0),
         
     | 
| 
      
 92 
     | 
    
         
            +
                                                          x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)))
         
     | 
| 
      
 93 
     | 
    
         
            +
              # backward always return 0 to match torch
         
     | 
| 
      
 94 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
         
     | 
| 
       94 
95 
     | 
    
         | 
| 
       95 
96 
     | 
    
         
             
            # ************* binary ops *************
         
     | 
| 
       96 
97 
     | 
    
         | 
| 
       97 
98 
     | 
    
         
             
            class Less(Function):
         
     | 
| 
       98 
     | 
    
         
            -
              def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       99 
     | 
    
         
            -
             
     | 
| 
      
 99 
     | 
    
         
            +
              def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y)
         
     | 
| 
      
 100 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
         
     | 
| 
      
 101 
     | 
    
         
            +
             
     | 
| 
      
 102 
     | 
    
         
            +
            class Eq(Function):
         
     | 
| 
      
 103 
     | 
    
         
            +
              def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPEQ, y)
         
     | 
| 
      
 104 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
         
     | 
| 
      
 105 
     | 
    
         
            +
             
     | 
| 
      
 106 
     | 
    
         
            +
            class Xor(Function):
         
     | 
| 
      
 107 
     | 
    
         
            +
              def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y)
         
     | 
| 
       100 
108 
     | 
    
         | 
| 
       101 
109 
     | 
    
         
             
            class Add(Function):
         
     | 
| 
       102 
     | 
    
         
            -
              def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       103 
     | 
    
         
            -
                return x.e(BinaryOps.ADD, y)
         
     | 
| 
      
 110 
     | 
    
         
            +
              def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.ADD, y)
         
     | 
| 
       104 
111 
     | 
    
         | 
| 
       105 
112 
     | 
    
         
             
              def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
         
     | 
| 
       106 
113 
     | 
    
         
             
                return grad_output if self.needs_input_grad[0] else None, \
         
     | 
| 
       107 
114 
     | 
    
         
             
                       grad_output if self.needs_input_grad[1] else None
         
     | 
| 
       108 
115 
     | 
    
         | 
| 
       109 
116 
     | 
    
         
             
            class Sub(Function):
         
     | 
| 
       110 
     | 
    
         
            -
              def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       111 
     | 
    
         
            -
                return x.e(BinaryOps.SUB, y)
         
     | 
| 
      
 117 
     | 
    
         
            +
              def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.SUB, y)
         
     | 
| 
       112 
118 
     | 
    
         | 
| 
       113 
119 
     | 
    
         
             
              def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
         
     | 
| 
       114 
120 
     | 
    
         
             
                return grad_output if self.needs_input_grad[0] else None, \
         
     | 
| 
       115 
     | 
    
         
            -
                       grad_output. 
     | 
| 
      
 121 
     | 
    
         
            +
                       grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None
         
     | 
| 
       116 
122 
     | 
    
         | 
| 
       117 
123 
     | 
    
         
             
            class Mul(Function):
         
     | 
| 
       118 
124 
     | 
    
         
             
              def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
         
     | 
| 
         @@ -130,67 +136,82 @@ class Div(Function): 
     | 
|
| 
       130 
136 
     | 
    
         | 
| 
       131 
137 
     | 
    
         
             
              def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
         
     | 
| 
       132 
138 
     | 
    
         
             
                return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
         
     | 
| 
       133 
     | 
    
         
            -
                       grad_output. 
     | 
| 
      
 139 
     | 
    
         
            +
                       grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None  # noqa: E501
         
     | 
| 
       134 
140 
     | 
    
         | 
| 
       135 
141 
     | 
    
         
             
            # ************* ternary ops *************
         
     | 
| 
       136 
142 
     | 
    
         | 
| 
       137 
143 
     | 
    
         
             
            class Where(Function):
         
     | 
| 
       138 
144 
     | 
    
         
             
              def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       139 
145 
     | 
    
         
             
                self.x = x
         
     | 
| 
       140 
     | 
    
         
            -
                return x.e(TernaryOps.WHERE, y, z)
         
     | 
| 
      
 146 
     | 
    
         
            +
                return self.x.e(TernaryOps.WHERE, y, z)
         
     | 
| 
       141 
147 
     | 
    
         | 
| 
       142 
148 
     | 
    
         
             
              def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
         
     | 
| 
       143 
149 
     | 
    
         
             
                return None, \
         
     | 
| 
       144 
     | 
    
         
            -
             
     | 
| 
       145 
     | 
    
         
            -
             
     | 
| 
      
 150 
     | 
    
         
            +
                  self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
         
     | 
| 
      
 151 
     | 
    
         
            +
                  self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
         
     | 
| 
      
 152 
     | 
    
         
            +
             
     | 
| 
      
 153 
     | 
    
         
            +
            # ************* reduce ops *************
         
     | 
| 
      
 154 
     | 
    
         
            +
             
     | 
| 
      
 155 
     | 
    
         
            +
            class Sum(Function):
         
     | 
| 
      
 156 
     | 
    
         
            +
              def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
         
     | 
| 
      
 157 
     | 
    
         
            +
                self.input_shape = x.shape
         
     | 
| 
      
 158 
     | 
    
         
            +
                return x.r(ReduceOps.SUM, axis)
         
     | 
| 
      
 159 
     | 
    
         
            +
             
     | 
| 
      
 160 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
         
     | 
| 
      
 161 
     | 
    
         
            +
             
     | 
| 
      
 162 
     | 
    
         
            +
            class Max(Function):
         
     | 
| 
      
 163 
     | 
    
         
            +
              def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
         
     | 
| 
      
 164 
     | 
    
         
            +
                self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis
         
     | 
| 
      
 165 
     | 
    
         
            +
                return self.ret
         
     | 
| 
      
 166 
     | 
    
         
            +
             
     | 
| 
      
 167 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
         
     | 
| 
      
 168 
     | 
    
         
            +
                # 1s in locations where the max was chosen (can be two locations)
         
     | 
| 
      
 169 
     | 
    
         
            +
                max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(dtypes.float)
         
     | 
| 
      
 170 
     | 
    
         
            +
                div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
         
     | 
| 
      
 171 
     | 
    
         
            +
                return max_is_1s.e(BinaryOps.DIV, div).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
         
     | 
| 
       146 
172 
     | 
    
         | 
| 
       147 
173 
     | 
    
         
             
            # ************* movement ops *************
         
     | 
| 
       148 
174 
     | 
    
         | 
| 
       149 
175 
     | 
    
         
             
            # NOTE: this is sum in reverse
         
     | 
| 
       150 
176 
     | 
    
         
             
            class Expand(Function):
         
     | 
| 
       151 
177 
     | 
    
         
             
              def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
         
     | 
| 
       152 
     | 
    
         
            -
                self. 
     | 
| 
      
 178 
     | 
    
         
            +
                self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so)
         
     | 
| 
       153 
179 
     | 
    
         
             
                return x.expand(shape)
         
     | 
| 
       154 
180 
     | 
    
         | 
| 
       155 
181 
     | 
    
         
             
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       156 
     | 
    
         
            -
                return grad_output. 
     | 
| 
      
 182 
     | 
    
         
            +
                return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype)
         
     | 
| 
       157 
183 
     | 
    
         | 
| 
       158 
184 
     | 
    
         
             
            class Reshape(Function):
         
     | 
| 
       159 
185 
     | 
    
         
             
              def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
         
     | 
| 
       160 
186 
     | 
    
         
             
                self.input_shape = x.shape
         
     | 
| 
       161 
187 
     | 
    
         
             
                return x.reshape(shape)
         
     | 
| 
       162 
188 
     | 
    
         | 
| 
       163 
     | 
    
         
            -
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       164 
     | 
    
         
            -
                return grad_output.reshape(self.input_shape)
         
     | 
| 
      
 189 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.reshape(self.input_shape)
         
     | 
| 
       165 
190 
     | 
    
         | 
| 
       166 
191 
     | 
    
         
             
            class Permute(Function):
         
     | 
| 
       167 
192 
     | 
    
         
             
              def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
         
     | 
| 
       168 
193 
     | 
    
         
             
                self.input_order = order
         
     | 
| 
       169 
194 
     | 
    
         
             
                return x.permute(order)
         
     | 
| 
       170 
195 
     | 
    
         | 
| 
       171 
     | 
    
         
            -
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       172 
     | 
    
         
            -
                return grad_output.permute(argsort(self.input_order))
         
     | 
| 
      
 196 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.permute(argsort(self.input_order))
         
     | 
| 
       173 
197 
     | 
    
         | 
| 
       174 
198 
     | 
    
         
             
            class Pad(Function):
         
     | 
| 
       175 
199 
     | 
    
         
             
              def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
         
     | 
| 
       176 
200 
     | 
    
         
             
                self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
         
     | 
| 
       177 
201 
     | 
    
         
             
                return x.pad(arg)
         
     | 
| 
       178 
202 
     | 
    
         | 
| 
       179 
     | 
    
         
            -
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       180 
     | 
    
         
            -
                return grad_output.shrink(self.narg)
         
     | 
| 
      
 203 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.shrink(self.narg)
         
     | 
| 
       181 
204 
     | 
    
         | 
| 
       182 
205 
     | 
    
         
             
            class Shrink(Function):
         
     | 
| 
       183 
     | 
    
         
            -
              def forward(self, x:LazyBuffer, arg:Tuple[Tuple[ 
     | 
| 
      
 206 
     | 
    
         
            +
              def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
         
     | 
| 
       184 
207 
     | 
    
         
             
                self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
         
     | 
| 
       185 
208 
     | 
    
         
             
                return x.shrink(arg)
         
     | 
| 
       186 
209 
     | 
    
         | 
| 
       187 
     | 
    
         
            -
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       188 
     | 
    
         
            -
                return grad_output.pad(self.narg)
         
     | 
| 
      
 210 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.pad(self.narg)
         
     | 
| 
       189 
211 
     | 
    
         | 
| 
       190 
212 
     | 
    
         
             
            class Flip(Function):
         
     | 
| 
       191 
213 
     | 
    
         
             
              def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
         
     | 
| 
       192 
214 
     | 
    
         
             
                self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
         
     | 
| 
       193 
215 
     | 
    
         
             
                return x.stride(self.arg)
         
     | 
| 
       194 
216 
     | 
    
         | 
| 
       195 
     | 
    
         
            -
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
         
     | 
| 
       196 
     | 
    
         
            -
                return grad_output.stride(self.arg)
         
     | 
| 
      
 217 
     | 
    
         
            +
              def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)
         
     | 
    
        tinygrad/helpers.py
    CHANGED
    
    | 
         @@ -1,35 +1,75 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            from __future__ import annotations
         
     | 
| 
       2 
     | 
    
         
            -
            import os, functools, platform, time, re, contextlib
         
     | 
| 
       3 
     | 
    
         
            -
            import  
     | 
| 
       4 
     | 
    
         
            -
            from  
     | 
| 
       5 
     | 
    
         
            -
            from  
     | 
| 
      
 2 
     | 
    
         
            +
            import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes
         
     | 
| 
      
 3 
     | 
    
         
            +
            import itertools, urllib.request, subprocess
         
     | 
| 
      
 4 
     | 
    
         
            +
            from tqdm import tqdm
         
     | 
| 
      
 5 
     | 
    
         
            +
            from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
         
     | 
| 
      
 6 
     | 
    
         
            +
            if TYPE_CHECKING:  # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
         
     | 
| 
      
 7 
     | 
    
         
            +
              from typing_extensions import TypeGuard
         
     | 
| 
      
 8 
     | 
    
         
            +
              from tinygrad.shape.shapetracker import sint
         
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
      
 10 
     | 
    
         
            +
            T = TypeVar("T")
         
     | 
| 
      
 11 
     | 
    
         
            +
            U = TypeVar("U")
         
     | 
| 
      
 12 
     | 
    
         
            +
            # NOTE: it returns int 1 if x is empty regardless of the type of x
         
     | 
| 
      
 13 
     | 
    
         
            +
            def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.mul, x, 1)
         
     | 
| 
       6 
14 
     | 
    
         | 
| 
       7 
15 
     | 
    
         
             
            # NOTE: helpers is not allowed to import from anything else in tinygrad
         
     | 
| 
       8 
16 
     | 
    
         
             
            OSX = platform.system() == "Darwin"
         
     | 
| 
       9 
17 
     | 
    
         
             
            CI = os.getenv("CI", "") != ""
         
     | 
| 
       10 
18 
     | 
    
         | 
| 
       11 
     | 
    
         
            -
            def dedup(x): return list(dict.fromkeys(x))   # retains list order
         
     | 
| 
       12 
     | 
    
         
            -
            def argfix(*x): 
     | 
| 
      
 19 
     | 
    
         
            +
            def dedup(x:Iterable[T]): return list(dict.fromkeys(x))   # retains list order
         
     | 
| 
      
 20 
     | 
    
         
            +
            def argfix(*x):
         
     | 
| 
      
 21 
     | 
    
         
            +
              if x and x[0].__class__ in (tuple, list):
         
     | 
| 
      
 22 
     | 
    
         
            +
                if len(x) != 1: raise ValueError(f"bad arg {x}")
         
     | 
| 
      
 23 
     | 
    
         
            +
                return tuple(x[0])
         
     | 
| 
      
 24 
     | 
    
         
            +
              return x
         
     | 
| 
       13 
25 
     | 
    
         
             
            def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
         
     | 
| 
       14 
     | 
    
         
            -
            def all_same(items): return all(x == items[0] for x in items)
         
     | 
| 
       15 
     | 
    
         
            -
            def  
     | 
| 
       16 
     | 
    
         
            -
            def  
     | 
| 
      
 26 
     | 
    
         
            +
            def all_same(items:List[T]): return all(x == items[0] for x in items)
         
     | 
| 
      
 27 
     | 
    
         
            +
            def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
         
     | 
| 
      
 28 
     | 
    
         
            +
            def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st  # replace the termcolor library with one line  # noqa: E501
         
     | 
| 
      
 29 
     | 
    
         
            +
            def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
         
     | 
| 
      
 30 
     | 
    
         
            +
            def ansilen(s:str): return len(ansistrip(s))
         
     | 
| 
       17 
31 
     | 
    
         
             
            def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
         
     | 
| 
       18 
     | 
    
         
            -
            def flatten(l: 
     | 
| 
       19 
     | 
    
         
            -
            def  
     | 
| 
      
 32 
     | 
    
         
            +
            def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
         
     | 
| 
      
 33 
     | 
    
         
            +
            def fully_flatten(l): return [item for sublist in l for item in (fully_flatten(sublist) if isinstance(sublist, (tuple, list)) else [sublist])]
         
     | 
| 
       20 
34 
     | 
    
         
             
            def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
         
     | 
| 
       21 
     | 
    
         
            -
            def  
     | 
| 
       22 
     | 
    
         
            -
             
     | 
| 
       23 
     | 
    
         
            -
             
     | 
| 
       24 
     | 
    
         
            -
               
     | 
| 
       25 
     | 
    
         
            -
             
     | 
| 
       26 
     | 
    
         
            -
             
     | 
| 
       27 
     | 
    
         
            -
               
     | 
| 
      
 35 
     | 
    
         
            +
            def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
         
     | 
| 
      
 36 
     | 
    
         
            +
            def round_up(num, amt:int): return (num+amt-1)//amt * amt
         
     | 
| 
      
 37 
     | 
    
         
            +
            def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
         
     | 
| 
      
 38 
     | 
    
         
            +
              assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"  # noqa: E501
         
     | 
| 
      
 39 
     | 
    
         
            +
              return {k:v for d in ds for k,v in d.items()}
         
     | 
| 
      
 40 
     | 
    
         
            +
            def partition(lst:List[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
         
     | 
| 
      
 41 
     | 
    
         
            +
              a:List[T] = []
         
     | 
| 
      
 42 
     | 
    
         
            +
              b:List[T] = []
         
     | 
| 
       28 
43 
     | 
    
         
             
              for s in lst: (a if fxn(s) else b).append(s)
         
     | 
| 
       29 
44 
     | 
    
         
             
              return a,b
         
     | 
| 
      
 45 
     | 
    
         
            +
            def unwrap(x:Optional[T]) -> T:
         
     | 
| 
      
 46 
     | 
    
         
            +
              assert x is not None
         
     | 
| 
      
 47 
     | 
    
         
            +
              return x
         
     | 
| 
      
 48 
     | 
    
         
            +
            def unwrap2(x:Tuple[T,Any]) -> T:
         
     | 
| 
      
 49 
     | 
    
         
            +
              ret, err = x
         
     | 
| 
      
 50 
     | 
    
         
            +
              assert err is None, str(err)
         
     | 
| 
      
 51 
     | 
    
         
            +
              return ret
         
     | 
| 
      
 52 
     | 
    
         
            +
            def get_child(obj, key):
         
     | 
| 
      
 53 
     | 
    
         
            +
              for k in key.split('.'):
         
     | 
| 
      
 54 
     | 
    
         
            +
                if k.isnumeric(): obj = obj[int(k)]
         
     | 
| 
      
 55 
     | 
    
         
            +
                elif isinstance(obj, dict): obj = obj[k]
         
     | 
| 
      
 56 
     | 
    
         
            +
                else: obj = getattr(obj, k)
         
     | 
| 
      
 57 
     | 
    
         
            +
              return obj
         
     | 
| 
      
 58 
     | 
    
         
            +
             
     | 
| 
      
 59 
     | 
    
         
            +
            # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
         
     | 
| 
      
 60 
     | 
    
         
            +
            def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
         
     | 
| 
      
 61 
     | 
    
         
            +
              acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
         
     | 
| 
      
 62 
     | 
    
         
            +
              try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
         
     | 
| 
      
 63 
     | 
    
         
            +
              except ValueError: return None
         
     | 
| 
      
 64 
     | 
    
         
            +
              return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
         
     | 
| 
       30 
65 
     | 
    
         | 
| 
       31 
66 
     | 
    
         
             
            @functools.lru_cache(maxsize=None)
         
     | 
| 
       32 
     | 
    
         
            -
            def  
     | 
| 
      
 67 
     | 
    
         
            +
            def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
         
     | 
| 
      
 68 
     | 
    
         
            +
            @functools.lru_cache(maxsize=None)
         
     | 
| 
      
 69 
     | 
    
         
            +
            def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
         
     | 
| 
      
 70 
     | 
    
         
            +
            def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
         
     | 
| 
      
 71 
     | 
    
         
            +
             
     | 
| 
      
 72 
     | 
    
         
            +
            class GraphException(Exception): pass
         
     | 
| 
       33 
73 
     | 
    
         | 
| 
       34 
74 
     | 
    
         
             
            class Context(contextlib.ContextDecorator):
         
     | 
| 
       35 
75 
     | 
    
         
             
              stack: ClassVar[List[dict[str, int]]] = [{}]
         
     | 
| 
         @@ -44,83 +84,23 @@ class Context(contextlib.ContextDecorator): 
     | 
|
| 
       44 
84 
     | 
    
         
             
            class ContextVar:
         
     | 
| 
       45 
85 
     | 
    
         
             
              _cache: ClassVar[Dict[str, ContextVar]] = {}
         
     | 
| 
       46 
86 
     | 
    
         
             
              value: int
         
     | 
| 
      
 87 
     | 
    
         
            +
              key: str
         
     | 
| 
       47 
88 
     | 
    
         
             
              def __new__(cls, key, default_value):
         
     | 
| 
       48 
89 
     | 
    
         
             
                if key in ContextVar._cache: return ContextVar._cache[key]
         
     | 
| 
       49 
90 
     | 
    
         
             
                instance = ContextVar._cache[key] = super().__new__(cls)
         
     | 
| 
       50 
     | 
    
         
            -
                instance.value = getenv(key, default_value)
         
     | 
| 
      
 91 
     | 
    
         
            +
                instance.value, instance.key = getenv(key, default_value), key
         
     | 
| 
       51 
92 
     | 
    
         
             
                return instance
         
     | 
| 
       52 
93 
     | 
    
         
             
              def __bool__(self): return bool(self.value)
         
     | 
| 
       53 
94 
     | 
    
         
             
              def __ge__(self, x): return self.value >= x
         
     | 
| 
       54 
95 
     | 
    
         
             
              def __gt__(self, x): return self.value > x
         
     | 
| 
       55 
96 
     | 
    
         
             
              def __lt__(self, x): return self.value < x
         
     | 
| 
       56 
97 
     | 
    
         | 
| 
       57 
     | 
    
         
            -
            DEBUG, IMAGE = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0)
         
     | 
| 
       58 
     | 
    
         
            -
             
     | 
| 
      
 98 
     | 
    
         
            +
            DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
         
     | 
| 
      
 99 
     | 
    
         
            +
            WINO, THREEFRY, CACHECOLLECTING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CACHECOLLECTING", 1)
         
     | 
| 
      
 100 
     | 
    
         
            +
            GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
         
     | 
| 
      
 101 
     | 
    
         
            +
            MULTIOUTPUT = ContextVar("MULTIOUTPUT", 1)
         
     | 
| 
       59 
102 
     | 
    
         | 
| 
       60 
     | 
    
         
            -
             
     | 
| 
       61 
     | 
    
         
            -
              def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
         
     | 
| 
       62 
     | 
    
         
            -
              def __enter__(self): self.st = time.perf_counter_ns()
         
     | 
| 
       63 
     | 
    
         
            -
              def __exit__(self, exc_type, exc_val, exc_tb):
         
     | 
| 
       64 
     | 
    
         
            -
                self.et = time.perf_counter_ns() - self.st
         
     | 
| 
       65 
     | 
    
         
            -
                if self.enabled: print(f"{self.prefix}{self.et*1e-6:.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
         
     | 
| 
       66 
     | 
    
         
            -
             
     | 
| 
       67 
     | 
    
         
            -
            # **** tinygrad now supports dtypes! *****
         
     | 
| 
       68 
     | 
    
         
            -
             
     | 
| 
       69 
     | 
    
         
            -
            class DType(NamedTuple):
         
     | 
| 
       70 
     | 
    
         
            -
              priority: int  # this determines when things get upcasted
         
     | 
| 
       71 
     | 
    
         
            -
              itemsize: int
         
     | 
| 
       72 
     | 
    
         
            -
              name: str
         
     | 
| 
       73 
     | 
    
         
            -
              np: Optional[type]  # TODO: someday this will be removed with the "remove numpy" project
         
     | 
| 
       74 
     | 
    
         
            -
              sz: int = 1
         
     | 
| 
       75 
     | 
    
         
            -
              def __repr__(self): return f"dtypes.{self.name}"
         
     | 
| 
       76 
     | 
    
         
            -
             
     | 
| 
       77 
     | 
    
         
            -
            # dependent typing?
         
     | 
| 
       78 
     | 
    
         
            -
            class ImageDType(DType):
         
     | 
| 
       79 
     | 
    
         
            -
              def __new__(cls, priority, itemsize, name, np, shape):
         
     | 
| 
       80 
     | 
    
         
            -
                return super().__new__(cls, priority, itemsize, name, np)
         
     | 
| 
       81 
     | 
    
         
            -
              def __init__(self, priority, itemsize, name, np, shape):
         
     | 
| 
       82 
     | 
    
         
            -
                self.shape: Tuple[int, ...] = shape  # arbitrary arg for the dtype, used in image for the shape
         
     | 
| 
       83 
     | 
    
         
            -
                super().__init__()
         
     | 
| 
       84 
     | 
    
         
            -
              def __repr__(self): return f"dtypes.{self.name}({self.shape})"
         
     | 
| 
       85 
     | 
    
         
            -
             
     | 
| 
       86 
     | 
    
         
            -
            class dtypes:
         
     | 
| 
       87 
     | 
    
         
            -
              @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
         
     | 
| 
       88 
     | 
    
         
            -
              def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
         
     | 
| 
       89 
     | 
    
         
            -
              @staticmethod
         
     | 
| 
       90 
     | 
    
         
            -
              def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes._half4, dtypes._float2, dtypes._float4)
         
     | 
| 
       91 
     | 
    
         
            -
              @staticmethod
         
     | 
| 
       92 
     | 
    
         
            -
              def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
         
     | 
| 
       93 
     | 
    
         
            -
              @staticmethod
         
     | 
| 
       94 
     | 
    
         
            -
              def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
         
     | 
| 
       95 
     | 
    
         
            -
              @staticmethod
         
     | 
| 
       96 
     | 
    
         
            -
              def fields() -> Dict[str, DType]: return DTYPES_DICT
         
     | 
| 
       97 
     | 
    
         
            -
              bool: Final[DType] = DType(0, 1, "bool", np.bool_)
         
     | 
| 
       98 
     | 
    
         
            -
              float16: Final[DType] = DType(0, 2, "half", np.float16)
         
     | 
| 
       99 
     | 
    
         
            -
              half = float16
         
     | 
| 
       100 
     | 
    
         
            -
              float32: Final[DType] = DType(4, 4, "float", np.float32)
         
     | 
| 
       101 
     | 
    
         
            -
              float = float32
         
     | 
| 
       102 
     | 
    
         
            -
              float64: Final[DType] = DType(0, 8, "double", np.float64)
         
     | 
| 
       103 
     | 
    
         
            -
              double = float64
         
     | 
| 
       104 
     | 
    
         
            -
              int8: Final[DType] = DType(0, 1, "char", np.int8)
         
     | 
| 
       105 
     | 
    
         
            -
              int16: Final[DType] = DType(1, 2, "short", np.int16)
         
     | 
| 
       106 
     | 
    
         
            -
              int32: Final[DType] = DType(2, 4, "int", np.int32)
         
     | 
| 
       107 
     | 
    
         
            -
              int64: Final[DType] = DType(3, 8, "long", np.int64)
         
     | 
| 
       108 
     | 
    
         
            -
              uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8)
         
     | 
| 
       109 
     | 
    
         
            -
              uint16: Final[DType] = DType(1, 2, "unsigned short", np.uint16)
         
     | 
| 
       110 
     | 
    
         
            -
              uint32: Final[DType] = DType(2, 4, "unsigned int", np.uint32)
         
     | 
| 
       111 
     | 
    
         
            -
              uint64: Final[DType] = DType(3, 8, "unsigned long", np.uint64)
         
     | 
| 
       112 
     | 
    
         
            -
             
     | 
| 
       113 
     | 
    
         
            -
              # NOTE: bfloat16 isn't supported in numpy
         
     | 
| 
       114 
     | 
    
         
            -
              bfloat16: Final[DType] = DType(0, 2, "__bf16", None)
         
     | 
| 
       115 
     | 
    
         
            -
             
     | 
| 
       116 
     | 
    
         
            -
              # NOTE: these are internal dtypes, should probably check for that
         
     | 
| 
       117 
     | 
    
         
            -
              _half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
         
     | 
| 
       118 
     | 
    
         
            -
              _float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
         
     | 
| 
       119 
     | 
    
         
            -
              _float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
         
     | 
| 
       120 
     | 
    
         
            -
              _arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
         
     | 
| 
       121 
     | 
    
         
            -
             
     | 
| 
       122 
     | 
    
         
            -
            # HACK: staticmethods are not callable in 3.8 so we have to compare the class
         
     | 
| 
       123 
     | 
    
         
            -
            DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}
         
     | 
| 
      
 103 
     | 
    
         
            +
            # **************** global state Counters ****************
         
     | 
| 
       124 
104 
     | 
    
         | 
| 
       125 
105 
     | 
    
         
             
            class GlobalCounters:
         
     | 
| 
       126 
106 
     | 
    
         
             
              global_ops: ClassVar[int] = 0
         
     | 
| 
         @@ -128,7 +108,134 @@ class GlobalCounters: 
     | 
|
| 
       128 
108 
     | 
    
         
             
              time_sum_s: ClassVar[float] = 0.0
         
     | 
| 
       129 
109 
     | 
    
         
             
              kernel_count: ClassVar[int] = 0
         
     | 
| 
       130 
110 
     | 
    
         
             
              mem_used: ClassVar[int] = 0   # NOTE: this is not reset
         
     | 
| 
       131 
     | 
    
         
            -
              mem_cached: ClassVar[int] = 0 # NOTE: this is not reset
         
     | 
| 
       132 
     | 
    
         
            -
              cache: ClassVar[Optional[List[Tuple[Callable, Any, Dict[Any, int]]]]] = None  # List[Tuple[Callable, List[RawBuffer], Dict[Variable, int]]]
         
     | 
| 
       133 
111 
     | 
    
         
             
              @staticmethod
         
     | 
| 
       134 
     | 
    
         
            -
              def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count 
     | 
| 
      
 112 
     | 
    
         
            +
              def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
         
     | 
| 
      
 113 
     | 
    
         
            +
             
     | 
| 
      
 114 
     | 
    
         
            +
            # **************** timer and profiler ****************
         
     | 
| 
      
 115 
     | 
    
         
            +
             
     | 
| 
      
 116 
     | 
    
         
            +
            class Timing(contextlib.ContextDecorator):
         
     | 
| 
      
 117 
     | 
    
         
            +
              def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
         
     | 
| 
      
 118 
     | 
    
         
            +
              def __enter__(self): self.st = time.perf_counter_ns()
         
     | 
| 
      
 119 
     | 
    
         
            +
              def __exit__(self, *exc):
         
     | 
| 
      
 120 
     | 
    
         
            +
                self.et = time.perf_counter_ns() - self.st
         
     | 
| 
      
 121 
     | 
    
         
            +
                if self.enabled: print(f"{self.prefix}{self.et*1e-6:6.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
         
     | 
| 
      
 122 
     | 
    
         
            +
             
     | 
| 
      
 123 
     | 
    
         
            +
            def _format_fcn(fcn): return f"{fcn[0]}:{fcn[1]}:{fcn[2]}"
         
     | 
| 
      
 124 
     | 
    
         
            +
            class Profiling(contextlib.ContextDecorator):
         
     | 
| 
      
 125 
     | 
    
         
            +
              def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None, ts=1):
         
     | 
| 
      
 126 
     | 
    
         
            +
                self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3/ts
         
     | 
| 
      
 127 
     | 
    
         
            +
              def __enter__(self):
         
     | 
| 
      
 128 
     | 
    
         
            +
                self.pr = cProfile.Profile()
         
     | 
| 
      
 129 
     | 
    
         
            +
                if self.enabled: self.pr.enable()
         
     | 
| 
      
 130 
     | 
    
         
            +
              def __exit__(self, *exc):
         
     | 
| 
      
 131 
     | 
    
         
            +
                if self.enabled:
         
     | 
| 
      
 132 
     | 
    
         
            +
                  self.pr.disable()
         
     | 
| 
      
 133 
     | 
    
         
            +
                  if self.fn: self.pr.dump_stats(self.fn)
         
     | 
| 
      
 134 
     | 
    
         
            +
                  stats = pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort)
         
     | 
| 
      
 135 
     | 
    
         
            +
                  for fcn in stats.fcn_list[0:int(len(stats.fcn_list)*self.frac)]:    # type: ignore[attr-defined]
         
     | 
| 
      
 136 
     | 
    
         
            +
                    (_primitive_calls, num_calls, tottime, cumtime, callers) = stats.stats[fcn]    # type: ignore[attr-defined]
         
     | 
| 
      
 137 
     | 
    
         
            +
                    scallers = sorted(callers.items(), key=lambda x: -x[1][2])
         
     | 
| 
      
 138 
     | 
    
         
            +
                    print(f"n:{num_calls:8d}  tm:{tottime*self.time_scale:7.2f}ms  tot:{cumtime*self.time_scale:7.2f}ms",
         
     | 
| 
      
 139 
     | 
    
         
            +
                          colored(_format_fcn(fcn), "yellow") + " "*(50-len(_format_fcn(fcn))),
         
     | 
| 
      
 140 
     | 
    
         
            +
                          colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if len(scallers) else '')
         
     | 
| 
      
 141 
     | 
    
         
            +
             
     | 
| 
      
 142 
     | 
    
         
            +
            # *** universal database cache ***
         
     | 
| 
      
 143 
     | 
    
         
            +
             
     | 
| 
      
 144 
     | 
    
         
            +
            _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache"))
         
     | 
| 
      
 145 
     | 
    
         
            +
            CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
         
     | 
| 
      
 146 
     | 
    
         
            +
            CACHELEVEL = getenv("CACHELEVEL", 2)
         
     | 
| 
      
 147 
     | 
    
         
            +
             
     | 
| 
      
 148 
     | 
    
         
            +
            VERSION = 16
         
     | 
| 
      
 149 
     | 
    
         
            +
            _db_connection = None
         
     | 
| 
      
 150 
     | 
    
         
            +
            def db_connection():
         
     | 
| 
      
 151 
     | 
    
         
            +
              global _db_connection
         
     | 
| 
      
 152 
     | 
    
         
            +
              if _db_connection is None:
         
     | 
| 
      
 153 
     | 
    
         
            +
                os.makedirs(CACHEDB.rsplit(os.sep, 1)[0], exist_ok=True)
         
     | 
| 
      
 154 
     | 
    
         
            +
                _db_connection = sqlite3.connect(CACHEDB)
         
     | 
| 
      
 155 
     | 
    
         
            +
                if DEBUG >= 7: _db_connection.set_trace_callback(print)
         
     | 
| 
      
 156 
     | 
    
         
            +
              return _db_connection
         
     | 
| 
      
 157 
     | 
    
         
            +
             
     | 
| 
      
 158 
     | 
    
         
            +
            def diskcache_clear():
         
     | 
| 
      
 159 
     | 
    
         
            +
              cur = db_connection().cursor()
         
     | 
| 
      
 160 
     | 
    
         
            +
              drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
         
     | 
| 
      
 161 
     | 
    
         
            +
              cur.executescript("\n".join([s[0] for s in drop_tables]))
         
     | 
| 
      
 162 
     | 
    
         
            +
             
     | 
| 
      
 163 
     | 
    
         
            +
            def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
         
     | 
| 
      
 164 
     | 
    
         
            +
              if CACHELEVEL == 0: return None
         
     | 
| 
      
 165 
     | 
    
         
            +
              if isinstance(key, (str,int)): key = {"key": key}
         
     | 
| 
      
 166 
     | 
    
         
            +
              conn = db_connection()
         
     | 
| 
      
 167 
     | 
    
         
            +
              cur = conn.cursor()
         
     | 
| 
      
 168 
     | 
    
         
            +
              try:
         
     | 
| 
      
 169 
     | 
    
         
            +
                res = cur.execute(f"SELECT val FROM '{table}_{VERSION}' WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
         
     | 
| 
      
 170 
     | 
    
         
            +
              except sqlite3.OperationalError:
         
     | 
| 
      
 171 
     | 
    
         
            +
                return None  # table doesn't exist
         
     | 
| 
      
 172 
     | 
    
         
            +
              if (val:=res.fetchone()) is not None: return pickle.loads(val[0])
         
     | 
| 
      
 173 
     | 
    
         
            +
              return None
         
     | 
| 
      
 174 
     | 
    
         
            +
             
     | 
| 
      
 175 
     | 
    
         
            +
            _db_tables = set()
         
     | 
| 
      
 176 
     | 
    
         
            +
            def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
         
     | 
| 
      
 177 
     | 
    
         
            +
              if CACHELEVEL == 0: return val
         
     | 
| 
      
 178 
     | 
    
         
            +
              if isinstance(key, (str,int)): key = {"key": key}
         
     | 
| 
      
 179 
     | 
    
         
            +
              conn = db_connection()
         
     | 
| 
      
 180 
     | 
    
         
            +
              cur = conn.cursor()
         
     | 
| 
      
 181 
     | 
    
         
            +
              if table not in _db_tables:
         
     | 
| 
      
 182 
     | 
    
         
            +
                TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
         
     | 
| 
      
 183 
     | 
    
         
            +
                ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
         
     | 
| 
      
 184 
     | 
    
         
            +
                cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
         
     | 
| 
      
 185 
     | 
    
         
            +
                _db_tables.add(table)
         
     | 
| 
      
 186 
     | 
    
         
            +
              cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), ))  # noqa: E501
         
     | 
| 
      
 187 
     | 
    
         
            +
              conn.commit()
         
     | 
| 
      
 188 
     | 
    
         
            +
              cur.close()
         
     | 
| 
      
 189 
     | 
    
         
            +
              return val
         
     | 
| 
      
 190 
     | 
    
         
            +
             
     | 
| 
      
 191 
     | 
    
         
            +
            def diskcache(func):
         
     | 
| 
      
 192 
     | 
    
         
            +
              def wrapper(*args, **kwargs) -> bytes:
         
     | 
| 
      
 193 
     | 
    
         
            +
                table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
         
     | 
| 
      
 194 
     | 
    
         
            +
                if (ret:=diskcache_get(table, key)): return ret
         
     | 
| 
      
 195 
     | 
    
         
            +
                return diskcache_put(table, key, func(*args, **kwargs))
         
     | 
| 
      
 196 
     | 
    
         
            +
              return wrapper
         
     | 
| 
      
 197 
     | 
    
         
            +
             
     | 
| 
      
 198 
     | 
    
         
            +
            # *** http support ***
         
     | 
| 
      
 199 
     | 
    
         
            +
             
     | 
| 
      
 200 
     | 
    
         
            +
            def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
         
     | 
| 
      
 201 
     | 
    
         
            +
              if url.startswith(("/", ".")): return pathlib.Path(url)
         
     | 
| 
      
 202 
     | 
    
         
            +
              fp = pathlib.Path(name) if name is not None and (isinstance(name, pathlib.Path) or '/' in name) else pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (name if name else hashlib.md5(url.encode('utf-8')).hexdigest())  # noqa: E501
         
     | 
| 
      
 203 
     | 
    
         
            +
              if not fp.is_file() or not allow_caching:
         
     | 
| 
      
 204 
     | 
    
         
            +
                with urllib.request.urlopen(url, timeout=10) as r:
         
     | 
| 
      
 205 
     | 
    
         
            +
                  assert r.status == 200
         
     | 
| 
      
 206 
     | 
    
         
            +
                  total_length = int(r.headers.get('content-length', 0))
         
     | 
| 
      
 207 
     | 
    
         
            +
                  progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=url)
         
     | 
| 
      
 208 
     | 
    
         
            +
                  (path := fp.parent).mkdir(parents=True, exist_ok=True)
         
     | 
| 
      
 209 
     | 
    
         
            +
                  with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
         
     | 
| 
      
 210 
     | 
    
         
            +
                    while chunk := r.read(16384): progress_bar.update(f.write(chunk))
         
     | 
| 
      
 211 
     | 
    
         
            +
                    f.close()
         
     | 
| 
      
 212 
     | 
    
         
            +
                    if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")
         
     | 
| 
      
 213 
     | 
    
         
            +
                    pathlib.Path(f.name).rename(fp)
         
     | 
| 
      
 214 
     | 
    
         
            +
              return fp
         
     | 
| 
      
 215 
     | 
    
         
            +
             
     | 
| 
      
 216 
     | 
    
         
            +
            # *** Exec helpers
         
     | 
| 
      
 217 
     | 
    
         
            +
             
     | 
| 
      
 218 
     | 
    
         
            +
            def cpu_time_execution(cb, enable):
         
     | 
| 
      
 219 
     | 
    
         
            +
              if enable: st = time.perf_counter()
         
     | 
| 
      
 220 
     | 
    
         
            +
              cb()
         
     | 
| 
      
 221 
     | 
    
         
            +
              if enable: return time.perf_counter()-st
         
     | 
| 
      
 222 
     | 
    
         
            +
             
     | 
| 
      
 223 
     | 
    
         
            +
            def cpu_objdump(lib):
         
     | 
| 
      
 224 
     | 
    
         
            +
              with tempfile.NamedTemporaryFile(delete=True) as f:
         
     | 
| 
      
 225 
     | 
    
         
            +
                pathlib.Path(f.name).write_bytes(lib)
         
     | 
| 
      
 226 
     | 
    
         
            +
                print(subprocess.check_output(['objdump', '-d', f.name]).decode('utf-8'))
         
     | 
| 
      
 227 
     | 
    
         
            +
             
     | 
| 
      
 228 
     | 
    
         
            +
            # *** ctypes helpers
         
     | 
| 
      
 229 
     | 
    
         
            +
             
     | 
| 
      
 230 
     | 
    
         
            +
            # TODO: make this work with read only memoryviews (if possible)
         
     | 
| 
      
 231 
     | 
    
         
            +
            def from_mv(mv:memoryview, to_type=ctypes.c_char):
         
     | 
| 
      
 232 
     | 
    
         
            +
              return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
         
     | 
| 
      
 233 
     | 
    
         
            +
            def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
         
     | 
| 
      
 234 
     | 
    
         
            +
            def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options])  # noqa: E501
         
     | 
| 
      
 235 
     | 
    
         
            +
            @functools.lru_cache(maxsize=None)
         
     | 
| 
      
 236 
     | 
    
         
            +
            def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
         
     | 
| 
      
 237 
     | 
    
         
            +
              class CStruct(ctypes.Structure):
         
     | 
| 
      
 238 
     | 
    
         
            +
                _pack_, _fields_ = 1, fields
         
     | 
| 
      
 239 
     | 
    
         
            +
              return CStruct
         
     | 
| 
      
 240 
     | 
    
         
            +
            def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
         
     | 
| 
      
 241 
     | 
    
         
            +
            def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))
         
     |