tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/ops.py CHANGED
@@ -1,113 +1,465 @@
1
1
  from __future__ import annotations
2
- from typing import Union, Tuple, Any, List, Dict, Callable
3
- import functools, hashlib, math, operator, ctypes, struct
4
- from enum import Enum, auto
5
- from dataclasses import dataclass
6
- from tinygrad.helpers import prod, dedup
7
- from tinygrad.dtype import dtypes, DType, ConstType
8
- from tinygrad.shape.symbolic import Variable, sint
9
- from tinygrad.shape.shapetracker import ShapeTracker
10
-
11
- # these are the llops your accelerator must implement, along with toCpu
12
- # the Enum class doesn't work with mypy, this is static. sorry it's ugly
13
- # NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
14
- # NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
15
- class UnaryOps(Enum):
16
- """A -> A (elementwise)"""
17
- EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto(); RECIP = auto() # noqa: E702
18
- class BinaryOps(Enum):
19
- """A + A -> A (elementwise)"""
2
+ from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, DefaultDict
3
+ import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect
4
+ from enum import auto, IntEnum, Enum
5
+ from dataclasses import dataclass, field
6
+ from collections import defaultdict
7
+ from weakref import WeakValueDictionary
8
+ from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
9
+ from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T
10
+ if TYPE_CHECKING:
11
+ from tinygrad.shape.shapetracker import ShapeTracker
12
+
13
+ # wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
14
+ class FastEnum(IntEnum):
15
+ def __str__(self): return Enum.__str__(self)
16
+ @staticmethod
17
+ def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]])
18
+
19
+ class SimpleMathTrait:
20
+ # required to implement
21
+ def alu(self:T, arg:Ops, *src) -> T: raise NotImplementedError
22
+ def const_like(self:T, b:ConstLike) -> T: raise NotImplementedError
23
+
24
+ # great functions you get!
25
+ def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x
26
+ def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
27
+ def logical_not(self): return self.ne(True)
28
+ def neg(self):
29
+ dtype: Optional[DType] = getattr(self, 'dtype', None)
30
+ assert dtype is not None, "MathTraits __neg__ requires a dtype"
31
+ return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
32
+ def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse)
33
+ def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse)
34
+ def bitwise_and(self, x, reverse=False): return self._binop(Ops.AND, x, reverse)
35
+ def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse)
36
+ def xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse)
37
+ def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse)
38
+ def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
39
+ def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
40
+
41
+ def __neg__(self): return self.neg()
42
+
43
+ def __add__(self, x): return self.add(x)
44
+ def __sub__(self, x): return self.sub(x)
45
+ def __mul__(self, x): return self.mul(x)
46
+ def __truediv__(self, x): return self.div(x)
47
+ def __floordiv__(self, x): return self.idiv(x)
48
+ def __and__(self, x): return self.bitwise_and(x)
49
+ def __or__(self, x): return self.bitwise_or(x)
50
+ def __xor__(self, x): return self.xor(x)
51
+
52
+ def __radd__(self, x): return self.add(x, True)
53
+ def __rsub__(self, x): return self.sub(x, True)
54
+ def __rmul__(self, x): return self.mul(x, True)
55
+ def __rtruediv__(self, x): return self.div(x, True)
56
+ def __rfloordiv__(self, x): return self.idiv(x, True)
57
+ def __rand__(self, x): return self.bitwise_and(x, True)
58
+ def __ror__(self, x): return self.bitwise_or(x, True)
59
+ def __rxor__(self, x): return self.xor(x, True)
60
+
61
+ def lt(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
62
+ def gt(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
63
+ def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x))
64
+ def ge(self, x): return self.lt(x).logical_not()
65
+ def le(self, x): return self.gt(x).logical_not()
66
+ def eq(self, x): return self.ne(x).logical_not()
67
+
68
+ def __lt__(self, x): return self.lt(x)
69
+ def __gt__(self, x): return self.gt(x)
70
+ def __ne__(self, x): return self.ne(x)
71
+ def __ge__(self, x): return self.ge(x)
72
+ def __le__(self, x): return self.le(x)
73
+ # NOTE: __eq__ isn't overridden, and means the same thing as is by default
74
+
75
+ class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
76
+ # TODO: move to Tensor when new backward is done
77
+ def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse)
78
+ def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse)
79
+ def __lshift__(self, x): return self.lshift(x)
80
+ def __rshift__(self, x): return self.rshift(x)
81
+ def __rlshift__(self, x): return self.lshift(x, True)
82
+ def __rrshift__(self, x): return self.rshift(x, True)
83
+
84
+ # not in Tensor
85
+ def __mod__(self, x): return self.alu(Ops.MOD, self.ufix(x))
86
+ def __rmod__(self, x): return self.ufix(x).alu(Ops.MOD, self)
87
+
88
+ def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
89
+ def minimum(self, x): return -(-self).maximum(-x)
90
+ def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y))
91
+ def threefry(self, seed): return self.alu(Ops.THREEFRY, seed)
92
+ def reciprocal(self): return self.alu(Ops.RECIP)
93
+ def sqrt(self): return self.alu(Ops.SQRT)
94
+ def sin(self): return self.alu(Ops.SIN)
95
+ def log2(self): return self.alu(Ops.LOG2)
96
+ def exp2(self): return self.alu(Ops.EXP2)
97
+
98
+ # the order of these Ops controls the order of the toposort
99
+ class Ops(FastEnum):
100
+ # uops that aren't rendered
101
+ SINK = auto()
102
+ CONTIGUOUS = auto()
103
+ PRELOAD = auto()
104
+
105
+ # MetaOps
106
+ COPY = auto()
107
+ EMPTY = auto()
108
+ BUFFER_VIEW = auto()
109
+
110
+ EXPAND = auto()
111
+ CONTRACT = auto()
112
+ VIEW = auto()
113
+ DEFINE_GLOBAL = auto()
114
+ BUFFER = auto()
115
+ DEFINE_VAR = auto()
116
+ DEFINE_LOCAL = auto()
117
+ DEFINE_ACC = auto()
118
+ VALID = auto()
119
+ SPECIAL = auto()
120
+ NOOP = auto()
121
+
122
+ # reduce
123
+ REDUCE_AXIS = auto()
124
+
125
+ # helper ops
126
+ GEP = auto()
127
+ VECTORIZE = auto()
128
+
129
+ # UnaryOps
130
+ CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
131
+
132
+ # loads before math
133
+ LOAD = auto()
134
+
135
+ # math ops
136
+ WMMA = auto()
137
+
138
+ # BinaryOps
20
139
  ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
21
- SHR = auto(); SHL = auto() # noqa: E702
22
- class TernaryOps(Enum):
23
- """A + A + A -> A (elementwise)"""
140
+ SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702
141
+
142
+ # TernaryOps
24
143
  WHERE = auto(); MULACC = auto() # noqa: E702
25
- class ReduceOps(Enum):
26
- """A -> B (reduce)"""
27
- SUM = auto(); MAX = auto() # noqa: E702
28
- class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
29
- class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
30
144
 
31
- Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
145
+ # assignment ops
146
+ STORE = auto()
147
+ ASSIGN = auto()
148
+ BIND = auto()
32
149
 
33
- # do not preserve f(0) = 0
34
- UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
150
+ # late INDEX
151
+ INDEX = auto()
35
152
 
36
- @dataclass(frozen=True)
37
- class MemBuffer:
38
- idx: int
39
- dtype: DType
40
- st: ShapeTracker
153
+ # control flow ops
154
+ BARRIER = auto()
155
+ IF = auto()
156
+ RANGE = auto()
41
157
 
42
- @dataclass(frozen=True)
43
- class ConstBuffer:
44
- val: ConstType | Variable
45
- dtype: DType
46
- st: ShapeTracker
47
-
48
- @dataclass(frozen=True, eq=False)
49
- class LazyOp:
50
- op: Op
51
- src: Tuple[LazyOp, ...] = ()
52
- arg: Any = None
53
- def cached_compare(self, x, context):
54
- if id(self) == id(x): return True
55
- if self.op != x.op or self.arg != x.arg or len(self.src) != len(x.src): return False
56
- if (key := (id(self), id(x))) in context: return context[key]
57
- ret = context[key] = all(a.cached_compare(b, context) for a,b in zip(self.src, x.src))
158
+ # ops that are not graph nodes
159
+ ENDRANGE = auto()
160
+ ENDIF = auto()
161
+
162
+ # consts last!
163
+ VCONST = auto()
164
+ CONST = auto()
165
+
166
+ class GroupOp:
167
+ Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
168
+ Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
169
+ Ops.SUB, Ops.FDIV}
170
+ Ternary = {Ops.WHERE, Ops.MULACC}
171
+ ALU = set.union(Unary, Binary, Ternary)
172
+
173
+ Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
174
+
175
+ # meta ops
176
+ Meta = {Ops.COPY, Ops.EMPTY, Ops.BUFFER_VIEW}
177
+ Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
178
+
179
+ # BinaryOps that can be flipped
180
+ Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR}
181
+
182
+ # do not preserve f(0) = 0
183
+ UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
184
+
185
+ # https://en.wikipedia.org/wiki/Identity_element
186
+ def identity_element(op:Ops, dt:DType): return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
187
+
188
+ def can_pad(u:UOp) -> bool: return not any(x.op in GroupOp.UnsafePad for x in u.sparents)
189
+
190
+ END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
191
+
192
+ # With True as the default, this matches the old symbolic behavior
193
+ def resolve(x, default:bool=True):
194
+ if not isinstance(x, UOp): return bool(x)
195
+ assert x.dtype is dtypes.bool, "UOp in resolve must be bool"
196
+ # NOTE: generating the text for the exception is expensive, so we do this
197
+ return bool(sx.vmin) if (sx:=x.simplify()).vmin == sx.vmax else default
198
+
199
+ # smax/smin are replacements for max/min that preserve symbolic
200
+ def _suop(lst, uop_fxn, python_fxn):
201
+ max_uop, max_num = partition(lst, lambda x: isinstance(x, UOp))
202
+ if len(max_uop): return functools.reduce(uop_fxn, (max_uop + [python_fxn(max_num)]) if len(max_num) else max_uop).ssimplify()
203
+ return python_fxn(max_num)
204
+ def smax(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.maximum, max)
205
+ def smin(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.minimum, min)
206
+
207
+ def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
208
+ def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
209
+
210
+ # used for UOp and UPat
211
+ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
212
+ def dfs(x:Any, cache:dict):
213
+ for s in srcfn(x) or []:
214
+ cache.setdefault(s, [len(cache), 0, False])[1] += 1
215
+ if cache[s][1] == 1: dfs(s, cache)
216
+ if cache is None: dfs(x, cache:={})
217
+ if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
218
+ cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
219
+ return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
220
+
221
+ class UOpMetaClass(type):
222
+ ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
223
+ def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
224
+ if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
225
+ UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg)
58
226
  return ret
59
- def __eq__(self, x): return self.cached_compare(x, context={})
60
- def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
61
- @functools.cached_property
62
- def dtype(self) -> DType:
63
- if self.op in BufferOps: return self.arg.dtype
64
- if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
65
- return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype
66
227
 
228
+ class UOp(MathTrait, metaclass=UOpMetaClass):
229
+ __slots__ = ["op", "dtype", "src", "arg"]
230
+ def __init__(self, op:Ops, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
231
+ # TODO: instant check rules here make debugging easier
232
+ self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
233
+ def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg)
234
+ def replace(self, **kwargs) -> UOp:
235
+ for k in kwargs: assert k in self.__slots__, f"unkown replace arg, expected one of {self.__slots__}, got {k}"
236
+ new_args = (kwargs.get("op", self.op), kwargs.get("dtype", self.dtype), kwargs.get("src", self.src), kwargs.get("arg", self.arg))
237
+ if (self.op, self.dtype, self.src, self.arg) == new_args: return self
238
+ return UOp(*new_args)
67
239
  @functools.cached_property
68
240
  def key(self) -> bytes:
69
- return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.arg)).encode())).digest()
241
+ return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
242
+ def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
243
+ def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else self.arg
244
+ @functools.cached_property
245
+ def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
246
+ @functools.cached_property # parents with self
247
+ def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
248
+
249
+ @functools.cached_property
250
+ def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]:
251
+ return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src))
252
+
253
+ # *** uop shape stuff ***
254
+
255
+ @property
256
+ def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR}
70
257
  @functools.cached_property
71
- def hash(self): return hash((self.op, self.src, self.arg))
72
- def __hash__(self): return self.hash
258
+ def st(self) -> Optional[ShapeTracker]:
259
+ if not self.has_st: return None
260
+ if self.op in GroupOp.Buffer: return self.st_arg
261
+ if self.op is Ops.VIEW: return self.arg
262
+ src_sts = [x.st for x in self.src if x.st is not None]
263
+ assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
264
+ from tinygrad.shape.shapetracker import ShapeTracker
265
+ return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is Ops.REDUCE_AXIS else src_sts[0]
73
266
  @functools.cached_property
74
- def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops])
75
- def vars(self) -> List[Variable]:
76
- extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps]
77
- const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
78
- return sorted(set.union(*extract_vars, set(const_vars)), key=lambda v: v.expr)
79
-
80
- # **************** independent FlopCounter ****************
81
-
82
- @dataclass
83
- class FlopCounter:
84
- shape: Tuple[int, ...]
85
- flops: sint
86
- mem: Dict[int, int]
267
+ def full_shape(self) -> Tuple[sint, ...]:
268
+ return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
269
+
270
+ # *** uop evaluation ***
271
+
272
+ def simplify(self):
273
+ with Context(TRACK_MATCH_STATS=0):
274
+ return graph_rewrite(self, symbolic)
275
+ def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
276
+ def _eval(self, dtype, expected_type:Type[T]) -> T:
277
+ assert self.dtype in dtype, f"eval with wrong dtype {self}"
278
+ vmin, vmax = (simple_self:=self.simplify())._min_max
279
+ if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self.render()}")
280
+ assert isinstance(vmin, expected_type), f"vmin is wrong dtype {type(vmin)} != {expected_type}"
281
+ return vmin
282
+ def __bool__(self): return self._eval((dtypes.bool,), bool)
283
+ def __int__(self): return self._eval(dtypes.ints, int)
284
+ def __float__(self): return self._eval(dtypes.floats, float)
285
+ def substitute(self, dvars:Dict[UOp, UOp]):
286
+ with Context(TRACK_MATCH_STATS=0):
287
+ return graph_rewrite(self, _substitute, dvars)
288
+
289
+ # *** uop syntactic sugar ***
290
+
291
+ @property
292
+ def st_arg(self) -> ShapeTracker:
293
+ assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}"
294
+ ret = self.src[0 if self.op is Ops.VALID else 1]
295
+ assert ret.op is Ops.VIEW, f"st_arg trying to return {ret}"
296
+ return ret.arg
87
297
  @property
88
- def mem_estimate(self): return sum(self.mem.values())
89
- def consume_flops(self):
90
- self.flops, ret = 0, self.flops
298
+ def axis_arg(self) -> Tuple[int, ...]:
299
+ assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
300
+ ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
301
+ assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
91
302
  return ret
303
+ def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
304
+ def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
305
+ def const_like(self, b:ConstLike): return UOp.const(self.dtype, b)
306
+ def broadcast(self, count:int):
307
+ assert self.dtype.count == 1
308
+ if count == 1: return self
309
+ return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
310
+ def cast(self, dtype:DType): return UOp(Ops.CAST, dtype, (self,))
311
+ def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
312
+ def gep(self, i:Union[Tuple[int, ...], int]):
313
+ if isinstance(i, int):
314
+ # NOTE: these are just shortcuts to not have to create and fold later
315
+ if self.op is Ops.VECTORIZE: return self.src[i]
316
+ if self.op is Ops.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
317
+ if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg)
318
+ i = (i,)
319
+ if (self.dtype.vcount == len(i) and i == tuple(range(len(i)))) or self.dtype == dtypes.void: return self
320
+ return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
321
+ def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, src=(self,)+src, **kwargs)
322
+ def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
323
+ def alu(self, arg, *src:UOp):
324
+ out_dtype = (self, *src)[-1].dtype
325
+ if arg in {Ops.CMPLT, Ops.CMPNE} and out_dtype is not None:
326
+ out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
327
+ return UOp(arg, out_dtype, (self,)+src)
328
+ @staticmethod
329
+ def const(dtype:DType, b:ConstLike):
330
+ if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
331
+ if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
332
+ return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
333
+ @staticmethod
334
+ def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
335
+ return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
336
+ UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=(idx, False))
337
+ def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
338
+ def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
339
+ def contiguous(self): return UOp(Ops.CONTIGUOUS, self.dtype, (self,))
340
+ @property
341
+ def is_contiguous_base(self): return self.op is Ops.CONTIGUOUS and not (self.src[0].base.op is Ops.VIEW and len(self.src[0].base.src) == 2)
92
342
 
93
- InterpretedFlopCounter: Dict[Op, Callable] = {
94
- BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, 0, {arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
95
- BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
96
- BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
97
- UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
98
- UnaryOps.BITCAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # bitcast uses no flops
99
- **{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op not in {UnaryOps.CAST, UnaryOps.BITCAST}}, # noqa: E501
100
- **{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
101
- **{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
102
- TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
343
+ # *** uop movement ops ***
103
344
 
104
- @functools.lru_cache(None)
105
- def get_lazyop_info(ast:LazyOp) -> FlopCounter:
106
- @functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs
107
- def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
108
- return run_ast(ast)
345
+ @property
346
+ def base(self): return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 else self
347
+ def view(self, st:ShapeTracker):
348
+ assert self.op is not Ops.STORE, "VIEW of STORE is invalid, STORE is always base"
349
+ return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st)
350
+ def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
109
351
 
110
- # **************** ops in python ****************
352
+ # *** uop Buffer stuff ***
353
+
354
+ @staticmethod
355
+ def new_buffer(device:str, size:int, dtype:DType, num=-1): return UOp(Ops.BUFFER, dtype.ptr(), (), (num, (device, size, dtype)))
356
+
357
+ @property
358
+ def buf_uop(self) -> UOp:
359
+ assert self.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}"
360
+ return self.src[0]
361
+
362
+ # *** uop Variable stuff ***
363
+
364
+ @staticmethod
365
+ def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int):
366
+ assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
367
+ return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
368
+ @property
369
+ def expr(self):
370
+ assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
371
+ return self.arg[0]
372
+ def bind(self, val:int):
373
+ assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
374
+ assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
375
+ return UOp(Ops.BIND, self.dtype, (self, self.const_like(val)))
376
+ def unbind(self) -> Tuple[Variable, int]:
377
+ assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
378
+ return self.src[0], self.src[1].arg
379
+ @property
380
+ def val(self) -> int: return self.unbind()[1]
381
+ def vars(self) -> Set[UOp]:
382
+ bound_vars = set([x for x in self.sparents if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
383
+ bound_var_base = set(x.src[0] for x in bound_vars)
384
+ all_vars = set([x for x in self.sparents if x.op is Ops.DEFINE_VAR])
385
+ return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
386
+ def variables(self) -> List[Variable]:
387
+ st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in GroupOp.Buffer]
388
+ return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
389
+
390
+ # *** uop symbolic stuff ***
391
+
392
+ def const_factor(self) -> int:
393
+ """largest known int that divides self"""
394
+ if self.op is Ops.CONST: return self.arg
395
+ if self.op is Ops.VCONST: return math.gcd(*self.arg)
396
+ if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
397
+ if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
398
+ return 1
399
+ def divides(self, v) -> Optional[UOp]:
400
+ if v==1: return self
401
+ if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
402
+ if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
403
+ if self.op is Ops.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
404
+ if self.op is Ops.MUL:
405
+ if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
406
+ if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
407
+ return None # generic None if we aren't sure
408
+ @property
409
+ def vmin(self) -> ConstType: return self._min_max[0]
410
+ @property
411
+ def vmax(self) -> ConstType: return self._min_max[1]
412
+ @functools.cached_property
413
+ def _min_max(self) -> Tuple[ConstType, ConstType]:
414
+ if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype):
415
+ (s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
416
+ if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
417
+ if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
418
+ if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1
419
+ if self.op is Ops.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST
420
+ if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
421
+ if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
422
+ if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
423
+ if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
424
+ if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
425
+ if self.dtype == dtypes.bool:
426
+ if self.op is Ops.OR: return s0_vmin or s1_vmin, s0_vmax or s1_vmax
427
+ if self.op is Ops.AND: return s0_vmin and s1_vmin, s0_vmax and s1_vmax
428
+ # float has NAN issue and we use explicit NAN in transcendental
429
+ if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax)
430
+ # NOTE: returned UOp is assumed to be CONST
431
+ if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
432
+ if self.op is Ops.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
433
+ if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
434
+ if self.op in {Ops.EXPAND, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
435
+ # TODO: UOps.SPECIAL is UOps.DEFINE_VAR
436
+ if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype)
437
+ if self.op is Ops.CONST: return self.arg, self.arg
438
+ if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
439
+ return dtypes.min(self.dtype), dtypes.max(self.dtype)
440
+
441
+ @functools.cached_property
442
+ def _sym_fxn(self):
443
+ sself = self.simplify()
444
+ varnames = tuple(x.arg[0] for x in sself.sparents if x.op is Ops.DEFINE_VAR)
445
+ # TODO: sanitize varnames, or don't use naked eval while staying fast
446
+ return eval("lambda "+','.join(varnames)+": "+sself.render()), varnames # pylint: disable=eval-used
447
+
448
+ def sym_infer(self, var_vals:Dict[UOp, int]):
449
+ fxn, varnames = self._sym_fxn
450
+ return fxn(**{k.arg[0]:v for k,v in var_vals.items() if k.arg[0] in varnames})
451
+
452
+ def render(self, simplify=True) -> str:
453
+ ret = graph_rewrite(self.simplify() if simplify else self, renderer)
454
+ return ret.arg if ret.op is Ops.NOOP else str(ret)
455
+
456
+ @dataclass(frozen=True)
457
+ class KernelInfo:
458
+ local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
459
+ upcasted: int = 0 # count that are upcasted (this is remapping RANGE to EXPAND)
460
+ dont_use_locals: bool = False # don't use local indexing
461
+
462
+ # ***** ops in python *****
111
463
 
112
464
  def hook_overflow(dv, fxn):
113
465
  def wfxn(*args):
@@ -115,55 +467,686 @@ def hook_overflow(dv, fxn):
115
467
  except OverflowError: return dv
116
468
  return wfxn
117
469
 
118
- python_alu = {
119
- UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
120
- UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
121
- UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan,
122
- UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
123
- UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
124
- UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
125
- BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
126
- BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
127
- BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
128
- BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x, y: int(x/y) if y != 0 else x*math.inf,
129
- TernaryOps.WHERE: lambda x,y,z: y if x else z}
130
-
131
- def truncate_fp16(x):
132
- try:
133
- x = float(x)
134
- struct.pack("@e", x)
135
- return x
136
- except OverflowError: return math.copysign(math.inf, x)
137
-
138
- truncate: Dict[DType, Callable] = {dtypes.bool: bool,
139
- # TODO: bfloat16
140
- dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
141
- dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
142
- dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
143
- dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
144
- dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,}
145
-
146
- def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
147
-
148
- # the living definition of LazyOps
149
- def verify_lazyop(*ast:LazyOp):
150
- sts: Dict[LazyOp, ShapeTracker] = {}
151
- def dfs(op:LazyOp, st:ShapeTracker):
152
- if op in sts: return
153
- for x in op.src: dfs(x, st)
154
- # only reduceop is allowed to change shape, limited to turning n to 1
155
- if op.op in ReduceOps:
156
- expected_shape = tuple(1 if i in op.arg else s for i,s in enumerate(sts[op.src[0]].shape))
157
- assert st.shape == expected_shape, f"unexpected reduceop shape {st.shape} != {expected_shape}"
158
- st = ShapeTracker.from_shape(expected_shape)
470
+ python_alu: Dict[Ops, Callable] = {
471
+ Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: hook_overflow(math.inf, lambda x: 2**x),
472
+ Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
473
+ Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
474
+ Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul,
475
+ Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
476
+ Ops.MAX: max, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt, Ops.XOR: operator.xor,
477
+ Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift,
478
+ Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z}
479
+
480
+ def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
481
+ if dtype.count > 1:
482
+ return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
483
+ alu = python_alu[op](*operands)
484
+ return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu
485
+
486
+ # ***** uop helpers *****
487
+
488
+ def print_uops(uops:List[UOp]):
489
+ for i,u in enumerate(uops):
490
+ formatted_parents = [(uops.index(x) if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
491
+ print(f"{i:4d} {str(u.op):20s}: {str(u.dtype):30s} " f"{str(formatted_parents):32s} {u.arg}")
492
+
493
+ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
494
+ flops: sint = 0
495
+ mem: sint = 0
496
+ mults: sint = 1
497
+ mult_stack: List[sint] = []
498
+ dont_count: Set[UOp] = set()
499
+ if ignore_indexing:
500
+ for u in uops:
501
+ if u.op in {Ops.LOAD, Ops.STORE}:
502
+ dont_count = dont_count.union(u.src[0].sparents)
503
+ if len(u.src) > 2: dont_count = dont_count.union(u.src[2].sparents)
504
+ elif u.op is Ops.IF:
505
+ dont_count = dont_count.union(u.src[0].sparents)
506
+ for u in uops:
507
+ if u.op is Ops.RANGE:
508
+ mult_stack.append(mults)
509
+ mults *= (u.src[1] - u.src[0]).ssimplify()
510
+ elif u.op is Ops.ENDRANGE:
511
+ mults = mult_stack.pop(-1)
512
+ elif u.op is Ops.SPECIAL:
513
+ mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
514
+ elif u.op is Ops.LOAD:
515
+ mem += u.dtype.itemsize * mults
516
+ elif u.op is Ops.STORE:
517
+ mem += u.src[1].dtype.itemsize * mults
518
+ elif u.op in GroupOp.ALU and u not in dont_count:
519
+ flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
520
+ elif u.op is Ops.WMMA and u not in dont_count:
521
+ flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
522
+ return flops, mem
523
+
524
+ # ***** pattern matcher *****
525
+
526
+ def get_location() -> Tuple[str, int]:
527
+ frm = sys._getframe(1)
528
+ # find the real frame in the file that has the UPat, TODO: is there a better way to do this?
529
+ while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "uopgraph.py", "schedule.py",
530
+ "lowerer.py", "cstyle.py"}:
531
+ frm = frm.f_back
532
+ return frm.f_code.co_filename, frm.f_lineno
533
+ @functools.lru_cache(None)
534
+ def lines(fn) -> List[str]:
535
+ with open(fn) as f: return f.readlines()
536
+
537
+ class UPat(MathTrait):
538
+ __slots__ = ["op", "dtype", "arg", "name", "src"]
539
+ def __init__(self, op:Optional[Union[Ops, Tuple[Ops, ...], Set[Ops]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
540
+ src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
541
+ name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[Set[Ops]]=None):
542
+ assert op is None or isinstance(op, Ops) or isinstance(op, tuple) or isinstance(op, set), "op must be Ops or tuple of Ops"
543
+ self.op: Optional[Tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
544
+ self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
545
+ self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
546
+ self.src: Any = None
547
+ assert self.name != "ctx", "UPat can't be named ctx"
548
+
549
+ # try all permutations if it's a list
550
+ if isinstance(src, list): self.src = list(itertools.permutations(src)) if not all_same(src) else [src]
551
+ # only one if it's a tuple
552
+ elif isinstance(src, tuple): self.src = [src]
553
+ # repeat if it's a UPat
554
+ elif isinstance(src, UPat): self.src = [itertools.repeat(src)]
555
+
556
+ self.allowed_len: int = -1 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
557
+ self.location = location or get_location()
558
+
559
+ if custom_early_reject is not None: self.early_reject = custom_early_reject
159
560
  else:
160
- # movementops are pushed to the edges with LOAD
161
- if op.op in BufferOps: st = op.arg.st
162
- else: st = sts[op.src[0]]
163
- for x in op.src: assert sts[x].shape == st.shape, f"found implicit movement op {x.op} {sts[x].shape} != {op.op} {st.shape}"
164
- sts[op] = st
165
- for i, out in enumerate(ast):
166
- assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}"
167
- assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}"
168
- assert out.arg.st.size == ast[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
169
- dfs(out, out.arg.st)
561
+ upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0])
562
+ self.early_reject = {pp.op[0] for pp in upat_match if pp.op is not None and len(pp.op) == 1}
563
+
564
+ def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, self.allowed_len == -1, self.custom_early_reject)
565
+
566
+ @staticmethod
567
+ def any(*src): return UPatAny(src=src)
568
+
569
+ @staticmethod
570
+ @functools.lru_cache(None)
571
+ def var(name:Optional[str]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name)
572
+ @staticmethod
573
+ @functools.lru_cache(None)
574
+ def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True):
575
+ return UPat((Ops.CONST, Ops.VCONST) if vec else Ops.CONST, dtype=dtype, name=name)
576
+ @staticmethod
577
+ def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
578
+
579
+ # copied from UOp
580
+ def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
581
+ def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
582
+ def cast(self, dtype=None): return UPat(Ops.CAST, dtype, (self,))
583
+ def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
584
+ def gep(self, i:int): return UPat(Ops.GEP, None, (self,), (i,))
585
+ def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
586
+ def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
587
+ def assign(self, x:UPat): return UPat(Ops.ASSIGN, self.dtype, (self,x))
588
+
589
+ def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
590
+ def alu(self, op:Ops, *src:UPat):
591
+ asrc = (self,)+src
592
+ return UPat(op, None if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
593
+
594
+ def printable(self:UPat) -> str:
595
+ try: return lines(self.location[0])[self.location[1]-1].strip()
596
+ except FileNotFoundError: return "<missing>"
597
+
598
+ def __repr__(self):
599
+ def rep(x):
600
+ form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
601
+ return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name),
602
+ set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
603
+ return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
604
+
605
+ def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
606
+ if (self.op is not None and uop.op not in self.op) or \
607
+ (self.name is not None and store.setdefault(self.name, uop) is not uop) or \
608
+ (self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
609
+ (self.arg is not None and self.arg != uop.arg) or \
610
+ (self.allowed_len != -1 and len(uop.src) != self.allowed_len): return []
611
+ if self.src is None: return [store]
612
+ res: List[Dict[str, UOp]] = []
613
+ for vp in self.src:
614
+ stores, new_stores = [store.copy()], []
615
+ for uu, vv in zip(uop.src, vp):
616
+ for s in stores: new_stores.extend(vv.match(uu, s))
617
+ stores, new_stores = new_stores, []
618
+ res.extend(stores)
619
+ return res
620
+
621
+ class UPatAny(UPat):
622
+ def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
623
+ ret = []
624
+ for x in self.src[0]:
625
+ if (match:=x.match(uop, store.copy())): ret.extend(match)
626
+ return ret
627
+
628
+ def deconstruct_function(fxn:Callable) -> Tuple:
629
+ new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
630
+ for co in fxn.__code__.co_consts:
631
+ if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names})
632
+ # NOTE: optional round trip through pickle!
633
+ assert fxn.__closure__ is None, "closures are not supported in pattern matchers"
634
+ ret = fxn.__code__, new_globals, fxn.__name__, fxn.__defaults__
635
+ return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret
636
+
637
+ class PatternMatcher:
638
+ def __init__(self, patterns:List[Tuple[UPat, Callable]]):
639
+ self.patterns = patterns
640
+ # NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
641
+ self.pdict: Dict[Ops, List[Tuple[UPat, Callable, Set, bool]]] = {}
642
+ # uop is required, arg is optional
643
+ for p,fxn in self.patterns:
644
+ assert p.op is not None
645
+ tuple_fxn = fxn if isinstance(fxn, tuple) else deconstruct_function(fxn)
646
+ real_fxn = types.FunctionType(*tuple_fxn)
647
+ for uop in p.op: self.pdict.setdefault(uop, []).append((p, real_fxn, p.early_reject, 'ctx' in inspect.signature(real_fxn).parameters))
648
+
649
+ def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "<lambda>" else fxn) for x,fxn in self.patterns],)
650
+
651
+ @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
652
+ def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
653
+
654
+ def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
655
+ ler = {u.op for u in uop.src}
656
+ for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
657
+ if not early_reject.issubset(ler): continue
658
+ for match in p.match(uop, {}):
659
+ if (ret:=(fxn(ctx=ctx, **match) if has_ctx else fxn(**match))) is not None: return ret
660
+ return None
661
+
662
+ # *** tracking pattern matcher ***
663
+
664
+ TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
665
+ match_stats:Dict[UPat, List[Union[int, float]]] = dict()
666
+ @dataclass(frozen=True)
667
+ class TrackedRewriteContext:
668
+ loc: Tuple[str, int] # location that called graph_rewrite
669
+ sink: UOp # the sink passed into the rewrite
670
+ matches: List[Tuple[UOp, Optional[UOp], Optional[UPat], float]] = field(default_factory=list) # all matches of sparents
671
+
672
+ rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
673
+ contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = []
674
+ _rewrite_cnt: Dict[str, int] = {}
675
+ def track_rewrites(named=False):
676
+ def _decorator(func):
677
+ def __wrapper(self, *args, **kwargs):
678
+ if TRACK_MATCH_STATS >= 2:
679
+ if named: _rewrite_cnt[func.__name__] = _rewrite_cnt.setdefault(func.__name__, 0)+1
680
+ rewrite_stack.append((f"{(n:=func.__name__)}_{_rewrite_cnt[n]}" if named else self, []))
681
+ try: ret = func(self, *args, **kwargs)
682
+ finally: # NOTE: save everything in the stack
683
+ if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
684
+ return ret
685
+ return __wrapper
686
+ return _decorator
687
+
688
+ class TrackedPatternMatcher(PatternMatcher):
689
+ def __init__(self, patterns:List[Tuple[UPat, Callable]]):
690
+ super().__init__(patterns)
691
+ for p,_ in self.patterns:
692
+ if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
693
+
694
+ def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
695
+ ret = None
696
+ ler = {u.op for u in uop.src}
697
+ for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
698
+ st = time.perf_counter()
699
+ if not early_reject.issubset(ler):
700
+ match_stats[p][2] += time.perf_counter()-st
701
+ continue
702
+ match_stats[p][1] += 1
703
+ for match in p.match(uop, {}):
704
+ if (ret:=(fxn(ctx=ctx, **match) if has_ctx else fxn(**match))) is not None:
705
+ match_stats[p][0] += 1
706
+ match_stats[p][3] += (et:=time.perf_counter()-st)
707
+ if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
708
+ if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].matches.append((uop, ret, p, et))
709
+ return ret # NOTE: if it returns None, we keep trying to match
710
+ match_stats[p][2] += time.perf_counter()-st
711
+ if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0: rewrite_stack[-1][1][-1].matches.append((uop, ret, None, 0))
712
+ return None
713
+
714
+ if TRACK_MATCH_STATS:
715
+ PatternMatcher = TrackedPatternMatcher # type: ignore
716
+ import atexit
717
+ @atexit.register
718
+ def print_match_stats():
719
+ if TRACK_MATCH_STATS >= 2:
720
+ with open(fn:=temp("rewrites.pkl"), "wb") as f:
721
+ print(f"rewrote {len(contexts)} graphs and matched {sum(len(r.matches) for _,x in contexts for r in x)} times, saved to {fn}")
722
+ pickle.dump(contexts, f)
723
+ if getenv("VIZ"):
724
+ os.environ["VIZ"] = "0"
725
+ os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py"), temp("rewrites.pkl")])
726
+ if getenv("PRINT_MATCH_STATS", 1):
727
+ ret = [0,0,0.0,0.0]
728
+ for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
729
+ loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
730
+ if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {(v[2]+v[3])*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
731
+ ret = [x+y for x,y in zip(ret, v)]
732
+ print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {(ret[2]+ret[3])*1000.:9.2f} ms -- TOTAL")
733
+
734
+ # *** simple graph rewrite engine ***
735
+
736
+ class RewriteContext:
737
+ def __init__(self, pm, ctx):
738
+ self.pm: PatternMatcher = pm
739
+ self.ctx = ctx
740
+ self.replace: Dict[UOp, UOp] = {}
741
+ def rewrite(self, n:UOp) -> UOp:
742
+ if (rn := self.replace.get(n)) is not None: return rn
743
+ new_src = tuple(map(self.rewrite, n.src))
744
+ new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg)
745
+ self.replace[n] = ret = n if new_n is None else self.rewrite(new_n)
746
+ return ret
747
+
748
+ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
749
+ if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0:
750
+ rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
751
+ return RewriteContext(pm, ctx).rewrite(sink)
752
+
753
+ # ***** uop type spec *****
754
+
755
+ # this is the matcher for the final rendered UOps
756
+ # matcher functions returns True or False (or None to not match)
757
+ spec = PatternMatcher([
758
+ (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
759
+ (UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
760
+ (UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True),
761
+ lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
762
+ (UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
763
+
764
+ (UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
765
+ (UPat(Ops.SPECIAL, src=()), lambda: True),
766
+
767
+ # TODO: confirm the args of both of these are shapetrackers
768
+ (UPat(Ops.VIEW, dtypes.void, src=()), lambda: True),
769
+ (UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype),
770
+
771
+ (UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
772
+ (UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
773
+
774
+ # early LOAD has a <buf, shapetracker, store?>
775
+ (UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
776
+ (UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
777
+
778
+ # early STORE has a <buf, shapetracker, val>
779
+ (UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
780
+
781
+ # **** new style load/store ****
782
+
783
+ # INDEX is used in new style load/store
784
+ (UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
785
+
786
+ # LOAD takes a <bufidx, alt?, gate?, barrier?>
787
+ (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
788
+ (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
789
+ (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
790
+
791
+ # STORE takes a <bufidx, val, gate?>
792
+ (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True),
793
+ (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
794
+ (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True),
795
+
796
+ # most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
797
+ (UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
798
+ (UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype == y.dtype),
799
+ # and SHL/SHR, the shift distance can be an int
800
+ (UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
801
+ (UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
802
+ (UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
803
+
804
+ (UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
805
+ (UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
806
+
807
+ # all WMMA has 3 args, <x, w, acc>
808
+ (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
809
+ (UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
810
+ (UPat(Ops.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
811
+
812
+ # if has a <gate, barrier?>
813
+ (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
814
+ (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
815
+ (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
816
+
817
+ (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
818
+ (UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
819
+ (UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
820
+ (UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
821
+ (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
822
+
823
+ # NOTE: for testing, we let sinks be anything
824
+ #(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
825
+ (UPat(Ops.SINK, dtypes.void), lambda: True),
826
+ (UPat(Ops.NOOP), lambda: True),
827
+
828
+ # PTX LOAD/STORE
829
+ (UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
830
+ (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True),
831
+ ])
832
+
833
+ def type_verify(uops:List[UOp]):
834
+ for i,u in enumerate(uops):
835
+ if not spec.rewrite(u):
836
+ print_uops(uops)
837
+ raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}")
838
+
839
+ # *** uop helpers ***
840
+
841
+ def cast_float_to_bf16(x: UOp) -> UOp:
842
+ assert x.dtype == dtypes.float, "cast float -> bf16 must start with float"
843
+ x = x.bitcast(dtypes.uint)
844
+ x = (-x & 0x7f800000).where(x + ((x >> 16) & 1) + 0x7fff, (x & 0xffff).where((x | 0x10000), x))
845
+ return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
846
+
847
+ # *** most of symbolic lives here now ***
848
+
849
+ def split_uop(x:UOp, sep:Ops):
850
+ if x.op is sep:
851
+ for s in x.src: yield from split_uop(s, sep)
852
+ else: yield x
853
+
854
+ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
855
+ # simplify x % c, None means no change
856
+
857
+ # simple cancel mod case
858
+ if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c
859
+
860
+ remainder, something_changed = [], False
861
+ for u in split_uop(x, Ops.ADD):
862
+ if (factor:=u.const_factor())%c != factor:
863
+ divides = u.divides(factor)*(factor%c)
864
+ assert divides is not None
865
+ remainder.append(divides)
866
+ something_changed = True
867
+ elif u.op is Ops.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0:
868
+ remainder.append(u.src[0])
869
+ something_changed = True
870
+ else: remainder.append(u)
871
+ if not something_changed: return None
872
+ return functools.reduce(operator.add, remainder)%c if remainder else x.const_like(0)
873
+
874
+ def div_folding(x:UOp, c:int) -> Optional[UOp]:
875
+ # simplify x // c, None means no change
876
+
877
+ # simple cancel div case
878
+ if 0 <= x.vmin and x.vmax < c: return x.const_like(0)
879
+
880
+ quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1
881
+ for u in split_uop(x, Ops.ADD):
882
+ if u.op is Ops.CONST:
883
+ # add all const together first
884
+ if rem_const != 0: something_changed = True
885
+ rem_const += u.arg
886
+ elif (factor:=u.const_factor())%c == 0:
887
+ if factor:
888
+ divides = u.divides(c)
889
+ assert divides is not None
890
+ quotient.append(divides)
891
+ something_changed = True
892
+ else:
893
+ # divisor is the smallest common divisor of all MULs
894
+ if u.op is Ops.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor
895
+ remainder.append(u)
896
+ gcd = math.gcd(gcd, factor)
897
+
898
+ # handle the const
899
+ if rem_const%c != rem_const:
900
+ something_changed = True
901
+ quotient.append(x.const_like(rem_const//c))
902
+ rem_const = rem_const%c
903
+ if rem_const != 0: remainder.append(x.const_like(rem_const))
904
+
905
+ # x // c -> quotient + (remainder // div) // (c // div)
906
+ div = gcd if gcd > 1 else divisor
907
+
908
+ if not something_changed: return newx//(c//div) if 1 < div < c and (newx:=div_folding(x, div)) is not None else None
909
+ rem:Optional[UOp] = functools.reduce(operator.add, remainder) if remainder else None
910
+ quo:Optional[UOp] = functools.reduce(operator.add, quotient) if quotient else None
911
+ if quo is None: return x.const_like(0) if rem is None else cast(UOp, div_folding(rem, div))//(c//div)
912
+ return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo
913
+
914
+ def lt_folding(x:UOp, c:int) -> Optional[UOp]:
915
+ p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
916
+ if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
917
+ return cast(UOp, functools.reduce(operator.add, np).divides(d)).lt(c//d)
918
+ return None
919
+
920
+ def fold_unrolled_divs(divs:UOp):
921
+ # div pattern in unrolled arange
922
+ # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
923
+ add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None
924
+ for u in add_chain:
925
+ if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
926
+ if denominator is None: denominator = u.src[1].arg
927
+ if denominator != u.src[1].arg: return None
928
+ # assumed CONST is the last of an ADD
929
+ if (s0:=u.src[0]).op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
930
+ seen_const.append(s0.src[1].arg)
931
+ s0 = s0.src[0]
932
+ else: seen_const.append(0)
933
+ if ans is None: ans = s0
934
+ if ans is not s0: return None
935
+ if denominator is None: return None
936
+ # the first (denominator-len(seen_const)) terms may have been folded to 0 already
937
+ for i in range(denominator-len(seen_const)):
938
+ if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
939
+ return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
940
+
941
+ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
942
+ # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
943
+ # returns x0 + x1 + ... in such case, or None if not
944
+ changed, ret = False, []
945
+ for u in split_uop(X, Ops.ADD):
946
+ # assumed the const is the last src of MUL
947
+ if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
948
+ changed = True
949
+ u = u.src[0]
950
+ if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None
951
+ ret.append(u)
952
+ return functools.reduce(operator.add, ret) if changed else None
953
+
954
+ def is_increasing(f:UOp) -> bool:
955
+ # is f a monotonically increasing function regards its input
956
+ if f.op in GroupOp.Irreducible: return True
957
+ if f.op is Ops.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
958
+ if f.op in (Ops.MUL, Ops.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
959
+ return False # False if not sure
960
+
961
+ def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
962
+ # if it's X <= c, returns X, True, c
963
+ # if it's X >= c, returns X, False, c
964
+
965
+ # (X < c).ne(True) -> X >= c
966
+ if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
967
+ (s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
968
+ # X < c -> X <= c-1
969
+ if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1
970
+ raise ValueError(f"not able to parse {valid=}")
971
+
972
+ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
973
+ # return None if valid is always False, otherwise the simplified uop (might be the same as input)
974
+
975
+ # first, parse valid into {expr: (lower_bound, upper_bound)}
976
+ bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
977
+ for stmt in split_uop(valid, Ops.AND):
978
+ try: expr, is_upper, c = parse_valid(stmt)
979
+ except ValueError: return uop # give up if we cannot parse the valid
980
+ bounds[expr][int(is_upper)] = c
981
+
982
+ # simplify uop given that valid is True
983
+ for expr,v in bounds.items():
984
+ # some expr has lower bound > upper bound -> valid is an empty set and we return None
985
+ if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
986
+
987
+ # every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
988
+ candidates = []
989
+ if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
990
+ # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
991
+ candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
992
+ # try checking the whole clause
993
+ if expr in uop.sparents:
994
+ candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
995
+
996
+ for candidate in candidates:
997
+ # if every branch in candidate gives the same simplified uop, we can rewrite the uop
998
+ newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
999
+ if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
1000
+ if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
1001
+ if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
1002
+ elif all_same(newuops): uop = newuops[0]
1003
+
1004
+ return uop
1005
+
1006
+ def _valid_priority(v: UOp, valids:List[UOp]):
1007
+ # we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
1008
+ try: return sum(-1 if parse_valid(v)[0] in other.parents else 0 for other in valids)
1009
+ except ValueError: return 0
1010
+
1011
+ def simplify_valid(valid:UOp) -> Optional[UOp]:
1012
+ ret:List[UOp] = []
1013
+ something_changed = False
1014
+ valids = list(split_uop(valid, Ops.AND))
1015
+ for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
1016
+ ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
1017
+ if ret[-1] is not stmt: something_changed = True
1018
+ return functools.reduce(operator.and_, ret) if something_changed else None
1019
+
1020
+ def max_var_const(x:UOp, c1:UOp, c2:UOp):
1021
+ if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
1022
+ if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
1023
+
1024
+ symbolic_simple = PatternMatcher([
1025
+ # ** self folding **
1026
+ (UPat.var("x") + 0, lambda x: x), # x+0 -> x
1027
+ (UPat.var("x") * 1, lambda x: x), # x*1 -> x
1028
+ (UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
1029
+ (UPat.var("x") // 1, lambda x: x), # x//1 -> x
1030
+ (UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
1031
+ (UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
1032
+ ((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
1033
+ ((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
1034
+ (UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
1035
+ (UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
1036
+ (UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
1037
+ (UPat.var("x").maximum(UPat.var("x")), lambda x: x),
1038
+ ((UPat.var("x") & UPat.var("x")), lambda x: x),
1039
+ ((UPat.var("x") | UPat.var("x")), lambda x: x),
1040
+ (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
1041
+ (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
1042
+ # ** zero folding **
1043
+ (UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
1044
+ (UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
1045
+ lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints)
1046
+ # x*0 -> 0 or 0*x -> 0
1047
+ # if x is nan or inf it should render the nan value.
1048
+ # NOTE: this can be wrong for loaded NaN
1049
+ (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
1050
+ # ** constant folding **
1051
+ (UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))), lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False))),
1052
+ # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
1053
+ (UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
1054
+ (UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
1055
+ (UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
1056
+ # *** cast ***
1057
+ (UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
1058
+ (UPat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
1059
+ ])
1060
+
1061
+ symbolic = symbolic_simple+PatternMatcher([
1062
+ # ** COMMUTATIVE flipping **
1063
+ (UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
1064
+ # group like
1065
+ ((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
1066
+ # ** boolean algebra **
1067
+ (UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
1068
+ # ** combine terms **
1069
+ (UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
1070
+ (UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
1071
+ (UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
1072
+ ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
1073
+ (-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
1074
+ # a conditional with the same results either way is a noop, also fold const conditionals
1075
+ (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
1076
+ (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
1077
+ # ALU min==max -> CONST (slow!)
1078
+ (UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
1079
+ # max folding
1080
+ (UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
1081
+ # TODO: why does this rule break beautiful_mnist?
1082
+ #((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z),
1083
+ ((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
1084
+ # ** two stage ALU folding **
1085
+ ((UPat.var("x") + UPat.cvar("c1")) + UPat.cvar("c2"), lambda x,c1,c2: x+(c1+c2)),
1086
+ ((UPat.var("x") * UPat.cvar("c1")) * UPat.cvar("c2"), lambda x,c1,c2: x*(c1*c2)),
1087
+ ((UPat.var("x") & UPat.cvar("c1")) & UPat.cvar("c2"), lambda x,c1,c2: x&(c1&c2)),
1088
+ ((UPat.var("x") | UPat.cvar("c1")) | UPat.cvar("c2"), lambda x,c1,c2: x|(c1|c2)),
1089
+ ((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
1090
+ ((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
1091
+ # ** lt **
1092
+ # c0*x<c1 for positive int c0,c1
1093
+ ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
1094
+ lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if c0.arg > 0 and c1.arg > 0 else None),
1095
+ # c0*x<c1 for negative int c0 and non-positive c1
1096
+ ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
1097
+ lambda x,c0,c1: (-x).lt(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
1098
+ # x//c0<c1 for positive int c0
1099
+ ((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False)).lt(UPat.cvar("c1", vec=False)),
1100
+ lambda x,c0,c1: x.lt(c1.arg*c0.arg) if c0.arg > 0 else None),
1101
+ # mul add lt
1102
+ (((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
1103
+ lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
1104
+ # ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
1105
+ (UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
1106
+ (UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
1107
+ # *** rules from symbolic ***
1108
+ # unrolled arange div folding
1109
+ (UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),
1110
+ # generic lt folding
1111
+ (UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
1112
+ # canonicalize a simplex with positive coefficients > 0
1113
+ # not x < 1 -> X > 0
1114
+ (UPat.var("x", dtypes.ints).lt(1).ne(True), lambda x: newx.lt(1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
1115
+ # ** div **
1116
+ # # div folding
1117
+ (UPat.var("x", dtypes.sints) // UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=div_folding(x,c.arg)) is not None else None),
1118
+ # ** mod **
1119
+ # mod folding
1120
+ (UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
1121
+ ])
1122
+
1123
+ symbolic_flat = symbolic+PatternMatcher([
1124
+ # ** combine terms (opinionated) **
1125
+ (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
1126
+ # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
1127
+ ((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
1128
+ ])
1129
+
1130
+ _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
1131
+
1132
+ # for debug
1133
+ syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>",
1134
+ Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
1135
+ renderer = PatternMatcher([
1136
+ (UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
1137
+ (UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")),
1138
+ (UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
1139
+ (UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
1140
+ (UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
1141
+ (UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
1142
+ (UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
1143
+ (UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
1144
+ (UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.op]}{x.src[1].arg})")),
1145
+ ])
1146
+
1147
+ # *** what was symbolic.py ***
1148
+
1149
+ sint = Union[int, UOp]
1150
+ Variable = UOp
1151
+
1152
+ ConstLike = Union[ConstType, Variable, Tuple[ConstType, ...]]