tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
 - tinygrad/codegen/kernel.py +572 -83
 - tinygrad/codegen/linearizer.py +415 -395
 - tinygrad/codegen/uops.py +415 -0
 - tinygrad/device.py +183 -0
 - tinygrad/dtype.py +113 -0
 - tinygrad/engine/__init__.py +0 -0
 - tinygrad/engine/graph.py +100 -0
 - tinygrad/engine/jit.py +195 -0
 - tinygrad/engine/realize.py +191 -0
 - tinygrad/engine/schedule.py +362 -0
 - tinygrad/engine/search.py +196 -0
 - tinygrad/{mlops.py → function.py} +76 -55
 - tinygrad/helpers.py +196 -89
 - tinygrad/lazy.py +210 -371
 - tinygrad/multi.py +169 -0
 - tinygrad/nn/__init__.py +202 -22
 - tinygrad/nn/datasets.py +7 -0
 - tinygrad/nn/optim.py +112 -32
 - tinygrad/nn/state.py +136 -39
 - tinygrad/ops.py +119 -202
 - tinygrad/renderer/__init__.py +61 -0
 - tinygrad/renderer/assembly.py +276 -0
 - tinygrad/renderer/cstyle.py +353 -166
 - tinygrad/renderer/llvmir.py +150 -138
 - tinygrad/runtime/autogen/amd_gpu.py +1900 -0
 - tinygrad/runtime/autogen/comgr.py +865 -0
 - tinygrad/runtime/autogen/cuda.py +5923 -0
 - tinygrad/runtime/autogen/hip.py +5909 -0
 - tinygrad/runtime/autogen/hsa.py +5761 -0
 - tinygrad/runtime/autogen/kfd.py +812 -0
 - tinygrad/runtime/autogen/nv_gpu.py +33328 -0
 - tinygrad/runtime/autogen/opencl.py +1795 -0
 - tinygrad/runtime/driver/hip_comgr.py +47 -0
 - tinygrad/runtime/driver/hsa.py +143 -0
 - tinygrad/runtime/graph/clang.py +38 -0
 - tinygrad/runtime/graph/cuda.py +81 -0
 - tinygrad/runtime/graph/hcq.py +143 -0
 - tinygrad/runtime/graph/hsa.py +171 -0
 - tinygrad/runtime/graph/metal.py +75 -0
 - tinygrad/runtime/ops_amd.py +564 -0
 - tinygrad/runtime/ops_clang.py +24 -77
 - tinygrad/runtime/ops_cuda.py +175 -89
 - tinygrad/runtime/ops_disk.py +56 -33
 - tinygrad/runtime/ops_gpu.py +92 -95
 - tinygrad/runtime/ops_hsa.py +278 -0
 - tinygrad/runtime/ops_llvm.py +39 -60
 - tinygrad/runtime/ops_metal.py +92 -74
 - tinygrad/runtime/ops_npy.py +9 -0
 - tinygrad/runtime/ops_nv.py +630 -0
 - tinygrad/runtime/ops_python.py +204 -0
 - tinygrad/shape/shapetracker.py +86 -254
 - tinygrad/shape/symbolic.py +166 -141
 - tinygrad/shape/view.py +296 -0
 - tinygrad/tensor.py +2619 -448
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
 - tinygrad-0.9.0.dist-info/METADATA +227 -0
 - tinygrad-0.9.0.dist-info/RECORD +60 -0
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
 - tinygrad/codegen/assembly.py +0 -190
 - tinygrad/codegen/optimizer.py +0 -379
 - tinygrad/codegen/search.py +0 -72
 - tinygrad/graph.py +0 -83
 - tinygrad/jit.py +0 -57
 - tinygrad/nn/image.py +0 -100
 - tinygrad/renderer/assembly_arm64.py +0 -169
 - tinygrad/renderer/assembly_ptx.py +0 -98
 - tinygrad/renderer/wgsl.py +0 -53
 - tinygrad/runtime/lib.py +0 -113
 - tinygrad/runtime/ops_cpu.py +0 -51
 - tinygrad/runtime/ops_hip.py +0 -82
 - tinygrad/runtime/ops_shm.py +0 -29
 - tinygrad/runtime/ops_torch.py +0 -30
 - tinygrad/runtime/ops_webgpu.py +0 -45
 - tinygrad-0.7.0.dist-info/METADATA +0 -212
 - tinygrad-0.7.0.dist-info/RECORD +0 -40
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
 
    
        tinygrad/shape/view.py
    ADDED
    
    | 
         @@ -0,0 +1,296 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from __future__ import annotations
         
     | 
| 
      
 2 
     | 
    
         
            +
            import functools, operator, itertools, math
         
     | 
| 
      
 3 
     | 
    
         
            +
            from dataclasses import dataclass
         
     | 
| 
      
 4 
     | 
    
         
            +
            from typing import Tuple, List, Optional, Dict, Set, cast
         
     | 
| 
      
 5 
     | 
    
         
            +
            from tinygrad.helpers import prod, all_int, argsort
         
     | 
| 
      
 6 
     | 
    
         
            +
            from tinygrad.shape.symbolic import Node, NumNode, Variable, sint
         
     | 
| 
      
 7 
     | 
    
         
            +
             
     | 
| 
      
 8 
     | 
    
         
            +
            @functools.lru_cache(maxsize=None)
         
     | 
| 
      
 9 
     | 
    
         
            +
            def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
         
     | 
| 
      
 10 
     | 
    
         
            +
              return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
            @functools.lru_cache(maxsize=None)
         
     | 
| 
      
 13 
     | 
    
         
            +
            def strides_for_shape(shape:Tuple[sint, ...]) -> Tuple[sint, ...]:
         
     | 
| 
      
 14 
     | 
    
         
            +
              if not shape: return ()
         
     | 
| 
      
 15 
     | 
    
         
            +
              strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))
         
     | 
| 
      
 16 
     | 
    
         
            +
              return canonicalize_strides(shape, strides[::-1])
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
            @functools.lru_cache(maxsize=None)
         
     | 
| 
      
 19 
     | 
    
         
            +
            def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]:
         
     | 
| 
      
 20 
     | 
    
         
            +
              # merge contiguous subparts or zero strided dims. ret = List[(merged_dims, stride, merged dims w/o zero stride), ...]
         
     | 
| 
      
 21 
     | 
    
         
            +
              if not shape: return tuple()
         
     | 
| 
      
 22 
     | 
    
         
            +
              assert len(shape) == len(strides)
         
     | 
| 
      
 23 
     | 
    
         
            +
              ret = [(shape[0], strides[0], shape[0] if strides[0] else 0)]
         
     | 
| 
      
 24 
     | 
    
         
            +
              # wrt merging zero strided dimensions
         
     | 
| 
      
 25 
     | 
    
         
            +
              merging = strides[0] == 0 and (mask[0][1] - mask[0][0] == 1 if mask else shape[0] == 1)
         
     | 
| 
      
 26 
     | 
    
         
            +
              for i, (sh, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
         
     | 
| 
      
 27 
     | 
    
         
            +
                if sh == 1: continue
         
     | 
| 
      
 28 
     | 
    
         
            +
                if merging or ret[-1][1] == sh * st: # mergeable
         
     | 
| 
      
 29 
     | 
    
         
            +
                  ret[-1] = (ret[-1][0] * sh, st, (sh if merging else ret[-1][2] * sh) if st else 0)
         
     | 
| 
      
 30 
     | 
    
         
            +
                else: ret.append((sh, st, sh if st else 0)) # begin new
         
     | 
| 
      
 31 
     | 
    
         
            +
                # merging ends with either non-zero strided dim or zero strided dim with mask range > 1
         
     | 
| 
      
 32 
     | 
    
         
            +
                merging = st == 0 and (mask[i][1] - mask[i][0] == 1 if mask else sh == 1)
         
     | 
| 
      
 33 
     | 
    
         
            +
              return tuple(ret)
         
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
            @functools.lru_cache(maxsize=None)
         
     | 
| 
      
 36 
     | 
    
         
            +
            def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tuple[Tuple[sint, sint], ...]], bool]:
         
     | 
| 
      
 37 
     | 
    
         
            +
              if view.mask is None: return view.mask, False
         
     | 
| 
      
 38 
     | 
    
         
            +
              if any(not isinstance(m[0], int) or not isinstance(m[1], int) for m in view.mask): return view.mask, True
         
     | 
| 
      
 39 
     | 
    
         
            +
              new_mask: List[Tuple[int, int]] = []
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
              r_masks, r_shape, r_new_shape = reversed(view.mask), reversed(view.shape), reversed(new_shape)
         
     | 
| 
      
 42 
     | 
    
         
            +
              curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
         
     | 
| 
      
 43 
     | 
    
         
            +
              if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
         
     | 
| 
      
 44 
     | 
    
         
            +
             
     | 
| 
      
 45 
     | 
    
         
            +
              while len(new_mask) < len(new_shape):
         
     | 
| 
      
 46 
     | 
    
         
            +
                (l, r), next_stride = mask, new_dim * curr_stride
         
     | 
| 
      
 47 
     | 
    
         
            +
             
     | 
| 
      
 48 
     | 
    
         
            +
                if old_dim >= next_stride: # need to split mask.
         
     | 
| 
      
 49 
     | 
    
         
            +
                  if old_dim == next_stride: # simply copy the mask and get next batch for merging
         
     | 
| 
      
 50 
     | 
    
         
            +
                    new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1))
         
     | 
| 
      
 51 
     | 
    
         
            +
                    curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
         
     | 
| 
      
 52 
     | 
    
         
            +
                    if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
                  else: # mask can only be splitted if reshape doesn't cut across the mask.
         
     | 
| 
      
 55 
     | 
    
         
            +
                    if (((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
         
     | 
| 
      
 56 
     | 
    
         
            +
                        or old_dim % next_stride != 0): return view.mask, True
         
     | 
| 
      
 57 
     | 
    
         
            +
                    new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
         
     | 
| 
      
 58 
     | 
    
         
            +
                    curr_stride, new_dim = next_stride,  next(r_new_shape, 1) # need to get mask for next dimension
         
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
                else:
         
     | 
| 
      
 61 
     | 
    
         
            +
                  next_mask = next(r_masks, (0, 1))
         
     | 
| 
      
 62 
     | 
    
         
            +
                  # combine if the mask can unfold continuously
         
     | 
| 
      
 63 
     | 
    
         
            +
                  if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return view.mask, True
         
     | 
| 
      
 64 
     | 
    
         
            +
                  mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
         
     | 
| 
      
 65 
     | 
    
         
            +
             
     | 
| 
      
 66 
     | 
    
         
            +
              for mask in r_masks: # if the old shape has leading 1s, need to make sure their mask is (0,1)
         
     | 
| 
      
 67 
     | 
    
         
            +
                if mask != (0, 1): return ((0, 0),) * len(new_shape), False # invalid mask
         
     | 
| 
      
 68 
     | 
    
         
            +
             
     | 
| 
      
 69 
     | 
    
         
            +
              return tuple(reversed(new_mask)), False
         
     | 
| 
      
 70 
     | 
    
         
            +
             
     | 
| 
      
 71 
     | 
    
         
            +
            def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
         
     | 
| 
      
 72 
     | 
    
         
            +
              strides = strides_for_shape(shape)
         
     | 
| 
      
 73 
     | 
    
         
            +
              result = []
         
     | 
| 
      
 74 
     | 
    
         
            +
              for stride in strides:
         
     | 
| 
      
 75 
     | 
    
         
            +
                here = offs // stride if stride else 0
         
     | 
| 
      
 76 
     | 
    
         
            +
                result.append(here)
         
     | 
| 
      
 77 
     | 
    
         
            +
                offs -= here * stride
         
     | 
| 
      
 78 
     | 
    
         
            +
              return result
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
            @dataclass(frozen=True)
         
     | 
| 
      
 81 
     | 
    
         
            +
            class View:
         
     | 
| 
      
 82 
     | 
    
         
            +
              shape:Tuple[sint, ...]
         
     | 
| 
      
 83 
     | 
    
         
            +
              strides:Tuple[sint, ...]
         
     | 
| 
      
 84 
     | 
    
         
            +
              offset:sint
         
     | 
| 
      
 85 
     | 
    
         
            +
              mask:Optional[Tuple[Tuple[sint, sint], ...]]
         
     | 
| 
      
 86 
     | 
    
         
            +
              contiguous:bool
         
     | 
| 
      
 87 
     | 
    
         
            +
             
     | 
| 
      
 88 
     | 
    
         
            +
              @functools.lru_cache(maxsize=None)  # pylint: disable=method-cache-max-size-none
         
     | 
| 
      
 89 
     | 
    
         
            +
              def size(self) -> int:
         
     | 
| 
      
 90 
     | 
    
         
            +
                # NOTE: Variable and the Node derived from it in symbolic shapes can only have int as max.
         
     | 
| 
      
 91 
     | 
    
         
            +
                ret = prod([x.max if isinstance(x, Node) else x for x in self.shape])
         
     | 
| 
      
 92 
     | 
    
         
            +
                assert isinstance(ret, int), f"{ret=} is not int"
         
     | 
| 
      
 93 
     | 
    
         
            +
                return ret
         
     | 
| 
      
 94 
     | 
    
         
            +
             
     | 
| 
      
 95 
     | 
    
         
            +
              @staticmethod
         
     | 
| 
      
 96 
     | 
    
         
            +
              @functools.lru_cache(maxsize=None)
         
     | 
| 
      
 97 
     | 
    
         
            +
              def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
         
     | 
| 
      
 98 
     | 
    
         
            +
                strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
         
     | 
| 
      
 99 
     | 
    
         
            +
                # canonicalize empty mask
         
     | 
| 
      
 100 
     | 
    
         
            +
                if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
         
     | 
| 
      
 101 
     | 
    
         
            +
                contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
         
     | 
| 
      
 102 
     | 
    
         
            +
                # if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
         
     | 
| 
      
 103 
     | 
    
         
            +
                # then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
         
     | 
| 
      
 104 
     | 
    
         
            +
                # TODO: assert comparison with LtNode to avoid mis-using symbolic
         
     | 
| 
      
 105 
     | 
    
         
            +
                if mask and any(elim := [not (b+1 < e) for b,e in mask]):
         
     | 
| 
      
 106 
     | 
    
         
            +
                  if any(not (b < e) for b,e in mask):
         
     | 
| 
      
 107 
     | 
    
         
            +
                    strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
         
     | 
| 
      
 108 
     | 
    
         
            +
                  offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
         
     | 
| 
      
 109 
     | 
    
         
            +
                  strides = tuple(0 if e else st for st,e in zip(strides, elim))
         
     | 
| 
      
 110 
     | 
    
         
            +
                return View(shape, strides, offset, mask, contiguous)
         
     | 
| 
      
 111 
     | 
    
         
            +
             
     | 
| 
      
 112 
     | 
    
         
            +
              @functools.lru_cache(None)  # pylint: disable=method-cache-max-size-none
         
     | 
| 
      
 113 
     | 
    
         
            +
              def vars(self) -> Set[Variable]:
         
     | 
| 
      
 114 
     | 
    
         
            +
                flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
         
     | 
| 
      
 115 
     | 
    
         
            +
                return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], set())
         
     | 
| 
      
 116 
     | 
    
         
            +
             
     | 
| 
      
 117 
     | 
    
         
            +
              @functools.lru_cache(None)  # pylint: disable=method-cache-max-size-none
         
     | 
| 
      
 118 
     | 
    
         
            +
              def unbind(self) -> Tuple[View, Dict[Variable, int]]:
         
     | 
| 
      
 119 
     | 
    
         
            +
                var_unboundvar_val = [(v, v.unbind()) for v in self.vars() if v.val is not None]
         
     | 
| 
      
 120 
     | 
    
         
            +
                unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
         
     | 
| 
      
 121 
     | 
    
         
            +
                new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])
         
     | 
| 
      
 122 
     | 
    
         
            +
                new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides])
         
     | 
| 
      
 123 
     | 
    
         
            +
                new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars)
         
     | 
| 
      
 124 
     | 
    
         
            +
                new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars),
         
     | 
| 
      
 125 
     | 
    
         
            +
                                  b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
         
     | 
| 
      
 126 
     | 
    
         
            +
                return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
         
     | 
| 
      
 127 
     | 
    
         
            +
             
     | 
| 
      
 128 
     | 
    
         
            +
              @functools.lru_cache(maxsize=None)  # pylint: disable=method-cache-max-size-none
         
     | 
| 
      
 129 
     | 
    
         
            +
              def __add__(self, vm1:View) -> Optional[View]:
         
     | 
| 
      
 130 
     | 
    
         
            +
                vm2 = self
         
     | 
| 
      
 131 
     | 
    
         
            +
                if vm2.contiguous: return vm1
         
     | 
| 
      
 132 
     | 
    
         
            +
                if vm1.contiguous and vm1.shape == vm2.shape: return vm2
         
     | 
| 
      
 133 
     | 
    
         
            +
                if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
         
     | 
| 
      
 134 
     | 
    
         
            +
                if vm1.mask:
         
     | 
| 
      
 135 
     | 
    
         
            +
                  for b,e in vm1.mask:
         
     | 
| 
      
 136 
     | 
    
         
            +
                    if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
         
     | 
| 
      
 137 
     | 
    
         
            +
                  return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
         
     | 
| 
      
 138 
     | 
    
         
            +
             
     | 
| 
      
 139 
     | 
    
         
            +
                # Project vm1's offset and strides on to vm2.
         
     | 
| 
      
 140 
     | 
    
         
            +
                origin = un1d(vm2.shape, vm1.offset)
         
     | 
| 
      
 141 
     | 
    
         
            +
                terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
         
     | 
| 
      
 142 
     | 
    
         
            +
                strides: List[sint] = [0] * len(vm1.shape)
         
     | 
| 
      
 143 
     | 
    
         
            +
                for d1, st in enumerate(vm1.strides):
         
     | 
| 
      
 144 
     | 
    
         
            +
                  if st == 0: continue
         
     | 
| 
      
 145 
     | 
    
         
            +
                  for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
         
     | 
| 
      
 146 
     | 
    
         
            +
                    if (s1 := s1 - o) == 0: continue
         
     | 
| 
      
 147 
     | 
    
         
            +
                    terms[d2].append((d1, s1))
         
     | 
| 
      
 148 
     | 
    
         
            +
                    strides[d1] += s1 * vm2.strides[d2]
         
     | 
| 
      
 149 
     | 
    
         
            +
             
     | 
| 
      
 150 
     | 
    
         
            +
                # Merge dimensions in vm2 if required.
         
     | 
| 
      
 151 
     | 
    
         
            +
                # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
         
     | 
| 
      
 152 
     | 
    
         
            +
                idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
         
     | 
| 
      
 153 
     | 
    
         
            +
                merged_size, merged_term = 1, NumNode(0)
         
     | 
| 
      
 154 
     | 
    
         
            +
                extents: List[Tuple[sint, Node]] = []
         
     | 
| 
      
 155 
     | 
    
         
            +
                for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
         
     | 
| 
      
 156 
     | 
    
         
            +
                  merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
         
     | 
| 
      
 157 
     | 
    
         
            +
                  merged_size *= s
         
     | 
| 
      
 158 
     | 
    
         
            +
                  if not (merged_term >= merged_size) and not (merged_term < 0):
         
     | 
| 
      
 159 
     | 
    
         
            +
                    extents.append((merged_size, merged_term))
         
     | 
| 
      
 160 
     | 
    
         
            +
                    merged_size, merged_term = 1, NumNode(0)
         
     | 
| 
      
 161 
     | 
    
         
            +
                if merged_term: return None
         
     | 
| 
      
 162 
     | 
    
         
            +
                if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
         
     | 
| 
      
 163 
     | 
    
         
            +
                  return (reshaped_vm2 := vm2.reshape(vm2_shape)) and reshaped_vm2 + vm1
         
     | 
| 
      
 164 
     | 
    
         
            +
             
     | 
| 
      
 165 
     | 
    
         
            +
                if vm2.mask:
         
     | 
| 
      
 166 
     | 
    
         
            +
                  # Try to project vm2's mask on to vm1.
         
     | 
| 
      
 167 
     | 
    
         
            +
                  newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
         
     | 
| 
      
 168 
     | 
    
         
            +
                  for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
         
     | 
| 
      
 169 
     | 
    
         
            +
                    if not (t.min < b or t.max >= e): continue
         
     | 
| 
      
 170 
     | 
    
         
            +
                    if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
         
     | 
| 
      
 171 
     | 
    
         
            +
                      bad = True
         
     | 
| 
      
 172 
     | 
    
         
            +
                      continue
         
     | 
| 
      
 173 
     | 
    
         
            +
                    term = terms[d2]
         
     | 
| 
      
 174 
     | 
    
         
            +
                    if len(term) != 1:
         
     | 
| 
      
 175 
     | 
    
         
            +
                      if not term and newe: newe[0] = 0
         
     | 
| 
      
 176 
     | 
    
         
            +
                      else: bad = True
         
     | 
| 
      
 177 
     | 
    
         
            +
                      continue
         
     | 
| 
      
 178 
     | 
    
         
            +
                    d1, s1 = term[0]
         
     | 
| 
      
 179 
     | 
    
         
            +
                    if not isinstance(s1, int) or not isinstance(newe[d1], int):
         
     | 
| 
      
 180 
     | 
    
         
            +
                      bad = True
         
     | 
| 
      
 181 
     | 
    
         
            +
                      continue
         
     | 
| 
      
 182 
     | 
    
         
            +
                    newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
         
     | 
| 
      
 183 
     | 
    
         
            +
                    newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
         
     | 
| 
      
 184 
     | 
    
         
            +
             
     | 
| 
      
 185 
     | 
    
         
            +
                  # If any of vm1 was masked off, try again with that mask in place.
         
     | 
| 
      
 186 
     | 
    
         
            +
                  for b, e, s in zip(newb, newe, vm1.shape):
         
     | 
| 
      
 187 
     | 
    
         
            +
                    if b != 0 or e != s:
         
     | 
| 
      
 188 
     | 
    
         
            +
                      return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
         
     | 
| 
      
 189 
     | 
    
         
            +
                  # Otherwise if vm2's mask was violated, then cannot merge.
         
     | 
| 
      
 190 
     | 
    
         
            +
                  if bad: return None
         
     | 
| 
      
 191 
     | 
    
         
            +
             
     | 
| 
      
 192 
     | 
    
         
            +
                return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
         
     | 
| 
      
 193 
     | 
    
         
            +
             
     | 
| 
      
 194 
     | 
    
         
            +
              @functools.lru_cache(maxsize=None)  # pylint: disable=method-cache-max-size-none
         
     | 
| 
      
 195 
     | 
    
         
            +
              def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]:
         
     | 
| 
      
 196 
     | 
    
         
            +
                ret = View.create(self.shape)
         
     | 
| 
      
 197 
     | 
    
         
            +
                if self.mask: ret = ret.shrink(self.mask)
         
     | 
| 
      
 198 
     | 
    
         
            +
                ret = ret.stride(tuple(-1 if x < 0 else 1 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
         
     | 
| 
      
 199 
     | 
    
         
            +
                return ret if prod(ret.shape) == prod(out_shape) else None   # don't support shrink, expand, or stride != (-1, 1)
         
     | 
| 
      
 200 
     | 
    
         
            +
             
     | 
| 
      
 201 
     | 
    
         
            +
              @functools.lru_cache(maxsize=None)  # pylint: disable=method-cache-max-size-none
         
     | 
| 
      
 202 
     | 
    
         
            +
              def minify(self):
         
     | 
| 
      
 203 
     | 
    
         
            +
                min_shape = tuple(x[0] for x in _merge_dims(self.shape, self.strides, self.mask))
         
     | 
| 
      
 204 
     | 
    
         
            +
                return nv if (nv := self.reshape(min_shape)) else self
         
     | 
| 
      
 205 
     | 
    
         
            +
             
     | 
| 
      
 206 
     | 
    
         
            +
              def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
         
     | 
| 
      
 207 
     | 
    
         
            +
                offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
         
     | 
| 
      
 208 
     | 
    
         
            +
                if self.mask:
         
     | 
| 
      
 209 
     | 
    
         
            +
                  # move the old mask
         
     | 
| 
      
 210 
     | 
    
         
            +
                  nmask = tuple([(max(0, min(mx-ax,ay-ax)), max(0, min(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)])
         
     | 
| 
      
 211 
     | 
    
         
            +
                  # merge the masks if we have two
         
     | 
| 
      
 212 
     | 
    
         
            +
                  mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
         
     | 
| 
      
 213 
     | 
    
         
            +
                shape = [y-x for x,y in arg]
         
     | 
| 
      
 214 
     | 
    
         
            +
                if mask is not None and all(m[0] == 0 and m[1] == s for m,s in zip(mask, shape)): mask = None
         
     | 
| 
      
 215 
     | 
    
         
            +
                return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
         
     | 
| 
      
 216 
     | 
    
         
            +
             
     | 
| 
      
 217 
     | 
    
         
            +
              @functools.lru_cache(maxsize=None)  # pylint: disable=method-cache-max-size-none
         
     | 
| 
      
 218 
     | 
    
         
            +
              def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
         
     | 
| 
      
 219 
     | 
    
         
            +
                assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape), f"{self.shape=}, {arg=}"
         
     | 
| 
      
 220 
     | 
    
         
            +
                if any(b or e for b, e in arg):
         
     | 
| 
      
 221 
     | 
    
         
            +
                  zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
         
     | 
| 
      
 222 
     | 
    
         
            +
                  mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])
         
     | 
| 
      
 223 
     | 
    
         
            +
                  return self.__unsafe_resize(zvarg, mask=mask)
         
     | 
| 
      
 224 
     | 
    
         
            +
                return self
         
     | 
| 
      
 225 
     | 
    
         
            +
             
     | 
| 
      
 226 
     | 
    
         
            +
              @functools.lru_cache(maxsize=None)  # pylint: disable=method-cache-max-size-none
         
     | 
| 
      
 227 
     | 
    
         
            +
              def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
         
     | 
| 
      
 228 
     | 
    
         
            +
                assert all((0<=b<=e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
         
     | 
| 
      
 229 
     | 
    
         
            +
                return self.__unsafe_resize(arg)
         
     | 
| 
      
 230 
     | 
    
         
            +
             
     | 
| 
      
 231 
     | 
    
         
            +
              @functools.lru_cache(maxsize=None)  # pylint: disable=method-cache-max-size-none
         
     | 
| 
      
 232 
     | 
    
         
            +
              def expand(self, new_shape: Tuple[sint, ...]) -> View:
         
     | 
| 
      
 233 
     | 
    
         
            +
                if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}")
         
     | 
| 
      
 234 
     | 
    
         
            +
                if 0 in self.shape:
         
     | 
| 
      
 235 
     | 
    
         
            +
                  assert all((s == x == 0) or (s > 0 and (x % s) == 0) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
         
     | 
| 
      
 236 
     | 
    
         
            +
                  return View.create(new_shape)
         
     | 
| 
      
 237 
     | 
    
         
            +
                assert all((s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}"
         
     | 
| 
      
 238 
     | 
    
         
            +
                # NOTE: can the mask ever be (0,0)?
         
     | 
| 
      
 239 
     | 
    
         
            +
                mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
         
     | 
| 
      
 240 
     | 
    
         
            +
                return View.create(new_shape, self.strides, self.offset, mask)
         
     | 
| 
      
 241 
     | 
    
         
            +
             
     | 
| 
      
 242 
     | 
    
         
            +
              @functools.lru_cache(maxsize=None)  # pylint: disable=method-cache-max-size-none
         
     | 
| 
      
 243 
     | 
    
         
            +
              def permute(self, axis: Tuple[int, ...]) -> View:
         
     | 
| 
      
 244 
     | 
    
         
            +
                assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
         
     | 
| 
      
 245 
     | 
    
         
            +
                assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
         
     | 
| 
      
 246 
     | 
    
         
            +
                return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
         
     | 
| 
      
 247 
     | 
    
         
            +
                                   tuple(self.mask[a] for a in axis) if self.mask is not None else None)
         
     | 
| 
      
 248 
     | 
    
         
            +
             
     | 
| 
      
 249 
     | 
    
         
            +
              @functools.lru_cache(maxsize=None)  # pylint: disable=method-cache-max-size-none
         
     | 
| 
      
 250 
     | 
    
         
            +
              def stride(self, mul: Tuple[int, ...]) -> View:
         
     | 
| 
      
 251 
     | 
    
         
            +
                # except for the negative case, you can build this from the others. invertible in the negative case
         
     | 
| 
      
 252 
     | 
    
         
            +
                assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}"
         
     | 
| 
      
 253 
     | 
    
         
            +
                strides = tuple([z*m for z,m in zip(self.strides, mul)])
         
     | 
| 
      
 254 
     | 
    
         
            +
                new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)])
         
     | 
| 
      
 255 
     | 
    
         
            +
                offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
         
     | 
| 
      
 256 
     | 
    
         
            +
                mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) \
         
     | 
| 
      
 257 
     | 
    
         
            +
                              for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
         
     | 
| 
      
 258 
     | 
    
         
            +
                return View.create(new_shape, strides, self.offset + offset, mask)
         
     | 
| 
      
 259 
     | 
    
         
            +
             
     | 
| 
      
 260 
     | 
    
         
            +
              @functools.lru_cache(maxsize=None)  # pylint: disable=method-cache-max-size-none
         
     | 
| 
      
 261 
     | 
    
         
            +
              def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
         
     | 
| 
      
 262 
     | 
    
         
            +
                if self.shape == new_shape: return self
         
     | 
| 
      
 263 
     | 
    
         
            +
             
     | 
| 
      
 264 
     | 
    
         
            +
                assert all(x >= 0 for x in new_shape), f"shape can't contain negative numbers {new_shape}"
         
     | 
| 
      
 265 
     | 
    
         
            +
                if 0 in self.shape:
         
     | 
| 
      
 266 
     | 
    
         
            +
                  assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
         
     | 
| 
      
 267 
     | 
    
         
            +
                  return View.create(new_shape)
         
     | 
| 
      
 268 
     | 
    
         
            +
                # check for the same size
         
     | 
| 
      
 269 
     | 
    
         
            +
                if all_int(self.shape):
         
     | 
| 
      
 270 
     | 
    
         
            +
                  assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
         
     | 
| 
      
 271 
     | 
    
         
            +
                  if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]):
         
     | 
| 
      
 272 
     | 
    
         
            +
                    raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
         
     | 
| 
      
 273 
     | 
    
         
            +
             
     | 
| 
      
 274 
     | 
    
         
            +
                if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
         
     | 
| 
      
 275 
     | 
    
         
            +
             
     | 
| 
      
 276 
     | 
    
         
            +
                # after the asserts, it's okay to check contiguous
         
     | 
| 
      
 277 
     | 
    
         
            +
                if self.contiguous: return View.create(new_shape)
         
     | 
| 
      
 278 
     | 
    
         
            +
             
     | 
| 
      
 279 
     | 
    
         
            +
                strides, r_new_shape = [], reversed(new_shape)
         
     | 
| 
      
 280 
     | 
    
         
            +
                for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
         
     | 
| 
      
 281 
     | 
    
         
            +
                  acc = 1
         
     | 
| 
      
 282 
     | 
    
         
            +
                  # TODO: this <= and != is for symbolic!?
         
     | 
| 
      
 283 
     | 
    
         
            +
                  while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, None)):
         
     | 
| 
      
 284 
     | 
    
         
            +
                    strides.append(new_stride)
         
     | 
| 
      
 285 
     | 
    
         
            +
                    if new_dim != 1: new_stride *= (new_dim if (acc :=  acc * new_dim) < real_dim else 0)
         
     | 
| 
      
 286 
     | 
    
         
            +
                  if acc != merged_dim: break
         
     | 
| 
      
 287 
     | 
    
         
            +
                else:
         
     | 
| 
      
 288 
     | 
    
         
            +
                  strides += [0,] * (len(new_shape) - len(strides))
         
     | 
| 
      
 289 
     | 
    
         
            +
                  new_mask, extra = _reshape_mask(self, new_shape)
         
     | 
| 
      
 290 
     | 
    
         
            +
                  if not extra:
         
     | 
| 
      
 291 
     | 
    
         
            +
                    new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask) if new_mask else new_shape, tuple(reversed(strides)))
         
     | 
| 
      
 292 
     | 
    
         
            +
                    extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
         
     | 
| 
      
 293 
     | 
    
         
            +
                                   (sum(m[0] * s for m,s in zip(new_mask, new_strides)) if new_mask else 0)
         
     | 
| 
      
 294 
     | 
    
         
            +
                    return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
         
     | 
| 
      
 295 
     | 
    
         
            +
             
     | 
| 
      
 296 
     | 
    
         
            +
                return None
         
     |