tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/ops.py CHANGED
@@ -1,14 +1,14 @@
1
1
  from __future__ import annotations
2
- from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, DefaultDict
3
- import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect
2
+ from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type, get_args
3
+ import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
4
4
  from enum import auto, IntEnum, Enum
5
5
  from dataclasses import dataclass, field
6
- from collections import defaultdict
7
- from weakref import WeakValueDictionary
8
- from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
9
- from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T
6
+ from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
7
+ from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, _METADATA, flatten
8
+ from tinygrad.helpers import PICKLE_BUFFERS, SPLIT_REDUCEOP, DEBUG, dedup
10
9
  if TYPE_CHECKING:
11
10
  from tinygrad.shape.shapetracker import ShapeTracker
11
+ from tinygrad.device import Buffer
12
12
 
13
13
  # wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
14
14
  class FastEnum(IntEnum):
@@ -26,8 +26,7 @@ class SimpleMathTrait:
26
26
  def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
27
27
  def logical_not(self): return self.ne(True)
28
28
  def neg(self):
29
- dtype: Optional[DType] = getattr(self, 'dtype', None)
30
- assert dtype is not None, "MathTraits __neg__ requires a dtype"
29
+ if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
31
30
  return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
32
31
  def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse)
33
32
  def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse)
@@ -35,6 +34,7 @@ class SimpleMathTrait:
35
34
  def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse)
36
35
  def xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse)
37
36
  def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse)
37
+ def mod(self, x, reverse=False): return self._binop(Ops.MOD, x, reverse)
38
38
  def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
39
39
  def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
40
40
 
@@ -44,7 +44,8 @@ class SimpleMathTrait:
44
44
  def __sub__(self, x): return self.sub(x)
45
45
  def __mul__(self, x): return self.mul(x)
46
46
  def __truediv__(self, x): return self.div(x)
47
- def __floordiv__(self, x): return self.idiv(x)
47
+ def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv
48
+ def __mod__(self, x): return self.mod(x)
48
49
  def __and__(self, x): return self.bitwise_and(x)
49
50
  def __or__(self, x): return self.bitwise_or(x)
50
51
  def __xor__(self, x): return self.xor(x)
@@ -57,22 +58,19 @@ class SimpleMathTrait:
57
58
  def __rand__(self, x): return self.bitwise_and(x, True)
58
59
  def __ror__(self, x): return self.bitwise_or(x, True)
59
60
  def __rxor__(self, x): return self.xor(x, True)
61
+ def __rmod__(self, x): return self.mod(x, True)
62
+
63
+ def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
64
+ def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
65
+ def __ge__(self, x): return (self < x).logical_not()
66
+ def __le__(self, x): return (self > x).logical_not()
60
67
 
61
- def lt(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
62
- def gt(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
63
68
  def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x))
64
- def ge(self, x): return self.lt(x).logical_not()
65
- def le(self, x): return self.gt(x).logical_not()
66
69
  def eq(self, x): return self.ne(x).logical_not()
67
-
68
- def __lt__(self, x): return self.lt(x)
69
- def __gt__(self, x): return self.gt(x)
70
70
  def __ne__(self, x): return self.ne(x)
71
- def __ge__(self, x): return self.ge(x)
72
- def __le__(self, x): return self.le(x)
73
71
  # NOTE: __eq__ isn't overridden, and means the same thing as is by default
74
72
 
75
- class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
73
+ class MathTrait(SimpleMathTrait):
76
74
  # TODO: move to Tensor when new backward is done
77
75
  def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse)
78
76
  def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse)
@@ -81,10 +79,6 @@ class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
81
79
  def __rlshift__(self, x): return self.lshift(x, True)
82
80
  def __rrshift__(self, x): return self.rshift(x, True)
83
81
 
84
- # not in Tensor
85
- def __mod__(self, x): return self.alu(Ops.MOD, self.ufix(x))
86
- def __rmod__(self, x): return self.ufix(x).alu(Ops.MOD, self)
87
-
88
82
  def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
89
83
  def minimum(self, x): return -(-self).maximum(-x)
90
84
  def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y))
@@ -94,118 +88,126 @@ class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
94
88
  def sin(self): return self.alu(Ops.SIN)
95
89
  def log2(self): return self.alu(Ops.LOG2)
96
90
  def exp2(self): return self.alu(Ops.EXP2)
91
+ def pow(self, x): return self.alu(Ops.POW, self.ufix(x))
97
92
 
98
93
  # the order of these Ops controls the order of the toposort
99
94
  class Ops(FastEnum):
100
95
  # uops that aren't rendered
101
- SINK = auto()
102
- CONTIGUOUS = auto()
103
- PRELOAD = auto()
96
+ NAME = auto(); SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); KERNEL = auto(); UNIQUE = auto() # noqa: E702
104
97
 
105
- # MetaOps
106
- COPY = auto()
98
+ # TODO: empty continues to exist because of tensor
107
99
  EMPTY = auto()
108
- BUFFER_VIEW = auto()
109
-
110
- EXPAND = auto()
111
- CONTRACT = auto()
112
- VIEW = auto()
113
- DEFINE_GLOBAL = auto()
114
- BUFFER = auto()
115
- DEFINE_VAR = auto()
116
- DEFINE_LOCAL = auto()
117
- DEFINE_ACC = auto()
118
- VALID = auto()
119
- SPECIAL = auto()
120
- NOOP = auto()
100
+
101
+ # MetaOps
102
+ COPY = auto(); BUFFER_VIEW = auto() # noqa: E702
103
+
104
+ # blocks in linearizer
105
+ BLOCK = auto(); BLOCKSTART = auto(); BLOCKFORK = auto(); BLOCKEND = auto() # noqa: E702
106
+
107
+ # movement ops!
108
+ RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
109
+
110
+ # misc ops
111
+ UNROLL = auto(); CONTRACT = auto() # noqa: E702
112
+ VIEW = auto(); DEFINE_GLOBAL = auto(); BUFFER = auto() # noqa: E702
113
+ DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
114
+ VALID = auto(); SPECIAL = auto(); NOOP = auto() # noqa: E702
121
115
 
122
116
  # reduce
123
117
  REDUCE_AXIS = auto()
124
118
 
125
119
  # helper ops
126
- GEP = auto()
127
- VECTORIZE = auto()
120
+ GEP = auto(); VECTORIZE = auto(); CAT = auto() # noqa: E702
128
121
 
129
122
  # UnaryOps
130
123
  CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
131
124
 
132
- # loads before math
133
- LOAD = auto()
125
+ # load/store before math
126
+ LOAD = auto(); STORE = auto() # noqa: E702
127
+
128
+ # early INDEX
129
+ INDEX = auto()
134
130
 
135
131
  # math ops
136
132
  WMMA = auto()
137
133
 
138
134
  # BinaryOps
139
135
  ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
140
- SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702
136
+ SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
141
137
 
142
138
  # TernaryOps
143
139
  WHERE = auto(); MULACC = auto() # noqa: E702
144
140
 
145
141
  # assignment ops
146
- STORE = auto()
147
142
  ASSIGN = auto()
148
143
  BIND = auto()
149
144
 
150
- # late INDEX
151
- INDEX = auto()
152
-
153
145
  # control flow ops
154
- BARRIER = auto()
155
- IF = auto()
156
- RANGE = auto()
157
-
158
- # ops that are not graph nodes
159
- ENDRANGE = auto()
160
- ENDIF = auto()
146
+ BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702
161
147
 
162
148
  # consts last!
163
- VCONST = auto()
164
- CONST = auto()
149
+ VCONST = auto(); CONST = auto() # noqa: E702
150
+
151
+ # device
152
+ DEVICE = auto()
153
+ MULTI = auto()
154
+ CUSTOM = auto()
165
155
 
166
156
  class GroupOp:
167
157
  Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
168
158
  Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
169
- Ops.SUB, Ops.FDIV}
159
+ Ops.SUB, Ops.FDIV, Ops.POW}
170
160
  Ternary = {Ops.WHERE, Ops.MULACC}
171
161
  ALU = set.union(Unary, Binary, Ternary)
172
162
 
173
163
  Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
164
+ Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
174
165
 
175
- # meta ops
176
- Meta = {Ops.COPY, Ops.EMPTY, Ops.BUFFER_VIEW}
177
- Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
166
+ Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
167
+ Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
178
168
 
179
169
  # BinaryOps that can be flipped
180
170
  Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR}
181
171
 
172
+ # BinaryOps where f(f(a,b),c) = f(a,f(b,c))
173
+ Associative = {Ops.ADD, Ops.MUL, Ops.AND, Ops.OR, Ops.MAX}
174
+
175
+ # BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence
176
+ Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
177
+
182
178
  # do not preserve f(0) = 0
183
- UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
179
+ UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
184
180
 
185
- # https://en.wikipedia.org/wiki/Identity_element
186
- def identity_element(op:Ops, dt:DType): return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
181
+ All = set(Ops)
187
182
 
188
- def can_pad(u:UOp) -> bool: return not any(x.op in GroupOp.UnsafePad for x in u.sparents)
183
+ # some BUFFER ops can be processed with only a view
184
+ view_supported_devices = {"LLVM", "CPU", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
189
185
 
190
- END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
186
+ # https://en.wikipedia.org/wiki/Identity_element
187
+ def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
188
+
189
+ def can_pad(u:UOp, edges:dict[UOp, None], cache:dict[UOp, None]) -> bool:
190
+ if u.op in GroupOp.UnsafePad: return False
191
+ if u in edges or u in cache: return True
192
+ cache[u] = None
193
+ return all(can_pad(x.base, edges, cache) for x in u.src)
191
194
 
192
195
  # With True as the default, this matches the old symbolic behavior
193
- def resolve(x, default:bool=True):
194
- if not isinstance(x, UOp): return bool(x)
195
- assert x.dtype is dtypes.bool, "UOp in resolve must be bool"
196
+ def resolve(x:UOp|bool, default:bool=True):
197
+ if isinstance(x, bool): return x
198
+ assert x.dtype == dtypes.bool, "UOp in resolve must be bool"
196
199
  # NOTE: generating the text for the exception is expensive, so we do this
197
200
  return bool(sx.vmin) if (sx:=x.simplify()).vmin == sx.vmax else default
198
201
 
199
202
  # smax/smin are replacements for max/min that preserve symbolic
200
203
  def _suop(lst, uop_fxn, python_fxn):
201
- max_uop, max_num = partition(lst, lambda x: isinstance(x, UOp))
202
- if len(max_uop): return functools.reduce(uop_fxn, (max_uop + [python_fxn(max_num)]) if len(max_num) else max_uop).ssimplify()
203
- return python_fxn(max_num)
204
- def smax(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.maximum, max)
205
- def smin(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.minimum, min)
204
+ uops, nums = partition(lst, lambda x: isinstance(x, UOp))
205
+ return ssimplify(functools.reduce(uop_fxn, uops + ([python_fxn(nums)] if nums else [])))
206
+ def smax(*lst): return _suop(argfix(*lst), UOp.maximum, max)
207
+ def smin(*lst): return _suop(argfix(*lst), UOp.minimum, min)
206
208
 
207
209
  def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
208
- def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
210
+ def sym_infer(uop: Union[UOp, int], var_vals: dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
209
211
 
210
212
  # used for UOp and UPat
211
213
  def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
@@ -219,57 +221,108 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
219
221
  return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
220
222
 
221
223
  class UOpMetaClass(type):
222
- ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
223
- def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
224
- if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
225
- UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg)
226
- return ret
227
-
224
+ ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
225
+ def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, _buffer:Buffer|None=None):
226
+ if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
227
+ UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
228
+ for s in src: s.children.add(ref)
229
+ # NOTE: this will soon be set by Tensor once we remove function.py
230
+ if (metadata:=_METADATA.get()) is not None: all_metadata[created] = metadata
231
+ # NOTE: this value is set by pickle when pickling a realized tensor
232
+ if _buffer is not None:
233
+ assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
234
+ buffers[created] = _buffer
235
+ return created
236
+
237
+ # some uops map to other stuff
238
+ buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers
239
+ all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary()
240
+
241
+ # NOTE: this should be frozen, but frozen is slower
242
+ @dataclass(eq=False, slots=True)
228
243
  class UOp(MathTrait, metaclass=UOpMetaClass):
229
- __slots__ = ["op", "dtype", "src", "arg"]
230
- def __init__(self, op:Ops, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
231
- # TODO: instant check rules here make debugging easier
232
- self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
233
- def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg)
244
+ op:Ops
245
+ dtype:DType = dtypes.void
246
+ src:tuple[UOp, ...] = tuple()
247
+ arg:Any = None
248
+ children:set[weakref.ref[UOp]] = field(default_factory=set)
249
+ def __del__(self):
250
+ if self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1)
251
+ if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg))) is not None:
252
+ for s in self.src: s.children.discard(ref)
253
+ del UOpMetaClass.ucache[k]
254
+ def __reduce__(self):
255
+ args = [self.op, self.dtype, self.src, self.arg]
256
+ if self.op is Ops.BUFFER and self.realized is not None and PICKLE_BUFFERS: args.append(self.realized)
257
+ return UOp, tuple(args)
234
258
  def replace(self, **kwargs) -> UOp:
235
- for k in kwargs: assert k in self.__slots__, f"unkown replace arg, expected one of {self.__slots__}, got {k}"
236
- new_args = (kwargs.get("op", self.op), kwargs.get("dtype", self.dtype), kwargs.get("src", self.src), kwargs.get("arg", self.arg))
259
+ new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg))
260
+ assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}"
237
261
  if (self.op, self.dtype, self.src, self.arg) == new_args: return self
238
262
  return UOp(*new_args)
239
263
  @functools.cached_property
240
264
  def key(self) -> bytes:
241
265
  return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
242
266
  def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
243
- def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else self.arg
244
- @functools.cached_property
245
- def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
246
- @functools.cached_property # parents with self
247
- def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
267
+ def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else repr(self.arg)
268
+
269
+ @property
270
+ def toposort(self) -> dict[UOp, None]:
271
+ def _toposort(u:UOp, cache:set[UOp]):
272
+ if u in cache: return {}
273
+ nodes: dict[UOp, None] = {}
274
+ # NOTE: this is a lot faster than the comprehension in parents
275
+ for parent in u.src: nodes.update(_toposort(parent, cache))
276
+ nodes[u] = None
277
+ cache.add(u)
278
+ return nodes
279
+ return _toposort(self, cache=set())
248
280
 
249
281
  @functools.cached_property
250
- def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]:
251
- return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src))
282
+ def tuplize(self:UOp) -> tuple[int, Any, Optional[DType], tuple]: return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src))
252
283
 
253
284
  # *** uop shape stuff ***
254
285
 
255
- @property
256
- def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR}
257
286
  @functools.cached_property
258
- def st(self) -> Optional[ShapeTracker]:
259
- if not self.has_st: return None
260
- if self.op in GroupOp.Buffer: return self.st_arg
261
- if self.op is Ops.VIEW: return self.arg
262
- src_sts = [x.st for x in self.src if x.st is not None]
263
- assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
287
+ def st(self) -> ShapeTracker|None:
264
288
  from tinygrad.shape.shapetracker import ShapeTracker
265
- return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is Ops.REDUCE_AXIS else src_sts[0]
289
+ if self.op is Ops.MULTI:
290
+ return ShapeTracker.from_shape(
291
+ tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)))
292
+ if self.op is Ops.BUFFER: return ShapeTracker.from_shape((self.size,))
293
+ if self.op is Ops.KERNEL: return ShapeTracker.from_shape(self.arg.ast.shape)
294
+ # these ops define a ShapeTracker from the arg
295
+ if self.op is Ops.VIEW: return self.arg
296
+ if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
297
+ # buffer ops return the ShapeTracker from sources
298
+ if self.op in GroupOp.Buffer: return vsrc[0] if len(vsrc:=[x.st for x in self.src if x.op is Ops.VIEW]) != 0 else None
299
+ if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
300
+ assert all_same([x.shape for x in src_sts]), f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}"
301
+ if self.op in {Ops.BITCAST, Ops.BUFFER_VIEW}:
302
+ shape = src_sts[0].shape
303
+ if self.dtype.itemsize != (input_sz:=self.src[0].dtype.itemsize): shape = shape[:-1]+((shape[-1]*input_sz) // self.dtype.itemsize,)
304
+ # only reduce ops are allowed to change shape, everything else derives shape from sources
305
+ elif self.op in {Ops.REDUCE_AXIS, Ops.WMMA}: shape = src_sts[0].reduce(self.axis_arg)
306
+ else: shape = src_sts[0].shape
307
+ return ShapeTracker.from_shape(shape)
308
+
266
309
  @functools.cached_property
267
- def full_shape(self) -> Tuple[sint, ...]:
268
- return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
310
+ def full_shape(self) -> tuple[sint, ...]:
311
+ if self.op is Ops.VIEW: return self.shape
312
+ # TODO: this should check if st is None, it cannot because local reduce has implicit movement ops
313
+ return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} \
314
+ # TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this
315
+ and not (x.op is Ops.CONST and x.st is None)]))
316
+ @property
317
+ def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
318
+ @property
319
+ def size(self) -> int: return self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
269
320
 
270
321
  # *** uop evaluation ***
271
322
 
272
323
  def simplify(self):
324
+ # late import!
325
+ from tinygrad.codegen.symbolic import symbolic
273
326
  with Context(TRACK_MATCH_STATS=0):
274
327
  return graph_rewrite(self, symbolic)
275
328
  def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
@@ -282,34 +335,37 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
282
335
  def __bool__(self): return self._eval((dtypes.bool,), bool)
283
336
  def __int__(self): return self._eval(dtypes.ints, int)
284
337
  def __float__(self): return self._eval(dtypes.floats, float)
285
- def substitute(self, dvars:Dict[UOp, UOp]):
338
+ def substitute(self, dvars:dict[UOp, UOp]):
286
339
  with Context(TRACK_MATCH_STATS=0):
287
- return graph_rewrite(self, _substitute, dvars)
340
+ return graph_rewrite(self, _substitute, dvars, bottom_up=True)
288
341
 
289
342
  # *** uop syntactic sugar ***
290
343
 
291
344
  @property
292
345
  def st_arg(self) -> ShapeTracker:
293
346
  assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}"
294
- ret = self.src[0 if self.op is Ops.VALID else 1]
295
- assert ret.op is Ops.VIEW, f"st_arg trying to return {ret}"
296
- return ret.arg
347
+ return unwrap(self.st)
297
348
  @property
298
- def axis_arg(self) -> Tuple[int, ...]:
349
+ def axis_arg(self) -> tuple[int, ...]:
299
350
  assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
300
351
  ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
301
352
  assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
302
353
  return ret
303
354
  def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
304
- def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
305
- def const_like(self, b:ConstLike): return UOp.const(self.dtype, b)
355
+ def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
356
+ def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
357
+ def const_like(self, b:ConstLike):
358
+ # constants can optionally have a DEVICE source
359
+ if self._device is None: return UOp.const(self.dtype, b)
360
+ if isinstance(self.device, tuple): return UOp.multi(*[UOp.metaop(Ops.CONST, self.shape, self.dtype, d, b) for d in self.device], axis=None)
361
+ return UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b)
306
362
  def broadcast(self, count:int):
307
363
  assert self.dtype.count == 1
308
364
  if count == 1: return self
309
365
  return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
310
- def cast(self, dtype:DType): return UOp(Ops.CAST, dtype, (self,))
311
- def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
312
- def gep(self, i:Union[Tuple[int, ...], int]):
366
+ def cast(self, dtype:DType): return self if self.dtype == dtype else UOp(Ops.CAST, dtype, (self,))
367
+ def bitcast(self, dtype:DType): return self if self.dtype == dtype else UOp(Ops.BITCAST, dtype, (self,))
368
+ def gep(self, i:Union[tuple[int, ...], int]):
313
369
  if isinstance(i, int):
314
370
  # NOTE: these are just shortcuts to not have to create and fold later
315
371
  if self.op is Ops.VECTORIZE: return self.src[i]
@@ -322,42 +378,192 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
322
378
  def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
323
379
  def alu(self, arg, *src:UOp):
324
380
  out_dtype = (self, *src)[-1].dtype
325
- if arg in {Ops.CMPLT, Ops.CMPNE} and out_dtype is not None:
326
- out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
381
+ if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
327
382
  return UOp(arg, out_dtype, (self,)+src)
328
383
  @staticmethod
329
384
  def const(dtype:DType, b:ConstLike):
330
385
  if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
331
386
  if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
332
- return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
387
+ return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
388
+ def valid(self, st:ShapeTracker):
389
+ assert self.op in {Ops.CONST, Ops.DEFINE_VAR}, f"can only create VALID from a constant, got {self.op}"
390
+ from tinygrad.shape.shapetracker import ShapeTracker
391
+ # NOTE: only VALID has a masked ShapeTracker, the CONST operands are unmasked
392
+ unmasked_st = ShapeTracker.from_shape(()).reshape((1,)*len(st.shape)).expand(st.shape).to_uop()
393
+ return UOp(Ops.VALID, dtypes.bool, (st.to_uop(),)).where(self.replace(src=(unmasked_st,)), UOp.const(self.dtype, 0).replace(src=(unmasked_st,)))
333
394
  @staticmethod
334
- def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
335
- return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
336
- UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=(idx, False))
337
- def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
395
+ def range(dtype:DType, start:sint, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(start), sint_to_uop(end)), arg=idx)
396
+ def _reduce_op(self, op:Ops, axis:tuple[int, ...]):
397
+ axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
398
+ return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
399
+ def r(self, op:Ops, axis:tuple[int, ...]) -> UOp:
400
+ new_shape = unwrap(self.st).reduce(axis)
401
+
402
+ # TODO: can we split symbolic shape if the reduce axis is not symbolic?
403
+ # TODO: this shouldn't be here, it belongs in scheduler! that's why it broke multi
404
+ if not SPLIT_REDUCEOP or isinstance(self._device, tuple) or not all_int(self.shape) or (0 in self.shape) or \
405
+ prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
406
+ return self._reduce_op(op, axis)
407
+
408
+ # if there are few globals, make some reduces into globals by splitting into two kernels
409
+ # cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
410
+ # ~2**10 should be enough if GROUP is used
411
+ # 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
412
+ # split is moved to the end to provide maximum locality for the second phase reduce.
413
+ self_real_strides = unwrap(self.st).real_strides(ignore_valid=True)
414
+ split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
415
+ if self.shape[i] % x == 0 and self_real_strides[i] != 0]
416
+ if not split_candidates: return self._reduce_op(op, axis)
417
+ dim_to_split, divisor = split_candidates[0]
418
+ splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
419
+ splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
420
+ if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
421
+ return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
338
422
  def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
339
- def contiguous(self): return UOp(Ops.CONTIGUOUS, self.dtype, (self,))
423
+ def contiguous(self): return self.alu(Ops.CONTIGUOUS)
424
+ def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
425
+
426
+ # *** from MultiLazyBuffer ***
427
+
428
+ def multi(self, *more:UOp, axis:int|None, real:tuple[bool,...]|None=None):
429
+ parents = (self,)+more
430
+ assert all_same([x.dtype for x in parents]), "multi parents must have the same dtype"
431
+ return UOp(Ops.MULTI, self.dtype, parents, (axis, real if real is not None else (True,)*len(parents)))
432
+
433
+ @property
434
+ def bounds(self):
435
+ if self.axis is None: raise RuntimeError("bounds is not defined when axis is None")
436
+ return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.src], initial=0)))
437
+
438
+ @functools.cached_property
439
+ def axis(self) -> Optional[int]:
440
+ if self.op is Ops.MULTI: return self.arg[0]
441
+ # NOTE: they all have to share an axis, we always choose [-1]
442
+ if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
443
+ src_axis = self.src[0].axis
444
+ if self.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in self.arg[1] else src_axis
445
+ if self.op is Ops.RESHAPE:
446
+ if src_axis is None: return None
447
+ arg_acc:list[sint] = list(itertools.accumulate(self.arg, operator.mul, initial=1))
448
+ # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
449
+ # TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
450
+ return len(arg_acc) - arg_acc[::-1].index(prod(self.src[0].shape[:src_axis])) - 1
451
+ if self.op is Ops.PERMUTE: return self.arg.index(src_axis) if src_axis is not None else None
452
+ return src_axis
453
+
454
+ @property
455
+ def real(self):
456
+ assert self.op is Ops.MULTI
457
+ return self.arg[1]
458
+
340
459
  @property
341
- def is_contiguous_base(self): return self.op is Ops.CONTIGUOUS and not (self.src[0].base.op is Ops.VIEW and len(self.src[0].base.src) == 2)
460
+ def real_lbs(self): return [lb for lb,r in zip(self.src, self.real) if r]
461
+
462
+ def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> UOp:
463
+ if axis is None: lbs = [self] * len(devices)
464
+ else:
465
+ if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}")
466
+ # NOTE: this works for both even shards and uneven shards
467
+ sz = self.shape[axis] // len(devices)
468
+ sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]
469
+ lbs = []
470
+ for sz,off in zip(sizes, itertools.accumulate(sizes, initial=0)):
471
+ lbs.append(self.shrink(tuple((0,s) if i != axis else (off,off+sz) for i,s in enumerate(self.shape))))
472
+ sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)]
473
+ return UOp.multi(*[lb.contiguous() for lb in sharded_lbs], axis=axis)
474
+
475
+ # *** from LazyBuffer ***
476
+
477
+ @staticmethod
478
+ def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None) -> UOp:
479
+ from tinygrad.shape.shapetracker import ShapeTracker
480
+ # Tensor const is CONST(VIEW(DEVICE)) -> RESHAPE -> EXPAND
481
+ if op is Ops.CONST:
482
+ assert isinstance(arg, get_args(ConstType)), f"trying to create CONST with {arg=}"
483
+ return UOp.const(dtype, unwrap(arg)).replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),),
484
+ ShapeTracker.from_shape(())),)).reshape((1,)*len(shape)).expand(shape)
485
+ # Tensor variable binding is BIND(VAR(VIEW(DEVICE)), CONST(VIEW(DEVICE)))
486
+ if op is Ops.BIND:
487
+ var, val = arg.unbind()
488
+ return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
489
+ # otherwise it's just a RESHAPE(BUFFER)
490
+ if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
491
+ return UOp.new_buffer(device, size, dtype).reshape(shape)
492
+ def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp:
493
+ # if it's a shrink, do the shrink before the copy with CONTIGUOUS
494
+ if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
495
+ # COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st)
496
+ ret = UOp(Ops.COPY, self.base.dtype, (UOp(Ops.DEVICE, arg=device), self.base), clone)
497
+ op_arg = []
498
+ mop = self
499
+ while mop is not self.base:
500
+ op_arg.append((mop.op, mop.arg))
501
+ mop = mop.src[0]
502
+ for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg)
503
+ return ret
504
+ def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
505
+ @property
506
+ def metadata(self) -> tuple[Metadata, ...]|Metadata|None: return self.arg.metadata if self.op is Ops.KERNEL else all_metadata.get(self, None)
342
507
 
343
508
  # *** uop movement ops ***
344
509
 
345
510
  @property
346
- def base(self): return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 else self
347
- def view(self, st:ShapeTracker):
348
- assert self.op is not Ops.STORE, "VIEW of STORE is invalid, STORE is always base"
349
- return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st)
350
- def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
511
+ def base(self) -> UOp:
512
+ if (self.op is Ops.VIEW and len(self.src) != 0) or self.op in GroupOp.Movement: return self.src[0].base
513
+ return self
514
+ def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
515
+
516
+ def _mop(self, op:Ops, arg):
517
+ ret = UOp(op, self.dtype, (self,), arg)
518
+ if self.st == ret.st: return self # ignore NOOPs, also check ret.st
519
+ return ret
351
520
 
352
- # *** uop Buffer stuff ***
521
+ def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg)
522
+ def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg)
523
+ def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg)
524
+ def permute(self, arg:tuple[sint, ...]): return self._mop(Ops.PERMUTE, arg)
525
+ def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg)
526
+ def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg)
527
+
528
+ # *** uop UNIQUE ***
353
529
 
530
+ # TODO: use this in Buffer
531
+ unique_num = itertools.count(0)
354
532
  @staticmethod
355
- def new_buffer(device:str, size:int, dtype:DType, num=-1): return UOp(Ops.BUFFER, dtype.ptr(), (), (num, (device, size, dtype)))
533
+ def unique(): return UOp(Ops.UNIQUE, arg=next(UOp.unique_num))
356
534
 
535
+ # *** uop Buffer stuff ***
536
+
537
+ @staticmethod
538
+ def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device), UOp.unique()), size)
539
+ @property
540
+ def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
541
+ @functools.cached_property
542
+ def _device(self) -> Optional[str|tuple[str, ...]]:
543
+ if self.op is Ops.DEVICE: return self.arg
544
+ if self.op is Ops.MULTI: return tuple(cast(str, x.device) for x in self.src)
545
+ return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
357
546
  @property
358
547
  def buf_uop(self) -> UOp:
359
- assert self.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}"
360
- return self.src[0]
548
+ if self.base.op is Ops.BUFFER: return self.base
549
+ assert self.base.op in {*GroupOp.Buffer, Ops.ASSIGN}, f"buf_uop called on {self.op}"
550
+ return self.src[0].buf_uop
551
+ @property
552
+ def buffer(self) -> Buffer:
553
+ if self is not self.base:
554
+ assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
555
+ return self.src[0].buffer
556
+ assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
557
+ if (cret:=buffers.get(self)) is not None: return cret
558
+ from tinygrad.device import Buffer
559
+ assert isinstance(self.device, str), f"buffer not supported on multi {self.device}"
560
+ buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base)
561
+ return ret
562
+ @property
563
+ def realized(self) -> Optional[Buffer]: return self.buffer if self.op is Ops.BUFFER else None
564
+ @property
565
+ def is_realized(self) -> bool:
566
+ return all(x.base.realized is not None for x in self.base.real_lbs) if self.base.op is Ops.MULTI else self.base.realized is not None
361
567
 
362
568
  # *** uop Variable stuff ***
363
569
 
@@ -373,22 +579,28 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
373
579
  assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
374
580
  assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
375
581
  return UOp(Ops.BIND, self.dtype, (self, self.const_like(val)))
376
- def unbind(self) -> Tuple[Variable, int]:
582
+ def unbind(self) -> tuple[Variable, int]:
377
583
  assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
378
584
  return self.src[0], self.src[1].arg
379
585
  @property
380
586
  def val(self) -> int: return self.unbind()[1]
381
- def vars(self) -> Set[UOp]:
382
- bound_vars = set([x for x in self.sparents if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
587
+ def vars(self) -> set[UOp]:
588
+ bound_vars = set([x for x in self.toposort if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
383
589
  bound_var_base = set(x.src[0] for x in bound_vars)
384
- all_vars = set([x for x in self.sparents if x.op is Ops.DEFINE_VAR])
590
+ all_vars = set([x for x in self.toposort if x.op is Ops.DEFINE_VAR])
385
591
  return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
386
- def variables(self) -> List[Variable]:
387
- st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in GroupOp.Buffer]
592
+ def variables(self) -> list[Variable]:
593
+ st_vars: list[set[Variable]] = [x.st_arg.vars() for x in self.toposort if x.op in GroupOp.Buffer]
388
594
  return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
389
595
 
390
596
  # *** uop symbolic stuff ***
391
597
 
598
+ def is_increasing(self:UOp) -> bool:
599
+ # is f a monotonically increasing function regards its input
600
+ if self.op in GroupOp.Irreducible: return True
601
+ if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
602
+ if self.op in (Ops.MUL, Ops.IDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
603
+ return False # False if not sure
392
604
  def const_factor(self) -> int:
393
605
  """largest known int that divides self"""
394
606
  if self.op is Ops.CONST: return self.arg
@@ -396,7 +608,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
396
608
  if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
397
609
  if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
398
610
  return 1
399
- def divides(self, v) -> Optional[UOp]:
611
+ def divides(self, v:int) -> UOp|None:
400
612
  if v==1: return self
401
613
  if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
402
614
  if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
@@ -410,15 +622,23 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
410
622
  @property
411
623
  def vmax(self) -> ConstType: return self._min_max[1]
412
624
  @functools.cached_property
413
- def _min_max(self) -> Tuple[ConstType, ConstType]:
625
+ def _min_max(self) -> tuple[ConstType, ConstType]:
414
626
  if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype):
415
627
  (s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
416
628
  if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
417
629
  if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
418
- if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1
419
- if self.op is Ops.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST
420
- if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
421
- if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
630
+ # SHL/SHR on consts only
631
+ if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
632
+ if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2]
633
+ if self.op is Ops.MOD and s1_vmin > 0:
634
+ return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), s1_vmax-1)
635
+ if self.op is Ops.IDIV:
636
+ if s1_vmin == s1_vmax: # min/max are equal in a CONST
637
+ if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
638
+ if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
639
+ # don't know exact bounds, but know the sign
640
+ if (s0_vmax <= 0 and s1_vmin < 0) or (s0_vmin >= 0 and s1_vmin > 0): return 0, dtypes.max(self.dtype)
641
+ if (s0_vmax <= 0 and s1_vmin > 0) or (s0_vmin >= 0 and s1_vmin < 0): return dtypes.min(self.dtype), 0
422
642
  if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
423
643
  if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
424
644
  if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
@@ -431,9 +651,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
431
651
  if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
432
652
  if self.op is Ops.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
433
653
  if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
434
- if self.op in {Ops.EXPAND, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
435
- # TODO: UOps.SPECIAL is UOps.DEFINE_VAR
436
- if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype)
654
+ if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
655
+ # TODO: Ops.SPECIAL is Ops.DEFINE_VAR
656
+ if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].vmax
437
657
  if self.op is Ops.CONST: return self.arg, self.arg
438
658
  if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
439
659
  return dtypes.min(self.dtype), dtypes.max(self.dtype)
@@ -441,11 +661,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
441
661
  @functools.cached_property
442
662
  def _sym_fxn(self):
443
663
  sself = self.simplify()
444
- varnames = tuple(x.arg[0] for x in sself.sparents if x.op is Ops.DEFINE_VAR)
664
+ varnames = tuple(x.arg[0] for x in sself.toposort if x.op is Ops.DEFINE_VAR)
445
665
  # TODO: sanitize varnames, or don't use naked eval while staying fast
446
666
  return eval("lambda "+','.join(varnames)+": "+sself.render()), varnames # pylint: disable=eval-used
447
667
 
448
- def sym_infer(self, var_vals:Dict[UOp, int]):
668
+ def sym_infer(self, var_vals:dict[UOp, int]):
449
669
  fxn, varnames = self._sym_fxn
450
670
  return fxn(**{k.arg[0]:v for k,v in var_vals.items() if k.arg[0] in varnames})
451
671
 
@@ -455,26 +675,29 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
455
675
 
456
676
  @dataclass(frozen=True)
457
677
  class KernelInfo:
678
+ name: str = "test" # name of the kernel
458
679
  local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
459
- upcasted: int = 0 # count that are upcasted (this is remapping RANGE to EXPAND)
680
+ upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL)
460
681
  dont_use_locals: bool = False # don't use local indexing
461
682
 
462
- # ***** ops in python *****
683
+ # ******** ops in python ********
684
+
685
+ def safe_exp2(x):
686
+ try: return 2 ** x
687
+ except OverflowError: return math.inf
463
688
 
464
- def hook_overflow(dv, fxn):
465
- def wfxn(*args):
466
- try: return fxn(*args)
467
- except OverflowError: return dv
468
- return wfxn
689
+ def safe_pow(x, y):
690
+ try: return math.nan if isinstance(p:=pow(x, y), complex) else p
691
+ except ZeroDivisionError: return math.inf
692
+ except ValueError: return math.inf if x > 0 else -math.inf
469
693
 
470
- python_alu: Dict[Ops, Callable] = {
471
- Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: hook_overflow(math.inf, lambda x: 2**x),
694
+ python_alu: dict[Ops, Callable] = {
695
+ Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
472
696
  Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
473
- Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
474
- Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul,
475
- Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
476
- Ops.MAX: max, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt, Ops.XOR: operator.xor,
477
- Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift,
697
+ Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow,
698
+ Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
699
+ Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
700
+ Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0,
478
701
  Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z}
479
702
 
480
703
  def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
@@ -485,63 +708,33 @@ def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
485
708
 
486
709
  # ***** uop helpers *****
487
710
 
488
- def print_uops(uops:List[UOp]):
711
+ def print_uops(uops:list[UOp]):
489
712
  for i,u in enumerate(uops):
490
713
  formatted_parents = [(uops.index(x) if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
491
714
  print(f"{i:4d} {str(u.op):20s}: {str(u.dtype):30s} " f"{str(formatted_parents):32s} {u.arg}")
492
715
 
493
- def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
494
- flops: sint = 0
495
- mem: sint = 0
496
- mults: sint = 1
497
- mult_stack: List[sint] = []
498
- dont_count: Set[UOp] = set()
499
- if ignore_indexing:
500
- for u in uops:
501
- if u.op in {Ops.LOAD, Ops.STORE}:
502
- dont_count = dont_count.union(u.src[0].sparents)
503
- if len(u.src) > 2: dont_count = dont_count.union(u.src[2].sparents)
504
- elif u.op is Ops.IF:
505
- dont_count = dont_count.union(u.src[0].sparents)
506
- for u in uops:
507
- if u.op is Ops.RANGE:
508
- mult_stack.append(mults)
509
- mults *= (u.src[1] - u.src[0]).ssimplify()
510
- elif u.op is Ops.ENDRANGE:
511
- mults = mult_stack.pop(-1)
512
- elif u.op is Ops.SPECIAL:
513
- mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
514
- elif u.op is Ops.LOAD:
515
- mem += u.dtype.itemsize * mults
516
- elif u.op is Ops.STORE:
517
- mem += u.src[1].dtype.itemsize * mults
518
- elif u.op in GroupOp.ALU and u not in dont_count:
519
- flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
520
- elif u.op is Ops.WMMA and u not in dont_count:
521
- flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
522
- return flops, mem
523
-
524
716
  # ***** pattern matcher *****
525
717
 
526
- def get_location() -> Tuple[str, int]:
718
+ def get_location() -> tuple[str, int]:
527
719
  frm = sys._getframe(1)
528
720
  # find the real frame in the file that has the UPat, TODO: is there a better way to do this?
529
- while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "uopgraph.py", "schedule.py",
530
- "lowerer.py", "cstyle.py"}:
721
+ while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py",
722
+ "symbolic.py", "expander.py", "lowerer.py", "cstyle.py",
723
+ "linearize.py"}:
531
724
  frm = frm.f_back
532
725
  return frm.f_code.co_filename, frm.f_lineno
533
726
  @functools.lru_cache(None)
534
- def lines(fn) -> List[str]:
727
+ def lines(fn) -> list[str]:
535
728
  with open(fn) as f: return f.readlines()
536
729
 
537
730
  class UPat(MathTrait):
538
- __slots__ = ["op", "dtype", "arg", "name", "src"]
539
- def __init__(self, op:Optional[Union[Ops, Tuple[Ops, ...], Set[Ops]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
540
- src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
541
- name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[Set[Ops]]=None):
731
+ __slots__ = ("op", "dtype", "arg", "name", "src")
732
+ def __init__(self, op:Optional[Union[Ops, tuple[Ops, ...], set[Ops]]]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None,
733
+ src:Optional[Union[tuple[UPat, ...], list[UPat], UPat]]=None, arg:Any=None,
734
+ name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[set[Ops]]=None):
542
735
  assert op is None or isinstance(op, Ops) or isinstance(op, tuple) or isinstance(op, set), "op must be Ops or tuple of Ops"
543
- self.op: Optional[Tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
544
- self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
736
+ self.op: Optional[tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
737
+ self.dtype: Optional[tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
545
738
  self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
546
739
  self.src: Any = None
547
740
  assert self.name != "ctx", "UPat can't be named ctx"
@@ -568,13 +761,12 @@ class UPat(MathTrait):
568
761
 
569
762
  @staticmethod
570
763
  @functools.lru_cache(None)
571
- def var(name:Optional[str]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name)
764
+ def var(name:Optional[str]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name)
572
765
  @staticmethod
573
766
  @functools.lru_cache(None)
574
- def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True):
575
- return UPat((Ops.CONST, Ops.VCONST) if vec else Ops.CONST, dtype=dtype, name=name)
767
+ def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True): return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name)
576
768
  @staticmethod
577
- def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
769
+ def const(dtype:Optional[Union[DType, tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
578
770
 
579
771
  # copied from UOp
580
772
  def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
@@ -589,7 +781,7 @@ class UPat(MathTrait):
589
781
  def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
590
782
  def alu(self, op:Ops, *src:UPat):
591
783
  asrc = (self,)+src
592
- return UPat(op, None if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
784
+ return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
593
785
 
594
786
  def printable(self:UPat) -> str:
595
787
  try: return lines(self.location[0])[self.location[1]-1].strip()
@@ -602,14 +794,14 @@ class UPat(MathTrait):
602
794
  set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
603
795
  return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
604
796
 
605
- def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
797
+ def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
606
798
  if (self.op is not None and uop.op not in self.op) or \
607
799
  (self.name is not None and store.setdefault(self.name, uop) is not uop) or \
608
800
  (self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
609
801
  (self.arg is not None and self.arg != uop.arg) or \
610
802
  (self.allowed_len != -1 and len(uop.src) != self.allowed_len): return []
611
803
  if self.src is None: return [store]
612
- res: List[Dict[str, UOp]] = []
804
+ res: list[dict[str, UOp]] = []
613
805
  for vp in self.src:
614
806
  stores, new_stores = [store.copy()], []
615
807
  for uu, vv in zip(uop.src, vp):
@@ -619,13 +811,11 @@ class UPat(MathTrait):
619
811
  return res
620
812
 
621
813
  class UPatAny(UPat):
622
- def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
623
- ret = []
624
- for x in self.src[0]:
625
- if (match:=x.match(uop, store.copy())): ret.extend(match)
626
- return ret
814
+ def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
815
+ matches = [x.match(uop, store.copy()) for x in self.src[0]]
816
+ return flatten([x for x in matches if x is not None])
627
817
 
628
- def deconstruct_function(fxn:Callable) -> Tuple:
818
+ def deconstruct_function(fxn:Callable) -> tuple:
629
819
  new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
630
820
  for co in fxn.__code__.co_consts:
631
821
  if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names})
@@ -635,10 +825,10 @@ def deconstruct_function(fxn:Callable) -> Tuple:
635
825
  return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret
636
826
 
637
827
  class PatternMatcher:
638
- def __init__(self, patterns:List[Tuple[UPat, Callable]]):
828
+ def __init__(self, patterns:list[tuple[UPat, Callable]]):
639
829
  self.patterns = patterns
640
830
  # NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
641
- self.pdict: Dict[Ops, List[Tuple[UPat, Callable, Set, bool]]] = {}
831
+ self.pdict: dict[Ops, list[tuple[UPat, Callable, set, bool]]] = {}
642
832
  # uop is required, arg is optional
643
833
  for p,fxn in self.patterns:
644
834
  assert p.op is not None
@@ -651,7 +841,7 @@ class PatternMatcher:
651
841
  @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
652
842
  def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
653
843
 
654
- def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
844
+ def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
655
845
  ler = {u.op for u in uop.src}
656
846
  for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
657
847
  if not early_reject.issubset(ler): continue
@@ -662,39 +852,34 @@ class PatternMatcher:
662
852
  # *** tracking pattern matcher ***
663
853
 
664
854
  TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
665
- match_stats:Dict[UPat, List[Union[int, float]]] = dict()
855
+ match_stats:dict[UPat, list[Union[int, float]]] = dict()
666
856
  @dataclass(frozen=True)
667
- class TrackedRewriteContext:
668
- loc: Tuple[str, int] # location that called graph_rewrite
669
- sink: UOp # the sink passed into the rewrite
670
- matches: List[Tuple[UOp, Optional[UOp], Optional[UPat], float]] = field(default_factory=list) # all matches of sparents
671
-
672
- rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
673
- contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = []
674
- _rewrite_cnt: Dict[str, int] = {}
857
+ class TrackedGraphRewrite:
858
+ loc: tuple[str, int] # location that called graph_rewrite
859
+ sink: UOp # the sink input to graph_rewrite
860
+ bottom_up: bool
861
+ matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches
862
+ name: Optional[str] = None
863
+ tracked_keys:list[Any] = []
864
+ tracked_ctxs:list[list[TrackedGraphRewrite]] = []
865
+ _name_cnt:dict[str, int] = {}
675
866
  def track_rewrites(named=False):
676
867
  def _decorator(func):
677
868
  def __wrapper(self, *args, **kwargs):
678
869
  if TRACK_MATCH_STATS >= 2:
679
- if named: _rewrite_cnt[func.__name__] = _rewrite_cnt.setdefault(func.__name__, 0)+1
680
- rewrite_stack.append((f"{(n:=func.__name__)}_{_rewrite_cnt[n]}" if named else self, []))
681
- try: ret = func(self, *args, **kwargs)
682
- finally: # NOTE: save everything in the stack
683
- if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
684
- return ret
870
+ if named: _name_cnt[func.__name__] = _name_cnt.get(func.__name__, 0)+1
871
+ tracked_keys.append(f"{func.__name__}_{_name_cnt[func.__name__]}" if named else self)
872
+ tracked_ctxs.append([])
873
+ return func(self, *args, **kwargs)
685
874
  return __wrapper
686
875
  return _decorator
687
876
 
688
877
  class TrackedPatternMatcher(PatternMatcher):
689
- def __init__(self, patterns:List[Tuple[UPat, Callable]]):
690
- super().__init__(patterns)
691
- for p,_ in self.patterns:
692
- if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
693
-
694
- def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
878
+ def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
695
879
  ret = None
696
880
  ler = {u.op for u in uop.src}
697
881
  for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
882
+ if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
698
883
  st = time.perf_counter()
699
884
  if not early_reject.issubset(ler):
700
885
  match_stats[p][2] += time.perf_counter()-st
@@ -705,10 +890,9 @@ class TrackedPatternMatcher(PatternMatcher):
705
890
  match_stats[p][0] += 1
706
891
  match_stats[p][3] += (et:=time.perf_counter()-st)
707
892
  if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
708
- if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].matches.append((uop, ret, p, et))
893
+ if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0: tracked_ctxs[-1][-1].matches.append((uop, ret, p))
709
894
  return ret # NOTE: if it returns None, we keep trying to match
710
895
  match_stats[p][2] += time.perf_counter()-st
711
- if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0: rewrite_stack[-1][1][-1].matches.append((uop, ret, None, 0))
712
896
  return None
713
897
 
714
898
  if TRACK_MATCH_STATS:
@@ -717,12 +901,10 @@ if TRACK_MATCH_STATS:
717
901
  @atexit.register
718
902
  def print_match_stats():
719
903
  if TRACK_MATCH_STATS >= 2:
720
- with open(fn:=temp("rewrites.pkl"), "wb") as f:
721
- print(f"rewrote {len(contexts)} graphs and matched {sum(len(r.matches) for _,x in contexts for r in x)} times, saved to {fn}")
722
- pickle.dump(contexts, f)
723
- if getenv("VIZ"):
724
- os.environ["VIZ"] = "0"
725
- os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py"), temp("rewrites.pkl")])
904
+ with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f:
905
+ print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
906
+ with Context(PICKLE_BUFFERS=0): pickle.dump((tracked_keys, tracked_ctxs), f)
907
+ if getenv("VIZ"): launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
726
908
  if getenv("PRINT_MATCH_STATS", 1):
727
909
  ret = [0,0,0.0,0.0]
728
910
  for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
@@ -731,401 +913,47 @@ if TRACK_MATCH_STATS:
731
913
  ret = [x+y for x,y in zip(ret, v)]
732
914
  print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {(ret[2]+ret[3])*1000.:9.2f} ms -- TOTAL")
733
915
 
916
+ def launch_viz(env_str:str, data:str):
917
+ os.environ[env_str] = "0"
918
+ os.environ[f"{env_str}_DATA"] = data
919
+ if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")):
920
+ args = ['--kernels', getenv("VIZ_DATA", "")] if getenv("VIZ_DATA", "") else []
921
+ args += ['--profile', getenv("PROFILE_DATA", "")] if getenv("PROFILE_DATA", "") else []
922
+ os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py")] + args)
923
+
734
924
  # *** simple graph rewrite engine ***
735
925
 
736
926
  class RewriteContext:
737
- def __init__(self, pm, ctx):
927
+ def __init__(self, pm, ctx=None):
738
928
  self.pm: PatternMatcher = pm
739
929
  self.ctx = ctx
740
- self.replace: Dict[UOp, UOp] = {}
741
- def rewrite(self, n:UOp) -> UOp:
930
+ self.replace: dict[UOp, UOp] = {}
931
+ def top_down_rewrite(self, n:UOp) -> UOp:
742
932
  if (rn := self.replace.get(n)) is not None: return rn
743
- new_src = tuple(map(self.rewrite, n.src))
933
+ new_src = tuple([self.top_down_rewrite(x) for x in n.src])
744
934
  new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg)
745
- self.replace[n] = ret = n if new_n is None else self.rewrite(new_n)
935
+ self.replace[n] = ret = n if new_n is None else self.top_down_rewrite(new_n)
936
+ return ret
937
+ def bottom_up_rewrite(self, n:UOp) -> UOp:
938
+ if (rn := self.replace.get(n)) is not None: return rn
939
+ new_n: UOp|None = n
940
+ while new_n is not None: last_n, new_n = new_n, self.pm.rewrite(new_n, self.ctx)
941
+ new_src = tuple([self.bottom_up_rewrite(x) for x in last_n.src])
942
+ self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
746
943
  return ret
747
944
 
748
- def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
749
- if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0:
750
- rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
751
- return RewriteContext(pm, ctx).rewrite(sink)
752
-
753
- # ***** uop type spec *****
754
-
755
- # this is the matcher for the final rendered UOps
756
- # matcher functions returns True or False (or None to not match)
757
- spec = PatternMatcher([
758
- (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
759
- (UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
760
- (UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True),
761
- lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
762
- (UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
763
-
764
- (UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
765
- (UPat(Ops.SPECIAL, src=()), lambda: True),
766
-
767
- # TODO: confirm the args of both of these are shapetrackers
768
- (UPat(Ops.VIEW, dtypes.void, src=()), lambda: True),
769
- (UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype),
770
-
771
- (UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
772
- (UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
773
-
774
- # early LOAD has a <buf, shapetracker, store?>
775
- (UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
776
- (UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
777
-
778
- # early STORE has a <buf, shapetracker, val>
779
- (UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
780
-
781
- # **** new style load/store ****
782
-
783
- # INDEX is used in new style load/store
784
- (UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
785
-
786
- # LOAD takes a <bufidx, alt?, gate?, barrier?>
787
- (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
788
- (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
789
- (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
790
-
791
- # STORE takes a <bufidx, val, gate?>
792
- (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True),
793
- (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
794
- (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True),
795
-
796
- # most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
797
- (UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
798
- (UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype == y.dtype),
799
- # and SHL/SHR, the shift distance can be an int
800
- (UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
801
- (UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
802
- (UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
803
-
804
- (UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
805
- (UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
806
-
807
- # all WMMA has 3 args, <x, w, acc>
808
- (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
809
- (UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
810
- (UPat(Ops.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
811
-
812
- # if has a <gate, barrier?>
813
- (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
814
- (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
815
- (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
816
-
817
- (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
818
- (UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
819
- (UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
820
- (UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
821
- (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
822
-
823
- # NOTE: for testing, we let sinks be anything
824
- #(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
825
- (UPat(Ops.SINK, dtypes.void), lambda: True),
826
- (UPat(Ops.NOOP), lambda: True),
827
-
828
- # PTX LOAD/STORE
829
- (UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
830
- (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True),
831
- ])
832
-
833
- def type_verify(uops:List[UOp]):
834
- for i,u in enumerate(uops):
835
- if not spec.rewrite(u):
836
- print_uops(uops)
837
- raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}")
838
-
839
- # *** uop helpers ***
840
-
841
- def cast_float_to_bf16(x: UOp) -> UOp:
842
- assert x.dtype == dtypes.float, "cast float -> bf16 must start with float"
843
- x = x.bitcast(dtypes.uint)
844
- x = (-x & 0x7f800000).where(x + ((x >> 16) & 1) + 0x7fff, (x & 0xffff).where((x | 0x10000), x))
845
- return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
846
-
847
- # *** most of symbolic lives here now ***
848
-
849
- def split_uop(x:UOp, sep:Ops):
850
- if x.op is sep:
851
- for s in x.src: yield from split_uop(s, sep)
852
- else: yield x
853
-
854
- def mod_folding(x:UOp, c:int) -> Optional[UOp]:
855
- # simplify x % c, None means no change
856
-
857
- # simple cancel mod case
858
- if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c
859
-
860
- remainder, something_changed = [], False
861
- for u in split_uop(x, Ops.ADD):
862
- if (factor:=u.const_factor())%c != factor:
863
- divides = u.divides(factor)*(factor%c)
864
- assert divides is not None
865
- remainder.append(divides)
866
- something_changed = True
867
- elif u.op is Ops.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0:
868
- remainder.append(u.src[0])
869
- something_changed = True
870
- else: remainder.append(u)
871
- if not something_changed: return None
872
- return functools.reduce(operator.add, remainder)%c if remainder else x.const_like(0)
873
-
874
- def div_folding(x:UOp, c:int) -> Optional[UOp]:
875
- # simplify x // c, None means no change
876
-
877
- # simple cancel div case
878
- if 0 <= x.vmin and x.vmax < c: return x.const_like(0)
879
-
880
- quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1
881
- for u in split_uop(x, Ops.ADD):
882
- if u.op is Ops.CONST:
883
- # add all const together first
884
- if rem_const != 0: something_changed = True
885
- rem_const += u.arg
886
- elif (factor:=u.const_factor())%c == 0:
887
- if factor:
888
- divides = u.divides(c)
889
- assert divides is not None
890
- quotient.append(divides)
891
- something_changed = True
892
- else:
893
- # divisor is the smallest common divisor of all MULs
894
- if u.op is Ops.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor
895
- remainder.append(u)
896
- gcd = math.gcd(gcd, factor)
897
-
898
- # handle the const
899
- if rem_const%c != rem_const:
900
- something_changed = True
901
- quotient.append(x.const_like(rem_const//c))
902
- rem_const = rem_const%c
903
- if rem_const != 0: remainder.append(x.const_like(rem_const))
904
-
905
- # x // c -> quotient + (remainder // div) // (c // div)
906
- div = gcd if gcd > 1 else divisor
907
-
908
- if not something_changed: return newx//(c//div) if 1 < div < c and (newx:=div_folding(x, div)) is not None else None
909
- rem:Optional[UOp] = functools.reduce(operator.add, remainder) if remainder else None
910
- quo:Optional[UOp] = functools.reduce(operator.add, quotient) if quotient else None
911
- if quo is None: return x.const_like(0) if rem is None else cast(UOp, div_folding(rem, div))//(c//div)
912
- return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo
913
-
914
- def lt_folding(x:UOp, c:int) -> Optional[UOp]:
915
- p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
916
- if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
917
- return cast(UOp, functools.reduce(operator.add, np).divides(d)).lt(c//d)
918
- return None
919
-
920
- def fold_unrolled_divs(divs:UOp):
921
- # div pattern in unrolled arange
922
- # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
923
- add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None
924
- for u in add_chain:
925
- if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
926
- if denominator is None: denominator = u.src[1].arg
927
- if denominator != u.src[1].arg: return None
928
- # assumed CONST is the last of an ADD
929
- if (s0:=u.src[0]).op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
930
- seen_const.append(s0.src[1].arg)
931
- s0 = s0.src[0]
932
- else: seen_const.append(0)
933
- if ans is None: ans = s0
934
- if ans is not s0: return None
935
- if denominator is None: return None
936
- # the first (denominator-len(seen_const)) terms may have been folded to 0 already
937
- for i in range(denominator-len(seen_const)):
938
- if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
939
- return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
940
-
941
- def canonicalize_simplex(X:UOp) -> Optional[UOp]:
942
- # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
943
- # returns x0 + x1 + ... in such case, or None if not
944
- changed, ret = False, []
945
- for u in split_uop(X, Ops.ADD):
946
- # assumed the const is the last src of MUL
947
- if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
948
- changed = True
949
- u = u.src[0]
950
- if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None
951
- ret.append(u)
952
- return functools.reduce(operator.add, ret) if changed else None
953
-
954
- def is_increasing(f:UOp) -> bool:
955
- # is f a monotonically increasing function regards its input
956
- if f.op in GroupOp.Irreducible: return True
957
- if f.op is Ops.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
958
- if f.op in (Ops.MUL, Ops.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
959
- return False # False if not sure
960
-
961
- def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
962
- # if it's X <= c, returns X, True, c
963
- # if it's X >= c, returns X, False, c
964
-
965
- # (X < c).ne(True) -> X >= c
966
- if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
967
- (s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
968
- # X < c -> X <= c-1
969
- if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1
970
- raise ValueError(f"not able to parse {valid=}")
971
-
972
- def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
973
- # return None if valid is always False, otherwise the simplified uop (might be the same as input)
974
-
975
- # first, parse valid into {expr: (lower_bound, upper_bound)}
976
- bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
977
- for stmt in split_uop(valid, Ops.AND):
978
- try: expr, is_upper, c = parse_valid(stmt)
979
- except ValueError: return uop # give up if we cannot parse the valid
980
- bounds[expr][int(is_upper)] = c
981
-
982
- # simplify uop given that valid is True
983
- for expr,v in bounds.items():
984
- # some expr has lower bound > upper bound -> valid is an empty set and we return None
985
- if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
986
-
987
- # every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
988
- candidates = []
989
- if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
990
- # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
991
- candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
992
- # try checking the whole clause
993
- if expr in uop.sparents:
994
- candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
995
-
996
- for candidate in candidates:
997
- # if every branch in candidate gives the same simplified uop, we can rewrite the uop
998
- newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
999
- if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
1000
- if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
1001
- if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
1002
- elif all_same(newuops): uop = newuops[0]
1003
-
1004
- return uop
1005
-
1006
- def _valid_priority(v: UOp, valids:List[UOp]):
1007
- # we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
1008
- try: return sum(-1 if parse_valid(v)[0] in other.parents else 0 for other in valids)
1009
- except ValueError: return 0
1010
-
1011
- def simplify_valid(valid:UOp) -> Optional[UOp]:
1012
- ret:List[UOp] = []
1013
- something_changed = False
1014
- valids = list(split_uop(valid, Ops.AND))
1015
- for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
1016
- ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
1017
- if ret[-1] is not stmt: something_changed = True
1018
- return functools.reduce(operator.and_, ret) if something_changed else None
1019
-
1020
- def max_var_const(x:UOp, c1:UOp, c2:UOp):
1021
- if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
1022
- if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
1023
-
1024
- symbolic_simple = PatternMatcher([
1025
- # ** self folding **
1026
- (UPat.var("x") + 0, lambda x: x), # x+0 -> x
1027
- (UPat.var("x") * 1, lambda x: x), # x*1 -> x
1028
- (UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
1029
- (UPat.var("x") // 1, lambda x: x), # x//1 -> x
1030
- (UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
1031
- (UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
1032
- ((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
1033
- ((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
1034
- (UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
1035
- (UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
1036
- (UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
1037
- (UPat.var("x").maximum(UPat.var("x")), lambda x: x),
1038
- ((UPat.var("x") & UPat.var("x")), lambda x: x),
1039
- ((UPat.var("x") | UPat.var("x")), lambda x: x),
1040
- (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
1041
- (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
1042
- # ** zero folding **
1043
- (UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
1044
- (UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
1045
- lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints)
1046
- # x*0 -> 0 or 0*x -> 0
1047
- # if x is nan or inf it should render the nan value.
1048
- # NOTE: this can be wrong for loaded NaN
1049
- (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
1050
- # ** constant folding **
1051
- (UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))), lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False))),
1052
- # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
1053
- (UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
1054
- (UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
1055
- (UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
1056
- # *** cast ***
1057
- (UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
1058
- (UPat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
1059
- ])
945
+ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> UOp:
946
+ if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
947
+ tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
948
+ return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink)
1060
949
 
1061
- symbolic = symbolic_simple+PatternMatcher([
1062
- # ** COMMUTATIVE flipping **
1063
- (UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
1064
- # group like
1065
- ((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
1066
- # ** boolean algebra **
1067
- (UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
1068
- # ** combine terms **
1069
- (UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
1070
- (UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
1071
- (UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
1072
- ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
1073
- (-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
1074
- # a conditional with the same results either way is a noop, also fold const conditionals
1075
- (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
1076
- (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
1077
- # ALU min==max -> CONST (slow!)
1078
- (UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
1079
- # max folding
1080
- (UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
1081
- # TODO: why does this rule break beautiful_mnist?
1082
- #((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z),
1083
- ((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
1084
- # ** two stage ALU folding **
1085
- ((UPat.var("x") + UPat.cvar("c1")) + UPat.cvar("c2"), lambda x,c1,c2: x+(c1+c2)),
1086
- ((UPat.var("x") * UPat.cvar("c1")) * UPat.cvar("c2"), lambda x,c1,c2: x*(c1*c2)),
1087
- ((UPat.var("x") & UPat.cvar("c1")) & UPat.cvar("c2"), lambda x,c1,c2: x&(c1&c2)),
1088
- ((UPat.var("x") | UPat.cvar("c1")) | UPat.cvar("c2"), lambda x,c1,c2: x|(c1|c2)),
1089
- ((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
1090
- ((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
1091
- # ** lt **
1092
- # c0*x<c1 for positive int c0,c1
1093
- ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
1094
- lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if c0.arg > 0 and c1.arg > 0 else None),
1095
- # c0*x<c1 for negative int c0 and non-positive c1
1096
- ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
1097
- lambda x,c0,c1: (-x).lt(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
1098
- # x//c0<c1 for positive int c0
1099
- ((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False)).lt(UPat.cvar("c1", vec=False)),
1100
- lambda x,c0,c1: x.lt(c1.arg*c0.arg) if c0.arg > 0 else None),
1101
- # mul add lt
1102
- (((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
1103
- lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
1104
- # ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
1105
- (UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
1106
- (UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
1107
- # *** rules from symbolic ***
1108
- # unrolled arange div folding
1109
- (UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),
1110
- # generic lt folding
1111
- (UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
1112
- # canonicalize a simplex with positive coefficients > 0
1113
- # not x < 1 -> X > 0
1114
- (UPat.var("x", dtypes.ints).lt(1).ne(True), lambda x: newx.lt(1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
1115
- # ** div **
1116
- # # div folding
1117
- (UPat.var("x", dtypes.sints) // UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=div_folding(x,c.arg)) is not None else None),
1118
- # ** mod **
1119
- # mod folding
1120
- (UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
1121
- ])
950
+ def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> dict[UOp, UOp]:
951
+ if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
952
+ tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
953
+ rewrite_ctx = RewriteContext(pm, ctx)
954
+ return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
1122
955
 
1123
- symbolic_flat = symbolic+PatternMatcher([
1124
- # ** combine terms (opinionated) **
1125
- (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
1126
- # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
1127
- ((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
1128
- ])
956
+ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x
1129
957
 
1130
958
  _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
1131
959
 
@@ -1134,7 +962,7 @@ syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<"
1134
962
  Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
1135
963
  renderer = PatternMatcher([
1136
964
  (UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
1137
- (UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")),
965
+ (UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg}")),
1138
966
  (UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
1139
967
  (UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
1140
968
  (UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
@@ -1149,4 +977,27 @@ renderer = PatternMatcher([
1149
977
  sint = Union[int, UOp]
1150
978
  Variable = UOp
1151
979
 
1152
- ConstLike = Union[ConstType, Variable, Tuple[ConstType, ...]]
980
+ ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]]
981
+
982
+ # *** UOp merge views and swizzling ***
983
+
984
+ merge_views = PatternMatcher([
985
+ # VIEW(VIEW) merges to a single VIEW
986
+ (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)),
987
+ # remove VIEW if it's contiguous and same as the base shape
988
+ (UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.All-{Ops.DEVICE}, name="x"),)), lambda vm,x: x if vm.st.contiguous and x.shape == vm.shape else None),
989
+ # merge unmasked const views
990
+ (UPat(Ops.VIEW, name="view", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
991
+ lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
992
+ ])
993
+
994
+ # push VIEW to parents
995
+ view_left = merge_views+PatternMatcher([
996
+ # VIEW(CONST) becomes VALID
997
+ (UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.valid(vm.st)),
998
+ # VIEW before elementwise/buffer ops
999
+ (UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
1000
+ lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
1001
+ (UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)),
1002
+ lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
1003
+ ])