tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +114 -172
- tinygrad/codegen/linearize.py +211 -81
- tinygrad/codegen/lowerer.py +30 -35
- tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
- tinygrad/codegen/transcendental.py +12 -13
- tinygrad/device.py +170 -47
- tinygrad/dtype.py +28 -26
- tinygrad/engine/jit.py +80 -63
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +162 -0
- tinygrad/engine/realize.py +58 -107
- tinygrad/engine/schedule.py +381 -314
- tinygrad/engine/search.py +40 -44
- tinygrad/gradient.py +70 -0
- tinygrad/helpers.py +77 -58
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +89 -64
- tinygrad/ops.py +562 -446
- tinygrad/renderer/__init__.py +79 -36
- tinygrad/renderer/cstyle.py +70 -84
- tinygrad/renderer/llvmir.py +32 -20
- tinygrad/renderer/ptx.py +79 -99
- tinygrad/renderer/wgsl.py +87 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libpciaccess.py +2023 -0
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +19 -21
- tinygrad/runtime/ops_amd.py +488 -327
- tinygrad/runtime/ops_clang.py +15 -28
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +129 -38
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +45 -40
- tinygrad/runtime/ops_metal.py +93 -73
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +232 -270
- tinygrad/runtime/ops_python.py +51 -46
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +63 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +384 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +26 -4
- tinygrad/runtime/support/hcq.py +254 -324
- tinygrad/runtime/support/llvm.py +32 -0
- tinygrad/shape/shapetracker.py +84 -53
- tinygrad/shape/view.py +103 -138
- tinygrad/spec.py +154 -0
- tinygrad/tensor.py +744 -496
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
- tinygrad-0.10.1.dist-info/RECORD +86 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/ops.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import Any,
|
3
|
-
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect
|
2
|
+
from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type, Literal, get_args
|
3
|
+
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
|
4
4
|
from enum import auto, IntEnum, Enum
|
5
5
|
from dataclasses import dataclass, field
|
6
6
|
from collections import defaultdict
|
7
|
-
from
|
8
|
-
from tinygrad.
|
9
|
-
from tinygrad.helpers import
|
7
|
+
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
|
8
|
+
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, _METADATA, flatten
|
9
|
+
from tinygrad.helpers import PICKLE_BUFFERS, SPLIT_REDUCEOP, DEBUG, dedup
|
10
10
|
if TYPE_CHECKING:
|
11
11
|
from tinygrad.shape.shapetracker import ShapeTracker
|
12
|
+
from tinygrad.device import Buffer
|
12
13
|
|
13
14
|
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
|
14
15
|
class FastEnum(IntEnum):
|
@@ -26,8 +27,7 @@ class SimpleMathTrait:
|
|
26
27
|
def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
|
27
28
|
def logical_not(self): return self.ne(True)
|
28
29
|
def neg(self):
|
29
|
-
dtype
|
30
|
-
assert dtype is not None, "MathTraits __neg__ requires a dtype"
|
30
|
+
if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
|
31
31
|
return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
|
32
32
|
def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse)
|
33
33
|
def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse)
|
@@ -35,6 +35,7 @@ class SimpleMathTrait:
|
|
35
35
|
def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse)
|
36
36
|
def xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse)
|
37
37
|
def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse)
|
38
|
+
def mod(self, x, reverse=False): return self._binop(Ops.MOD, x, reverse)
|
38
39
|
def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
|
39
40
|
def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
|
40
41
|
|
@@ -44,7 +45,8 @@ class SimpleMathTrait:
|
|
44
45
|
def __sub__(self, x): return self.sub(x)
|
45
46
|
def __mul__(self, x): return self.mul(x)
|
46
47
|
def __truediv__(self, x): return self.div(x)
|
47
|
-
def __floordiv__(self, x): return self.idiv(x)
|
48
|
+
def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv
|
49
|
+
def __mod__(self, x): return self.mod(x)
|
48
50
|
def __and__(self, x): return self.bitwise_and(x)
|
49
51
|
def __or__(self, x): return self.bitwise_or(x)
|
50
52
|
def __xor__(self, x): return self.xor(x)
|
@@ -57,22 +59,19 @@ class SimpleMathTrait:
|
|
57
59
|
def __rand__(self, x): return self.bitwise_and(x, True)
|
58
60
|
def __ror__(self, x): return self.bitwise_or(x, True)
|
59
61
|
def __rxor__(self, x): return self.xor(x, True)
|
62
|
+
def __rmod__(self, x): return self.mod(x, True)
|
63
|
+
|
64
|
+
def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
|
65
|
+
def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
|
66
|
+
def __ge__(self, x): return (self < x).logical_not()
|
67
|
+
def __le__(self, x): return (self > x).logical_not()
|
60
68
|
|
61
|
-
def lt(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
|
62
|
-
def gt(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
|
63
69
|
def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x))
|
64
|
-
def ge(self, x): return self.lt(x).logical_not()
|
65
|
-
def le(self, x): return self.gt(x).logical_not()
|
66
70
|
def eq(self, x): return self.ne(x).logical_not()
|
67
|
-
|
68
|
-
def __lt__(self, x): return self.lt(x)
|
69
|
-
def __gt__(self, x): return self.gt(x)
|
70
71
|
def __ne__(self, x): return self.ne(x)
|
71
|
-
def __ge__(self, x): return self.ge(x)
|
72
|
-
def __le__(self, x): return self.le(x)
|
73
72
|
# NOTE: __eq__ isn't overridden, and means the same thing as is by default
|
74
73
|
|
75
|
-
class MathTrait(SimpleMathTrait):
|
74
|
+
class MathTrait(SimpleMathTrait):
|
76
75
|
# TODO: move to Tensor when new backward is done
|
77
76
|
def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse)
|
78
77
|
def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse)
|
@@ -81,10 +80,6 @@ class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
|
|
81
80
|
def __rlshift__(self, x): return self.lshift(x, True)
|
82
81
|
def __rrshift__(self, x): return self.rshift(x, True)
|
83
82
|
|
84
|
-
# not in Tensor
|
85
|
-
def __mod__(self, x): return self.alu(Ops.MOD, self.ufix(x))
|
86
|
-
def __rmod__(self, x): return self.ufix(x).alu(Ops.MOD, self)
|
87
|
-
|
88
83
|
def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
|
89
84
|
def minimum(self, x): return -(-self).maximum(-x)
|
90
85
|
def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y))
|
@@ -98,39 +93,40 @@ class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
|
|
98
93
|
# the order of these Ops controls the order of the toposort
|
99
94
|
class Ops(FastEnum):
|
100
95
|
# uops that aren't rendered
|
101
|
-
SINK = auto()
|
102
|
-
CONTIGUOUS = auto()
|
103
|
-
PRELOAD = auto()
|
96
|
+
SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); PRELOAD = auto(); KERNEL = auto() # noqa: E702
|
104
97
|
|
105
|
-
#
|
106
|
-
COPY = auto()
|
98
|
+
# TODO: empty continues to exist because of tensor
|
107
99
|
EMPTY = auto()
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
100
|
+
|
101
|
+
# MetaOps
|
102
|
+
COPY = auto(); BUFFER_VIEW = auto() # noqa: E702
|
103
|
+
|
104
|
+
# blocks in linearizer
|
105
|
+
BLOCK = auto(); BLOCKSTART = auto(); BLOCKFORK = auto(); BLOCKEND = auto() # noqa: E702
|
106
|
+
|
107
|
+
# movement ops!
|
108
|
+
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
|
109
|
+
|
110
|
+
# misc ops
|
111
|
+
UNROLL = auto(); CONTRACT = auto() # noqa: E702
|
112
|
+
VIEW = auto(); DEFINE_GLOBAL = auto(); BUFFER = auto() # noqa: E702
|
113
|
+
DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
|
114
|
+
VALID = auto(); SPECIAL = auto(); NOOP = auto() # noqa: E702
|
121
115
|
|
122
116
|
# reduce
|
123
117
|
REDUCE_AXIS = auto()
|
124
118
|
|
125
119
|
# helper ops
|
126
|
-
GEP = auto()
|
127
|
-
VECTORIZE = auto()
|
120
|
+
GEP = auto(); VECTORIZE = auto() # noqa: E702
|
128
121
|
|
129
122
|
# UnaryOps
|
130
123
|
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
131
124
|
|
132
|
-
#
|
133
|
-
LOAD = auto()
|
125
|
+
# load/store before math
|
126
|
+
LOAD = auto(); STORE = auto() # noqa: E702
|
127
|
+
|
128
|
+
# early INDEX
|
129
|
+
INDEX = auto()
|
134
130
|
|
135
131
|
# math ops
|
136
132
|
WMMA = auto()
|
@@ -143,25 +139,18 @@ class Ops(FastEnum):
|
|
143
139
|
WHERE = auto(); MULACC = auto() # noqa: E702
|
144
140
|
|
145
141
|
# assignment ops
|
146
|
-
STORE = auto()
|
147
142
|
ASSIGN = auto()
|
148
143
|
BIND = auto()
|
149
144
|
|
150
|
-
# late INDEX
|
151
|
-
INDEX = auto()
|
152
|
-
|
153
145
|
# control flow ops
|
154
|
-
BARRIER = auto()
|
155
|
-
IF = auto()
|
156
|
-
RANGE = auto()
|
157
|
-
|
158
|
-
# ops that are not graph nodes
|
159
|
-
ENDRANGE = auto()
|
160
|
-
ENDIF = auto()
|
146
|
+
BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702
|
161
147
|
|
162
148
|
# consts last!
|
163
|
-
VCONST = auto()
|
164
|
-
|
149
|
+
VCONST = auto(); CONST = auto() # noqa: E702
|
150
|
+
|
151
|
+
# device
|
152
|
+
DEVICE = auto()
|
153
|
+
MULTI = auto()
|
165
154
|
|
166
155
|
class GroupOp:
|
167
156
|
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
|
@@ -171,41 +160,53 @@ class GroupOp:
|
|
171
160
|
ALU = set.union(Unary, Binary, Ternary)
|
172
161
|
|
173
162
|
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
|
163
|
+
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
|
174
164
|
|
175
|
-
|
176
|
-
|
177
|
-
Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
|
165
|
+
Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
|
166
|
+
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
|
178
167
|
|
179
168
|
# BinaryOps that can be flipped
|
180
169
|
Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR}
|
181
170
|
|
171
|
+
# BinaryOps where f(f(a,b),c) = f(a,f(b,c))
|
172
|
+
Associative = {Ops.ADD, Ops.MUL, Ops.AND, Ops.OR, Ops.MAX}
|
173
|
+
|
174
|
+
# BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence
|
175
|
+
Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
|
176
|
+
|
182
177
|
# do not preserve f(0) = 0
|
183
178
|
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
|
184
179
|
|
185
|
-
|
186
|
-
|
180
|
+
All = set(Ops)
|
181
|
+
|
182
|
+
# some BUFFER ops can be processed with only a view
|
183
|
+
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
|
187
184
|
|
188
|
-
|
185
|
+
# https://en.wikipedia.org/wiki/Identity_element
|
186
|
+
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
189
187
|
|
190
|
-
|
188
|
+
def can_pad(u:UOp, edges:dict[UOp, UOp], visisted:dict[UOp, None]) -> bool:
|
189
|
+
if u.op in GroupOp.UnsafePad: return False
|
190
|
+
if (len(u.src) == 2 and u.src[0] in edges) or u in visisted: return True
|
191
|
+
visisted[u] = None
|
192
|
+
return all(can_pad(x.base, edges, visisted) for x in u.src)
|
191
193
|
|
192
194
|
# With True as the default, this matches the old symbolic behavior
|
193
|
-
def resolve(x, default:bool=True):
|
194
|
-
if
|
195
|
-
assert x.dtype
|
195
|
+
def resolve(x:UOp|bool, default:bool=True):
|
196
|
+
if isinstance(x, bool): return x
|
197
|
+
assert x.dtype == dtypes.bool, "UOp in resolve must be bool"
|
196
198
|
# NOTE: generating the text for the exception is expensive, so we do this
|
197
199
|
return bool(sx.vmin) if (sx:=x.simplify()).vmin == sx.vmax else default
|
198
200
|
|
199
201
|
# smax/smin are replacements for max/min that preserve symbolic
|
200
202
|
def _suop(lst, uop_fxn, python_fxn):
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
def
|
205
|
-
def smin(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.minimum, min)
|
203
|
+
uops, nums = partition(lst, lambda x: isinstance(x, UOp))
|
204
|
+
return ssimplify(functools.reduce(uop_fxn, uops + ([python_fxn(nums)] if nums else [])))
|
205
|
+
def smax(*lst): return _suop(argfix(*lst), UOp.maximum, max)
|
206
|
+
def smin(*lst): return _suop(argfix(*lst), UOp.minimum, min)
|
206
207
|
|
207
208
|
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
208
|
-
def sym_infer(uop: Union[UOp, int], var_vals:
|
209
|
+
def sym_infer(uop: Union[UOp, int], var_vals: dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
209
210
|
|
210
211
|
# used for UOp and UPat
|
211
212
|
def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
|
@@ -219,53 +220,101 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
|
|
219
220
|
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
|
220
221
|
|
221
222
|
class UOpMetaClass(type):
|
222
|
-
ucache:
|
223
|
-
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:
|
224
|
-
if (
|
225
|
-
UOpMetaClass.ucache[key] =
|
226
|
-
|
227
|
-
|
223
|
+
ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
|
224
|
+
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, _buffer:Buffer|None=None):
|
225
|
+
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
|
226
|
+
UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
|
227
|
+
for s in src: s.children.add(ref)
|
228
|
+
# NOTE: this will soon be set by Tensor once we remove function.py
|
229
|
+
if (metadata:=_METADATA.get()) is not None: all_metadata[created] = metadata
|
230
|
+
# NOTE: this value is set by pickle when pickling a realized tensor
|
231
|
+
if _buffer is not None:
|
232
|
+
assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
|
233
|
+
buffers[created] = _buffer
|
234
|
+
return created
|
235
|
+
|
236
|
+
# some uops map to other stuff
|
237
|
+
buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers
|
238
|
+
all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary()
|
239
|
+
|
240
|
+
# NOTE: this should be frozen, but frozen is slower
|
241
|
+
@dataclass(eq=False, slots=True)
|
228
242
|
class UOp(MathTrait, metaclass=UOpMetaClass):
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
243
|
+
op:Ops
|
244
|
+
dtype:DType = dtypes.void
|
245
|
+
src:tuple[UOp, ...] = tuple()
|
246
|
+
arg:Any = None
|
247
|
+
children:set[weakref.ref[UOp]] = field(default_factory=set)
|
248
|
+
def __del__(self):
|
249
|
+
if self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1)
|
250
|
+
if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg))) is not None:
|
251
|
+
for s in self.src: s.children.discard(ref)
|
252
|
+
del UOpMetaClass.ucache[k]
|
253
|
+
def __reduce__(self):
|
254
|
+
args = [self.op, self.dtype, self.src, self.arg]
|
255
|
+
if self.op is Ops.BUFFER and self.realized is not None and PICKLE_BUFFERS: args.append(self.realized)
|
256
|
+
return UOp, tuple(args)
|
234
257
|
def replace(self, **kwargs) -> UOp:
|
235
|
-
|
236
|
-
|
258
|
+
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg))
|
259
|
+
assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}"
|
237
260
|
if (self.op, self.dtype, self.src, self.arg) == new_args: return self
|
238
261
|
return UOp(*new_args)
|
239
262
|
@functools.cached_property
|
240
263
|
def key(self) -> bytes:
|
241
264
|
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
|
242
265
|
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
243
|
-
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else self.arg
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
266
|
+
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else repr(self.arg)
|
267
|
+
|
268
|
+
@property
|
269
|
+
def toposort(self) -> dict[UOp, None]:
|
270
|
+
def _toposort(u:UOp, cache:set[UOp]):
|
271
|
+
if u in cache: return {}
|
272
|
+
nodes: dict[UOp, None] = {}
|
273
|
+
# NOTE: this is a lot faster than the comprehension in parents
|
274
|
+
for parent in u.src: nodes.update(_toposort(parent, cache))
|
275
|
+
nodes[u] = None
|
276
|
+
cache.add(u)
|
277
|
+
return nodes
|
278
|
+
return _toposort(self, cache=set())
|
248
279
|
|
249
280
|
@functools.cached_property
|
250
|
-
def tuplize(self:UOp) ->
|
251
|
-
return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src))
|
281
|
+
def tuplize(self:UOp) -> tuple[int, Any, Optional[DType], tuple]: return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src))
|
252
282
|
|
253
283
|
# *** uop shape stuff ***
|
254
284
|
|
255
|
-
@property
|
256
|
-
def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR}
|
257
285
|
@functools.cached_property
|
258
|
-
def st(self) ->
|
259
|
-
if not self.has_st: return None
|
260
|
-
if self.op in GroupOp.Buffer: return self.st_arg
|
261
|
-
if self.op is Ops.VIEW: return self.arg
|
262
|
-
src_sts = [x.st for x in self.src if x.st is not None]
|
263
|
-
assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
|
286
|
+
def st(self) -> ShapeTracker|None:
|
264
287
|
from tinygrad.shape.shapetracker import ShapeTracker
|
265
|
-
|
288
|
+
if self.op is Ops.MULTI:
|
289
|
+
return ShapeTracker.from_shape(
|
290
|
+
tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)))
|
291
|
+
if self.op is Ops.BUFFER: return ShapeTracker.from_shape((self.size,))
|
292
|
+
# these ops define a ShapeTracker from the arg
|
293
|
+
if self.op is Ops.VIEW: return self.arg
|
294
|
+
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
|
295
|
+
# buffer ops return the ShapeTracker from sources
|
296
|
+
if self.op in GroupOp.Buffer: return vsrc[0] if len(vsrc:=[x.st for x in self.src if x.op is Ops.VIEW]) != 0 else None
|
297
|
+
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
|
298
|
+
assert all_same([x.shape for x in src_sts]), f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}"
|
299
|
+
if self.op in {Ops.BITCAST, Ops.BUFFER_VIEW}:
|
300
|
+
shape = src_sts[0].shape
|
301
|
+
if self.dtype.itemsize != (input_sz:=self.src[0].dtype.itemsize): shape = shape[:-1]+((shape[-1]*input_sz) // self.dtype.itemsize,)
|
302
|
+
# only reduce ops are allowed to change shape, everything else derives shape from sources
|
303
|
+
elif self.op in {Ops.REDUCE_AXIS, Ops.WMMA}: shape = src_sts[0].reduce(self.axis_arg)
|
304
|
+
else: shape = src_sts[0].shape
|
305
|
+
return ShapeTracker.from_shape(shape)
|
306
|
+
|
266
307
|
@functools.cached_property
|
267
|
-
def full_shape(self) ->
|
268
|
-
|
308
|
+
def full_shape(self) -> tuple[sint, ...]:
|
309
|
+
if self.op is Ops.VIEW: return self.shape
|
310
|
+
# TODO: this should check if st is None, it cannot because local reduce has implicit movement ops
|
311
|
+
return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} \
|
312
|
+
# TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this
|
313
|
+
and not (x.op is Ops.CONST and x.st is None)]))
|
314
|
+
@property
|
315
|
+
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
|
316
|
+
@property
|
317
|
+
def size(self) -> int: return self.arg[1] if self.op is Ops.BUFFER else unwrap(self.st).size
|
269
318
|
|
270
319
|
# *** uop evaluation ***
|
271
320
|
|
@@ -282,34 +331,44 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
282
331
|
def __bool__(self): return self._eval((dtypes.bool,), bool)
|
283
332
|
def __int__(self): return self._eval(dtypes.ints, int)
|
284
333
|
def __float__(self): return self._eval(dtypes.floats, float)
|
285
|
-
def substitute(self, dvars:
|
334
|
+
def substitute(self, dvars:dict[UOp, UOp]):
|
286
335
|
with Context(TRACK_MATCH_STATS=0):
|
287
|
-
return graph_rewrite(self, _substitute, dvars)
|
336
|
+
return graph_rewrite(self, _substitute, dvars, bottom_up=True)
|
288
337
|
|
289
338
|
# *** uop syntactic sugar ***
|
290
339
|
|
291
340
|
@property
|
292
341
|
def st_arg(self) -> ShapeTracker:
|
293
342
|
assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}"
|
294
|
-
|
295
|
-
|
296
|
-
|
343
|
+
return unwrap(self.st)
|
344
|
+
@property
|
345
|
+
def const_arg(self) -> ConstType:
|
346
|
+
match self.base.op:
|
347
|
+
case Ops.CONST: ret = self.base.arg
|
348
|
+
case op: raise AssertionError(f"const_arg called on {op}")
|
349
|
+
assert isinstance(ret, get_args(ConstType)), f"const_arg trying to return {ret}"
|
350
|
+
return ret
|
297
351
|
@property
|
298
|
-
def axis_arg(self) ->
|
352
|
+
def axis_arg(self) -> tuple[int, ...]:
|
299
353
|
assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
|
300
354
|
ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
|
301
355
|
assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
|
302
356
|
return ret
|
303
357
|
def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
|
304
|
-
def
|
305
|
-
def
|
358
|
+
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
359
|
+
def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
360
|
+
def const_like(self, b:ConstLike):
|
361
|
+
# constants can optionally have a DEVICE source
|
362
|
+
if self._device is None: return UOp.const(self.dtype, b)
|
363
|
+
if isinstance(self.device, tuple): return UOp.multi(*[UOp.metaop(Ops.CONST, self.shape, self.dtype, d, b) for d in self.device], axis=None)
|
364
|
+
return UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b)
|
306
365
|
def broadcast(self, count:int):
|
307
366
|
assert self.dtype.count == 1
|
308
367
|
if count == 1: return self
|
309
368
|
return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
|
310
369
|
def cast(self, dtype:DType): return UOp(Ops.CAST, dtype, (self,))
|
311
370
|
def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
|
312
|
-
def gep(self, i:Union[
|
371
|
+
def gep(self, i:Union[tuple[int, ...], int]):
|
313
372
|
if isinstance(i, int):
|
314
373
|
# NOTE: these are just shortcuts to not have to create and fold later
|
315
374
|
if self.op is Ops.VECTORIZE: return self.src[i]
|
@@ -322,42 +381,185 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
322
381
|
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
323
382
|
def alu(self, arg, *src:UOp):
|
324
383
|
out_dtype = (self, *src)[-1].dtype
|
325
|
-
if arg in {Ops.CMPLT, Ops.CMPNE}
|
326
|
-
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
384
|
+
if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
327
385
|
return UOp(arg, out_dtype, (self,)+src)
|
328
386
|
@staticmethod
|
329
387
|
def const(dtype:DType, b:ConstLike):
|
330
388
|
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
331
389
|
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
332
|
-
return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype)
|
390
|
+
return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
|
391
|
+
def valid(self, st:ShapeTracker):
|
392
|
+
assert self.op in {Ops.CONST, Ops.DEFINE_VAR}, f"can only create VALID from a constant, got {self.op}"
|
393
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
394
|
+
# NOTE: only VALID has a masked ShapeTracker, the CONST operands are unmasked
|
395
|
+
unmasked_st = ShapeTracker.from_shape(()).reshape((1,)*len(st.shape)).expand(st.shape).to_uop()
|
396
|
+
return UOp(Ops.VALID, dtypes.bool, (st.to_uop(),)).where(self.replace(src=(unmasked_st,)), UOp.const(self.dtype, 0).replace(src=(unmasked_st,)))
|
333
397
|
@staticmethod
|
334
|
-
def range(dtype:DType, start:
|
335
|
-
|
336
|
-
|
337
|
-
|
398
|
+
def range(dtype:DType, start:sint, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(start), sint_to_uop(end)), arg=idx)
|
399
|
+
def _reduce_op(self, op:Ops, axis:tuple[int, ...]):
|
400
|
+
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
401
|
+
return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
|
402
|
+
def r(self, op:Ops, axis:tuple[int, ...]) -> UOp:
|
403
|
+
new_shape = unwrap(self.st).reduce(axis)
|
404
|
+
|
405
|
+
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
406
|
+
# TODO: this shouldn't be here, it belongs in scheduler! that's why it broke multi
|
407
|
+
if not SPLIT_REDUCEOP or isinstance(self._device, tuple) or not all_int(self.shape) or (0 in self.shape) or \
|
408
|
+
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
|
409
|
+
return self._reduce_op(op, axis)
|
410
|
+
|
411
|
+
# if there are few globals, make some reduces into globals by splitting into two kernels
|
412
|
+
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
|
413
|
+
# ~2**10 should be enough if GROUP is used
|
414
|
+
# 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
|
415
|
+
# split is moved to the end to provide maximum locality for the second phase reduce.
|
416
|
+
self_real_strides = unwrap(self.st).real_strides(ignore_valid=True)
|
417
|
+
split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
|
418
|
+
if self.shape[i] % x == 0 and self_real_strides[i] != 0]
|
419
|
+
if not split_candidates: return self._reduce_op(op, axis)
|
420
|
+
dim_to_split, divisor = split_candidates[0]
|
421
|
+
splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
|
422
|
+
splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
|
423
|
+
if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
|
424
|
+
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
|
338
425
|
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
339
|
-
def contiguous(self): return
|
426
|
+
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
|
427
|
+
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
|
428
|
+
|
429
|
+
# *** from MultiLazyBuffer ***
|
430
|
+
|
431
|
+
def multi(self, *more:UOp, axis:int|None, real:tuple[bool,...]|None=None):
|
432
|
+
parents = (self,)+more
|
433
|
+
assert all_same([x.dtype for x in parents]), "multi parents must have the same dtype"
|
434
|
+
return UOp(Ops.MULTI, self.dtype, parents, (axis, real if real is not None else (True,)*len(parents)))
|
435
|
+
|
436
|
+
@property
|
437
|
+
def bounds(self):
|
438
|
+
if self.axis is None: raise RuntimeError("bounds is not defined when axis is None")
|
439
|
+
return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.src], initial=0)))
|
440
|
+
|
441
|
+
@functools.cached_property
|
442
|
+
def axis(self) -> Optional[int]:
|
443
|
+
if self.op is Ops.MULTI: return self.arg[0]
|
444
|
+
# NOTE: they all have to share an axis, we always choose [-1]
|
445
|
+
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
|
446
|
+
src_axis = self.src[0].axis
|
447
|
+
if self.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in self.arg[1] else src_axis
|
448
|
+
if self.op is Ops.RESHAPE:
|
449
|
+
if src_axis is None: return None
|
450
|
+
arg_acc:list[sint] = list(itertools.accumulate(self.arg, operator.mul, initial=1))
|
451
|
+
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
|
452
|
+
# TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
|
453
|
+
return len(arg_acc) - arg_acc[::-1].index(prod(self.src[0].shape[:src_axis])) - 1
|
454
|
+
if self.op is Ops.PERMUTE: return self.arg.index(src_axis) if src_axis is not None else None
|
455
|
+
return src_axis
|
456
|
+
|
457
|
+
@property
|
458
|
+
def real(self):
|
459
|
+
assert self.op is Ops.MULTI
|
460
|
+
return self.arg[1]
|
461
|
+
|
340
462
|
@property
|
341
|
-
def
|
463
|
+
def real_lbs(self): return [lb for lb,r in zip(self.src, self.real) if r]
|
464
|
+
|
465
|
+
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> UOp:
|
466
|
+
if axis is None: lbs = [self] * len(devices)
|
467
|
+
else:
|
468
|
+
if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}")
|
469
|
+
# NOTE: this works for both even shards and uneven shards
|
470
|
+
sz = self.shape[axis] // len(devices)
|
471
|
+
sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]
|
472
|
+
lbs = []
|
473
|
+
for sz,off in zip(sizes, itertools.accumulate(sizes, initial=0)):
|
474
|
+
lbs.append(self.shrink(tuple((0,s) if i != axis else (off,off+sz) for i,s in enumerate(self.shape))))
|
475
|
+
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)]
|
476
|
+
return UOp.multi(*[lb.contiguous() for lb in sharded_lbs], axis=axis)
|
477
|
+
|
478
|
+
# *** from LazyBuffer ***
|
479
|
+
|
480
|
+
@staticmethod
|
481
|
+
def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None) -> UOp:
|
482
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
483
|
+
# Tensor const is CONST(VIEW(DEVICE)) -> RESHAPE -> EXPAND
|
484
|
+
if op is Ops.CONST:
|
485
|
+
assert isinstance(arg, get_args(ConstType)), f"trying to create CONST with {arg=}"
|
486
|
+
return UOp.const(dtype, unwrap(arg)).replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),),
|
487
|
+
ShapeTracker.from_shape(())),)).reshape((1,)*len(shape)).expand(shape)
|
488
|
+
# Tensor variable binding is BIND(VAR(VIEW(DEVICE)), CONST(VIEW(DEVICE)))
|
489
|
+
if op is Ops.BIND:
|
490
|
+
var, val = arg.unbind()
|
491
|
+
return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
|
492
|
+
# otherwise it's just a VIEW(BUFFER)
|
493
|
+
return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st)
|
494
|
+
def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp:
|
495
|
+
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
496
|
+
if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
|
497
|
+
# COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st)
|
498
|
+
ret = UOp(Ops.COPY, self.base.dtype, (UOp(Ops.DEVICE, arg=device), self.base), clone)
|
499
|
+
op_arg = []
|
500
|
+
mop = self
|
501
|
+
while mop is not self.base:
|
502
|
+
op_arg.append((mop.op, mop.arg))
|
503
|
+
mop = mop.src[0]
|
504
|
+
for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg)
|
505
|
+
return ret
|
506
|
+
def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
|
507
|
+
@property
|
508
|
+
def metadata(self): return all_metadata.get(self, None)
|
342
509
|
|
343
510
|
# *** uop movement ops ***
|
344
511
|
|
345
512
|
@property
|
346
|
-
def base(self)
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
513
|
+
def base(self) -> UOp:
|
514
|
+
if self.op in GroupOp.Movement: return self.src[0].base
|
515
|
+
return self.src[0].base if self.op is Ops.VIEW and len(self.src) == 1 else self
|
516
|
+
def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
|
517
|
+
|
518
|
+
def _mop(self, op:Ops, arg):
|
519
|
+
ret = UOp(op, self.dtype, (self,), arg)
|
520
|
+
if self.st == ret.st: return self # ignore NOOPs, also check ret.st
|
521
|
+
return ret
|
522
|
+
|
523
|
+
def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg)
|
524
|
+
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg)
|
525
|
+
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg)
|
526
|
+
def permute(self, arg:tuple[sint, ...]): return self._mop(Ops.PERMUTE, arg)
|
527
|
+
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg)
|
528
|
+
def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg)
|
351
529
|
|
352
530
|
# *** uop Buffer stuff ***
|
353
531
|
|
532
|
+
buffer_num = itertools.count(0)
|
354
533
|
@staticmethod
|
355
|
-
def new_buffer(device:str, size:int, dtype:DType
|
356
|
-
|
534
|
+
def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size))
|
535
|
+
@property
|
536
|
+
def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
|
537
|
+
@functools.cached_property
|
538
|
+
def _device(self) -> Optional[str|tuple[str, ...]]:
|
539
|
+
if self.op is Ops.DEVICE: return self.arg
|
540
|
+
if self.op is Ops.MULTI: return tuple(cast(str, x.device) for x in self.src)
|
541
|
+
return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
|
357
542
|
@property
|
358
543
|
def buf_uop(self) -> UOp:
|
359
|
-
|
360
|
-
|
544
|
+
if self.base.op is Ops.BUFFER: return self.base
|
545
|
+
assert self.base.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW}, f"buf_uop called on {self.op}"
|
546
|
+
return self.src[0].buf_uop
|
547
|
+
@property
|
548
|
+
def buffer(self) -> Buffer:
|
549
|
+
if self.op is Ops.VIEW:
|
550
|
+
assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
|
551
|
+
return self.src[0].buffer
|
552
|
+
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
|
553
|
+
if (cret:=buffers.get(self)) is not None: return cret
|
554
|
+
from tinygrad.device import Buffer
|
555
|
+
assert isinstance(self.device, str), f"buffer not supported on multi {self.device}"
|
556
|
+
buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base)
|
557
|
+
return ret
|
558
|
+
@property
|
559
|
+
def realized(self) -> Optional[Buffer]: return self.buffer if self.op is Ops.BUFFER else None
|
560
|
+
@property
|
561
|
+
def is_realized(self) -> bool:
|
562
|
+
return all(x.base.realized is not None for x in self.base.real_lbs) if self.base.op is Ops.MULTI else self.base.realized is not None
|
361
563
|
|
362
564
|
# *** uop Variable stuff ***
|
363
565
|
|
@@ -373,18 +575,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
373
575
|
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
374
576
|
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
|
375
577
|
return UOp(Ops.BIND, self.dtype, (self, self.const_like(val)))
|
376
|
-
def unbind(self) ->
|
578
|
+
def unbind(self) -> tuple[Variable, int]:
|
377
579
|
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
|
378
580
|
return self.src[0], self.src[1].arg
|
379
581
|
@property
|
380
582
|
def val(self) -> int: return self.unbind()[1]
|
381
|
-
def vars(self) ->
|
382
|
-
bound_vars = set([x for x in self.
|
583
|
+
def vars(self) -> set[UOp]:
|
584
|
+
bound_vars = set([x for x in self.toposort if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
|
383
585
|
bound_var_base = set(x.src[0] for x in bound_vars)
|
384
|
-
all_vars = set([x for x in self.
|
586
|
+
all_vars = set([x for x in self.toposort if x.op is Ops.DEFINE_VAR])
|
385
587
|
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
386
|
-
def variables(self) ->
|
387
|
-
st_vars:
|
588
|
+
def variables(self) -> list[Variable]:
|
589
|
+
st_vars: list[set[Variable]] = [x.st_arg.vars() for x in self.toposort if x.op in GroupOp.Buffer]
|
388
590
|
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
|
389
591
|
|
390
592
|
# *** uop symbolic stuff ***
|
@@ -396,7 +598,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
396
598
|
if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
397
599
|
if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
|
398
600
|
return 1
|
399
|
-
def divides(self, v) ->
|
601
|
+
def divides(self, v) -> UOp|None:
|
400
602
|
if v==1: return self
|
401
603
|
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
|
402
604
|
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
|
@@ -410,15 +612,23 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
410
612
|
@property
|
411
613
|
def vmax(self) -> ConstType: return self._min_max[1]
|
412
614
|
@functools.cached_property
|
413
|
-
def _min_max(self) ->
|
615
|
+
def _min_max(self) -> tuple[ConstType, ConstType]:
|
414
616
|
if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype):
|
415
617
|
(s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
|
416
618
|
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
|
417
619
|
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
|
418
|
-
|
419
|
-
if self.op is Ops.
|
420
|
-
|
421
|
-
|
620
|
+
# SHL/SHR on consts only
|
621
|
+
if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
|
622
|
+
if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2]
|
623
|
+
if self.op is Ops.MOD and s1_vmin > 0:
|
624
|
+
return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), s1_vmax-1)
|
625
|
+
if self.op is Ops.IDIV:
|
626
|
+
if s1_vmin == s1_vmax: # min/max are equal in a CONST
|
627
|
+
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
|
628
|
+
if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
|
629
|
+
# don't know exact bounds, but know the sign
|
630
|
+
if (s0_vmax <= 0 and s1_vmin < 0) or (s0_vmin >= 0 and s1_vmin > 0): return 0, dtypes.max(self.dtype)
|
631
|
+
if (s0_vmax <= 0 and s1_vmin > 0) or (s0_vmin >= 0 and s1_vmin < 0): return dtypes.min(self.dtype), 0
|
422
632
|
if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
|
423
633
|
if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
|
424
634
|
if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
|
@@ -431,9 +641,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
431
641
|
if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
|
432
642
|
if self.op is Ops.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
|
433
643
|
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
|
434
|
-
if self.op in {Ops.
|
644
|
+
if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
|
435
645
|
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
|
436
|
-
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else
|
646
|
+
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].vmax
|
437
647
|
if self.op is Ops.CONST: return self.arg, self.arg
|
438
648
|
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
|
439
649
|
return dtypes.min(self.dtype), dtypes.max(self.dtype)
|
@@ -441,11 +651,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
441
651
|
@functools.cached_property
|
442
652
|
def _sym_fxn(self):
|
443
653
|
sself = self.simplify()
|
444
|
-
varnames = tuple(x.arg[0] for x in sself.
|
654
|
+
varnames = tuple(x.arg[0] for x in sself.toposort if x.op is Ops.DEFINE_VAR)
|
445
655
|
# TODO: sanitize varnames, or don't use naked eval while staying fast
|
446
656
|
return eval("lambda "+','.join(varnames)+": "+sself.render()), varnames # pylint: disable=eval-used
|
447
657
|
|
448
|
-
def sym_infer(self, var_vals:
|
658
|
+
def sym_infer(self, var_vals:dict[UOp, int]):
|
449
659
|
fxn, varnames = self._sym_fxn
|
450
660
|
return fxn(**{k.arg[0]:v for k,v in var_vals.items() if k.arg[0] in varnames})
|
451
661
|
|
@@ -456,25 +666,22 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
456
666
|
@dataclass(frozen=True)
|
457
667
|
class KernelInfo:
|
458
668
|
local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
|
459
|
-
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to
|
669
|
+
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL)
|
460
670
|
dont_use_locals: bool = False # don't use local indexing
|
461
671
|
|
462
672
|
# ***** ops in python *****
|
463
673
|
|
464
|
-
def
|
465
|
-
|
466
|
-
|
467
|
-
except OverflowError: return dv
|
468
|
-
return wfxn
|
674
|
+
def safe_exp2(x):
|
675
|
+
try: return 2 ** x
|
676
|
+
except OverflowError: return math.inf
|
469
677
|
|
470
|
-
python_alu:
|
471
|
-
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2:
|
678
|
+
python_alu: dict[Ops, Callable] = {
|
679
|
+
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
|
472
680
|
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
473
681
|
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
|
474
|
-
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul,
|
475
|
-
Ops.
|
476
|
-
Ops.
|
477
|
-
Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift,
|
682
|
+
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
|
683
|
+
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
|
684
|
+
Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0,
|
478
685
|
Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z}
|
479
686
|
|
480
687
|
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
@@ -485,63 +692,32 @@ def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
|
485
692
|
|
486
693
|
# ***** uop helpers *****
|
487
694
|
|
488
|
-
def print_uops(uops:
|
695
|
+
def print_uops(uops:list[UOp]):
|
489
696
|
for i,u in enumerate(uops):
|
490
697
|
formatted_parents = [(uops.index(x) if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
|
491
698
|
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype):30s} " f"{str(formatted_parents):32s} {u.arg}")
|
492
699
|
|
493
|
-
def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
494
|
-
flops: sint = 0
|
495
|
-
mem: sint = 0
|
496
|
-
mults: sint = 1
|
497
|
-
mult_stack: List[sint] = []
|
498
|
-
dont_count: Set[UOp] = set()
|
499
|
-
if ignore_indexing:
|
500
|
-
for u in uops:
|
501
|
-
if u.op in {Ops.LOAD, Ops.STORE}:
|
502
|
-
dont_count = dont_count.union(u.src[0].sparents)
|
503
|
-
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].sparents)
|
504
|
-
elif u.op is Ops.IF:
|
505
|
-
dont_count = dont_count.union(u.src[0].sparents)
|
506
|
-
for u in uops:
|
507
|
-
if u.op is Ops.RANGE:
|
508
|
-
mult_stack.append(mults)
|
509
|
-
mults *= (u.src[1] - u.src[0]).ssimplify()
|
510
|
-
elif u.op is Ops.ENDRANGE:
|
511
|
-
mults = mult_stack.pop(-1)
|
512
|
-
elif u.op is Ops.SPECIAL:
|
513
|
-
mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
|
514
|
-
elif u.op is Ops.LOAD:
|
515
|
-
mem += u.dtype.itemsize * mults
|
516
|
-
elif u.op is Ops.STORE:
|
517
|
-
mem += u.src[1].dtype.itemsize * mults
|
518
|
-
elif u.op in GroupOp.ALU and u not in dont_count:
|
519
|
-
flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
520
|
-
elif u.op is Ops.WMMA and u not in dont_count:
|
521
|
-
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
522
|
-
return flops, mem
|
523
|
-
|
524
700
|
# ***** pattern matcher *****
|
525
701
|
|
526
|
-
def get_location() ->
|
702
|
+
def get_location() -> tuple[str, int]:
|
527
703
|
frm = sys._getframe(1)
|
528
704
|
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
|
529
|
-
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "
|
530
|
-
"lowerer.py", "cstyle.py"}:
|
705
|
+
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py",
|
706
|
+
"lowerer.py", "cstyle.py", "linearize.py"}:
|
531
707
|
frm = frm.f_back
|
532
708
|
return frm.f_code.co_filename, frm.f_lineno
|
533
709
|
@functools.lru_cache(None)
|
534
|
-
def lines(fn) ->
|
710
|
+
def lines(fn) -> list[str]:
|
535
711
|
with open(fn) as f: return f.readlines()
|
536
712
|
|
537
713
|
class UPat(MathTrait):
|
538
|
-
__slots__ =
|
539
|
-
def __init__(self, op:Optional[Union[Ops,
|
540
|
-
src:Optional[Union[
|
541
|
-
name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[
|
714
|
+
__slots__ = ("op", "dtype", "arg", "name", "src")
|
715
|
+
def __init__(self, op:Optional[Union[Ops, tuple[Ops, ...], set[Ops]]]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None,
|
716
|
+
src:Optional[Union[tuple[UPat, ...], list[UPat], UPat]]=None, arg:Any=None,
|
717
|
+
name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[set[Ops]]=None):
|
542
718
|
assert op is None or isinstance(op, Ops) or isinstance(op, tuple) or isinstance(op, set), "op must be Ops or tuple of Ops"
|
543
|
-
self.op: Optional[
|
544
|
-
self.dtype: Optional[
|
719
|
+
self.op: Optional[tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
|
720
|
+
self.dtype: Optional[tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
|
545
721
|
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
|
546
722
|
self.src: Any = None
|
547
723
|
assert self.name != "ctx", "UPat can't be named ctx"
|
@@ -568,13 +744,12 @@ class UPat(MathTrait):
|
|
568
744
|
|
569
745
|
@staticmethod
|
570
746
|
@functools.lru_cache(None)
|
571
|
-
def var(name:Optional[str]=None, dtype:Optional[Union[DType,
|
747
|
+
def var(name:Optional[str]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name)
|
572
748
|
@staticmethod
|
573
749
|
@functools.lru_cache(None)
|
574
|
-
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True):
|
575
|
-
return UPat((Ops.CONST, Ops.VCONST) if vec else Ops.CONST, dtype=dtype, name=name)
|
750
|
+
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True): return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name)
|
576
751
|
@staticmethod
|
577
|
-
def const(dtype:Optional[Union[DType,
|
752
|
+
def const(dtype:Optional[Union[DType, tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
|
578
753
|
|
579
754
|
# copied from UOp
|
580
755
|
def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
@@ -589,7 +764,7 @@ class UPat(MathTrait):
|
|
589
764
|
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
|
590
765
|
def alu(self, op:Ops, *src:UPat):
|
591
766
|
asrc = (self,)+src
|
592
|
-
return UPat(op,
|
767
|
+
return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
|
593
768
|
|
594
769
|
def printable(self:UPat) -> str:
|
595
770
|
try: return lines(self.location[0])[self.location[1]-1].strip()
|
@@ -602,14 +777,14 @@ class UPat(MathTrait):
|
|
602
777
|
set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
|
603
778
|
return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
|
604
779
|
|
605
|
-
def match(self:UPat, uop:UOp, store:
|
780
|
+
def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
|
606
781
|
if (self.op is not None and uop.op not in self.op) or \
|
607
782
|
(self.name is not None and store.setdefault(self.name, uop) is not uop) or \
|
608
783
|
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
|
609
784
|
(self.arg is not None and self.arg != uop.arg) or \
|
610
785
|
(self.allowed_len != -1 and len(uop.src) != self.allowed_len): return []
|
611
786
|
if self.src is None: return [store]
|
612
|
-
res:
|
787
|
+
res: list[dict[str, UOp]] = []
|
613
788
|
for vp in self.src:
|
614
789
|
stores, new_stores = [store.copy()], []
|
615
790
|
for uu, vv in zip(uop.src, vp):
|
@@ -619,13 +794,11 @@ class UPat(MathTrait):
|
|
619
794
|
return res
|
620
795
|
|
621
796
|
class UPatAny(UPat):
|
622
|
-
def match(self:UPat, uop:UOp, store:
|
623
|
-
|
624
|
-
for x in
|
625
|
-
if (match:=x.match(uop, store.copy())): ret.extend(match)
|
626
|
-
return ret
|
797
|
+
def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
|
798
|
+
matches = [x.match(uop, store.copy()) for x in self.src[0]]
|
799
|
+
return flatten([x for x in matches if x is not None])
|
627
800
|
|
628
|
-
def deconstruct_function(fxn:Callable) ->
|
801
|
+
def deconstruct_function(fxn:Callable) -> tuple:
|
629
802
|
new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
|
630
803
|
for co in fxn.__code__.co_consts:
|
631
804
|
if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names})
|
@@ -635,10 +808,10 @@ def deconstruct_function(fxn:Callable) -> Tuple:
|
|
635
808
|
return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret
|
636
809
|
|
637
810
|
class PatternMatcher:
|
638
|
-
def __init__(self, patterns:
|
811
|
+
def __init__(self, patterns:list[tuple[UPat, Callable]]):
|
639
812
|
self.patterns = patterns
|
640
813
|
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
|
641
|
-
self.pdict:
|
814
|
+
self.pdict: dict[Ops, list[tuple[UPat, Callable, set, bool]]] = {}
|
642
815
|
# uop is required, arg is optional
|
643
816
|
for p,fxn in self.patterns:
|
644
817
|
assert p.op is not None
|
@@ -651,7 +824,7 @@ class PatternMatcher:
|
|
651
824
|
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
652
825
|
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
653
826
|
|
654
|
-
def rewrite(self, uop:UOp, ctx=None) ->
|
827
|
+
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
655
828
|
ler = {u.op for u in uop.src}
|
656
829
|
for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
|
657
830
|
if not early_reject.issubset(ler): continue
|
@@ -662,39 +835,32 @@ class PatternMatcher:
|
|
662
835
|
# *** tracking pattern matcher ***
|
663
836
|
|
664
837
|
TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
|
665
|
-
match_stats:
|
838
|
+
match_stats:dict[UPat, list[Union[int, float]]] = dict()
|
666
839
|
@dataclass(frozen=True)
|
667
|
-
class
|
668
|
-
loc:
|
669
|
-
sink: UOp
|
670
|
-
matches:
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
_rewrite_cnt: Dict[str, int] = {}
|
840
|
+
class TrackedGraphRewrite:
|
841
|
+
loc: tuple[str, int] # location that called graph_rewrite
|
842
|
+
sink: UOp # the sink input to graph_rewrite
|
843
|
+
matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches
|
844
|
+
tracked_keys:list[Any] = []
|
845
|
+
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
|
846
|
+
_name_cnt:dict[str, int] = {}
|
675
847
|
def track_rewrites(named=False):
|
676
848
|
def _decorator(func):
|
677
849
|
def __wrapper(self, *args, **kwargs):
|
678
850
|
if TRACK_MATCH_STATS >= 2:
|
679
|
-
if named:
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
|
684
|
-
return ret
|
851
|
+
if named: _name_cnt[func.__name__] = _name_cnt.get(func.__name__, 0)+1
|
852
|
+
tracked_keys.append(f"{func.__name__}_{_name_cnt[func.__name__]}" if named else self)
|
853
|
+
tracked_ctxs.append([])
|
854
|
+
return func(self, *args, **kwargs)
|
685
855
|
return __wrapper
|
686
856
|
return _decorator
|
687
857
|
|
688
858
|
class TrackedPatternMatcher(PatternMatcher):
|
689
|
-
def
|
690
|
-
super().__init__(patterns)
|
691
|
-
for p,_ in self.patterns:
|
692
|
-
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
693
|
-
|
694
|
-
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
|
859
|
+
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
695
860
|
ret = None
|
696
861
|
ler = {u.op for u in uop.src}
|
697
862
|
for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
|
863
|
+
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
698
864
|
st = time.perf_counter()
|
699
865
|
if not early_reject.issubset(ler):
|
700
866
|
match_stats[p][2] += time.perf_counter()-st
|
@@ -705,10 +871,9 @@ class TrackedPatternMatcher(PatternMatcher):
|
|
705
871
|
match_stats[p][0] += 1
|
706
872
|
match_stats[p][3] += (et:=time.perf_counter()-st)
|
707
873
|
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
|
708
|
-
if TRACK_MATCH_STATS >= 2 and len(
|
874
|
+
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0: tracked_ctxs[-1][-1].matches.append((uop, ret, p))
|
709
875
|
return ret # NOTE: if it returns None, we keep trying to match
|
710
876
|
match_stats[p][2] += time.perf_counter()-st
|
711
|
-
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0: rewrite_stack[-1][1][-1].matches.append((uop, ret, None, 0))
|
712
877
|
return None
|
713
878
|
|
714
879
|
if TRACK_MATCH_STATS:
|
@@ -717,12 +882,10 @@ if TRACK_MATCH_STATS:
|
|
717
882
|
@atexit.register
|
718
883
|
def print_match_stats():
|
719
884
|
if TRACK_MATCH_STATS >= 2:
|
720
|
-
with open(fn:=temp("rewrites.pkl"), "wb") as f:
|
721
|
-
print(f"rewrote {len(
|
722
|
-
pickle.dump(
|
723
|
-
if getenv("VIZ"):
|
724
|
-
os.environ["VIZ"] = "0"
|
725
|
-
os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py"), temp("rewrites.pkl")])
|
885
|
+
with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f:
|
886
|
+
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
|
887
|
+
with Context(PICKLE_BUFFERS=0): pickle.dump((tracked_keys, tracked_ctxs), f)
|
888
|
+
if getenv("VIZ"): launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
|
726
889
|
if getenv("PRINT_MATCH_STATS", 1):
|
727
890
|
ret = [0,0,0.0,0.0]
|
728
891
|
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
|
@@ -731,118 +894,46 @@ if TRACK_MATCH_STATS:
|
|
731
894
|
ret = [x+y for x,y in zip(ret, v)]
|
732
895
|
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {(ret[2]+ret[3])*1000.:9.2f} ms -- TOTAL")
|
733
896
|
|
897
|
+
def launch_viz(env_str:str, data:str):
|
898
|
+
os.environ[env_str] = "0"
|
899
|
+
os.environ[f"{env_str}_DATA"] = data
|
900
|
+
if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")):
|
901
|
+
args = ['--kernels', getenv("VIZ_DATA", "")] if getenv("VIZ_DATA", "") else []
|
902
|
+
args += ['--profile', getenv("PROFILE_DATA", "")] if getenv("PROFILE_DATA", "") else []
|
903
|
+
os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py")] + args)
|
904
|
+
|
734
905
|
# *** simple graph rewrite engine ***
|
735
906
|
|
736
907
|
class RewriteContext:
|
737
|
-
def __init__(self, pm, ctx):
|
908
|
+
def __init__(self, pm, ctx=None):
|
738
909
|
self.pm: PatternMatcher = pm
|
739
910
|
self.ctx = ctx
|
740
|
-
self.replace:
|
741
|
-
def
|
911
|
+
self.replace: dict[UOp, UOp] = {}
|
912
|
+
def top_down_rewrite(self, n:UOp) -> UOp:
|
742
913
|
if (rn := self.replace.get(n)) is not None: return rn
|
743
|
-
new_src = tuple(
|
914
|
+
new_src = tuple([self.top_down_rewrite(x) for x in n.src])
|
744
915
|
new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg)
|
745
|
-
self.replace[n] = ret = n if new_n is None else self.
|
916
|
+
self.replace[n] = ret = n if new_n is None else self.top_down_rewrite(new_n)
|
917
|
+
return ret
|
918
|
+
def bottom_up_rewrite(self, n:UOp) -> UOp:
|
919
|
+
if (rn := self.replace.get(n)) is not None: return rn
|
920
|
+
new_n: UOp|None = n
|
921
|
+
while new_n is not None: last_n, new_n = new_n, self.pm.rewrite(new_n, self.ctx)
|
922
|
+
new_src = tuple([self.bottom_up_rewrite(x) for x in last_n.src])
|
923
|
+
self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
|
746
924
|
return ret
|
747
925
|
|
748
|
-
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
|
749
|
-
if TRACK_MATCH_STATS >= 2 and len(
|
750
|
-
|
751
|
-
return RewriteContext(pm, ctx).
|
752
|
-
|
753
|
-
# ***** uop type spec *****
|
754
|
-
|
755
|
-
# this is the matcher for the final rendered UOps
|
756
|
-
# matcher functions returns True or False (or None to not match)
|
757
|
-
spec = PatternMatcher([
|
758
|
-
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
|
759
|
-
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
|
760
|
-
(UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True),
|
761
|
-
lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
|
762
|
-
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
763
|
-
|
764
|
-
(UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
|
765
|
-
(UPat(Ops.SPECIAL, src=()), lambda: True),
|
766
|
-
|
767
|
-
# TODO: confirm the args of both of these are shapetrackers
|
768
|
-
(UPat(Ops.VIEW, dtypes.void, src=()), lambda: True),
|
769
|
-
(UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype),
|
770
|
-
|
771
|
-
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
|
772
|
-
(UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
|
773
|
-
|
774
|
-
# early LOAD has a <buf, shapetracker, store?>
|
775
|
-
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
|
776
|
-
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
|
777
|
-
|
778
|
-
# early STORE has a <buf, shapetracker, val>
|
779
|
-
(UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
|
780
|
-
|
781
|
-
# **** new style load/store ****
|
782
|
-
|
783
|
-
# INDEX is used in new style load/store
|
784
|
-
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
|
785
|
-
|
786
|
-
# LOAD takes a <bufidx, alt?, gate?, barrier?>
|
787
|
-
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
|
788
|
-
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
|
789
|
-
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
|
790
|
-
|
791
|
-
# STORE takes a <bufidx, val, gate?>
|
792
|
-
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True),
|
793
|
-
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
794
|
-
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True),
|
795
|
-
|
796
|
-
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
797
|
-
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
798
|
-
(UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype == y.dtype),
|
799
|
-
# and SHL/SHR, the shift distance can be an int
|
800
|
-
(UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
|
801
|
-
(UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
802
|
-
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
|
803
|
-
|
804
|
-
(UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
|
805
|
-
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
|
806
|
-
|
807
|
-
# all WMMA has 3 args, <x, w, acc>
|
808
|
-
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
|
809
|
-
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
810
|
-
(UPat(Ops.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
811
|
-
|
812
|
-
# if has a <gate, barrier?>
|
813
|
-
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
|
814
|
-
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
|
815
|
-
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
816
|
-
|
817
|
-
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
818
|
-
(UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
819
|
-
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
820
|
-
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
|
821
|
-
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
|
822
|
-
|
823
|
-
# NOTE: for testing, we let sinks be anything
|
824
|
-
#(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
|
825
|
-
(UPat(Ops.SINK, dtypes.void), lambda: True),
|
826
|
-
(UPat(Ops.NOOP), lambda: True),
|
827
|
-
|
828
|
-
# PTX LOAD/STORE
|
829
|
-
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
830
|
-
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True),
|
831
|
-
])
|
832
|
-
|
833
|
-
def type_verify(uops:List[UOp]):
|
834
|
-
for i,u in enumerate(uops):
|
835
|
-
if not spec.rewrite(u):
|
836
|
-
print_uops(uops)
|
837
|
-
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}")
|
926
|
+
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp:
|
927
|
+
if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True
|
928
|
+
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
|
929
|
+
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink)
|
838
930
|
|
839
|
-
|
931
|
+
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> dict[UOp, UOp]:
|
932
|
+
if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True
|
933
|
+
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
|
934
|
+
rewrite_ctx = RewriteContext(pm, ctx)
|
935
|
+
return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
|
840
936
|
|
841
|
-
def cast_float_to_bf16(x: UOp) -> UOp:
|
842
|
-
assert x.dtype == dtypes.float, "cast float -> bf16 must start with float"
|
843
|
-
x = x.bitcast(dtypes.uint)
|
844
|
-
x = (-x & 0x7f800000).where(x + ((x >> 16) & 1) + 0x7fff, (x & 0xffff).where((x | 0x10000), x))
|
845
|
-
return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
|
846
937
|
|
847
938
|
# *** most of symbolic lives here now ***
|
848
939
|
|
@@ -851,70 +942,68 @@ def split_uop(x:UOp, sep:Ops):
|
|
851
942
|
for s in x.src: yield from split_uop(s, sep)
|
852
943
|
else: yield x
|
853
944
|
|
854
|
-
def
|
855
|
-
# simplify x %
|
945
|
+
def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None:
|
946
|
+
# simplify x // y or x % y, None means no change
|
947
|
+
# simple cancel div/mod case
|
948
|
+
if y.vmin != 0 != y.vmax and (q:=x.vmin//y.vmin) == x.vmin//y.vmax == x.vmax//y.vmin == x.vmax//y.vmax:
|
949
|
+
return x - q*y if which is Ops.MOD else x.const_like(q)
|
856
950
|
|
857
|
-
|
858
|
-
if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c
|
951
|
+
if (y.op is not Ops.CONST) or ((c := y.arg) <= 0) or (x.dtype.count > 1): return None
|
859
952
|
|
860
|
-
|
953
|
+
svars, factors, quotients, remainders, gcd, div, const, offset, something_changed = [], [], [], [], c, 1, 0, 0, False
|
861
954
|
for u in split_uop(x, Ops.ADD):
|
862
|
-
if
|
863
|
-
|
864
|
-
assert divides is not None
|
865
|
-
remainder.append(divides)
|
866
|
-
something_changed = True
|
867
|
-
elif u.op is Ops.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0:
|
868
|
-
remainder.append(u.src[0])
|
955
|
+
if u.op is Ops.MOD and which is Ops.MOD and u.src[1].op is Ops.CONST and u.src[1].arg%c == 0:
|
956
|
+
u = u.src[0]
|
869
957
|
something_changed = True
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
#
|
958
|
+
v: UOp = u.divides(f:=u.const_factor())
|
959
|
+
q, r = divmod(f, c)
|
960
|
+
if r==0 or ((which is Ops.MOD or split_rem or u.op is Ops.CONST) and r!=f): something_changed = True
|
961
|
+
offset += r*v.vmin
|
962
|
+
if u.op is Ops.CONST: const += f
|
963
|
+
else: # div is the smallest common divisor of all terms
|
964
|
+
if f > 1 and c % f == 0 and (div == 1 or div > f): div = f
|
965
|
+
gcd = math.gcd(r, gcd)
|
966
|
+
factors.append(f); svars.append(v); quotients.append(q); remainders.append(r) # noqa: E702
|
967
|
+
|
968
|
+
lbound = ubound = offset = offset % c
|
969
|
+
# we can fold if the expression has only one non-constant term and this term can only take on two values
|
970
|
+
if len(svars)==1 and (v:=svars[0]).vmax-v.vmin == 1:
|
971
|
+
r = (offset+remainders[0])%c - offset%c
|
972
|
+
offset -= r * v.vmin
|
973
|
+
if which is Ops.MOD: return r*v + offset
|
974
|
+
return (factors[0]-r)//c * v + (const-offset)//c
|
975
|
+
|
976
|
+
# a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
|
977
|
+
# within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
|
978
|
+
for (r, v) in zip(remainders, svars):
|
979
|
+
if r > c//2:
|
980
|
+
if (lbound := lbound + (r:=r-c) * (v.vmax-v.vmin)) < 0: break
|
981
|
+
elif (ubound := ubound + r * (v.vmax-v.vmin)) >= c: break
|
982
|
+
offset -= r * v.vmin # determine what the new offset would be
|
983
|
+
else: # vmin/vmax of the remainder is between 0 and c, we can remove the mod/div
|
984
|
+
remainders = [min(r, r-c, key=abs) for r in remainders]
|
985
|
+
if which is Ops.MOD: return functools.reduce(operator.add, [r*v for r,v in zip(remainders,svars)], x.const_like(offset))
|
986
|
+
return functools.reduce(operator.add, [(f-r)//c * v for f,r,v in zip(factors, remainders,svars)], x.const_like((const-offset)//c))
|
987
|
+
|
988
|
+
if gcd != 1: something_changed = True
|
989
|
+
if not something_changed:
|
990
|
+
if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, UOp.const(dtypes.int, div), Ops.IDIV)) is not None: return newx//(c//div)
|
991
|
+
return None
|
992
|
+
quo, rem = x.const_like(const//c), x.const_like((const%c)//gcd)
|
993
|
+
for q,r,f,v in zip(quotients, remainders, factors, svars):
|
994
|
+
if which is Ops.IDIV and (not split_rem) and r!=0:
|
995
|
+
rem += f//gcd * v
|
996
|
+
else:
|
997
|
+
rem += r//gcd * v
|
998
|
+
quo += q * v
|
876
999
|
|
877
|
-
|
878
|
-
|
1000
|
+
if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
|
1001
|
+
return rem//(c//gcd)+quo
|
879
1002
|
|
880
|
-
|
881
|
-
for u in split_uop(x, Ops.ADD):
|
882
|
-
if u.op is Ops.CONST:
|
883
|
-
# add all const together first
|
884
|
-
if rem_const != 0: something_changed = True
|
885
|
-
rem_const += u.arg
|
886
|
-
elif (factor:=u.const_factor())%c == 0:
|
887
|
-
if factor:
|
888
|
-
divides = u.divides(c)
|
889
|
-
assert divides is not None
|
890
|
-
quotient.append(divides)
|
891
|
-
something_changed = True
|
892
|
-
else:
|
893
|
-
# divisor is the smallest common divisor of all MULs
|
894
|
-
if u.op is Ops.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor
|
895
|
-
remainder.append(u)
|
896
|
-
gcd = math.gcd(gcd, factor)
|
897
|
-
|
898
|
-
# handle the const
|
899
|
-
if rem_const%c != rem_const:
|
900
|
-
something_changed = True
|
901
|
-
quotient.append(x.const_like(rem_const//c))
|
902
|
-
rem_const = rem_const%c
|
903
|
-
if rem_const != 0: remainder.append(x.const_like(rem_const))
|
904
|
-
|
905
|
-
# x // c -> quotient + (remainder // div) // (c // div)
|
906
|
-
div = gcd if gcd > 1 else divisor
|
907
|
-
|
908
|
-
if not something_changed: return newx//(c//div) if 1 < div < c and (newx:=div_folding(x, div)) is not None else None
|
909
|
-
rem:Optional[UOp] = functools.reduce(operator.add, remainder) if remainder else None
|
910
|
-
quo:Optional[UOp] = functools.reduce(operator.add, quotient) if quotient else None
|
911
|
-
if quo is None: return x.const_like(0) if rem is None else cast(UOp, div_folding(rem, div))//(c//div)
|
912
|
-
return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo
|
913
|
-
|
914
|
-
def lt_folding(x:UOp, c:int) -> Optional[UOp]:
|
1003
|
+
def lt_folding(x:UOp, c:int) -> UOp|None:
|
915
1004
|
p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
|
916
1005
|
if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
|
917
|
-
return cast(UOp, functools.reduce(operator.add, np).divides(d))
|
1006
|
+
return cast(UOp, functools.reduce(operator.add, np).divides(d))<(c//d)
|
918
1007
|
return None
|
919
1008
|
|
920
1009
|
def fold_unrolled_divs(divs:UOp):
|
@@ -938,7 +1027,7 @@ def fold_unrolled_divs(divs:UOp):
|
|
938
1027
|
if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
|
939
1028
|
return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
|
940
1029
|
|
941
|
-
def canonicalize_simplex(X:UOp) ->
|
1030
|
+
def canonicalize_simplex(X:UOp) -> UOp|None:
|
942
1031
|
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
|
943
1032
|
# returns x0 + x1 + ... in such case, or None if not
|
944
1033
|
changed, ret = False, []
|
@@ -958,7 +1047,7 @@ def is_increasing(f:UOp) -> bool:
|
|
958
1047
|
if f.op in (Ops.MUL, Ops.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
|
959
1048
|
return False # False if not sure
|
960
1049
|
|
961
|
-
def parse_valid(valid:UOp) ->
|
1050
|
+
def parse_valid(valid:UOp) -> tuple[UOp, bool, int]:
|
962
1051
|
# if it's X <= c, returns X, True, c
|
963
1052
|
# if it's X >= c, returns X, False, c
|
964
1053
|
|
@@ -966,14 +1055,14 @@ def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
|
|
966
1055
|
if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
|
967
1056
|
(s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
|
968
1057
|
# X < c -> X <= c-1
|
969
|
-
if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1
|
1058
|
+
if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, valid.src[1].arg-1
|
970
1059
|
raise ValueError(f"not able to parse {valid=}")
|
971
1060
|
|
972
|
-
def uop_given_valid(valid:UOp, uop:UOp) ->
|
1061
|
+
def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
|
973
1062
|
# return None if valid is always False, otherwise the simplified uop (might be the same as input)
|
974
1063
|
|
975
1064
|
# first, parse valid into {expr: (lower_bound, upper_bound)}
|
976
|
-
bounds:
|
1065
|
+
bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
|
977
1066
|
for stmt in split_uop(valid, Ops.AND):
|
978
1067
|
try: expr, is_upper, c = parse_valid(stmt)
|
979
1068
|
except ValueError: return uop # give up if we cannot parse the valid
|
@@ -990,7 +1079,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
|
990
1079
|
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
991
1080
|
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
|
992
1081
|
# try checking the whole clause
|
993
|
-
if expr in uop.
|
1082
|
+
if expr in uop.toposort:
|
994
1083
|
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
|
995
1084
|
|
996
1085
|
for candidate in candidates:
|
@@ -1003,13 +1092,13 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
|
1003
1092
|
|
1004
1093
|
return uop
|
1005
1094
|
|
1006
|
-
def _valid_priority(v: UOp, valids:
|
1095
|
+
def _valid_priority(v: UOp, valids:list[UOp]):
|
1007
1096
|
# we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
|
1008
|
-
try: return sum(-1 if parse_valid(v)[0] in other.
|
1097
|
+
try: return sum(-1 if parse_valid(v)[0] in other.toposort else 0 for other in valids)
|
1009
1098
|
except ValueError: return 0
|
1010
1099
|
|
1011
|
-
def simplify_valid(valid:UOp) ->
|
1012
|
-
ret:
|
1100
|
+
def simplify_valid(valid:UOp) -> UOp|None:
|
1101
|
+
ret:list[UOp] = []
|
1013
1102
|
something_changed = False
|
1014
1103
|
valids = list(split_uop(valid, Ops.AND))
|
1015
1104
|
for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
|
@@ -1017,9 +1106,11 @@ def simplify_valid(valid:UOp) -> Optional[UOp]:
|
|
1017
1106
|
if ret[-1] is not stmt: something_changed = True
|
1018
1107
|
return functools.reduce(operator.and_, ret) if something_changed else None
|
1019
1108
|
|
1020
|
-
def max_var_const(x:UOp, c1:UOp, c2:UOp):
|
1021
|
-
|
1022
|
-
|
1109
|
+
# def max_var_const(x:UOp, c1:UOp, c2:UOp):
|
1110
|
+
# if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
|
1111
|
+
# if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
|
1112
|
+
|
1113
|
+
def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x
|
1023
1114
|
|
1024
1115
|
symbolic_simple = PatternMatcher([
|
1025
1116
|
# ** self folding **
|
@@ -1032,23 +1123,25 @@ symbolic_simple = PatternMatcher([
|
|
1032
1123
|
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
1033
1124
|
((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
|
1034
1125
|
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
|
1126
|
+
((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
|
1127
|
+
lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
|
1035
1128
|
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
|
1036
1129
|
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
|
1037
|
-
(UPat.var("x")
|
1038
|
-
((UPat.var("x") & UPat.var("x")), lambda x: x),
|
1039
|
-
((UPat.var("x") | UPat.var("x")), lambda x: x),
|
1130
|
+
(UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
|
1040
1131
|
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
|
1041
1132
|
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
|
1042
1133
|
# ** zero folding **
|
1043
|
-
(UPat.var("x") < UPat.var("x"), lambda x:
|
1134
|
+
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
|
1044
1135
|
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
|
1045
|
-
lambda x:
|
1136
|
+
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
|
1046
1137
|
# x*0 -> 0 or 0*x -> 0
|
1047
1138
|
# if x is nan or inf it should render the nan value.
|
1048
1139
|
# NOTE: this can be wrong for loaded NaN
|
1049
1140
|
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
1050
1141
|
# ** constant folding **
|
1051
|
-
|
1142
|
+
# TODO: add const folding for Ops.THREEFRY
|
1143
|
+
(UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))),
|
1144
|
+
lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False)) if a.op is not Ops.THREEFRY else None),
|
1052
1145
|
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
1053
1146
|
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
|
1054
1147
|
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
|
@@ -1061,46 +1154,45 @@ symbolic_simple = PatternMatcher([
|
|
1061
1154
|
symbolic = symbolic_simple+PatternMatcher([
|
1062
1155
|
# ** COMMUTATIVE flipping **
|
1063
1156
|
(UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
1064
|
-
# group like
|
1065
|
-
((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
|
1066
1157
|
# ** boolean algebra **
|
1067
1158
|
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
|
1068
1159
|
# ** combine terms **
|
1069
1160
|
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
|
1161
|
+
((UPat.var("y") + UPat.var("x") * UPat.cvar("c0")) + UPat.var("x") * UPat.cvar("c1"), lambda x,y,c0,c1: y+x*(c0+c1)),
|
1070
1162
|
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
|
1163
|
+
((UPat.var("y") + UPat.var("x")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: y+x*(c+1)),
|
1071
1164
|
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
|
1072
|
-
((UPat.var("
|
1165
|
+
((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2),
|
1166
|
+
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3)
|
1073
1167
|
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
|
1074
1168
|
# a conditional with the same results either way is a noop, also fold const conditionals
|
1075
1169
|
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
1076
1170
|
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
1171
|
+
# alu of two where with same conds can combine, only do if true branch or false branch is const
|
1172
|
+
(UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
|
1173
|
+
lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
|
1077
1174
|
# ALU min==max -> CONST (slow!)
|
1078
1175
|
(UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
1079
1176
|
# max folding
|
1080
1177
|
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
1081
1178
|
# TODO: why does this rule break beautiful_mnist?
|
1082
1179
|
#((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z),
|
1083
|
-
((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
|
1180
|
+
#((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
|
1084
1181
|
# ** two stage ALU folding **
|
1085
|
-
((UPat.var("x")
|
1086
|
-
|
1087
|
-
((UPat.var("x") & UPat.cvar("c1")) & UPat.cvar("c2"), lambda x,c1,c2: x&(c1&c2)),
|
1088
|
-
((UPat.var("x") | UPat.cvar("c1")) | UPat.cvar("c2"), lambda x,c1,c2: x|(c1|c2)),
|
1182
|
+
*((UPat.var("x").alu(op, UPat.cvar("c1")).alu(op, UPat.cvar("c2")).named("f"),
|
1183
|
+
lambda f,x,c1,c2: x.alu(f.op,c1.alu(f.op,c2))) for op in GroupOp.Associative),
|
1089
1184
|
((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
|
1090
1185
|
((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
|
1091
1186
|
# ** lt **
|
1092
1187
|
# c0*x<c1 for positive int c0,c1
|
1093
|
-
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))
|
1094
|
-
lambda x,c0,c1: x
|
1188
|
+
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
|
1189
|
+
lambda x,c0,c1: x<math.ceil(c1.arg/c0.arg) if c0.arg > 0 and c1.arg > 0 else None),
|
1095
1190
|
# c0*x<c1 for negative int c0 and non-positive c1
|
1096
|
-
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))
|
1097
|
-
lambda x,c0,c1: (-x)
|
1191
|
+
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
|
1192
|
+
lambda x,c0,c1: (-x)<(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
|
1098
1193
|
# x//c0<c1 for positive int c0
|
1099
|
-
((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False))
|
1100
|
-
lambda x,c0,c1: x
|
1101
|
-
# mul add lt
|
1102
|
-
(((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
|
1103
|
-
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
|
1194
|
+
((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False))<UPat.cvar("c1", vec=False),
|
1195
|
+
lambda x,c0,c1: x<(c1.arg*c0.arg) if c0.arg > 0 else None),
|
1104
1196
|
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
|
1105
1197
|
(UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
|
1106
1198
|
(UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
|
@@ -1108,18 +1200,20 @@ symbolic = symbolic_simple+PatternMatcher([
|
|
1108
1200
|
# unrolled arange div folding
|
1109
1201
|
(UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),
|
1110
1202
|
# generic lt folding
|
1111
|
-
(UPat.var("x", dtypes.sints)
|
1203
|
+
(UPat.var("x", dtypes.sints)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
1112
1204
|
# canonicalize a simplex with positive coefficients > 0
|
1113
1205
|
# not x < 1 -> X > 0
|
1114
|
-
(UPat.var("x", dtypes.ints)
|
1206
|
+
((UPat.var("x", dtypes.ints)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
|
1115
1207
|
# ** div **
|
1116
|
-
#
|
1117
|
-
(UPat.var("x"
|
1208
|
+
# div folding
|
1209
|
+
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)), # (x//c+a)//d -> (x+a*c)//(c*d)
|
1210
|
+
(UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
|
1118
1211
|
# ** mod **
|
1119
1212
|
# mod folding
|
1120
|
-
(UPat.var("x") % UPat.
|
1213
|
+
(UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
|
1121
1214
|
])
|
1122
1215
|
|
1216
|
+
|
1123
1217
|
symbolic_flat = symbolic+PatternMatcher([
|
1124
1218
|
# ** combine terms (opinionated) **
|
1125
1219
|
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
@@ -1134,7 +1228,7 @@ syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<"
|
|
1134
1228
|
Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
|
1135
1229
|
renderer = PatternMatcher([
|
1136
1230
|
(UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
|
1137
|
-
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg
|
1231
|
+
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg}")),
|
1138
1232
|
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
|
1139
1233
|
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
1140
1234
|
(UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
|
@@ -1149,4 +1243,26 @@ renderer = PatternMatcher([
|
|
1149
1243
|
sint = Union[int, UOp]
|
1150
1244
|
Variable = UOp
|
1151
1245
|
|
1152
|
-
ConstLike = Union[ConstType, Variable,
|
1246
|
+
ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]]
|
1247
|
+
|
1248
|
+
# *** UOp merge views and swizzling ***
|
1249
|
+
|
1250
|
+
merge_views = PatternMatcher([
|
1251
|
+
# VIEW(VIEW) merges to a single VIEW
|
1252
|
+
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)),
|
1253
|
+
(UPat(Ops.VIEW, name="vm", src=(UPat.var("x"),)), lambda vm,x: x if vm.st.contiguous and x.st is not None and x.shape == vm.shape else None),
|
1254
|
+
# merge unmasked const views
|
1255
|
+
(UPat(Ops.VIEW, name="view", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
|
1256
|
+
lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
|
1257
|
+
])
|
1258
|
+
|
1259
|
+
# push VIEW to parents
|
1260
|
+
view_left = merge_views+PatternMatcher([
|
1261
|
+
# VIEW(CONST) becomes VALID
|
1262
|
+
(UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.valid(vm.st)),
|
1263
|
+
# VIEW before elementwise/buffer ops
|
1264
|
+
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
|
1265
|
+
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
|
1266
|
+
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)),
|
1267
|
+
lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
1268
|
+
])
|