tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/ops.py
DELETED
@@ -1,1003 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type, get_args
|
3
|
-
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
|
4
|
-
from enum import auto, IntEnum, Enum
|
5
|
-
from dataclasses import dataclass, field
|
6
|
-
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
|
7
|
-
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, _METADATA, flatten
|
8
|
-
from tinygrad.helpers import PICKLE_BUFFERS, SPLIT_REDUCEOP, DEBUG, dedup
|
9
|
-
if TYPE_CHECKING:
|
10
|
-
from tinygrad.shape.shapetracker import ShapeTracker
|
11
|
-
from tinygrad.device import Buffer
|
12
|
-
|
13
|
-
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
|
14
|
-
class FastEnum(IntEnum):
|
15
|
-
def __str__(self): return Enum.__str__(self)
|
16
|
-
@staticmethod
|
17
|
-
def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]])
|
18
|
-
|
19
|
-
class SimpleMathTrait:
|
20
|
-
# required to implement
|
21
|
-
def alu(self:T, arg:Ops, *src) -> T: raise NotImplementedError
|
22
|
-
def const_like(self:T, b:ConstLike) -> T: raise NotImplementedError
|
23
|
-
|
24
|
-
# great functions you get!
|
25
|
-
def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x
|
26
|
-
def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
|
27
|
-
def logical_not(self): return self.ne(True)
|
28
|
-
def neg(self):
|
29
|
-
if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
|
30
|
-
return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
|
31
|
-
def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse)
|
32
|
-
def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse)
|
33
|
-
def bitwise_and(self, x, reverse=False): return self._binop(Ops.AND, x, reverse)
|
34
|
-
def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse)
|
35
|
-
def xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse)
|
36
|
-
def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse)
|
37
|
-
def mod(self, x, reverse=False): return self._binop(Ops.MOD, x, reverse)
|
38
|
-
def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
|
39
|
-
def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
|
40
|
-
|
41
|
-
def __neg__(self): return self.neg()
|
42
|
-
|
43
|
-
def __add__(self, x): return self.add(x)
|
44
|
-
def __sub__(self, x): return self.sub(x)
|
45
|
-
def __mul__(self, x): return self.mul(x)
|
46
|
-
def __truediv__(self, x): return self.div(x)
|
47
|
-
def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv
|
48
|
-
def __mod__(self, x): return self.mod(x)
|
49
|
-
def __and__(self, x): return self.bitwise_and(x)
|
50
|
-
def __or__(self, x): return self.bitwise_or(x)
|
51
|
-
def __xor__(self, x): return self.xor(x)
|
52
|
-
|
53
|
-
def __radd__(self, x): return self.add(x, True)
|
54
|
-
def __rsub__(self, x): return self.sub(x, True)
|
55
|
-
def __rmul__(self, x): return self.mul(x, True)
|
56
|
-
def __rtruediv__(self, x): return self.div(x, True)
|
57
|
-
def __rfloordiv__(self, x): return self.idiv(x, True)
|
58
|
-
def __rand__(self, x): return self.bitwise_and(x, True)
|
59
|
-
def __ror__(self, x): return self.bitwise_or(x, True)
|
60
|
-
def __rxor__(self, x): return self.xor(x, True)
|
61
|
-
def __rmod__(self, x): return self.mod(x, True)
|
62
|
-
|
63
|
-
def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
|
64
|
-
def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
|
65
|
-
def __ge__(self, x): return (self < x).logical_not()
|
66
|
-
def __le__(self, x): return (self > x).logical_not()
|
67
|
-
|
68
|
-
def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x))
|
69
|
-
def eq(self, x): return self.ne(x).logical_not()
|
70
|
-
def __ne__(self, x): return self.ne(x)
|
71
|
-
# NOTE: __eq__ isn't overridden, and means the same thing as is by default
|
72
|
-
|
73
|
-
class MathTrait(SimpleMathTrait):
|
74
|
-
# TODO: move to Tensor when new backward is done
|
75
|
-
def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse)
|
76
|
-
def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse)
|
77
|
-
def __lshift__(self, x): return self.lshift(x)
|
78
|
-
def __rshift__(self, x): return self.rshift(x)
|
79
|
-
def __rlshift__(self, x): return self.lshift(x, True)
|
80
|
-
def __rrshift__(self, x): return self.rshift(x, True)
|
81
|
-
|
82
|
-
def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
|
83
|
-
def minimum(self, x): return -(-self).maximum(-x)
|
84
|
-
def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y))
|
85
|
-
def threefry(self, seed): return self.alu(Ops.THREEFRY, seed)
|
86
|
-
def reciprocal(self): return self.alu(Ops.RECIP)
|
87
|
-
def sqrt(self): return self.alu(Ops.SQRT)
|
88
|
-
def sin(self): return self.alu(Ops.SIN)
|
89
|
-
def log2(self): return self.alu(Ops.LOG2)
|
90
|
-
def exp2(self): return self.alu(Ops.EXP2)
|
91
|
-
def pow(self, x): return self.alu(Ops.POW, self.ufix(x))
|
92
|
-
|
93
|
-
# the order of these Ops controls the order of the toposort
|
94
|
-
class Ops(FastEnum):
|
95
|
-
# uops that aren't rendered
|
96
|
-
NAME = auto(); SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); KERNEL = auto(); UNIQUE = auto() # noqa: E702
|
97
|
-
|
98
|
-
# TODO: empty continues to exist because of tensor
|
99
|
-
EMPTY = auto()
|
100
|
-
|
101
|
-
# MetaOps
|
102
|
-
COPY = auto(); BUFFER_VIEW = auto() # noqa: E702
|
103
|
-
|
104
|
-
# blocks in linearizer
|
105
|
-
BLOCK = auto(); BLOCKSTART = auto(); BLOCKFORK = auto(); BLOCKEND = auto() # noqa: E702
|
106
|
-
|
107
|
-
# movement ops!
|
108
|
-
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
|
109
|
-
|
110
|
-
# misc ops
|
111
|
-
UNROLL = auto(); CONTRACT = auto() # noqa: E702
|
112
|
-
VIEW = auto(); DEFINE_GLOBAL = auto(); BUFFER = auto() # noqa: E702
|
113
|
-
DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
|
114
|
-
VALID = auto(); SPECIAL = auto(); NOOP = auto() # noqa: E702
|
115
|
-
|
116
|
-
# reduce
|
117
|
-
REDUCE_AXIS = auto()
|
118
|
-
|
119
|
-
# helper ops
|
120
|
-
GEP = auto(); VECTORIZE = auto(); CAT = auto() # noqa: E702
|
121
|
-
|
122
|
-
# UnaryOps
|
123
|
-
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
124
|
-
|
125
|
-
# load/store before math
|
126
|
-
LOAD = auto(); STORE = auto() # noqa: E702
|
127
|
-
|
128
|
-
# early INDEX
|
129
|
-
INDEX = auto()
|
130
|
-
|
131
|
-
# math ops
|
132
|
-
WMMA = auto()
|
133
|
-
|
134
|
-
# BinaryOps
|
135
|
-
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
136
|
-
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
|
137
|
-
|
138
|
-
# TernaryOps
|
139
|
-
WHERE = auto(); MULACC = auto() # noqa: E702
|
140
|
-
|
141
|
-
# assignment ops
|
142
|
-
ASSIGN = auto()
|
143
|
-
BIND = auto()
|
144
|
-
|
145
|
-
# control flow ops
|
146
|
-
BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702
|
147
|
-
|
148
|
-
# consts last!
|
149
|
-
VCONST = auto(); CONST = auto() # noqa: E702
|
150
|
-
|
151
|
-
# device
|
152
|
-
DEVICE = auto()
|
153
|
-
MULTI = auto()
|
154
|
-
CUSTOM = auto()
|
155
|
-
|
156
|
-
class GroupOp:
|
157
|
-
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
|
158
|
-
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
|
159
|
-
Ops.SUB, Ops.FDIV, Ops.POW}
|
160
|
-
Ternary = {Ops.WHERE, Ops.MULACC}
|
161
|
-
ALU = set.union(Unary, Binary, Ternary)
|
162
|
-
|
163
|
-
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
|
164
|
-
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
|
165
|
-
|
166
|
-
Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
|
167
|
-
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
|
168
|
-
|
169
|
-
# BinaryOps that can be flipped
|
170
|
-
Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR}
|
171
|
-
|
172
|
-
# BinaryOps where f(f(a,b),c) = f(a,f(b,c))
|
173
|
-
Associative = {Ops.ADD, Ops.MUL, Ops.AND, Ops.OR, Ops.MAX}
|
174
|
-
|
175
|
-
# BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence
|
176
|
-
Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
|
177
|
-
|
178
|
-
# do not preserve f(0) = 0
|
179
|
-
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
|
180
|
-
|
181
|
-
All = set(Ops)
|
182
|
-
|
183
|
-
# some BUFFER ops can be processed with only a view
|
184
|
-
view_supported_devices = {"LLVM", "CPU", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
|
185
|
-
|
186
|
-
# https://en.wikipedia.org/wiki/Identity_element
|
187
|
-
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
188
|
-
|
189
|
-
def can_pad(u:UOp, edges:dict[UOp, None], cache:dict[UOp, None]) -> bool:
|
190
|
-
if u.op in GroupOp.UnsafePad: return False
|
191
|
-
if u in edges or u in cache: return True
|
192
|
-
cache[u] = None
|
193
|
-
return all(can_pad(x.base, edges, cache) for x in u.src)
|
194
|
-
|
195
|
-
# With True as the default, this matches the old symbolic behavior
|
196
|
-
def resolve(x:UOp|bool, default:bool=True):
|
197
|
-
if isinstance(x, bool): return x
|
198
|
-
assert x.dtype == dtypes.bool, "UOp in resolve must be bool"
|
199
|
-
# NOTE: generating the text for the exception is expensive, so we do this
|
200
|
-
return bool(sx.vmin) if (sx:=x.simplify()).vmin == sx.vmax else default
|
201
|
-
|
202
|
-
# smax/smin are replacements for max/min that preserve symbolic
|
203
|
-
def _suop(lst, uop_fxn, python_fxn):
|
204
|
-
uops, nums = partition(lst, lambda x: isinstance(x, UOp))
|
205
|
-
return ssimplify(functools.reduce(uop_fxn, uops + ([python_fxn(nums)] if nums else [])))
|
206
|
-
def smax(*lst): return _suop(argfix(*lst), UOp.maximum, max)
|
207
|
-
def smin(*lst): return _suop(argfix(*lst), UOp.minimum, min)
|
208
|
-
|
209
|
-
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
210
|
-
def sym_infer(uop: Union[UOp, int], var_vals: dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
211
|
-
|
212
|
-
# used for UOp and UPat
|
213
|
-
def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
|
214
|
-
def dfs(x:Any, cache:dict):
|
215
|
-
for s in srcfn(x) or []:
|
216
|
-
cache.setdefault(s, [len(cache), 0, False])[1] += 1
|
217
|
-
if cache[s][1] == 1: dfs(s, cache)
|
218
|
-
if cache is None: dfs(x, cache:={})
|
219
|
-
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
|
220
|
-
cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
|
221
|
-
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
|
222
|
-
|
223
|
-
class UOpMetaClass(type):
|
224
|
-
ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
|
225
|
-
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, _buffer:Buffer|None=None):
|
226
|
-
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
|
227
|
-
UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
|
228
|
-
for s in src: s.children.add(ref)
|
229
|
-
# NOTE: this will soon be set by Tensor once we remove function.py
|
230
|
-
if (metadata:=_METADATA.get()) is not None: all_metadata[created] = metadata
|
231
|
-
# NOTE: this value is set by pickle when pickling a realized tensor
|
232
|
-
if _buffer is not None:
|
233
|
-
assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
|
234
|
-
buffers[created] = _buffer
|
235
|
-
return created
|
236
|
-
|
237
|
-
# some uops map to other stuff
|
238
|
-
buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers
|
239
|
-
all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary()
|
240
|
-
|
241
|
-
# NOTE: this should be frozen, but frozen is slower
|
242
|
-
@dataclass(eq=False, slots=True)
|
243
|
-
class UOp(MathTrait, metaclass=UOpMetaClass):
|
244
|
-
op:Ops
|
245
|
-
dtype:DType = dtypes.void
|
246
|
-
src:tuple[UOp, ...] = tuple()
|
247
|
-
arg:Any = None
|
248
|
-
children:set[weakref.ref[UOp]] = field(default_factory=set)
|
249
|
-
def __del__(self):
|
250
|
-
if self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1)
|
251
|
-
if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg))) is not None:
|
252
|
-
for s in self.src: s.children.discard(ref)
|
253
|
-
del UOpMetaClass.ucache[k]
|
254
|
-
def __reduce__(self):
|
255
|
-
args = [self.op, self.dtype, self.src, self.arg]
|
256
|
-
if self.op is Ops.BUFFER and self.realized is not None and PICKLE_BUFFERS: args.append(self.realized)
|
257
|
-
return UOp, tuple(args)
|
258
|
-
def replace(self, **kwargs) -> UOp:
|
259
|
-
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg))
|
260
|
-
assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}"
|
261
|
-
if (self.op, self.dtype, self.src, self.arg) == new_args: return self
|
262
|
-
return UOp(*new_args)
|
263
|
-
@functools.cached_property
|
264
|
-
def key(self) -> bytes:
|
265
|
-
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
|
266
|
-
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
267
|
-
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else repr(self.arg)
|
268
|
-
|
269
|
-
@property
|
270
|
-
def toposort(self) -> dict[UOp, None]:
|
271
|
-
def _toposort(u:UOp, cache:set[UOp]):
|
272
|
-
if u in cache: return {}
|
273
|
-
nodes: dict[UOp, None] = {}
|
274
|
-
# NOTE: this is a lot faster than the comprehension in parents
|
275
|
-
for parent in u.src: nodes.update(_toposort(parent, cache))
|
276
|
-
nodes[u] = None
|
277
|
-
cache.add(u)
|
278
|
-
return nodes
|
279
|
-
return _toposort(self, cache=set())
|
280
|
-
|
281
|
-
@functools.cached_property
|
282
|
-
def tuplize(self:UOp) -> tuple[int, Any, Optional[DType], tuple]: return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src))
|
283
|
-
|
284
|
-
# *** uop shape stuff ***
|
285
|
-
|
286
|
-
@functools.cached_property
|
287
|
-
def st(self) -> ShapeTracker|None:
|
288
|
-
from tinygrad.shape.shapetracker import ShapeTracker
|
289
|
-
if self.op is Ops.MULTI:
|
290
|
-
return ShapeTracker.from_shape(
|
291
|
-
tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)))
|
292
|
-
if self.op is Ops.BUFFER: return ShapeTracker.from_shape((self.size,))
|
293
|
-
if self.op is Ops.KERNEL: return ShapeTracker.from_shape(self.arg.ast.shape)
|
294
|
-
# these ops define a ShapeTracker from the arg
|
295
|
-
if self.op is Ops.VIEW: return self.arg
|
296
|
-
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
|
297
|
-
# buffer ops return the ShapeTracker from sources
|
298
|
-
if self.op in GroupOp.Buffer: return vsrc[0] if len(vsrc:=[x.st for x in self.src if x.op is Ops.VIEW]) != 0 else None
|
299
|
-
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
|
300
|
-
assert all_same([x.shape for x in src_sts]), f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}"
|
301
|
-
if self.op in {Ops.BITCAST, Ops.BUFFER_VIEW}:
|
302
|
-
shape = src_sts[0].shape
|
303
|
-
if self.dtype.itemsize != (input_sz:=self.src[0].dtype.itemsize): shape = shape[:-1]+((shape[-1]*input_sz) // self.dtype.itemsize,)
|
304
|
-
# only reduce ops are allowed to change shape, everything else derives shape from sources
|
305
|
-
elif self.op in {Ops.REDUCE_AXIS, Ops.WMMA}: shape = src_sts[0].reduce(self.axis_arg)
|
306
|
-
else: shape = src_sts[0].shape
|
307
|
-
return ShapeTracker.from_shape(shape)
|
308
|
-
|
309
|
-
@functools.cached_property
|
310
|
-
def full_shape(self) -> tuple[sint, ...]:
|
311
|
-
if self.op is Ops.VIEW: return self.shape
|
312
|
-
# TODO: this should check if st is None, it cannot because local reduce has implicit movement ops
|
313
|
-
return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} \
|
314
|
-
# TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this
|
315
|
-
and not (x.op is Ops.CONST and x.st is None)]))
|
316
|
-
@property
|
317
|
-
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
|
318
|
-
@property
|
319
|
-
def size(self) -> int: return self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
|
320
|
-
|
321
|
-
# *** uop evaluation ***
|
322
|
-
|
323
|
-
def simplify(self):
|
324
|
-
# late import!
|
325
|
-
from tinygrad.codegen.symbolic import symbolic
|
326
|
-
with Context(TRACK_MATCH_STATS=0):
|
327
|
-
return graph_rewrite(self, symbolic)
|
328
|
-
def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
|
329
|
-
def _eval(self, dtype, expected_type:Type[T]) -> T:
|
330
|
-
assert self.dtype in dtype, f"eval with wrong dtype {self}"
|
331
|
-
vmin, vmax = (simple_self:=self.simplify())._min_max
|
332
|
-
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self.render()}")
|
333
|
-
assert isinstance(vmin, expected_type), f"vmin is wrong dtype {type(vmin)} != {expected_type}"
|
334
|
-
return vmin
|
335
|
-
def __bool__(self): return self._eval((dtypes.bool,), bool)
|
336
|
-
def __int__(self): return self._eval(dtypes.ints, int)
|
337
|
-
def __float__(self): return self._eval(dtypes.floats, float)
|
338
|
-
def substitute(self, dvars:dict[UOp, UOp]):
|
339
|
-
with Context(TRACK_MATCH_STATS=0):
|
340
|
-
return graph_rewrite(self, _substitute, dvars, bottom_up=True)
|
341
|
-
|
342
|
-
# *** uop syntactic sugar ***
|
343
|
-
|
344
|
-
@property
|
345
|
-
def st_arg(self) -> ShapeTracker:
|
346
|
-
assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}"
|
347
|
-
return unwrap(self.st)
|
348
|
-
@property
|
349
|
-
def axis_arg(self) -> tuple[int, ...]:
|
350
|
-
assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
|
351
|
-
ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
|
352
|
-
assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
|
353
|
-
return ret
|
354
|
-
def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
|
355
|
-
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
356
|
-
def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
357
|
-
def const_like(self, b:ConstLike):
|
358
|
-
# constants can optionally have a DEVICE source
|
359
|
-
if self._device is None: return UOp.const(self.dtype, b)
|
360
|
-
if isinstance(self.device, tuple): return UOp.multi(*[UOp.metaop(Ops.CONST, self.shape, self.dtype, d, b) for d in self.device], axis=None)
|
361
|
-
return UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b)
|
362
|
-
def broadcast(self, count:int):
|
363
|
-
assert self.dtype.count == 1
|
364
|
-
if count == 1: return self
|
365
|
-
return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
|
366
|
-
def cast(self, dtype:DType): return self if self.dtype == dtype else UOp(Ops.CAST, dtype, (self,))
|
367
|
-
def bitcast(self, dtype:DType): return self if self.dtype == dtype else UOp(Ops.BITCAST, dtype, (self,))
|
368
|
-
def gep(self, i:Union[tuple[int, ...], int]):
|
369
|
-
if isinstance(i, int):
|
370
|
-
# NOTE: these are just shortcuts to not have to create and fold later
|
371
|
-
if self.op is Ops.VECTORIZE: return self.src[i]
|
372
|
-
if self.op is Ops.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
|
373
|
-
if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg)
|
374
|
-
i = (i,)
|
375
|
-
if (self.dtype.vcount == len(i) and i == tuple(range(len(i)))) or self.dtype == dtypes.void: return self
|
376
|
-
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
377
|
-
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, src=(self,)+src, **kwargs)
|
378
|
-
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
379
|
-
def alu(self, arg, *src:UOp):
|
380
|
-
out_dtype = (self, *src)[-1].dtype
|
381
|
-
if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
382
|
-
return UOp(arg, out_dtype, (self,)+src)
|
383
|
-
@staticmethod
|
384
|
-
def const(dtype:DType, b:ConstLike):
|
385
|
-
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
386
|
-
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
387
|
-
return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
|
388
|
-
def valid(self, st:ShapeTracker):
|
389
|
-
assert self.op in {Ops.CONST, Ops.DEFINE_VAR}, f"can only create VALID from a constant, got {self.op}"
|
390
|
-
from tinygrad.shape.shapetracker import ShapeTracker
|
391
|
-
# NOTE: only VALID has a masked ShapeTracker, the CONST operands are unmasked
|
392
|
-
unmasked_st = ShapeTracker.from_shape(()).reshape((1,)*len(st.shape)).expand(st.shape).to_uop()
|
393
|
-
return UOp(Ops.VALID, dtypes.bool, (st.to_uop(),)).where(self.replace(src=(unmasked_st,)), UOp.const(self.dtype, 0).replace(src=(unmasked_st,)))
|
394
|
-
@staticmethod
|
395
|
-
def range(dtype:DType, start:sint, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(start), sint_to_uop(end)), arg=idx)
|
396
|
-
def _reduce_op(self, op:Ops, axis:tuple[int, ...]):
|
397
|
-
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
398
|
-
return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
|
399
|
-
def r(self, op:Ops, axis:tuple[int, ...]) -> UOp:
|
400
|
-
new_shape = unwrap(self.st).reduce(axis)
|
401
|
-
|
402
|
-
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
403
|
-
# TODO: this shouldn't be here, it belongs in scheduler! that's why it broke multi
|
404
|
-
if not SPLIT_REDUCEOP or isinstance(self._device, tuple) or not all_int(self.shape) or (0 in self.shape) or \
|
405
|
-
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
|
406
|
-
return self._reduce_op(op, axis)
|
407
|
-
|
408
|
-
# if there are few globals, make some reduces into globals by splitting into two kernels
|
409
|
-
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
|
410
|
-
# ~2**10 should be enough if GROUP is used
|
411
|
-
# 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
|
412
|
-
# split is moved to the end to provide maximum locality for the second phase reduce.
|
413
|
-
self_real_strides = unwrap(self.st).real_strides(ignore_valid=True)
|
414
|
-
split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
|
415
|
-
if self.shape[i] % x == 0 and self_real_strides[i] != 0]
|
416
|
-
if not split_candidates: return self._reduce_op(op, axis)
|
417
|
-
dim_to_split, divisor = split_candidates[0]
|
418
|
-
splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
|
419
|
-
splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
|
420
|
-
if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
|
421
|
-
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
|
422
|
-
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
423
|
-
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
|
424
|
-
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
|
425
|
-
|
426
|
-
# *** from MultiLazyBuffer ***
|
427
|
-
|
428
|
-
def multi(self, *more:UOp, axis:int|None, real:tuple[bool,...]|None=None):
|
429
|
-
parents = (self,)+more
|
430
|
-
assert all_same([x.dtype for x in parents]), "multi parents must have the same dtype"
|
431
|
-
return UOp(Ops.MULTI, self.dtype, parents, (axis, real if real is not None else (True,)*len(parents)))
|
432
|
-
|
433
|
-
@property
|
434
|
-
def bounds(self):
|
435
|
-
if self.axis is None: raise RuntimeError("bounds is not defined when axis is None")
|
436
|
-
return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.src], initial=0)))
|
437
|
-
|
438
|
-
@functools.cached_property
|
439
|
-
def axis(self) -> Optional[int]:
|
440
|
-
if self.op is Ops.MULTI: return self.arg[0]
|
441
|
-
# NOTE: they all have to share an axis, we always choose [-1]
|
442
|
-
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
|
443
|
-
src_axis = self.src[0].axis
|
444
|
-
if self.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in self.arg[1] else src_axis
|
445
|
-
if self.op is Ops.RESHAPE:
|
446
|
-
if src_axis is None: return None
|
447
|
-
arg_acc:list[sint] = list(itertools.accumulate(self.arg, operator.mul, initial=1))
|
448
|
-
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
|
449
|
-
# TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
|
450
|
-
return len(arg_acc) - arg_acc[::-1].index(prod(self.src[0].shape[:src_axis])) - 1
|
451
|
-
if self.op is Ops.PERMUTE: return self.arg.index(src_axis) if src_axis is not None else None
|
452
|
-
return src_axis
|
453
|
-
|
454
|
-
@property
|
455
|
-
def real(self):
|
456
|
-
assert self.op is Ops.MULTI
|
457
|
-
return self.arg[1]
|
458
|
-
|
459
|
-
@property
|
460
|
-
def real_lbs(self): return [lb for lb,r in zip(self.src, self.real) if r]
|
461
|
-
|
462
|
-
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> UOp:
|
463
|
-
if axis is None: lbs = [self] * len(devices)
|
464
|
-
else:
|
465
|
-
if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}")
|
466
|
-
# NOTE: this works for both even shards and uneven shards
|
467
|
-
sz = self.shape[axis] // len(devices)
|
468
|
-
sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]
|
469
|
-
lbs = []
|
470
|
-
for sz,off in zip(sizes, itertools.accumulate(sizes, initial=0)):
|
471
|
-
lbs.append(self.shrink(tuple((0,s) if i != axis else (off,off+sz) for i,s in enumerate(self.shape))))
|
472
|
-
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)]
|
473
|
-
return UOp.multi(*[lb.contiguous() for lb in sharded_lbs], axis=axis)
|
474
|
-
|
475
|
-
# *** from LazyBuffer ***
|
476
|
-
|
477
|
-
@staticmethod
|
478
|
-
def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None) -> UOp:
|
479
|
-
from tinygrad.shape.shapetracker import ShapeTracker
|
480
|
-
# Tensor const is CONST(VIEW(DEVICE)) -> RESHAPE -> EXPAND
|
481
|
-
if op is Ops.CONST:
|
482
|
-
assert isinstance(arg, get_args(ConstType)), f"trying to create CONST with {arg=}"
|
483
|
-
return UOp.const(dtype, unwrap(arg)).replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),),
|
484
|
-
ShapeTracker.from_shape(())),)).reshape((1,)*len(shape)).expand(shape)
|
485
|
-
# Tensor variable binding is BIND(VAR(VIEW(DEVICE)), CONST(VIEW(DEVICE)))
|
486
|
-
if op is Ops.BIND:
|
487
|
-
var, val = arg.unbind()
|
488
|
-
return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
|
489
|
-
# otherwise it's just a RESHAPE(BUFFER)
|
490
|
-
if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
|
491
|
-
return UOp.new_buffer(device, size, dtype).reshape(shape)
|
492
|
-
def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp:
|
493
|
-
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
494
|
-
if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
|
495
|
-
# COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st)
|
496
|
-
ret = UOp(Ops.COPY, self.base.dtype, (UOp(Ops.DEVICE, arg=device), self.base), clone)
|
497
|
-
op_arg = []
|
498
|
-
mop = self
|
499
|
-
while mop is not self.base:
|
500
|
-
op_arg.append((mop.op, mop.arg))
|
501
|
-
mop = mop.src[0]
|
502
|
-
for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg)
|
503
|
-
return ret
|
504
|
-
def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
|
505
|
-
@property
|
506
|
-
def metadata(self) -> tuple[Metadata, ...]|Metadata|None: return self.arg.metadata if self.op is Ops.KERNEL else all_metadata.get(self, None)
|
507
|
-
|
508
|
-
# *** uop movement ops ***
|
509
|
-
|
510
|
-
@property
|
511
|
-
def base(self) -> UOp:
|
512
|
-
if (self.op is Ops.VIEW and len(self.src) != 0) or self.op in GroupOp.Movement: return self.src[0].base
|
513
|
-
return self
|
514
|
-
def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
|
515
|
-
|
516
|
-
def _mop(self, op:Ops, arg):
|
517
|
-
ret = UOp(op, self.dtype, (self,), arg)
|
518
|
-
if self.st == ret.st: return self # ignore NOOPs, also check ret.st
|
519
|
-
return ret
|
520
|
-
|
521
|
-
def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg)
|
522
|
-
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg)
|
523
|
-
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg)
|
524
|
-
def permute(self, arg:tuple[sint, ...]): return self._mop(Ops.PERMUTE, arg)
|
525
|
-
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg)
|
526
|
-
def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg)
|
527
|
-
|
528
|
-
# *** uop UNIQUE ***
|
529
|
-
|
530
|
-
# TODO: use this in Buffer
|
531
|
-
unique_num = itertools.count(0)
|
532
|
-
@staticmethod
|
533
|
-
def unique(): return UOp(Ops.UNIQUE, arg=next(UOp.unique_num))
|
534
|
-
|
535
|
-
# *** uop Buffer stuff ***
|
536
|
-
|
537
|
-
@staticmethod
|
538
|
-
def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device), UOp.unique()), size)
|
539
|
-
@property
|
540
|
-
def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
|
541
|
-
@functools.cached_property
|
542
|
-
def _device(self) -> Optional[str|tuple[str, ...]]:
|
543
|
-
if self.op is Ops.DEVICE: return self.arg
|
544
|
-
if self.op is Ops.MULTI: return tuple(cast(str, x.device) for x in self.src)
|
545
|
-
return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
|
546
|
-
@property
|
547
|
-
def buf_uop(self) -> UOp:
|
548
|
-
if self.base.op is Ops.BUFFER: return self.base
|
549
|
-
assert self.base.op in {*GroupOp.Buffer, Ops.ASSIGN}, f"buf_uop called on {self.op}"
|
550
|
-
return self.src[0].buf_uop
|
551
|
-
@property
|
552
|
-
def buffer(self) -> Buffer:
|
553
|
-
if self is not self.base:
|
554
|
-
assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
|
555
|
-
return self.src[0].buffer
|
556
|
-
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
|
557
|
-
if (cret:=buffers.get(self)) is not None: return cret
|
558
|
-
from tinygrad.device import Buffer
|
559
|
-
assert isinstance(self.device, str), f"buffer not supported on multi {self.device}"
|
560
|
-
buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base)
|
561
|
-
return ret
|
562
|
-
@property
|
563
|
-
def realized(self) -> Optional[Buffer]: return self.buffer if self.op is Ops.BUFFER else None
|
564
|
-
@property
|
565
|
-
def is_realized(self) -> bool:
|
566
|
-
return all(x.base.realized is not None for x in self.base.real_lbs) if self.base.op is Ops.MULTI else self.base.realized is not None
|
567
|
-
|
568
|
-
# *** uop Variable stuff ***
|
569
|
-
|
570
|
-
@staticmethod
|
571
|
-
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int):
|
572
|
-
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
|
573
|
-
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
574
|
-
@property
|
575
|
-
def expr(self):
|
576
|
-
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
577
|
-
return self.arg[0]
|
578
|
-
def bind(self, val:int):
|
579
|
-
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
580
|
-
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
|
581
|
-
return UOp(Ops.BIND, self.dtype, (self, self.const_like(val)))
|
582
|
-
def unbind(self) -> tuple[Variable, int]:
|
583
|
-
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
|
584
|
-
return self.src[0], self.src[1].arg
|
585
|
-
@property
|
586
|
-
def val(self) -> int: return self.unbind()[1]
|
587
|
-
def vars(self) -> set[UOp]:
|
588
|
-
bound_vars = set([x for x in self.toposort if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
|
589
|
-
bound_var_base = set(x.src[0] for x in bound_vars)
|
590
|
-
all_vars = set([x for x in self.toposort if x.op is Ops.DEFINE_VAR])
|
591
|
-
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
592
|
-
def variables(self) -> list[Variable]:
|
593
|
-
st_vars: list[set[Variable]] = [x.st_arg.vars() for x in self.toposort if x.op in GroupOp.Buffer]
|
594
|
-
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
|
595
|
-
|
596
|
-
# *** uop symbolic stuff ***
|
597
|
-
|
598
|
-
def is_increasing(self:UOp) -> bool:
|
599
|
-
# is f a monotonically increasing function regards its input
|
600
|
-
if self.op in GroupOp.Irreducible: return True
|
601
|
-
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
|
602
|
-
if self.op in (Ops.MUL, Ops.IDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
|
603
|
-
return False # False if not sure
|
604
|
-
def const_factor(self) -> int:
|
605
|
-
"""largest known int that divides self"""
|
606
|
-
if self.op is Ops.CONST: return self.arg
|
607
|
-
if self.op is Ops.VCONST: return math.gcd(*self.arg)
|
608
|
-
if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
609
|
-
if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
|
610
|
-
return 1
|
611
|
-
def divides(self, v:int) -> UOp|None:
|
612
|
-
if v==1: return self
|
613
|
-
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
|
614
|
-
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
|
615
|
-
if self.op is Ops.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
|
616
|
-
if self.op is Ops.MUL:
|
617
|
-
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
|
618
|
-
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
619
|
-
return None # generic None if we aren't sure
|
620
|
-
@property
|
621
|
-
def vmin(self) -> ConstType: return self._min_max[0]
|
622
|
-
@property
|
623
|
-
def vmax(self) -> ConstType: return self._min_max[1]
|
624
|
-
@functools.cached_property
|
625
|
-
def _min_max(self) -> tuple[ConstType, ConstType]:
|
626
|
-
if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype):
|
627
|
-
(s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
|
628
|
-
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
|
629
|
-
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
|
630
|
-
# SHL/SHR on consts only
|
631
|
-
if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
|
632
|
-
if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2]
|
633
|
-
if self.op is Ops.MOD and s1_vmin > 0:
|
634
|
-
return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), s1_vmax-1)
|
635
|
-
if self.op is Ops.IDIV:
|
636
|
-
if s1_vmin == s1_vmax: # min/max are equal in a CONST
|
637
|
-
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
|
638
|
-
if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
|
639
|
-
# don't know exact bounds, but know the sign
|
640
|
-
if (s0_vmax <= 0 and s1_vmin < 0) or (s0_vmin >= 0 and s1_vmin > 0): return 0, dtypes.max(self.dtype)
|
641
|
-
if (s0_vmax <= 0 and s1_vmin > 0) or (s0_vmin >= 0 and s1_vmin < 0): return dtypes.min(self.dtype), 0
|
642
|
-
if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
|
643
|
-
if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
|
644
|
-
if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
|
645
|
-
if self.dtype == dtypes.bool:
|
646
|
-
if self.op is Ops.OR: return s0_vmin or s1_vmin, s0_vmax or s1_vmax
|
647
|
-
if self.op is Ops.AND: return s0_vmin and s1_vmin, s0_vmax and s1_vmax
|
648
|
-
# float has NAN issue and we use explicit NAN in transcendental
|
649
|
-
if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax)
|
650
|
-
# NOTE: returned UOp is assumed to be CONST
|
651
|
-
if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
|
652
|
-
if self.op is Ops.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
|
653
|
-
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
|
654
|
-
if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
|
655
|
-
# TODO: Ops.SPECIAL is Ops.DEFINE_VAR
|
656
|
-
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].vmax
|
657
|
-
if self.op is Ops.CONST: return self.arg, self.arg
|
658
|
-
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
|
659
|
-
return dtypes.min(self.dtype), dtypes.max(self.dtype)
|
660
|
-
|
661
|
-
@functools.cached_property
|
662
|
-
def _sym_fxn(self):
|
663
|
-
sself = self.simplify()
|
664
|
-
varnames = tuple(x.arg[0] for x in sself.toposort if x.op is Ops.DEFINE_VAR)
|
665
|
-
# TODO: sanitize varnames, or don't use naked eval while staying fast
|
666
|
-
return eval("lambda "+','.join(varnames)+": "+sself.render()), varnames # pylint: disable=eval-used
|
667
|
-
|
668
|
-
def sym_infer(self, var_vals:dict[UOp, int]):
|
669
|
-
fxn, varnames = self._sym_fxn
|
670
|
-
return fxn(**{k.arg[0]:v for k,v in var_vals.items() if k.arg[0] in varnames})
|
671
|
-
|
672
|
-
def render(self, simplify=True) -> str:
|
673
|
-
ret = graph_rewrite(self.simplify() if simplify else self, renderer)
|
674
|
-
return ret.arg if ret.op is Ops.NOOP else str(ret)
|
675
|
-
|
676
|
-
@dataclass(frozen=True)
|
677
|
-
class KernelInfo:
|
678
|
-
name: str = "test" # name of the kernel
|
679
|
-
local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
|
680
|
-
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL)
|
681
|
-
dont_use_locals: bool = False # don't use local indexing
|
682
|
-
|
683
|
-
# ******** ops in python ********
|
684
|
-
|
685
|
-
def safe_exp2(x):
|
686
|
-
try: return 2 ** x
|
687
|
-
except OverflowError: return math.inf
|
688
|
-
|
689
|
-
def safe_pow(x, y):
|
690
|
-
try: return math.nan if isinstance(p:=pow(x, y), complex) else p
|
691
|
-
except ZeroDivisionError: return math.inf
|
692
|
-
except ValueError: return math.inf if x > 0 else -math.inf
|
693
|
-
|
694
|
-
python_alu: dict[Ops, Callable] = {
|
695
|
-
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
|
696
|
-
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
697
|
-
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow,
|
698
|
-
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
|
699
|
-
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
|
700
|
-
Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0,
|
701
|
-
Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z}
|
702
|
-
|
703
|
-
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
704
|
-
if dtype.count > 1:
|
705
|
-
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
|
706
|
-
alu = python_alu[op](*operands)
|
707
|
-
return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu
|
708
|
-
|
709
|
-
# ***** uop helpers *****
|
710
|
-
|
711
|
-
def print_uops(uops:list[UOp]):
|
712
|
-
for i,u in enumerate(uops):
|
713
|
-
formatted_parents = [(uops.index(x) if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
|
714
|
-
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype):30s} " f"{str(formatted_parents):32s} {u.arg}")
|
715
|
-
|
716
|
-
# ***** pattern matcher *****
|
717
|
-
|
718
|
-
def get_location() -> tuple[str, int]:
|
719
|
-
frm = sys._getframe(1)
|
720
|
-
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
|
721
|
-
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py",
|
722
|
-
"symbolic.py", "expander.py", "lowerer.py", "cstyle.py",
|
723
|
-
"linearize.py"}:
|
724
|
-
frm = frm.f_back
|
725
|
-
return frm.f_code.co_filename, frm.f_lineno
|
726
|
-
@functools.lru_cache(None)
|
727
|
-
def lines(fn) -> list[str]:
|
728
|
-
with open(fn) as f: return f.readlines()
|
729
|
-
|
730
|
-
class UPat(MathTrait):
|
731
|
-
__slots__ = ("op", "dtype", "arg", "name", "src")
|
732
|
-
def __init__(self, op:Optional[Union[Ops, tuple[Ops, ...], set[Ops]]]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None,
|
733
|
-
src:Optional[Union[tuple[UPat, ...], list[UPat], UPat]]=None, arg:Any=None,
|
734
|
-
name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[set[Ops]]=None):
|
735
|
-
assert op is None or isinstance(op, Ops) or isinstance(op, tuple) or isinstance(op, set), "op must be Ops or tuple of Ops"
|
736
|
-
self.op: Optional[tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
|
737
|
-
self.dtype: Optional[tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
|
738
|
-
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
|
739
|
-
self.src: Any = None
|
740
|
-
assert self.name != "ctx", "UPat can't be named ctx"
|
741
|
-
|
742
|
-
# try all permutations if it's a list
|
743
|
-
if isinstance(src, list): self.src = list(itertools.permutations(src)) if not all_same(src) else [src]
|
744
|
-
# only one if it's a tuple
|
745
|
-
elif isinstance(src, tuple): self.src = [src]
|
746
|
-
# repeat if it's a UPat
|
747
|
-
elif isinstance(src, UPat): self.src = [itertools.repeat(src)]
|
748
|
-
|
749
|
-
self.allowed_len: int = -1 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
|
750
|
-
self.location = location or get_location()
|
751
|
-
|
752
|
-
if custom_early_reject is not None: self.early_reject = custom_early_reject
|
753
|
-
else:
|
754
|
-
upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0])
|
755
|
-
self.early_reject = {pp.op[0] for pp in upat_match if pp.op is not None and len(pp.op) == 1}
|
756
|
-
|
757
|
-
def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, self.allowed_len == -1, self.custom_early_reject)
|
758
|
-
|
759
|
-
@staticmethod
|
760
|
-
def any(*src): return UPatAny(src=src)
|
761
|
-
|
762
|
-
@staticmethod
|
763
|
-
@functools.lru_cache(None)
|
764
|
-
def var(name:Optional[str]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name)
|
765
|
-
@staticmethod
|
766
|
-
@functools.lru_cache(None)
|
767
|
-
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True): return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name)
|
768
|
-
@staticmethod
|
769
|
-
def const(dtype:Optional[Union[DType, tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
|
770
|
-
|
771
|
-
# copied from UOp
|
772
|
-
def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
773
|
-
def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
|
774
|
-
def cast(self, dtype=None): return UPat(Ops.CAST, dtype, (self,))
|
775
|
-
def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
|
776
|
-
def gep(self, i:int): return UPat(Ops.GEP, None, (self,), (i,))
|
777
|
-
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
|
778
|
-
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
779
|
-
def assign(self, x:UPat): return UPat(Ops.ASSIGN, self.dtype, (self,x))
|
780
|
-
|
781
|
-
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
|
782
|
-
def alu(self, op:Ops, *src:UPat):
|
783
|
-
asrc = (self,)+src
|
784
|
-
return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
|
785
|
-
|
786
|
-
def printable(self:UPat) -> str:
|
787
|
-
try: return lines(self.location[0])[self.location[1]-1].strip()
|
788
|
-
except FileNotFoundError: return "<missing>"
|
789
|
-
|
790
|
-
def __repr__(self):
|
791
|
-
def rep(x):
|
792
|
-
form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
|
793
|
-
return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name),
|
794
|
-
set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
|
795
|
-
return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
|
796
|
-
|
797
|
-
def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
|
798
|
-
if (self.op is not None and uop.op not in self.op) or \
|
799
|
-
(self.name is not None and store.setdefault(self.name, uop) is not uop) or \
|
800
|
-
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
|
801
|
-
(self.arg is not None and self.arg != uop.arg) or \
|
802
|
-
(self.allowed_len != -1 and len(uop.src) != self.allowed_len): return []
|
803
|
-
if self.src is None: return [store]
|
804
|
-
res: list[dict[str, UOp]] = []
|
805
|
-
for vp in self.src:
|
806
|
-
stores, new_stores = [store.copy()], []
|
807
|
-
for uu, vv in zip(uop.src, vp):
|
808
|
-
for s in stores: new_stores.extend(vv.match(uu, s))
|
809
|
-
stores, new_stores = new_stores, []
|
810
|
-
res.extend(stores)
|
811
|
-
return res
|
812
|
-
|
813
|
-
class UPatAny(UPat):
|
814
|
-
def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
|
815
|
-
matches = [x.match(uop, store.copy()) for x in self.src[0]]
|
816
|
-
return flatten([x for x in matches if x is not None])
|
817
|
-
|
818
|
-
def deconstruct_function(fxn:Callable) -> tuple:
|
819
|
-
new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
|
820
|
-
for co in fxn.__code__.co_consts:
|
821
|
-
if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names})
|
822
|
-
# NOTE: optional round trip through pickle!
|
823
|
-
assert fxn.__closure__ is None, "closures are not supported in pattern matchers"
|
824
|
-
ret = fxn.__code__, new_globals, fxn.__name__, fxn.__defaults__
|
825
|
-
return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret
|
826
|
-
|
827
|
-
class PatternMatcher:
|
828
|
-
def __init__(self, patterns:list[tuple[UPat, Callable]]):
|
829
|
-
self.patterns = patterns
|
830
|
-
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
|
831
|
-
self.pdict: dict[Ops, list[tuple[UPat, Callable, set, bool]]] = {}
|
832
|
-
# uop is required, arg is optional
|
833
|
-
for p,fxn in self.patterns:
|
834
|
-
assert p.op is not None
|
835
|
-
tuple_fxn = fxn if isinstance(fxn, tuple) else deconstruct_function(fxn)
|
836
|
-
real_fxn = types.FunctionType(*tuple_fxn)
|
837
|
-
for uop in p.op: self.pdict.setdefault(uop, []).append((p, real_fxn, p.early_reject, 'ctx' in inspect.signature(real_fxn).parameters))
|
838
|
-
|
839
|
-
def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "<lambda>" else fxn) for x,fxn in self.patterns],)
|
840
|
-
|
841
|
-
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
842
|
-
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
843
|
-
|
844
|
-
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
845
|
-
ler = {u.op for u in uop.src}
|
846
|
-
for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
|
847
|
-
if not early_reject.issubset(ler): continue
|
848
|
-
for match in p.match(uop, {}):
|
849
|
-
if (ret:=(fxn(ctx=ctx, **match) if has_ctx else fxn(**match))) is not None: return ret
|
850
|
-
return None
|
851
|
-
|
852
|
-
# *** tracking pattern matcher ***
|
853
|
-
|
854
|
-
TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
|
855
|
-
match_stats:dict[UPat, list[Union[int, float]]] = dict()
|
856
|
-
@dataclass(frozen=True)
|
857
|
-
class TrackedGraphRewrite:
|
858
|
-
loc: tuple[str, int] # location that called graph_rewrite
|
859
|
-
sink: UOp # the sink input to graph_rewrite
|
860
|
-
bottom_up: bool
|
861
|
-
matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches
|
862
|
-
name: Optional[str] = None
|
863
|
-
tracked_keys:list[Any] = []
|
864
|
-
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
|
865
|
-
_name_cnt:dict[str, int] = {}
|
866
|
-
def track_rewrites(named=False):
|
867
|
-
def _decorator(func):
|
868
|
-
def __wrapper(self, *args, **kwargs):
|
869
|
-
if TRACK_MATCH_STATS >= 2:
|
870
|
-
if named: _name_cnt[func.__name__] = _name_cnt.get(func.__name__, 0)+1
|
871
|
-
tracked_keys.append(f"{func.__name__}_{_name_cnt[func.__name__]}" if named else self)
|
872
|
-
tracked_ctxs.append([])
|
873
|
-
return func(self, *args, **kwargs)
|
874
|
-
return __wrapper
|
875
|
-
return _decorator
|
876
|
-
|
877
|
-
class TrackedPatternMatcher(PatternMatcher):
|
878
|
-
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
879
|
-
ret = None
|
880
|
-
ler = {u.op for u in uop.src}
|
881
|
-
for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
|
882
|
-
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
883
|
-
st = time.perf_counter()
|
884
|
-
if not early_reject.issubset(ler):
|
885
|
-
match_stats[p][2] += time.perf_counter()-st
|
886
|
-
continue
|
887
|
-
match_stats[p][1] += 1
|
888
|
-
for match in p.match(uop, {}):
|
889
|
-
if (ret:=(fxn(ctx=ctx, **match) if has_ctx else fxn(**match))) is not None:
|
890
|
-
match_stats[p][0] += 1
|
891
|
-
match_stats[p][3] += (et:=time.perf_counter()-st)
|
892
|
-
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
|
893
|
-
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0: tracked_ctxs[-1][-1].matches.append((uop, ret, p))
|
894
|
-
return ret # NOTE: if it returns None, we keep trying to match
|
895
|
-
match_stats[p][2] += time.perf_counter()-st
|
896
|
-
return None
|
897
|
-
|
898
|
-
if TRACK_MATCH_STATS:
|
899
|
-
PatternMatcher = TrackedPatternMatcher # type: ignore
|
900
|
-
import atexit
|
901
|
-
@atexit.register
|
902
|
-
def print_match_stats():
|
903
|
-
if TRACK_MATCH_STATS >= 2:
|
904
|
-
with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f:
|
905
|
-
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
|
906
|
-
with Context(PICKLE_BUFFERS=0): pickle.dump((tracked_keys, tracked_ctxs), f)
|
907
|
-
if getenv("VIZ"): launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
|
908
|
-
if getenv("PRINT_MATCH_STATS", 1):
|
909
|
-
ret = [0,0,0.0,0.0]
|
910
|
-
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
|
911
|
-
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
|
912
|
-
if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {(v[2]+v[3])*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
|
913
|
-
ret = [x+y for x,y in zip(ret, v)]
|
914
|
-
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {(ret[2]+ret[3])*1000.:9.2f} ms -- TOTAL")
|
915
|
-
|
916
|
-
def launch_viz(env_str:str, data:str):
|
917
|
-
os.environ[env_str] = "0"
|
918
|
-
os.environ[f"{env_str}_DATA"] = data
|
919
|
-
if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")):
|
920
|
-
args = ['--kernels', getenv("VIZ_DATA", "")] if getenv("VIZ_DATA", "") else []
|
921
|
-
args += ['--profile', getenv("PROFILE_DATA", "")] if getenv("PROFILE_DATA", "") else []
|
922
|
-
os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py")] + args)
|
923
|
-
|
924
|
-
# *** simple graph rewrite engine ***
|
925
|
-
|
926
|
-
class RewriteContext:
|
927
|
-
def __init__(self, pm, ctx=None):
|
928
|
-
self.pm: PatternMatcher = pm
|
929
|
-
self.ctx = ctx
|
930
|
-
self.replace: dict[UOp, UOp] = {}
|
931
|
-
def top_down_rewrite(self, n:UOp) -> UOp:
|
932
|
-
if (rn := self.replace.get(n)) is not None: return rn
|
933
|
-
new_src = tuple([self.top_down_rewrite(x) for x in n.src])
|
934
|
-
new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg)
|
935
|
-
self.replace[n] = ret = n if new_n is None else self.top_down_rewrite(new_n)
|
936
|
-
return ret
|
937
|
-
def bottom_up_rewrite(self, n:UOp) -> UOp:
|
938
|
-
if (rn := self.replace.get(n)) is not None: return rn
|
939
|
-
new_n: UOp|None = n
|
940
|
-
while new_n is not None: last_n, new_n = new_n, self.pm.rewrite(new_n, self.ctx)
|
941
|
-
new_src = tuple([self.bottom_up_rewrite(x) for x in last_n.src])
|
942
|
-
self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
|
943
|
-
return ret
|
944
|
-
|
945
|
-
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> UOp:
|
946
|
-
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
|
947
|
-
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
|
948
|
-
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink)
|
949
|
-
|
950
|
-
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> dict[UOp, UOp]:
|
951
|
-
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
|
952
|
-
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
|
953
|
-
rewrite_ctx = RewriteContext(pm, ctx)
|
954
|
-
return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
|
955
|
-
|
956
|
-
def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x
|
957
|
-
|
958
|
-
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
959
|
-
|
960
|
-
# for debug
|
961
|
-
syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>",
|
962
|
-
Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
|
963
|
-
renderer = PatternMatcher([
|
964
|
-
(UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
|
965
|
-
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg}")),
|
966
|
-
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
|
967
|
-
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
968
|
-
(UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
|
969
|
-
(UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
|
970
|
-
(UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
|
971
|
-
(UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
|
972
|
-
(UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.op]}{x.src[1].arg})")),
|
973
|
-
])
|
974
|
-
|
975
|
-
# *** what was symbolic.py ***
|
976
|
-
|
977
|
-
sint = Union[int, UOp]
|
978
|
-
Variable = UOp
|
979
|
-
|
980
|
-
ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]]
|
981
|
-
|
982
|
-
# *** UOp merge views and swizzling ***
|
983
|
-
|
984
|
-
merge_views = PatternMatcher([
|
985
|
-
# VIEW(VIEW) merges to a single VIEW
|
986
|
-
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)),
|
987
|
-
# remove VIEW if it's contiguous and same as the base shape
|
988
|
-
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.All-{Ops.DEVICE}, name="x"),)), lambda vm,x: x if vm.st.contiguous and x.shape == vm.shape else None),
|
989
|
-
# merge unmasked const views
|
990
|
-
(UPat(Ops.VIEW, name="view", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
|
991
|
-
lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
|
992
|
-
])
|
993
|
-
|
994
|
-
# push VIEW to parents
|
995
|
-
view_left = merge_views+PatternMatcher([
|
996
|
-
# VIEW(CONST) becomes VALID
|
997
|
-
(UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.valid(vm.st)),
|
998
|
-
# VIEW before elementwise/buffer ops
|
999
|
-
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
|
1000
|
-
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
|
1001
|
-
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)),
|
1002
|
-
lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
1003
|
-
])
|