tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/ops.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import Any,
|
3
|
-
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect
|
2
|
+
from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type, get_args
|
3
|
+
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
|
4
4
|
from enum import auto, IntEnum, Enum
|
5
5
|
from dataclasses import dataclass, field
|
6
|
-
from
|
7
|
-
from
|
8
|
-
from tinygrad.
|
9
|
-
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T
|
6
|
+
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
|
7
|
+
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, _METADATA, flatten
|
8
|
+
from tinygrad.helpers import PICKLE_BUFFERS, SPLIT_REDUCEOP, DEBUG, dedup
|
10
9
|
if TYPE_CHECKING:
|
11
10
|
from tinygrad.shape.shapetracker import ShapeTracker
|
11
|
+
from tinygrad.device import Buffer
|
12
12
|
|
13
13
|
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
|
14
14
|
class FastEnum(IntEnum):
|
@@ -26,8 +26,7 @@ class SimpleMathTrait:
|
|
26
26
|
def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
|
27
27
|
def logical_not(self): return self.ne(True)
|
28
28
|
def neg(self):
|
29
|
-
dtype
|
30
|
-
assert dtype is not None, "MathTraits __neg__ requires a dtype"
|
29
|
+
if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
|
31
30
|
return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
|
32
31
|
def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse)
|
33
32
|
def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse)
|
@@ -35,6 +34,7 @@ class SimpleMathTrait:
|
|
35
34
|
def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse)
|
36
35
|
def xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse)
|
37
36
|
def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse)
|
37
|
+
def mod(self, x, reverse=False): return self._binop(Ops.MOD, x, reverse)
|
38
38
|
def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
|
39
39
|
def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
|
40
40
|
|
@@ -44,7 +44,8 @@ class SimpleMathTrait:
|
|
44
44
|
def __sub__(self, x): return self.sub(x)
|
45
45
|
def __mul__(self, x): return self.mul(x)
|
46
46
|
def __truediv__(self, x): return self.div(x)
|
47
|
-
def __floordiv__(self, x): return self.idiv(x)
|
47
|
+
def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv
|
48
|
+
def __mod__(self, x): return self.mod(x)
|
48
49
|
def __and__(self, x): return self.bitwise_and(x)
|
49
50
|
def __or__(self, x): return self.bitwise_or(x)
|
50
51
|
def __xor__(self, x): return self.xor(x)
|
@@ -57,22 +58,19 @@ class SimpleMathTrait:
|
|
57
58
|
def __rand__(self, x): return self.bitwise_and(x, True)
|
58
59
|
def __ror__(self, x): return self.bitwise_or(x, True)
|
59
60
|
def __rxor__(self, x): return self.xor(x, True)
|
61
|
+
def __rmod__(self, x): return self.mod(x, True)
|
62
|
+
|
63
|
+
def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
|
64
|
+
def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
|
65
|
+
def __ge__(self, x): return (self < x).logical_not()
|
66
|
+
def __le__(self, x): return (self > x).logical_not()
|
60
67
|
|
61
|
-
def lt(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
|
62
|
-
def gt(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
|
63
68
|
def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x))
|
64
|
-
def ge(self, x): return self.lt(x).logical_not()
|
65
|
-
def le(self, x): return self.gt(x).logical_not()
|
66
69
|
def eq(self, x): return self.ne(x).logical_not()
|
67
|
-
|
68
|
-
def __lt__(self, x): return self.lt(x)
|
69
|
-
def __gt__(self, x): return self.gt(x)
|
70
70
|
def __ne__(self, x): return self.ne(x)
|
71
|
-
def __ge__(self, x): return self.ge(x)
|
72
|
-
def __le__(self, x): return self.le(x)
|
73
71
|
# NOTE: __eq__ isn't overridden, and means the same thing as is by default
|
74
72
|
|
75
|
-
class MathTrait(SimpleMathTrait):
|
73
|
+
class MathTrait(SimpleMathTrait):
|
76
74
|
# TODO: move to Tensor when new backward is done
|
77
75
|
def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse)
|
78
76
|
def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse)
|
@@ -81,10 +79,6 @@ class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
|
|
81
79
|
def __rlshift__(self, x): return self.lshift(x, True)
|
82
80
|
def __rrshift__(self, x): return self.rshift(x, True)
|
83
81
|
|
84
|
-
# not in Tensor
|
85
|
-
def __mod__(self, x): return self.alu(Ops.MOD, self.ufix(x))
|
86
|
-
def __rmod__(self, x): return self.ufix(x).alu(Ops.MOD, self)
|
87
|
-
|
88
82
|
def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
|
89
83
|
def minimum(self, x): return -(-self).maximum(-x)
|
90
84
|
def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y))
|
@@ -94,118 +88,126 @@ class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
|
|
94
88
|
def sin(self): return self.alu(Ops.SIN)
|
95
89
|
def log2(self): return self.alu(Ops.LOG2)
|
96
90
|
def exp2(self): return self.alu(Ops.EXP2)
|
91
|
+
def pow(self, x): return self.alu(Ops.POW, self.ufix(x))
|
97
92
|
|
98
93
|
# the order of these Ops controls the order of the toposort
|
99
94
|
class Ops(FastEnum):
|
100
95
|
# uops that aren't rendered
|
101
|
-
SINK = auto()
|
102
|
-
CONTIGUOUS = auto()
|
103
|
-
PRELOAD = auto()
|
96
|
+
NAME = auto(); SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); KERNEL = auto(); UNIQUE = auto() # noqa: E702
|
104
97
|
|
105
|
-
#
|
106
|
-
COPY = auto()
|
98
|
+
# TODO: empty continues to exist because of tensor
|
107
99
|
EMPTY = auto()
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
100
|
+
|
101
|
+
# MetaOps
|
102
|
+
COPY = auto(); BUFFER_VIEW = auto() # noqa: E702
|
103
|
+
|
104
|
+
# blocks in linearizer
|
105
|
+
BLOCK = auto(); BLOCKSTART = auto(); BLOCKFORK = auto(); BLOCKEND = auto() # noqa: E702
|
106
|
+
|
107
|
+
# movement ops!
|
108
|
+
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
|
109
|
+
|
110
|
+
# misc ops
|
111
|
+
UNROLL = auto(); CONTRACT = auto() # noqa: E702
|
112
|
+
VIEW = auto(); DEFINE_GLOBAL = auto(); BUFFER = auto() # noqa: E702
|
113
|
+
DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
|
114
|
+
VALID = auto(); SPECIAL = auto(); NOOP = auto() # noqa: E702
|
121
115
|
|
122
116
|
# reduce
|
123
117
|
REDUCE_AXIS = auto()
|
124
118
|
|
125
119
|
# helper ops
|
126
|
-
GEP = auto()
|
127
|
-
VECTORIZE = auto()
|
120
|
+
GEP = auto(); VECTORIZE = auto(); CAT = auto() # noqa: E702
|
128
121
|
|
129
122
|
# UnaryOps
|
130
123
|
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
131
124
|
|
132
|
-
#
|
133
|
-
LOAD = auto()
|
125
|
+
# load/store before math
|
126
|
+
LOAD = auto(); STORE = auto() # noqa: E702
|
127
|
+
|
128
|
+
# early INDEX
|
129
|
+
INDEX = auto()
|
134
130
|
|
135
131
|
# math ops
|
136
132
|
WMMA = auto()
|
137
133
|
|
138
134
|
# BinaryOps
|
139
135
|
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
140
|
-
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702
|
136
|
+
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
|
141
137
|
|
142
138
|
# TernaryOps
|
143
139
|
WHERE = auto(); MULACC = auto() # noqa: E702
|
144
140
|
|
145
141
|
# assignment ops
|
146
|
-
STORE = auto()
|
147
142
|
ASSIGN = auto()
|
148
143
|
BIND = auto()
|
149
144
|
|
150
|
-
# late INDEX
|
151
|
-
INDEX = auto()
|
152
|
-
|
153
145
|
# control flow ops
|
154
|
-
BARRIER = auto()
|
155
|
-
IF = auto()
|
156
|
-
RANGE = auto()
|
157
|
-
|
158
|
-
# ops that are not graph nodes
|
159
|
-
ENDRANGE = auto()
|
160
|
-
ENDIF = auto()
|
146
|
+
BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702
|
161
147
|
|
162
148
|
# consts last!
|
163
|
-
VCONST = auto()
|
164
|
-
|
149
|
+
VCONST = auto(); CONST = auto() # noqa: E702
|
150
|
+
|
151
|
+
# device
|
152
|
+
DEVICE = auto()
|
153
|
+
MULTI = auto()
|
154
|
+
CUSTOM = auto()
|
165
155
|
|
166
156
|
class GroupOp:
|
167
157
|
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
|
168
158
|
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
|
169
|
-
Ops.SUB, Ops.FDIV}
|
159
|
+
Ops.SUB, Ops.FDIV, Ops.POW}
|
170
160
|
Ternary = {Ops.WHERE, Ops.MULACC}
|
171
161
|
ALU = set.union(Unary, Binary, Ternary)
|
172
162
|
|
173
163
|
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
|
164
|
+
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
|
174
165
|
|
175
|
-
|
176
|
-
|
177
|
-
Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
|
166
|
+
Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
|
167
|
+
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
|
178
168
|
|
179
169
|
# BinaryOps that can be flipped
|
180
170
|
Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR}
|
181
171
|
|
172
|
+
# BinaryOps where f(f(a,b),c) = f(a,f(b,c))
|
173
|
+
Associative = {Ops.ADD, Ops.MUL, Ops.AND, Ops.OR, Ops.MAX}
|
174
|
+
|
175
|
+
# BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence
|
176
|
+
Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
|
177
|
+
|
182
178
|
# do not preserve f(0) = 0
|
183
|
-
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
|
179
|
+
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
|
184
180
|
|
185
|
-
|
186
|
-
def identity_element(op:Ops, dt:DType): return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
181
|
+
All = set(Ops)
|
187
182
|
|
188
|
-
|
183
|
+
# some BUFFER ops can be processed with only a view
|
184
|
+
view_supported_devices = {"LLVM", "CPU", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
|
189
185
|
|
190
|
-
|
186
|
+
# https://en.wikipedia.org/wiki/Identity_element
|
187
|
+
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
188
|
+
|
189
|
+
def can_pad(u:UOp, edges:dict[UOp, None], cache:dict[UOp, None]) -> bool:
|
190
|
+
if u.op in GroupOp.UnsafePad: return False
|
191
|
+
if u in edges or u in cache: return True
|
192
|
+
cache[u] = None
|
193
|
+
return all(can_pad(x.base, edges, cache) for x in u.src)
|
191
194
|
|
192
195
|
# With True as the default, this matches the old symbolic behavior
|
193
|
-
def resolve(x, default:bool=True):
|
194
|
-
if
|
195
|
-
assert x.dtype
|
196
|
+
def resolve(x:UOp|bool, default:bool=True):
|
197
|
+
if isinstance(x, bool): return x
|
198
|
+
assert x.dtype == dtypes.bool, "UOp in resolve must be bool"
|
196
199
|
# NOTE: generating the text for the exception is expensive, so we do this
|
197
200
|
return bool(sx.vmin) if (sx:=x.simplify()).vmin == sx.vmax else default
|
198
201
|
|
199
202
|
# smax/smin are replacements for max/min that preserve symbolic
|
200
203
|
def _suop(lst, uop_fxn, python_fxn):
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
def
|
205
|
-
def smin(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.minimum, min)
|
204
|
+
uops, nums = partition(lst, lambda x: isinstance(x, UOp))
|
205
|
+
return ssimplify(functools.reduce(uop_fxn, uops + ([python_fxn(nums)] if nums else [])))
|
206
|
+
def smax(*lst): return _suop(argfix(*lst), UOp.maximum, max)
|
207
|
+
def smin(*lst): return _suop(argfix(*lst), UOp.minimum, min)
|
206
208
|
|
207
209
|
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
208
|
-
def sym_infer(uop: Union[UOp, int], var_vals:
|
210
|
+
def sym_infer(uop: Union[UOp, int], var_vals: dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
209
211
|
|
210
212
|
# used for UOp and UPat
|
211
213
|
def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
|
@@ -219,57 +221,108 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
|
|
219
221
|
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
|
220
222
|
|
221
223
|
class UOpMetaClass(type):
|
222
|
-
ucache:
|
223
|
-
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:
|
224
|
-
if (
|
225
|
-
UOpMetaClass.ucache[key] =
|
226
|
-
|
227
|
-
|
224
|
+
ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
|
225
|
+
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, _buffer:Buffer|None=None):
|
226
|
+
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
|
227
|
+
UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
|
228
|
+
for s in src: s.children.add(ref)
|
229
|
+
# NOTE: this will soon be set by Tensor once we remove function.py
|
230
|
+
if (metadata:=_METADATA.get()) is not None: all_metadata[created] = metadata
|
231
|
+
# NOTE: this value is set by pickle when pickling a realized tensor
|
232
|
+
if _buffer is not None:
|
233
|
+
assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
|
234
|
+
buffers[created] = _buffer
|
235
|
+
return created
|
236
|
+
|
237
|
+
# some uops map to other stuff
|
238
|
+
buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers
|
239
|
+
all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary()
|
240
|
+
|
241
|
+
# NOTE: this should be frozen, but frozen is slower
|
242
|
+
@dataclass(eq=False, slots=True)
|
228
243
|
class UOp(MathTrait, metaclass=UOpMetaClass):
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
244
|
+
op:Ops
|
245
|
+
dtype:DType = dtypes.void
|
246
|
+
src:tuple[UOp, ...] = tuple()
|
247
|
+
arg:Any = None
|
248
|
+
children:set[weakref.ref[UOp]] = field(default_factory=set)
|
249
|
+
def __del__(self):
|
250
|
+
if self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1)
|
251
|
+
if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg))) is not None:
|
252
|
+
for s in self.src: s.children.discard(ref)
|
253
|
+
del UOpMetaClass.ucache[k]
|
254
|
+
def __reduce__(self):
|
255
|
+
args = [self.op, self.dtype, self.src, self.arg]
|
256
|
+
if self.op is Ops.BUFFER and self.realized is not None and PICKLE_BUFFERS: args.append(self.realized)
|
257
|
+
return UOp, tuple(args)
|
234
258
|
def replace(self, **kwargs) -> UOp:
|
235
|
-
|
236
|
-
|
259
|
+
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg))
|
260
|
+
assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}"
|
237
261
|
if (self.op, self.dtype, self.src, self.arg) == new_args: return self
|
238
262
|
return UOp(*new_args)
|
239
263
|
@functools.cached_property
|
240
264
|
def key(self) -> bytes:
|
241
265
|
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
|
242
266
|
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
243
|
-
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else self.arg
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
267
|
+
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else repr(self.arg)
|
268
|
+
|
269
|
+
@property
|
270
|
+
def toposort(self) -> dict[UOp, None]:
|
271
|
+
def _toposort(u:UOp, cache:set[UOp]):
|
272
|
+
if u in cache: return {}
|
273
|
+
nodes: dict[UOp, None] = {}
|
274
|
+
# NOTE: this is a lot faster than the comprehension in parents
|
275
|
+
for parent in u.src: nodes.update(_toposort(parent, cache))
|
276
|
+
nodes[u] = None
|
277
|
+
cache.add(u)
|
278
|
+
return nodes
|
279
|
+
return _toposort(self, cache=set())
|
248
280
|
|
249
281
|
@functools.cached_property
|
250
|
-
def tuplize(self:UOp) ->
|
251
|
-
return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src))
|
282
|
+
def tuplize(self:UOp) -> tuple[int, Any, Optional[DType], tuple]: return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src))
|
252
283
|
|
253
284
|
# *** uop shape stuff ***
|
254
285
|
|
255
|
-
@property
|
256
|
-
def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR}
|
257
286
|
@functools.cached_property
|
258
|
-
def st(self) ->
|
259
|
-
if not self.has_st: return None
|
260
|
-
if self.op in GroupOp.Buffer: return self.st_arg
|
261
|
-
if self.op is Ops.VIEW: return self.arg
|
262
|
-
src_sts = [x.st for x in self.src if x.st is not None]
|
263
|
-
assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
|
287
|
+
def st(self) -> ShapeTracker|None:
|
264
288
|
from tinygrad.shape.shapetracker import ShapeTracker
|
265
|
-
|
289
|
+
if self.op is Ops.MULTI:
|
290
|
+
return ShapeTracker.from_shape(
|
291
|
+
tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)))
|
292
|
+
if self.op is Ops.BUFFER: return ShapeTracker.from_shape((self.size,))
|
293
|
+
if self.op is Ops.KERNEL: return ShapeTracker.from_shape(self.arg.ast.shape)
|
294
|
+
# these ops define a ShapeTracker from the arg
|
295
|
+
if self.op is Ops.VIEW: return self.arg
|
296
|
+
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
|
297
|
+
# buffer ops return the ShapeTracker from sources
|
298
|
+
if self.op in GroupOp.Buffer: return vsrc[0] if len(vsrc:=[x.st for x in self.src if x.op is Ops.VIEW]) != 0 else None
|
299
|
+
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
|
300
|
+
assert all_same([x.shape for x in src_sts]), f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}"
|
301
|
+
if self.op in {Ops.BITCAST, Ops.BUFFER_VIEW}:
|
302
|
+
shape = src_sts[0].shape
|
303
|
+
if self.dtype.itemsize != (input_sz:=self.src[0].dtype.itemsize): shape = shape[:-1]+((shape[-1]*input_sz) // self.dtype.itemsize,)
|
304
|
+
# only reduce ops are allowed to change shape, everything else derives shape from sources
|
305
|
+
elif self.op in {Ops.REDUCE_AXIS, Ops.WMMA}: shape = src_sts[0].reduce(self.axis_arg)
|
306
|
+
else: shape = src_sts[0].shape
|
307
|
+
return ShapeTracker.from_shape(shape)
|
308
|
+
|
266
309
|
@functools.cached_property
|
267
|
-
def full_shape(self) ->
|
268
|
-
|
310
|
+
def full_shape(self) -> tuple[sint, ...]:
|
311
|
+
if self.op is Ops.VIEW: return self.shape
|
312
|
+
# TODO: this should check if st is None, it cannot because local reduce has implicit movement ops
|
313
|
+
return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} \
|
314
|
+
# TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this
|
315
|
+
and not (x.op is Ops.CONST and x.st is None)]))
|
316
|
+
@property
|
317
|
+
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
|
318
|
+
@property
|
319
|
+
def size(self) -> int: return self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
|
269
320
|
|
270
321
|
# *** uop evaluation ***
|
271
322
|
|
272
323
|
def simplify(self):
|
324
|
+
# late import!
|
325
|
+
from tinygrad.codegen.symbolic import symbolic
|
273
326
|
with Context(TRACK_MATCH_STATS=0):
|
274
327
|
return graph_rewrite(self, symbolic)
|
275
328
|
def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
|
@@ -282,34 +335,37 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
282
335
|
def __bool__(self): return self._eval((dtypes.bool,), bool)
|
283
336
|
def __int__(self): return self._eval(dtypes.ints, int)
|
284
337
|
def __float__(self): return self._eval(dtypes.floats, float)
|
285
|
-
def substitute(self, dvars:
|
338
|
+
def substitute(self, dvars:dict[UOp, UOp]):
|
286
339
|
with Context(TRACK_MATCH_STATS=0):
|
287
|
-
return graph_rewrite(self, _substitute, dvars)
|
340
|
+
return graph_rewrite(self, _substitute, dvars, bottom_up=True)
|
288
341
|
|
289
342
|
# *** uop syntactic sugar ***
|
290
343
|
|
291
344
|
@property
|
292
345
|
def st_arg(self) -> ShapeTracker:
|
293
346
|
assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}"
|
294
|
-
|
295
|
-
assert ret.op is Ops.VIEW, f"st_arg trying to return {ret}"
|
296
|
-
return ret.arg
|
347
|
+
return unwrap(self.st)
|
297
348
|
@property
|
298
|
-
def axis_arg(self) ->
|
349
|
+
def axis_arg(self) -> tuple[int, ...]:
|
299
350
|
assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
|
300
351
|
ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
|
301
352
|
assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
|
302
353
|
return ret
|
303
354
|
def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
|
304
|
-
def
|
305
|
-
def
|
355
|
+
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
356
|
+
def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
357
|
+
def const_like(self, b:ConstLike):
|
358
|
+
# constants can optionally have a DEVICE source
|
359
|
+
if self._device is None: return UOp.const(self.dtype, b)
|
360
|
+
if isinstance(self.device, tuple): return UOp.multi(*[UOp.metaop(Ops.CONST, self.shape, self.dtype, d, b) for d in self.device], axis=None)
|
361
|
+
return UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b)
|
306
362
|
def broadcast(self, count:int):
|
307
363
|
assert self.dtype.count == 1
|
308
364
|
if count == 1: return self
|
309
365
|
return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
|
310
|
-
def cast(self, dtype:DType): return UOp(Ops.CAST, dtype, (self,))
|
311
|
-
def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
|
312
|
-
def gep(self, i:Union[
|
366
|
+
def cast(self, dtype:DType): return self if self.dtype == dtype else UOp(Ops.CAST, dtype, (self,))
|
367
|
+
def bitcast(self, dtype:DType): return self if self.dtype == dtype else UOp(Ops.BITCAST, dtype, (self,))
|
368
|
+
def gep(self, i:Union[tuple[int, ...], int]):
|
313
369
|
if isinstance(i, int):
|
314
370
|
# NOTE: these are just shortcuts to not have to create and fold later
|
315
371
|
if self.op is Ops.VECTORIZE: return self.src[i]
|
@@ -322,42 +378,192 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
322
378
|
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
323
379
|
def alu(self, arg, *src:UOp):
|
324
380
|
out_dtype = (self, *src)[-1].dtype
|
325
|
-
if arg in {Ops.CMPLT, Ops.CMPNE}
|
326
|
-
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
381
|
+
if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
327
382
|
return UOp(arg, out_dtype, (self,)+src)
|
328
383
|
@staticmethod
|
329
384
|
def const(dtype:DType, b:ConstLike):
|
330
385
|
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
331
386
|
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
332
|
-
return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype)
|
387
|
+
return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
|
388
|
+
def valid(self, st:ShapeTracker):
|
389
|
+
assert self.op in {Ops.CONST, Ops.DEFINE_VAR}, f"can only create VALID from a constant, got {self.op}"
|
390
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
391
|
+
# NOTE: only VALID has a masked ShapeTracker, the CONST operands are unmasked
|
392
|
+
unmasked_st = ShapeTracker.from_shape(()).reshape((1,)*len(st.shape)).expand(st.shape).to_uop()
|
393
|
+
return UOp(Ops.VALID, dtypes.bool, (st.to_uop(),)).where(self.replace(src=(unmasked_st,)), UOp.const(self.dtype, 0).replace(src=(unmasked_st,)))
|
333
394
|
@staticmethod
|
334
|
-
def range(dtype:DType, start:
|
335
|
-
|
336
|
-
|
337
|
-
|
395
|
+
def range(dtype:DType, start:sint, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(start), sint_to_uop(end)), arg=idx)
|
396
|
+
def _reduce_op(self, op:Ops, axis:tuple[int, ...]):
|
397
|
+
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
398
|
+
return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
|
399
|
+
def r(self, op:Ops, axis:tuple[int, ...]) -> UOp:
|
400
|
+
new_shape = unwrap(self.st).reduce(axis)
|
401
|
+
|
402
|
+
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
403
|
+
# TODO: this shouldn't be here, it belongs in scheduler! that's why it broke multi
|
404
|
+
if not SPLIT_REDUCEOP or isinstance(self._device, tuple) or not all_int(self.shape) or (0 in self.shape) or \
|
405
|
+
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
|
406
|
+
return self._reduce_op(op, axis)
|
407
|
+
|
408
|
+
# if there are few globals, make some reduces into globals by splitting into two kernels
|
409
|
+
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
|
410
|
+
# ~2**10 should be enough if GROUP is used
|
411
|
+
# 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
|
412
|
+
# split is moved to the end to provide maximum locality for the second phase reduce.
|
413
|
+
self_real_strides = unwrap(self.st).real_strides(ignore_valid=True)
|
414
|
+
split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
|
415
|
+
if self.shape[i] % x == 0 and self_real_strides[i] != 0]
|
416
|
+
if not split_candidates: return self._reduce_op(op, axis)
|
417
|
+
dim_to_split, divisor = split_candidates[0]
|
418
|
+
splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
|
419
|
+
splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
|
420
|
+
if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
|
421
|
+
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
|
338
422
|
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
339
|
-
def contiguous(self): return
|
423
|
+
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
|
424
|
+
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
|
425
|
+
|
426
|
+
# *** from MultiLazyBuffer ***
|
427
|
+
|
428
|
+
def multi(self, *more:UOp, axis:int|None, real:tuple[bool,...]|None=None):
|
429
|
+
parents = (self,)+more
|
430
|
+
assert all_same([x.dtype for x in parents]), "multi parents must have the same dtype"
|
431
|
+
return UOp(Ops.MULTI, self.dtype, parents, (axis, real if real is not None else (True,)*len(parents)))
|
432
|
+
|
433
|
+
@property
|
434
|
+
def bounds(self):
|
435
|
+
if self.axis is None: raise RuntimeError("bounds is not defined when axis is None")
|
436
|
+
return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.src], initial=0)))
|
437
|
+
|
438
|
+
@functools.cached_property
|
439
|
+
def axis(self) -> Optional[int]:
|
440
|
+
if self.op is Ops.MULTI: return self.arg[0]
|
441
|
+
# NOTE: they all have to share an axis, we always choose [-1]
|
442
|
+
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
|
443
|
+
src_axis = self.src[0].axis
|
444
|
+
if self.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in self.arg[1] else src_axis
|
445
|
+
if self.op is Ops.RESHAPE:
|
446
|
+
if src_axis is None: return None
|
447
|
+
arg_acc:list[sint] = list(itertools.accumulate(self.arg, operator.mul, initial=1))
|
448
|
+
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
|
449
|
+
# TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
|
450
|
+
return len(arg_acc) - arg_acc[::-1].index(prod(self.src[0].shape[:src_axis])) - 1
|
451
|
+
if self.op is Ops.PERMUTE: return self.arg.index(src_axis) if src_axis is not None else None
|
452
|
+
return src_axis
|
453
|
+
|
454
|
+
@property
|
455
|
+
def real(self):
|
456
|
+
assert self.op is Ops.MULTI
|
457
|
+
return self.arg[1]
|
458
|
+
|
340
459
|
@property
|
341
|
-
def
|
460
|
+
def real_lbs(self): return [lb for lb,r in zip(self.src, self.real) if r]
|
461
|
+
|
462
|
+
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> UOp:
|
463
|
+
if axis is None: lbs = [self] * len(devices)
|
464
|
+
else:
|
465
|
+
if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}")
|
466
|
+
# NOTE: this works for both even shards and uneven shards
|
467
|
+
sz = self.shape[axis] // len(devices)
|
468
|
+
sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]
|
469
|
+
lbs = []
|
470
|
+
for sz,off in zip(sizes, itertools.accumulate(sizes, initial=0)):
|
471
|
+
lbs.append(self.shrink(tuple((0,s) if i != axis else (off,off+sz) for i,s in enumerate(self.shape))))
|
472
|
+
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)]
|
473
|
+
return UOp.multi(*[lb.contiguous() for lb in sharded_lbs], axis=axis)
|
474
|
+
|
475
|
+
# *** from LazyBuffer ***
|
476
|
+
|
477
|
+
@staticmethod
|
478
|
+
def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None) -> UOp:
|
479
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
480
|
+
# Tensor const is CONST(VIEW(DEVICE)) -> RESHAPE -> EXPAND
|
481
|
+
if op is Ops.CONST:
|
482
|
+
assert isinstance(arg, get_args(ConstType)), f"trying to create CONST with {arg=}"
|
483
|
+
return UOp.const(dtype, unwrap(arg)).replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),),
|
484
|
+
ShapeTracker.from_shape(())),)).reshape((1,)*len(shape)).expand(shape)
|
485
|
+
# Tensor variable binding is BIND(VAR(VIEW(DEVICE)), CONST(VIEW(DEVICE)))
|
486
|
+
if op is Ops.BIND:
|
487
|
+
var, val = arg.unbind()
|
488
|
+
return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
|
489
|
+
# otherwise it's just a RESHAPE(BUFFER)
|
490
|
+
if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
|
491
|
+
return UOp.new_buffer(device, size, dtype).reshape(shape)
|
492
|
+
def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp:
|
493
|
+
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
494
|
+
if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
|
495
|
+
# COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st)
|
496
|
+
ret = UOp(Ops.COPY, self.base.dtype, (UOp(Ops.DEVICE, arg=device), self.base), clone)
|
497
|
+
op_arg = []
|
498
|
+
mop = self
|
499
|
+
while mop is not self.base:
|
500
|
+
op_arg.append((mop.op, mop.arg))
|
501
|
+
mop = mop.src[0]
|
502
|
+
for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg)
|
503
|
+
return ret
|
504
|
+
def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
|
505
|
+
@property
|
506
|
+
def metadata(self) -> tuple[Metadata, ...]|Metadata|None: return self.arg.metadata if self.op is Ops.KERNEL else all_metadata.get(self, None)
|
342
507
|
|
343
508
|
# *** uop movement ops ***
|
344
509
|
|
345
510
|
@property
|
346
|
-
def base(self)
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
511
|
+
def base(self) -> UOp:
|
512
|
+
if (self.op is Ops.VIEW and len(self.src) != 0) or self.op in GroupOp.Movement: return self.src[0].base
|
513
|
+
return self
|
514
|
+
def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
|
515
|
+
|
516
|
+
def _mop(self, op:Ops, arg):
|
517
|
+
ret = UOp(op, self.dtype, (self,), arg)
|
518
|
+
if self.st == ret.st: return self # ignore NOOPs, also check ret.st
|
519
|
+
return ret
|
351
520
|
|
352
|
-
|
521
|
+
def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg)
|
522
|
+
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg)
|
523
|
+
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg)
|
524
|
+
def permute(self, arg:tuple[sint, ...]): return self._mop(Ops.PERMUTE, arg)
|
525
|
+
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg)
|
526
|
+
def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg)
|
527
|
+
|
528
|
+
# *** uop UNIQUE ***
|
353
529
|
|
530
|
+
# TODO: use this in Buffer
|
531
|
+
unique_num = itertools.count(0)
|
354
532
|
@staticmethod
|
355
|
-
def
|
533
|
+
def unique(): return UOp(Ops.UNIQUE, arg=next(UOp.unique_num))
|
356
534
|
|
535
|
+
# *** uop Buffer stuff ***
|
536
|
+
|
537
|
+
@staticmethod
|
538
|
+
def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device), UOp.unique()), size)
|
539
|
+
@property
|
540
|
+
def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
|
541
|
+
@functools.cached_property
|
542
|
+
def _device(self) -> Optional[str|tuple[str, ...]]:
|
543
|
+
if self.op is Ops.DEVICE: return self.arg
|
544
|
+
if self.op is Ops.MULTI: return tuple(cast(str, x.device) for x in self.src)
|
545
|
+
return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
|
357
546
|
@property
|
358
547
|
def buf_uop(self) -> UOp:
|
359
|
-
|
360
|
-
|
548
|
+
if self.base.op is Ops.BUFFER: return self.base
|
549
|
+
assert self.base.op in {*GroupOp.Buffer, Ops.ASSIGN}, f"buf_uop called on {self.op}"
|
550
|
+
return self.src[0].buf_uop
|
551
|
+
@property
|
552
|
+
def buffer(self) -> Buffer:
|
553
|
+
if self is not self.base:
|
554
|
+
assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
|
555
|
+
return self.src[0].buffer
|
556
|
+
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
|
557
|
+
if (cret:=buffers.get(self)) is not None: return cret
|
558
|
+
from tinygrad.device import Buffer
|
559
|
+
assert isinstance(self.device, str), f"buffer not supported on multi {self.device}"
|
560
|
+
buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base)
|
561
|
+
return ret
|
562
|
+
@property
|
563
|
+
def realized(self) -> Optional[Buffer]: return self.buffer if self.op is Ops.BUFFER else None
|
564
|
+
@property
|
565
|
+
def is_realized(self) -> bool:
|
566
|
+
return all(x.base.realized is not None for x in self.base.real_lbs) if self.base.op is Ops.MULTI else self.base.realized is not None
|
361
567
|
|
362
568
|
# *** uop Variable stuff ***
|
363
569
|
|
@@ -373,22 +579,28 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
373
579
|
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
374
580
|
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
|
375
581
|
return UOp(Ops.BIND, self.dtype, (self, self.const_like(val)))
|
376
|
-
def unbind(self) ->
|
582
|
+
def unbind(self) -> tuple[Variable, int]:
|
377
583
|
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
|
378
584
|
return self.src[0], self.src[1].arg
|
379
585
|
@property
|
380
586
|
def val(self) -> int: return self.unbind()[1]
|
381
|
-
def vars(self) ->
|
382
|
-
bound_vars = set([x for x in self.
|
587
|
+
def vars(self) -> set[UOp]:
|
588
|
+
bound_vars = set([x for x in self.toposort if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
|
383
589
|
bound_var_base = set(x.src[0] for x in bound_vars)
|
384
|
-
all_vars = set([x for x in self.
|
590
|
+
all_vars = set([x for x in self.toposort if x.op is Ops.DEFINE_VAR])
|
385
591
|
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
386
|
-
def variables(self) ->
|
387
|
-
st_vars:
|
592
|
+
def variables(self) -> list[Variable]:
|
593
|
+
st_vars: list[set[Variable]] = [x.st_arg.vars() for x in self.toposort if x.op in GroupOp.Buffer]
|
388
594
|
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
|
389
595
|
|
390
596
|
# *** uop symbolic stuff ***
|
391
597
|
|
598
|
+
def is_increasing(self:UOp) -> bool:
|
599
|
+
# is f a monotonically increasing function regards its input
|
600
|
+
if self.op in GroupOp.Irreducible: return True
|
601
|
+
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
|
602
|
+
if self.op in (Ops.MUL, Ops.IDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
|
603
|
+
return False # False if not sure
|
392
604
|
def const_factor(self) -> int:
|
393
605
|
"""largest known int that divides self"""
|
394
606
|
if self.op is Ops.CONST: return self.arg
|
@@ -396,7 +608,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
396
608
|
if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
397
609
|
if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
|
398
610
|
return 1
|
399
|
-
def divides(self, v) ->
|
611
|
+
def divides(self, v:int) -> UOp|None:
|
400
612
|
if v==1: return self
|
401
613
|
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
|
402
614
|
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
|
@@ -410,15 +622,23 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
410
622
|
@property
|
411
623
|
def vmax(self) -> ConstType: return self._min_max[1]
|
412
624
|
@functools.cached_property
|
413
|
-
def _min_max(self) ->
|
625
|
+
def _min_max(self) -> tuple[ConstType, ConstType]:
|
414
626
|
if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype):
|
415
627
|
(s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
|
416
628
|
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
|
417
629
|
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
|
418
|
-
|
419
|
-
if self.op is Ops.
|
420
|
-
|
421
|
-
|
630
|
+
# SHL/SHR on consts only
|
631
|
+
if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
|
632
|
+
if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2]
|
633
|
+
if self.op is Ops.MOD and s1_vmin > 0:
|
634
|
+
return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), s1_vmax-1)
|
635
|
+
if self.op is Ops.IDIV:
|
636
|
+
if s1_vmin == s1_vmax: # min/max are equal in a CONST
|
637
|
+
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
|
638
|
+
if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
|
639
|
+
# don't know exact bounds, but know the sign
|
640
|
+
if (s0_vmax <= 0 and s1_vmin < 0) or (s0_vmin >= 0 and s1_vmin > 0): return 0, dtypes.max(self.dtype)
|
641
|
+
if (s0_vmax <= 0 and s1_vmin > 0) or (s0_vmin >= 0 and s1_vmin < 0): return dtypes.min(self.dtype), 0
|
422
642
|
if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
|
423
643
|
if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
|
424
644
|
if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
|
@@ -431,9 +651,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
431
651
|
if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
|
432
652
|
if self.op is Ops.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
|
433
653
|
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
|
434
|
-
if self.op in {Ops.
|
435
|
-
# TODO:
|
436
|
-
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else
|
654
|
+
if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
|
655
|
+
# TODO: Ops.SPECIAL is Ops.DEFINE_VAR
|
656
|
+
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].vmax
|
437
657
|
if self.op is Ops.CONST: return self.arg, self.arg
|
438
658
|
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
|
439
659
|
return dtypes.min(self.dtype), dtypes.max(self.dtype)
|
@@ -441,11 +661,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
441
661
|
@functools.cached_property
|
442
662
|
def _sym_fxn(self):
|
443
663
|
sself = self.simplify()
|
444
|
-
varnames = tuple(x.arg[0] for x in sself.
|
664
|
+
varnames = tuple(x.arg[0] for x in sself.toposort if x.op is Ops.DEFINE_VAR)
|
445
665
|
# TODO: sanitize varnames, or don't use naked eval while staying fast
|
446
666
|
return eval("lambda "+','.join(varnames)+": "+sself.render()), varnames # pylint: disable=eval-used
|
447
667
|
|
448
|
-
def sym_infer(self, var_vals:
|
668
|
+
def sym_infer(self, var_vals:dict[UOp, int]):
|
449
669
|
fxn, varnames = self._sym_fxn
|
450
670
|
return fxn(**{k.arg[0]:v for k,v in var_vals.items() if k.arg[0] in varnames})
|
451
671
|
|
@@ -455,26 +675,29 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
455
675
|
|
456
676
|
@dataclass(frozen=True)
|
457
677
|
class KernelInfo:
|
678
|
+
name: str = "test" # name of the kernel
|
458
679
|
local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
|
459
|
-
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to
|
680
|
+
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL)
|
460
681
|
dont_use_locals: bool = False # don't use local indexing
|
461
682
|
|
462
|
-
#
|
683
|
+
# ******** ops in python ********
|
684
|
+
|
685
|
+
def safe_exp2(x):
|
686
|
+
try: return 2 ** x
|
687
|
+
except OverflowError: return math.inf
|
463
688
|
|
464
|
-
def
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
return wfxn
|
689
|
+
def safe_pow(x, y):
|
690
|
+
try: return math.nan if isinstance(p:=pow(x, y), complex) else p
|
691
|
+
except ZeroDivisionError: return math.inf
|
692
|
+
except ValueError: return math.inf if x > 0 else -math.inf
|
469
693
|
|
470
|
-
python_alu:
|
471
|
-
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2:
|
694
|
+
python_alu: dict[Ops, Callable] = {
|
695
|
+
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
|
472
696
|
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
473
|
-
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
|
474
|
-
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul,
|
475
|
-
Ops.
|
476
|
-
Ops.
|
477
|
-
Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift,
|
697
|
+
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow,
|
698
|
+
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
|
699
|
+
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
|
700
|
+
Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0,
|
478
701
|
Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z}
|
479
702
|
|
480
703
|
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
@@ -485,63 +708,33 @@ def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
|
485
708
|
|
486
709
|
# ***** uop helpers *****
|
487
710
|
|
488
|
-
def print_uops(uops:
|
711
|
+
def print_uops(uops:list[UOp]):
|
489
712
|
for i,u in enumerate(uops):
|
490
713
|
formatted_parents = [(uops.index(x) if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
|
491
714
|
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype):30s} " f"{str(formatted_parents):32s} {u.arg}")
|
492
715
|
|
493
|
-
def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
494
|
-
flops: sint = 0
|
495
|
-
mem: sint = 0
|
496
|
-
mults: sint = 1
|
497
|
-
mult_stack: List[sint] = []
|
498
|
-
dont_count: Set[UOp] = set()
|
499
|
-
if ignore_indexing:
|
500
|
-
for u in uops:
|
501
|
-
if u.op in {Ops.LOAD, Ops.STORE}:
|
502
|
-
dont_count = dont_count.union(u.src[0].sparents)
|
503
|
-
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].sparents)
|
504
|
-
elif u.op is Ops.IF:
|
505
|
-
dont_count = dont_count.union(u.src[0].sparents)
|
506
|
-
for u in uops:
|
507
|
-
if u.op is Ops.RANGE:
|
508
|
-
mult_stack.append(mults)
|
509
|
-
mults *= (u.src[1] - u.src[0]).ssimplify()
|
510
|
-
elif u.op is Ops.ENDRANGE:
|
511
|
-
mults = mult_stack.pop(-1)
|
512
|
-
elif u.op is Ops.SPECIAL:
|
513
|
-
mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
|
514
|
-
elif u.op is Ops.LOAD:
|
515
|
-
mem += u.dtype.itemsize * mults
|
516
|
-
elif u.op is Ops.STORE:
|
517
|
-
mem += u.src[1].dtype.itemsize * mults
|
518
|
-
elif u.op in GroupOp.ALU and u not in dont_count:
|
519
|
-
flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
520
|
-
elif u.op is Ops.WMMA and u not in dont_count:
|
521
|
-
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
522
|
-
return flops, mem
|
523
|
-
|
524
716
|
# ***** pattern matcher *****
|
525
717
|
|
526
|
-
def get_location() ->
|
718
|
+
def get_location() -> tuple[str, int]:
|
527
719
|
frm = sys._getframe(1)
|
528
720
|
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
|
529
|
-
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "
|
530
|
-
"lowerer.py", "cstyle.py"
|
721
|
+
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py",
|
722
|
+
"symbolic.py", "expander.py", "lowerer.py", "cstyle.py",
|
723
|
+
"linearize.py"}:
|
531
724
|
frm = frm.f_back
|
532
725
|
return frm.f_code.co_filename, frm.f_lineno
|
533
726
|
@functools.lru_cache(None)
|
534
|
-
def lines(fn) ->
|
727
|
+
def lines(fn) -> list[str]:
|
535
728
|
with open(fn) as f: return f.readlines()
|
536
729
|
|
537
730
|
class UPat(MathTrait):
|
538
|
-
__slots__ =
|
539
|
-
def __init__(self, op:Optional[Union[Ops,
|
540
|
-
src:Optional[Union[
|
541
|
-
name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[
|
731
|
+
__slots__ = ("op", "dtype", "arg", "name", "src")
|
732
|
+
def __init__(self, op:Optional[Union[Ops, tuple[Ops, ...], set[Ops]]]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None,
|
733
|
+
src:Optional[Union[tuple[UPat, ...], list[UPat], UPat]]=None, arg:Any=None,
|
734
|
+
name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[set[Ops]]=None):
|
542
735
|
assert op is None or isinstance(op, Ops) or isinstance(op, tuple) or isinstance(op, set), "op must be Ops or tuple of Ops"
|
543
|
-
self.op: Optional[
|
544
|
-
self.dtype: Optional[
|
736
|
+
self.op: Optional[tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
|
737
|
+
self.dtype: Optional[tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
|
545
738
|
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
|
546
739
|
self.src: Any = None
|
547
740
|
assert self.name != "ctx", "UPat can't be named ctx"
|
@@ -568,13 +761,12 @@ class UPat(MathTrait):
|
|
568
761
|
|
569
762
|
@staticmethod
|
570
763
|
@functools.lru_cache(None)
|
571
|
-
def var(name:Optional[str]=None, dtype:Optional[Union[DType,
|
764
|
+
def var(name:Optional[str]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name)
|
572
765
|
@staticmethod
|
573
766
|
@functools.lru_cache(None)
|
574
|
-
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True):
|
575
|
-
return UPat((Ops.CONST, Ops.VCONST) if vec else Ops.CONST, dtype=dtype, name=name)
|
767
|
+
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True): return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name)
|
576
768
|
@staticmethod
|
577
|
-
def const(dtype:Optional[Union[DType,
|
769
|
+
def const(dtype:Optional[Union[DType, tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
|
578
770
|
|
579
771
|
# copied from UOp
|
580
772
|
def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
@@ -589,7 +781,7 @@ class UPat(MathTrait):
|
|
589
781
|
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
|
590
782
|
def alu(self, op:Ops, *src:UPat):
|
591
783
|
asrc = (self,)+src
|
592
|
-
return UPat(op,
|
784
|
+
return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
|
593
785
|
|
594
786
|
def printable(self:UPat) -> str:
|
595
787
|
try: return lines(self.location[0])[self.location[1]-1].strip()
|
@@ -602,14 +794,14 @@ class UPat(MathTrait):
|
|
602
794
|
set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
|
603
795
|
return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
|
604
796
|
|
605
|
-
def match(self:UPat, uop:UOp, store:
|
797
|
+
def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
|
606
798
|
if (self.op is not None and uop.op not in self.op) or \
|
607
799
|
(self.name is not None and store.setdefault(self.name, uop) is not uop) or \
|
608
800
|
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
|
609
801
|
(self.arg is not None and self.arg != uop.arg) or \
|
610
802
|
(self.allowed_len != -1 and len(uop.src) != self.allowed_len): return []
|
611
803
|
if self.src is None: return [store]
|
612
|
-
res:
|
804
|
+
res: list[dict[str, UOp]] = []
|
613
805
|
for vp in self.src:
|
614
806
|
stores, new_stores = [store.copy()], []
|
615
807
|
for uu, vv in zip(uop.src, vp):
|
@@ -619,13 +811,11 @@ class UPat(MathTrait):
|
|
619
811
|
return res
|
620
812
|
|
621
813
|
class UPatAny(UPat):
|
622
|
-
def match(self:UPat, uop:UOp, store:
|
623
|
-
|
624
|
-
for x in
|
625
|
-
if (match:=x.match(uop, store.copy())): ret.extend(match)
|
626
|
-
return ret
|
814
|
+
def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
|
815
|
+
matches = [x.match(uop, store.copy()) for x in self.src[0]]
|
816
|
+
return flatten([x for x in matches if x is not None])
|
627
817
|
|
628
|
-
def deconstruct_function(fxn:Callable) ->
|
818
|
+
def deconstruct_function(fxn:Callable) -> tuple:
|
629
819
|
new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
|
630
820
|
for co in fxn.__code__.co_consts:
|
631
821
|
if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names})
|
@@ -635,10 +825,10 @@ def deconstruct_function(fxn:Callable) -> Tuple:
|
|
635
825
|
return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret
|
636
826
|
|
637
827
|
class PatternMatcher:
|
638
|
-
def __init__(self, patterns:
|
828
|
+
def __init__(self, patterns:list[tuple[UPat, Callable]]):
|
639
829
|
self.patterns = patterns
|
640
830
|
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
|
641
|
-
self.pdict:
|
831
|
+
self.pdict: dict[Ops, list[tuple[UPat, Callable, set, bool]]] = {}
|
642
832
|
# uop is required, arg is optional
|
643
833
|
for p,fxn in self.patterns:
|
644
834
|
assert p.op is not None
|
@@ -651,7 +841,7 @@ class PatternMatcher:
|
|
651
841
|
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
652
842
|
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
653
843
|
|
654
|
-
def rewrite(self, uop:UOp, ctx=None) ->
|
844
|
+
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
655
845
|
ler = {u.op for u in uop.src}
|
656
846
|
for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
|
657
847
|
if not early_reject.issubset(ler): continue
|
@@ -662,39 +852,34 @@ class PatternMatcher:
|
|
662
852
|
# *** tracking pattern matcher ***
|
663
853
|
|
664
854
|
TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
|
665
|
-
match_stats:
|
855
|
+
match_stats:dict[UPat, list[Union[int, float]]] = dict()
|
666
856
|
@dataclass(frozen=True)
|
667
|
-
class
|
668
|
-
loc:
|
669
|
-
sink: UOp
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
857
|
+
class TrackedGraphRewrite:
|
858
|
+
loc: tuple[str, int] # location that called graph_rewrite
|
859
|
+
sink: UOp # the sink input to graph_rewrite
|
860
|
+
bottom_up: bool
|
861
|
+
matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches
|
862
|
+
name: Optional[str] = None
|
863
|
+
tracked_keys:list[Any] = []
|
864
|
+
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
|
865
|
+
_name_cnt:dict[str, int] = {}
|
675
866
|
def track_rewrites(named=False):
|
676
867
|
def _decorator(func):
|
677
868
|
def __wrapper(self, *args, **kwargs):
|
678
869
|
if TRACK_MATCH_STATS >= 2:
|
679
|
-
if named:
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
|
684
|
-
return ret
|
870
|
+
if named: _name_cnt[func.__name__] = _name_cnt.get(func.__name__, 0)+1
|
871
|
+
tracked_keys.append(f"{func.__name__}_{_name_cnt[func.__name__]}" if named else self)
|
872
|
+
tracked_ctxs.append([])
|
873
|
+
return func(self, *args, **kwargs)
|
685
874
|
return __wrapper
|
686
875
|
return _decorator
|
687
876
|
|
688
877
|
class TrackedPatternMatcher(PatternMatcher):
|
689
|
-
def
|
690
|
-
super().__init__(patterns)
|
691
|
-
for p,_ in self.patterns:
|
692
|
-
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
693
|
-
|
694
|
-
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
|
878
|
+
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
695
879
|
ret = None
|
696
880
|
ler = {u.op for u in uop.src}
|
697
881
|
for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
|
882
|
+
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
698
883
|
st = time.perf_counter()
|
699
884
|
if not early_reject.issubset(ler):
|
700
885
|
match_stats[p][2] += time.perf_counter()-st
|
@@ -705,10 +890,9 @@ class TrackedPatternMatcher(PatternMatcher):
|
|
705
890
|
match_stats[p][0] += 1
|
706
891
|
match_stats[p][3] += (et:=time.perf_counter()-st)
|
707
892
|
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
|
708
|
-
if TRACK_MATCH_STATS >= 2 and len(
|
893
|
+
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0: tracked_ctxs[-1][-1].matches.append((uop, ret, p))
|
709
894
|
return ret # NOTE: if it returns None, we keep trying to match
|
710
895
|
match_stats[p][2] += time.perf_counter()-st
|
711
|
-
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0: rewrite_stack[-1][1][-1].matches.append((uop, ret, None, 0))
|
712
896
|
return None
|
713
897
|
|
714
898
|
if TRACK_MATCH_STATS:
|
@@ -717,12 +901,10 @@ if TRACK_MATCH_STATS:
|
|
717
901
|
@atexit.register
|
718
902
|
def print_match_stats():
|
719
903
|
if TRACK_MATCH_STATS >= 2:
|
720
|
-
with open(fn:=temp("rewrites.pkl"), "wb") as f:
|
721
|
-
print(f"rewrote {len(
|
722
|
-
pickle.dump(
|
723
|
-
if getenv("VIZ"):
|
724
|
-
os.environ["VIZ"] = "0"
|
725
|
-
os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py"), temp("rewrites.pkl")])
|
904
|
+
with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f:
|
905
|
+
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
|
906
|
+
with Context(PICKLE_BUFFERS=0): pickle.dump((tracked_keys, tracked_ctxs), f)
|
907
|
+
if getenv("VIZ"): launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
|
726
908
|
if getenv("PRINT_MATCH_STATS", 1):
|
727
909
|
ret = [0,0,0.0,0.0]
|
728
910
|
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
|
@@ -731,401 +913,47 @@ if TRACK_MATCH_STATS:
|
|
731
913
|
ret = [x+y for x,y in zip(ret, v)]
|
732
914
|
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {(ret[2]+ret[3])*1000.:9.2f} ms -- TOTAL")
|
733
915
|
|
916
|
+
def launch_viz(env_str:str, data:str):
|
917
|
+
os.environ[env_str] = "0"
|
918
|
+
os.environ[f"{env_str}_DATA"] = data
|
919
|
+
if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")):
|
920
|
+
args = ['--kernels', getenv("VIZ_DATA", "")] if getenv("VIZ_DATA", "") else []
|
921
|
+
args += ['--profile', getenv("PROFILE_DATA", "")] if getenv("PROFILE_DATA", "") else []
|
922
|
+
os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py")] + args)
|
923
|
+
|
734
924
|
# *** simple graph rewrite engine ***
|
735
925
|
|
736
926
|
class RewriteContext:
|
737
|
-
def __init__(self, pm, ctx):
|
927
|
+
def __init__(self, pm, ctx=None):
|
738
928
|
self.pm: PatternMatcher = pm
|
739
929
|
self.ctx = ctx
|
740
|
-
self.replace:
|
741
|
-
def
|
930
|
+
self.replace: dict[UOp, UOp] = {}
|
931
|
+
def top_down_rewrite(self, n:UOp) -> UOp:
|
742
932
|
if (rn := self.replace.get(n)) is not None: return rn
|
743
|
-
new_src = tuple(
|
933
|
+
new_src = tuple([self.top_down_rewrite(x) for x in n.src])
|
744
934
|
new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg)
|
745
|
-
self.replace[n] = ret = n if new_n is None else self.
|
935
|
+
self.replace[n] = ret = n if new_n is None else self.top_down_rewrite(new_n)
|
936
|
+
return ret
|
937
|
+
def bottom_up_rewrite(self, n:UOp) -> UOp:
|
938
|
+
if (rn := self.replace.get(n)) is not None: return rn
|
939
|
+
new_n: UOp|None = n
|
940
|
+
while new_n is not None: last_n, new_n = new_n, self.pm.rewrite(new_n, self.ctx)
|
941
|
+
new_src = tuple([self.bottom_up_rewrite(x) for x in last_n.src])
|
942
|
+
self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
|
746
943
|
return ret
|
747
944
|
|
748
|
-
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
|
749
|
-
if TRACK_MATCH_STATS >= 2 and len(
|
750
|
-
|
751
|
-
return RewriteContext(pm, ctx).
|
752
|
-
|
753
|
-
# ***** uop type spec *****
|
754
|
-
|
755
|
-
# this is the matcher for the final rendered UOps
|
756
|
-
# matcher functions returns True or False (or None to not match)
|
757
|
-
spec = PatternMatcher([
|
758
|
-
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
|
759
|
-
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
|
760
|
-
(UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True),
|
761
|
-
lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
|
762
|
-
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
763
|
-
|
764
|
-
(UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
|
765
|
-
(UPat(Ops.SPECIAL, src=()), lambda: True),
|
766
|
-
|
767
|
-
# TODO: confirm the args of both of these are shapetrackers
|
768
|
-
(UPat(Ops.VIEW, dtypes.void, src=()), lambda: True),
|
769
|
-
(UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype),
|
770
|
-
|
771
|
-
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
|
772
|
-
(UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
|
773
|
-
|
774
|
-
# early LOAD has a <buf, shapetracker, store?>
|
775
|
-
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
|
776
|
-
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
|
777
|
-
|
778
|
-
# early STORE has a <buf, shapetracker, val>
|
779
|
-
(UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
|
780
|
-
|
781
|
-
# **** new style load/store ****
|
782
|
-
|
783
|
-
# INDEX is used in new style load/store
|
784
|
-
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
|
785
|
-
|
786
|
-
# LOAD takes a <bufidx, alt?, gate?, barrier?>
|
787
|
-
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
|
788
|
-
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
|
789
|
-
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
|
790
|
-
|
791
|
-
# STORE takes a <bufidx, val, gate?>
|
792
|
-
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True),
|
793
|
-
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
794
|
-
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True),
|
795
|
-
|
796
|
-
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
797
|
-
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
798
|
-
(UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype == y.dtype),
|
799
|
-
# and SHL/SHR, the shift distance can be an int
|
800
|
-
(UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
|
801
|
-
(UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
802
|
-
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
|
803
|
-
|
804
|
-
(UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
|
805
|
-
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
|
806
|
-
|
807
|
-
# all WMMA has 3 args, <x, w, acc>
|
808
|
-
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
|
809
|
-
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
810
|
-
(UPat(Ops.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
811
|
-
|
812
|
-
# if has a <gate, barrier?>
|
813
|
-
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
|
814
|
-
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
|
815
|
-
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
816
|
-
|
817
|
-
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
818
|
-
(UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
819
|
-
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
820
|
-
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
|
821
|
-
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
|
822
|
-
|
823
|
-
# NOTE: for testing, we let sinks be anything
|
824
|
-
#(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
|
825
|
-
(UPat(Ops.SINK, dtypes.void), lambda: True),
|
826
|
-
(UPat(Ops.NOOP), lambda: True),
|
827
|
-
|
828
|
-
# PTX LOAD/STORE
|
829
|
-
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
830
|
-
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True),
|
831
|
-
])
|
832
|
-
|
833
|
-
def type_verify(uops:List[UOp]):
|
834
|
-
for i,u in enumerate(uops):
|
835
|
-
if not spec.rewrite(u):
|
836
|
-
print_uops(uops)
|
837
|
-
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}")
|
838
|
-
|
839
|
-
# *** uop helpers ***
|
840
|
-
|
841
|
-
def cast_float_to_bf16(x: UOp) -> UOp:
|
842
|
-
assert x.dtype == dtypes.float, "cast float -> bf16 must start with float"
|
843
|
-
x = x.bitcast(dtypes.uint)
|
844
|
-
x = (-x & 0x7f800000).where(x + ((x >> 16) & 1) + 0x7fff, (x & 0xffff).where((x | 0x10000), x))
|
845
|
-
return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
|
846
|
-
|
847
|
-
# *** most of symbolic lives here now ***
|
848
|
-
|
849
|
-
def split_uop(x:UOp, sep:Ops):
|
850
|
-
if x.op is sep:
|
851
|
-
for s in x.src: yield from split_uop(s, sep)
|
852
|
-
else: yield x
|
853
|
-
|
854
|
-
def mod_folding(x:UOp, c:int) -> Optional[UOp]:
|
855
|
-
# simplify x % c, None means no change
|
856
|
-
|
857
|
-
# simple cancel mod case
|
858
|
-
if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c
|
859
|
-
|
860
|
-
remainder, something_changed = [], False
|
861
|
-
for u in split_uop(x, Ops.ADD):
|
862
|
-
if (factor:=u.const_factor())%c != factor:
|
863
|
-
divides = u.divides(factor)*(factor%c)
|
864
|
-
assert divides is not None
|
865
|
-
remainder.append(divides)
|
866
|
-
something_changed = True
|
867
|
-
elif u.op is Ops.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0:
|
868
|
-
remainder.append(u.src[0])
|
869
|
-
something_changed = True
|
870
|
-
else: remainder.append(u)
|
871
|
-
if not something_changed: return None
|
872
|
-
return functools.reduce(operator.add, remainder)%c if remainder else x.const_like(0)
|
873
|
-
|
874
|
-
def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
875
|
-
# simplify x // c, None means no change
|
876
|
-
|
877
|
-
# simple cancel div case
|
878
|
-
if 0 <= x.vmin and x.vmax < c: return x.const_like(0)
|
879
|
-
|
880
|
-
quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1
|
881
|
-
for u in split_uop(x, Ops.ADD):
|
882
|
-
if u.op is Ops.CONST:
|
883
|
-
# add all const together first
|
884
|
-
if rem_const != 0: something_changed = True
|
885
|
-
rem_const += u.arg
|
886
|
-
elif (factor:=u.const_factor())%c == 0:
|
887
|
-
if factor:
|
888
|
-
divides = u.divides(c)
|
889
|
-
assert divides is not None
|
890
|
-
quotient.append(divides)
|
891
|
-
something_changed = True
|
892
|
-
else:
|
893
|
-
# divisor is the smallest common divisor of all MULs
|
894
|
-
if u.op is Ops.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor
|
895
|
-
remainder.append(u)
|
896
|
-
gcd = math.gcd(gcd, factor)
|
897
|
-
|
898
|
-
# handle the const
|
899
|
-
if rem_const%c != rem_const:
|
900
|
-
something_changed = True
|
901
|
-
quotient.append(x.const_like(rem_const//c))
|
902
|
-
rem_const = rem_const%c
|
903
|
-
if rem_const != 0: remainder.append(x.const_like(rem_const))
|
904
|
-
|
905
|
-
# x // c -> quotient + (remainder // div) // (c // div)
|
906
|
-
div = gcd if gcd > 1 else divisor
|
907
|
-
|
908
|
-
if not something_changed: return newx//(c//div) if 1 < div < c and (newx:=div_folding(x, div)) is not None else None
|
909
|
-
rem:Optional[UOp] = functools.reduce(operator.add, remainder) if remainder else None
|
910
|
-
quo:Optional[UOp] = functools.reduce(operator.add, quotient) if quotient else None
|
911
|
-
if quo is None: return x.const_like(0) if rem is None else cast(UOp, div_folding(rem, div))//(c//div)
|
912
|
-
return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo
|
913
|
-
|
914
|
-
def lt_folding(x:UOp, c:int) -> Optional[UOp]:
|
915
|
-
p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
|
916
|
-
if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
|
917
|
-
return cast(UOp, functools.reduce(operator.add, np).divides(d)).lt(c//d)
|
918
|
-
return None
|
919
|
-
|
920
|
-
def fold_unrolled_divs(divs:UOp):
|
921
|
-
# div pattern in unrolled arange
|
922
|
-
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
|
923
|
-
add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None
|
924
|
-
for u in add_chain:
|
925
|
-
if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
|
926
|
-
if denominator is None: denominator = u.src[1].arg
|
927
|
-
if denominator != u.src[1].arg: return None
|
928
|
-
# assumed CONST is the last of an ADD
|
929
|
-
if (s0:=u.src[0]).op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
|
930
|
-
seen_const.append(s0.src[1].arg)
|
931
|
-
s0 = s0.src[0]
|
932
|
-
else: seen_const.append(0)
|
933
|
-
if ans is None: ans = s0
|
934
|
-
if ans is not s0: return None
|
935
|
-
if denominator is None: return None
|
936
|
-
# the first (denominator-len(seen_const)) terms may have been folded to 0 already
|
937
|
-
for i in range(denominator-len(seen_const)):
|
938
|
-
if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
|
939
|
-
return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
|
940
|
-
|
941
|
-
def canonicalize_simplex(X:UOp) -> Optional[UOp]:
|
942
|
-
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
|
943
|
-
# returns x0 + x1 + ... in such case, or None if not
|
944
|
-
changed, ret = False, []
|
945
|
-
for u in split_uop(X, Ops.ADD):
|
946
|
-
# assumed the const is the last src of MUL
|
947
|
-
if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
|
948
|
-
changed = True
|
949
|
-
u = u.src[0]
|
950
|
-
if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None
|
951
|
-
ret.append(u)
|
952
|
-
return functools.reduce(operator.add, ret) if changed else None
|
953
|
-
|
954
|
-
def is_increasing(f:UOp) -> bool:
|
955
|
-
# is f a monotonically increasing function regards its input
|
956
|
-
if f.op in GroupOp.Irreducible: return True
|
957
|
-
if f.op is Ops.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
|
958
|
-
if f.op in (Ops.MUL, Ops.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
|
959
|
-
return False # False if not sure
|
960
|
-
|
961
|
-
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
|
962
|
-
# if it's X <= c, returns X, True, c
|
963
|
-
# if it's X >= c, returns X, False, c
|
964
|
-
|
965
|
-
# (X < c).ne(True) -> X >= c
|
966
|
-
if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
|
967
|
-
(s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
|
968
|
-
# X < c -> X <= c-1
|
969
|
-
if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1
|
970
|
-
raise ValueError(f"not able to parse {valid=}")
|
971
|
-
|
972
|
-
def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
973
|
-
# return None if valid is always False, otherwise the simplified uop (might be the same as input)
|
974
|
-
|
975
|
-
# first, parse valid into {expr: (lower_bound, upper_bound)}
|
976
|
-
bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
|
977
|
-
for stmt in split_uop(valid, Ops.AND):
|
978
|
-
try: expr, is_upper, c = parse_valid(stmt)
|
979
|
-
except ValueError: return uop # give up if we cannot parse the valid
|
980
|
-
bounds[expr][int(is_upper)] = c
|
981
|
-
|
982
|
-
# simplify uop given that valid is True
|
983
|
-
for expr,v in bounds.items():
|
984
|
-
# some expr has lower bound > upper bound -> valid is an empty set and we return None
|
985
|
-
if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
|
986
|
-
|
987
|
-
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
988
|
-
candidates = []
|
989
|
-
if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
|
990
|
-
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
991
|
-
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
|
992
|
-
# try checking the whole clause
|
993
|
-
if expr in uop.sparents:
|
994
|
-
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
|
995
|
-
|
996
|
-
for candidate in candidates:
|
997
|
-
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
|
998
|
-
newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
|
999
|
-
if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
|
1000
|
-
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
|
1001
|
-
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
|
1002
|
-
elif all_same(newuops): uop = newuops[0]
|
1003
|
-
|
1004
|
-
return uop
|
1005
|
-
|
1006
|
-
def _valid_priority(v: UOp, valids:List[UOp]):
|
1007
|
-
# we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
|
1008
|
-
try: return sum(-1 if parse_valid(v)[0] in other.parents else 0 for other in valids)
|
1009
|
-
except ValueError: return 0
|
1010
|
-
|
1011
|
-
def simplify_valid(valid:UOp) -> Optional[UOp]:
|
1012
|
-
ret:List[UOp] = []
|
1013
|
-
something_changed = False
|
1014
|
-
valids = list(split_uop(valid, Ops.AND))
|
1015
|
-
for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
|
1016
|
-
ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
|
1017
|
-
if ret[-1] is not stmt: something_changed = True
|
1018
|
-
return functools.reduce(operator.and_, ret) if something_changed else None
|
1019
|
-
|
1020
|
-
def max_var_const(x:UOp, c1:UOp, c2:UOp):
|
1021
|
-
if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
|
1022
|
-
if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
|
1023
|
-
|
1024
|
-
symbolic_simple = PatternMatcher([
|
1025
|
-
# ** self folding **
|
1026
|
-
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
|
1027
|
-
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
|
1028
|
-
(UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
|
1029
|
-
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
|
1030
|
-
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
|
1031
|
-
(UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
|
1032
|
-
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
1033
|
-
((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
|
1034
|
-
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
|
1035
|
-
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
|
1036
|
-
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
|
1037
|
-
(UPat.var("x").maximum(UPat.var("x")), lambda x: x),
|
1038
|
-
((UPat.var("x") & UPat.var("x")), lambda x: x),
|
1039
|
-
((UPat.var("x") | UPat.var("x")), lambda x: x),
|
1040
|
-
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
|
1041
|
-
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
|
1042
|
-
# ** zero folding **
|
1043
|
-
(UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
|
1044
|
-
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
|
1045
|
-
lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints)
|
1046
|
-
# x*0 -> 0 or 0*x -> 0
|
1047
|
-
# if x is nan or inf it should render the nan value.
|
1048
|
-
# NOTE: this can be wrong for loaded NaN
|
1049
|
-
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
1050
|
-
# ** constant folding **
|
1051
|
-
(UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))), lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False))),
|
1052
|
-
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
1053
|
-
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
|
1054
|
-
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
|
1055
|
-
(UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
|
1056
|
-
# *** cast ***
|
1057
|
-
(UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
|
1058
|
-
(UPat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
|
1059
|
-
])
|
945
|
+
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> UOp:
|
946
|
+
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
|
947
|
+
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
|
948
|
+
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink)
|
1060
949
|
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
(
|
1066
|
-
# ** boolean algebra **
|
1067
|
-
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
|
1068
|
-
# ** combine terms **
|
1069
|
-
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
|
1070
|
-
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
|
1071
|
-
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
|
1072
|
-
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
|
1073
|
-
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
|
1074
|
-
# a conditional with the same results either way is a noop, also fold const conditionals
|
1075
|
-
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
1076
|
-
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
1077
|
-
# ALU min==max -> CONST (slow!)
|
1078
|
-
(UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
1079
|
-
# max folding
|
1080
|
-
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
1081
|
-
# TODO: why does this rule break beautiful_mnist?
|
1082
|
-
#((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z),
|
1083
|
-
((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
|
1084
|
-
# ** two stage ALU folding **
|
1085
|
-
((UPat.var("x") + UPat.cvar("c1")) + UPat.cvar("c2"), lambda x,c1,c2: x+(c1+c2)),
|
1086
|
-
((UPat.var("x") * UPat.cvar("c1")) * UPat.cvar("c2"), lambda x,c1,c2: x*(c1*c2)),
|
1087
|
-
((UPat.var("x") & UPat.cvar("c1")) & UPat.cvar("c2"), lambda x,c1,c2: x&(c1&c2)),
|
1088
|
-
((UPat.var("x") | UPat.cvar("c1")) | UPat.cvar("c2"), lambda x,c1,c2: x|(c1|c2)),
|
1089
|
-
((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
|
1090
|
-
((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
|
1091
|
-
# ** lt **
|
1092
|
-
# c0*x<c1 for positive int c0,c1
|
1093
|
-
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
|
1094
|
-
lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if c0.arg > 0 and c1.arg > 0 else None),
|
1095
|
-
# c0*x<c1 for negative int c0 and non-positive c1
|
1096
|
-
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
|
1097
|
-
lambda x,c0,c1: (-x).lt(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
|
1098
|
-
# x//c0<c1 for positive int c0
|
1099
|
-
((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False)).lt(UPat.cvar("c1", vec=False)),
|
1100
|
-
lambda x,c0,c1: x.lt(c1.arg*c0.arg) if c0.arg > 0 else None),
|
1101
|
-
# mul add lt
|
1102
|
-
(((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
|
1103
|
-
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
|
1104
|
-
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
|
1105
|
-
(UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
|
1106
|
-
(UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
|
1107
|
-
# *** rules from symbolic ***
|
1108
|
-
# unrolled arange div folding
|
1109
|
-
(UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),
|
1110
|
-
# generic lt folding
|
1111
|
-
(UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
1112
|
-
# canonicalize a simplex with positive coefficients > 0
|
1113
|
-
# not x < 1 -> X > 0
|
1114
|
-
(UPat.var("x", dtypes.ints).lt(1).ne(True), lambda x: newx.lt(1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
|
1115
|
-
# ** div **
|
1116
|
-
# # div folding
|
1117
|
-
(UPat.var("x", dtypes.sints) // UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=div_folding(x,c.arg)) is not None else None),
|
1118
|
-
# ** mod **
|
1119
|
-
# mod folding
|
1120
|
-
(UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
|
1121
|
-
])
|
950
|
+
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> dict[UOp, UOp]:
|
951
|
+
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
|
952
|
+
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
|
953
|
+
rewrite_ctx = RewriteContext(pm, ctx)
|
954
|
+
return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
|
1122
955
|
|
1123
|
-
|
1124
|
-
# ** combine terms (opinionated) **
|
1125
|
-
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
1126
|
-
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
|
1127
|
-
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
1128
|
-
])
|
956
|
+
def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x
|
1129
957
|
|
1130
958
|
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
1131
959
|
|
@@ -1134,7 +962,7 @@ syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<"
|
|
1134
962
|
Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
|
1135
963
|
renderer = PatternMatcher([
|
1136
964
|
(UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
|
1137
|
-
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg
|
965
|
+
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg}")),
|
1138
966
|
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
|
1139
967
|
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
1140
968
|
(UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
|
@@ -1149,4 +977,27 @@ renderer = PatternMatcher([
|
|
1149
977
|
sint = Union[int, UOp]
|
1150
978
|
Variable = UOp
|
1151
979
|
|
1152
|
-
ConstLike = Union[ConstType, Variable,
|
980
|
+
ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]]
|
981
|
+
|
982
|
+
# *** UOp merge views and swizzling ***
|
983
|
+
|
984
|
+
merge_views = PatternMatcher([
|
985
|
+
# VIEW(VIEW) merges to a single VIEW
|
986
|
+
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)),
|
987
|
+
# remove VIEW if it's contiguous and same as the base shape
|
988
|
+
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.All-{Ops.DEVICE}, name="x"),)), lambda vm,x: x if vm.st.contiguous and x.shape == vm.shape else None),
|
989
|
+
# merge unmasked const views
|
990
|
+
(UPat(Ops.VIEW, name="view", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
|
991
|
+
lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
|
992
|
+
])
|
993
|
+
|
994
|
+
# push VIEW to parents
|
995
|
+
view_left = merge_views+PatternMatcher([
|
996
|
+
# VIEW(CONST) becomes VALID
|
997
|
+
(UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.valid(vm.st)),
|
998
|
+
# VIEW before elementwise/buffer ops
|
999
|
+
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
|
1000
|
+
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
|
1001
|
+
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)),
|
1002
|
+
lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
1003
|
+
])
|