tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
 - tinygrad/codegen/expander.py +121 -0
 - tinygrad/codegen/kernel.py +35 -37
 - tinygrad/codegen/linearize.py +19 -10
 - tinygrad/codegen/lowerer.py +31 -8
 - tinygrad/codegen/symbolic.py +476 -0
 - tinygrad/codegen/transcendental.py +10 -0
 - tinygrad/device.py +28 -11
 - tinygrad/dtype.py +12 -3
 - tinygrad/engine/jit.py +3 -2
 - tinygrad/engine/multi.py +0 -1
 - tinygrad/engine/realize.py +7 -4
 - tinygrad/engine/schedule.py +227 -255
 - tinygrad/engine/search.py +20 -27
 - tinygrad/gradient.py +3 -0
 - tinygrad/helpers.py +7 -4
 - tinygrad/nn/state.py +2 -2
 - tinygrad/ops.py +64 -329
 - tinygrad/renderer/__init__.py +19 -3
 - tinygrad/renderer/cstyle.py +39 -18
 - tinygrad/renderer/llvmir.py +55 -18
 - tinygrad/renderer/ptx.py +6 -2
 - tinygrad/renderer/wgsl.py +20 -12
 - tinygrad/runtime/autogen/libc.py +404 -71
 - tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
 - tinygrad/runtime/autogen/webgpu.py +6985 -0
 - tinygrad/runtime/graph/metal.py +28 -29
 - tinygrad/runtime/ops_amd.py +37 -34
 - tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
 - tinygrad/runtime/ops_disk.py +1 -1
 - tinygrad/runtime/ops_dsp.py +59 -33
 - tinygrad/runtime/ops_llvm.py +14 -12
 - tinygrad/runtime/ops_metal.py +78 -62
 - tinygrad/runtime/ops_nv.py +9 -6
 - tinygrad/runtime/ops_python.py +5 -5
 - tinygrad/runtime/ops_webgpu.py +200 -38
 - tinygrad/runtime/support/am/amdev.py +23 -11
 - tinygrad/runtime/support/am/ip.py +10 -10
 - tinygrad/runtime/support/elf.py +2 -0
 - tinygrad/runtime/support/hcq.py +7 -5
 - tinygrad/runtime/support/llvm.py +8 -14
 - tinygrad/shape/shapetracker.py +3 -2
 - tinygrad/shape/view.py +2 -3
 - tinygrad/spec.py +21 -20
 - tinygrad/tensor.py +150 -90
 - tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
 - tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
 - tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
 - tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
 - tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
 - tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
 - tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
 - tinygrad/viz/index.html +544 -0
 - tinygrad/viz/perfetto.html +178 -0
 - tinygrad/viz/serve.py +205 -0
 - {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
 - tinygrad-0.10.2.dist-info/RECORD +99 -0
 - tinygrad/codegen/rewriter.py +0 -516
 - tinygrad-0.10.1.dist-info/RECORD +0 -86
 - {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
 - {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
 - {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
 
    
        tinygrad/ops.py
    CHANGED
    
    | 
         @@ -1,9 +1,8 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            from __future__ import annotations
         
     | 
| 
       2 
     | 
    
         
            -
            from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type,  
     | 
| 
      
 2 
     | 
    
         
            +
            from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type, get_args
         
     | 
| 
       3 
3 
     | 
    
         
             
            import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
         
     | 
| 
       4 
4 
     | 
    
         
             
            from enum import auto, IntEnum, Enum
         
     | 
| 
       5 
5 
     | 
    
         
             
            from dataclasses import dataclass, field
         
     | 
| 
       6 
     | 
    
         
            -
            from collections import defaultdict
         
     | 
| 
       7 
6 
     | 
    
         
             
            from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
         
     | 
| 
       8 
7 
     | 
    
         
             
            from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, _METADATA, flatten
         
     | 
| 
       9 
8 
     | 
    
         
             
            from tinygrad.helpers import PICKLE_BUFFERS, SPLIT_REDUCEOP, DEBUG, dedup
         
     | 
| 
         @@ -89,11 +88,12 @@ class MathTrait(SimpleMathTrait): 
     | 
|
| 
       89 
88 
     | 
    
         
             
              def sin(self): return self.alu(Ops.SIN)
         
     | 
| 
       90 
89 
     | 
    
         
             
              def log2(self): return self.alu(Ops.LOG2)
         
     | 
| 
       91 
90 
     | 
    
         
             
              def exp2(self): return self.alu(Ops.EXP2)
         
     | 
| 
      
 91 
     | 
    
         
            +
              def pow(self, x): return self.alu(Ops.POW, self.ufix(x))
         
     | 
| 
       92 
92 
     | 
    
         | 
| 
       93 
93 
     | 
    
         
             
            # the order of these Ops controls the order of the toposort
         
     | 
| 
       94 
94 
     | 
    
         
             
            class Ops(FastEnum):
         
     | 
| 
       95 
95 
     | 
    
         
             
              # uops that aren't rendered
         
     | 
| 
       96 
     | 
    
         
            -
              SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto();  
     | 
| 
      
 96 
     | 
    
         
            +
              NAME = auto(); SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); KERNEL = auto(); UNIQUE = auto() # noqa: E702
         
     | 
| 
       97 
97 
     | 
    
         | 
| 
       98 
98 
     | 
    
         
             
              # TODO: empty continues to exist because of tensor
         
     | 
| 
       99 
99 
     | 
    
         
             
              EMPTY = auto()
         
     | 
| 
         @@ -117,7 +117,7 @@ class Ops(FastEnum): 
     | 
|
| 
       117 
117 
     | 
    
         
             
              REDUCE_AXIS = auto()
         
     | 
| 
       118 
118 
     | 
    
         | 
| 
       119 
119 
     | 
    
         
             
              # helper ops
         
     | 
| 
       120 
     | 
    
         
            -
              GEP = auto(); VECTORIZE = auto() # noqa: E702
         
     | 
| 
      
 120 
     | 
    
         
            +
              GEP = auto(); VECTORIZE = auto(); CAT = auto() # noqa: E702
         
     | 
| 
       121 
121 
     | 
    
         | 
| 
       122 
122 
     | 
    
         
             
              # UnaryOps
         
     | 
| 
       123 
123 
     | 
    
         
             
              CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
         
     | 
| 
         @@ -133,7 +133,7 @@ class Ops(FastEnum): 
     | 
|
| 
       133 
133 
     | 
    
         | 
| 
       134 
134 
     | 
    
         
             
              # BinaryOps
         
     | 
| 
       135 
135 
     | 
    
         
             
              ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
         
     | 
| 
       136 
     | 
    
         
            -
              SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702
         
     | 
| 
      
 136 
     | 
    
         
            +
              SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
         
     | 
| 
       137 
137 
     | 
    
         | 
| 
       138 
138 
     | 
    
         
             
              # TernaryOps
         
     | 
| 
       139 
139 
     | 
    
         
             
              WHERE = auto(); MULACC = auto() # noqa: E702
         
     | 
| 
         @@ -151,18 +151,19 @@ class Ops(FastEnum): 
     | 
|
| 
       151 
151 
     | 
    
         
             
              # device
         
     | 
| 
       152 
152 
     | 
    
         
             
              DEVICE = auto()
         
     | 
| 
       153 
153 
     | 
    
         
             
              MULTI = auto()
         
     | 
| 
      
 154 
     | 
    
         
            +
              CUSTOM = auto()
         
     | 
| 
       154 
155 
     | 
    
         | 
| 
       155 
156 
     | 
    
         
             
            class GroupOp:
         
     | 
| 
       156 
157 
     | 
    
         
             
              Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
         
     | 
| 
       157 
158 
     | 
    
         
             
              Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
         
     | 
| 
       158 
     | 
    
         
            -
                        Ops.SUB, Ops.FDIV}
         
     | 
| 
      
 159 
     | 
    
         
            +
                        Ops.SUB, Ops.FDIV, Ops.POW}
         
     | 
| 
       159 
160 
     | 
    
         
             
              Ternary = {Ops.WHERE, Ops.MULACC}
         
     | 
| 
       160 
161 
     | 
    
         
             
              ALU = set.union(Unary, Binary, Ternary)
         
     | 
| 
       161 
162 
     | 
    
         | 
| 
       162 
163 
     | 
    
         
             
              Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
         
     | 
| 
       163 
164 
     | 
    
         
             
              Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
         
     | 
| 
       164 
165 
     | 
    
         | 
| 
       165 
     | 
    
         
            -
              Buffer = {Ops.LOAD, Ops. 
     | 
| 
      
 166 
     | 
    
         
            +
              Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
         
     | 
| 
       166 
167 
     | 
    
         
             
              Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
         
     | 
| 
       167 
168 
     | 
    
         | 
| 
       168 
169 
     | 
    
         
             
              # BinaryOps that can be flipped
         
     | 
| 
         @@ -175,21 +176,21 @@ class GroupOp: 
     | 
|
| 
       175 
176 
     | 
    
         
             
              Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
         
     | 
| 
       176 
177 
     | 
    
         | 
| 
       177 
178 
     | 
    
         
             
              # do not preserve f(0) = 0
         
     | 
| 
       178 
     | 
    
         
            -
              UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
         
     | 
| 
      
 179 
     | 
    
         
            +
              UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
         
     | 
| 
       179 
180 
     | 
    
         | 
| 
       180 
181 
     | 
    
         
             
              All = set(Ops)
         
     | 
| 
       181 
182 
     | 
    
         | 
| 
       182 
183 
     | 
    
         
             
            # some BUFFER ops can be processed with only a view
         
     | 
| 
       183 
     | 
    
         
            -
            view_supported_devices = {"LLVM", " 
     | 
| 
      
 184 
     | 
    
         
            +
            view_supported_devices = {"LLVM", "CPU", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
         
     | 
| 
       184 
185 
     | 
    
         | 
| 
       185 
186 
     | 
    
         
             
            # https://en.wikipedia.org/wiki/Identity_element
         
     | 
| 
       186 
187 
     | 
    
         
             
            def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
         
     | 
| 
       187 
188 
     | 
    
         | 
| 
       188 
     | 
    
         
            -
            def can_pad(u:UOp, edges:dict[UOp,  
     | 
| 
      
 189 
     | 
    
         
            +
            def can_pad(u:UOp, edges:dict[UOp, None], cache:dict[UOp, None]) -> bool:
         
     | 
| 
       189 
190 
     | 
    
         
             
              if u.op in GroupOp.UnsafePad: return False
         
     | 
| 
       190 
     | 
    
         
            -
              if  
     | 
| 
       191 
     | 
    
         
            -
               
     | 
| 
       192 
     | 
    
         
            -
              return all(can_pad(x.base, edges,  
     | 
| 
      
 191 
     | 
    
         
            +
              if u in edges or u in cache: return True
         
     | 
| 
      
 192 
     | 
    
         
            +
              cache[u] = None
         
     | 
| 
      
 193 
     | 
    
         
            +
              return all(can_pad(x.base, edges, cache) for x in u.src)
         
     | 
| 
       193 
194 
     | 
    
         | 
| 
       194 
195 
     | 
    
         
             
            # With True as the default, this matches the old symbolic behavior
         
     | 
| 
       195 
196 
     | 
    
         
             
            def resolve(x:UOp|bool, default:bool=True):
         
     | 
| 
         @@ -289,6 +290,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): 
     | 
|
| 
       289 
290 
     | 
    
         
             
                  return ShapeTracker.from_shape(
         
     | 
| 
       290 
291 
     | 
    
         
             
                    tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)))
         
     | 
| 
       291 
292 
     | 
    
         
             
                if self.op is Ops.BUFFER: return ShapeTracker.from_shape((self.size,))
         
     | 
| 
      
 293 
     | 
    
         
            +
                if self.op is Ops.KERNEL: return ShapeTracker.from_shape(self.arg.ast.shape)
         
     | 
| 
       292 
294 
     | 
    
         
             
                # these ops define a ShapeTracker from the arg
         
     | 
| 
       293 
295 
     | 
    
         
             
                if self.op is Ops.VIEW: return self.arg
         
     | 
| 
       294 
296 
     | 
    
         
             
                if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
         
     | 
| 
         @@ -314,11 +316,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass): 
     | 
|
| 
       314 
316 
     | 
    
         
             
              @property
         
     | 
| 
       315 
317 
     | 
    
         
             
              def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
         
     | 
| 
       316 
318 
     | 
    
         
             
              @property
         
     | 
| 
       317 
     | 
    
         
            -
              def size(self) -> int: return self.arg 
     | 
| 
      
 319 
     | 
    
         
            +
              def size(self) -> int: return self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
         
     | 
| 
       318 
320 
     | 
    
         | 
| 
       319 
321 
     | 
    
         
             
              # *** uop evaluation ***
         
     | 
| 
       320 
322 
     | 
    
         | 
| 
       321 
323 
     | 
    
         
             
              def simplify(self):
         
     | 
| 
      
 324 
     | 
    
         
            +
                # late import!
         
     | 
| 
      
 325 
     | 
    
         
            +
                from tinygrad.codegen.symbolic import symbolic
         
     | 
| 
       322 
326 
     | 
    
         
             
                with Context(TRACK_MATCH_STATS=0):
         
     | 
| 
       323 
327 
     | 
    
         
             
                  return graph_rewrite(self, symbolic)
         
     | 
| 
       324 
328 
     | 
    
         
             
              def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
         
     | 
| 
         @@ -342,13 +346,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): 
     | 
|
| 
       342 
346 
     | 
    
         
             
                assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}"
         
     | 
| 
       343 
347 
     | 
    
         
             
                return unwrap(self.st)
         
     | 
| 
       344 
348 
     | 
    
         
             
              @property
         
     | 
| 
       345 
     | 
    
         
            -
              def const_arg(self) -> ConstType:
         
     | 
| 
       346 
     | 
    
         
            -
                match self.base.op:
         
     | 
| 
       347 
     | 
    
         
            -
                  case Ops.CONST: ret = self.base.arg
         
     | 
| 
       348 
     | 
    
         
            -
                  case op: raise AssertionError(f"const_arg called on {op}")
         
     | 
| 
       349 
     | 
    
         
            -
                assert isinstance(ret, get_args(ConstType)), f"const_arg trying to return {ret}"
         
     | 
| 
       350 
     | 
    
         
            -
                return ret
         
     | 
| 
       351 
     | 
    
         
            -
              @property
         
     | 
| 
       352 
349 
     | 
    
         
             
              def axis_arg(self) -> tuple[int, ...]:
         
     | 
| 
       353 
350 
     | 
    
         
             
                assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
         
     | 
| 
       354 
351 
     | 
    
         
             
                ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
         
     | 
| 
         @@ -366,8 +363,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): 
     | 
|
| 
       366 
363 
     | 
    
         
             
                assert self.dtype.count == 1
         
     | 
| 
       367 
364 
     | 
    
         
             
                if count == 1: return self
         
     | 
| 
       368 
365 
     | 
    
         
             
                return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
         
     | 
| 
       369 
     | 
    
         
            -
              def cast(self, dtype:DType): return UOp(Ops.CAST, dtype, (self,))
         
     | 
| 
       370 
     | 
    
         
            -
              def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
         
     | 
| 
      
 366 
     | 
    
         
            +
              def cast(self, dtype:DType): return self if self.dtype == dtype else UOp(Ops.CAST, dtype, (self,))
         
     | 
| 
      
 367 
     | 
    
         
            +
              def bitcast(self, dtype:DType): return self if self.dtype == dtype else UOp(Ops.BITCAST, dtype, (self,))
         
     | 
| 
       371 
368 
     | 
    
         
             
              def gep(self, i:Union[tuple[int, ...], int]):
         
     | 
| 
       372 
369 
     | 
    
         
             
                if isinstance(i, int):
         
     | 
| 
       373 
370 
     | 
    
         
             
                  # NOTE: these are just shortcuts to not have to create and fold later
         
     | 
| 
         @@ -489,8 +486,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): 
     | 
|
| 
       489 
486 
     | 
    
         
             
                if op is Ops.BIND:
         
     | 
| 
       490 
487 
     | 
    
         
             
                  var, val = arg.unbind()
         
     | 
| 
       491 
488 
     | 
    
         
             
                  return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
         
     | 
| 
       492 
     | 
    
         
            -
                # otherwise it's just a  
     | 
| 
       493 
     | 
    
         
            -
                 
     | 
| 
      
 489 
     | 
    
         
            +
                # otherwise it's just a RESHAPE(BUFFER)
         
     | 
| 
      
 490 
     | 
    
         
            +
                if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
         
     | 
| 
      
 491 
     | 
    
         
            +
                return UOp.new_buffer(device, size, dtype).reshape(shape)
         
     | 
| 
       494 
492 
     | 
    
         
             
              def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp:
         
     | 
| 
       495 
493 
     | 
    
         
             
                # if it's a shrink, do the shrink before the copy with CONTIGUOUS
         
     | 
| 
       496 
494 
     | 
    
         
             
                if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
         
     | 
| 
         @@ -505,14 +503,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass): 
     | 
|
| 
       505 
503 
     | 
    
         
             
                return ret
         
     | 
| 
       506 
504 
     | 
    
         
             
              def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
         
     | 
| 
       507 
505 
     | 
    
         
             
              @property
         
     | 
| 
       508 
     | 
    
         
            -
              def metadata(self): return all_metadata.get(self, None)
         
     | 
| 
      
 506 
     | 
    
         
            +
              def metadata(self) -> tuple[Metadata, ...]|Metadata|None: return self.arg.metadata if self.op is Ops.KERNEL else all_metadata.get(self, None)
         
     | 
| 
       509 
507 
     | 
    
         | 
| 
       510 
508 
     | 
    
         
             
              # *** uop movement ops ***
         
     | 
| 
       511 
509 
     | 
    
         | 
| 
       512 
510 
     | 
    
         
             
              @property
         
     | 
| 
       513 
511 
     | 
    
         
             
              def base(self) -> UOp:
         
     | 
| 
       514 
     | 
    
         
            -
                if self.op in GroupOp.Movement: return self.src[0].base
         
     | 
| 
       515 
     | 
    
         
            -
                return self 
     | 
| 
      
 512 
     | 
    
         
            +
                if (self.op is Ops.VIEW and len(self.src) != 0) or self.op in GroupOp.Movement: return self.src[0].base
         
     | 
| 
      
 513 
     | 
    
         
            +
                return self
         
     | 
| 
       516 
514 
     | 
    
         
             
              def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
         
     | 
| 
       517 
515 
     | 
    
         | 
| 
       518 
516 
     | 
    
         
             
              def _mop(self, op:Ops, arg):
         
     | 
| 
         @@ -527,11 +525,17 @@ class UOp(MathTrait, metaclass=UOpMetaClass): 
     | 
|
| 
       527 
525 
     | 
    
         
             
              def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg)
         
     | 
| 
       528 
526 
     | 
    
         
             
              def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg)
         
     | 
| 
       529 
527 
     | 
    
         | 
| 
      
 528 
     | 
    
         
            +
              # *** uop UNIQUE ***
         
     | 
| 
      
 529 
     | 
    
         
            +
             
     | 
| 
      
 530 
     | 
    
         
            +
              # TODO: use this in Buffer
         
     | 
| 
      
 531 
     | 
    
         
            +
              unique_num = itertools.count(0)
         
     | 
| 
      
 532 
     | 
    
         
            +
              @staticmethod
         
     | 
| 
      
 533 
     | 
    
         
            +
              def unique(): return UOp(Ops.UNIQUE, arg=next(UOp.unique_num))
         
     | 
| 
      
 534 
     | 
    
         
            +
             
     | 
| 
       530 
535 
     | 
    
         
             
              # *** uop Buffer stuff ***
         
     | 
| 
       531 
536 
     | 
    
         | 
| 
       532 
     | 
    
         
            -
              buffer_num = itertools.count(0)
         
     | 
| 
       533 
537 
     | 
    
         
             
              @staticmethod
         
     | 
| 
       534 
     | 
    
         
            -
              def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device), 
     | 
| 
      
 538 
     | 
    
         
            +
              def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device), UOp.unique()), size)
         
     | 
| 
       535 
539 
     | 
    
         
             
              @property
         
     | 
| 
       536 
540 
     | 
    
         
             
              def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
         
     | 
| 
       537 
541 
     | 
    
         
             
              @functools.cached_property
         
     | 
| 
         @@ -542,11 +546,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass): 
     | 
|
| 
       542 
546 
     | 
    
         
             
              @property
         
     | 
| 
       543 
547 
     | 
    
         
             
              def buf_uop(self) -> UOp:
         
     | 
| 
       544 
548 
     | 
    
         
             
                if self.base.op is Ops.BUFFER: return self.base
         
     | 
| 
       545 
     | 
    
         
            -
                assert self.base.op in {*GroupOp.Buffer, Ops.ASSIGN 
     | 
| 
      
 549 
     | 
    
         
            +
                assert self.base.op in {*GroupOp.Buffer, Ops.ASSIGN}, f"buf_uop called on {self.op}"
         
     | 
| 
       546 
550 
     | 
    
         
             
                return self.src[0].buf_uop
         
     | 
| 
       547 
551 
     | 
    
         
             
              @property
         
     | 
| 
       548 
552 
     | 
    
         
             
              def buffer(self) -> Buffer:
         
     | 
| 
       549 
     | 
    
         
            -
                if self 
     | 
| 
      
 553 
     | 
    
         
            +
                if self is not self.base:
         
     | 
| 
       550 
554 
     | 
    
         
             
                  assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
         
     | 
| 
       551 
555 
     | 
    
         
             
                  return self.src[0].buffer
         
     | 
| 
       552 
556 
     | 
    
         
             
                assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
         
     | 
| 
         @@ -591,6 +595,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass): 
     | 
|
| 
       591 
595 
     | 
    
         | 
| 
       592 
596 
     | 
    
         
             
              # *** uop symbolic stuff ***
         
     | 
| 
       593 
597 
     | 
    
         | 
| 
      
 598 
     | 
    
         
            +
              def is_increasing(self:UOp) -> bool:
         
     | 
| 
      
 599 
     | 
    
         
            +
                # is f a monotonically increasing function regards its input
         
     | 
| 
      
 600 
     | 
    
         
            +
                if self.op in GroupOp.Irreducible: return True
         
     | 
| 
      
 601 
     | 
    
         
            +
                if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
         
     | 
| 
      
 602 
     | 
    
         
            +
                if self.op in (Ops.MUL, Ops.IDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
         
     | 
| 
      
 603 
     | 
    
         
            +
                return False  # False if not sure
         
     | 
| 
       594 
604 
     | 
    
         
             
              def const_factor(self) -> int:
         
     | 
| 
       595 
605 
     | 
    
         
             
                """largest known int that divides self"""
         
     | 
| 
       596 
606 
     | 
    
         
             
                if self.op is Ops.CONST: return self.arg
         
     | 
| 
         @@ -598,7 +608,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): 
     | 
|
| 
       598 
608 
     | 
    
         
             
                if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
         
     | 
| 
       599 
609 
     | 
    
         
             
                if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
         
     | 
| 
       600 
610 
     | 
    
         
             
                return 1
         
     | 
| 
       601 
     | 
    
         
            -
              def divides(self, v) -> UOp|None:
         
     | 
| 
      
 611 
     | 
    
         
            +
              def divides(self, v:int) -> UOp|None:
         
     | 
| 
       602 
612 
     | 
    
         
             
                if v==1: return self
         
     | 
| 
       603 
613 
     | 
    
         
             
                if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
         
     | 
| 
       604 
614 
     | 
    
         
             
                if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
         
     | 
| 
         @@ -642,7 +652,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): 
     | 
|
| 
       642 
652 
     | 
    
         
             
                if self.op is Ops.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
         
     | 
| 
       643 
653 
     | 
    
         
             
                if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
         
     | 
| 
       644 
654 
     | 
    
         
             
                if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
         
     | 
| 
       645 
     | 
    
         
            -
                # TODO:  
     | 
| 
      
 655 
     | 
    
         
            +
                # TODO: Ops.SPECIAL is Ops.DEFINE_VAR
         
     | 
| 
       646 
656 
     | 
    
         
             
                if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].vmax
         
     | 
| 
       647 
657 
     | 
    
         
             
                if self.op is Ops.CONST: return self.arg, self.arg
         
     | 
| 
       648 
658 
     | 
    
         
             
                if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
         
     | 
| 
         @@ -665,20 +675,26 @@ class UOp(MathTrait, metaclass=UOpMetaClass): 
     | 
|
| 
       665 
675 
     | 
    
         | 
| 
       666 
676 
     | 
    
         
             
            @dataclass(frozen=True)
         
     | 
| 
       667 
677 
     | 
    
         
             
            class KernelInfo:
         
     | 
| 
      
 678 
     | 
    
         
            +
              name: str = "test"            # name of the kernel
         
     | 
| 
       668 
679 
     | 
    
         
             
              local_dims: int = 0           # number of local dimensions  (this is remapping RANGE to SPECIAL)
         
     | 
| 
       669 
680 
     | 
    
         
             
              upcasted: int = 0             # count that are upcasted     (this is remapping RANGE to UNROLL)
         
     | 
| 
       670 
681 
     | 
    
         
             
              dont_use_locals: bool = False # don't use local indexing
         
     | 
| 
       671 
682 
     | 
    
         | 
| 
       672 
     | 
    
         
            -
            #  
     | 
| 
      
 683 
     | 
    
         
            +
            # ******** ops in python ********
         
     | 
| 
       673 
684 
     | 
    
         | 
| 
       674 
685 
     | 
    
         
             
            def safe_exp2(x):
         
     | 
| 
       675 
686 
     | 
    
         
             
              try: return 2 ** x
         
     | 
| 
       676 
687 
     | 
    
         
             
              except OverflowError: return math.inf
         
     | 
| 
       677 
688 
     | 
    
         | 
| 
      
 689 
     | 
    
         
            +
            def safe_pow(x, y):
         
     | 
| 
      
 690 
     | 
    
         
            +
              try: return math.nan if isinstance(p:=pow(x, y), complex) else p
         
     | 
| 
      
 691 
     | 
    
         
            +
              except ZeroDivisionError: return math.inf
         
     | 
| 
      
 692 
     | 
    
         
            +
              except ValueError: return math.inf if x > 0 else -math.inf
         
     | 
| 
      
 693 
     | 
    
         
            +
             
     | 
| 
       678 
694 
     | 
    
         
             
            python_alu: dict[Ops, Callable]  = {
         
     | 
| 
       679 
695 
     | 
    
         
             
              Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
         
     | 
| 
       680 
696 
     | 
    
         
             
              Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
         
     | 
| 
       681 
     | 
    
         
            -
              Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
         
     | 
| 
      
 697 
     | 
    
         
            +
              Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow,
         
     | 
| 
       682 
698 
     | 
    
         
             
              Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
         
     | 
| 
       683 
699 
     | 
    
         
             
              Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
         
     | 
| 
       684 
700 
     | 
    
         
             
              Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0,
         
     | 
| 
         @@ -703,7 +719,8 @@ def get_location() -> tuple[str, int]: 
     | 
|
| 
       703 
719 
     | 
    
         
             
              frm = sys._getframe(1)
         
     | 
| 
       704 
720 
     | 
    
         
             
              # find the real frame in the file that has the UPat, TODO: is there a better way to do this?
         
     | 
| 
       705 
721 
     | 
    
         
             
              while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py",
         
     | 
| 
       706 
     | 
    
         
            -
                                                                                                    " 
     | 
| 
      
 722 
     | 
    
         
            +
                                                                                                    "symbolic.py", "expander.py", "lowerer.py", "cstyle.py",
         
     | 
| 
      
 723 
     | 
    
         
            +
                                                                                                    "linearize.py"}:
         
     | 
| 
       707 
724 
     | 
    
         
             
                frm = frm.f_back
         
     | 
| 
       708 
725 
     | 
    
         
             
              return frm.f_code.co_filename, frm.f_lineno
         
     | 
| 
       709 
726 
     | 
    
         
             
            @functools.lru_cache(None)
         
     | 
| 
         @@ -840,7 +857,9 @@ match_stats:dict[UPat, list[Union[int, float]]] = dict() 
     | 
|
| 
       840 
857 
     | 
    
         
             
            class TrackedGraphRewrite:
         
     | 
| 
       841 
858 
     | 
    
         
             
              loc: tuple[str, int]                                                                       # location that called graph_rewrite
         
     | 
| 
       842 
859 
     | 
    
         
             
              sink: UOp                                                                                  # the sink input to graph_rewrite
         
     | 
| 
      
 860 
     | 
    
         
            +
              bottom_up: bool
         
     | 
| 
       843 
861 
     | 
    
         
             
              matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list)                         # before+after of all the matches
         
     | 
| 
      
 862 
     | 
    
         
            +
              name: Optional[str] = None
         
     | 
| 
       844 
863 
     | 
    
         
             
            tracked_keys:list[Any] = []
         
     | 
| 
       845 
864 
     | 
    
         
             
            tracked_ctxs:list[list[TrackedGraphRewrite]] = []
         
     | 
| 
       846 
865 
     | 
    
         
             
            _name_cnt:dict[str, int] = {}
         
     | 
| 
         @@ -923,304 +942,19 @@ class RewriteContext: 
     | 
|
| 
       923 
942 
     | 
    
         
             
                self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
         
     | 
| 
       924 
943 
     | 
    
         
             
                return ret
         
     | 
| 
       925 
944 
     | 
    
         | 
| 
       926 
     | 
    
         
            -
            def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp:
         
     | 
| 
       927 
     | 
    
         
            -
              if TRACK_MATCH_STATS >= 2 and  
     | 
| 
       928 
     | 
    
         
            -
                tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
         
     | 
| 
      
 945 
     | 
    
         
            +
            def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> UOp:
         
     | 
| 
      
 946 
     | 
    
         
            +
              if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
         
     | 
| 
      
 947 
     | 
    
         
            +
                tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
         
     | 
| 
       929 
948 
     | 
    
         
             
              return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink)
         
     | 
| 
       930 
949 
     | 
    
         | 
| 
       931 
     | 
    
         
            -
            def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> dict[UOp, UOp]:
         
     | 
| 
       932 
     | 
    
         
            -
              if TRACK_MATCH_STATS >= 2 and  
     | 
| 
       933 
     | 
    
         
            -
                tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
         
     | 
| 
      
 950 
     | 
    
         
            +
            def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> dict[UOp, UOp]:
         
     | 
| 
      
 951 
     | 
    
         
            +
              if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
         
     | 
| 
      
 952 
     | 
    
         
            +
                tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
         
     | 
| 
       934 
953 
     | 
    
         
             
              rewrite_ctx = RewriteContext(pm, ctx)
         
     | 
| 
       935 
954 
     | 
    
         
             
              return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
         
     | 
| 
       936 
955 
     | 
    
         | 
| 
       937 
     | 
    
         
            -
             
     | 
| 
       938 
     | 
    
         
            -
            # *** most of symbolic lives here now ***
         
     | 
| 
       939 
     | 
    
         
            -
             
     | 
| 
       940 
     | 
    
         
            -
            def split_uop(x:UOp, sep:Ops):
         
     | 
| 
       941 
     | 
    
         
            -
              if x.op is sep:
         
     | 
| 
       942 
     | 
    
         
            -
                for s in x.src: yield from split_uop(s, sep)
         
     | 
| 
       943 
     | 
    
         
            -
              else: yield x
         
     | 
| 
       944 
     | 
    
         
            -
             
     | 
| 
       945 
     | 
    
         
            -
            def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None:
         
     | 
| 
       946 
     | 
    
         
            -
              # simplify x // y or x % y, None means no change
         
     | 
| 
       947 
     | 
    
         
            -
              # simple cancel div/mod case
         
     | 
| 
       948 
     | 
    
         
            -
              if y.vmin != 0 != y.vmax and (q:=x.vmin//y.vmin) == x.vmin//y.vmax == x.vmax//y.vmin == x.vmax//y.vmax:
         
     | 
| 
       949 
     | 
    
         
            -
                return x - q*y if which is Ops.MOD else x.const_like(q)
         
     | 
| 
       950 
     | 
    
         
            -
             
     | 
| 
       951 
     | 
    
         
            -
              if (y.op is not Ops.CONST) or ((c := y.arg) <= 0) or (x.dtype.count > 1): return None
         
     | 
| 
       952 
     | 
    
         
            -
             
     | 
| 
       953 
     | 
    
         
            -
              svars, factors, quotients, remainders, gcd, div, const, offset, something_changed = [], [], [], [], c, 1, 0, 0, False
         
     | 
| 
       954 
     | 
    
         
            -
              for u in split_uop(x, Ops.ADD):
         
     | 
| 
       955 
     | 
    
         
            -
                if u.op is Ops.MOD and which is Ops.MOD and u.src[1].op is Ops.CONST and u.src[1].arg%c == 0:
         
     | 
| 
       956 
     | 
    
         
            -
                  u = u.src[0]
         
     | 
| 
       957 
     | 
    
         
            -
                  something_changed = True
         
     | 
| 
       958 
     | 
    
         
            -
                v: UOp = u.divides(f:=u.const_factor())
         
     | 
| 
       959 
     | 
    
         
            -
                q, r = divmod(f, c)
         
     | 
| 
       960 
     | 
    
         
            -
                if r==0 or ((which is Ops.MOD or split_rem or u.op is Ops.CONST) and r!=f): something_changed = True
         
     | 
| 
       961 
     | 
    
         
            -
                offset += r*v.vmin
         
     | 
| 
       962 
     | 
    
         
            -
                if u.op is Ops.CONST: const += f
         
     | 
| 
       963 
     | 
    
         
            -
                else:  # div is the smallest common divisor of all terms
         
     | 
| 
       964 
     | 
    
         
            -
                  if f > 1 and c % f == 0 and (div == 1 or div > f): div = f
         
     | 
| 
       965 
     | 
    
         
            -
                  gcd = math.gcd(r, gcd)
         
     | 
| 
       966 
     | 
    
         
            -
                  factors.append(f); svars.append(v); quotients.append(q); remainders.append(r)  # noqa: E702
         
     | 
| 
       967 
     | 
    
         
            -
             
     | 
| 
       968 
     | 
    
         
            -
              lbound = ubound = offset = offset % c
         
     | 
| 
       969 
     | 
    
         
            -
              # we can fold if the expression has only one non-constant term and this term can only take on two values
         
     | 
| 
       970 
     | 
    
         
            -
              if len(svars)==1 and (v:=svars[0]).vmax-v.vmin == 1:
         
     | 
| 
       971 
     | 
    
         
            -
                r = (offset+remainders[0])%c - offset%c
         
     | 
| 
       972 
     | 
    
         
            -
                offset -= r * v.vmin
         
     | 
| 
       973 
     | 
    
         
            -
                if which is Ops.MOD: return r*v + offset
         
     | 
| 
       974 
     | 
    
         
            -
                return (factors[0]-r)//c * v + (const-offset)//c
         
     | 
| 
       975 
     | 
    
         
            -
             
     | 
| 
       976 
     | 
    
         
            -
              # a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
         
     | 
| 
       977 
     | 
    
         
            -
              # within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
         
     | 
| 
       978 
     | 
    
         
            -
              for (r, v) in zip(remainders, svars):
         
     | 
| 
       979 
     | 
    
         
            -
                if r > c//2:
         
     | 
| 
       980 
     | 
    
         
            -
                  if (lbound := lbound + (r:=r-c) * (v.vmax-v.vmin)) < 0: break
         
     | 
| 
       981 
     | 
    
         
            -
                elif (ubound := ubound + r * (v.vmax-v.vmin)) >= c: break
         
     | 
| 
       982 
     | 
    
         
            -
                offset -= r * v.vmin  # determine what the new offset would be
         
     | 
| 
       983 
     | 
    
         
            -
              else: # vmin/vmax of the remainder is between 0 and c, we can remove the mod/div
         
     | 
| 
       984 
     | 
    
         
            -
                remainders = [min(r, r-c, key=abs) for r in remainders]
         
     | 
| 
       985 
     | 
    
         
            -
                if which is Ops.MOD: return functools.reduce(operator.add, [r*v for r,v in zip(remainders,svars)], x.const_like(offset))
         
     | 
| 
       986 
     | 
    
         
            -
                return functools.reduce(operator.add, [(f-r)//c * v for f,r,v in zip(factors, remainders,svars)], x.const_like((const-offset)//c))
         
     | 
| 
       987 
     | 
    
         
            -
             
     | 
| 
       988 
     | 
    
         
            -
              if gcd != 1: something_changed = True
         
     | 
| 
       989 
     | 
    
         
            -
              if not something_changed:
         
     | 
| 
       990 
     | 
    
         
            -
                if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, UOp.const(dtypes.int, div), Ops.IDIV)) is not None: return newx//(c//div)
         
     | 
| 
       991 
     | 
    
         
            -
                return None
         
     | 
| 
       992 
     | 
    
         
            -
              quo, rem = x.const_like(const//c), x.const_like((const%c)//gcd)
         
     | 
| 
       993 
     | 
    
         
            -
              for q,r,f,v in zip(quotients, remainders, factors, svars):
         
     | 
| 
       994 
     | 
    
         
            -
                if which is Ops.IDIV and (not split_rem) and r!=0:
         
     | 
| 
       995 
     | 
    
         
            -
                  rem += f//gcd * v
         
     | 
| 
       996 
     | 
    
         
            -
                else:
         
     | 
| 
       997 
     | 
    
         
            -
                  rem += r//gcd * v
         
     | 
| 
       998 
     | 
    
         
            -
                  quo += q * v
         
     | 
| 
       999 
     | 
    
         
            -
             
     | 
| 
       1000 
     | 
    
         
            -
              if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
         
     | 
| 
       1001 
     | 
    
         
            -
              return rem//(c//gcd)+quo
         
     | 
| 
       1002 
     | 
    
         
            -
             
     | 
| 
       1003 
     | 
    
         
            -
            def lt_folding(x:UOp, c:int) -> UOp|None:
         
     | 
| 
       1004 
     | 
    
         
            -
              p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
         
     | 
| 
       1005 
     | 
    
         
            -
              if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
         
     | 
| 
       1006 
     | 
    
         
            -
                return cast(UOp, functools.reduce(operator.add, np).divides(d))<(c//d)
         
     | 
| 
       1007 
     | 
    
         
            -
              return None
         
     | 
| 
       1008 
     | 
    
         
            -
             
     | 
| 
       1009 
     | 
    
         
            -
            def fold_unrolled_divs(divs:UOp):
         
     | 
| 
       1010 
     | 
    
         
            -
              # div pattern in unrolled arange
         
     | 
| 
       1011 
     | 
    
         
            -
              # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
         
     | 
| 
       1012 
     | 
    
         
            -
              add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None
         
     | 
| 
       1013 
     | 
    
         
            -
              for u in add_chain:
         
     | 
| 
       1014 
     | 
    
         
            -
                if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
         
     | 
| 
       1015 
     | 
    
         
            -
                if denominator is None: denominator = u.src[1].arg
         
     | 
| 
       1016 
     | 
    
         
            -
                if denominator != u.src[1].arg: return None
         
     | 
| 
       1017 
     | 
    
         
            -
                # assumed CONST is the last of an ADD
         
     | 
| 
       1018 
     | 
    
         
            -
                if (s0:=u.src[0]).op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
         
     | 
| 
       1019 
     | 
    
         
            -
                  seen_const.append(s0.src[1].arg)
         
     | 
| 
       1020 
     | 
    
         
            -
                  s0 = s0.src[0]
         
     | 
| 
       1021 
     | 
    
         
            -
                else: seen_const.append(0)
         
     | 
| 
       1022 
     | 
    
         
            -
                if ans is None: ans = s0
         
     | 
| 
       1023 
     | 
    
         
            -
                if ans is not s0: return None
         
     | 
| 
       1024 
     | 
    
         
            -
              if denominator is None: return None
         
     | 
| 
       1025 
     | 
    
         
            -
              # the first (denominator-len(seen_const)) terms may have been folded to 0 already
         
     | 
| 
       1026 
     | 
    
         
            -
              for i in range(denominator-len(seen_const)):
         
     | 
| 
       1027 
     | 
    
         
            -
                if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
         
     | 
| 
       1028 
     | 
    
         
            -
              return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
         
     | 
| 
       1029 
     | 
    
         
            -
             
     | 
| 
       1030 
     | 
    
         
            -
            def canonicalize_simplex(X:UOp) -> UOp|None:
         
     | 
| 
       1031 
     | 
    
         
            -
              # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
         
     | 
| 
       1032 
     | 
    
         
            -
              # returns x0 + x1 + ... in such case, or None if not
         
     | 
| 
       1033 
     | 
    
         
            -
              changed, ret = False, []
         
     | 
| 
       1034 
     | 
    
         
            -
              for u in split_uop(X, Ops.ADD):
         
     | 
| 
       1035 
     | 
    
         
            -
                # assumed the const is the last src of MUL
         
     | 
| 
       1036 
     | 
    
         
            -
                if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
         
     | 
| 
       1037 
     | 
    
         
            -
                  changed = True
         
     | 
| 
       1038 
     | 
    
         
            -
                  u = u.src[0]
         
     | 
| 
       1039 
     | 
    
         
            -
                if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None
         
     | 
| 
       1040 
     | 
    
         
            -
                ret.append(u)
         
     | 
| 
       1041 
     | 
    
         
            -
              return functools.reduce(operator.add, ret) if changed else None
         
     | 
| 
       1042 
     | 
    
         
            -
             
     | 
| 
       1043 
     | 
    
         
            -
            def is_increasing(f:UOp) -> bool:
         
     | 
| 
       1044 
     | 
    
         
            -
              # is f a monotonically increasing function regards its input
         
     | 
| 
       1045 
     | 
    
         
            -
              if f.op in GroupOp.Irreducible: return True
         
     | 
| 
       1046 
     | 
    
         
            -
              if f.op is Ops.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
         
     | 
| 
       1047 
     | 
    
         
            -
              if f.op in (Ops.MUL, Ops.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
         
     | 
| 
       1048 
     | 
    
         
            -
              return False  # False if not sure
         
     | 
| 
       1049 
     | 
    
         
            -
             
     | 
| 
       1050 
     | 
    
         
            -
            def parse_valid(valid:UOp) -> tuple[UOp, bool, int]:
         
     | 
| 
       1051 
     | 
    
         
            -
              # if it's X <= c, returns X, True, c
         
     | 
| 
       1052 
     | 
    
         
            -
              # if it's X >= c, returns X, False, c
         
     | 
| 
       1053 
     | 
    
         
            -
             
     | 
| 
       1054 
     | 
    
         
            -
              # (X < c).ne(True) -> X >= c
         
     | 
| 
       1055 
     | 
    
         
            -
              if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
         
     | 
| 
       1056 
     | 
    
         
            -
                (s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
         
     | 
| 
       1057 
     | 
    
         
            -
              # X < c -> X <= c-1
         
     | 
| 
       1058 
     | 
    
         
            -
              if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, valid.src[1].arg-1
         
     | 
| 
       1059 
     | 
    
         
            -
              raise ValueError(f"not able to parse {valid=}")
         
     | 
| 
       1060 
     | 
    
         
            -
             
     | 
| 
       1061 
     | 
    
         
            -
            def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
         
     | 
| 
       1062 
     | 
    
         
            -
              # return None if valid is always False, otherwise the simplified uop (might be the same as input)
         
     | 
| 
       1063 
     | 
    
         
            -
             
     | 
| 
       1064 
     | 
    
         
            -
              # first, parse valid into {expr: (lower_bound, upper_bound)}
         
     | 
| 
       1065 
     | 
    
         
            -
              bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
         
     | 
| 
       1066 
     | 
    
         
            -
              for stmt in split_uop(valid, Ops.AND):
         
     | 
| 
       1067 
     | 
    
         
            -
                try: expr, is_upper, c = parse_valid(stmt)
         
     | 
| 
       1068 
     | 
    
         
            -
                except ValueError: return uop  # give up if we cannot parse the valid
         
     | 
| 
       1069 
     | 
    
         
            -
                bounds[expr][int(is_upper)] = c
         
     | 
| 
       1070 
     | 
    
         
            -
             
     | 
| 
       1071 
     | 
    
         
            -
              # simplify uop given that valid is True
         
     | 
| 
       1072 
     | 
    
         
            -
              for expr,v in bounds.items():
         
     | 
| 
       1073 
     | 
    
         
            -
                # some expr has lower bound > upper bound -> valid is an empty set and we return None
         
     | 
| 
       1074 
     | 
    
         
            -
                if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
         
     | 
| 
       1075 
     | 
    
         
            -
             
     | 
| 
       1076 
     | 
    
         
            -
                # every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
         
     | 
| 
       1077 
     | 
    
         
            -
                candidates = []
         
     | 
| 
       1078 
     | 
    
         
            -
                if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
         
     | 
| 
       1079 
     | 
    
         
            -
                  # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
         
     | 
| 
       1080 
     | 
    
         
            -
                  candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
         
     | 
| 
       1081 
     | 
    
         
            -
                # try checking the whole clause
         
     | 
| 
       1082 
     | 
    
         
            -
                if expr in uop.toposort:
         
     | 
| 
       1083 
     | 
    
         
            -
                  candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
         
     | 
| 
       1084 
     | 
    
         
            -
             
     | 
| 
       1085 
     | 
    
         
            -
                for candidate in candidates:
         
     | 
| 
       1086 
     | 
    
         
            -
                  # if every branch in candidate gives the same simplified uop, we can rewrite the uop
         
     | 
| 
       1087 
     | 
    
         
            -
                  newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
         
     | 
| 
       1088 
     | 
    
         
            -
                  if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
         
     | 
| 
       1089 
     | 
    
         
            -
                    if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
         
     | 
| 
       1090 
     | 
    
         
            -
                    if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
         
     | 
| 
       1091 
     | 
    
         
            -
                  elif all_same(newuops): uop = newuops[0]
         
     | 
| 
       1092 
     | 
    
         
            -
             
     | 
| 
       1093 
     | 
    
         
            -
              return uop
         
     | 
| 
       1094 
     | 
    
         
            -
             
     | 
| 
       1095 
     | 
    
         
            -
            def _valid_priority(v: UOp, valids:list[UOp]):
         
     | 
| 
       1096 
     | 
    
         
            -
              # we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
         
     | 
| 
       1097 
     | 
    
         
            -
              try: return sum(-1 if parse_valid(v)[0] in other.toposort else 0 for other in valids)
         
     | 
| 
       1098 
     | 
    
         
            -
              except ValueError: return 0
         
     | 
| 
       1099 
     | 
    
         
            -
             
     | 
| 
       1100 
     | 
    
         
            -
            def simplify_valid(valid:UOp) -> UOp|None:
         
     | 
| 
       1101 
     | 
    
         
            -
              ret:list[UOp] = []
         
     | 
| 
       1102 
     | 
    
         
            -
              something_changed = False
         
     | 
| 
       1103 
     | 
    
         
            -
              valids = list(split_uop(valid, Ops.AND))
         
     | 
| 
       1104 
     | 
    
         
            -
              for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
         
     | 
| 
       1105 
     | 
    
         
            -
                ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
         
     | 
| 
       1106 
     | 
    
         
            -
                if ret[-1] is not stmt: something_changed = True
         
     | 
| 
       1107 
     | 
    
         
            -
              return functools.reduce(operator.and_, ret) if something_changed else None
         
     | 
| 
       1108 
     | 
    
         
            -
             
     | 
| 
       1109 
     | 
    
         
            -
            # def max_var_const(x:UOp, c1:UOp, c2:UOp):
         
     | 
| 
       1110 
     | 
    
         
            -
            #   if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
         
     | 
| 
       1111 
     | 
    
         
            -
            #   if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
         
     | 
| 
       1112 
     | 
    
         
            -
             
     | 
| 
       1113 
956 
     | 
    
         
             
            def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x
         
     | 
| 
       1114 
957 
     | 
    
         | 
| 
       1115 
     | 
    
         
            -
            symbolic_simple = PatternMatcher([
         
     | 
| 
       1116 
     | 
    
         
            -
              # ** self folding **
         
     | 
| 
       1117 
     | 
    
         
            -
              (UPat.var("x") + 0, lambda x: x),    # x+0 -> x
         
     | 
| 
       1118 
     | 
    
         
            -
              (UPat.var("x") * 1, lambda x: x),    # x*1 -> x
         
     | 
| 
       1119 
     | 
    
         
            -
              (UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
         
     | 
| 
       1120 
     | 
    
         
            -
              (UPat.var("x") // 1, lambda x: x),   # x//1 -> x
         
     | 
| 
       1121 
     | 
    
         
            -
              (UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
         
     | 
| 
       1122 
     | 
    
         
            -
              (UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
         
     | 
| 
       1123 
     | 
    
         
            -
              ((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
         
     | 
| 
       1124 
     | 
    
         
            -
              ((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base),  # (x%y)%y = -> x%y (rewritten with base for speed)
         
     | 
| 
       1125 
     | 
    
         
            -
              (UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
         
     | 
| 
       1126 
     | 
    
         
            -
              ((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
         
     | 
| 
       1127 
     | 
    
         
            -
                lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
         
     | 
| 
       1128 
     | 
    
         
            -
              (UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
         
     | 
| 
       1129 
     | 
    
         
            -
              (UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
         
     | 
| 
       1130 
     | 
    
         
            -
              (UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
         
     | 
| 
       1131 
     | 
    
         
            -
              (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
         
     | 
| 
       1132 
     | 
    
         
            -
              (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
         
     | 
| 
       1133 
     | 
    
         
            -
              # ** zero folding **
         
     | 
| 
       1134 
     | 
    
         
            -
              (UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
         
     | 
| 
       1135 
     | 
    
         
            -
              (UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
         
     | 
| 
       1136 
     | 
    
         
            -
               lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
         
     | 
| 
       1137 
     | 
    
         
            -
              # x*0 -> 0 or 0*x -> 0
         
     | 
| 
       1138 
     | 
    
         
            -
              # if x is nan or inf it should render the nan value.
         
     | 
| 
       1139 
     | 
    
         
            -
              # NOTE: this can be wrong for loaded NaN
         
     | 
| 
       1140 
     | 
    
         
            -
              (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
         
     | 
| 
       1141 
     | 
    
         
            -
              # ** constant folding **
         
     | 
| 
       1142 
     | 
    
         
            -
              # TODO: add const folding for Ops.THREEFRY
         
     | 
| 
       1143 
     | 
    
         
            -
              (UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))),
         
     | 
| 
       1144 
     | 
    
         
            -
               lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False)) if a.op is not Ops.THREEFRY else None),
         
     | 
| 
       1145 
     | 
    
         
            -
              # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
         
     | 
| 
       1146 
     | 
    
         
            -
              (UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
         
     | 
| 
       1147 
     | 
    
         
            -
              (UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
         
     | 
| 
       1148 
     | 
    
         
            -
              (UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
         
     | 
| 
       1149 
     | 
    
         
            -
              # *** cast ***
         
     | 
| 
       1150 
     | 
    
         
            -
              (UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
         
     | 
| 
       1151 
     | 
    
         
            -
              (UPat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
         
     | 
| 
       1152 
     | 
    
         
            -
            ])
         
     | 
| 
       1153 
     | 
    
         
            -
             
     | 
| 
       1154 
     | 
    
         
            -
            symbolic = symbolic_simple+PatternMatcher([
         
     | 
| 
       1155 
     | 
    
         
            -
              # ** COMMUTATIVE flipping **
         
     | 
| 
       1156 
     | 
    
         
            -
              (UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
         
     | 
| 
       1157 
     | 
    
         
            -
              # ** boolean algebra **
         
     | 
| 
       1158 
     | 
    
         
            -
              (UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
         
     | 
| 
       1159 
     | 
    
         
            -
              # ** combine terms **
         
     | 
| 
       1160 
     | 
    
         
            -
              (UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
         
     | 
| 
       1161 
     | 
    
         
            -
              ((UPat.var("y") + UPat.var("x") * UPat.cvar("c0")) + UPat.var("x") * UPat.cvar("c1"), lambda x,y,c0,c1: y+x*(c0+c1)),
         
     | 
| 
       1162 
     | 
    
         
            -
              (UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
         
     | 
| 
       1163 
     | 
    
         
            -
              ((UPat.var("y") + UPat.var("x")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: y+x*(c+1)),
         
     | 
| 
       1164 
     | 
    
         
            -
              (UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
         
     | 
| 
       1165 
     | 
    
         
            -
              ((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2),
         
     | 
| 
       1166 
     | 
    
         
            -
              ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3)
         
     | 
| 
       1167 
     | 
    
         
            -
              (-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)),  # -(x+c) -> -x + -c
         
     | 
| 
       1168 
     | 
    
         
            -
              # a conditional with the same results either way is a noop, also fold const conditionals
         
     | 
| 
       1169 
     | 
    
         
            -
              (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
         
     | 
| 
       1170 
     | 
    
         
            -
              (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
         
     | 
| 
       1171 
     | 
    
         
            -
              # alu of two where with same conds can combine, only do if true branch or false branch is const
         
     | 
| 
       1172 
     | 
    
         
            -
              (UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
         
     | 
| 
       1173 
     | 
    
         
            -
               lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
         
     | 
| 
       1174 
     | 
    
         
            -
              # ALU min==max -> CONST (slow!)
         
     | 
| 
       1175 
     | 
    
         
            -
              (UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
         
     | 
| 
       1176 
     | 
    
         
            -
              # max folding
         
     | 
| 
       1177 
     | 
    
         
            -
              (UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
         
     | 
| 
       1178 
     | 
    
         
            -
              # TODO: why does this rule break beautiful_mnist?
         
     | 
| 
       1179 
     | 
    
         
            -
              #((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z),
         
     | 
| 
       1180 
     | 
    
         
            -
              #((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
         
     | 
| 
       1181 
     | 
    
         
            -
              # ** two stage ALU folding **
         
     | 
| 
       1182 
     | 
    
         
            -
              *((UPat.var("x").alu(op, UPat.cvar("c1")).alu(op, UPat.cvar("c2")).named("f"),
         
     | 
| 
       1183 
     | 
    
         
            -
                 lambda f,x,c1,c2: x.alu(f.op,c1.alu(f.op,c2))) for op in GroupOp.Associative),
         
     | 
| 
       1184 
     | 
    
         
            -
              ((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)),  # c0 + x < c1 -> x < c1 - c0
         
     | 
| 
       1185 
     | 
    
         
            -
              ((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
         
     | 
| 
       1186 
     | 
    
         
            -
              # ** lt **
         
     | 
| 
       1187 
     | 
    
         
            -
              # c0*x<c1 for positive int c0,c1
         
     | 
| 
       1188 
     | 
    
         
            -
              ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
         
     | 
| 
       1189 
     | 
    
         
            -
               lambda x,c0,c1: x<math.ceil(c1.arg/c0.arg) if c0.arg > 0 and c1.arg > 0 else None),
         
     | 
| 
       1190 
     | 
    
         
            -
              # c0*x<c1 for negative int c0 and non-positive c1
         
     | 
| 
       1191 
     | 
    
         
            -
              ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
         
     | 
| 
       1192 
     | 
    
         
            -
               lambda x,c0,c1: (-x)<(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
         
     | 
| 
       1193 
     | 
    
         
            -
              # x//c0<c1 for positive int c0
         
     | 
| 
       1194 
     | 
    
         
            -
              ((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False))<UPat.cvar("c1", vec=False),
         
     | 
| 
       1195 
     | 
    
         
            -
               lambda x,c0,c1: x<(c1.arg*c0.arg) if c0.arg > 0 else None),
         
     | 
| 
       1196 
     | 
    
         
            -
              # ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
         
     | 
| 
       1197 
     | 
    
         
            -
              (UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
         
     | 
| 
       1198 
     | 
    
         
            -
              (UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
         
     | 
| 
       1199 
     | 
    
         
            -
              # *** rules from symbolic ***
         
     | 
| 
       1200 
     | 
    
         
            -
              # unrolled arange div folding
         
     | 
| 
       1201 
     | 
    
         
            -
              (UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),
         
     | 
| 
       1202 
     | 
    
         
            -
              # generic lt folding
         
     | 
| 
       1203 
     | 
    
         
            -
              (UPat.var("x", dtypes.sints)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
         
     | 
| 
       1204 
     | 
    
         
            -
              # canonicalize a simplex with positive coefficients > 0
         
     | 
| 
       1205 
     | 
    
         
            -
              # not x < 1 -> X > 0
         
     | 
| 
       1206 
     | 
    
         
            -
              ((UPat.var("x", dtypes.ints)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
         
     | 
| 
       1207 
     | 
    
         
            -
              # ** div **
         
     | 
| 
       1208 
     | 
    
         
            -
              # div folding
         
     | 
| 
       1209 
     | 
    
         
            -
              ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)),  # (x//c+a)//d -> (x+a*c)//(c*d)
         
     | 
| 
       1210 
     | 
    
         
            -
              (UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
         
     | 
| 
       1211 
     | 
    
         
            -
              # ** mod **
         
     | 
| 
       1212 
     | 
    
         
            -
              # mod folding
         
     | 
| 
       1213 
     | 
    
         
            -
              (UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
         
     | 
| 
       1214 
     | 
    
         
            -
            ])
         
     | 
| 
       1215 
     | 
    
         
            -
             
     | 
| 
       1216 
     | 
    
         
            -
             
     | 
| 
       1217 
     | 
    
         
            -
            symbolic_flat = symbolic+PatternMatcher([
         
     | 
| 
       1218 
     | 
    
         
            -
              # ** combine terms (opinionated) **
         
     | 
| 
       1219 
     | 
    
         
            -
              (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)),  # -(x+y) -> -x + -y
         
     | 
| 
       1220 
     | 
    
         
            -
              # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
         
     | 
| 
       1221 
     | 
    
         
            -
              ((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
         
     | 
| 
       1222 
     | 
    
         
            -
            ])
         
     | 
| 
       1223 
     | 
    
         
            -
             
     | 
| 
       1224 
958 
     | 
    
         
             
            _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
         
     | 
| 
       1225 
959 
     | 
    
         | 
| 
       1226 
960 
     | 
    
         
             
            # for debug
         
     | 
| 
         @@ -1250,7 +984,8 @@ ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]] 
     | 
|
| 
       1250 
984 
     | 
    
         
             
            merge_views = PatternMatcher([
         
     | 
| 
       1251 
985 
     | 
    
         
             
              # VIEW(VIEW) merges to a single VIEW
         
     | 
| 
       1252 
986 
     | 
    
         
             
              (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)),
         
     | 
| 
       1253 
     | 
    
         
            -
               
     | 
| 
      
 987 
     | 
    
         
            +
              # remove VIEW if it's contiguous and same as the base shape
         
     | 
| 
      
 988 
     | 
    
         
            +
              (UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.All-{Ops.DEVICE}, name="x"),)), lambda vm,x: x if vm.st.contiguous and x.shape == vm.shape else None),
         
     | 
| 
       1254 
989 
     | 
    
         
             
              # merge unmasked const views
         
     | 
| 
       1255 
990 
     | 
    
         
             
              (UPat(Ops.VIEW, name="view", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
         
     | 
| 
       1256 
991 
     | 
    
         
             
               lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
         
     | 
    
        tinygrad/renderer/__init__.py
    CHANGED
    
    | 
         @@ -1,11 +1,24 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            from __future__ import annotations
         
     | 
| 
       2 
2 
     | 
    
         
             
            from typing import Optional, Callable
         
     | 
| 
       3 
3 
     | 
    
         
             
            import functools, math
         
     | 
| 
      
 4 
     | 
    
         
            +
            from enum import Enum, auto
         
     | 
| 
       4 
5 
     | 
    
         
             
            from dataclasses import dataclass, field, replace
         
     | 
| 
       5 
6 
     | 
    
         
             
            from tinygrad.helpers import to_function_name, dedup, prod
         
     | 
| 
       6 
7 
     | 
    
         
             
            from tinygrad.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
         
     | 
| 
       7 
8 
     | 
    
         
             
            from tinygrad.dtype import DType
         
     | 
| 
       8 
9 
     | 
    
         | 
| 
      
 10 
     | 
    
         
            +
            class OptOps(Enum):
         
     | 
| 
      
 11 
     | 
    
         
            +
              TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
         
     | 
| 
      
 12 
     | 
    
         
            +
              GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
         
     | 
| 
      
 13 
     | 
    
         
            +
              def __lt__(self, x:OptOps): return self.value < x.value
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
            @dataclass(frozen=True, order=True)
         
     | 
| 
      
 16 
     | 
    
         
            +
            class Opt:
         
     | 
| 
      
 17 
     | 
    
         
            +
              op: OptOps
         
     | 
| 
      
 18 
     | 
    
         
            +
              axis: Optional[int] = None
         
     | 
| 
      
 19 
     | 
    
         
            +
              arg: Optional[int | tuple] = None
         
     | 
| 
      
 20 
     | 
    
         
            +
              def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
       9 
22 
     | 
    
         
             
            @dataclass(frozen=True)
         
     | 
| 
       10 
23 
     | 
    
         
             
            class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
         
     | 
| 
       11 
24 
     | 
    
         
             
              dims: tuple[int,int,int] # N, M, K
         
     | 
| 
         @@ -70,7 +83,9 @@ class ProgramSpec: 
     | 
|
| 
       70 
83 
     | 
    
         
             
              name:str
         
     | 
| 
       71 
84 
     | 
    
         
             
              src:str
         
     | 
| 
       72 
85 
     | 
    
         
             
              device:str
         
     | 
| 
      
 86 
     | 
    
         
            +
              ast:UOp  # save the base ast (this is method cache key)
         
     | 
| 
       73 
87 
     | 
    
         
             
              uops:Optional[list[UOp]]=None
         
     | 
| 
      
 88 
     | 
    
         
            +
              applied_opts:Optional[list[Opt]]=None
         
     | 
| 
       74 
89 
     | 
    
         
             
              mem_estimate:sint=0  # TODO: get this from the load/store uops once min/max are good
         
     | 
| 
       75 
90 
     | 
    
         | 
| 
       76 
91 
     | 
    
         
             
              # filled in from uops (if we have uops)
         
     | 
| 
         @@ -121,12 +136,13 @@ class Renderer: 
     | 
|
| 
       121 
136 
     | 
    
         
             
              has_local: bool = True
         
     | 
| 
       122 
137 
     | 
    
         
             
              has_shared: bool = True
         
     | 
| 
       123 
138 
     | 
    
         
             
              # NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
         
     | 
| 
       124 
     | 
    
         
            -
              global_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO:  
     | 
| 
       125 
     | 
    
         
            -
              local_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO:  
     | 
| 
      
 139 
     | 
    
         
            +
              global_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
         
     | 
| 
      
 140 
     | 
    
         
            +
              local_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
         
     | 
| 
       126 
141 
     | 
    
         
             
              shared_max: int = 32768
         
     | 
| 
       127 
142 
     | 
    
         
             
              tensor_cores: list[TensorCore] = []
         
     | 
| 
      
 143 
     | 
    
         
            +
              pre_matcher: Optional[PatternMatcher] = None
         
     | 
| 
       128 
144 
     | 
    
         
             
              extra_matcher: Optional[PatternMatcher] = None
         
     | 
| 
       129 
145 
     | 
    
         
             
              code_for_op: dict[Ops, Callable] = {}
         
     | 
| 
       130 
146 
     | 
    
         | 
| 
       131 
147 
     | 
    
         
             
              def __reduce__(self): return self.__class__, ()
         
     | 
| 
       132 
     | 
    
         
            -
              def render(self,  
     | 
| 
      
 148 
     | 
    
         
            +
              def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
         
     |