tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/ops.py
CHANGED
@@ -1,113 +1,465 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import Union, Tuple,
|
3
|
-
import functools,
|
4
|
-
from enum import
|
5
|
-
from dataclasses import dataclass
|
6
|
-
from
|
7
|
-
from
|
8
|
-
from tinygrad.
|
9
|
-
from tinygrad.
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
#
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
2
|
+
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, DefaultDict
|
3
|
+
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect
|
4
|
+
from enum import auto, IntEnum, Enum
|
5
|
+
from dataclasses import dataclass, field
|
6
|
+
from collections import defaultdict
|
7
|
+
from weakref import WeakValueDictionary
|
8
|
+
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
9
|
+
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
12
|
+
|
13
|
+
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
|
14
|
+
class FastEnum(IntEnum):
|
15
|
+
def __str__(self): return Enum.__str__(self)
|
16
|
+
@staticmethod
|
17
|
+
def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]])
|
18
|
+
|
19
|
+
class SimpleMathTrait:
|
20
|
+
# required to implement
|
21
|
+
def alu(self:T, arg:Ops, *src) -> T: raise NotImplementedError
|
22
|
+
def const_like(self:T, b:ConstLike) -> T: raise NotImplementedError
|
23
|
+
|
24
|
+
# great functions you get!
|
25
|
+
def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x
|
26
|
+
def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
|
27
|
+
def logical_not(self): return self.ne(True)
|
28
|
+
def neg(self):
|
29
|
+
dtype: Optional[DType] = getattr(self, 'dtype', None)
|
30
|
+
assert dtype is not None, "MathTraits __neg__ requires a dtype"
|
31
|
+
return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
|
32
|
+
def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse)
|
33
|
+
def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse)
|
34
|
+
def bitwise_and(self, x, reverse=False): return self._binop(Ops.AND, x, reverse)
|
35
|
+
def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse)
|
36
|
+
def xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse)
|
37
|
+
def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse)
|
38
|
+
def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
|
39
|
+
def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
|
40
|
+
|
41
|
+
def __neg__(self): return self.neg()
|
42
|
+
|
43
|
+
def __add__(self, x): return self.add(x)
|
44
|
+
def __sub__(self, x): return self.sub(x)
|
45
|
+
def __mul__(self, x): return self.mul(x)
|
46
|
+
def __truediv__(self, x): return self.div(x)
|
47
|
+
def __floordiv__(self, x): return self.idiv(x)
|
48
|
+
def __and__(self, x): return self.bitwise_and(x)
|
49
|
+
def __or__(self, x): return self.bitwise_or(x)
|
50
|
+
def __xor__(self, x): return self.xor(x)
|
51
|
+
|
52
|
+
def __radd__(self, x): return self.add(x, True)
|
53
|
+
def __rsub__(self, x): return self.sub(x, True)
|
54
|
+
def __rmul__(self, x): return self.mul(x, True)
|
55
|
+
def __rtruediv__(self, x): return self.div(x, True)
|
56
|
+
def __rfloordiv__(self, x): return self.idiv(x, True)
|
57
|
+
def __rand__(self, x): return self.bitwise_and(x, True)
|
58
|
+
def __ror__(self, x): return self.bitwise_or(x, True)
|
59
|
+
def __rxor__(self, x): return self.xor(x, True)
|
60
|
+
|
61
|
+
def lt(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
|
62
|
+
def gt(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
|
63
|
+
def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x))
|
64
|
+
def ge(self, x): return self.lt(x).logical_not()
|
65
|
+
def le(self, x): return self.gt(x).logical_not()
|
66
|
+
def eq(self, x): return self.ne(x).logical_not()
|
67
|
+
|
68
|
+
def __lt__(self, x): return self.lt(x)
|
69
|
+
def __gt__(self, x): return self.gt(x)
|
70
|
+
def __ne__(self, x): return self.ne(x)
|
71
|
+
def __ge__(self, x): return self.ge(x)
|
72
|
+
def __le__(self, x): return self.le(x)
|
73
|
+
# NOTE: __eq__ isn't overridden, and means the same thing as is by default
|
74
|
+
|
75
|
+
class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
|
76
|
+
# TODO: move to Tensor when new backward is done
|
77
|
+
def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse)
|
78
|
+
def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse)
|
79
|
+
def __lshift__(self, x): return self.lshift(x)
|
80
|
+
def __rshift__(self, x): return self.rshift(x)
|
81
|
+
def __rlshift__(self, x): return self.lshift(x, True)
|
82
|
+
def __rrshift__(self, x): return self.rshift(x, True)
|
83
|
+
|
84
|
+
# not in Tensor
|
85
|
+
def __mod__(self, x): return self.alu(Ops.MOD, self.ufix(x))
|
86
|
+
def __rmod__(self, x): return self.ufix(x).alu(Ops.MOD, self)
|
87
|
+
|
88
|
+
def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
|
89
|
+
def minimum(self, x): return -(-self).maximum(-x)
|
90
|
+
def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y))
|
91
|
+
def threefry(self, seed): return self.alu(Ops.THREEFRY, seed)
|
92
|
+
def reciprocal(self): return self.alu(Ops.RECIP)
|
93
|
+
def sqrt(self): return self.alu(Ops.SQRT)
|
94
|
+
def sin(self): return self.alu(Ops.SIN)
|
95
|
+
def log2(self): return self.alu(Ops.LOG2)
|
96
|
+
def exp2(self): return self.alu(Ops.EXP2)
|
97
|
+
|
98
|
+
# the order of these Ops controls the order of the toposort
|
99
|
+
class Ops(FastEnum):
|
100
|
+
# uops that aren't rendered
|
101
|
+
SINK = auto()
|
102
|
+
CONTIGUOUS = auto()
|
103
|
+
PRELOAD = auto()
|
104
|
+
|
105
|
+
# MetaOps
|
106
|
+
COPY = auto()
|
107
|
+
EMPTY = auto()
|
108
|
+
BUFFER_VIEW = auto()
|
109
|
+
|
110
|
+
EXPAND = auto()
|
111
|
+
CONTRACT = auto()
|
112
|
+
VIEW = auto()
|
113
|
+
DEFINE_GLOBAL = auto()
|
114
|
+
BUFFER = auto()
|
115
|
+
DEFINE_VAR = auto()
|
116
|
+
DEFINE_LOCAL = auto()
|
117
|
+
DEFINE_ACC = auto()
|
118
|
+
VALID = auto()
|
119
|
+
SPECIAL = auto()
|
120
|
+
NOOP = auto()
|
121
|
+
|
122
|
+
# reduce
|
123
|
+
REDUCE_AXIS = auto()
|
124
|
+
|
125
|
+
# helper ops
|
126
|
+
GEP = auto()
|
127
|
+
VECTORIZE = auto()
|
128
|
+
|
129
|
+
# UnaryOps
|
130
|
+
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
131
|
+
|
132
|
+
# loads before math
|
133
|
+
LOAD = auto()
|
134
|
+
|
135
|
+
# math ops
|
136
|
+
WMMA = auto()
|
137
|
+
|
138
|
+
# BinaryOps
|
20
139
|
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
21
|
-
SHR = auto();
|
22
|
-
|
23
|
-
|
140
|
+
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702
|
141
|
+
|
142
|
+
# TernaryOps
|
24
143
|
WHERE = auto(); MULACC = auto() # noqa: E702
|
25
|
-
class ReduceOps(Enum):
|
26
|
-
"""A -> B (reduce)"""
|
27
|
-
SUM = auto(); MAX = auto() # noqa: E702
|
28
|
-
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
|
29
|
-
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
|
30
144
|
|
31
|
-
|
145
|
+
# assignment ops
|
146
|
+
STORE = auto()
|
147
|
+
ASSIGN = auto()
|
148
|
+
BIND = auto()
|
32
149
|
|
33
|
-
#
|
34
|
-
|
150
|
+
# late INDEX
|
151
|
+
INDEX = auto()
|
35
152
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
st: ShapeTracker
|
153
|
+
# control flow ops
|
154
|
+
BARRIER = auto()
|
155
|
+
IF = auto()
|
156
|
+
RANGE = auto()
|
41
157
|
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
158
|
+
# ops that are not graph nodes
|
159
|
+
ENDRANGE = auto()
|
160
|
+
ENDIF = auto()
|
161
|
+
|
162
|
+
# consts last!
|
163
|
+
VCONST = auto()
|
164
|
+
CONST = auto()
|
165
|
+
|
166
|
+
class GroupOp:
|
167
|
+
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
|
168
|
+
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
|
169
|
+
Ops.SUB, Ops.FDIV}
|
170
|
+
Ternary = {Ops.WHERE, Ops.MULACC}
|
171
|
+
ALU = set.union(Unary, Binary, Ternary)
|
172
|
+
|
173
|
+
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
|
174
|
+
|
175
|
+
# meta ops
|
176
|
+
Meta = {Ops.COPY, Ops.EMPTY, Ops.BUFFER_VIEW}
|
177
|
+
Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
|
178
|
+
|
179
|
+
# BinaryOps that can be flipped
|
180
|
+
Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR}
|
181
|
+
|
182
|
+
# do not preserve f(0) = 0
|
183
|
+
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
|
184
|
+
|
185
|
+
# https://en.wikipedia.org/wiki/Identity_element
|
186
|
+
def identity_element(op:Ops, dt:DType): return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
187
|
+
|
188
|
+
def can_pad(u:UOp) -> bool: return not any(x.op in GroupOp.UnsafePad for x in u.sparents)
|
189
|
+
|
190
|
+
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
|
191
|
+
|
192
|
+
# With True as the default, this matches the old symbolic behavior
|
193
|
+
def resolve(x, default:bool=True):
|
194
|
+
if not isinstance(x, UOp): return bool(x)
|
195
|
+
assert x.dtype is dtypes.bool, "UOp in resolve must be bool"
|
196
|
+
# NOTE: generating the text for the exception is expensive, so we do this
|
197
|
+
return bool(sx.vmin) if (sx:=x.simplify()).vmin == sx.vmax else default
|
198
|
+
|
199
|
+
# smax/smin are replacements for max/min that preserve symbolic
|
200
|
+
def _suop(lst, uop_fxn, python_fxn):
|
201
|
+
max_uop, max_num = partition(lst, lambda x: isinstance(x, UOp))
|
202
|
+
if len(max_uop): return functools.reduce(uop_fxn, (max_uop + [python_fxn(max_num)]) if len(max_num) else max_uop).ssimplify()
|
203
|
+
return python_fxn(max_num)
|
204
|
+
def smax(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.maximum, max)
|
205
|
+
def smin(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.minimum, min)
|
206
|
+
|
207
|
+
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
208
|
+
def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
209
|
+
|
210
|
+
# used for UOp and UPat
|
211
|
+
def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
|
212
|
+
def dfs(x:Any, cache:dict):
|
213
|
+
for s in srcfn(x) or []:
|
214
|
+
cache.setdefault(s, [len(cache), 0, False])[1] += 1
|
215
|
+
if cache[s][1] == 1: dfs(s, cache)
|
216
|
+
if cache is None: dfs(x, cache:={})
|
217
|
+
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
|
218
|
+
cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
|
219
|
+
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
|
220
|
+
|
221
|
+
class UOpMetaClass(type):
|
222
|
+
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
|
223
|
+
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
|
224
|
+
if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
|
225
|
+
UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg)
|
58
226
|
return ret
|
59
|
-
def __eq__(self, x): return self.cached_compare(x, context={})
|
60
|
-
def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
|
61
|
-
@functools.cached_property
|
62
|
-
def dtype(self) -> DType:
|
63
|
-
if self.op in BufferOps: return self.arg.dtype
|
64
|
-
if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
|
65
|
-
return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype
|
66
227
|
|
228
|
+
class UOp(MathTrait, metaclass=UOpMetaClass):
|
229
|
+
__slots__ = ["op", "dtype", "src", "arg"]
|
230
|
+
def __init__(self, op:Ops, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
|
231
|
+
# TODO: instant check rules here make debugging easier
|
232
|
+
self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
|
233
|
+
def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg)
|
234
|
+
def replace(self, **kwargs) -> UOp:
|
235
|
+
for k in kwargs: assert k in self.__slots__, f"unkown replace arg, expected one of {self.__slots__}, got {k}"
|
236
|
+
new_args = (kwargs.get("op", self.op), kwargs.get("dtype", self.dtype), kwargs.get("src", self.src), kwargs.get("arg", self.arg))
|
237
|
+
if (self.op, self.dtype, self.src, self.arg) == new_args: return self
|
238
|
+
return UOp(*new_args)
|
67
239
|
@functools.cached_property
|
68
240
|
def key(self) -> bytes:
|
69
|
-
return hashlib.sha256(
|
241
|
+
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
|
242
|
+
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
243
|
+
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else self.arg
|
244
|
+
@functools.cached_property
|
245
|
+
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
|
246
|
+
@functools.cached_property # parents with self
|
247
|
+
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
|
248
|
+
|
249
|
+
@functools.cached_property
|
250
|
+
def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]:
|
251
|
+
return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src))
|
252
|
+
|
253
|
+
# *** uop shape stuff ***
|
254
|
+
|
255
|
+
@property
|
256
|
+
def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR}
|
70
257
|
@functools.cached_property
|
71
|
-
def
|
72
|
-
|
258
|
+
def st(self) -> Optional[ShapeTracker]:
|
259
|
+
if not self.has_st: return None
|
260
|
+
if self.op in GroupOp.Buffer: return self.st_arg
|
261
|
+
if self.op is Ops.VIEW: return self.arg
|
262
|
+
src_sts = [x.st for x in self.src if x.st is not None]
|
263
|
+
assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
|
264
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
265
|
+
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is Ops.REDUCE_AXIS else src_sts[0]
|
73
266
|
@functools.cached_property
|
74
|
-
def
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
267
|
+
def full_shape(self) -> Tuple[sint, ...]:
|
268
|
+
return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
269
|
+
|
270
|
+
# *** uop evaluation ***
|
271
|
+
|
272
|
+
def simplify(self):
|
273
|
+
with Context(TRACK_MATCH_STATS=0):
|
274
|
+
return graph_rewrite(self, symbolic)
|
275
|
+
def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
|
276
|
+
def _eval(self, dtype, expected_type:Type[T]) -> T:
|
277
|
+
assert self.dtype in dtype, f"eval with wrong dtype {self}"
|
278
|
+
vmin, vmax = (simple_self:=self.simplify())._min_max
|
279
|
+
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self.render()}")
|
280
|
+
assert isinstance(vmin, expected_type), f"vmin is wrong dtype {type(vmin)} != {expected_type}"
|
281
|
+
return vmin
|
282
|
+
def __bool__(self): return self._eval((dtypes.bool,), bool)
|
283
|
+
def __int__(self): return self._eval(dtypes.ints, int)
|
284
|
+
def __float__(self): return self._eval(dtypes.floats, float)
|
285
|
+
def substitute(self, dvars:Dict[UOp, UOp]):
|
286
|
+
with Context(TRACK_MATCH_STATS=0):
|
287
|
+
return graph_rewrite(self, _substitute, dvars)
|
288
|
+
|
289
|
+
# *** uop syntactic sugar ***
|
290
|
+
|
291
|
+
@property
|
292
|
+
def st_arg(self) -> ShapeTracker:
|
293
|
+
assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}"
|
294
|
+
ret = self.src[0 if self.op is Ops.VALID else 1]
|
295
|
+
assert ret.op is Ops.VIEW, f"st_arg trying to return {ret}"
|
296
|
+
return ret.arg
|
87
297
|
@property
|
88
|
-
def
|
89
|
-
|
90
|
-
self.
|
298
|
+
def axis_arg(self) -> Tuple[int, ...]:
|
299
|
+
assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
|
300
|
+
ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
|
301
|
+
assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
|
91
302
|
return ret
|
303
|
+
def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
|
304
|
+
def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
305
|
+
def const_like(self, b:ConstLike): return UOp.const(self.dtype, b)
|
306
|
+
def broadcast(self, count:int):
|
307
|
+
assert self.dtype.count == 1
|
308
|
+
if count == 1: return self
|
309
|
+
return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
|
310
|
+
def cast(self, dtype:DType): return UOp(Ops.CAST, dtype, (self,))
|
311
|
+
def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
|
312
|
+
def gep(self, i:Union[Tuple[int, ...], int]):
|
313
|
+
if isinstance(i, int):
|
314
|
+
# NOTE: these are just shortcuts to not have to create and fold later
|
315
|
+
if self.op is Ops.VECTORIZE: return self.src[i]
|
316
|
+
if self.op is Ops.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
|
317
|
+
if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg)
|
318
|
+
i = (i,)
|
319
|
+
if (self.dtype.vcount == len(i) and i == tuple(range(len(i)))) or self.dtype == dtypes.void: return self
|
320
|
+
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
321
|
+
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, src=(self,)+src, **kwargs)
|
322
|
+
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
323
|
+
def alu(self, arg, *src:UOp):
|
324
|
+
out_dtype = (self, *src)[-1].dtype
|
325
|
+
if arg in {Ops.CMPLT, Ops.CMPNE} and out_dtype is not None:
|
326
|
+
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
327
|
+
return UOp(arg, out_dtype, (self,)+src)
|
328
|
+
@staticmethod
|
329
|
+
def const(dtype:DType, b:ConstLike):
|
330
|
+
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
331
|
+
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
332
|
+
return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
|
333
|
+
@staticmethod
|
334
|
+
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
|
335
|
+
return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
|
336
|
+
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=(idx, False))
|
337
|
+
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
|
338
|
+
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
339
|
+
def contiguous(self): return UOp(Ops.CONTIGUOUS, self.dtype, (self,))
|
340
|
+
@property
|
341
|
+
def is_contiguous_base(self): return self.op is Ops.CONTIGUOUS and not (self.src[0].base.op is Ops.VIEW and len(self.src[0].base.src) == 2)
|
92
342
|
|
93
|
-
|
94
|
-
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, 0, {arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
|
95
|
-
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
|
96
|
-
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
|
97
|
-
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
|
98
|
-
UnaryOps.BITCAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # bitcast uses no flops
|
99
|
-
**{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op not in {UnaryOps.CAST, UnaryOps.BITCAST}}, # noqa: E501
|
100
|
-
**{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
|
101
|
-
**{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
|
102
|
-
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
|
343
|
+
# *** uop movement ops ***
|
103
344
|
|
104
|
-
@
|
105
|
-
def
|
106
|
-
|
107
|
-
|
108
|
-
|
345
|
+
@property
|
346
|
+
def base(self): return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 else self
|
347
|
+
def view(self, st:ShapeTracker):
|
348
|
+
assert self.op is not Ops.STORE, "VIEW of STORE is invalid, STORE is always base"
|
349
|
+
return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st)
|
350
|
+
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
|
109
351
|
|
110
|
-
#
|
352
|
+
# *** uop Buffer stuff ***
|
353
|
+
|
354
|
+
@staticmethod
|
355
|
+
def new_buffer(device:str, size:int, dtype:DType, num=-1): return UOp(Ops.BUFFER, dtype.ptr(), (), (num, (device, size, dtype)))
|
356
|
+
|
357
|
+
@property
|
358
|
+
def buf_uop(self) -> UOp:
|
359
|
+
assert self.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}"
|
360
|
+
return self.src[0]
|
361
|
+
|
362
|
+
# *** uop Variable stuff ***
|
363
|
+
|
364
|
+
@staticmethod
|
365
|
+
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int):
|
366
|
+
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
|
367
|
+
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
368
|
+
@property
|
369
|
+
def expr(self):
|
370
|
+
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
371
|
+
return self.arg[0]
|
372
|
+
def bind(self, val:int):
|
373
|
+
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
374
|
+
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
|
375
|
+
return UOp(Ops.BIND, self.dtype, (self, self.const_like(val)))
|
376
|
+
def unbind(self) -> Tuple[Variable, int]:
|
377
|
+
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
|
378
|
+
return self.src[0], self.src[1].arg
|
379
|
+
@property
|
380
|
+
def val(self) -> int: return self.unbind()[1]
|
381
|
+
def vars(self) -> Set[UOp]:
|
382
|
+
bound_vars = set([x for x in self.sparents if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
|
383
|
+
bound_var_base = set(x.src[0] for x in bound_vars)
|
384
|
+
all_vars = set([x for x in self.sparents if x.op is Ops.DEFINE_VAR])
|
385
|
+
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
386
|
+
def variables(self) -> List[Variable]:
|
387
|
+
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in GroupOp.Buffer]
|
388
|
+
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
|
389
|
+
|
390
|
+
# *** uop symbolic stuff ***
|
391
|
+
|
392
|
+
def const_factor(self) -> int:
|
393
|
+
"""largest known int that divides self"""
|
394
|
+
if self.op is Ops.CONST: return self.arg
|
395
|
+
if self.op is Ops.VCONST: return math.gcd(*self.arg)
|
396
|
+
if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
397
|
+
if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
|
398
|
+
return 1
|
399
|
+
def divides(self, v) -> Optional[UOp]:
|
400
|
+
if v==1: return self
|
401
|
+
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
|
402
|
+
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
|
403
|
+
if self.op is Ops.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
|
404
|
+
if self.op is Ops.MUL:
|
405
|
+
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
|
406
|
+
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
407
|
+
return None # generic None if we aren't sure
|
408
|
+
@property
|
409
|
+
def vmin(self) -> ConstType: return self._min_max[0]
|
410
|
+
@property
|
411
|
+
def vmax(self) -> ConstType: return self._min_max[1]
|
412
|
+
@functools.cached_property
|
413
|
+
def _min_max(self) -> Tuple[ConstType, ConstType]:
|
414
|
+
if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype):
|
415
|
+
(s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
|
416
|
+
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
|
417
|
+
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
|
418
|
+
if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1
|
419
|
+
if self.op is Ops.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST
|
420
|
+
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
|
421
|
+
if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
|
422
|
+
if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
|
423
|
+
if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
|
424
|
+
if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
|
425
|
+
if self.dtype == dtypes.bool:
|
426
|
+
if self.op is Ops.OR: return s0_vmin or s1_vmin, s0_vmax or s1_vmax
|
427
|
+
if self.op is Ops.AND: return s0_vmin and s1_vmin, s0_vmax and s1_vmax
|
428
|
+
# float has NAN issue and we use explicit NAN in transcendental
|
429
|
+
if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax)
|
430
|
+
# NOTE: returned UOp is assumed to be CONST
|
431
|
+
if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
|
432
|
+
if self.op is Ops.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
|
433
|
+
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
|
434
|
+
if self.op in {Ops.EXPAND, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
|
435
|
+
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
|
436
|
+
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype)
|
437
|
+
if self.op is Ops.CONST: return self.arg, self.arg
|
438
|
+
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
|
439
|
+
return dtypes.min(self.dtype), dtypes.max(self.dtype)
|
440
|
+
|
441
|
+
@functools.cached_property
|
442
|
+
def _sym_fxn(self):
|
443
|
+
sself = self.simplify()
|
444
|
+
varnames = tuple(x.arg[0] for x in sself.sparents if x.op is Ops.DEFINE_VAR)
|
445
|
+
# TODO: sanitize varnames, or don't use naked eval while staying fast
|
446
|
+
return eval("lambda "+','.join(varnames)+": "+sself.render()), varnames # pylint: disable=eval-used
|
447
|
+
|
448
|
+
def sym_infer(self, var_vals:Dict[UOp, int]):
|
449
|
+
fxn, varnames = self._sym_fxn
|
450
|
+
return fxn(**{k.arg[0]:v for k,v in var_vals.items() if k.arg[0] in varnames})
|
451
|
+
|
452
|
+
def render(self, simplify=True) -> str:
|
453
|
+
ret = graph_rewrite(self.simplify() if simplify else self, renderer)
|
454
|
+
return ret.arg if ret.op is Ops.NOOP else str(ret)
|
455
|
+
|
456
|
+
@dataclass(frozen=True)
|
457
|
+
class KernelInfo:
|
458
|
+
local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
|
459
|
+
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to EXPAND)
|
460
|
+
dont_use_locals: bool = False # don't use local indexing
|
461
|
+
|
462
|
+
# ***** ops in python *****
|
111
463
|
|
112
464
|
def hook_overflow(dv, fxn):
|
113
465
|
def wfxn(*args):
|
@@ -115,55 +467,686 @@ def hook_overflow(dv, fxn):
|
|
115
467
|
except OverflowError: return dv
|
116
468
|
return wfxn
|
117
469
|
|
118
|
-
python_alu = {
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
if
|
156
|
-
|
157
|
-
|
158
|
-
|
470
|
+
python_alu: Dict[Ops, Callable] = {
|
471
|
+
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: hook_overflow(math.inf, lambda x: 2**x),
|
472
|
+
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
473
|
+
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
|
474
|
+
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul,
|
475
|
+
Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
|
476
|
+
Ops.MAX: max, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt, Ops.XOR: operator.xor,
|
477
|
+
Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift,
|
478
|
+
Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z}
|
479
|
+
|
480
|
+
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
481
|
+
if dtype.count > 1:
|
482
|
+
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
|
483
|
+
alu = python_alu[op](*operands)
|
484
|
+
return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu
|
485
|
+
|
486
|
+
# ***** uop helpers *****
|
487
|
+
|
488
|
+
def print_uops(uops:List[UOp]):
|
489
|
+
for i,u in enumerate(uops):
|
490
|
+
formatted_parents = [(uops.index(x) if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
|
491
|
+
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype):30s} " f"{str(formatted_parents):32s} {u.arg}")
|
492
|
+
|
493
|
+
def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
494
|
+
flops: sint = 0
|
495
|
+
mem: sint = 0
|
496
|
+
mults: sint = 1
|
497
|
+
mult_stack: List[sint] = []
|
498
|
+
dont_count: Set[UOp] = set()
|
499
|
+
if ignore_indexing:
|
500
|
+
for u in uops:
|
501
|
+
if u.op in {Ops.LOAD, Ops.STORE}:
|
502
|
+
dont_count = dont_count.union(u.src[0].sparents)
|
503
|
+
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].sparents)
|
504
|
+
elif u.op is Ops.IF:
|
505
|
+
dont_count = dont_count.union(u.src[0].sparents)
|
506
|
+
for u in uops:
|
507
|
+
if u.op is Ops.RANGE:
|
508
|
+
mult_stack.append(mults)
|
509
|
+
mults *= (u.src[1] - u.src[0]).ssimplify()
|
510
|
+
elif u.op is Ops.ENDRANGE:
|
511
|
+
mults = mult_stack.pop(-1)
|
512
|
+
elif u.op is Ops.SPECIAL:
|
513
|
+
mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
|
514
|
+
elif u.op is Ops.LOAD:
|
515
|
+
mem += u.dtype.itemsize * mults
|
516
|
+
elif u.op is Ops.STORE:
|
517
|
+
mem += u.src[1].dtype.itemsize * mults
|
518
|
+
elif u.op in GroupOp.ALU and u not in dont_count:
|
519
|
+
flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
520
|
+
elif u.op is Ops.WMMA and u not in dont_count:
|
521
|
+
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
522
|
+
return flops, mem
|
523
|
+
|
524
|
+
# ***** pattern matcher *****
|
525
|
+
|
526
|
+
def get_location() -> Tuple[str, int]:
|
527
|
+
frm = sys._getframe(1)
|
528
|
+
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
|
529
|
+
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "uopgraph.py", "schedule.py",
|
530
|
+
"lowerer.py", "cstyle.py"}:
|
531
|
+
frm = frm.f_back
|
532
|
+
return frm.f_code.co_filename, frm.f_lineno
|
533
|
+
@functools.lru_cache(None)
|
534
|
+
def lines(fn) -> List[str]:
|
535
|
+
with open(fn) as f: return f.readlines()
|
536
|
+
|
537
|
+
class UPat(MathTrait):
|
538
|
+
__slots__ = ["op", "dtype", "arg", "name", "src"]
|
539
|
+
def __init__(self, op:Optional[Union[Ops, Tuple[Ops, ...], Set[Ops]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
|
540
|
+
src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
|
541
|
+
name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[Set[Ops]]=None):
|
542
|
+
assert op is None or isinstance(op, Ops) or isinstance(op, tuple) or isinstance(op, set), "op must be Ops or tuple of Ops"
|
543
|
+
self.op: Optional[Tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
|
544
|
+
self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
|
545
|
+
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
|
546
|
+
self.src: Any = None
|
547
|
+
assert self.name != "ctx", "UPat can't be named ctx"
|
548
|
+
|
549
|
+
# try all permutations if it's a list
|
550
|
+
if isinstance(src, list): self.src = list(itertools.permutations(src)) if not all_same(src) else [src]
|
551
|
+
# only one if it's a tuple
|
552
|
+
elif isinstance(src, tuple): self.src = [src]
|
553
|
+
# repeat if it's a UPat
|
554
|
+
elif isinstance(src, UPat): self.src = [itertools.repeat(src)]
|
555
|
+
|
556
|
+
self.allowed_len: int = -1 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
|
557
|
+
self.location = location or get_location()
|
558
|
+
|
559
|
+
if custom_early_reject is not None: self.early_reject = custom_early_reject
|
159
560
|
else:
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
561
|
+
upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0])
|
562
|
+
self.early_reject = {pp.op[0] for pp in upat_match if pp.op is not None and len(pp.op) == 1}
|
563
|
+
|
564
|
+
def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, self.allowed_len == -1, self.custom_early_reject)
|
565
|
+
|
566
|
+
@staticmethod
|
567
|
+
def any(*src): return UPatAny(src=src)
|
568
|
+
|
569
|
+
@staticmethod
|
570
|
+
@functools.lru_cache(None)
|
571
|
+
def var(name:Optional[str]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name)
|
572
|
+
@staticmethod
|
573
|
+
@functools.lru_cache(None)
|
574
|
+
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True):
|
575
|
+
return UPat((Ops.CONST, Ops.VCONST) if vec else Ops.CONST, dtype=dtype, name=name)
|
576
|
+
@staticmethod
|
577
|
+
def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
|
578
|
+
|
579
|
+
# copied from UOp
|
580
|
+
def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
581
|
+
def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
|
582
|
+
def cast(self, dtype=None): return UPat(Ops.CAST, dtype, (self,))
|
583
|
+
def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
|
584
|
+
def gep(self, i:int): return UPat(Ops.GEP, None, (self,), (i,))
|
585
|
+
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
|
586
|
+
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
587
|
+
def assign(self, x:UPat): return UPat(Ops.ASSIGN, self.dtype, (self,x))
|
588
|
+
|
589
|
+
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
|
590
|
+
def alu(self, op:Ops, *src:UPat):
|
591
|
+
asrc = (self,)+src
|
592
|
+
return UPat(op, None if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
|
593
|
+
|
594
|
+
def printable(self:UPat) -> str:
|
595
|
+
try: return lines(self.location[0])[self.location[1]-1].strip()
|
596
|
+
except FileNotFoundError: return "<missing>"
|
597
|
+
|
598
|
+
def __repr__(self):
|
599
|
+
def rep(x):
|
600
|
+
form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
|
601
|
+
return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name),
|
602
|
+
set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
|
603
|
+
return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
|
604
|
+
|
605
|
+
def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
|
606
|
+
if (self.op is not None and uop.op not in self.op) or \
|
607
|
+
(self.name is not None and store.setdefault(self.name, uop) is not uop) or \
|
608
|
+
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
|
609
|
+
(self.arg is not None and self.arg != uop.arg) or \
|
610
|
+
(self.allowed_len != -1 and len(uop.src) != self.allowed_len): return []
|
611
|
+
if self.src is None: return [store]
|
612
|
+
res: List[Dict[str, UOp]] = []
|
613
|
+
for vp in self.src:
|
614
|
+
stores, new_stores = [store.copy()], []
|
615
|
+
for uu, vv in zip(uop.src, vp):
|
616
|
+
for s in stores: new_stores.extend(vv.match(uu, s))
|
617
|
+
stores, new_stores = new_stores, []
|
618
|
+
res.extend(stores)
|
619
|
+
return res
|
620
|
+
|
621
|
+
class UPatAny(UPat):
|
622
|
+
def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
|
623
|
+
ret = []
|
624
|
+
for x in self.src[0]:
|
625
|
+
if (match:=x.match(uop, store.copy())): ret.extend(match)
|
626
|
+
return ret
|
627
|
+
|
628
|
+
def deconstruct_function(fxn:Callable) -> Tuple:
|
629
|
+
new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
|
630
|
+
for co in fxn.__code__.co_consts:
|
631
|
+
if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names})
|
632
|
+
# NOTE: optional round trip through pickle!
|
633
|
+
assert fxn.__closure__ is None, "closures are not supported in pattern matchers"
|
634
|
+
ret = fxn.__code__, new_globals, fxn.__name__, fxn.__defaults__
|
635
|
+
return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret
|
636
|
+
|
637
|
+
class PatternMatcher:
|
638
|
+
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
639
|
+
self.patterns = patterns
|
640
|
+
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
|
641
|
+
self.pdict: Dict[Ops, List[Tuple[UPat, Callable, Set, bool]]] = {}
|
642
|
+
# uop is required, arg is optional
|
643
|
+
for p,fxn in self.patterns:
|
644
|
+
assert p.op is not None
|
645
|
+
tuple_fxn = fxn if isinstance(fxn, tuple) else deconstruct_function(fxn)
|
646
|
+
real_fxn = types.FunctionType(*tuple_fxn)
|
647
|
+
for uop in p.op: self.pdict.setdefault(uop, []).append((p, real_fxn, p.early_reject, 'ctx' in inspect.signature(real_fxn).parameters))
|
648
|
+
|
649
|
+
def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "<lambda>" else fxn) for x,fxn in self.patterns],)
|
650
|
+
|
651
|
+
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
652
|
+
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
653
|
+
|
654
|
+
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
|
655
|
+
ler = {u.op for u in uop.src}
|
656
|
+
for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
|
657
|
+
if not early_reject.issubset(ler): continue
|
658
|
+
for match in p.match(uop, {}):
|
659
|
+
if (ret:=(fxn(ctx=ctx, **match) if has_ctx else fxn(**match))) is not None: return ret
|
660
|
+
return None
|
661
|
+
|
662
|
+
# *** tracking pattern matcher ***
|
663
|
+
|
664
|
+
TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
|
665
|
+
match_stats:Dict[UPat, List[Union[int, float]]] = dict()
|
666
|
+
@dataclass(frozen=True)
|
667
|
+
class TrackedRewriteContext:
|
668
|
+
loc: Tuple[str, int] # location that called graph_rewrite
|
669
|
+
sink: UOp # the sink passed into the rewrite
|
670
|
+
matches: List[Tuple[UOp, Optional[UOp], Optional[UPat], float]] = field(default_factory=list) # all matches of sparents
|
671
|
+
|
672
|
+
rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
|
673
|
+
contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = []
|
674
|
+
_rewrite_cnt: Dict[str, int] = {}
|
675
|
+
def track_rewrites(named=False):
|
676
|
+
def _decorator(func):
|
677
|
+
def __wrapper(self, *args, **kwargs):
|
678
|
+
if TRACK_MATCH_STATS >= 2:
|
679
|
+
if named: _rewrite_cnt[func.__name__] = _rewrite_cnt.setdefault(func.__name__, 0)+1
|
680
|
+
rewrite_stack.append((f"{(n:=func.__name__)}_{_rewrite_cnt[n]}" if named else self, []))
|
681
|
+
try: ret = func(self, *args, **kwargs)
|
682
|
+
finally: # NOTE: save everything in the stack
|
683
|
+
if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
|
684
|
+
return ret
|
685
|
+
return __wrapper
|
686
|
+
return _decorator
|
687
|
+
|
688
|
+
class TrackedPatternMatcher(PatternMatcher):
|
689
|
+
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
690
|
+
super().__init__(patterns)
|
691
|
+
for p,_ in self.patterns:
|
692
|
+
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
693
|
+
|
694
|
+
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
|
695
|
+
ret = None
|
696
|
+
ler = {u.op for u in uop.src}
|
697
|
+
for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
|
698
|
+
st = time.perf_counter()
|
699
|
+
if not early_reject.issubset(ler):
|
700
|
+
match_stats[p][2] += time.perf_counter()-st
|
701
|
+
continue
|
702
|
+
match_stats[p][1] += 1
|
703
|
+
for match in p.match(uop, {}):
|
704
|
+
if (ret:=(fxn(ctx=ctx, **match) if has_ctx else fxn(**match))) is not None:
|
705
|
+
match_stats[p][0] += 1
|
706
|
+
match_stats[p][3] += (et:=time.perf_counter()-st)
|
707
|
+
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
|
708
|
+
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].matches.append((uop, ret, p, et))
|
709
|
+
return ret # NOTE: if it returns None, we keep trying to match
|
710
|
+
match_stats[p][2] += time.perf_counter()-st
|
711
|
+
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0: rewrite_stack[-1][1][-1].matches.append((uop, ret, None, 0))
|
712
|
+
return None
|
713
|
+
|
714
|
+
if TRACK_MATCH_STATS:
|
715
|
+
PatternMatcher = TrackedPatternMatcher # type: ignore
|
716
|
+
import atexit
|
717
|
+
@atexit.register
|
718
|
+
def print_match_stats():
|
719
|
+
if TRACK_MATCH_STATS >= 2:
|
720
|
+
with open(fn:=temp("rewrites.pkl"), "wb") as f:
|
721
|
+
print(f"rewrote {len(contexts)} graphs and matched {sum(len(r.matches) for _,x in contexts for r in x)} times, saved to {fn}")
|
722
|
+
pickle.dump(contexts, f)
|
723
|
+
if getenv("VIZ"):
|
724
|
+
os.environ["VIZ"] = "0"
|
725
|
+
os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py"), temp("rewrites.pkl")])
|
726
|
+
if getenv("PRINT_MATCH_STATS", 1):
|
727
|
+
ret = [0,0,0.0,0.0]
|
728
|
+
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
|
729
|
+
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
|
730
|
+
if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {(v[2]+v[3])*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
|
731
|
+
ret = [x+y for x,y in zip(ret, v)]
|
732
|
+
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {(ret[2]+ret[3])*1000.:9.2f} ms -- TOTAL")
|
733
|
+
|
734
|
+
# *** simple graph rewrite engine ***
|
735
|
+
|
736
|
+
class RewriteContext:
|
737
|
+
def __init__(self, pm, ctx):
|
738
|
+
self.pm: PatternMatcher = pm
|
739
|
+
self.ctx = ctx
|
740
|
+
self.replace: Dict[UOp, UOp] = {}
|
741
|
+
def rewrite(self, n:UOp) -> UOp:
|
742
|
+
if (rn := self.replace.get(n)) is not None: return rn
|
743
|
+
new_src = tuple(map(self.rewrite, n.src))
|
744
|
+
new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg)
|
745
|
+
self.replace[n] = ret = n if new_n is None else self.rewrite(new_n)
|
746
|
+
return ret
|
747
|
+
|
748
|
+
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
|
749
|
+
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0:
|
750
|
+
rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
|
751
|
+
return RewriteContext(pm, ctx).rewrite(sink)
|
752
|
+
|
753
|
+
# ***** uop type spec *****
|
754
|
+
|
755
|
+
# this is the matcher for the final rendered UOps
|
756
|
+
# matcher functions returns True or False (or None to not match)
|
757
|
+
spec = PatternMatcher([
|
758
|
+
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
|
759
|
+
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
|
760
|
+
(UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True),
|
761
|
+
lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
|
762
|
+
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
763
|
+
|
764
|
+
(UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
|
765
|
+
(UPat(Ops.SPECIAL, src=()), lambda: True),
|
766
|
+
|
767
|
+
# TODO: confirm the args of both of these are shapetrackers
|
768
|
+
(UPat(Ops.VIEW, dtypes.void, src=()), lambda: True),
|
769
|
+
(UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype),
|
770
|
+
|
771
|
+
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
|
772
|
+
(UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
|
773
|
+
|
774
|
+
# early LOAD has a <buf, shapetracker, store?>
|
775
|
+
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
|
776
|
+
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
|
777
|
+
|
778
|
+
# early STORE has a <buf, shapetracker, val>
|
779
|
+
(UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
|
780
|
+
|
781
|
+
# **** new style load/store ****
|
782
|
+
|
783
|
+
# INDEX is used in new style load/store
|
784
|
+
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
|
785
|
+
|
786
|
+
# LOAD takes a <bufidx, alt?, gate?, barrier?>
|
787
|
+
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
|
788
|
+
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
|
789
|
+
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
|
790
|
+
|
791
|
+
# STORE takes a <bufidx, val, gate?>
|
792
|
+
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True),
|
793
|
+
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
794
|
+
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True),
|
795
|
+
|
796
|
+
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
797
|
+
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
798
|
+
(UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype == y.dtype),
|
799
|
+
# and SHL/SHR, the shift distance can be an int
|
800
|
+
(UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
|
801
|
+
(UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
802
|
+
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
|
803
|
+
|
804
|
+
(UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
|
805
|
+
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
|
806
|
+
|
807
|
+
# all WMMA has 3 args, <x, w, acc>
|
808
|
+
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
|
809
|
+
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
810
|
+
(UPat(Ops.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
811
|
+
|
812
|
+
# if has a <gate, barrier?>
|
813
|
+
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
|
814
|
+
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
|
815
|
+
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
816
|
+
|
817
|
+
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
818
|
+
(UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
819
|
+
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
820
|
+
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
|
821
|
+
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
|
822
|
+
|
823
|
+
# NOTE: for testing, we let sinks be anything
|
824
|
+
#(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
|
825
|
+
(UPat(Ops.SINK, dtypes.void), lambda: True),
|
826
|
+
(UPat(Ops.NOOP), lambda: True),
|
827
|
+
|
828
|
+
# PTX LOAD/STORE
|
829
|
+
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
830
|
+
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True),
|
831
|
+
])
|
832
|
+
|
833
|
+
def type_verify(uops:List[UOp]):
|
834
|
+
for i,u in enumerate(uops):
|
835
|
+
if not spec.rewrite(u):
|
836
|
+
print_uops(uops)
|
837
|
+
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}")
|
838
|
+
|
839
|
+
# *** uop helpers ***
|
840
|
+
|
841
|
+
def cast_float_to_bf16(x: UOp) -> UOp:
|
842
|
+
assert x.dtype == dtypes.float, "cast float -> bf16 must start with float"
|
843
|
+
x = x.bitcast(dtypes.uint)
|
844
|
+
x = (-x & 0x7f800000).where(x + ((x >> 16) & 1) + 0x7fff, (x & 0xffff).where((x | 0x10000), x))
|
845
|
+
return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
|
846
|
+
|
847
|
+
# *** most of symbolic lives here now ***
|
848
|
+
|
849
|
+
def split_uop(x:UOp, sep:Ops):
|
850
|
+
if x.op is sep:
|
851
|
+
for s in x.src: yield from split_uop(s, sep)
|
852
|
+
else: yield x
|
853
|
+
|
854
|
+
def mod_folding(x:UOp, c:int) -> Optional[UOp]:
|
855
|
+
# simplify x % c, None means no change
|
856
|
+
|
857
|
+
# simple cancel mod case
|
858
|
+
if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c
|
859
|
+
|
860
|
+
remainder, something_changed = [], False
|
861
|
+
for u in split_uop(x, Ops.ADD):
|
862
|
+
if (factor:=u.const_factor())%c != factor:
|
863
|
+
divides = u.divides(factor)*(factor%c)
|
864
|
+
assert divides is not None
|
865
|
+
remainder.append(divides)
|
866
|
+
something_changed = True
|
867
|
+
elif u.op is Ops.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0:
|
868
|
+
remainder.append(u.src[0])
|
869
|
+
something_changed = True
|
870
|
+
else: remainder.append(u)
|
871
|
+
if not something_changed: return None
|
872
|
+
return functools.reduce(operator.add, remainder)%c if remainder else x.const_like(0)
|
873
|
+
|
874
|
+
def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
875
|
+
# simplify x // c, None means no change
|
876
|
+
|
877
|
+
# simple cancel div case
|
878
|
+
if 0 <= x.vmin and x.vmax < c: return x.const_like(0)
|
879
|
+
|
880
|
+
quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1
|
881
|
+
for u in split_uop(x, Ops.ADD):
|
882
|
+
if u.op is Ops.CONST:
|
883
|
+
# add all const together first
|
884
|
+
if rem_const != 0: something_changed = True
|
885
|
+
rem_const += u.arg
|
886
|
+
elif (factor:=u.const_factor())%c == 0:
|
887
|
+
if factor:
|
888
|
+
divides = u.divides(c)
|
889
|
+
assert divides is not None
|
890
|
+
quotient.append(divides)
|
891
|
+
something_changed = True
|
892
|
+
else:
|
893
|
+
# divisor is the smallest common divisor of all MULs
|
894
|
+
if u.op is Ops.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor
|
895
|
+
remainder.append(u)
|
896
|
+
gcd = math.gcd(gcd, factor)
|
897
|
+
|
898
|
+
# handle the const
|
899
|
+
if rem_const%c != rem_const:
|
900
|
+
something_changed = True
|
901
|
+
quotient.append(x.const_like(rem_const//c))
|
902
|
+
rem_const = rem_const%c
|
903
|
+
if rem_const != 0: remainder.append(x.const_like(rem_const))
|
904
|
+
|
905
|
+
# x // c -> quotient + (remainder // div) // (c // div)
|
906
|
+
div = gcd if gcd > 1 else divisor
|
907
|
+
|
908
|
+
if not something_changed: return newx//(c//div) if 1 < div < c and (newx:=div_folding(x, div)) is not None else None
|
909
|
+
rem:Optional[UOp] = functools.reduce(operator.add, remainder) if remainder else None
|
910
|
+
quo:Optional[UOp] = functools.reduce(operator.add, quotient) if quotient else None
|
911
|
+
if quo is None: return x.const_like(0) if rem is None else cast(UOp, div_folding(rem, div))//(c//div)
|
912
|
+
return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo
|
913
|
+
|
914
|
+
def lt_folding(x:UOp, c:int) -> Optional[UOp]:
|
915
|
+
p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
|
916
|
+
if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
|
917
|
+
return cast(UOp, functools.reduce(operator.add, np).divides(d)).lt(c//d)
|
918
|
+
return None
|
919
|
+
|
920
|
+
def fold_unrolled_divs(divs:UOp):
|
921
|
+
# div pattern in unrolled arange
|
922
|
+
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
|
923
|
+
add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None
|
924
|
+
for u in add_chain:
|
925
|
+
if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
|
926
|
+
if denominator is None: denominator = u.src[1].arg
|
927
|
+
if denominator != u.src[1].arg: return None
|
928
|
+
# assumed CONST is the last of an ADD
|
929
|
+
if (s0:=u.src[0]).op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
|
930
|
+
seen_const.append(s0.src[1].arg)
|
931
|
+
s0 = s0.src[0]
|
932
|
+
else: seen_const.append(0)
|
933
|
+
if ans is None: ans = s0
|
934
|
+
if ans is not s0: return None
|
935
|
+
if denominator is None: return None
|
936
|
+
# the first (denominator-len(seen_const)) terms may have been folded to 0 already
|
937
|
+
for i in range(denominator-len(seen_const)):
|
938
|
+
if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
|
939
|
+
return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
|
940
|
+
|
941
|
+
def canonicalize_simplex(X:UOp) -> Optional[UOp]:
|
942
|
+
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
|
943
|
+
# returns x0 + x1 + ... in such case, or None if not
|
944
|
+
changed, ret = False, []
|
945
|
+
for u in split_uop(X, Ops.ADD):
|
946
|
+
# assumed the const is the last src of MUL
|
947
|
+
if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
|
948
|
+
changed = True
|
949
|
+
u = u.src[0]
|
950
|
+
if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None
|
951
|
+
ret.append(u)
|
952
|
+
return functools.reduce(operator.add, ret) if changed else None
|
953
|
+
|
954
|
+
def is_increasing(f:UOp) -> bool:
|
955
|
+
# is f a monotonically increasing function regards its input
|
956
|
+
if f.op in GroupOp.Irreducible: return True
|
957
|
+
if f.op is Ops.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
|
958
|
+
if f.op in (Ops.MUL, Ops.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
|
959
|
+
return False # False if not sure
|
960
|
+
|
961
|
+
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
|
962
|
+
# if it's X <= c, returns X, True, c
|
963
|
+
# if it's X >= c, returns X, False, c
|
964
|
+
|
965
|
+
# (X < c).ne(True) -> X >= c
|
966
|
+
if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
|
967
|
+
(s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
|
968
|
+
# X < c -> X <= c-1
|
969
|
+
if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1
|
970
|
+
raise ValueError(f"not able to parse {valid=}")
|
971
|
+
|
972
|
+
def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
973
|
+
# return None if valid is always False, otherwise the simplified uop (might be the same as input)
|
974
|
+
|
975
|
+
# first, parse valid into {expr: (lower_bound, upper_bound)}
|
976
|
+
bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
|
977
|
+
for stmt in split_uop(valid, Ops.AND):
|
978
|
+
try: expr, is_upper, c = parse_valid(stmt)
|
979
|
+
except ValueError: return uop # give up if we cannot parse the valid
|
980
|
+
bounds[expr][int(is_upper)] = c
|
981
|
+
|
982
|
+
# simplify uop given that valid is True
|
983
|
+
for expr,v in bounds.items():
|
984
|
+
# some expr has lower bound > upper bound -> valid is an empty set and we return None
|
985
|
+
if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
|
986
|
+
|
987
|
+
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
988
|
+
candidates = []
|
989
|
+
if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
|
990
|
+
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
991
|
+
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
|
992
|
+
# try checking the whole clause
|
993
|
+
if expr in uop.sparents:
|
994
|
+
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
|
995
|
+
|
996
|
+
for candidate in candidates:
|
997
|
+
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
|
998
|
+
newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
|
999
|
+
if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
|
1000
|
+
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
|
1001
|
+
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
|
1002
|
+
elif all_same(newuops): uop = newuops[0]
|
1003
|
+
|
1004
|
+
return uop
|
1005
|
+
|
1006
|
+
def _valid_priority(v: UOp, valids:List[UOp]):
|
1007
|
+
# we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
|
1008
|
+
try: return sum(-1 if parse_valid(v)[0] in other.parents else 0 for other in valids)
|
1009
|
+
except ValueError: return 0
|
1010
|
+
|
1011
|
+
def simplify_valid(valid:UOp) -> Optional[UOp]:
|
1012
|
+
ret:List[UOp] = []
|
1013
|
+
something_changed = False
|
1014
|
+
valids = list(split_uop(valid, Ops.AND))
|
1015
|
+
for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
|
1016
|
+
ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
|
1017
|
+
if ret[-1] is not stmt: something_changed = True
|
1018
|
+
return functools.reduce(operator.and_, ret) if something_changed else None
|
1019
|
+
|
1020
|
+
def max_var_const(x:UOp, c1:UOp, c2:UOp):
|
1021
|
+
if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
|
1022
|
+
if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
|
1023
|
+
|
1024
|
+
symbolic_simple = PatternMatcher([
|
1025
|
+
# ** self folding **
|
1026
|
+
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
|
1027
|
+
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
|
1028
|
+
(UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
|
1029
|
+
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
|
1030
|
+
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
|
1031
|
+
(UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
|
1032
|
+
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
1033
|
+
((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
|
1034
|
+
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
|
1035
|
+
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
|
1036
|
+
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
|
1037
|
+
(UPat.var("x").maximum(UPat.var("x")), lambda x: x),
|
1038
|
+
((UPat.var("x") & UPat.var("x")), lambda x: x),
|
1039
|
+
((UPat.var("x") | UPat.var("x")), lambda x: x),
|
1040
|
+
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
|
1041
|
+
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
|
1042
|
+
# ** zero folding **
|
1043
|
+
(UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
|
1044
|
+
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
|
1045
|
+
lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints)
|
1046
|
+
# x*0 -> 0 or 0*x -> 0
|
1047
|
+
# if x is nan or inf it should render the nan value.
|
1048
|
+
# NOTE: this can be wrong for loaded NaN
|
1049
|
+
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
1050
|
+
# ** constant folding **
|
1051
|
+
(UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))), lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False))),
|
1052
|
+
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
1053
|
+
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
|
1054
|
+
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
|
1055
|
+
(UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
|
1056
|
+
# *** cast ***
|
1057
|
+
(UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
|
1058
|
+
(UPat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
|
1059
|
+
])
|
1060
|
+
|
1061
|
+
symbolic = symbolic_simple+PatternMatcher([
|
1062
|
+
# ** COMMUTATIVE flipping **
|
1063
|
+
(UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
1064
|
+
# group like
|
1065
|
+
((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
|
1066
|
+
# ** boolean algebra **
|
1067
|
+
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
|
1068
|
+
# ** combine terms **
|
1069
|
+
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
|
1070
|
+
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
|
1071
|
+
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
|
1072
|
+
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
|
1073
|
+
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
|
1074
|
+
# a conditional with the same results either way is a noop, also fold const conditionals
|
1075
|
+
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
1076
|
+
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
1077
|
+
# ALU min==max -> CONST (slow!)
|
1078
|
+
(UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
1079
|
+
# max folding
|
1080
|
+
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
1081
|
+
# TODO: why does this rule break beautiful_mnist?
|
1082
|
+
#((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z),
|
1083
|
+
((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
|
1084
|
+
# ** two stage ALU folding **
|
1085
|
+
((UPat.var("x") + UPat.cvar("c1")) + UPat.cvar("c2"), lambda x,c1,c2: x+(c1+c2)),
|
1086
|
+
((UPat.var("x") * UPat.cvar("c1")) * UPat.cvar("c2"), lambda x,c1,c2: x*(c1*c2)),
|
1087
|
+
((UPat.var("x") & UPat.cvar("c1")) & UPat.cvar("c2"), lambda x,c1,c2: x&(c1&c2)),
|
1088
|
+
((UPat.var("x") | UPat.cvar("c1")) | UPat.cvar("c2"), lambda x,c1,c2: x|(c1|c2)),
|
1089
|
+
((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
|
1090
|
+
((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
|
1091
|
+
# ** lt **
|
1092
|
+
# c0*x<c1 for positive int c0,c1
|
1093
|
+
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
|
1094
|
+
lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if c0.arg > 0 and c1.arg > 0 else None),
|
1095
|
+
# c0*x<c1 for negative int c0 and non-positive c1
|
1096
|
+
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
|
1097
|
+
lambda x,c0,c1: (-x).lt(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
|
1098
|
+
# x//c0<c1 for positive int c0
|
1099
|
+
((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False)).lt(UPat.cvar("c1", vec=False)),
|
1100
|
+
lambda x,c0,c1: x.lt(c1.arg*c0.arg) if c0.arg > 0 else None),
|
1101
|
+
# mul add lt
|
1102
|
+
(((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
|
1103
|
+
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
|
1104
|
+
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
|
1105
|
+
(UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
|
1106
|
+
(UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
|
1107
|
+
# *** rules from symbolic ***
|
1108
|
+
# unrolled arange div folding
|
1109
|
+
(UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),
|
1110
|
+
# generic lt folding
|
1111
|
+
(UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
1112
|
+
# canonicalize a simplex with positive coefficients > 0
|
1113
|
+
# not x < 1 -> X > 0
|
1114
|
+
(UPat.var("x", dtypes.ints).lt(1).ne(True), lambda x: newx.lt(1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
|
1115
|
+
# ** div **
|
1116
|
+
# # div folding
|
1117
|
+
(UPat.var("x", dtypes.sints) // UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=div_folding(x,c.arg)) is not None else None),
|
1118
|
+
# ** mod **
|
1119
|
+
# mod folding
|
1120
|
+
(UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
|
1121
|
+
])
|
1122
|
+
|
1123
|
+
symbolic_flat = symbolic+PatternMatcher([
|
1124
|
+
# ** combine terms (opinionated) **
|
1125
|
+
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
1126
|
+
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
|
1127
|
+
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
1128
|
+
])
|
1129
|
+
|
1130
|
+
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
1131
|
+
|
1132
|
+
# for debug
|
1133
|
+
syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>",
|
1134
|
+
Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
|
1135
|
+
renderer = PatternMatcher([
|
1136
|
+
(UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
|
1137
|
+
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")),
|
1138
|
+
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
|
1139
|
+
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
1140
|
+
(UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
|
1141
|
+
(UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
|
1142
|
+
(UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
|
1143
|
+
(UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
|
1144
|
+
(UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.op]}{x.src[1].arg})")),
|
1145
|
+
])
|
1146
|
+
|
1147
|
+
# *** what was symbolic.py ***
|
1148
|
+
|
1149
|
+
sint = Union[int, UOp]
|
1150
|
+
Variable = UOp
|
1151
|
+
|
1152
|
+
ConstLike = Union[ConstType, Variable, Tuple[ConstType, ...]]
|