tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
 - tinygrad/codegen/kernel.py +572 -83
 - tinygrad/codegen/linearizer.py +415 -395
 - tinygrad/codegen/uops.py +415 -0
 - tinygrad/device.py +183 -0
 - tinygrad/dtype.py +113 -0
 - tinygrad/engine/__init__.py +0 -0
 - tinygrad/engine/graph.py +100 -0
 - tinygrad/engine/jit.py +195 -0
 - tinygrad/engine/realize.py +191 -0
 - tinygrad/engine/schedule.py +362 -0
 - tinygrad/engine/search.py +196 -0
 - tinygrad/{mlops.py → function.py} +76 -55
 - tinygrad/helpers.py +196 -89
 - tinygrad/lazy.py +210 -371
 - tinygrad/multi.py +169 -0
 - tinygrad/nn/__init__.py +202 -22
 - tinygrad/nn/datasets.py +7 -0
 - tinygrad/nn/optim.py +112 -32
 - tinygrad/nn/state.py +136 -39
 - tinygrad/ops.py +119 -202
 - tinygrad/renderer/__init__.py +61 -0
 - tinygrad/renderer/assembly.py +276 -0
 - tinygrad/renderer/cstyle.py +353 -166
 - tinygrad/renderer/llvmir.py +150 -138
 - tinygrad/runtime/autogen/amd_gpu.py +1900 -0
 - tinygrad/runtime/autogen/comgr.py +865 -0
 - tinygrad/runtime/autogen/cuda.py +5923 -0
 - tinygrad/runtime/autogen/hip.py +5909 -0
 - tinygrad/runtime/autogen/hsa.py +5761 -0
 - tinygrad/runtime/autogen/kfd.py +812 -0
 - tinygrad/runtime/autogen/nv_gpu.py +33328 -0
 - tinygrad/runtime/autogen/opencl.py +1795 -0
 - tinygrad/runtime/driver/hip_comgr.py +47 -0
 - tinygrad/runtime/driver/hsa.py +143 -0
 - tinygrad/runtime/graph/clang.py +38 -0
 - tinygrad/runtime/graph/cuda.py +81 -0
 - tinygrad/runtime/graph/hcq.py +143 -0
 - tinygrad/runtime/graph/hsa.py +171 -0
 - tinygrad/runtime/graph/metal.py +75 -0
 - tinygrad/runtime/ops_amd.py +564 -0
 - tinygrad/runtime/ops_clang.py +24 -77
 - tinygrad/runtime/ops_cuda.py +175 -89
 - tinygrad/runtime/ops_disk.py +56 -33
 - tinygrad/runtime/ops_gpu.py +92 -95
 - tinygrad/runtime/ops_hsa.py +278 -0
 - tinygrad/runtime/ops_llvm.py +39 -60
 - tinygrad/runtime/ops_metal.py +92 -74
 - tinygrad/runtime/ops_npy.py +9 -0
 - tinygrad/runtime/ops_nv.py +630 -0
 - tinygrad/runtime/ops_python.py +204 -0
 - tinygrad/shape/shapetracker.py +86 -254
 - tinygrad/shape/symbolic.py +166 -141
 - tinygrad/shape/view.py +296 -0
 - tinygrad/tensor.py +2619 -448
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
 - tinygrad-0.9.0.dist-info/METADATA +227 -0
 - tinygrad-0.9.0.dist-info/RECORD +60 -0
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
 - tinygrad/codegen/assembly.py +0 -190
 - tinygrad/codegen/optimizer.py +0 -379
 - tinygrad/codegen/search.py +0 -72
 - tinygrad/graph.py +0 -83
 - tinygrad/jit.py +0 -57
 - tinygrad/nn/image.py +0 -100
 - tinygrad/renderer/assembly_arm64.py +0 -169
 - tinygrad/renderer/assembly_ptx.py +0 -98
 - tinygrad/renderer/wgsl.py +0 -53
 - tinygrad/runtime/lib.py +0 -113
 - tinygrad/runtime/ops_cpu.py +0 -51
 - tinygrad/runtime/ops_hip.py +0 -82
 - tinygrad/runtime/ops_shm.py +0 -29
 - tinygrad/runtime/ops_torch.py +0 -30
 - tinygrad/runtime/ops_webgpu.py +0 -45
 - tinygrad-0.7.0.dist-info/METADATA +0 -212
 - tinygrad-0.7.0.dist-info/RECORD +0 -40
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
 
    
        tinygrad/codegen/linearizer.py
    CHANGED
    
    | 
         @@ -1,440 +1,460 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            from  
     | 
| 
       2 
     | 
    
         
            -
            import  
     | 
| 
      
 1 
     | 
    
         
            +
            from __future__ import annotations
         
     | 
| 
      
 2 
     | 
    
         
            +
            from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence
         
     | 
| 
      
 3 
     | 
    
         
            +
            import itertools, math, functools
         
     | 
| 
       3 
4 
     | 
    
         
             
            from collections import defaultdict
         
     | 
| 
       4 
     | 
    
         
            -
            from enum import Enum, auto
         
     | 
| 
       5 
5 
     | 
    
         | 
| 
       6 
     | 
    
         
            -
            from tinygrad. 
     | 
| 
       7 
     | 
    
         
            -
            from tinygrad. 
     | 
| 
       8 
     | 
    
         
            -
            from tinygrad. 
     | 
| 
       9 
     | 
    
         
            -
            from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
         
     | 
| 
       10 
     | 
    
         
            -
            from tinygrad.runtime.lib import RawConst
         
     | 
| 
      
 6 
     | 
    
         
            +
            from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
         
     | 
| 
      
 7 
     | 
    
         
            +
            from tinygrad.helpers import colored, DEBUG, prod, getenv, to_function_name
         
     | 
| 
      
 8 
     | 
    
         
            +
            from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
         
     | 
| 
       11 
9 
     | 
    
         
             
            from tinygrad.shape.shapetracker import ShapeTracker
         
     | 
| 
       12 
     | 
    
         
            -
            from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode,  
     | 
| 
       13 
     | 
    
         
            -
            from tinygrad.codegen. 
     | 
| 
       14 
     | 
    
         
            -
            from tinygrad. 
     | 
| 
       15 
     | 
    
         
            -
             
     | 
| 
       16 
     | 
    
         
            -
             
     | 
| 
       17 
     | 
    
         
            -
             
     | 
| 
       18 
     | 
    
         
            -
             
     | 
| 
       19 
     | 
    
         
            -
               
     | 
| 
       20 
     | 
    
         
            -
               
     | 
| 
       21 
     | 
    
         
            -
             
     | 
| 
       22 
     | 
    
         
            -
             
     | 
| 
       23 
     | 
    
         
            -
             
     | 
| 
       24 
     | 
    
         
            -
             
     | 
| 
       25 
     | 
    
         
            -
             
     | 
| 
       26 
     | 
    
         
            -
             
     | 
| 
       27 
     | 
    
         
            -
             
     | 
| 
       28 
     | 
    
         
            -
             
     | 
| 
       29 
     | 
    
         
            -
             
     | 
| 
       30 
     | 
    
         
            -
             
     | 
| 
       31 
     | 
    
         
            -
             
     | 
| 
       32 
     | 
    
         
            -
             
     | 
| 
       33 
     | 
    
         
            -
             
     | 
| 
       34 
     | 
    
         
            -
             
     | 
| 
       35 
     | 
    
         
            -
             
     | 
| 
       36 
     | 
    
         
            -
             
     | 
| 
       37 
     | 
    
         
            -
             
     | 
| 
       38 
     | 
    
         
            -
             
     | 
| 
       39 
     | 
    
         
            -
             
     | 
| 
       40 
     | 
    
         
            -
             
     | 
| 
       41 
     | 
    
         
            -
             
     | 
| 
       42 
     | 
    
         
            -
             
     | 
| 
       43 
     | 
    
         
            -
             
     | 
| 
       44 
     | 
    
         
            -
             
     | 
| 
       45 
     | 
    
         
            -
             
     | 
| 
       46 
     | 
    
         
            -
             
     | 
| 
       47 
     | 
    
         
            -
             
     | 
| 
       48 
     | 
    
         
            -
             
     | 
| 
       49 
     | 
    
         
            -
             
     | 
| 
       50 
     | 
    
         
            -
               
     | 
| 
       51 
     | 
    
         
            -
             
     | 
| 
       52 
     | 
    
         
            -
             
     | 
| 
       53 
     | 
    
         
            -
             
     | 
| 
       54 
     | 
    
         
            -
             
     | 
| 
       55 
     | 
    
         
            -
             
     | 
| 
       56 
     | 
    
         
            -
             
     | 
| 
       57 
     | 
    
         
            -
                return  
     | 
| 
       58 
     | 
    
         
            -
             
     | 
| 
       59 
     | 
    
         
            -
             
     | 
| 
       60 
     | 
    
         
            -
             
     | 
| 
       61 
     | 
    
         
            -
             
     | 
| 
       62 
     | 
    
         
            -
               
     | 
| 
       63 
     | 
    
         
            -
               
     | 
| 
       64 
     | 
    
         
            -
             
     | 
| 
       65 
     | 
    
         
            -
             
     | 
| 
       66 
     | 
    
         
            -
             
     | 
| 
       67 
     | 
    
         
            -
             
     | 
| 
       68 
     | 
    
         
            -
             
     | 
| 
       69 
     | 
    
         
            -
             
     | 
| 
       70 
     | 
    
         
            -
             
     | 
| 
       71 
     | 
    
         
            -
             
     | 
| 
       72 
     | 
    
         
            -
             
     | 
| 
       73 
     | 
    
         
            -
             
     | 
| 
       74 
     | 
    
         
            -
             
     | 
| 
       75 
     | 
    
         
            -
             
     | 
| 
       76 
     | 
    
         
            -
                 
     | 
| 
       77 
     | 
    
         
            -
             
     | 
| 
       78 
     | 
    
         
            -
             
     | 
| 
       79 
     | 
    
         
            -
             
     | 
| 
       80 
     | 
    
         
            -
             
     | 
| 
       81 
     | 
    
         
            -
             
     | 
| 
       82 
     | 
    
         
            -
             
     | 
| 
       83 
     | 
    
         
            -
                 
     | 
| 
       84 
     | 
    
         
            -
             
     | 
| 
       85 
     | 
    
         
            -
             
     | 
| 
       86 
     | 
    
         
            -
             
     | 
| 
       87 
     | 
    
         
            -
             
     | 
| 
       88 
     | 
    
         
            -
             
     | 
| 
       89 
     | 
    
         
            -
             
     | 
| 
       90 
     | 
    
         
            -
             
     | 
| 
       91 
     | 
    
         
            -
                 
     | 
| 
       92 
     | 
    
         
            -
             
     | 
| 
       93 
     | 
    
         
            -
                 
     | 
| 
       94 
     | 
    
         
            -
             
     | 
| 
       95 
     | 
    
         
            -
                  if any(x is None for x in nv): break
         
     | 
| 
       96 
     | 
    
         
            -
                  new_idxs.append(idxs[i:i+4])
         
     | 
| 
       97 
     | 
    
         
            -
                  new_values.append(nv)
         
     | 
| 
       98 
     | 
    
         
            -
                if len(new_values) == len(idxs)//4:
         
     | 
| 
       99 
     | 
    
         
            -
                  return zip(new_idxs, new_values)
         
     | 
| 
       100 
     | 
    
         
            -
              return zip([[i] for i in range(len(values[0]))], zip(*values))
         
     | 
| 
       101 
     | 
    
         
            -
             
     | 
| 
       102 
     | 
    
         
            -
            # TODO: generic visitor pattern?
         
     | 
| 
       103 
     | 
    
         
            -
            def expand_node(idx:Node) -> List[Node]:
         
     | 
| 
       104 
     | 
    
         
            -
              if isinstance(idx, Variable): return [idx] if idx.expr is not None else [Variable.num(j) for j in range(idx.min, idx.max+1)]
         
     | 
| 
       105 
     | 
    
         
            -
              if isinstance(idx, NumNode): return [idx]
         
     | 
| 
       106 
     | 
    
         
            -
              if isinstance(idx, MulNode): return [x*idx.b for x in expand_node(idx.a)]
         
     | 
| 
       107 
     | 
    
         
            -
              if isinstance(idx, SumNode): return [Variable.sum(list(it)) for it in itertools.product(*[expand_node(x) for x in idx.nodes])]
         
     | 
| 
       108 
     | 
    
         
            -
              raise NotImplementedError(idx)
         
     | 
| 
       109 
     | 
    
         
            -
             
     | 
| 
       110 
     | 
    
         
            -
            def expand_idxs(idxs:Sequence[Node]) -> Iterator[Tuple[Node, ...]]:
         
     | 
| 
       111 
     | 
    
         
            -
              for x in itertools.product(*[expand_node(idx) for idx in idxs[::-1]]):
         
     | 
| 
       112 
     | 
    
         
            -
                yield x[::-1]
         
     | 
| 
       113 
     | 
    
         
            -
             
     | 
| 
       114 
     | 
    
         
            -
            class MemOp(NamedTuple):
         
     | 
| 
       115 
     | 
    
         
            -
              name: str
         
     | 
| 
       116 
     | 
    
         
            -
              idx: Node
         
     | 
| 
       117 
     | 
    
         
            -
              local: bool
         
     | 
| 
       118 
     | 
    
         
            -
              memory_dtype: DType
         
     | 
| 
       119 
     | 
    
         
            -
             
     | 
| 
       120 
     | 
    
         
            -
              # shared
         
     | 
| 
       121 
     | 
    
         
            -
              valid: Node
         
     | 
| 
       122 
     | 
    
         
            -
              invalid_value: Union[float, int] = 0.0
         
     | 
| 
       123 
     | 
    
         
            -
             
     | 
| 
       124 
     | 
    
         
            -
            class ConstOp(NamedTuple):
         
     | 
| 
       125 
     | 
    
         
            -
              value: Union[float, int]
         
     | 
| 
       126 
     | 
    
         
            -
             
     | 
| 
       127 
     | 
    
         
            -
              # shared
         
     | 
| 
       128 
     | 
    
         
            -
              valid: Node
         
     | 
| 
       129 
     | 
    
         
            -
              invalid_value: Union[float, int] = 0.0
         
     | 
| 
       130 
     | 
    
         
            -
             
     | 
| 
       131 
     | 
    
         
            -
            class UOp(NamedTuple):
         
     | 
| 
       132 
     | 
    
         
            -
              uop: UOps
         
     | 
| 
       133 
     | 
    
         
            -
              out: Optional[Token]
         
     | 
| 
       134 
     | 
    
         
            -
              vin: List[Token]
         
     | 
| 
       135 
     | 
    
         
            -
              arg: Any
         
     | 
| 
       136 
     | 
    
         
            -
              def __repr__(self): return f"{str(self.uop):20s}: {str(self.out) if self.out is not None else '':25s} {str(self.vin):32s} {self.arg}"
         
     | 
| 
       137 
     | 
    
         
            -
             
     | 
| 
       138 
     | 
    
         
            -
            class Linearizer(OptimizedKernel):
         
     | 
| 
       139 
     | 
    
         
            -
              def get_buffer_name(self, i):
         
     | 
| 
       140 
     | 
    
         
            -
                if self.bufs[i].__class__ == LocalBuffer: return self.bufs[i].name
         
     | 
| 
       141 
     | 
    
         
            -
                assert self.bufs[i].realized.__class__ is not RawConst  # constants shouldn't be loaded with memops
         
     | 
| 
       142 
     | 
    
         
            -
                return self.arg_bufs[self.bufs[i].realized]
         
     | 
| 
       143 
     | 
    
         
            -
             
     | 
| 
       144 
     | 
    
         
            -
              def global_load(self, i:int, idxs:Sequence[VariableOrNum], acc=None) -> List[Token]:
         
     | 
| 
       145 
     | 
    
         
            -
                const = self.bufs[i].realized._buf if isinstance(self.bufs[i].realized, RawConst) else acc
         
     | 
| 
       146 
     | 
    
         
            -
             
     | 
| 
       147 
     | 
    
         
            -
                expanded_nodes = [expand_node(idx) for idx in idxs]
         
     | 
| 
       148 
     | 
    
         
            -
                _idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
         
     | 
| 
       149 
     | 
    
         
            -
                upcast_dim = self.get_upcast_dim(i)
         
     | 
| 
       150 
     | 
    
         
            -
             
     | 
| 
       151 
     | 
    
         
            -
                amt = 1
         
     | 
| 
       152 
     | 
    
         
            -
                if len(upcast_dim) == 1 and len(expanded_nodes[upcast_dim[0]]) in [4,2]:
         
     | 
| 
       153 
     | 
    
         
            -
                  dim, amt = upcast_dim[0], len(expanded_nodes[upcast_dim[0]])
         
     | 
| 
      
 10 
     | 
    
         
            +
            from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node
         
     | 
| 
      
 11 
     | 
    
         
            +
            from tinygrad.codegen.kernel import LocalBuffer, Kernel
         
     | 
| 
      
 12 
     | 
    
         
            +
            from tinygrad.renderer import Program
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
            from tinygrad.codegen.uops import UOps, UOp, UOpGraph
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
         
     | 
| 
      
 17 
     | 
    
         
            +
              local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate((prod(local_dims[:-(maxdim-1)]),) + local_dims[-(maxdim-1):] if len(local_dims) > maxdim else local_dims)]  # noqa: E501
         
     | 
| 
      
 18 
     | 
    
         
            +
              if maxdim != 0 and len(local_dims) > maxdim:
         
     | 
| 
      
 19 
     | 
    
         
            +
                dd = local_idxs[0]
         
     | 
| 
      
 20 
     | 
    
         
            +
                nli = []
         
     | 
| 
      
 21 
     | 
    
         
            +
                for s in local_dims[:-(maxdim-1)]:
         
     | 
| 
      
 22 
     | 
    
         
            +
                  nli.append(dd % s)
         
     | 
| 
      
 23 
     | 
    
         
            +
                  dd //= s
         
     | 
| 
      
 24 
     | 
    
         
            +
                local_idxs = nli + local_idxs[-(maxdim-1):]
         
     | 
| 
      
 25 
     | 
    
         
            +
              return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
            def expand_idx(node:Node) -> Union[Variable, NumNode]: return next((v for v in node.vars() if v.expr.startswith("_uidx")), NumNode(0))
         
     | 
| 
      
 28 
     | 
    
         
            +
            def expand_idxs(nodes:Sequence[Node]) -> Tuple[Union[Variable, NumNode], ...]:
         
     | 
| 
      
 29 
     | 
    
         
            +
              eidxs = [expand_idx(node) for node in nodes]
         
     | 
| 
      
 30 
     | 
    
         
            +
              return tuple([v if v not in eidxs[:j] else NumNode(0) for j, v in enumerate(eidxs)])  # take only first occurrence of expand variable
         
     | 
| 
      
 31 
     | 
    
         
            +
            def iter_idxs(idxs:Tuple[Union[Variable, NumNode], ...]) -> Iterator[Tuple[int,...]]:
         
     | 
| 
      
 32 
     | 
    
         
            +
              yield from (x[::-1] for x in itertools.product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]]))
         
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
            def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
         
     | 
| 
      
 35 
     | 
    
         
            +
              idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
         
     | 
| 
      
 36 
     | 
    
         
            +
              # TODO: bring back the valid removal logic (correct!)
         
     | 
| 
      
 37 
     | 
    
         
            +
              if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid)
         
     | 
| 
      
 38 
     | 
    
         
            +
              return (idx, idy), valid
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
            # expand a Node into List[Node] that enumerates the underlying Variables from min to max
         
     | 
| 
      
 41 
     | 
    
         
            +
            # expand increments earlier variables faster than later variables (as specified in the argument)
         
     | 
| 
      
 42 
     | 
    
         
            +
            @functools.lru_cache(maxsize=None)
         
     | 
| 
      
 43 
     | 
    
         
            +
            def expand_node(node:Node, idxs:Optional[Tuple[Union[Variable, NumNode], ...]]=None) -> List[Node]:
         
     | 
| 
      
 44 
     | 
    
         
            +
              if idxs is None: idxs = (expand_idx(node),)
         
     | 
| 
      
 45 
     | 
    
         
            +
              return [node.substitute({k:v for k,v in zip(idxs, (NumNode(x) for x in rep)) if isinstance(k, Variable)}) for rep in iter_idxs(idxs)]
         
     | 
| 
      
 46 
     | 
    
         
            +
             
     | 
| 
      
 47 
     | 
    
         
            +
            class Linearizer(Kernel):
         
     | 
| 
      
 48 
     | 
    
         
            +
              def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op): return UOp.alu(op, a, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
         
     | 
| 
      
 49 
     | 
    
         
            +
             
     | 
| 
      
 50 
     | 
    
         
            +
              # NOTE: the consts have to be cached for deduping of downstream uops to work
         
     | 
| 
      
 51 
     | 
    
         
            +
              def const(self, b:ConstType, dtype:DType=dtypes.int32) -> UOp:
         
     | 
| 
      
 52 
     | 
    
         
            +
                return self.uops.add(UOps.DEFINE_VAR, dtype, (), b.unbind()[0]) if isinstance(b, Variable) else UOp.const(dtype, b)
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
              def get_reduce_acc(self, reduceop:LazyOp):
         
     | 
| 
      
 55 
     | 
    
         
            +
                if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
         
     | 
| 
      
 56 
     | 
    
         
            +
                if reduceop.op is ReduceOps.MAX:
         
     | 
| 
      
 57 
     | 
    
         
            +
                  if dtypes.is_int(reduceop.dtype): return 0 if dtypes.is_unsigned(reduceop.dtype) else -2**(reduceop.dtype.itemsize*8-1)
         
     | 
| 
      
 58 
     | 
    
         
            +
                  return -math.inf if dtypes.is_float(reduceop.dtype) else False
         
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
              # NOTE: once images are loaded, we uop them as their base float
         
     | 
| 
      
 61 
     | 
    
         
            +
              def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
              render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
         
     | 
| 
      
 64 
     | 
    
         
            +
                            MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
         
     | 
| 
      
 65 
     | 
    
         
            +
                            DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
         
     | 
| 
      
 66 
     | 
    
         
            +
                            ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
         
     | 
| 
      
 67 
     | 
    
         
            +
                            LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT),
         
     | 
| 
      
 68 
     | 
    
         
            +
                SumNode: lambda self,ops,ctx:
         
     | 
| 
      
 69 
     | 
    
         
            +
                  functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
         
     | 
| 
      
 70 
     | 
    
         
            +
                AndNode: lambda self,ops,ctx:
         
     | 
| 
      
 71 
     | 
    
         
            +
                  functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
         
     | 
| 
      
 72 
     | 
    
         
            +
             
     | 
| 
      
 73 
     | 
    
         
            +
              def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Optional[UOp]=None, loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]:
         
     | 
| 
      
 74 
     | 
    
         
            +
                buf = self.bufs[i]
         
     | 
| 
      
 75 
     | 
    
         
            +
                localtype = self.get_base_dtype(buf.dtype if acc is None else acc.dtype)
         
     | 
| 
      
 76 
     | 
    
         
            +
                const = buf.val if isinstance(buf, ConstBuffer) else None
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
                expand_vars = expand_idxs(idxs)
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
                dim, amt = None, 1
         
     | 
| 
      
 81 
     | 
    
         
            +
                # float 4 grouping
         
     | 
| 
      
 82 
     | 
    
         
            +
                if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [4,2]:
         
     | 
| 
      
 83 
     | 
    
         
            +
                  dim, amt = upcast_dim[0], len(float4_expand)
         
     | 
| 
      
 84 
     | 
    
         
            +
                  g_idx, g_valid = self.sts[i].expr_idxs(idxs[:dim] + [float4_expand[0]] + idxs[dim+1:])
         
     | 
| 
      
 85 
     | 
    
         
            +
                  # do not use float4 if idx is not aligned
         
     | 
| 
      
 86 
     | 
    
         
            +
                  if g_idx != (g_idx//amt*amt): dim, amt = None, 1
         
     | 
| 
      
 87 
     | 
    
         
            +
                if dim is None:
         
     | 
| 
      
 88 
     | 
    
         
            +
                  g_idx, g_valid = self.sts[i].expr_idxs(idxs)
         
     | 
| 
      
 89 
     | 
    
         
            +
                # todo: multioutput test with different output valids to add if acc is None: g_valid = NumNode(1)
         
     | 
| 
      
 90 
     | 
    
         
            +
             
     | 
| 
      
 91 
     | 
    
         
            +
                if amt > 1: localtype = localtype.vec(amt)
         
     | 
| 
      
 92 
     | 
    
         
            +
                e_idxs, e_valids = expand_node(g_idx, expand_vars), expand_node(g_valid, expand_vars)
         
     | 
| 
       154 
93 
     | 
    
         | 
| 
       155 
94 
     | 
    
         
             
                ret = []
         
     | 
| 
       156 
     | 
    
         
            -
                invalid_value = 0 
     | 
| 
       157 
     | 
    
         
            -
                 
     | 
| 
       158 
     | 
    
         
            -
             
     | 
| 
       159 
     | 
    
         
            -
             
     | 
| 
       160 
     | 
    
         
            -
             
     | 
| 
       161 
     | 
    
         
            -
             
     | 
| 
       162 
     | 
    
         
            -
                      idx, valid = self.sts[i].expr_idxs(_idx)
         
     | 
| 
       163 
     | 
    
         
            -
                      localtype = dtypes.float32
         
     | 
| 
       164 
     | 
    
         
            -
                  else:
         
     | 
| 
       165 
     | 
    
         
            -
                    idx, valid = self.sts[i].expr_idxs(_idx)
         
     | 
| 
       166 
     | 
    
         
            -
                    localtype = dtypes.float32
         
     | 
| 
       167 
     | 
    
         
            -
                  this_const, idx, valid = (invalid_value, Variable.num(0), Variable.num(1)) if valid.max == 0 else (const, idx, valid)
         
     | 
| 
       168 
     | 
    
         
            -
                  key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else self.get_buffer_name(i)}{idx.render()}{valid.render()}"
         
     | 
| 
      
 95 
     | 
    
         
            +
                invalid_value = 0
         
     | 
| 
      
 96 
     | 
    
         
            +
                acc_count = 0
         
     | 
| 
      
 97 
     | 
    
         
            +
                for idx, valid, rep_idx in zip(e_idxs, e_valids, iter_idxs(expand_vars)):
         
     | 
| 
      
 98 
     | 
    
         
            +
                  this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid)
         
     | 
| 
      
 99 
     | 
    
         
            +
                  # todo: when multiple reduceops are supported, clearly disambiguate and test acc load keys are unique for each reduceop
         
     | 
| 
      
 100 
     | 
    
         
            +
                  key = f"{acc is not None}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}"  # noqa: E501
         
     | 
| 
       169 
101 
     | 
    
         
             
                  if key not in self.load_cache:
         
     | 
| 
       170 
     | 
    
         
            -
                    if  
     | 
| 
       171 
     | 
    
         
            -
             
     | 
| 
       172 
     | 
    
         
            -
             
     | 
| 
       173 
     | 
    
         
            -
             
     | 
| 
      
 102 
     | 
    
         
            +
                    if acc is not None:
         
     | 
| 
      
 103 
     | 
    
         
            +
                      self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, loop_ctx, (self.get_reduce_acc(acc), i, acc_count))
         
     | 
| 
      
 104 
     | 
    
         
            +
                      acc_count += 1
         
     | 
| 
      
 105 
     | 
    
         
            +
                    elif this_const is not None:
         
     | 
| 
      
 106 
     | 
    
         
            +
                      self.load_cache[key] = self.const(this_const, localtype)
         
     | 
| 
      
 107 
     | 
    
         
            +
                      if valid.min == 0 and valid.max == 1:
         
     | 
| 
      
 108 
     | 
    
         
            +
                        valid_rendered = valid.render(self.render_ops, self)
         
     | 
| 
      
 109 
     | 
    
         
            +
                        self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_rendered, self.load_cache[key], self.const(invalid_value, localtype))
         
     | 
| 
      
 110 
     | 
    
         
            +
                    elif isinstance(buf.dtype, ImageDType):
         
     | 
| 
      
 111 
     | 
    
         
            +
                      buf_uop = self.buf_uops[i]
         
     | 
| 
      
 112 
     | 
    
         
            +
                      assert buf_uop is not None, f"buffer {i} wasn't UOped"
         
     | 
| 
      
 113 
     | 
    
         
            +
                      image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
         
     | 
| 
      
 114 
     | 
    
         
            +
                      rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
         
     | 
| 
      
 115 
     | 
    
         
            +
                      valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, buf.dtype.base.vec(4))) if valid.min == 0 else tuple()
         
     | 
| 
      
 116 
     | 
    
         
            +
                      self.load_cache[key] = self.uops.add(UOps.LOAD, buf.dtype.base.vec(4),
         
     | 
| 
      
 117 
     | 
    
         
            +
                                                           (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
         
     | 
| 
      
 118 
     | 
    
         
            +
                      if localtype == localtype.scalar():
         
     | 
| 
      
 119 
     | 
    
         
            +
                        idx_small = idx%4
         
     | 
| 
      
 120 
     | 
    
         
            +
                        res = idx_small.render(self.render_ops, self)
         
     | 
| 
      
 121 
     | 
    
         
            +
                        out = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
         
     | 
| 
      
 122 
     | 
    
         
            +
                        for ix in range(idx_small.max, idx_small.min, -1):
         
     | 
| 
      
 123 
     | 
    
         
            +
                          rvv = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
         
     | 
| 
      
 124 
     | 
    
         
            +
                          sel = UOp.alu(BinaryOps.CMPLT, res, self.const(ix))
         
     | 
| 
      
 125 
     | 
    
         
            +
                          out = UOp.alu(TernaryOps.WHERE, sel, rvv, out)
         
     | 
| 
      
 126 
     | 
    
         
            +
                        self.load_cache[key] = out
         
     | 
| 
      
 127 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 128 
     | 
    
         
            +
                      buf_uop = self.buf_uops[i]
         
     | 
| 
      
 129 
     | 
    
         
            +
                      assert buf_uop is not None, f"buffer {i} wasn't UOped"
         
     | 
| 
      
 130 
     | 
    
         
            +
                      rendered_idx = idx.render(self.render_ops, self)
         
     | 
| 
      
 131 
     | 
    
         
            +
                      valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, localtype)) if valid.min == 0 else tuple()
         
     | 
| 
      
 132 
     | 
    
         
            +
                      self.load_cache[key] = self.uops.add(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
         
     | 
| 
      
 133 
     | 
    
         
            +
                  ret.append(self.uops.add(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
         
     | 
| 
       174 
134 
     | 
    
         
             
                return ret
         
     | 
| 
       175 
135 
     | 
    
         | 
| 
       176 
     | 
    
         
            -
              def global_store(self, i, idxs:List[ 
     | 
| 
       177 
     | 
    
         
            -
                 
     | 
| 
       178 
     | 
    
         
            -
                 
     | 
| 
       179 
     | 
    
         
            -
                 
     | 
| 
      
 136 
     | 
    
         
            +
              def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
         
     | 
| 
      
 137 
     | 
    
         
            +
                buf = self.bufs[i]
         
     | 
| 
      
 138 
     | 
    
         
            +
                buf_uop = self.buf_uops[i]
         
     | 
| 
      
 139 
     | 
    
         
            +
                assert buf_uop is not None, f"buffer {i} wasn't UOped"
         
     | 
| 
       180 
140 
     | 
    
         | 
| 
      
 141 
     | 
    
         
            +
                expand_vars = expand_idxs(idxs)
         
     | 
| 
      
 142 
     | 
    
         
            +
                _idxs = zip(*[expand_node(idx, expand_vars) for idx in idxs]) if idxs else [tuple()]  # transpose
         
     | 
| 
       181 
143 
     | 
    
         
             
                store_offset = dict(zip(_idxs, store))
         
     | 
| 
       182 
144 
     | 
    
         | 
| 
       183 
145 
     | 
    
         
             
                # float4 grouping
         
     | 
| 
       184 
     | 
    
         
            -
                if len(upcast_dim) == 1 and len( 
     | 
| 
      
 146 
     | 
    
         
            +
                if len(upcast_dim := self.get_float4_upcast_dim(i)) == 1 and len(float4_expand := expand_node(idxs[upcast_dim[0]])) in [2,4]:
         
     | 
| 
       185 
147 
     | 
    
         
             
                  grouped_store_offset = defaultdict(list)
         
     | 
| 
       186 
148 
     | 
    
         
             
                  for k in store_offset:
         
     | 
| 
       187 
     | 
    
         
            -
                    _idx = k[:upcast_dim[0]] + ( 
     | 
| 
      
 149 
     | 
    
         
            +
                    _idx = k[:upcast_dim[0]] + (float4_expand[0],) + k[upcast_dim[0]+1:]
         
     | 
| 
       188 
150 
     | 
    
         
             
                    grouped_store_offset[_idx].append(store_offset[k])
         
     | 
| 
       189 
151 
     | 
    
         
             
                  store_offset_new = {}
         
     | 
| 
       190 
     | 
    
         
            -
                  for k, 
     | 
| 
       191 
     | 
    
         
            -
                    amt = len( 
     | 
| 
      
 152 
     | 
    
         
            +
                  for k,grouped in grouped_store_offset.items():
         
     | 
| 
      
 153 
     | 
    
         
            +
                    amt = len(grouped)
         
     | 
| 
       192 
154 
     | 
    
         
             
                    idx, valid = self.sts[i].expr_idxs(k)
         
     | 
| 
       193 
     | 
    
         
            -
                    assert idx 
     | 
| 
       194 
     | 
    
         
            -
                     
     | 
| 
       195 
     | 
    
         
            -
                    if all_same([x.name for x in out_tokens]) and tuple(range(amt)) == tuple(x.offset for x in out_tokens):
         
     | 
| 
       196 
     | 
    
         
            -
                      store_offset_new[k] = Token(out_tokens[0].name, dtypes._float4 if amt == 4 else dtypes._float2)
         
     | 
| 
       197 
     | 
    
         
            -
                    else:
         
     | 
| 
       198 
     | 
    
         
            -
                      store_offset_new[k] = self.uop(UOps.CAST, ssa("alu", dtypes._float4 if amt == 4 else dtypes._float2), out_tokens)
         
     | 
| 
      
 155 
     | 
    
         
            +
                    assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
         
     | 
| 
      
 156 
     | 
    
         
            +
                    store_offset_new[k] = self.uops.add(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
         
     | 
| 
       199 
157 
     | 
    
         
             
                  store_offset = store_offset_new
         
     | 
| 
       200 
158 
     | 
    
         | 
| 
       201 
     | 
    
         
            -
                 
     | 
| 
       202 
     | 
    
         
            -
             
     | 
| 
       203 
     | 
    
         
            -
                   
     | 
| 
       204 
     | 
    
         
            -
                   
     | 
| 
      
 159 
     | 
    
         
            +
                stores = []
         
     | 
| 
      
 160 
     | 
    
         
            +
                for _idx, var in store_offset.items():
         
     | 
| 
      
 161 
     | 
    
         
            +
                  idx, valid = self.sts[i].expr_idxs(_idx)
         
     | 
| 
      
 162 
     | 
    
         
            +
                  if isinstance(buf.dtype, ImageDType):
         
     | 
| 
      
 163 
     | 
    
         
            +
                    image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
         
     | 
| 
      
 164 
     | 
    
         
            +
                    rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), \
         
     | 
| 
      
 165 
     | 
    
         
            +
                                  tuple(x.render(self.render_ops, self) for x in image_idx))
         
     | 
| 
      
 166 
     | 
    
         
            +
                  else:
         
     | 
| 
      
 167 
     | 
    
         
            +
                    rendered_idx = idx.render(self.render_ops, self)
         
     | 
| 
      
 168 
     | 
    
         
            +
                  if valid.min == 1: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var)))
         
     | 
| 
      
 169 
     | 
    
         
            +
                  else: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
         
     | 
| 
      
 170 
     | 
    
         
            +
                return stores
         
     | 
| 
      
 171 
     | 
    
         
            +
             
     | 
| 
      
 172 
     | 
    
         
            +
              # render loop
         
     | 
| 
      
 173 
     | 
    
         
            +
              def render_loop(self, xx:List[Variable], depth:int) -> Tuple[UOp, ...]:
         
     | 
| 
      
 174 
     | 
    
         
            +
                new_loops = {x.expr:self.uops.add(UOps.RANGE, dtypes.int32, (
         
     | 
| 
      
 175 
     | 
    
         
            +
                  self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
         
     | 
| 
      
 176 
     | 
    
         
            +
                  self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), arg=(depth,i)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None}  # noqa: E501
         
     | 
| 
      
 177 
     | 
    
         
            +
                self.loop_uops.update(new_loops)
         
     | 
| 
      
 178 
     | 
    
         
            +
                return tuple(new_loops.values())
         
     | 
| 
      
 179 
     | 
    
         
            +
             
     | 
| 
      
 180 
     | 
    
         
            +
              def render_reduceop(self, reduceop:LazyOp, accs:Dict[LazyOp, List[UOp]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]],
         
     | 
| 
      
 181 
     | 
    
         
            +
                                  global_idxs, local_idxs, upcast_idxs):
         
     | 
| 
      
 182 
     | 
    
         
            +
                # define indicies
         
     | 
| 
      
 183 
     | 
    
         
            +
                full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
         
     | 
| 
      
 184 
     | 
    
         
            +
                reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)]  # noqa: E501
         
     | 
| 
      
 185 
     | 
    
         
            +
                fake_reduce_idxs = [x*0 for x in reduce_idxs]
         
     | 
| 
      
 186 
     | 
    
         
            +
             
     | 
| 
      
 187 
     | 
    
         
            +
                def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
         
     | 
| 
      
 188 
     | 
    
         
            +
                  replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
         
     | 
| 
      
 189 
     | 
    
         
            +
                  for s in local_sizes:
         
     | 
| 
      
 190 
     | 
    
         
            +
                    thread_idxs.append(thread_idx % s)
         
     | 
| 
      
 191 
     | 
    
         
            +
                    thread_idx //= s
         
     | 
| 
      
 192 
     | 
    
         
            +
                  for alias in aliases:
         
     | 
| 
      
 193 
     | 
    
         
            +
                    full_var, full_var_sz = NumNode(0), 1
         
     | 
| 
      
 194 
     | 
    
         
            +
                    if alias[0] != 0:
         
     | 
| 
      
 195 
     | 
    
         
            +
                      for i in alias:
         
     | 
| 
      
 196 
     | 
    
         
            +
                        next_var = local_idxs[i-1] if i > 0 else thread_idxs[-i-1]
         
     | 
| 
      
 197 
     | 
    
         
            +
                        full_var += next_var * full_var_sz
         
     | 
| 
      
 198 
     | 
    
         
            +
                        full_var_sz *= next_var.max+1
         
     | 
| 
      
 199 
     | 
    
         
            +
                    replace_idxs.append(full_var)
         
     | 
| 
      
 200 
     | 
    
         
            +
                  return replace_idxs
         
     | 
| 
      
 201 
     | 
    
         
            +
             
     | 
| 
      
 202 
     | 
    
         
            +
                # compute local aliases - modify idxs if necessary for TC
         
     | 
| 
      
 203 
     | 
    
         
            +
                alias_buf_idxs = []
         
     | 
| 
      
 204 
     | 
    
         
            +
                for i in self.local_alias:
         
     | 
| 
      
 205 
     | 
    
         
            +
                  localbuf_idx = self.bufs.index(self.local_alias[i])
         
     | 
| 
      
 206 
     | 
    
         
            +
                  buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
         
     | 
| 
      
 207 
     | 
    
         
            +
                  if (tc:=self.tensor_core):
         
     | 
| 
      
 208 
     | 
    
         
            +
                    min_alias_idx = min(self.local_alias.keys())
         
     | 
| 
      
 209 
     | 
    
         
            +
                    replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
         
     | 
| 
      
 210 
     | 
    
         
            +
                    for n in range(len(tc.threads)):
         
     | 
| 
      
 211 
     | 
    
         
            +
                      buf_idxs[self.global_dims+n] = replace_input_idxs[n] # replace locals
         
     | 
| 
      
 212 
     | 
    
         
            +
                    for n in range(tc.num_upcasts()):
         
     | 
| 
      
 213 
     | 
    
         
            +
                      buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
         
     | 
| 
      
 214 
     | 
    
         
            +
                  if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
         
     | 
| 
      
 215 
     | 
    
         
            +
                  alias_buf_idxs.append((i, localbuf_idx, buf_idxs,))
         
     | 
| 
      
 216 
     | 
    
         
            +
             
     | 
| 
      
 217 
     | 
    
         
            +
                # reduce loop
         
     | 
| 
      
 218 
     | 
    
         
            +
                loop_ctx = self.render_loop(reduce_idxs, 2)
         
     | 
| 
      
 219 
     | 
    
         
            +
             
     | 
| 
      
 220 
     | 
    
         
            +
                # define accumulator - modify idxs if necessary for TC
         
     | 
| 
      
 221 
     | 
    
         
            +
                out_buf = -1 if self.group_for_reduces else 0
         
     | 
| 
      
 222 
     | 
    
         
            +
                if (tc:=self.tensor_core):
         
     | 
| 
      
 223 
     | 
    
         
            +
                  replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
         
     | 
| 
      
 224 
     | 
    
         
            +
                  for n in range(len(tc.threads)):
         
     | 
| 
      
 225 
     | 
    
         
            +
                    local_idxs[n] = replace_acc_idxs[n] # replace locals
         
     | 
| 
      
 226 
     | 
    
         
            +
                  for n in range(len(replace_acc_idxs)-len(tc.threads)):
         
     | 
| 
      
 227 
     | 
    
         
            +
                    upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
         
     | 
| 
      
 228 
     | 
    
         
            +
                  if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs}")
         
     | 
| 
      
 229 
     | 
    
         
            +
                accs[reduceop] = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
         
     | 
| 
      
 230 
     | 
    
         
            +
             
     | 
| 
      
 231 
     | 
    
         
            +
                # store local aliases
         
     | 
| 
      
 232 
     | 
    
         
            +
                locals_to_store = [(localbuf_idx, buf_idxs, self.global_load(i, buf_idxs)) for i, localbuf_idx, buf_idxs in alias_buf_idxs]
         
     | 
| 
      
 233 
     | 
    
         
            +
             
     | 
| 
      
 234 
     | 
    
         
            +
                if (tc:=self.tensor_core):
         
     | 
| 
      
 235 
     | 
    
         
            +
                  # run tensor cores AST
         
     | 
| 
      
 236 
     | 
    
         
            +
                  wmma_sz = [prod(l) for l in tc.thread_local_sizes]
         
     | 
| 
      
 237 
     | 
    
         
            +
                  def upcast_strides(buf:int):
         
     | 
| 
      
 238 
     | 
    
         
            +
                    strides, next = [], 1
         
     | 
| 
      
 239 
     | 
    
         
            +
                    for (sz, stride, reduce) in self.upcasted_axis(buf)[tc.num_upcasts():]:
         
     | 
| 
      
 240 
     | 
    
         
            +
                      strides.append((0 if stride == 0 else next, sz))
         
     | 
| 
      
 241 
     | 
    
         
            +
                      next *= 1 if stride == 0 else sz
         
     | 
| 
      
 242 
     | 
    
         
            +
                    return strides
         
     | 
| 
      
 243 
     | 
    
         
            +
                  upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
         
     | 
| 
      
 244 
     | 
    
         
            +
                  # cast initial accs
         
     | 
| 
      
 245 
     | 
    
         
            +
                  wmmas = [self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
         
     | 
| 
      
 246 
     | 
    
         
            +
                           for x in range(0, len(accs[reduceop]), wmma_sz[2])]
         
     | 
| 
      
 247 
     | 
    
         
            +
                  for iter in [x[::-1] for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]:
         
     | 
| 
      
 248 
     | 
    
         
            +
                    offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
         
     | 
| 
      
 249 
     | 
    
         
            +
                    ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
         
     | 
| 
      
 250 
     | 
    
         
            +
                            self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
         
     | 
| 
      
 251 
     | 
    
         
            +
                            wmmas[(wmma_idx:=offs[2]//wmma_sz[2])])
         
     | 
| 
      
 252 
     | 
    
         
            +
                    # TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid
         
     | 
| 
      
 253 
     | 
    
         
            +
                    wmmas[wmma_idx] = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev))
         
     | 
| 
      
 254 
     | 
    
         
            +
                  # phi the last wmmas back to accs
         
     | 
| 
      
 255 
     | 
    
         
            +
                  accs[reduceop] = [self.uops.add(UOps.PHI, tc.dtype_out, (acc, self.uops.add(UOps.GEP, tc.dtype_out, (wmmas[z//wmma_sz[2]],), z%wmma_sz[2])))
         
     | 
| 
      
 256 
     | 
    
         
            +
                                    for z, acc in enumerate(accs[reduceop])]
         
     | 
| 
      
 257 
     | 
    
         
            +
                else:
         
     | 
| 
      
 258 
     | 
    
         
            +
                  assert not locals_to_store, "storing locals isn't supported here"
         
     | 
| 
      
 259 
     | 
    
         
            +
             
     | 
| 
      
 260 
     | 
    
         
            +
                  # load earlybufs
         
     | 
| 
      
 261 
     | 
    
         
            +
                  loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i,
         
     | 
| 
      
 262 
     | 
    
         
            +
                    global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs})
         
     | 
| 
      
 263 
     | 
    
         
            +
             
     | 
| 
      
 264 
     | 
    
         
            +
                  # run early AST (with reduce)
         
     | 
| 
      
 265 
     | 
    
         
            +
                  self.ast_parse(reduceop, accs, self.acc_offsets(self.full_buf_index), loaded_buffers, reduce_acc=accs[reduceop])
         
     | 
| 
      
 266 
     | 
    
         
            +
             
     | 
| 
      
 267 
     | 
    
         
            +
                # end the reduce loop
         
     | 
| 
      
 268 
     | 
    
         
            +
                self.load_cache.clear()
         
     | 
| 
      
 269 
     | 
    
         
            +
             
     | 
| 
      
 270 
     | 
    
         
            +
                # end the local loop, do the local reduce
         
     | 
| 
      
 271 
     | 
    
         
            +
                if self.group_for_reduces:
         
     | 
| 
      
 272 
     | 
    
         
            +
                  fake_global_idxs = [x*0 for x in global_idxs]
         
     | 
| 
      
 273 
     | 
    
         
            +
                  stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop])  # store accumulators
         
     | 
| 
      
 274 
     | 
    
         
            +
                  barrier = self.uops.add(UOps.BARRIER, None, tuple(stores))
         
     | 
| 
      
 275 
     | 
    
         
            +
                  if self.opts.has_local:
         
     | 
| 
      
 276 
     | 
    
         
            +
                    fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
         
     | 
| 
      
 277 
     | 
    
         
            +
                    fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
         
     | 
| 
      
 278 
     | 
    
         
            +
                    if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(self.render_ops, self)
         
     | 
| 
      
 279 
     | 
    
         
            +
                    barrier = self.uops.add(UOps.IF, None, (if_cond, barrier))
         
     | 
| 
      
 280 
     | 
    
         
            +
             
     | 
| 
      
 281 
     | 
    
         
            +
                  # create new late reduce local loops and replace local_idxs that have been used
         
     | 
| 
      
 282 
     | 
    
         
            +
                  end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)]  # noqa: E501
         
     | 
| 
      
 283 
     | 
    
         
            +
                  local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
         
     | 
| 
      
 284 
     | 
    
         
            +
             
     | 
| 
      
 285 
     | 
    
         
            +
                  # if any group_for_reduce items aren't reduces, upcast them here
         
     | 
| 
      
 286 
     | 
    
         
            +
                  for j in self.upcast_in_mid_reduce_axes:
         
     | 
| 
      
 287 
     | 
    
         
            +
                    self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
         
     | 
| 
      
 288 
     | 
    
         
            +
                    self.upcast()
         
     | 
| 
      
 289 
     | 
    
         
            +
                    self.group_for_reduces -= 1
         
     | 
| 
      
 290 
     | 
    
         
            +
                    local_idxs = local_idxs[:-1]
         
     | 
| 
      
 291 
     | 
    
         
            +
                    end_local_idxs = end_local_idxs[:-1]
         
     | 
| 
      
 292 
     | 
    
         
            +
                    # regenerate upcast_idxs
         
     | 
| 
      
 293 
     | 
    
         
            +
                    upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
         
     | 
| 
      
 294 
     | 
    
         
            +
             
     | 
| 
      
 295 
     | 
    
         
            +
                  # NOTE: this structure is the same as the reduce op above
         
     | 
| 
      
 296 
     | 
    
         
            +
             
     | 
| 
      
 297 
     | 
    
         
            +
                  # late reduce loop
         
     | 
| 
      
 298 
     | 
    
         
            +
                  loop_ctx = self.render_loop(end_local_idxs, 3)
         
     | 
| 
      
 299 
     | 
    
         
            +
             
     | 
| 
      
 300 
     | 
    
         
            +
                  # define late accumulator
         
     | 
| 
      
 301 
     | 
    
         
            +
                  accs[reduceop] = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
         
     | 
| 
      
 302 
     | 
    
         
            +
             
     | 
| 
      
 303 
     | 
    
         
            +
                  # load localbufs
         
     | 
| 
      
 304 
     | 
    
         
            +
                  loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
         
     | 
| 
      
 305 
     | 
    
         
            +
             
     | 
| 
      
 306 
     | 
    
         
            +
                  # there's no AST here (and there's no shape for the reduce LazyOp)
         
     | 
| 
      
 307 
     | 
    
         
            +
                  self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)),\
         
     | 
| 
      
 308 
     | 
    
         
            +
                                 accs, self.acc_offsets(-1), loaded_buffers, reduce_acc=accs[reduceop])
         
     | 
| 
      
 309 
     | 
    
         
            +
             
     | 
| 
      
 310 
     | 
    
         
            +
                  # end the late reduce loop
         
     | 
| 
      
 311 
     | 
    
         
            +
                  self.load_cache.clear()
         
     | 
| 
      
 312 
     | 
    
         
            +
             
     | 
| 
      
 313 
     | 
    
         
            +
                  # all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
         
     | 
| 
      
 314 
     | 
    
         
            +
                  # been rewritten with fake end_local_idxs.
         
     | 
| 
      
 315 
     | 
    
         
            +
                return (accs, loaded_buffers, fake_reduce_idxs, local_idxs[:self.local_dims] + [NumNode(0) for i in range(self.group_for_reduces)], upcast_idxs)
         
     | 
| 
       205 
316 
     | 
    
         | 
| 
       206 
317 
     | 
    
         
             
              kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
         
     | 
| 
       207 
318 
     | 
    
         
             
              def linearize(self):
         
     | 
| 
       208 
     | 
    
         
            -
                 
     | 
| 
      
 319 
     | 
    
         
            +
                # no new opts and we already ran? skip relinearizing
         
     | 
| 
      
 320 
     | 
    
         
            +
                if self.applied_opts == self.applied_opts_cache: return self
         
     | 
| 
      
 321 
     | 
    
         
            +
             
     | 
| 
      
 322 
     | 
    
         
            +
                # late alias the tensor core buffers
         
     | 
| 
      
 323 
     | 
    
         
            +
                if (tc:=self.tensor_core) and (tc_opts:=self.tensor_core_opts):
         
     | 
| 
      
 324 
     | 
    
         
            +
                  alias_pattern = [0]*(self.global_dims) + [2]*(len(tc.threads)) + [0]*(self.local_dims-len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2)  # noqa: E501
         
     | 
| 
      
 325 
     | 
    
         
            +
                  for tc_buf in tc_opts.bufs:
         
     | 
| 
      
 326 
     | 
    
         
            +
                    self.alias_buffer(tc_buf, alias_pattern)
         
     | 
| 
      
 327 
     | 
    
         
            +
             
     | 
| 
      
 328 
     | 
    
         
            +
                # save backups
         
     | 
| 
      
 329 
     | 
    
         
            +
                sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduces, self.upcasted
         
     | 
| 
      
 330 
     | 
    
         
            +
             
     | 
| 
      
 331 
     | 
    
         
            +
                # global uop cache
         
     | 
| 
      
 332 
     | 
    
         
            +
                self.saved_exprs: Dict[Tuple, UOp] = dict()
         
     | 
| 
       209 
333 
     | 
    
         | 
| 
       210 
334 
     | 
    
         
             
                # limit dims if we need to
         
     | 
| 
       211 
     | 
    
         
            -
                if self.opts.global_max and self.opts.local_max: self. 
     | 
| 
      
 335 
     | 
    
         
            +
                if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max)
         
     | 
| 
       212 
336 
     | 
    
         | 
| 
       213 
337 
     | 
    
         
             
                # uops
         
     | 
| 
       214 
     | 
    
         
            -
                self.uops:  
     | 
| 
       215 
     | 
    
         
            -
                self. 
     | 
| 
       216 
     | 
    
         
            -
                self. 
     | 
| 
      
 338 
     | 
    
         
            +
                self.uops:UOpGraph = UOpGraph()
         
     | 
| 
      
 339 
     | 
    
         
            +
                self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
         
     | 
| 
      
 340 
     | 
    
         
            +
                self.loop_uops: Dict[str, UOp] = {}
         
     | 
| 
       217 
341 
     | 
    
         | 
| 
       218 
342 
     | 
    
         
             
                # add global buffers
         
     | 
| 
       219 
     | 
    
         
            -
                for buf 
     | 
| 
       220 
     | 
    
         
            -
                   
     | 
| 
       221 
     | 
    
         
            -
             
     | 
| 
       222 
     | 
    
         
            -
             
     | 
| 
       223 
     | 
    
         
            -
             
     | 
| 
       224 
     | 
    
         
            -
             
     | 
| 
       225 
     | 
    
         
            -
                 
     | 
| 
       226 
     | 
    
         
            -
             
     | 
| 
       227 
     | 
    
         
            -
                   
     | 
| 
       228 
     | 
    
         
            -
                  self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
         
     | 
| 
       229 
     | 
    
         
            -
                  self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
         
     | 
| 
       230 
     | 
    
         
            -
                  self.uop(UOps.DEFINE_LOCAL, None, [], ("temp", self.sts[-1].size()))
         
     | 
| 
       231 
     | 
    
         
            -
             
     | 
| 
      
 343 
     | 
    
         
            +
                for i,buf in enumerate(self.bufs):
         
     | 
| 
      
 344 
     | 
    
         
            +
                  if isinstance(buf, MemBuffer):
         
     | 
| 
      
 345 
     | 
    
         
            +
                    self.buf_uops[i] = self.uops.add(UOps.DEFINE_GLOBAL,
         
     | 
| 
      
 346 
     | 
    
         
            +
                                                     buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
         
     | 
| 
      
 347 
     | 
    
         
            +
                                                     (buf.idx, any(buf.idx == x.idx for x in self.outbufs)))
         
     | 
| 
      
 348 
     | 
    
         
            +
                # add var vals
         
     | 
| 
      
 349 
     | 
    
         
            +
                for i,var in enumerate(self.vars):
         
     | 
| 
      
 350 
     | 
    
         
            +
                  assert var.expr is not None
         
     | 
| 
      
 351 
     | 
    
         
            +
                  self.loop_uops[var.expr] = self.uops.add(UOps.DEFINE_VAR, dtypes.int32, (), var)
         
     | 
| 
       232 
352 
     | 
    
         
             
                # define local buffers
         
     | 
| 
       233 
353 
     | 
    
         
             
                for lb in self.local_alias.values():
         
     | 
| 
       234 
     | 
    
         
            -
                  self. 
     | 
| 
       235 
     | 
    
         
            -
             
     | 
| 
       236 
     | 
    
         
            -
                #  
     | 
| 
       237 
     | 
    
         
            -
                if  
     | 
| 
      
 354 
     | 
    
         
            +
                  self.buf_uops[self.bufs.index(lb)] = self.uops.add(UOps.DEFINE_LOCAL,
         
     | 
| 
      
 355 
     | 
    
         
            +
                                                                     PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size))
         
     | 
| 
      
 356 
     | 
    
         
            +
                # add a local buffer for multistage reduce. # TODO: use local alias
         
     | 
| 
      
 357 
     | 
    
         
            +
                if self.group_for_reduces:
         
     | 
| 
      
 358 
     | 
    
         
            +
                  # TODO: the strides of this can be controlled
         
     | 
| 
      
 359 
     | 
    
         
            +
                  self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))  # noqa: E501
         
     | 
| 
      
 360 
     | 
    
         
            +
                  temp_dtype = self.get_base_dtype(cast(LazyOp, self.reduceop).dtype)
         
     | 
| 
      
 361 
     | 
    
         
            +
                  self.bufs.append(LocalBuffer("temp", self.sts[-1].size, temp_dtype))
         
     | 
| 
      
 362 
     | 
    
         
            +
                  self.buf_uops.append(self.uops.add(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size)))
         
     | 
| 
       238 
363 
     | 
    
         | 
| 
       239 
364 
     | 
    
         
             
                # kernel name (before late upcast)
         
     | 
| 
       240 
     | 
    
         
            -
                self. 
     | 
| 
       241 
     | 
    
         
            -
             
     | 
| 
       242 
     | 
    
         
            -
             
     | 
| 
       243 
     | 
    
         
            -
                # parse AST
         
     | 
| 
       244 
     | 
    
         
            -
                loaded_buffers = {}
         
     | 
| 
       245 
     | 
    
         
            -
                acc = []
         
     | 
| 
       246 
     | 
    
         
            -
             
     | 
| 
       247 
     | 
    
         
            -
                # ssa
         
     | 
| 
       248 
     | 
    
         
            -
                _ssa:DefaultDict[str,int] = defaultdict(int)
         
     | 
| 
       249 
     | 
    
         
            -
                def ssa(name, ltype=dtypes.float) -> Token:
         
     | 
| 
       250 
     | 
    
         
            -
                  _ssa[name] += 1
         
     | 
| 
       251 
     | 
    
         
            -
                  return Token(f"{name}{_ssa[name]-1}", ltype)
         
     | 
| 
       252 
     | 
    
         
            -
             
     | 
| 
       253 
     | 
    
         
            -
                # global loop
         
     | 
| 
       254 
     | 
    
         
            -
                global_idxs = [Variable(f"gidx{i}", 0, self.full_shape[i]-1) for i in range(0, self.first_reduce-self.local_dims)]
         
     | 
| 
       255 
     | 
    
         
            -
                self.uop(UOps.LOOP, None, [], (global_idxs, "global"))
         
     | 
| 
      
 365 
     | 
    
         
            +
                self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
         
     | 
| 
      
 366 
     | 
    
         
            +
                             (f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \
         
     | 
| 
      
 367 
     | 
    
         
            +
                             colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
         
     | 
| 
       256 
368 
     | 
    
         | 
| 
       257 
     | 
    
         
            -
                #  
     | 
| 
       258 
     | 
    
         
            -
                 
     | 
| 
       259 
     | 
    
         
            -
                 
     | 
| 
      
 369 
     | 
    
         
            +
                # name the function something unique
         
     | 
| 
      
 370 
     | 
    
         
            +
                Linearizer.kernel_cnt[(function_name := to_function_name(self.name))] += 1
         
     | 
| 
      
 371 
     | 
    
         
            +
                suffix = f"{'n'+str(Linearizer.kernel_cnt[function_name]-1)}" if Linearizer.kernel_cnt[function_name] > 1 else ""
         
     | 
| 
      
 372 
     | 
    
         
            +
                self.name = self.name+colored(suffix, 'BLACK')
         
     | 
| 
      
 373 
     | 
    
         
            +
             
     | 
| 
      
 374 
     | 
    
         
            +
                # define indexes
         
     | 
| 
      
 375 
     | 
    
         
            +
                global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
         
     | 
| 
      
 376 
     | 
    
         
            +
                local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+self.group_for_reduces], 3 if self.opts.has_local else 0)  # noqa: E501
         
     | 
| 
      
 377 
     | 
    
         
            +
                upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
         
     | 
| 
      
 378 
     | 
    
         
            +
             
     | 
| 
      
 379 
     | 
    
         
            +
                # set global/local size
         
     | 
| 
      
 380 
     | 
    
         
            +
                self.global_size: Optional[List[int]] = None
         
     | 
| 
      
 381 
     | 
    
         
            +
                self.local_size: Optional[List[int]] = None
         
     | 
| 
      
 382 
     | 
    
         
            +
                if self.dont_use_locals:
         
     | 
| 
      
 383 
     | 
    
         
            +
                  self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
         
     | 
| 
      
 384 
     | 
    
         
            +
                  self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)})  # noqa: E501
         
     | 
| 
      
 385 
     | 
    
         
            +
                elif self.opts.has_local:
         
     | 
| 
      
 386 
     | 
    
         
            +
                  self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs]
         
     | 
| 
      
 387 
     | 
    
         
            +
                  self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})  # noqa: E501
         
     | 
| 
      
 388 
     | 
    
         
            +
                  self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
         
     | 
| 
      
 389 
     | 
    
         
            +
                else:
         
     | 
| 
      
 390 
     | 
    
         
            +
                  self.render_loop(loop_global_idxs+loop_local_idxs, 1)
         
     | 
| 
      
 391 
     | 
    
         
            +
                if self.global_size is not None: self.global_size += [1]*(3-len(self.global_size))
         
     | 
| 
      
 392 
     | 
    
         
            +
                if self.local_size is not None: self.local_size += [1]*(3-len(self.local_size))
         
     | 
| 
       260 
393 
     | 
    
         | 
| 
       261 
     | 
    
         
            -
                #  
     | 
| 
       262 
     | 
    
         
            -
                 
     | 
| 
       263 
     | 
    
         
            -
                 
     | 
| 
      
 394 
     | 
    
         
            +
                # parse AST
         
     | 
| 
      
 395 
     | 
    
         
            +
                loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]] = {}
         
     | 
| 
      
 396 
     | 
    
         
            +
                accs: Dict[LazyOp, List[UOp]] = {}
         
     | 
| 
      
 397 
     | 
    
         
            +
                self.load_cache: Dict[str, UOp] = {}
         
     | 
| 
       264 
398 
     | 
    
         | 
| 
       265 
399 
     | 
    
         
             
                # reduce op
         
     | 
| 
       266 
     | 
    
         
            -
                fake_reduce_idxs = []
         
     | 
| 
       267 
     | 
    
         
            -
                if self.reduceop is not None:
         
     | 
| 
       268 
     | 
    
         
            -
                   
     | 
| 
       269 
     | 
    
         
            -
             
     | 
| 
       270 
     | 
    
         
            -
                  fake_reduce_idxs = [x*0 for x in reduce_idxs]
         
     | 
| 
       271 
     | 
    
         
            -
             
     | 
| 
       272 
     | 
    
         
            -
                  # define accumulator
         
     | 
| 
       273 
     | 
    
         
            -
                  acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
         
     | 
| 
       274 
     | 
    
         
            -
             
     | 
| 
       275 
     | 
    
         
            -
                  # reduce loop
         
     | 
| 
       276 
     | 
    
         
            -
                  self.uop(UOps.LOOP, None, [], (reduce_idxs, "reduce"))
         
     | 
| 
       277 
     | 
    
         
            -
             
     | 
| 
       278 
     | 
    
         
            -
                  # barrier for fast GEMM
         
     | 
| 
       279 
     | 
    
         
            -
                  if self.use_tensor_cores: self.uop(UOps.BARRIER, None, [], ())
         
     | 
| 
       280 
     | 
    
         
            -
             
     | 
| 
       281 
     | 
    
         
            -
                  # compute local aliases
         
     | 
| 
       282 
     | 
    
         
            -
                  locals_to_store = []
         
     | 
| 
       283 
     | 
    
         
            -
                  for i in self.local_alias:
         
     | 
| 
       284 
     | 
    
         
            -
                    strides = self.sts[i].real_strides()
         
     | 
| 
       285 
     | 
    
         
            -
                    extra_locals = [lidx for lidx,st in zip(local_idxs[self.exclude_local_upcast:], strides[len(global_idxs)+self.exclude_local_upcast:self.first_reduce]) if st == 0]
         
     | 
| 
       286 
     | 
    
         
            -
                    this_upcast_idxs: List[Node] = []
         
     | 
| 
       287 
     | 
    
         
            -
                    # TODO: just flipping the order here is likely not generic at all
         
     | 
| 
       288 
     | 
    
         
            -
                    for j,v in list(enumerate(full_upcast_idxs))[::-1] if self.reverse_upcast_dir else list(enumerate(full_upcast_idxs)):
         
     | 
| 
       289 
     | 
    
         
            -
                      if strides[len(global_idxs)+len(local_idxs)+len(reduce_idxs)+j] == 0:
         
     | 
| 
       290 
     | 
    
         
            -
                        if DEBUG >= 4: print(f"upcasting@{j} stride 0")
         
     | 
| 
       291 
     | 
    
         
            -
                        this_upcast_idxs.append(Variable.num(0))
         
     | 
| 
       292 
     | 
    
         
            -
                      elif (elc:=[el for el in extra_locals if v.min == el.min and v.max == el.max]):
         
     | 
| 
       293 
     | 
    
         
            -
                        if DEBUG >= 4: print(f"upcasting@{j} matched stride {elc[0]}")
         
     | 
| 
       294 
     | 
    
         
            -
                        this_upcast_idxs.append(elc[0])
         
     | 
| 
       295 
     | 
    
         
            -
                        extra_locals.remove(elc[0])
         
     | 
| 
       296 
     | 
    
         
            -
                      elif (elc:=[el for el in extra_locals if v.min == el.min and (v.max+1)%(el.max+1) == 0]):
         
     | 
| 
       297 
     | 
    
         
            -
                        tacc = Variable.num(0)
         
     | 
| 
       298 
     | 
    
         
            -
                        rem = v.max+1
         
     | 
| 
       299 
     | 
    
         
            -
                        while len(elc) and rem%(elc[0].max+1) == 0:
         
     | 
| 
       300 
     | 
    
         
            -
                          if DEBUG >= 4: print(f"upcasting@{j} partial stride {rem} {elc[0]} left: {elc[1:]}")
         
     | 
| 
       301 
     | 
    
         
            -
                          rem = rem//(elc[0].max+1)
         
     | 
| 
       302 
     | 
    
         
            -
                          tacc += (elc[0] * rem)
         
     | 
| 
       303 
     | 
    
         
            -
                          extra_locals.remove(elc[0])
         
     | 
| 
       304 
     | 
    
         
            -
                          elc = [el for el in extra_locals if v.min == el.min and rem%(el.max+1) == 0]
         
     | 
| 
       305 
     | 
    
         
            -
                        if DEBUG >= 4 and rem > 1: print(f"failed upcasting@{j} partial stride {rem} extra locals {extra_locals}")
         
     | 
| 
       306 
     | 
    
         
            -
                        this_upcast_idxs.append(tacc + Variable(None, 0, rem-1))
         
     | 
| 
       307 
     | 
    
         
            -
                      else:
         
     | 
| 
       308 
     | 
    
         
            -
                        if DEBUG >= 4: print(f"failed upcasting@{j} stride {v} extra locals {extra_locals}")
         
     | 
| 
       309 
     | 
    
         
            -
                        this_upcast_idxs.append(v)
         
     | 
| 
       310 
     | 
    
         
            -
                    idxs = global_idxs+local_idxs+reduce_idxs+(this_upcast_idxs[::-1] if self.reverse_upcast_dir else this_upcast_idxs)
         
     | 
| 
       311 
     | 
    
         
            -
                    ll = self.global_load(i, idxs)
         
     | 
| 
       312 
     | 
    
         
            -
                    locals_to_store.append((self.bufs.index(self.local_alias[i]), idxs, ll))
         
     | 
| 
       313 
     | 
    
         
            -
             
     | 
| 
       314 
     | 
    
         
            -
                  # copy in any global buffers
         
     | 
| 
       315 
     | 
    
         
            -
                  if self.use_tensor_cores:
         
     | 
| 
       316 
     | 
    
         
            -
                    if self.bufs[0].device == "METAL":
         
     | 
| 
       317 
     | 
    
         
            -
                      i = 0
         
     | 
| 
       318 
     | 
    
         
            -
                      for y0,y1 in zip(locals_to_store[1][2][::2], locals_to_store[1][2][1::2]):
         
     | 
| 
       319 
     | 
    
         
            -
                        for x0,x1 in zip(locals_to_store[0][2][::2], locals_to_store[0][2][1::2]):
         
     | 
| 
       320 
     | 
    
         
            -
                          self.uop(UOps.WMMA, None, [x0, x1, y0, y1, acc[i], acc[i+1]], "METAL")
         
     | 
| 
       321 
     | 
    
         
            -
                          i += 2
         
     | 
| 
       322 
     | 
    
         
            -
                    elif self.bufs[0].device == "HIP":
         
     | 
| 
       323 
     | 
    
         
            -
                      i = 0
         
     | 
| 
       324 
     | 
    
         
            -
                      for y in range(0, len(locals_to_store[1][2]), 0x10):
         
     | 
| 
       325 
     | 
    
         
            -
                        for x in range(0, len(locals_to_store[0][2]), 0x10):
         
     | 
| 
       326 
     | 
    
         
            -
                          self.uop(UOps.WMMA, None, acc[i:i+8]+locals_to_store[0][2][x:x+0x10]+locals_to_store[1][2][y:y+0x10], "HIP")
         
     | 
| 
       327 
     | 
    
         
            -
                          i += 8
         
     | 
| 
       328 
     | 
    
         
            -
                  else:
         
     | 
| 
       329 
     | 
    
         
            -
                    if locals_to_store:
         
     | 
| 
       330 
     | 
    
         
            -
                      self.uop(UOps.BARRIER, None, [], ())
         
     | 
| 
       331 
     | 
    
         
            -
                      for i, idxs, ll in locals_to_store: self.global_store(i, idxs, ll, ssa)
         
     | 
| 
       332 
     | 
    
         
            -
                      self.uop(UOps.BARRIER, None, [], ())
         
     | 
| 
       333 
     | 
    
         
            -
             
     | 
| 
       334 
     | 
    
         
            -
                    # load earlybufs
         
     | 
| 
       335 
     | 
    
         
            -
                    loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
         
     | 
| 
       336 
     | 
    
         
            -
             
     | 
| 
       337 
     | 
    
         
            -
                    # run early AST (with reduce)
         
     | 
| 
       338 
     | 
    
         
            -
                    self.ast_parse(self.reduceop, [acc[off] for off in self.acc_offsets(self.full_buf_index)], loaded_buffers, ssa, do_reduce=True)
         
     | 
| 
       339 
     | 
    
         
            -
             
     | 
| 
       340 
     | 
    
         
            -
                  # end the reduce loop
         
     | 
| 
       341 
     | 
    
         
            -
                  self.uop(UOps.ENDLOOP, None, [], (reduce_idxs, "reduce"))
         
     | 
| 
       342 
     | 
    
         
            -
                  self.load_cache.clear()
         
     | 
| 
       343 
     | 
    
         
            -
             
     | 
| 
       344 
     | 
    
         
            -
                  # end the local loop, do the local reduce
         
     | 
| 
       345 
     | 
    
         
            -
                  if self.group_for_reduce:
         
     | 
| 
       346 
     | 
    
         
            -
                    fake_global_idxs = [x*0 for x in global_idxs]
         
     | 
| 
       347 
     | 
    
         
            -
                    self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc, ssa)  # store accumulators
         
     | 
| 
       348 
     | 
    
         
            -
                    self.uop(UOps.BARRIER, None, [], ())
         
     | 
| 
       349 
     | 
    
         
            -
                    self.uop(UOps.ENDLOOP, None, [], (local_idxs, "local"))
         
     | 
| 
       350 
     | 
    
         
            -
             
     | 
| 
       351 
     | 
    
         
            -
                    # local indexs are over, 0 them out
         
     | 
| 
       352 
     | 
    
         
            -
                    local_idxs = [x*0 for x in local_idxs]
         
     | 
| 
       353 
     | 
    
         
            -
             
     | 
| 
       354 
     | 
    
         
            -
                    # if any group_for_reduce items aren't reduces, upcast them here
         
     | 
| 
       355 
     | 
    
         
            -
                    for j in self.upcast_in_mid_reduce_axes:
         
     | 
| 
       356 
     | 
    
         
            -
                      self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
         
     | 
| 
       357 
     | 
    
         
            -
                      self.upcast()
         
     | 
| 
       358 
     | 
    
         
            -
                      self.group_for_reduce.pop()
         
     | 
| 
       359 
     | 
    
         
            -
                      local_idxs = local_idxs[:-1]
         
     | 
| 
       360 
     | 
    
         
            -
                      # regenerate upcast_idxs
         
     | 
| 
       361 
     | 
    
         
            -
                      upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
         
     | 
| 
       362 
     | 
    
         
            -
             
     | 
| 
       363 
     | 
    
         
            -
                    # NOTE: this structure is the same as the reduce op above
         
     | 
| 
       364 
     | 
    
         
            -
             
     | 
| 
       365 
     | 
    
         
            -
                    # define late accumulator
         
     | 
| 
       366 
     | 
    
         
            -
                    acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
         
     | 
| 
       367 
     | 
    
         
            -
             
     | 
| 
       368 
     | 
    
         
            -
                    # late reduce loop
         
     | 
| 
       369 
     | 
    
         
            -
                    end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
         
     | 
| 
       370 
     | 
    
         
            -
                    self.uop(UOps.LOOP, None, [], (end_local_idxs, "late_reduce"))
         
     | 
| 
       371 
     | 
    
         
            -
             
     | 
| 
       372 
     | 
    
         
            -
                    # load localbufs
         
     | 
| 
       373 
     | 
    
         
            -
                    loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs+upcast_idxs)
         
     | 
| 
       374 
     | 
    
         
            -
             
     | 
| 
       375 
     | 
    
         
            -
                    # there's no AST here (and there's no shape for the reduce LazyOp)
         
     | 
| 
       376 
     | 
    
         
            -
                    self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True) # type: ignore
         
     | 
| 
       377 
     | 
    
         
            -
             
     | 
| 
       378 
     | 
    
         
            -
                    # end the late reduce loop
         
     | 
| 
       379 
     | 
    
         
            -
                    self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce"))
         
     | 
| 
       380 
     | 
    
         
            -
                    self.load_cache.clear()
         
     | 
| 
      
 400 
     | 
    
         
            +
                fake_reduce_idxs: List[Variable] = []
         
     | 
| 
      
 401 
     | 
    
         
            +
                for reduceop in [self.reduceop] if self.reduceop is not None else []:
         
     | 
| 
      
 402 
     | 
    
         
            +
                  accs,loaded_buffers,fake_reduce_idxs,local_idxs,upcast_idxs = \
         
     | 
| 
      
 403 
     | 
    
         
            +
                    self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs)
         
     | 
| 
       381 
404 
     | 
    
         | 
| 
       382 
405 
     | 
    
         
             
                # load latebufs
         
     | 
| 
       383 
     | 
    
         
            -
                loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs)  
     | 
| 
      
 406 
     | 
    
         
            +
                loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
         
     | 
| 
      
 407 
     | 
    
         
            +
                                       for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})
         
     | 
| 
       384 
408 
     | 
    
         | 
| 
       385 
     | 
    
         
            -
                # run late AST
         
     | 
| 
       386 
     | 
    
         
            -
                 
     | 
| 
      
 409 
     | 
    
         
            +
                # run late AST (without the store)
         
     | 
| 
      
 410 
     | 
    
         
            +
                for op in self.ast:
         
     | 
| 
      
 411 
     | 
    
         
            +
                  val = self.ast_parse(op.src[0], accs, None, loaded_buffers)
         
     | 
| 
      
 412 
     | 
    
         
            +
                  self.global_store(op.arg.idx, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
         
     | 
| 
       387 
413 
     | 
    
         | 
| 
       388 
     | 
    
         
            -
                #  
     | 
| 
       389 
     | 
    
         
            -
                self. 
     | 
| 
      
 414 
     | 
    
         
            +
                # maybe graph the uops
         
     | 
| 
      
 415 
     | 
    
         
            +
                if DEBUG >= 5: self.uops.print()
         
     | 
| 
      
 416 
     | 
    
         
            +
                if getenv("GRAPHUOPS"): self.uops.graph()
         
     | 
| 
       390 
417 
     | 
    
         | 
| 
       391 
     | 
    
         
            -
                 
     | 
| 
       392 
     | 
    
         
            -
             
     | 
| 
       393 
     | 
    
         
            -
                  self.uop(UOps.ENDLOOP, None, [], (global_idxs+local_idxs, "global+local"))
         
     | 
| 
       394 
     | 
    
         
            -
                else:
         
     | 
| 
       395 
     | 
    
         
            -
                  # end the global loop
         
     | 
| 
       396 
     | 
    
         
            -
                  self.uop(UOps.ENDLOOP, None, [], (global_idxs, "global"))
         
     | 
| 
      
 418 
     | 
    
         
            +
                # restore backups
         
     | 
| 
      
 419 
     | 
    
         
            +
                self.sts, self.group_for_reduces, self.upcasted = sts_backup, gfr_backup, upc_backup
         
     | 
| 
       397 
420 
     | 
    
         | 
| 
       398 
     | 
    
         
            -
                #  
     | 
| 
       399 
     | 
    
         
            -
                 
     | 
| 
       400 
     | 
    
         
            -
                suffix = f"{'n'+str(Linearizer.kernel_cnt[self.function_name]-1)}" if Linearizer.kernel_cnt[self.function_name] > 1 else ""
         
     | 
| 
       401 
     | 
    
         
            -
                self.function_name, self.display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK')
         
     | 
| 
      
 421 
     | 
    
         
            +
                # set cache and return
         
     | 
| 
      
 422 
     | 
    
         
            +
                self.applied_opts_cache = self.applied_opts[:]
         
     | 
| 
       402 
423 
     | 
    
         
             
                return self
         
     | 
| 
       403 
424 
     | 
    
         | 
| 
       404 
     | 
    
         
            -
               
     | 
| 
       405 
     | 
    
         
            -
             
     | 
| 
       406 
     | 
    
         
            -
                 
     | 
| 
       407 
     | 
    
         
            -
                if  
     | 
| 
       408 
     | 
    
         
            -
                 
     | 
| 
       409 
     | 
    
         
            -
             
     | 
| 
       410 
     | 
    
         
            -
             
     | 
| 
       411 
     | 
    
         
            -
                 
     | 
| 
       412 
     | 
    
         
            -
             
     | 
| 
       413 
     | 
    
         
            -
             
     | 
| 
       414 
     | 
    
         
            -
             
     | 
| 
       415 
     | 
    
         
            -
             
     | 
| 
       416 
     | 
    
         
            -
                 
     | 
| 
       417 
     | 
    
         
            -
                if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers, ssa)  # cast isn't an ALU op
         
     | 
| 
       418 
     | 
    
         
            -
                if x.op in ReduceOps and not do_reduce: return acc
         
     | 
| 
       419 
     | 
    
         
            -
                # MULACC fusion. TODO: this is copied from Interpreted
         
     | 
| 
       420 
     | 
    
         
            -
                if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == BinaryOps.MUL:
         
     | 
| 
       421 
     | 
    
         
            -
                  x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
         
     | 
| 
       422 
     | 
    
         
            -
                if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL:
         
     | 
| 
       423 
     | 
    
         
            -
                  x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg)
         
     | 
| 
       424 
     | 
    
         
            -
                if x.op in {BinaryOps.ADD, BinaryOps.MUL}:
         
     | 
| 
       425 
     | 
    
         
            -
                  # Reorder sources to put constants first so get_grouped_maybe_float4 can fold the op
         
     | 
| 
       426 
     | 
    
         
            -
                  srcs = sorted(x.src, key=lambda x: (x.realized.__class__ != RawConst) if x.__class__ == LazyBuffer else 0)
         
     | 
| 
       427 
     | 
    
         
            -
                  x.src = tuple(srcs)
         
     | 
| 
       428 
     | 
    
         
            -
                values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
         
     | 
| 
       429 
     | 
    
         
            -
                ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
         
     | 
| 
      
 425 
     | 
    
         
            +
              def ast_parse(self, x:LazyOp, accs:Dict[LazyOp, List[UOp]], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], reduce_acc:Optional[List[UOp]]=None, cache=None) -> List[UOp]: # noqa: E501
         
     | 
| 
      
 426 
     | 
    
         
            +
                if cache is None: cache = {}
         
     | 
| 
      
 427 
     | 
    
         
            +
                if x in cache: return cache[x]
         
     | 
| 
      
 428 
     | 
    
         
            +
                if x.op in BufferOps: return loaded_buffers[x.arg]
         
     | 
| 
      
 429 
     | 
    
         
            +
                if x.op in [UnaryOps.CAST, UnaryOps.BITCAST]:
         
     | 
| 
      
 430 
     | 
    
         
            +
                  return [self.uops.add(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST,
         
     | 
| 
      
 431 
     | 
    
         
            +
                                        self.get_base_dtype(x.arg), (u,)) for u in self.ast_parse(x.src[0], accs, offs, loaded_buffers)]
         
     | 
| 
      
 432 
     | 
    
         
            +
                if x.op in ReduceOps and reduce_acc is None:
         
     | 
| 
      
 433 
     | 
    
         
            +
                  assert offs is None, "not available if we aren't doing reduce"
         
     | 
| 
      
 434 
     | 
    
         
            +
                  return accs[x]
         
     | 
| 
      
 435 
     | 
    
         
            +
             
     | 
| 
      
 436 
     | 
    
         
            +
                values = [self.ast_parse(v, accs, offs, loaded_buffers, cache=cache) for v in x.src]
         
     | 
| 
      
 437 
     | 
    
         
            +
                ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}
         
     | 
| 
       430 
438 
     | 
    
         
             
                if x.op in ops:
         
     | 
| 
       431 
     | 
    
         
            -
                   
     | 
| 
       432 
     | 
    
         
            -
             
     | 
| 
       433 
     | 
    
         
            -
                   
     | 
| 
       434 
     | 
    
         
            -
             
     | 
| 
       435 
     | 
    
         
            -
             
     | 
| 
       436 
     | 
    
         
            -
             
     | 
| 
       437 
     | 
    
         
            -
                  for  
     | 
| 
       438 
     | 
    
         
            -
                     
     | 
| 
       439 
     | 
    
         
            -
             
     | 
| 
       440 
     | 
    
         
            -
                 
     | 
| 
      
 439 
     | 
    
         
            +
                  assert reduce_acc is not None
         
     | 
| 
      
 440 
     | 
    
         
            +
                  ret: List[UOp] = []
         
     | 
| 
      
 441 
     | 
    
         
            +
                  acc, input_acc = reduce_acc, reduce_acc[:]
         
     | 
| 
      
 442 
     | 
    
         
            +
                  for val, off in zip(zip(*values), cast(List[int], offs)):
         
     | 
| 
      
 443 
     | 
    
         
            +
                    acc[off] = UOp.alu(ops[cast(ReduceOps, x.op)], *(val+(acc[off], )))
         
     | 
| 
      
 444 
     | 
    
         
            +
                    ret.append(acc[off])
         
     | 
| 
      
 445 
     | 
    
         
            +
                  for off in range(len(acc)):
         
     | 
| 
      
 446 
     | 
    
         
            +
                    if input_acc[off] != acc[off]:
         
     | 
| 
      
 447 
     | 
    
         
            +
                      acc[off] = self.uops.add(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]))
         
     | 
| 
      
 448 
     | 
    
         
            +
                else: ret = [UOp.alu(x.op, *vin) for vin in zip(*values)]
         
     | 
| 
      
 449 
     | 
    
         
            +
                cache[x] = ret
         
     | 
| 
      
 450 
     | 
    
         
            +
                return ret
         
     | 
| 
      
 451 
     | 
    
         
            +
             
     | 
| 
      
 452 
     | 
    
         
            +
              def to_program(self) -> Program:
         
     | 
| 
      
 453 
     | 
    
         
            +
                self.linearize()
         
     | 
| 
      
 454 
     | 
    
         
            +
                info = get_lazyop_info(self.ast[0])
         
     | 
| 
      
 455 
     | 
    
         
            +
                src = self.opts.render(to_function_name(self.name), self.uops)
         
     | 
| 
      
 456 
     | 
    
         
            +
                ops, mem = self.uops.flops_mem()
         
     | 
| 
      
 457 
     | 
    
         
            +
                run_count = prod((self.global_size if self.global_size else []) + (self.local_size if self.local_size else []))
         
     | 
| 
      
 458 
     | 
    
         
            +
                # NOTE: we use min here to ignore the indexing FLOPS
         
     | 
| 
      
 459 
     | 
    
         
            +
                return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
         
     | 
| 
      
 460 
     | 
    
         
            +
                               self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
         
     |