tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/uop/ops.py
ADDED
@@ -0,0 +1,1021 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Any, Callable, cast, TYPE_CHECKING, Type, Sequence
|
3
|
+
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
|
4
|
+
from dataclasses import dataclass, field
|
5
|
+
from enum import Enum, auto
|
6
|
+
from tinygrad.uop import Ops, GroupOp
|
7
|
+
from tinygrad.uop.mathtraits import MathTrait
|
8
|
+
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType
|
9
|
+
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten
|
10
|
+
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
13
|
+
from tinygrad.device import Buffer, MultiBuffer
|
14
|
+
|
15
|
+
# https://en.wikipedia.org/wiki/Identity_element
|
16
|
+
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
17
|
+
|
18
|
+
def can_pad(root:UOp, edges:dict[UOp, None]) -> bool:
|
19
|
+
return all(u.op not in GroupOp.UnsafePad for u in root.toposort(gate=lambda x:x not in edges))
|
20
|
+
|
21
|
+
# With True as the default, this matches the old symbolic behavior
|
22
|
+
def resolve(x:UOp|bool, default:bool=True):
|
23
|
+
if isinstance(x, bool): return x
|
24
|
+
assert x.dtype == dtypes.bool, "UOp in resolve must be bool"
|
25
|
+
# NOTE: generating the text for the exception is expensive, so we do this
|
26
|
+
return bool(sx.vmin) if (sx:=x.simplify()).vmin == sx.vmax else default
|
27
|
+
|
28
|
+
# smax/smin are replacements for max/min that preserve symbolic
|
29
|
+
def _suop(lst, uop_fxn, python_fxn):
|
30
|
+
uops, nums = partition(lst, lambda x: isinstance(x, UOp))
|
31
|
+
return ssimplify(functools.reduce(uop_fxn, uops + ([python_fxn(nums)] if nums else [])))
|
32
|
+
def smax(*lst): return _suop(argfix(*lst), UOp.maximum, max)
|
33
|
+
def smin(*lst): return _suop(argfix(*lst), UOp.minimum, min)
|
34
|
+
def srender(x) -> str: return x.render() if isinstance(x, UOp) else str(x)
|
35
|
+
|
36
|
+
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
37
|
+
def sym_infer(uop: UOp|int, var_vals: dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
38
|
+
|
39
|
+
# used for UOp and UPat
|
40
|
+
def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
|
41
|
+
def dfs(x:Any, cache:dict):
|
42
|
+
for s in srcfn(x) or []:
|
43
|
+
cache.setdefault(s, [len(cache), 0, False])[1] += 1
|
44
|
+
if cache[s][1] == 1: dfs(s, cache)
|
45
|
+
if cache is None: dfs(x, cache:={})
|
46
|
+
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
|
47
|
+
cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
|
48
|
+
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
|
49
|
+
|
50
|
+
class UOpMetaClass(type):
|
51
|
+
ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
|
52
|
+
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, tag:Any=None,
|
53
|
+
metadata:tuple[Metadata,...]|None=None, _buffer:Buffer|None=None):
|
54
|
+
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg, tag), None)) is not None and (ret:=wret()) is not None: return ret
|
55
|
+
UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
|
56
|
+
for s in src: s.children.add(ref)
|
57
|
+
if metadata is not None: all_metadata[created] = metadata
|
58
|
+
# NOTE: this value is set by pickle when pickling a realized tensor
|
59
|
+
if _buffer is not None:
|
60
|
+
assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
|
61
|
+
buffers[created] = _buffer
|
62
|
+
return created
|
63
|
+
|
64
|
+
# some uops map to other stuff
|
65
|
+
buffers:weakref.WeakKeyDictionary[UOp, Buffer|MultiBuffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers
|
66
|
+
all_metadata:weakref.WeakKeyDictionary[UOp, tuple[Metadata, ...]] = weakref.WeakKeyDictionary() # TODO: should this be here?
|
67
|
+
|
68
|
+
# NOTE: this should be frozen, but frozen is slower
|
69
|
+
@dataclass(eq=False, slots=True)
|
70
|
+
class UOp(MathTrait, metaclass=UOpMetaClass):
|
71
|
+
op:Ops
|
72
|
+
dtype:DType = dtypes.void
|
73
|
+
src:tuple[UOp, ...] = tuple()
|
74
|
+
arg:Any = None
|
75
|
+
tag:Any = None
|
76
|
+
children:set[weakref.ref[UOp]] = field(default_factory=set)
|
77
|
+
def __del__(self):
|
78
|
+
if Ops is not None and self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1)
|
79
|
+
try:
|
80
|
+
if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg, self.tag))) is not None:
|
81
|
+
for s in self.src: s.children.discard(ref)
|
82
|
+
del UOpMetaClass.ucache[k]
|
83
|
+
except AttributeError: pass
|
84
|
+
def __reduce__(self):
|
85
|
+
args = [self.op, self.dtype, self.src, self.arg, self.tag, self.metadata]
|
86
|
+
if self.op is Ops.BUFFER and self.realized is not None and PICKLE_BUFFERS: args.append(self.realized)
|
87
|
+
return UOp, tuple(args)
|
88
|
+
def replace(self, **kwargs) -> UOp:
|
89
|
+
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src),
|
90
|
+
kwargs.pop("arg", self.arg), kwargs.pop("tag", self.tag))
|
91
|
+
assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}"
|
92
|
+
if (self.op, self.dtype, self.src, self.arg, self.tag) == new_args: return self
|
93
|
+
return UOp(*new_args)
|
94
|
+
def rtag(self, tag=True): return self.replace(tag=tag)
|
95
|
+
@functools.cached_property
|
96
|
+
def key(self) -> bytes:
|
97
|
+
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
|
98
|
+
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}{x.tagstr()}, src=(%s))")
|
99
|
+
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else repr(self.arg)
|
100
|
+
def tagstr(self): return f", tag={self.tag}" if self.tag is not None else ""
|
101
|
+
|
102
|
+
@functools.cached_property
|
103
|
+
def parents(self:UOp) -> dict[UOp, None]:
|
104
|
+
ret = {s:None for s in self.src}
|
105
|
+
for s in self.src: ret.update(s.parents)
|
106
|
+
return ret
|
107
|
+
@property
|
108
|
+
def sparents(self:UOp) -> dict[UOp, None]: return {self:None, **self.parents}
|
109
|
+
|
110
|
+
def toposort(self, gate:Callable|None=None) -> dict[UOp, None]:
|
111
|
+
ret: dict[UOp, None] = {}
|
112
|
+
stack: list[tuple[UOp, bool]] = [(self, False)] # each stack entry is (node, visited_flag)
|
113
|
+
while stack:
|
114
|
+
node, visited = stack.pop()
|
115
|
+
if node in ret: continue
|
116
|
+
if not visited:
|
117
|
+
if gate is None or gate(node):
|
118
|
+
stack.append((node, True)) # push node back on stack to process after its parents
|
119
|
+
for parent in reversed(node.src): stack.append((parent, False)) # push parents on the stack
|
120
|
+
else: ret[node] = None # second time i'm seeing this node, add it to returned toposort
|
121
|
+
return ret
|
122
|
+
|
123
|
+
# returns map of UOps to their children in the graph rooted by self
|
124
|
+
def get_children_map(self) -> dict[UOp, dict[UOp, None]]:
|
125
|
+
ret: dict[UOp, dict[UOp, None]] = {}
|
126
|
+
for u in self.toposort():
|
127
|
+
ret[u] = {}
|
128
|
+
for s in u.src: ret[s][u] = None
|
129
|
+
return ret
|
130
|
+
|
131
|
+
@functools.cached_property
|
132
|
+
def tuplize(self:UOp) -> tuple:
|
133
|
+
return (self.op.value, self.arg, self.dtype,)+tuple([x.tuplize for x in self.src])
|
134
|
+
|
135
|
+
# *** uop shape stuff ***
|
136
|
+
|
137
|
+
@functools.cached_property
|
138
|
+
def st(self) -> ShapeTracker|None:
|
139
|
+
if self.op in GroupOp.Block or self.op is Ops.INDEX: return None
|
140
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
141
|
+
# VIEW and MovementOps define a new ShapeTracker from the arg
|
142
|
+
if self.op is Ops.VIEW: return self.arg
|
143
|
+
# allow reshape from nothing
|
144
|
+
if self.op is Ops.RESHAPE and self.src[0].st is None: return ShapeTracker.from_shape(self.arg)
|
145
|
+
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
|
146
|
+
# CONST with a DEVICE has a shape of ()
|
147
|
+
if self.op is Ops.CONST and len(self.src) and self.src[0].op is Ops.DEVICE: return ShapeTracker.from_shape(())
|
148
|
+
# BufferOps and ASSIGN flow ShapeTracker from a direct edge
|
149
|
+
if self.op in {Ops.STORE, Ops.ASSIGN, Ops.LOAD}: return self.src[0].st
|
150
|
+
if self.op in GroupOp.Buffer: return views[0] if (views:=[x.st for x in self.src if x.op is Ops.VIEW]) else None
|
151
|
+
|
152
|
+
# BUFFER/BUFFER_VIEW and KERNEL only have a size
|
153
|
+
if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return ShapeTracker.from_shape((self.size,))
|
154
|
+
if self.op is Ops.KERNEL: return ShapeTracker.from_shape((self.arg.ast.size,))
|
155
|
+
if self.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
|
156
|
+
sz = cast(PtrDType, self.dtype).size
|
157
|
+
return ShapeTracker.from_shape((sz,)) if sz > 0 else None
|
158
|
+
|
159
|
+
# hack for PTX, CASTing the ptr loses the shape
|
160
|
+
if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL: return None
|
161
|
+
|
162
|
+
# otherwise we get the shape from sources
|
163
|
+
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
|
164
|
+
assert all_same([x.shape for x in src_sts]), f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}"
|
165
|
+
match self.op:
|
166
|
+
case Ops.MULTI: shape = tuple(self.src[0].shape[a]*len(self.device) if a == self.axis else s for a,s in enumerate(self.src[0].shape))
|
167
|
+
case Ops.BITCAST:
|
168
|
+
shape = src_sts[0].shape
|
169
|
+
if self.dtype.itemsize != (input_sz:=self.src[0].dtype.itemsize): shape = shape[:-1]+((shape[-1]*input_sz) // self.dtype.itemsize,)
|
170
|
+
case Ops.REDUCE_AXIS | Ops.WMMA: shape = src_sts[0].reduce(self.axis_arg)
|
171
|
+
case _: shape = src_sts[0].shape
|
172
|
+
return ShapeTracker.from_shape(shape)
|
173
|
+
|
174
|
+
@functools.cached_property
|
175
|
+
def full_shape(self) -> tuple[sint, ...]:
|
176
|
+
if self.op is Ops.VIEW: return self.shape
|
177
|
+
# NOTE: if a parent doesn't have st its full_shape is empty
|
178
|
+
parent_shapes = [x.full_shape for x in self.src]
|
179
|
+
return tuple(smax(x) for x in itertools.zip_longest(*parent_shapes, fillvalue=1))
|
180
|
+
@property
|
181
|
+
def shape(self) -> tuple[sint, ...]:
|
182
|
+
assert self.st is not None, f"{self.op} doesn't have a shape"
|
183
|
+
return unwrap(self.st).shape
|
184
|
+
@property
|
185
|
+
def size(self) -> int: return self.arg[0] if self.op is Ops.BUFFER_VIEW else self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
|
186
|
+
|
187
|
+
# determine what ranges this is in
|
188
|
+
@functools.cached_property
|
189
|
+
def ranges(self) -> dict[UOp, None]:
|
190
|
+
if self.op is Ops.RANGE: return {self:None}
|
191
|
+
if self.op in {Ops.BUFFERIZE, Ops.REDUCE}:
|
192
|
+
ret = self.src[0].ranges.copy()
|
193
|
+
for s in self.src[1:]:
|
194
|
+
if s in ret: del ret[s]
|
195
|
+
elif self.op in {Ops.STORE}:
|
196
|
+
ret = self.src[0].ranges.copy()
|
197
|
+
ret.update(self.src[1].ranges)
|
198
|
+
for s in self.src[2:]:
|
199
|
+
if s in ret: del ret[s]
|
200
|
+
else:
|
201
|
+
ret = {}
|
202
|
+
for s in self.src: ret.update(s.ranges)
|
203
|
+
return ret
|
204
|
+
|
205
|
+
# *** uop evaluation ***
|
206
|
+
|
207
|
+
def simplify(self):
|
208
|
+
# late import!
|
209
|
+
from tinygrad.uop.symbolic import symbolic
|
210
|
+
with Context(TRACK_MATCH_STATS=0):
|
211
|
+
return graph_rewrite(self, symbolic)
|
212
|
+
def ssimplify(self) -> UOp|ConstType: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
|
213
|
+
def _eval(self, dtype, expected_type:Type[T]) -> T:
|
214
|
+
assert self.dtype in dtype, f"eval with wrong dtype {self}"
|
215
|
+
vmin, vmax = (simple_self:=self.simplify())._min_max
|
216
|
+
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self.render()}")
|
217
|
+
assert isinstance(vmin, expected_type), f"vmin is wrong dtype {type(vmin)} != {expected_type}"
|
218
|
+
return vmin
|
219
|
+
def __bool__(self): return self._eval((dtypes.bool,), bool)
|
220
|
+
def __int__(self): return self._eval(dtypes.ints, int)
|
221
|
+
def __float__(self): return self._eval(dtypes.floats, float)
|
222
|
+
def substitute(self, dvars:dict[UOp, UOp], name:str|None=None):
|
223
|
+
dvars = {k:v for k,v in dvars.items() if k is not v}
|
224
|
+
if len(dvars) == 0: return self
|
225
|
+
with Context(TRACK_MATCH_STATS=(0 if name is None else TRACK_MATCH_STATS.value)):
|
226
|
+
return graph_rewrite(self, _substitute, dvars, bottom_up=True, name=name)
|
227
|
+
|
228
|
+
# *** uop syntactic sugar ***
|
229
|
+
|
230
|
+
@property
|
231
|
+
def st_arg(self) -> ShapeTracker:
|
232
|
+
assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}"
|
233
|
+
return unwrap(self.st)
|
234
|
+
@property
|
235
|
+
def axis_arg(self) -> tuple[int, ...]:
|
236
|
+
assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
|
237
|
+
ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
|
238
|
+
assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
|
239
|
+
return ret
|
240
|
+
def sink(self, *srcs:UOp|None, **kwargs): return UOp(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
241
|
+
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
242
|
+
def index(self, *srcs:UOp|None, **kwargs):
|
243
|
+
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
244
|
+
def __getitem__(self, idx): return self.index(idx)
|
245
|
+
def const_like(self, b:ConstLike):
|
246
|
+
# constants can optionally have a DEVICE source
|
247
|
+
return UOp.const(self.dtype, b, device=self._device, shape=self.shape if self.st is not None else None)
|
248
|
+
def broadcast(self, count:int):
|
249
|
+
assert self.dtype.count == 1
|
250
|
+
if count == 1: return self
|
251
|
+
return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
|
252
|
+
def cast(self, dtype:DType):
|
253
|
+
# TODO: we shouldn't have to check for dtype.count == 1 here, but CAST is misused in AMD LLVM
|
254
|
+
if dtype.count == 1 and dtype.count != self.dtype.count: dtype = dtype.vec(self.dtype.count)
|
255
|
+
if self.dtype == dtype: return self
|
256
|
+
return UOp(Ops.CAST, dtype, (self,))
|
257
|
+
def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
|
258
|
+
def gep(self, i:tuple[int, ...]|int):
|
259
|
+
if isinstance(i, tuple) and len(i) == 1: return self.gep(i[0])
|
260
|
+
if isinstance(i, int):
|
261
|
+
# NOTE: these are just shortcuts to not have to create and fold later
|
262
|
+
if self.op is Ops.VECTORIZE: return self.src[i]
|
263
|
+
if self.op is Ops.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
|
264
|
+
if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg)
|
265
|
+
i = (i,)
|
266
|
+
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
267
|
+
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
|
268
|
+
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
269
|
+
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
|
270
|
+
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
|
271
|
+
def alu(self, op, *src:UOp, **kwargs):
|
272
|
+
out_dtype = (self, *src)[-1].dtype
|
273
|
+
if op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
274
|
+
return UOp(op, out_dtype, (self,)+src, **kwargs)
|
275
|
+
@staticmethod
|
276
|
+
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None):
|
277
|
+
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
278
|
+
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
279
|
+
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
|
280
|
+
if shape is not None:
|
281
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
282
|
+
ret = ret.replace(src=(UOp(Ops.VIEW, dtypes.void, (), ShapeTracker.from_shape(shape, (0,)*len(shape))),))
|
283
|
+
if device is not None:
|
284
|
+
if shape is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
|
285
|
+
else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
|
286
|
+
return ret
|
287
|
+
@staticmethod
|
288
|
+
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
|
289
|
+
def r(self, op:Ops, axis:tuple[int, ...]):
|
290
|
+
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
291
|
+
if len(axis) == 0: return self
|
292
|
+
# move any non reduce axis before the first reduce axis
|
293
|
+
move_early, rest = partition(range(axis[0], len(self.shape)), lambda i: i not in axis and resolve(self.shape[i] != 1))
|
294
|
+
permaxis = tuple(range(axis[0])) + tuple(move_early) + tuple(rest)
|
295
|
+
ret = self.permute(permaxis)
|
296
|
+
new_axis = tuple([x for x in range(axis[0]+len(move_early), len(self.shape)) if resolve(ret.shape[x] != 1)])
|
297
|
+
assert len(axis) == len(new_axis)
|
298
|
+
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis))
|
299
|
+
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
|
300
|
+
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
301
|
+
def contiguous(self, *args, **kwargs): return UOp(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
|
302
|
+
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
|
303
|
+
def bufferize(self, *args, **kwargs): return UOp(Ops.BUFFERIZE, dtype=self.dtype, src=(self,)+args, **kwargs)
|
304
|
+
def fuse(self): return self.alu(Ops.FUSE)
|
305
|
+
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
|
306
|
+
assert isinstance(self.device, tuple), f"allreduce must be on tuple {self.device} isn't"
|
307
|
+
return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op)
|
308
|
+
|
309
|
+
# *** from MultiLazyBuffer ***
|
310
|
+
|
311
|
+
def multi(self, axis:int|None):
|
312
|
+
assert isinstance(self.device, tuple), f"multi device must be tuple, {self.device} isn't"
|
313
|
+
assert axis is not None, "multi None is no longer supported"
|
314
|
+
return UOp(Ops.MULTI, self.dtype, (self,), axis)
|
315
|
+
|
316
|
+
@property
|
317
|
+
def bounds(self):
|
318
|
+
if self.axis is None: raise RuntimeError("bounds is not defined when axis is None")
|
319
|
+
return tuple(itertools.pairwise(itertools.accumulate([self.src[0].shape[self.axis] for _ in self.device], initial=0)))
|
320
|
+
|
321
|
+
@functools.cached_property
|
322
|
+
def axis(self) -> int|None:
|
323
|
+
if self.op is Ops.MULTI: return self.arg
|
324
|
+
# NOTE: they all have to share an axis, we always choose [-1]
|
325
|
+
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
|
326
|
+
if len(self.src) == 0: return None
|
327
|
+
src_axis = self.src[0].axis
|
328
|
+
if self.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in self.arg[1] else src_axis
|
329
|
+
if self.op is Ops.RESHAPE:
|
330
|
+
if src_axis is None: return None
|
331
|
+
arg_acc:list[sint] = list(itertools.accumulate(self.arg, operator.mul, initial=1))
|
332
|
+
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
|
333
|
+
# TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
|
334
|
+
return len(arg_acc) - arg_acc[::-1].index(prod(self.src[0].shape[:src_axis])) - 1
|
335
|
+
if self.op is Ops.PERMUTE: return self.arg.index(src_axis) if src_axis is not None else None
|
336
|
+
return src_axis
|
337
|
+
|
338
|
+
def _unshard(self, axis:int) -> UOp:
|
339
|
+
bsz, dcount = self.shape[axis], len(self.device)
|
340
|
+
dnum = UOp.variable("_device_num", 0, dcount-1)
|
341
|
+
return self.pad(tuple((0,0) if a != axis else (bsz*dnum, bsz*(dcount-1) - bsz*dnum) for a in range(len(self.shape))))
|
342
|
+
|
343
|
+
def _shard(self, axis:int) -> UOp:
|
344
|
+
dcount = len(self.device)
|
345
|
+
dnum = UOp.variable("_device_num", 0, dcount-1)
|
346
|
+
if self.shape[axis] % dcount != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {dcount=}")
|
347
|
+
sz = self.shape[axis] // dcount
|
348
|
+
return self.shrink(tuple((0,s) if i != axis else (dnum*sz,dnum*sz+sz) for i,s in enumerate(self.shape)))
|
349
|
+
def shard(self, devices:tuple[str, ...], axis:int) -> UOp: return self.copy_to_device(devices)._shard(axis).multi(axis)
|
350
|
+
|
351
|
+
# *** from LazyBuffer ***
|
352
|
+
|
353
|
+
def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None):
|
354
|
+
assert arg is None or isinstance(self.device, tuple)
|
355
|
+
inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self,), arg=arg)
|
356
|
+
return UOp(Ops.COPY, self.dtype, (inp, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device))
|
357
|
+
def mselect(self, arg:int) -> UOp: return UOp(Ops.MSELECT, self.dtype, (self,), arg)
|
358
|
+
@property
|
359
|
+
def metadata(self) -> tuple[Metadata, ...]|None: return all_metadata.get(self, None)
|
360
|
+
|
361
|
+
# *** uop movement ops ***
|
362
|
+
|
363
|
+
@property
|
364
|
+
def base(self) -> UOp:
|
365
|
+
if (self.op is Ops.VIEW and len(self.src) != 0) or self.op in GroupOp.Movement: return self.src[0].base
|
366
|
+
if self.op is Ops.MULTI: return self.src[0].base # MULTI is really a VIEW
|
367
|
+
return self
|
368
|
+
def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self,), new_st)
|
369
|
+
|
370
|
+
def _mop(self, op:Ops, arg) -> UOp:
|
371
|
+
ret = UOp(op, self.dtype, (self,), arg)
|
372
|
+
if self.st == ret.st: return self # ignore NOOPs, also check ret.st
|
373
|
+
return ret
|
374
|
+
|
375
|
+
def forced_reshape(self, arg:tuple[sint, ...]): return UOp(Ops.RESHAPE, self.dtype, src=(self,), arg=arg)
|
376
|
+
def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg)
|
377
|
+
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg)
|
378
|
+
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg)
|
379
|
+
def permute(self, arg:tuple[sint, ...]): return self._mop(Ops.PERMUTE, arg)
|
380
|
+
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg)
|
381
|
+
def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg)
|
382
|
+
|
383
|
+
# *** uop UNIQUE ***
|
384
|
+
|
385
|
+
# TODO: use this in Buffer
|
386
|
+
unique_num = itertools.count(0)
|
387
|
+
@staticmethod
|
388
|
+
def unique(): return UOp(Ops.UNIQUE, arg=next(UOp.unique_num))
|
389
|
+
|
390
|
+
# *** uop Buffer stuff ***
|
391
|
+
|
392
|
+
@staticmethod
|
393
|
+
def new_buffer(device:str|tuple[str, ...], size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp.unique(), UOp(Ops.DEVICE, arg=device)), size)
|
394
|
+
@property
|
395
|
+
def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
|
396
|
+
@functools.cached_property
|
397
|
+
def _device(self) -> str|tuple[str, ...]|None:
|
398
|
+
if self.op is Ops.DEVICE: return self.arg
|
399
|
+
if self.op is Ops.MSELECT:
|
400
|
+
assert isinstance(self.src[0].device, tuple), "mselect must be on tuple device"
|
401
|
+
return self.src[0].device[self.arg]
|
402
|
+
if self.op is Ops.MSTACK: return tuple(cast(str, x.device) for x in self.src)
|
403
|
+
if self.op in {Ops.COPY, Ops.BUFFER, Ops.ALLREDUCE}: return self.src[1].device
|
404
|
+
return next((x._device for x in self.src if x._device is not None), None)
|
405
|
+
@property
|
406
|
+
def buf_uop(self) -> UOp:
|
407
|
+
if self.op is Ops.BUFFER: return self
|
408
|
+
if self.op is Ops.MSELECT: return self.src[0].buf_uop.mselect(self.arg)
|
409
|
+
if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.buf_uop for x in self.src))
|
410
|
+
assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}"
|
411
|
+
return self.src[0].base
|
412
|
+
@property
|
413
|
+
def buffer(self) -> Buffer|MultiBuffer:
|
414
|
+
from tinygrad.device import Buffer, MultiBuffer
|
415
|
+
if self is not self.base:
|
416
|
+
assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
|
417
|
+
return self.src[0].buffer
|
418
|
+
if self.op is Ops.MSELECT:
|
419
|
+
ret = self.src[0].buffer
|
420
|
+
assert isinstance(ret, MultiBuffer)
|
421
|
+
return ret.bufs[self.arg]
|
422
|
+
if self.op is Ops.MSTACK:
|
423
|
+
ret = MultiBuffer.__new__(MultiBuffer)
|
424
|
+
ret.bufs = [cast(Buffer, x.buffer) for x in self.src]
|
425
|
+
assert all_same([x.size for x in ret.bufs]) and all_same([x.dtype for x in ret.bufs]), "multibuffers mismatch buffers"
|
426
|
+
return ret
|
427
|
+
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
|
428
|
+
if (cret:=buffers.get(self)) is not None: return cret
|
429
|
+
rdtype = self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base
|
430
|
+
if isinstance(self.device, tuple): ret = MultiBuffer(self.device, self.size, rdtype).ref(1)
|
431
|
+
else: ret = Buffer(self.device, self.size, rdtype).ref(1)
|
432
|
+
buffers[self] = ret
|
433
|
+
return ret
|
434
|
+
@property
|
435
|
+
def realized(self) -> Buffer|MultiBuffer|None:
|
436
|
+
# NOTE: this is used by the JIT to determine which inputs we capture
|
437
|
+
return self.buffer if self.op in {Ops.BUFFER, Ops.MSTACK} and self.buffer.is_allocated() else None
|
438
|
+
@property
|
439
|
+
def is_realized(self) -> bool:
|
440
|
+
return all(x.base.realized is not None for x in self.base.src) if self.base.op is Ops.MULTI else self.base.realized is not None
|
441
|
+
|
442
|
+
# *** uop Variable stuff ***
|
443
|
+
|
444
|
+
@staticmethod
|
445
|
+
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int) -> UOp:
|
446
|
+
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
|
447
|
+
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
448
|
+
@property
|
449
|
+
def expr(self):
|
450
|
+
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
451
|
+
return self.arg[0]
|
452
|
+
def bind(self, val:int|UOp):
|
453
|
+
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
454
|
+
uval = self.const_like(val) if isinstance(val, int) else val
|
455
|
+
assert self.arg[1] <= uval.vmin and uval.vmax <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
|
456
|
+
return UOp(Ops.BIND, self.dtype, (self, uval))
|
457
|
+
def unbind(self) -> tuple[Variable, int]:
|
458
|
+
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
|
459
|
+
return self.src[0], self.src[1].arg
|
460
|
+
@property
|
461
|
+
def val(self) -> int: return self.unbind()[1]
|
462
|
+
def vars(self) -> set[UOp]:
|
463
|
+
bound_vars = set([x for x in self.toposort() if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
|
464
|
+
bound_var_base = set(x.src[0] for x in bound_vars)
|
465
|
+
all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR])
|
466
|
+
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
467
|
+
def variables(self) -> list[Variable]:
|
468
|
+
st_vars: list[set[Variable]] = [x.arg.vars() for x in self.toposort() if x.op is Ops.VIEW]
|
469
|
+
return sorted(set.union(*st_vars, set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()])), key=lambda v: v.arg)
|
470
|
+
|
471
|
+
# *** uop symbolic stuff ***
|
472
|
+
|
473
|
+
def is_increasing(self:UOp) -> bool:
|
474
|
+
# is f a monotonically increasing function regards its input
|
475
|
+
if self.op in GroupOp.Irreducible: return True
|
476
|
+
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
|
477
|
+
if self.op in (Ops.MUL, Ops.IDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
|
478
|
+
return False # False if not sure
|
479
|
+
def const_factor(self) -> int:
|
480
|
+
"""largest known int that divides self"""
|
481
|
+
# TODO: for negatives it's not the largest
|
482
|
+
if self.op is Ops.CONST: return self.arg
|
483
|
+
if self.op is Ops.VCONST: return math.gcd(*self.arg)
|
484
|
+
if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
485
|
+
if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
|
486
|
+
return 1
|
487
|
+
def divides(self, v:int) -> UOp|None:
|
488
|
+
if v==1: return self
|
489
|
+
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
|
490
|
+
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
|
491
|
+
if self.op is Ops.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
|
492
|
+
if self.op is Ops.MUL:
|
493
|
+
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
|
494
|
+
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
495
|
+
return None # generic None if we aren't sure
|
496
|
+
def pop_const(self) -> tuple[UOp, int]: return (self.src[0], self.src[1].arg) if self.op is Ops.ADD and self.src[1].op is Ops.CONST else (self, 0)
|
497
|
+
@property
|
498
|
+
def vmin(self) -> ConstType: return self._min_max[0]
|
499
|
+
@property
|
500
|
+
def vmax(self) -> ConstType: return self._min_max[1]
|
501
|
+
@functools.cached_property
|
502
|
+
def _min_max(self) -> tuple[ConstType, ConstType]:
|
503
|
+
if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype):
|
504
|
+
(s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
|
505
|
+
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
|
506
|
+
if self.op is Ops.SUB: return s0_vmin-s1_vmax, s0_vmax-s1_vmin
|
507
|
+
if self.op is Ops.AND and s1_vmin == s1_vmax and s0_vmin >= 0 and s1_vmin >= 0: return min(0, s0_vmin), min(s0_vmax, s1_vmax)
|
508
|
+
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
|
509
|
+
# SHL/SHR on consts only
|
510
|
+
if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
|
511
|
+
if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2]
|
512
|
+
if self.op is Ops.MOD:
|
513
|
+
if s1_vmin > 0: return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), 0) if s0_vmax <= 0 else (-(s1_vmax-1), s1_vmax-1)
|
514
|
+
if s1_vmax < 0: return (0, -s1_vmin-1) if s0_vmin >= 0 else (-(-s1_vmin-1), 0) if s0_vmax <= 0 else (-(-s1_vmin-1), -s1_vmin-1)
|
515
|
+
if self.op is Ops.IDIV:
|
516
|
+
assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int)
|
517
|
+
if (c:=s1_vmin) == s1_vmax: # s1 is a const
|
518
|
+
if c > 0: return cdiv(s0_vmin, c), cdiv(s0_vmax, c)
|
519
|
+
if c < 0: return cdiv(s0_vmax, c), cdiv(s0_vmin, c)
|
520
|
+
if (s0_vmax <= 0 and s1_vmax < 0): return cdiv(s0_vmax, s1_vmin), cdiv(s0_vmin, s1_vmax)
|
521
|
+
if (s0_vmin >= 0 and s1_vmin > 0): return cdiv(s0_vmin, s1_vmax), cdiv(s0_vmax, s1_vmin)
|
522
|
+
if (s0_vmax <= 0 and s1_vmin > 0): return cdiv(s0_vmin, s1_vmin), cdiv(s0_vmax, s1_vmax)
|
523
|
+
if (s0_vmin >= 0 and s1_vmax < 0): return cdiv(s0_vmax, s1_vmax), cdiv(s0_vmin, s1_vmin)
|
524
|
+
if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
|
525
|
+
if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
|
526
|
+
if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
|
527
|
+
if self.dtype == dtypes.bool:
|
528
|
+
if self.op is Ops.OR: return s0_vmin or s1_vmin, s0_vmax or s1_vmax
|
529
|
+
if self.op is Ops.AND: return s0_vmin and s1_vmin, s0_vmax and s1_vmax
|
530
|
+
# float has NAN issue and we use explicit NAN in transcendental
|
531
|
+
if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax)
|
532
|
+
# NOTE: returned UOp is assumed to be CONST
|
533
|
+
if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
|
534
|
+
if self.op is Ops.RANGE: return 0, (self.src[0]-1).vmax
|
535
|
+
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
|
536
|
+
if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
|
537
|
+
# TODO: Ops.SPECIAL is Ops.DEFINE_VAR
|
538
|
+
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].vmax
|
539
|
+
if self.op is Ops.CONST: return self.arg, self.arg
|
540
|
+
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
|
541
|
+
# TODO: CAST to bool/unsigned is not monotone, still some case can be simplified
|
542
|
+
if self.op is Ops.CAST and self.dtype in (dtypes.floats+dtypes.sints):
|
543
|
+
return max(dtypes.min(self.dtype), self.src[0].vmin), min(self.src[0].vmax, dtypes.max(self.dtype))
|
544
|
+
return dtypes.min(self.dtype), dtypes.max(self.dtype)
|
545
|
+
|
546
|
+
@functools.cached_property
|
547
|
+
def _sym_fxn(self):
|
548
|
+
sself = self.simplify()
|
549
|
+
varnames = tuple(x.arg[0] for x in sself.toposort() if x.op is Ops.DEFINE_VAR)
|
550
|
+
# TODO: sanitize varnames, or don't use naked eval while staying fast
|
551
|
+
return eval("lambda "+','.join(varnames)+": "+sself.render(pm=renderer_infer)), varnames # pylint: disable=eval-used
|
552
|
+
|
553
|
+
def sym_infer(self, var_vals:dict[UOp, int]):
|
554
|
+
fxn, varnames = self._sym_fxn
|
555
|
+
return fxn(**{k.arg[0]:v for k,v in var_vals.items() if k.arg[0] in varnames})
|
556
|
+
|
557
|
+
def render(self, simplify=True, pm:PatternMatcher|None=None) -> str:
|
558
|
+
ret = graph_rewrite(self.simplify() if simplify else self, renderer if pm is None else pm)
|
559
|
+
return ret.arg if ret.op is Ops.NOOP else str(ret)
|
560
|
+
|
561
|
+
class AxisType(Enum):
|
562
|
+
GLOBAL = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
|
563
|
+
|
564
|
+
@dataclass(frozen=True)
|
565
|
+
class KernelInfo:
|
566
|
+
name: str = "test" # name of the kernel
|
567
|
+
axis_types: tuple[AxisType, ...] = tuple()
|
568
|
+
dont_use_locals: bool = False # don't use local indexing
|
569
|
+
applied_opts: tuple = tuple()
|
570
|
+
opts_to_apply: tuple|None = None
|
571
|
+
@property
|
572
|
+
def function_name(self): return to_function_name(self.name)
|
573
|
+
|
574
|
+
# ******** ops in python ********
|
575
|
+
|
576
|
+
def safe_exp2(x):
|
577
|
+
try: return 2 ** x
|
578
|
+
except OverflowError: return math.inf
|
579
|
+
|
580
|
+
def safe_pow(x, y):
|
581
|
+
try: return math.nan if isinstance(p:=pow(x, y), complex) else p
|
582
|
+
except ZeroDivisionError: return math.inf
|
583
|
+
except ValueError: return math.inf if x > 0 else -math.inf
|
584
|
+
|
585
|
+
python_alu: dict[Ops, Callable] = {
|
586
|
+
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
|
587
|
+
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
588
|
+
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow, Ops.TRUNC: math.trunc,
|
589
|
+
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
|
590
|
+
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
|
591
|
+
Ops.MOD: cmod, Ops.IDIV: cdiv, Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq}
|
592
|
+
|
593
|
+
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
594
|
+
if dtype.count > 1:
|
595
|
+
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
|
596
|
+
alu = python_alu[op](*operands)
|
597
|
+
return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu
|
598
|
+
|
599
|
+
# ***** uop helpers *****
|
600
|
+
|
601
|
+
def print_uops(uops:list[UOp]):
|
602
|
+
for i,u in enumerate(uops):
|
603
|
+
formatted_parents = [(uops.index(x) if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
|
604
|
+
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype):30s} " f"{str(formatted_parents):32s} {u.arg}")
|
605
|
+
|
606
|
+
# ***** pattern matcher *****
|
607
|
+
|
608
|
+
def get_location() -> tuple[str, int]:
|
609
|
+
frm = sys._getframe(1)
|
610
|
+
# skip over ops.py/mathtraits.py (unless there's nothing but ops.py/mathtraits.py)
|
611
|
+
while pathlib.Path(frm.f_code.co_filename).name in ("ops.py", "mathtraits.py") and frm.f_back is not None and \
|
612
|
+
not frm.f_back.f_code.co_filename.startswith("<frozen"):
|
613
|
+
frm = frm.f_back
|
614
|
+
return frm.f_code.co_filename, frm.f_lineno
|
615
|
+
|
616
|
+
@functools.cache
|
617
|
+
def lines(fn) -> list[str]:
|
618
|
+
with open(fn) as f: return f.readlines()
|
619
|
+
|
620
|
+
def printable(loc:tuple[str, int]) -> str:
|
621
|
+
try: return lines(loc[0])[loc[1]-1].strip()
|
622
|
+
except FileNotFoundError: return "<missing>"
|
623
|
+
|
624
|
+
class UPat(MathTrait):
|
625
|
+
__slots__ = ("op", "dtype", "arg", "name", "src")
|
626
|
+
def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|None=None,
|
627
|
+
src:tuple[UPat, ...]|list[UPat]|UPat|None=None, arg:Any=None,
|
628
|
+
name:str|None=None, allow_any_len:bool=False, custom_early_reject:set[Ops]|None=None, location=None):
|
629
|
+
assert op is None or isinstance(op, (Ops, tuple, set)), "op must be Ops or tuple of Ops"
|
630
|
+
self.op: tuple[Ops, ...]|None = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
|
631
|
+
self.dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else dtype
|
632
|
+
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
|
633
|
+
self.src: Any = None
|
634
|
+
assert self.name != "ctx", "UPat can't be named ctx"
|
635
|
+
assert dtype is None or isinstance(dtype, DType) or all(isinstance(x, DType) for x in dtype), f"invalid dtype {dtype}"
|
636
|
+
|
637
|
+
# try all permutations if it's a list
|
638
|
+
if isinstance(src, list): self.src = list(itertools.permutations(src)) if not all_same(src) else [tuple(src)]
|
639
|
+
# only one if it's a tuple
|
640
|
+
elif isinstance(src, tuple): self.src = [src]
|
641
|
+
# repeat if it's a UPat
|
642
|
+
elif isinstance(src, UPat): self.src = [itertools.repeat(src)]
|
643
|
+
|
644
|
+
self.strict_length = not (allow_any_len or isinstance(src, UPat) or src is None)
|
645
|
+
self.required_len: int = 0 if isinstance(src, UPat) or src is None else len(src)
|
646
|
+
self.location = location or get_location()
|
647
|
+
|
648
|
+
if custom_early_reject is not None: self.early_reject = custom_early_reject
|
649
|
+
else:
|
650
|
+
upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0])
|
651
|
+
self.early_reject = {pp.op[0] for pp in upat_match if pp.op is not None and len(pp.op) == 1}
|
652
|
+
|
653
|
+
def __reduce__(self):
|
654
|
+
return UPat, (self.op, self.dtype, self._in_src, self.arg, self.name, not self.strict_length, self.custom_early_reject, self.location)
|
655
|
+
def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, not self.strict_length, self.custom_early_reject)
|
656
|
+
|
657
|
+
@staticmethod
|
658
|
+
def any(*src): return UPatAny(src=src)
|
659
|
+
def or_casted(self, name:str|None=None): return UPat.any(self if name is None else self.named(name), UPat(Ops.CAST, name=name, src=(self,)))
|
660
|
+
|
661
|
+
@staticmethod
|
662
|
+
@functools.cache
|
663
|
+
def var(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None): return UPat(dtype=dtype, name=name)
|
664
|
+
@staticmethod
|
665
|
+
@functools.cache
|
666
|
+
def cvar(name:str|None=None, dtype:DType|None=None, vec=True): return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name)
|
667
|
+
@staticmethod
|
668
|
+
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
|
669
|
+
|
670
|
+
# lil helper
|
671
|
+
def f(self, op, **kwargs): return UPat(op, src=(self,), **kwargs)
|
672
|
+
|
673
|
+
# copied from UOp
|
674
|
+
def sink(self, *srcs:UPat|None, **kwargs): return UPat(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
675
|
+
def index(self, idx:UPat, valid:UPat|None=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
676
|
+
def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
|
677
|
+
def cast(self, dtype=None, **kwargs): return UPat(Ops.CAST, dtype, (self,), **kwargs)
|
678
|
+
def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
|
679
|
+
def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs)
|
680
|
+
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
|
681
|
+
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, self.dtype, (self,)+src, **kwargs)
|
682
|
+
def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.dtype, (self,x), **kwargs)
|
683
|
+
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs)
|
684
|
+
def fuse(self): return self.alu(Ops.FUSE)
|
685
|
+
def or_broadcasted(self, **kwargs): return UPat.any(self, UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs))
|
686
|
+
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
|
687
|
+
|
688
|
+
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
|
689
|
+
def alu(self, op:Ops, *src:UPat):
|
690
|
+
asrc = (self,)+src
|
691
|
+
return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
|
692
|
+
|
693
|
+
def __repr__(self):
|
694
|
+
def rep(x):
|
695
|
+
form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
|
696
|
+
return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name),
|
697
|
+
set(x.dtype) if x.dtype else None, not x.strict_length, "[%s]" if x.src and len(x.src)>1 else ("(%s)" if x.src else "%s"))
|
698
|
+
return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
|
699
|
+
|
700
|
+
def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
|
701
|
+
if (self.op is not None and uop.op not in self.op) or \
|
702
|
+
(self.name is not None and store.setdefault(self.name, uop) is not uop) or \
|
703
|
+
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
|
704
|
+
(self.arg is not None and self.arg != uop.arg) or \
|
705
|
+
(len(uop.src) < self.required_len) or \
|
706
|
+
(self.strict_length and len(uop.src) != self.required_len): return []
|
707
|
+
if self.src is None: return [store]
|
708
|
+
res: list[dict[str, UOp]] = []
|
709
|
+
for vp in self.src:
|
710
|
+
stores, new_stores = [store.copy()], []
|
711
|
+
for uu, vv in zip(uop.src, vp):
|
712
|
+
for s in stores: new_stores.extend(vv.match(uu, s))
|
713
|
+
stores, new_stores = new_stores, []
|
714
|
+
res.extend(stores)
|
715
|
+
return res
|
716
|
+
|
717
|
+
class UPatAny(UPat):
|
718
|
+
def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
|
719
|
+
matches = [x.match(uop, store.copy()) for x in self.src[0]]
|
720
|
+
return flatten([x for x in matches if x is not None])
|
721
|
+
|
722
|
+
def deconstruct_function(fxn:Callable) -> tuple:
|
723
|
+
new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
|
724
|
+
for co in fxn.__code__.co_consts:
|
725
|
+
if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names})
|
726
|
+
# NOTE: optional round trip through pickle!
|
727
|
+
assert fxn.__closure__ is None, "closures are not supported in pattern matchers"
|
728
|
+
ret = fxn.__code__, new_globals, fxn.__name__, fxn.__defaults__
|
729
|
+
return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret
|
730
|
+
|
731
|
+
@functools.cache
|
732
|
+
def upat_interpret(p:UPat, fxn:Callable) -> Callable:
|
733
|
+
real_fxn = types.FunctionType(*deconstruct_function(fxn))
|
734
|
+
if 'ctx' in inspect.signature(real_fxn).parameters:
|
735
|
+
def universal_match(uop, ctx):
|
736
|
+
for match in p.match(uop, {}):
|
737
|
+
if (ret:=real_fxn(ctx=ctx, **match)) is not None: return ret # pylint: disable=not-callable
|
738
|
+
return None
|
739
|
+
else:
|
740
|
+
def universal_match(uop, _):
|
741
|
+
for match in p.match(uop, {}):
|
742
|
+
if (ret:=real_fxn(**match)) is not None: return ret # pylint: disable=not-callable
|
743
|
+
return None
|
744
|
+
return universal_match
|
745
|
+
|
746
|
+
class PatternMatcher:
|
747
|
+
def __init__(self, patterns:Sequence[tuple[UPat, Callable|tuple]], compiled=bool(getenv("UPAT_COMPILE", 1))):
|
748
|
+
if compiled: from tinygrad.uop.upat import upat_compile
|
749
|
+
# if this comes from a pickle, we reconstruct the lambda functions here
|
750
|
+
self.patterns:list[tuple[UPat, Callable]] = [(p,types.FunctionType(*fxn) if isinstance(fxn, tuple) else fxn) for p,fxn in patterns]
|
751
|
+
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
|
752
|
+
self.pdict: dict[Ops, list[tuple[UPat, Callable, set]]] = {}
|
753
|
+
# uop is required, arg is optional
|
754
|
+
for p,fxn in self.patterns:
|
755
|
+
assert p.op is not None
|
756
|
+
if compiled and (match:=upat_compile(p, fxn)) is not None: pass # pylint: disable=E0606
|
757
|
+
else: match = upat_interpret(p, fxn)
|
758
|
+
for uop in p.op: self.pdict.setdefault(uop, []).append((p, match, p.early_reject))
|
759
|
+
|
760
|
+
def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "<lambda>" else fxn) for x,fxn in self.patterns],)
|
761
|
+
|
762
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
763
|
+
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
764
|
+
|
765
|
+
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
766
|
+
ler = {u.op for u in uop.src}
|
767
|
+
for _,match,early_reject in self.pdict.get(uop.op, []):
|
768
|
+
if not early_reject.issubset(ler): continue
|
769
|
+
if (ret:=match(uop, ctx)) is not None and ret is not uop: return ret
|
770
|
+
return None
|
771
|
+
|
772
|
+
# *** non-blocking UOp tracker ***
|
773
|
+
|
774
|
+
ucount = itertools.count()
|
775
|
+
uop_number:weakref.WeakKeyDictionary[UOp, int] = weakref.WeakKeyDictionary()
|
776
|
+
uop_fields:dict[int, tuple] = {}
|
777
|
+
def track_uop(u:UOp):
|
778
|
+
if (cret:=uop_number.get(u)) is not None: return cret
|
779
|
+
uop_number[u] = num = next(ucount)
|
780
|
+
# KERNEL also has a UOp in the arg
|
781
|
+
arg = type(u.arg)(track_uop(u.arg.ast), u.arg.metadata) if u.op is Ops.KERNEL else u.arg
|
782
|
+
uop_fields[num] = (u.op, u.dtype, tuple(track_uop(s) for s in u.src), arg, u.tag)
|
783
|
+
return num
|
784
|
+
|
785
|
+
# *** tracking pattern matcher ***
|
786
|
+
|
787
|
+
VIZ = ContextVar("VIZ", 0)
|
788
|
+
TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if VIZ else 0)
|
789
|
+
match_stats:dict[UPat, list[int|float]] = dict()
|
790
|
+
|
791
|
+
@dataclass(frozen=True)
|
792
|
+
class TrackedGraphRewrite:
|
793
|
+
loc:tuple[str, int] # location that called graph_rewrite
|
794
|
+
sink:int # the sink input to graph_rewrite
|
795
|
+
matches:list[tuple[int, int, tuple]] # before/after UOp, UPat location
|
796
|
+
name:str|None # optional name of the rewrite
|
797
|
+
depth:int # depth if it's a subrewrite
|
798
|
+
bottom_up:bool
|
799
|
+
|
800
|
+
tracked_keys:list[TracingKey] = []
|
801
|
+
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
|
802
|
+
_name_cnt:dict[str, itertools.count] = {}
|
803
|
+
|
804
|
+
if getenv("CAPTURE_PROCESS_REPLAY"):
|
805
|
+
replay_capture: dict[str, bytes] = {}
|
806
|
+
import atexit
|
807
|
+
@atexit.register
|
808
|
+
def save_to_diskcache():
|
809
|
+
for k,v in replay_capture.items(): diskcache_put("process_replay", k, v, prepickled=True)
|
810
|
+
|
811
|
+
def track_rewrites(name:Callable[..., str|TracingKey]|bool=True, replay:bool=False):
|
812
|
+
def _decorator(func):
|
813
|
+
def __wrapper(*args, **kwargs):
|
814
|
+
fn = key = func.__name__
|
815
|
+
if TRACK_MATCH_STATS >= 2:
|
816
|
+
tracked_keys.append(key:=TracingKey(n:=f"{fn} n{next(_name_cnt.setdefault(fn, itertools.count(1)))}", (n,), cat=fn))
|
817
|
+
tracked_ctxs.append([])
|
818
|
+
with cpu_profile(key, "TINY") as e:
|
819
|
+
ret = func(*args, **kwargs)
|
820
|
+
if TRACK_MATCH_STATS >= 2 and callable(name):
|
821
|
+
name_ret = name(*args, **kwargs, ret=ret)
|
822
|
+
assert isinstance(name_ret, (TracingKey, str)), f"name function returned {type(name_ret)}"
|
823
|
+
tracked_keys[-1] = k = TracingKey(n:=tracked_keys[-1].display_name.replace(fn, name_ret), (n,)) if isinstance(name_ret, str) else name_ret
|
824
|
+
e.name = TracingKey(k.display_name if isinstance(name_ret, str) else f"{fn} for {k.display_name}", k.keys, cat=fn)
|
825
|
+
if getenv("CAPTURE_PROCESS_REPLAY") and replay:
|
826
|
+
# find the unittest frame we're capturing in
|
827
|
+
frm = sys._getframe(1)
|
828
|
+
while (f_back:=frm.f_back) is not None and "unittest" not in f_back.f_code.co_filename: frm = f_back
|
829
|
+
loc = f"{frm.f_code.co_filename.split('/')[-1]}:{frm.f_lineno} {frm.f_code.co_name}"
|
830
|
+
# capture global context vars and all the args passed in
|
831
|
+
with Context(PICKLE_BUFFERS=0):
|
832
|
+
inputs = (fn, args, kwargs, ContextVar._cache)
|
833
|
+
replay_capture[hashlib.sha256(pickle.dumps(inputs)).hexdigest()] = pickle.dumps(inputs+(loc, ret))
|
834
|
+
return ret
|
835
|
+
return __wrapper
|
836
|
+
return _decorator
|
837
|
+
|
838
|
+
active_rewrites:list[TrackedGraphRewrite] = []
|
839
|
+
def track_matches(func):
|
840
|
+
def _track_func(*args, **kwargs):
|
841
|
+
if tracking:=(TRACK_MATCH_STATS >= 2 and tracked_ctxs):
|
842
|
+
loc = ((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno)
|
843
|
+
depth = len(active_rewrites)
|
844
|
+
tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, track_uop(args[0]), [], kwargs.get("name", None), depth, kwargs.get("bottom_up", False)))
|
845
|
+
active_rewrites.append(ctx)
|
846
|
+
with cpu_profile(kwargs.get("name", "<unnamed>"), "TINY", display=tracking):
|
847
|
+
ret = func(*args, **kwargs)
|
848
|
+
if tracking: active_rewrites.pop()
|
849
|
+
return ret
|
850
|
+
return _track_func
|
851
|
+
|
852
|
+
class TrackedPatternMatcher(PatternMatcher):
|
853
|
+
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
854
|
+
ret = None
|
855
|
+
ler = {u.op for u in uop.src}
|
856
|
+
for p,match,early_reject in self.pdict.get(uop.op, []):
|
857
|
+
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
858
|
+
st = time.perf_counter()
|
859
|
+
if not early_reject.issubset(ler):
|
860
|
+
match_stats[p][2] += time.perf_counter()-st
|
861
|
+
continue
|
862
|
+
match_stats[p][1] += 1
|
863
|
+
try: ret = match(uop, ctx)
|
864
|
+
except Exception as e:
|
865
|
+
if TRACK_MATCH_STATS >= 2 and active_rewrites and not isinstance(e, RewriteNotReady):
|
866
|
+
active_rewrites[-1].matches.append((track_uop(uop), track_uop(UOp(Ops.REWRITE_ERROR, src=uop.src, arg=str(sys.exc_info()[1]))), p.location))
|
867
|
+
raise
|
868
|
+
if ret is not None and ret is not uop:
|
869
|
+
match_stats[p][0] += 1
|
870
|
+
match_stats[p][3] += (et:=time.perf_counter()-st)
|
871
|
+
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", printable(p.location))
|
872
|
+
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites:
|
873
|
+
active_rewrites[-1].matches.append((track_uop(uop), track_uop(ret), p.location))
|
874
|
+
return ret
|
875
|
+
match_stats[p][2] += time.perf_counter()-st
|
876
|
+
return None
|
877
|
+
|
878
|
+
if TRACK_MATCH_STATS or PROFILE:
|
879
|
+
PatternMatcher = TrackedPatternMatcher # type: ignore
|
880
|
+
import atexit
|
881
|
+
@atexit.register
|
882
|
+
def print_match_stats():
|
883
|
+
if TRACK_MATCH_STATS >= 2:
|
884
|
+
with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f:
|
885
|
+
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
|
886
|
+
pickle.dump((tracked_keys, tracked_ctxs, uop_fields), f)
|
887
|
+
if VIZ: launch_viz(VIZ, temp("rewrites.pkl", append_user=True))
|
888
|
+
if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value):
|
889
|
+
ret = [0,0,0.0,0.0]
|
890
|
+
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
|
891
|
+
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
|
892
|
+
if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {(v[2]+v[3])*1000.:9.2f} ms -- {loc_str:20s}", printable(k.location))
|
893
|
+
ret = [x+y for x,y in zip(ret, v)]
|
894
|
+
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {(ret[2]+ret[3])*1000.:9.2f} ms -- TOTAL")
|
895
|
+
print(f"{len(match_stats)} rules, {sum(v[0] > 0 for v in match_stats.values())} matched once")
|
896
|
+
|
897
|
+
def launch_viz(var:ContextVar, data:str):
|
898
|
+
os.environ[(env_str:=var.key)] = "0"
|
899
|
+
os.environ[f"{env_str}_DATA"] = data
|
900
|
+
os.environ[f"{env_str}_VALUE"] = str(var.value)
|
901
|
+
if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")):
|
902
|
+
args = ['--kernels', getenv("VIZ_DATA", "")] if getenv("VIZ_DATA", "") else []
|
903
|
+
args += ['--profile', getenv("PROFILE_DATA", "")] if getenv("PROFILE_DATA", "") else []
|
904
|
+
os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), "../", "viz", "serve.py")] + args)
|
905
|
+
|
906
|
+
# *** simple graph rewrite engine ***
|
907
|
+
|
908
|
+
class RewriteNotReady(Exception): pass
|
909
|
+
class RewriteContext:
|
910
|
+
def __init__(self, pm, bpm, ctx=None):
|
911
|
+
self.pm: PatternMatcher|None = pm
|
912
|
+
self.pm_cache: dict[UOp, UOp|None] = {}
|
913
|
+
self.bpm: PatternMatcher|None = bpm
|
914
|
+
self.bpm_cache: dict[UOp, UOp|None] = {}
|
915
|
+
self.ctx = ctx
|
916
|
+
self.replace: dict[UOp, UOp] = {}
|
917
|
+
|
918
|
+
def cached_pm_rewrite(self, x:UOp):
|
919
|
+
if (ret:=self.pm_cache.get(x,False)) is not False: return ret
|
920
|
+
ret = self.pm_cache[x] = cast(PatternMatcher, self.pm).rewrite(x, self.ctx)
|
921
|
+
return ret
|
922
|
+
|
923
|
+
def cached_bpm_rewrite(self, x:UOp):
|
924
|
+
if (ret:=self.bpm_cache.get(x,False)) is not False: return ret
|
925
|
+
ret = self.bpm_cache[x] = cast(PatternMatcher, self.bpm).rewrite(x, self.ctx)
|
926
|
+
return ret
|
927
|
+
|
928
|
+
def unified_rewrite(self, root:UOp) -> UOp:
|
929
|
+
stack: list[tuple[UOp, int, UOp]] = [(root, 0, root)]
|
930
|
+
while stack:
|
931
|
+
if len(stack) >= 200000: raise RuntimeError("infinite loop in graph_rewrite (stack too big)")
|
932
|
+
n, stage, new_n = stack.pop()
|
933
|
+
if n in self.replace: continue # skip any nodes we have seen
|
934
|
+
try:
|
935
|
+
if stage == 0:
|
936
|
+
# if bottom up, we rewrite this node early. in both cases, we add its parents to the stack
|
937
|
+
if self.bpm is not None:
|
938
|
+
# apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match
|
939
|
+
test_n: UOp|None = n
|
940
|
+
seen = set()
|
941
|
+
while test_n is not None:
|
942
|
+
if test_n in seen: raise RuntimeError("infinite loop in fixed_point_rewrite")
|
943
|
+
seen.add(test_n)
|
944
|
+
new_n, test_n = test_n, self.cached_bpm_rewrite(test_n)
|
945
|
+
stack.append((n, 1, new_n))
|
946
|
+
for x in reversed(new_n.src): stack.append((x, 0, x))
|
947
|
+
elif stage == 1:
|
948
|
+
try: new_src = tuple([self.replace[x] for x in new_n.src])
|
949
|
+
except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from
|
950
|
+
if new_src == new_n.src:
|
951
|
+
# if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict
|
952
|
+
if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None:
|
953
|
+
self.replace[n] = new_n
|
954
|
+
continue
|
955
|
+
else:
|
956
|
+
# if srcs changed from rewrites, construct a new UOp with the new srcs
|
957
|
+
new_src_n = UOp(new_n.op, new_n.dtype, new_src, new_n.arg, new_n.tag)
|
958
|
+
# trigger a rewrite of new_src_n, then after that rewrite is done, link it back to n
|
959
|
+
stack.append((n, 2, new_src_n))
|
960
|
+
stack.append((new_src_n, 0, new_src_n))
|
961
|
+
else:
|
962
|
+
# in stage 2, we link the result of new_n to the result of n
|
963
|
+
try: self.replace[n] = self.replace[new_n]
|
964
|
+
except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from
|
965
|
+
except RewriteNotReady:
|
966
|
+
# retry this later
|
967
|
+
stack.insert(0, (n, stage, new_n))
|
968
|
+
return self.replace[root]
|
969
|
+
|
970
|
+
@track_matches
|
971
|
+
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None) -> UOp:
|
972
|
+
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
|
973
|
+
return rewrite_ctx.unified_rewrite(sink)
|
974
|
+
|
975
|
+
@track_matches
|
976
|
+
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None,
|
977
|
+
input_map:dict[UOp, UOp]|None=None, ) -> dict[UOp, UOp]:
|
978
|
+
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
|
979
|
+
new_map: dict[UOp, UOp] = {}
|
980
|
+
for k in (list(sink.toposort())[::-1] if bottom_up else sink.toposort()):
|
981
|
+
new_map[k] = v = rewrite_ctx.unified_rewrite(k)
|
982
|
+
if k is not v and k.metadata is not None: all_metadata[v] = tuple(dedup(all_metadata.get(v, ())))+k.metadata
|
983
|
+
if input_map is not None:
|
984
|
+
for k,v in input_map.items(): new_map[k] = new_map.get(v,v)
|
985
|
+
return new_map
|
986
|
+
|
987
|
+
def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x
|
988
|
+
|
989
|
+
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
990
|
+
|
991
|
+
# for debug
|
992
|
+
syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>",
|
993
|
+
Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
|
994
|
+
renderer = PatternMatcher([
|
995
|
+
(UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
|
996
|
+
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg}")),
|
997
|
+
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
|
998
|
+
(UPat(Ops.UNROLL, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UNROLL({x.src[0].arg}, {x.arg})")),
|
999
|
+
(UPat(Ops.CAST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"({str(x.dtype)[7:]})({x.src[0].arg})")),
|
1000
|
+
(UPat(Ops.LOAD), lambda: UOp(Ops.NOOP, arg="load")),
|
1001
|
+
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
1002
|
+
#(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}[={x.src[1].arg}]")),
|
1003
|
+
(UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
|
1004
|
+
(UPat(Ops.RECIP, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(1/{x.src[0].arg})")),
|
1005
|
+
(UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
|
1006
|
+
(UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
|
1007
|
+
(UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
|
1008
|
+
(UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.op]}{x.src[1].arg})")),
|
1009
|
+
])
|
1010
|
+
renderer_infer = PatternMatcher([
|
1011
|
+
(UPat(Ops.MOD, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"cmod({x.src[0].arg}, {x.src[1].arg})")),
|
1012
|
+
(UPat(Ops.IDIV, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"cdiv({x.src[0].arg}, {x.src[1].arg})")),
|
1013
|
+
*renderer.patterns
|
1014
|
+
])
|
1015
|
+
|
1016
|
+
# *** what was symbolic.py ***
|
1017
|
+
|
1018
|
+
sint = int|UOp
|
1019
|
+
Variable = UOp
|
1020
|
+
|
1021
|
+
ConstLike = ConstType|Variable|tuple[ConstType, ...]
|