tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/ops.py CHANGED
@@ -1,14 +1,15 @@
1
1
  from __future__ import annotations
2
- from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, DefaultDict
3
- import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect
2
+ from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type, Literal, get_args
3
+ import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
4
4
  from enum import auto, IntEnum, Enum
5
5
  from dataclasses import dataclass, field
6
6
  from collections import defaultdict
7
- from weakref import WeakValueDictionary
8
- from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
9
- from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T
7
+ from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
8
+ from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, _METADATA, flatten
9
+ from tinygrad.helpers import PICKLE_BUFFERS, SPLIT_REDUCEOP, DEBUG, dedup
10
10
  if TYPE_CHECKING:
11
11
  from tinygrad.shape.shapetracker import ShapeTracker
12
+ from tinygrad.device import Buffer
12
13
 
13
14
  # wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
14
15
  class FastEnum(IntEnum):
@@ -26,8 +27,7 @@ class SimpleMathTrait:
26
27
  def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
27
28
  def logical_not(self): return self.ne(True)
28
29
  def neg(self):
29
- dtype: Optional[DType] = getattr(self, 'dtype', None)
30
- assert dtype is not None, "MathTraits __neg__ requires a dtype"
30
+ if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
31
31
  return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
32
32
  def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse)
33
33
  def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse)
@@ -35,6 +35,7 @@ class SimpleMathTrait:
35
35
  def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse)
36
36
  def xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse)
37
37
  def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse)
38
+ def mod(self, x, reverse=False): return self._binop(Ops.MOD, x, reverse)
38
39
  def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
39
40
  def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
40
41
 
@@ -44,7 +45,8 @@ class SimpleMathTrait:
44
45
  def __sub__(self, x): return self.sub(x)
45
46
  def __mul__(self, x): return self.mul(x)
46
47
  def __truediv__(self, x): return self.div(x)
47
- def __floordiv__(self, x): return self.idiv(x)
48
+ def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv
49
+ def __mod__(self, x): return self.mod(x)
48
50
  def __and__(self, x): return self.bitwise_and(x)
49
51
  def __or__(self, x): return self.bitwise_or(x)
50
52
  def __xor__(self, x): return self.xor(x)
@@ -57,22 +59,19 @@ class SimpleMathTrait:
57
59
  def __rand__(self, x): return self.bitwise_and(x, True)
58
60
  def __ror__(self, x): return self.bitwise_or(x, True)
59
61
  def __rxor__(self, x): return self.xor(x, True)
62
+ def __rmod__(self, x): return self.mod(x, True)
63
+
64
+ def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
65
+ def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
66
+ def __ge__(self, x): return (self < x).logical_not()
67
+ def __le__(self, x): return (self > x).logical_not()
60
68
 
61
- def lt(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
62
- def gt(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
63
69
  def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x))
64
- def ge(self, x): return self.lt(x).logical_not()
65
- def le(self, x): return self.gt(x).logical_not()
66
70
  def eq(self, x): return self.ne(x).logical_not()
67
-
68
- def __lt__(self, x): return self.lt(x)
69
- def __gt__(self, x): return self.gt(x)
70
71
  def __ne__(self, x): return self.ne(x)
71
- def __ge__(self, x): return self.ge(x)
72
- def __le__(self, x): return self.le(x)
73
72
  # NOTE: __eq__ isn't overridden, and means the same thing as is by default
74
73
 
75
- class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
74
+ class MathTrait(SimpleMathTrait):
76
75
  # TODO: move to Tensor when new backward is done
77
76
  def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse)
78
77
  def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse)
@@ -81,10 +80,6 @@ class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
81
80
  def __rlshift__(self, x): return self.lshift(x, True)
82
81
  def __rrshift__(self, x): return self.rshift(x, True)
83
82
 
84
- # not in Tensor
85
- def __mod__(self, x): return self.alu(Ops.MOD, self.ufix(x))
86
- def __rmod__(self, x): return self.ufix(x).alu(Ops.MOD, self)
87
-
88
83
  def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
89
84
  def minimum(self, x): return -(-self).maximum(-x)
90
85
  def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y))
@@ -98,39 +93,40 @@ class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
98
93
  # the order of these Ops controls the order of the toposort
99
94
  class Ops(FastEnum):
100
95
  # uops that aren't rendered
101
- SINK = auto()
102
- CONTIGUOUS = auto()
103
- PRELOAD = auto()
96
+ SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); PRELOAD = auto(); KERNEL = auto() # noqa: E702
104
97
 
105
- # MetaOps
106
- COPY = auto()
98
+ # TODO: empty continues to exist because of tensor
107
99
  EMPTY = auto()
108
- BUFFER_VIEW = auto()
109
-
110
- EXPAND = auto()
111
- CONTRACT = auto()
112
- VIEW = auto()
113
- DEFINE_GLOBAL = auto()
114
- BUFFER = auto()
115
- DEFINE_VAR = auto()
116
- DEFINE_LOCAL = auto()
117
- DEFINE_ACC = auto()
118
- VALID = auto()
119
- SPECIAL = auto()
120
- NOOP = auto()
100
+
101
+ # MetaOps
102
+ COPY = auto(); BUFFER_VIEW = auto() # noqa: E702
103
+
104
+ # blocks in linearizer
105
+ BLOCK = auto(); BLOCKSTART = auto(); BLOCKFORK = auto(); BLOCKEND = auto() # noqa: E702
106
+
107
+ # movement ops!
108
+ RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
109
+
110
+ # misc ops
111
+ UNROLL = auto(); CONTRACT = auto() # noqa: E702
112
+ VIEW = auto(); DEFINE_GLOBAL = auto(); BUFFER = auto() # noqa: E702
113
+ DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
114
+ VALID = auto(); SPECIAL = auto(); NOOP = auto() # noqa: E702
121
115
 
122
116
  # reduce
123
117
  REDUCE_AXIS = auto()
124
118
 
125
119
  # helper ops
126
- GEP = auto()
127
- VECTORIZE = auto()
120
+ GEP = auto(); VECTORIZE = auto() # noqa: E702
128
121
 
129
122
  # UnaryOps
130
123
  CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
131
124
 
132
- # loads before math
133
- LOAD = auto()
125
+ # load/store before math
126
+ LOAD = auto(); STORE = auto() # noqa: E702
127
+
128
+ # early INDEX
129
+ INDEX = auto()
134
130
 
135
131
  # math ops
136
132
  WMMA = auto()
@@ -143,25 +139,18 @@ class Ops(FastEnum):
143
139
  WHERE = auto(); MULACC = auto() # noqa: E702
144
140
 
145
141
  # assignment ops
146
- STORE = auto()
147
142
  ASSIGN = auto()
148
143
  BIND = auto()
149
144
 
150
- # late INDEX
151
- INDEX = auto()
152
-
153
145
  # control flow ops
154
- BARRIER = auto()
155
- IF = auto()
156
- RANGE = auto()
157
-
158
- # ops that are not graph nodes
159
- ENDRANGE = auto()
160
- ENDIF = auto()
146
+ BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702
161
147
 
162
148
  # consts last!
163
- VCONST = auto()
164
- CONST = auto()
149
+ VCONST = auto(); CONST = auto() # noqa: E702
150
+
151
+ # device
152
+ DEVICE = auto()
153
+ MULTI = auto()
165
154
 
166
155
  class GroupOp:
167
156
  Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
@@ -171,41 +160,53 @@ class GroupOp:
171
160
  ALU = set.union(Unary, Binary, Ternary)
172
161
 
173
162
  Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
163
+ Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
174
164
 
175
- # meta ops
176
- Meta = {Ops.COPY, Ops.EMPTY, Ops.BUFFER_VIEW}
177
- Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
165
+ Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
166
+ Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
178
167
 
179
168
  # BinaryOps that can be flipped
180
169
  Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR}
181
170
 
171
+ # BinaryOps where f(f(a,b),c) = f(a,f(b,c))
172
+ Associative = {Ops.ADD, Ops.MUL, Ops.AND, Ops.OR, Ops.MAX}
173
+
174
+ # BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence
175
+ Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
176
+
182
177
  # do not preserve f(0) = 0
183
178
  UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
184
179
 
185
- # https://en.wikipedia.org/wiki/Identity_element
186
- def identity_element(op:Ops, dt:DType): return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
180
+ All = set(Ops)
181
+
182
+ # some BUFFER ops can be processed with only a view
183
+ view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
187
184
 
188
- def can_pad(u:UOp) -> bool: return not any(x.op in GroupOp.UnsafePad for x in u.sparents)
185
+ # https://en.wikipedia.org/wiki/Identity_element
186
+ def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
189
187
 
190
- END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
188
+ def can_pad(u:UOp, edges:dict[UOp, UOp], visisted:dict[UOp, None]) -> bool:
189
+ if u.op in GroupOp.UnsafePad: return False
190
+ if (len(u.src) == 2 and u.src[0] in edges) or u in visisted: return True
191
+ visisted[u] = None
192
+ return all(can_pad(x.base, edges, visisted) for x in u.src)
191
193
 
192
194
  # With True as the default, this matches the old symbolic behavior
193
- def resolve(x, default:bool=True):
194
- if not isinstance(x, UOp): return bool(x)
195
- assert x.dtype is dtypes.bool, "UOp in resolve must be bool"
195
+ def resolve(x:UOp|bool, default:bool=True):
196
+ if isinstance(x, bool): return x
197
+ assert x.dtype == dtypes.bool, "UOp in resolve must be bool"
196
198
  # NOTE: generating the text for the exception is expensive, so we do this
197
199
  return bool(sx.vmin) if (sx:=x.simplify()).vmin == sx.vmax else default
198
200
 
199
201
  # smax/smin are replacements for max/min that preserve symbolic
200
202
  def _suop(lst, uop_fxn, python_fxn):
201
- max_uop, max_num = partition(lst, lambda x: isinstance(x, UOp))
202
- if len(max_uop): return functools.reduce(uop_fxn, (max_uop + [python_fxn(max_num)]) if len(max_num) else max_uop).ssimplify()
203
- return python_fxn(max_num)
204
- def smax(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.maximum, max)
205
- def smin(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.minimum, min)
203
+ uops, nums = partition(lst, lambda x: isinstance(x, UOp))
204
+ return ssimplify(functools.reduce(uop_fxn, uops + ([python_fxn(nums)] if nums else [])))
205
+ def smax(*lst): return _suop(argfix(*lst), UOp.maximum, max)
206
+ def smin(*lst): return _suop(argfix(*lst), UOp.minimum, min)
206
207
 
207
208
  def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
208
- def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
209
+ def sym_infer(uop: Union[UOp, int], var_vals: dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
209
210
 
210
211
  # used for UOp and UPat
211
212
  def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
@@ -219,53 +220,101 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
219
220
  return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
220
221
 
221
222
  class UOpMetaClass(type):
222
- ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
223
- def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
224
- if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
225
- UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg)
226
- return ret
227
-
223
+ ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
224
+ def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, _buffer:Buffer|None=None):
225
+ if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
226
+ UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
227
+ for s in src: s.children.add(ref)
228
+ # NOTE: this will soon be set by Tensor once we remove function.py
229
+ if (metadata:=_METADATA.get()) is not None: all_metadata[created] = metadata
230
+ # NOTE: this value is set by pickle when pickling a realized tensor
231
+ if _buffer is not None:
232
+ assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
233
+ buffers[created] = _buffer
234
+ return created
235
+
236
+ # some uops map to other stuff
237
+ buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers
238
+ all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary()
239
+
240
+ # NOTE: this should be frozen, but frozen is slower
241
+ @dataclass(eq=False, slots=True)
228
242
  class UOp(MathTrait, metaclass=UOpMetaClass):
229
- __slots__ = ["op", "dtype", "src", "arg"]
230
- def __init__(self, op:Ops, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
231
- # TODO: instant check rules here make debugging easier
232
- self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
233
- def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg)
243
+ op:Ops
244
+ dtype:DType = dtypes.void
245
+ src:tuple[UOp, ...] = tuple()
246
+ arg:Any = None
247
+ children:set[weakref.ref[UOp]] = field(default_factory=set)
248
+ def __del__(self):
249
+ if self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1)
250
+ if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg))) is not None:
251
+ for s in self.src: s.children.discard(ref)
252
+ del UOpMetaClass.ucache[k]
253
+ def __reduce__(self):
254
+ args = [self.op, self.dtype, self.src, self.arg]
255
+ if self.op is Ops.BUFFER and self.realized is not None and PICKLE_BUFFERS: args.append(self.realized)
256
+ return UOp, tuple(args)
234
257
  def replace(self, **kwargs) -> UOp:
235
- for k in kwargs: assert k in self.__slots__, f"unkown replace arg, expected one of {self.__slots__}, got {k}"
236
- new_args = (kwargs.get("op", self.op), kwargs.get("dtype", self.dtype), kwargs.get("src", self.src), kwargs.get("arg", self.arg))
258
+ new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg))
259
+ assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}"
237
260
  if (self.op, self.dtype, self.src, self.arg) == new_args: return self
238
261
  return UOp(*new_args)
239
262
  @functools.cached_property
240
263
  def key(self) -> bytes:
241
264
  return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
242
265
  def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
243
- def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else self.arg
244
- @functools.cached_property
245
- def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
246
- @functools.cached_property # parents with self
247
- def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
266
+ def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else repr(self.arg)
267
+
268
+ @property
269
+ def toposort(self) -> dict[UOp, None]:
270
+ def _toposort(u:UOp, cache:set[UOp]):
271
+ if u in cache: return {}
272
+ nodes: dict[UOp, None] = {}
273
+ # NOTE: this is a lot faster than the comprehension in parents
274
+ for parent in u.src: nodes.update(_toposort(parent, cache))
275
+ nodes[u] = None
276
+ cache.add(u)
277
+ return nodes
278
+ return _toposort(self, cache=set())
248
279
 
249
280
  @functools.cached_property
250
- def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]:
251
- return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src))
281
+ def tuplize(self:UOp) -> tuple[int, Any, Optional[DType], tuple]: return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src))
252
282
 
253
283
  # *** uop shape stuff ***
254
284
 
255
- @property
256
- def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR}
257
285
  @functools.cached_property
258
- def st(self) -> Optional[ShapeTracker]:
259
- if not self.has_st: return None
260
- if self.op in GroupOp.Buffer: return self.st_arg
261
- if self.op is Ops.VIEW: return self.arg
262
- src_sts = [x.st for x in self.src if x.st is not None]
263
- assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
286
+ def st(self) -> ShapeTracker|None:
264
287
  from tinygrad.shape.shapetracker import ShapeTracker
265
- return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is Ops.REDUCE_AXIS else src_sts[0]
288
+ if self.op is Ops.MULTI:
289
+ return ShapeTracker.from_shape(
290
+ tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)))
291
+ if self.op is Ops.BUFFER: return ShapeTracker.from_shape((self.size,))
292
+ # these ops define a ShapeTracker from the arg
293
+ if self.op is Ops.VIEW: return self.arg
294
+ if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
295
+ # buffer ops return the ShapeTracker from sources
296
+ if self.op in GroupOp.Buffer: return vsrc[0] if len(vsrc:=[x.st for x in self.src if x.op is Ops.VIEW]) != 0 else None
297
+ if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
298
+ assert all_same([x.shape for x in src_sts]), f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}"
299
+ if self.op in {Ops.BITCAST, Ops.BUFFER_VIEW}:
300
+ shape = src_sts[0].shape
301
+ if self.dtype.itemsize != (input_sz:=self.src[0].dtype.itemsize): shape = shape[:-1]+((shape[-1]*input_sz) // self.dtype.itemsize,)
302
+ # only reduce ops are allowed to change shape, everything else derives shape from sources
303
+ elif self.op in {Ops.REDUCE_AXIS, Ops.WMMA}: shape = src_sts[0].reduce(self.axis_arg)
304
+ else: shape = src_sts[0].shape
305
+ return ShapeTracker.from_shape(shape)
306
+
266
307
  @functools.cached_property
267
- def full_shape(self) -> Tuple[sint, ...]:
268
- return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
308
+ def full_shape(self) -> tuple[sint, ...]:
309
+ if self.op is Ops.VIEW: return self.shape
310
+ # TODO: this should check if st is None, it cannot because local reduce has implicit movement ops
311
+ return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} \
312
+ # TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this
313
+ and not (x.op is Ops.CONST and x.st is None)]))
314
+ @property
315
+ def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
316
+ @property
317
+ def size(self) -> int: return self.arg[1] if self.op is Ops.BUFFER else unwrap(self.st).size
269
318
 
270
319
  # *** uop evaluation ***
271
320
 
@@ -282,34 +331,44 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
282
331
  def __bool__(self): return self._eval((dtypes.bool,), bool)
283
332
  def __int__(self): return self._eval(dtypes.ints, int)
284
333
  def __float__(self): return self._eval(dtypes.floats, float)
285
- def substitute(self, dvars:Dict[UOp, UOp]):
334
+ def substitute(self, dvars:dict[UOp, UOp]):
286
335
  with Context(TRACK_MATCH_STATS=0):
287
- return graph_rewrite(self, _substitute, dvars)
336
+ return graph_rewrite(self, _substitute, dvars, bottom_up=True)
288
337
 
289
338
  # *** uop syntactic sugar ***
290
339
 
291
340
  @property
292
341
  def st_arg(self) -> ShapeTracker:
293
342
  assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}"
294
- ret = self.src[0 if self.op is Ops.VALID else 1]
295
- assert ret.op is Ops.VIEW, f"st_arg trying to return {ret}"
296
- return ret.arg
343
+ return unwrap(self.st)
344
+ @property
345
+ def const_arg(self) -> ConstType:
346
+ match self.base.op:
347
+ case Ops.CONST: ret = self.base.arg
348
+ case op: raise AssertionError(f"const_arg called on {op}")
349
+ assert isinstance(ret, get_args(ConstType)), f"const_arg trying to return {ret}"
350
+ return ret
297
351
  @property
298
- def axis_arg(self) -> Tuple[int, ...]:
352
+ def axis_arg(self) -> tuple[int, ...]:
299
353
  assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
300
354
  ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
301
355
  assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
302
356
  return ret
303
357
  def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
304
- def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
305
- def const_like(self, b:ConstLike): return UOp.const(self.dtype, b)
358
+ def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
359
+ def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
360
+ def const_like(self, b:ConstLike):
361
+ # constants can optionally have a DEVICE source
362
+ if self._device is None: return UOp.const(self.dtype, b)
363
+ if isinstance(self.device, tuple): return UOp.multi(*[UOp.metaop(Ops.CONST, self.shape, self.dtype, d, b) for d in self.device], axis=None)
364
+ return UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b)
306
365
  def broadcast(self, count:int):
307
366
  assert self.dtype.count == 1
308
367
  if count == 1: return self
309
368
  return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
310
369
  def cast(self, dtype:DType): return UOp(Ops.CAST, dtype, (self,))
311
370
  def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
312
- def gep(self, i:Union[Tuple[int, ...], int]):
371
+ def gep(self, i:Union[tuple[int, ...], int]):
313
372
  if isinstance(i, int):
314
373
  # NOTE: these are just shortcuts to not have to create and fold later
315
374
  if self.op is Ops.VECTORIZE: return self.src[i]
@@ -322,42 +381,185 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
322
381
  def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
323
382
  def alu(self, arg, *src:UOp):
324
383
  out_dtype = (self, *src)[-1].dtype
325
- if arg in {Ops.CMPLT, Ops.CMPNE} and out_dtype is not None:
326
- out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
384
+ if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
327
385
  return UOp(arg, out_dtype, (self,)+src)
328
386
  @staticmethod
329
387
  def const(dtype:DType, b:ConstLike):
330
388
  if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
331
389
  if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
332
- return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
390
+ return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
391
+ def valid(self, st:ShapeTracker):
392
+ assert self.op in {Ops.CONST, Ops.DEFINE_VAR}, f"can only create VALID from a constant, got {self.op}"
393
+ from tinygrad.shape.shapetracker import ShapeTracker
394
+ # NOTE: only VALID has a masked ShapeTracker, the CONST operands are unmasked
395
+ unmasked_st = ShapeTracker.from_shape(()).reshape((1,)*len(st.shape)).expand(st.shape).to_uop()
396
+ return UOp(Ops.VALID, dtypes.bool, (st.to_uop(),)).where(self.replace(src=(unmasked_st,)), UOp.const(self.dtype, 0).replace(src=(unmasked_st,)))
333
397
  @staticmethod
334
- def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
335
- return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
336
- UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=(idx, False))
337
- def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
398
+ def range(dtype:DType, start:sint, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(start), sint_to_uop(end)), arg=idx)
399
+ def _reduce_op(self, op:Ops, axis:tuple[int, ...]):
400
+ axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
401
+ return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
402
+ def r(self, op:Ops, axis:tuple[int, ...]) -> UOp:
403
+ new_shape = unwrap(self.st).reduce(axis)
404
+
405
+ # TODO: can we split symbolic shape if the reduce axis is not symbolic?
406
+ # TODO: this shouldn't be here, it belongs in scheduler! that's why it broke multi
407
+ if not SPLIT_REDUCEOP or isinstance(self._device, tuple) or not all_int(self.shape) or (0 in self.shape) or \
408
+ prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
409
+ return self._reduce_op(op, axis)
410
+
411
+ # if there are few globals, make some reduces into globals by splitting into two kernels
412
+ # cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
413
+ # ~2**10 should be enough if GROUP is used
414
+ # 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
415
+ # split is moved to the end to provide maximum locality for the second phase reduce.
416
+ self_real_strides = unwrap(self.st).real_strides(ignore_valid=True)
417
+ split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
418
+ if self.shape[i] % x == 0 and self_real_strides[i] != 0]
419
+ if not split_candidates: return self._reduce_op(op, axis)
420
+ dim_to_split, divisor = split_candidates[0]
421
+ splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
422
+ splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
423
+ if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
424
+ return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
338
425
  def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
339
- def contiguous(self): return UOp(Ops.CONTIGUOUS, self.dtype, (self,))
426
+ def contiguous(self): return self.alu(Ops.CONTIGUOUS)
427
+ def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
428
+
429
+ # *** from MultiLazyBuffer ***
430
+
431
+ def multi(self, *more:UOp, axis:int|None, real:tuple[bool,...]|None=None):
432
+ parents = (self,)+more
433
+ assert all_same([x.dtype for x in parents]), "multi parents must have the same dtype"
434
+ return UOp(Ops.MULTI, self.dtype, parents, (axis, real if real is not None else (True,)*len(parents)))
435
+
436
+ @property
437
+ def bounds(self):
438
+ if self.axis is None: raise RuntimeError("bounds is not defined when axis is None")
439
+ return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.src], initial=0)))
440
+
441
+ @functools.cached_property
442
+ def axis(self) -> Optional[int]:
443
+ if self.op is Ops.MULTI: return self.arg[0]
444
+ # NOTE: they all have to share an axis, we always choose [-1]
445
+ if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
446
+ src_axis = self.src[0].axis
447
+ if self.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in self.arg[1] else src_axis
448
+ if self.op is Ops.RESHAPE:
449
+ if src_axis is None: return None
450
+ arg_acc:list[sint] = list(itertools.accumulate(self.arg, operator.mul, initial=1))
451
+ # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
452
+ # TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
453
+ return len(arg_acc) - arg_acc[::-1].index(prod(self.src[0].shape[:src_axis])) - 1
454
+ if self.op is Ops.PERMUTE: return self.arg.index(src_axis) if src_axis is not None else None
455
+ return src_axis
456
+
457
+ @property
458
+ def real(self):
459
+ assert self.op is Ops.MULTI
460
+ return self.arg[1]
461
+
340
462
  @property
341
- def is_contiguous_base(self): return self.op is Ops.CONTIGUOUS and not (self.src[0].base.op is Ops.VIEW and len(self.src[0].base.src) == 2)
463
+ def real_lbs(self): return [lb for lb,r in zip(self.src, self.real) if r]
464
+
465
+ def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> UOp:
466
+ if axis is None: lbs = [self] * len(devices)
467
+ else:
468
+ if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}")
469
+ # NOTE: this works for both even shards and uneven shards
470
+ sz = self.shape[axis] // len(devices)
471
+ sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]
472
+ lbs = []
473
+ for sz,off in zip(sizes, itertools.accumulate(sizes, initial=0)):
474
+ lbs.append(self.shrink(tuple((0,s) if i != axis else (off,off+sz) for i,s in enumerate(self.shape))))
475
+ sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)]
476
+ return UOp.multi(*[lb.contiguous() for lb in sharded_lbs], axis=axis)
477
+
478
+ # *** from LazyBuffer ***
479
+
480
+ @staticmethod
481
+ def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None) -> UOp:
482
+ from tinygrad.shape.shapetracker import ShapeTracker
483
+ # Tensor const is CONST(VIEW(DEVICE)) -> RESHAPE -> EXPAND
484
+ if op is Ops.CONST:
485
+ assert isinstance(arg, get_args(ConstType)), f"trying to create CONST with {arg=}"
486
+ return UOp.const(dtype, unwrap(arg)).replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),),
487
+ ShapeTracker.from_shape(())),)).reshape((1,)*len(shape)).expand(shape)
488
+ # Tensor variable binding is BIND(VAR(VIEW(DEVICE)), CONST(VIEW(DEVICE)))
489
+ if op is Ops.BIND:
490
+ var, val = arg.unbind()
491
+ return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
492
+ # otherwise it's just a VIEW(BUFFER)
493
+ return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st)
494
+ def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp:
495
+ # if it's a shrink, do the shrink before the copy with CONTIGUOUS
496
+ if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
497
+ # COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st)
498
+ ret = UOp(Ops.COPY, self.base.dtype, (UOp(Ops.DEVICE, arg=device), self.base), clone)
499
+ op_arg = []
500
+ mop = self
501
+ while mop is not self.base:
502
+ op_arg.append((mop.op, mop.arg))
503
+ mop = mop.src[0]
504
+ for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg)
505
+ return ret
506
+ def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
507
+ @property
508
+ def metadata(self): return all_metadata.get(self, None)
342
509
 
343
510
  # *** uop movement ops ***
344
511
 
345
512
  @property
346
- def base(self): return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 else self
347
- def view(self, st:ShapeTracker):
348
- assert self.op is not Ops.STORE, "VIEW of STORE is invalid, STORE is always base"
349
- return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st)
350
- def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
513
+ def base(self) -> UOp:
514
+ if self.op in GroupOp.Movement: return self.src[0].base
515
+ return self.src[0].base if self.op is Ops.VIEW and len(self.src) == 1 else self
516
+ def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
517
+
518
+ def _mop(self, op:Ops, arg):
519
+ ret = UOp(op, self.dtype, (self,), arg)
520
+ if self.st == ret.st: return self # ignore NOOPs, also check ret.st
521
+ return ret
522
+
523
+ def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg)
524
+ def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg)
525
+ def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg)
526
+ def permute(self, arg:tuple[sint, ...]): return self._mop(Ops.PERMUTE, arg)
527
+ def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg)
528
+ def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg)
351
529
 
352
530
  # *** uop Buffer stuff ***
353
531
 
532
+ buffer_num = itertools.count(0)
354
533
  @staticmethod
355
- def new_buffer(device:str, size:int, dtype:DType, num=-1): return UOp(Ops.BUFFER, dtype.ptr(), (), (num, (device, size, dtype)))
356
-
534
+ def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size))
535
+ @property
536
+ def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
537
+ @functools.cached_property
538
+ def _device(self) -> Optional[str|tuple[str, ...]]:
539
+ if self.op is Ops.DEVICE: return self.arg
540
+ if self.op is Ops.MULTI: return tuple(cast(str, x.device) for x in self.src)
541
+ return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
357
542
  @property
358
543
  def buf_uop(self) -> UOp:
359
- assert self.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}"
360
- return self.src[0]
544
+ if self.base.op is Ops.BUFFER: return self.base
545
+ assert self.base.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW}, f"buf_uop called on {self.op}"
546
+ return self.src[0].buf_uop
547
+ @property
548
+ def buffer(self) -> Buffer:
549
+ if self.op is Ops.VIEW:
550
+ assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
551
+ return self.src[0].buffer
552
+ assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
553
+ if (cret:=buffers.get(self)) is not None: return cret
554
+ from tinygrad.device import Buffer
555
+ assert isinstance(self.device, str), f"buffer not supported on multi {self.device}"
556
+ buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base)
557
+ return ret
558
+ @property
559
+ def realized(self) -> Optional[Buffer]: return self.buffer if self.op is Ops.BUFFER else None
560
+ @property
561
+ def is_realized(self) -> bool:
562
+ return all(x.base.realized is not None for x in self.base.real_lbs) if self.base.op is Ops.MULTI else self.base.realized is not None
361
563
 
362
564
  # *** uop Variable stuff ***
363
565
 
@@ -373,18 +575,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
373
575
  assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
374
576
  assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
375
577
  return UOp(Ops.BIND, self.dtype, (self, self.const_like(val)))
376
- def unbind(self) -> Tuple[Variable, int]:
578
+ def unbind(self) -> tuple[Variable, int]:
377
579
  assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
378
580
  return self.src[0], self.src[1].arg
379
581
  @property
380
582
  def val(self) -> int: return self.unbind()[1]
381
- def vars(self) -> Set[UOp]:
382
- bound_vars = set([x for x in self.sparents if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
583
+ def vars(self) -> set[UOp]:
584
+ bound_vars = set([x for x in self.toposort if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
383
585
  bound_var_base = set(x.src[0] for x in bound_vars)
384
- all_vars = set([x for x in self.sparents if x.op is Ops.DEFINE_VAR])
586
+ all_vars = set([x for x in self.toposort if x.op is Ops.DEFINE_VAR])
385
587
  return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
386
- def variables(self) -> List[Variable]:
387
- st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in GroupOp.Buffer]
588
+ def variables(self) -> list[Variable]:
589
+ st_vars: list[set[Variable]] = [x.st_arg.vars() for x in self.toposort if x.op in GroupOp.Buffer]
388
590
  return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
389
591
 
390
592
  # *** uop symbolic stuff ***
@@ -396,7 +598,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
396
598
  if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
397
599
  if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
398
600
  return 1
399
- def divides(self, v) -> Optional[UOp]:
601
+ def divides(self, v) -> UOp|None:
400
602
  if v==1: return self
401
603
  if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
402
604
  if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
@@ -410,15 +612,23 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
410
612
  @property
411
613
  def vmax(self) -> ConstType: return self._min_max[1]
412
614
  @functools.cached_property
413
- def _min_max(self) -> Tuple[ConstType, ConstType]:
615
+ def _min_max(self) -> tuple[ConstType, ConstType]:
414
616
  if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype):
415
617
  (s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
416
618
  if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
417
619
  if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
418
- if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1
419
- if self.op is Ops.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST
420
- if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
421
- if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
620
+ # SHL/SHR on consts only
621
+ if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
622
+ if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2]
623
+ if self.op is Ops.MOD and s1_vmin > 0:
624
+ return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), s1_vmax-1)
625
+ if self.op is Ops.IDIV:
626
+ if s1_vmin == s1_vmax: # min/max are equal in a CONST
627
+ if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
628
+ if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
629
+ # don't know exact bounds, but know the sign
630
+ if (s0_vmax <= 0 and s1_vmin < 0) or (s0_vmin >= 0 and s1_vmin > 0): return 0, dtypes.max(self.dtype)
631
+ if (s0_vmax <= 0 and s1_vmin > 0) or (s0_vmin >= 0 and s1_vmin < 0): return dtypes.min(self.dtype), 0
422
632
  if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
423
633
  if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
424
634
  if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
@@ -431,9 +641,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
431
641
  if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
432
642
  if self.op is Ops.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
433
643
  if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
434
- if self.op in {Ops.EXPAND, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
644
+ if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
435
645
  # TODO: UOps.SPECIAL is UOps.DEFINE_VAR
436
- if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype)
646
+ if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].vmax
437
647
  if self.op is Ops.CONST: return self.arg, self.arg
438
648
  if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
439
649
  return dtypes.min(self.dtype), dtypes.max(self.dtype)
@@ -441,11 +651,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
441
651
  @functools.cached_property
442
652
  def _sym_fxn(self):
443
653
  sself = self.simplify()
444
- varnames = tuple(x.arg[0] for x in sself.sparents if x.op is Ops.DEFINE_VAR)
654
+ varnames = tuple(x.arg[0] for x in sself.toposort if x.op is Ops.DEFINE_VAR)
445
655
  # TODO: sanitize varnames, or don't use naked eval while staying fast
446
656
  return eval("lambda "+','.join(varnames)+": "+sself.render()), varnames # pylint: disable=eval-used
447
657
 
448
- def sym_infer(self, var_vals:Dict[UOp, int]):
658
+ def sym_infer(self, var_vals:dict[UOp, int]):
449
659
  fxn, varnames = self._sym_fxn
450
660
  return fxn(**{k.arg[0]:v for k,v in var_vals.items() if k.arg[0] in varnames})
451
661
 
@@ -456,25 +666,22 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
456
666
  @dataclass(frozen=True)
457
667
  class KernelInfo:
458
668
  local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
459
- upcasted: int = 0 # count that are upcasted (this is remapping RANGE to EXPAND)
669
+ upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL)
460
670
  dont_use_locals: bool = False # don't use local indexing
461
671
 
462
672
  # ***** ops in python *****
463
673
 
464
- def hook_overflow(dv, fxn):
465
- def wfxn(*args):
466
- try: return fxn(*args)
467
- except OverflowError: return dv
468
- return wfxn
674
+ def safe_exp2(x):
675
+ try: return 2 ** x
676
+ except OverflowError: return math.inf
469
677
 
470
- python_alu: Dict[Ops, Callable] = {
471
- Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: hook_overflow(math.inf, lambda x: 2**x),
678
+ python_alu: dict[Ops, Callable] = {
679
+ Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
472
680
  Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
473
681
  Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
474
- Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul,
475
- Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
476
- Ops.MAX: max, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt, Ops.XOR: operator.xor,
477
- Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift,
682
+ Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
683
+ Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
684
+ Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0,
478
685
  Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z}
479
686
 
480
687
  def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
@@ -485,63 +692,32 @@ def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
485
692
 
486
693
  # ***** uop helpers *****
487
694
 
488
- def print_uops(uops:List[UOp]):
695
+ def print_uops(uops:list[UOp]):
489
696
  for i,u in enumerate(uops):
490
697
  formatted_parents = [(uops.index(x) if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
491
698
  print(f"{i:4d} {str(u.op):20s}: {str(u.dtype):30s} " f"{str(formatted_parents):32s} {u.arg}")
492
699
 
493
- def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
494
- flops: sint = 0
495
- mem: sint = 0
496
- mults: sint = 1
497
- mult_stack: List[sint] = []
498
- dont_count: Set[UOp] = set()
499
- if ignore_indexing:
500
- for u in uops:
501
- if u.op in {Ops.LOAD, Ops.STORE}:
502
- dont_count = dont_count.union(u.src[0].sparents)
503
- if len(u.src) > 2: dont_count = dont_count.union(u.src[2].sparents)
504
- elif u.op is Ops.IF:
505
- dont_count = dont_count.union(u.src[0].sparents)
506
- for u in uops:
507
- if u.op is Ops.RANGE:
508
- mult_stack.append(mults)
509
- mults *= (u.src[1] - u.src[0]).ssimplify()
510
- elif u.op is Ops.ENDRANGE:
511
- mults = mult_stack.pop(-1)
512
- elif u.op is Ops.SPECIAL:
513
- mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
514
- elif u.op is Ops.LOAD:
515
- mem += u.dtype.itemsize * mults
516
- elif u.op is Ops.STORE:
517
- mem += u.src[1].dtype.itemsize * mults
518
- elif u.op in GroupOp.ALU and u not in dont_count:
519
- flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
520
- elif u.op is Ops.WMMA and u not in dont_count:
521
- flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
522
- return flops, mem
523
-
524
700
  # ***** pattern matcher *****
525
701
 
526
- def get_location() -> Tuple[str, int]:
702
+ def get_location() -> tuple[str, int]:
527
703
  frm = sys._getframe(1)
528
704
  # find the real frame in the file that has the UPat, TODO: is there a better way to do this?
529
- while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "uopgraph.py", "schedule.py",
530
- "lowerer.py", "cstyle.py"}:
705
+ while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py",
706
+ "lowerer.py", "cstyle.py", "linearize.py"}:
531
707
  frm = frm.f_back
532
708
  return frm.f_code.co_filename, frm.f_lineno
533
709
  @functools.lru_cache(None)
534
- def lines(fn) -> List[str]:
710
+ def lines(fn) -> list[str]:
535
711
  with open(fn) as f: return f.readlines()
536
712
 
537
713
  class UPat(MathTrait):
538
- __slots__ = ["op", "dtype", "arg", "name", "src"]
539
- def __init__(self, op:Optional[Union[Ops, Tuple[Ops, ...], Set[Ops]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
540
- src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
541
- name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[Set[Ops]]=None):
714
+ __slots__ = ("op", "dtype", "arg", "name", "src")
715
+ def __init__(self, op:Optional[Union[Ops, tuple[Ops, ...], set[Ops]]]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None,
716
+ src:Optional[Union[tuple[UPat, ...], list[UPat], UPat]]=None, arg:Any=None,
717
+ name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[set[Ops]]=None):
542
718
  assert op is None or isinstance(op, Ops) or isinstance(op, tuple) or isinstance(op, set), "op must be Ops or tuple of Ops"
543
- self.op: Optional[Tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
544
- self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
719
+ self.op: Optional[tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
720
+ self.dtype: Optional[tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
545
721
  self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
546
722
  self.src: Any = None
547
723
  assert self.name != "ctx", "UPat can't be named ctx"
@@ -568,13 +744,12 @@ class UPat(MathTrait):
568
744
 
569
745
  @staticmethod
570
746
  @functools.lru_cache(None)
571
- def var(name:Optional[str]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name)
747
+ def var(name:Optional[str]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name)
572
748
  @staticmethod
573
749
  @functools.lru_cache(None)
574
- def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True):
575
- return UPat((Ops.CONST, Ops.VCONST) if vec else Ops.CONST, dtype=dtype, name=name)
750
+ def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True): return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name)
576
751
  @staticmethod
577
- def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
752
+ def const(dtype:Optional[Union[DType, tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
578
753
 
579
754
  # copied from UOp
580
755
  def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
@@ -589,7 +764,7 @@ class UPat(MathTrait):
589
764
  def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
590
765
  def alu(self, op:Ops, *src:UPat):
591
766
  asrc = (self,)+src
592
- return UPat(op, None if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
767
+ return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
593
768
 
594
769
  def printable(self:UPat) -> str:
595
770
  try: return lines(self.location[0])[self.location[1]-1].strip()
@@ -602,14 +777,14 @@ class UPat(MathTrait):
602
777
  set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
603
778
  return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
604
779
 
605
- def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
780
+ def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
606
781
  if (self.op is not None and uop.op not in self.op) or \
607
782
  (self.name is not None and store.setdefault(self.name, uop) is not uop) or \
608
783
  (self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
609
784
  (self.arg is not None and self.arg != uop.arg) or \
610
785
  (self.allowed_len != -1 and len(uop.src) != self.allowed_len): return []
611
786
  if self.src is None: return [store]
612
- res: List[Dict[str, UOp]] = []
787
+ res: list[dict[str, UOp]] = []
613
788
  for vp in self.src:
614
789
  stores, new_stores = [store.copy()], []
615
790
  for uu, vv in zip(uop.src, vp):
@@ -619,13 +794,11 @@ class UPat(MathTrait):
619
794
  return res
620
795
 
621
796
  class UPatAny(UPat):
622
- def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
623
- ret = []
624
- for x in self.src[0]:
625
- if (match:=x.match(uop, store.copy())): ret.extend(match)
626
- return ret
797
+ def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
798
+ matches = [x.match(uop, store.copy()) for x in self.src[0]]
799
+ return flatten([x for x in matches if x is not None])
627
800
 
628
- def deconstruct_function(fxn:Callable) -> Tuple:
801
+ def deconstruct_function(fxn:Callable) -> tuple:
629
802
  new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
630
803
  for co in fxn.__code__.co_consts:
631
804
  if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names})
@@ -635,10 +808,10 @@ def deconstruct_function(fxn:Callable) -> Tuple:
635
808
  return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret
636
809
 
637
810
  class PatternMatcher:
638
- def __init__(self, patterns:List[Tuple[UPat, Callable]]):
811
+ def __init__(self, patterns:list[tuple[UPat, Callable]]):
639
812
  self.patterns = patterns
640
813
  # NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
641
- self.pdict: Dict[Ops, List[Tuple[UPat, Callable, Set, bool]]] = {}
814
+ self.pdict: dict[Ops, list[tuple[UPat, Callable, set, bool]]] = {}
642
815
  # uop is required, arg is optional
643
816
  for p,fxn in self.patterns:
644
817
  assert p.op is not None
@@ -651,7 +824,7 @@ class PatternMatcher:
651
824
  @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
652
825
  def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
653
826
 
654
- def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
827
+ def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
655
828
  ler = {u.op for u in uop.src}
656
829
  for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
657
830
  if not early_reject.issubset(ler): continue
@@ -662,39 +835,32 @@ class PatternMatcher:
662
835
  # *** tracking pattern matcher ***
663
836
 
664
837
  TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
665
- match_stats:Dict[UPat, List[Union[int, float]]] = dict()
838
+ match_stats:dict[UPat, list[Union[int, float]]] = dict()
666
839
  @dataclass(frozen=True)
667
- class TrackedRewriteContext:
668
- loc: Tuple[str, int] # location that called graph_rewrite
669
- sink: UOp # the sink passed into the rewrite
670
- matches: List[Tuple[UOp, Optional[UOp], Optional[UPat], float]] = field(default_factory=list) # all matches of sparents
671
-
672
- rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
673
- contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = []
674
- _rewrite_cnt: Dict[str, int] = {}
840
+ class TrackedGraphRewrite:
841
+ loc: tuple[str, int] # location that called graph_rewrite
842
+ sink: UOp # the sink input to graph_rewrite
843
+ matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches
844
+ tracked_keys:list[Any] = []
845
+ tracked_ctxs:list[list[TrackedGraphRewrite]] = []
846
+ _name_cnt:dict[str, int] = {}
675
847
  def track_rewrites(named=False):
676
848
  def _decorator(func):
677
849
  def __wrapper(self, *args, **kwargs):
678
850
  if TRACK_MATCH_STATS >= 2:
679
- if named: _rewrite_cnt[func.__name__] = _rewrite_cnt.setdefault(func.__name__, 0)+1
680
- rewrite_stack.append((f"{(n:=func.__name__)}_{_rewrite_cnt[n]}" if named else self, []))
681
- try: ret = func(self, *args, **kwargs)
682
- finally: # NOTE: save everything in the stack
683
- if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
684
- return ret
851
+ if named: _name_cnt[func.__name__] = _name_cnt.get(func.__name__, 0)+1
852
+ tracked_keys.append(f"{func.__name__}_{_name_cnt[func.__name__]}" if named else self)
853
+ tracked_ctxs.append([])
854
+ return func(self, *args, **kwargs)
685
855
  return __wrapper
686
856
  return _decorator
687
857
 
688
858
  class TrackedPatternMatcher(PatternMatcher):
689
- def __init__(self, patterns:List[Tuple[UPat, Callable]]):
690
- super().__init__(patterns)
691
- for p,_ in self.patterns:
692
- if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
693
-
694
- def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
859
+ def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
695
860
  ret = None
696
861
  ler = {u.op for u in uop.src}
697
862
  for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
863
+ if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
698
864
  st = time.perf_counter()
699
865
  if not early_reject.issubset(ler):
700
866
  match_stats[p][2] += time.perf_counter()-st
@@ -705,10 +871,9 @@ class TrackedPatternMatcher(PatternMatcher):
705
871
  match_stats[p][0] += 1
706
872
  match_stats[p][3] += (et:=time.perf_counter()-st)
707
873
  if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
708
- if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].matches.append((uop, ret, p, et))
874
+ if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0: tracked_ctxs[-1][-1].matches.append((uop, ret, p))
709
875
  return ret # NOTE: if it returns None, we keep trying to match
710
876
  match_stats[p][2] += time.perf_counter()-st
711
- if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0: rewrite_stack[-1][1][-1].matches.append((uop, ret, None, 0))
712
877
  return None
713
878
 
714
879
  if TRACK_MATCH_STATS:
@@ -717,12 +882,10 @@ if TRACK_MATCH_STATS:
717
882
  @atexit.register
718
883
  def print_match_stats():
719
884
  if TRACK_MATCH_STATS >= 2:
720
- with open(fn:=temp("rewrites.pkl"), "wb") as f:
721
- print(f"rewrote {len(contexts)} graphs and matched {sum(len(r.matches) for _,x in contexts for r in x)} times, saved to {fn}")
722
- pickle.dump(contexts, f)
723
- if getenv("VIZ"):
724
- os.environ["VIZ"] = "0"
725
- os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py"), temp("rewrites.pkl")])
885
+ with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f:
886
+ print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
887
+ with Context(PICKLE_BUFFERS=0): pickle.dump((tracked_keys, tracked_ctxs), f)
888
+ if getenv("VIZ"): launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
726
889
  if getenv("PRINT_MATCH_STATS", 1):
727
890
  ret = [0,0,0.0,0.0]
728
891
  for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
@@ -731,118 +894,46 @@ if TRACK_MATCH_STATS:
731
894
  ret = [x+y for x,y in zip(ret, v)]
732
895
  print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {(ret[2]+ret[3])*1000.:9.2f} ms -- TOTAL")
733
896
 
897
+ def launch_viz(env_str:str, data:str):
898
+ os.environ[env_str] = "0"
899
+ os.environ[f"{env_str}_DATA"] = data
900
+ if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")):
901
+ args = ['--kernels', getenv("VIZ_DATA", "")] if getenv("VIZ_DATA", "") else []
902
+ args += ['--profile', getenv("PROFILE_DATA", "")] if getenv("PROFILE_DATA", "") else []
903
+ os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py")] + args)
904
+
734
905
  # *** simple graph rewrite engine ***
735
906
 
736
907
  class RewriteContext:
737
- def __init__(self, pm, ctx):
908
+ def __init__(self, pm, ctx=None):
738
909
  self.pm: PatternMatcher = pm
739
910
  self.ctx = ctx
740
- self.replace: Dict[UOp, UOp] = {}
741
- def rewrite(self, n:UOp) -> UOp:
911
+ self.replace: dict[UOp, UOp] = {}
912
+ def top_down_rewrite(self, n:UOp) -> UOp:
742
913
  if (rn := self.replace.get(n)) is not None: return rn
743
- new_src = tuple(map(self.rewrite, n.src))
914
+ new_src = tuple([self.top_down_rewrite(x) for x in n.src])
744
915
  new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg)
745
- self.replace[n] = ret = n if new_n is None else self.rewrite(new_n)
916
+ self.replace[n] = ret = n if new_n is None else self.top_down_rewrite(new_n)
917
+ return ret
918
+ def bottom_up_rewrite(self, n:UOp) -> UOp:
919
+ if (rn := self.replace.get(n)) is not None: return rn
920
+ new_n: UOp|None = n
921
+ while new_n is not None: last_n, new_n = new_n, self.pm.rewrite(new_n, self.ctx)
922
+ new_src = tuple([self.bottom_up_rewrite(x) for x in last_n.src])
923
+ self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
746
924
  return ret
747
925
 
748
- def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
749
- if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0:
750
- rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
751
- return RewriteContext(pm, ctx).rewrite(sink)
752
-
753
- # ***** uop type spec *****
754
-
755
- # this is the matcher for the final rendered UOps
756
- # matcher functions returns True or False (or None to not match)
757
- spec = PatternMatcher([
758
- (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
759
- (UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
760
- (UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True),
761
- lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
762
- (UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
763
-
764
- (UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
765
- (UPat(Ops.SPECIAL, src=()), lambda: True),
766
-
767
- # TODO: confirm the args of both of these are shapetrackers
768
- (UPat(Ops.VIEW, dtypes.void, src=()), lambda: True),
769
- (UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype),
770
-
771
- (UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
772
- (UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
773
-
774
- # early LOAD has a <buf, shapetracker, store?>
775
- (UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
776
- (UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
777
-
778
- # early STORE has a <buf, shapetracker, val>
779
- (UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
780
-
781
- # **** new style load/store ****
782
-
783
- # INDEX is used in new style load/store
784
- (UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
785
-
786
- # LOAD takes a <bufidx, alt?, gate?, barrier?>
787
- (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
788
- (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
789
- (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
790
-
791
- # STORE takes a <bufidx, val, gate?>
792
- (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True),
793
- (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
794
- (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True),
795
-
796
- # most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
797
- (UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
798
- (UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype == y.dtype),
799
- # and SHL/SHR, the shift distance can be an int
800
- (UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
801
- (UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
802
- (UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
803
-
804
- (UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
805
- (UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
806
-
807
- # all WMMA has 3 args, <x, w, acc>
808
- (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
809
- (UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
810
- (UPat(Ops.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
811
-
812
- # if has a <gate, barrier?>
813
- (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
814
- (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
815
- (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
816
-
817
- (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
818
- (UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
819
- (UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
820
- (UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
821
- (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
822
-
823
- # NOTE: for testing, we let sinks be anything
824
- #(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
825
- (UPat(Ops.SINK, dtypes.void), lambda: True),
826
- (UPat(Ops.NOOP), lambda: True),
827
-
828
- # PTX LOAD/STORE
829
- (UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
830
- (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True),
831
- ])
832
-
833
- def type_verify(uops:List[UOp]):
834
- for i,u in enumerate(uops):
835
- if not spec.rewrite(u):
836
- print_uops(uops)
837
- raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}")
926
+ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp:
927
+ if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True
928
+ tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
929
+ return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink)
838
930
 
839
- # *** uop helpers ***
931
+ def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> dict[UOp, UOp]:
932
+ if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True
933
+ tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
934
+ rewrite_ctx = RewriteContext(pm, ctx)
935
+ return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
840
936
 
841
- def cast_float_to_bf16(x: UOp) -> UOp:
842
- assert x.dtype == dtypes.float, "cast float -> bf16 must start with float"
843
- x = x.bitcast(dtypes.uint)
844
- x = (-x & 0x7f800000).where(x + ((x >> 16) & 1) + 0x7fff, (x & 0xffff).where((x | 0x10000), x))
845
- return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
846
937
 
847
938
  # *** most of symbolic lives here now ***
848
939
 
@@ -851,70 +942,68 @@ def split_uop(x:UOp, sep:Ops):
851
942
  for s in x.src: yield from split_uop(s, sep)
852
943
  else: yield x
853
944
 
854
- def mod_folding(x:UOp, c:int) -> Optional[UOp]:
855
- # simplify x % c, None means no change
945
+ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None:
946
+ # simplify x // y or x % y, None means no change
947
+ # simple cancel div/mod case
948
+ if y.vmin != 0 != y.vmax and (q:=x.vmin//y.vmin) == x.vmin//y.vmax == x.vmax//y.vmin == x.vmax//y.vmax:
949
+ return x - q*y if which is Ops.MOD else x.const_like(q)
856
950
 
857
- # simple cancel mod case
858
- if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c
951
+ if (y.op is not Ops.CONST) or ((c := y.arg) <= 0) or (x.dtype.count > 1): return None
859
952
 
860
- remainder, something_changed = [], False
953
+ svars, factors, quotients, remainders, gcd, div, const, offset, something_changed = [], [], [], [], c, 1, 0, 0, False
861
954
  for u in split_uop(x, Ops.ADD):
862
- if (factor:=u.const_factor())%c != factor:
863
- divides = u.divides(factor)*(factor%c)
864
- assert divides is not None
865
- remainder.append(divides)
866
- something_changed = True
867
- elif u.op is Ops.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0:
868
- remainder.append(u.src[0])
955
+ if u.op is Ops.MOD and which is Ops.MOD and u.src[1].op is Ops.CONST and u.src[1].arg%c == 0:
956
+ u = u.src[0]
869
957
  something_changed = True
870
- else: remainder.append(u)
871
- if not something_changed: return None
872
- return functools.reduce(operator.add, remainder)%c if remainder else x.const_like(0)
873
-
874
- def div_folding(x:UOp, c:int) -> Optional[UOp]:
875
- # simplify x // c, None means no change
958
+ v: UOp = u.divides(f:=u.const_factor())
959
+ q, r = divmod(f, c)
960
+ if r==0 or ((which is Ops.MOD or split_rem or u.op is Ops.CONST) and r!=f): something_changed = True
961
+ offset += r*v.vmin
962
+ if u.op is Ops.CONST: const += f
963
+ else: # div is the smallest common divisor of all terms
964
+ if f > 1 and c % f == 0 and (div == 1 or div > f): div = f
965
+ gcd = math.gcd(r, gcd)
966
+ factors.append(f); svars.append(v); quotients.append(q); remainders.append(r) # noqa: E702
967
+
968
+ lbound = ubound = offset = offset % c
969
+ # we can fold if the expression has only one non-constant term and this term can only take on two values
970
+ if len(svars)==1 and (v:=svars[0]).vmax-v.vmin == 1:
971
+ r = (offset+remainders[0])%c - offset%c
972
+ offset -= r * v.vmin
973
+ if which is Ops.MOD: return r*v + offset
974
+ return (factors[0]-r)//c * v + (const-offset)//c
975
+
976
+ # a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
977
+ # within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
978
+ for (r, v) in zip(remainders, svars):
979
+ if r > c//2:
980
+ if (lbound := lbound + (r:=r-c) * (v.vmax-v.vmin)) < 0: break
981
+ elif (ubound := ubound + r * (v.vmax-v.vmin)) >= c: break
982
+ offset -= r * v.vmin # determine what the new offset would be
983
+ else: # vmin/vmax of the remainder is between 0 and c, we can remove the mod/div
984
+ remainders = [min(r, r-c, key=abs) for r in remainders]
985
+ if which is Ops.MOD: return functools.reduce(operator.add, [r*v for r,v in zip(remainders,svars)], x.const_like(offset))
986
+ return functools.reduce(operator.add, [(f-r)//c * v for f,r,v in zip(factors, remainders,svars)], x.const_like((const-offset)//c))
987
+
988
+ if gcd != 1: something_changed = True
989
+ if not something_changed:
990
+ if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, UOp.const(dtypes.int, div), Ops.IDIV)) is not None: return newx//(c//div)
991
+ return None
992
+ quo, rem = x.const_like(const//c), x.const_like((const%c)//gcd)
993
+ for q,r,f,v in zip(quotients, remainders, factors, svars):
994
+ if which is Ops.IDIV and (not split_rem) and r!=0:
995
+ rem += f//gcd * v
996
+ else:
997
+ rem += r//gcd * v
998
+ quo += q * v
876
999
 
877
- # simple cancel div case
878
- if 0 <= x.vmin and x.vmax < c: return x.const_like(0)
1000
+ if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
1001
+ return rem//(c//gcd)+quo
879
1002
 
880
- quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1
881
- for u in split_uop(x, Ops.ADD):
882
- if u.op is Ops.CONST:
883
- # add all const together first
884
- if rem_const != 0: something_changed = True
885
- rem_const += u.arg
886
- elif (factor:=u.const_factor())%c == 0:
887
- if factor:
888
- divides = u.divides(c)
889
- assert divides is not None
890
- quotient.append(divides)
891
- something_changed = True
892
- else:
893
- # divisor is the smallest common divisor of all MULs
894
- if u.op is Ops.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor
895
- remainder.append(u)
896
- gcd = math.gcd(gcd, factor)
897
-
898
- # handle the const
899
- if rem_const%c != rem_const:
900
- something_changed = True
901
- quotient.append(x.const_like(rem_const//c))
902
- rem_const = rem_const%c
903
- if rem_const != 0: remainder.append(x.const_like(rem_const))
904
-
905
- # x // c -> quotient + (remainder // div) // (c // div)
906
- div = gcd if gcd > 1 else divisor
907
-
908
- if not something_changed: return newx//(c//div) if 1 < div < c and (newx:=div_folding(x, div)) is not None else None
909
- rem:Optional[UOp] = functools.reduce(operator.add, remainder) if remainder else None
910
- quo:Optional[UOp] = functools.reduce(operator.add, quotient) if quotient else None
911
- if quo is None: return x.const_like(0) if rem is None else cast(UOp, div_folding(rem, div))//(c//div)
912
- return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo
913
-
914
- def lt_folding(x:UOp, c:int) -> Optional[UOp]:
1003
+ def lt_folding(x:UOp, c:int) -> UOp|None:
915
1004
  p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
916
1005
  if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
917
- return cast(UOp, functools.reduce(operator.add, np).divides(d)).lt(c//d)
1006
+ return cast(UOp, functools.reduce(operator.add, np).divides(d))<(c//d)
918
1007
  return None
919
1008
 
920
1009
  def fold_unrolled_divs(divs:UOp):
@@ -938,7 +1027,7 @@ def fold_unrolled_divs(divs:UOp):
938
1027
  if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
939
1028
  return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
940
1029
 
941
- def canonicalize_simplex(X:UOp) -> Optional[UOp]:
1030
+ def canonicalize_simplex(X:UOp) -> UOp|None:
942
1031
  # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
943
1032
  # returns x0 + x1 + ... in such case, or None if not
944
1033
  changed, ret = False, []
@@ -958,7 +1047,7 @@ def is_increasing(f:UOp) -> bool:
958
1047
  if f.op in (Ops.MUL, Ops.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
959
1048
  return False # False if not sure
960
1049
 
961
- def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
1050
+ def parse_valid(valid:UOp) -> tuple[UOp, bool, int]:
962
1051
  # if it's X <= c, returns X, True, c
963
1052
  # if it's X >= c, returns X, False, c
964
1053
 
@@ -966,14 +1055,14 @@ def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
966
1055
  if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
967
1056
  (s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
968
1057
  # X < c -> X <= c-1
969
- if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1
1058
+ if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, valid.src[1].arg-1
970
1059
  raise ValueError(f"not able to parse {valid=}")
971
1060
 
972
- def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
1061
+ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
973
1062
  # return None if valid is always False, otherwise the simplified uop (might be the same as input)
974
1063
 
975
1064
  # first, parse valid into {expr: (lower_bound, upper_bound)}
976
- bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
1065
+ bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
977
1066
  for stmt in split_uop(valid, Ops.AND):
978
1067
  try: expr, is_upper, c = parse_valid(stmt)
979
1068
  except ValueError: return uop # give up if we cannot parse the valid
@@ -990,7 +1079,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
990
1079
  # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
991
1080
  candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
992
1081
  # try checking the whole clause
993
- if expr in uop.sparents:
1082
+ if expr in uop.toposort:
994
1083
  candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
995
1084
 
996
1085
  for candidate in candidates:
@@ -1003,13 +1092,13 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
1003
1092
 
1004
1093
  return uop
1005
1094
 
1006
- def _valid_priority(v: UOp, valids:List[UOp]):
1095
+ def _valid_priority(v: UOp, valids:list[UOp]):
1007
1096
  # we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
1008
- try: return sum(-1 if parse_valid(v)[0] in other.parents else 0 for other in valids)
1097
+ try: return sum(-1 if parse_valid(v)[0] in other.toposort else 0 for other in valids)
1009
1098
  except ValueError: return 0
1010
1099
 
1011
- def simplify_valid(valid:UOp) -> Optional[UOp]:
1012
- ret:List[UOp] = []
1100
+ def simplify_valid(valid:UOp) -> UOp|None:
1101
+ ret:list[UOp] = []
1013
1102
  something_changed = False
1014
1103
  valids = list(split_uop(valid, Ops.AND))
1015
1104
  for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
@@ -1017,9 +1106,11 @@ def simplify_valid(valid:UOp) -> Optional[UOp]:
1017
1106
  if ret[-1] is not stmt: something_changed = True
1018
1107
  return functools.reduce(operator.and_, ret) if something_changed else None
1019
1108
 
1020
- def max_var_const(x:UOp, c1:UOp, c2:UOp):
1021
- if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
1022
- if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
1109
+ # def max_var_const(x:UOp, c1:UOp, c2:UOp):
1110
+ # if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
1111
+ # if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
1112
+
1113
+ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x
1023
1114
 
1024
1115
  symbolic_simple = PatternMatcher([
1025
1116
  # ** self folding **
@@ -1032,23 +1123,25 @@ symbolic_simple = PatternMatcher([
1032
1123
  ((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
1033
1124
  ((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
1034
1125
  (UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
1126
+ ((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
1127
+ lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
1035
1128
  (UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
1036
1129
  (UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
1037
- (UPat.var("x").maximum(UPat.var("x")), lambda x: x),
1038
- ((UPat.var("x") & UPat.var("x")), lambda x: x),
1039
- ((UPat.var("x") | UPat.var("x")), lambda x: x),
1130
+ (UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
1040
1131
  (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
1041
1132
  (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
1042
1133
  # ** zero folding **
1043
- (UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
1134
+ (UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
1044
1135
  (UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
1045
- lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints)
1136
+ lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
1046
1137
  # x*0 -> 0 or 0*x -> 0
1047
1138
  # if x is nan or inf it should render the nan value.
1048
1139
  # NOTE: this can be wrong for loaded NaN
1049
1140
  (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
1050
1141
  # ** constant folding **
1051
- (UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))), lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False))),
1142
+ # TODO: add const folding for Ops.THREEFRY
1143
+ (UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))),
1144
+ lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False)) if a.op is not Ops.THREEFRY else None),
1052
1145
  # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
1053
1146
  (UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
1054
1147
  (UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
@@ -1061,46 +1154,45 @@ symbolic_simple = PatternMatcher([
1061
1154
  symbolic = symbolic_simple+PatternMatcher([
1062
1155
  # ** COMMUTATIVE flipping **
1063
1156
  (UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
1064
- # group like
1065
- ((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
1066
1157
  # ** boolean algebra **
1067
1158
  (UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
1068
1159
  # ** combine terms **
1069
1160
  (UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
1161
+ ((UPat.var("y") + UPat.var("x") * UPat.cvar("c0")) + UPat.var("x") * UPat.cvar("c1"), lambda x,y,c0,c1: y+x*(c0+c1)),
1070
1162
  (UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
1163
+ ((UPat.var("y") + UPat.var("x")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: y+x*(c+1)),
1071
1164
  (UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
1072
- ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
1165
+ ((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2),
1166
+ ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3)
1073
1167
  (-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
1074
1168
  # a conditional with the same results either way is a noop, also fold const conditionals
1075
1169
  (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
1076
1170
  (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
1171
+ # alu of two where with same conds can combine, only do if true branch or false branch is const
1172
+ (UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
1173
+ lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
1077
1174
  # ALU min==max -> CONST (slow!)
1078
1175
  (UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
1079
1176
  # max folding
1080
1177
  (UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
1081
1178
  # TODO: why does this rule break beautiful_mnist?
1082
1179
  #((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z),
1083
- ((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
1180
+ #((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
1084
1181
  # ** two stage ALU folding **
1085
- ((UPat.var("x") + UPat.cvar("c1")) + UPat.cvar("c2"), lambda x,c1,c2: x+(c1+c2)),
1086
- ((UPat.var("x") * UPat.cvar("c1")) * UPat.cvar("c2"), lambda x,c1,c2: x*(c1*c2)),
1087
- ((UPat.var("x") & UPat.cvar("c1")) & UPat.cvar("c2"), lambda x,c1,c2: x&(c1&c2)),
1088
- ((UPat.var("x") | UPat.cvar("c1")) | UPat.cvar("c2"), lambda x,c1,c2: x|(c1|c2)),
1182
+ *((UPat.var("x").alu(op, UPat.cvar("c1")).alu(op, UPat.cvar("c2")).named("f"),
1183
+ lambda f,x,c1,c2: x.alu(f.op,c1.alu(f.op,c2))) for op in GroupOp.Associative),
1089
1184
  ((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
1090
1185
  ((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
1091
1186
  # ** lt **
1092
1187
  # c0*x<c1 for positive int c0,c1
1093
- ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
1094
- lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if c0.arg > 0 and c1.arg > 0 else None),
1188
+ ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
1189
+ lambda x,c0,c1: x<math.ceil(c1.arg/c0.arg) if c0.arg > 0 and c1.arg > 0 else None),
1095
1190
  # c0*x<c1 for negative int c0 and non-positive c1
1096
- ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
1097
- lambda x,c0,c1: (-x).lt(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
1191
+ ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
1192
+ lambda x,c0,c1: (-x)<(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
1098
1193
  # x//c0<c1 for positive int c0
1099
- ((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False)).lt(UPat.cvar("c1", vec=False)),
1100
- lambda x,c0,c1: x.lt(c1.arg*c0.arg) if c0.arg > 0 else None),
1101
- # mul add lt
1102
- (((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
1103
- lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
1194
+ ((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False))<UPat.cvar("c1", vec=False),
1195
+ lambda x,c0,c1: x<(c1.arg*c0.arg) if c0.arg > 0 else None),
1104
1196
  # ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
1105
1197
  (UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
1106
1198
  (UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
@@ -1108,18 +1200,20 @@ symbolic = symbolic_simple+PatternMatcher([
1108
1200
  # unrolled arange div folding
1109
1201
  (UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),
1110
1202
  # generic lt folding
1111
- (UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
1203
+ (UPat.var("x", dtypes.sints)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
1112
1204
  # canonicalize a simplex with positive coefficients > 0
1113
1205
  # not x < 1 -> X > 0
1114
- (UPat.var("x", dtypes.ints).lt(1).ne(True), lambda x: newx.lt(1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
1206
+ ((UPat.var("x", dtypes.ints)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
1115
1207
  # ** div **
1116
- # # div folding
1117
- (UPat.var("x", dtypes.sints) // UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=div_folding(x,c.arg)) is not None else None),
1208
+ # div folding
1209
+ ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)), # (x//c+a)//d -> (x+a*c)//(c*d)
1210
+ (UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
1118
1211
  # ** mod **
1119
1212
  # mod folding
1120
- (UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
1213
+ (UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
1121
1214
  ])
1122
1215
 
1216
+
1123
1217
  symbolic_flat = symbolic+PatternMatcher([
1124
1218
  # ** combine terms (opinionated) **
1125
1219
  (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
@@ -1134,7 +1228,7 @@ syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<"
1134
1228
  Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
1135
1229
  renderer = PatternMatcher([
1136
1230
  (UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
1137
- (UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")),
1231
+ (UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg}")),
1138
1232
  (UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
1139
1233
  (UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
1140
1234
  (UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
@@ -1149,4 +1243,26 @@ renderer = PatternMatcher([
1149
1243
  sint = Union[int, UOp]
1150
1244
  Variable = UOp
1151
1245
 
1152
- ConstLike = Union[ConstType, Variable, Tuple[ConstType, ...]]
1246
+ ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]]
1247
+
1248
+ # *** UOp merge views and swizzling ***
1249
+
1250
+ merge_views = PatternMatcher([
1251
+ # VIEW(VIEW) merges to a single VIEW
1252
+ (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)),
1253
+ (UPat(Ops.VIEW, name="vm", src=(UPat.var("x"),)), lambda vm,x: x if vm.st.contiguous and x.st is not None and x.shape == vm.shape else None),
1254
+ # merge unmasked const views
1255
+ (UPat(Ops.VIEW, name="view", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
1256
+ lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
1257
+ ])
1258
+
1259
+ # push VIEW to parents
1260
+ view_left = merge_views+PatternMatcher([
1261
+ # VIEW(CONST) becomes VALID
1262
+ (UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.valid(vm.st)),
1263
+ # VIEW before elementwise/buffer ops
1264
+ (UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
1265
+ lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
1266
+ (UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)),
1267
+ lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
1268
+ ])