tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/codegen/uops.py
DELETED
@@ -1,451 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TypeVar
|
3
|
-
import functools, itertools, heapq, math
|
4
|
-
from collections import defaultdict
|
5
|
-
from enum import Enum, auto
|
6
|
-
from dataclasses import dataclass, field
|
7
|
-
from tinygrad.dtype import ConstType, dtypes, DType
|
8
|
-
from tinygrad.shape.symbolic import sint, Variable
|
9
|
-
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
|
10
|
-
from tinygrad.helpers import prod, DEBUG, getenv
|
11
|
-
|
12
|
-
# the order of these UOps controls the order of the toposort
|
13
|
-
class UOps(Enum):
|
14
|
-
# ops that aren't rendered
|
15
|
-
SINK = auto(); VAR = auto() # noqa: E702
|
16
|
-
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
|
17
|
-
CONST = auto(); SPECIAL = auto() # noqa: E702
|
18
|
-
NOOP = auto(); UNMUL = auto(); GEP = auto() # noqa: E702
|
19
|
-
# math ops
|
20
|
-
CAST = auto(); BITCAST = auto() # noqa: E702
|
21
|
-
ALU = auto(); WMMA = auto() # noqa: E702
|
22
|
-
# memory/assignment ops
|
23
|
-
LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
|
24
|
-
# control flow ops
|
25
|
-
BARRIER = auto(); IF = auto(); RANGE = auto() # noqa: E702
|
26
|
-
# these two are not graph nodes
|
27
|
-
ENDRANGE = auto(); ENDIF = auto() # noqa: E702
|
28
|
-
|
29
|
-
def ufix(dtype: Optional[DType], x): return UOp.const(dtype, x) if not isinstance(x, UOp) else x
|
30
|
-
@dataclass(frozen=True, eq=False)
|
31
|
-
class UOp:
|
32
|
-
op: UOps
|
33
|
-
dtype: Optional[DType] = None
|
34
|
-
src: Tuple[UOp, ...] = tuple()
|
35
|
-
arg: Any = None
|
36
|
-
def commutative(self) -> bool:
|
37
|
-
return self.op is UOps.ALU and self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR}
|
38
|
-
@functools.cached_property
|
39
|
-
def cmp_tuple(self):
|
40
|
-
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
41
|
-
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
|
42
|
-
(type(self.op), self.op.value), self.dtype, self.src)
|
43
|
-
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
|
44
|
-
def __repr__(self):
|
45
|
-
return f"{str(self.op):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.op for x in self.src]):32s} {self.arg}"
|
46
|
-
def cast(self, dtype=None): return UOp(UOps.CAST, dtype, (self,))
|
47
|
-
def name(self, name:Optional[str]): return UOp(UOps.VAR, src=(self,), arg=name)
|
48
|
-
def __neg__(self): return UOp.alu(UnaryOps.NEG, self)
|
49
|
-
def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, ufix(self.dtype, x))
|
50
|
-
def __radd__(self, x): return UOp.alu(BinaryOps.ADD, ufix(self.dtype, x), self)
|
51
|
-
def __sub__(self, x): return UOp.alu(BinaryOps.ADD, self, -ufix(self.dtype, x))
|
52
|
-
def __mul__(self, x): return UOp.alu(BinaryOps.MUL, self, ufix(self.dtype, x))
|
53
|
-
def __rmul__(self, x): return UOp.alu(BinaryOps.MUL, ufix(self.dtype, x), self)
|
54
|
-
def __floordiv__(self, x): return UOp.alu(BinaryOps.IDIV, self, ufix(self.dtype, x))
|
55
|
-
def __truediv__(self, x): return UOp.alu(BinaryOps.MUL, self, UOp.alu(UnaryOps.RECIP, ufix(self.dtype, x)))
|
56
|
-
def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x))
|
57
|
-
def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x))
|
58
|
-
def ge(self, x): return -self.lt(x)
|
59
|
-
def max(self, x): return UOp.alu(BinaryOps.MAX, self, x)
|
60
|
-
def min(self, x): return -UOp.alu(BinaryOps.MAX, -self, -x)
|
61
|
-
@staticmethod
|
62
|
-
@functools.lru_cache(maxsize=None)
|
63
|
-
def const(dtype:Optional[DType], b:ConstType|Variable):
|
64
|
-
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b)
|
65
|
-
return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
|
66
|
-
@staticmethod
|
67
|
-
def alu(arg, *src:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else src[-1].dtype, src, arg)
|
68
|
-
@staticmethod
|
69
|
-
def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values()))
|
70
|
-
@staticmethod
|
71
|
-
def store(*src:UOp, **kwargs): return UOp(UOps.STORE, None, tuple(src)+tuple(kwargs.values()))
|
72
|
-
@staticmethod
|
73
|
-
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.VAR, dtype=dtype, arg=name)
|
74
|
-
@staticmethod
|
75
|
-
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.CONST, dtype=dtype).name(name)
|
76
|
-
@functools.cached_property
|
77
|
-
def parents(self) -> Set[UOp]: return set.union(set(self.src), *[x.parents for x in self.src])
|
78
|
-
@property # parents with self
|
79
|
-
def sparents(self) -> Set[UOp]: return set([self]).union(self.parents)
|
80
|
-
def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.op is UOps.DEFINE_VAR])
|
81
|
-
|
82
|
-
def uop_alu_resolve(u:UOp) -> sint:
|
83
|
-
if u.op is UOps.CONST: return u.arg
|
84
|
-
if u.op is UOps.DEFINE_VAR: return u.arg
|
85
|
-
if u.op is UOps.SPECIAL: return u.arg[2]-1
|
86
|
-
if u.op is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.src[0]) * uop_alu_resolve(u.src[1])
|
87
|
-
if u.op is UOps.ALU and u.arg is BinaryOps.SHL: return uop_alu_resolve(u.src[0]) * (2**cast(int, uop_alu_resolve(u.src[1])))
|
88
|
-
if u.op is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.src[0]) + uop_alu_resolve(u.src[1])
|
89
|
-
raise RuntimeError(f"ALU resolve fail @ {u.op}")
|
90
|
-
|
91
|
-
# *** simplification logic ***
|
92
|
-
|
93
|
-
@dataclass(frozen=True)
|
94
|
-
class UPat:
|
95
|
-
op: Optional[Union[UOps, Set[UOps]]] = None
|
96
|
-
arg: Any = None
|
97
|
-
src: Optional[Union[Tuple[UPat, ...], List[UPat], UPat]] = None
|
98
|
-
name: Optional[str] = None
|
99
|
-
dtype: Optional[Union[DType, Set[DType]]] = None
|
100
|
-
allow_len: Set[int] = field(default_factory=set)
|
101
|
-
|
102
|
-
@staticmethod
|
103
|
-
def compile(u: UOp, name:Optional[str]=None) -> UPat:
|
104
|
-
if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg)
|
105
|
-
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None, name, u.dtype)
|
106
|
-
|
107
|
-
T = TypeVar("T")
|
108
|
-
def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool: return m2 not in m1 if isinstance(m1, set) else m2 != m1
|
109
|
-
|
110
|
-
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
|
111
|
-
if pat.name is not None and store.setdefault(pat.name, uop) is not uop: return False
|
112
|
-
if pat.arg is not None and __unmatch(pat.arg, uop.arg): return False
|
113
|
-
if pat.dtype is not None and uop.dtype is not None and __unmatch(pat.dtype, uop.dtype): return False
|
114
|
-
if pat.op is not None and __unmatch(pat.op, uop.op): return False
|
115
|
-
if pat.src is None: return True
|
116
|
-
# only one if it's a tuple
|
117
|
-
# try all permutations if it's a list
|
118
|
-
# repeat if it's a UPat
|
119
|
-
for vp in itertools.permutations(pat.src) if isinstance(pat.src,list) else ([pat.src] if isinstance(pat.src,tuple) else [(pat.src,)*len(uop.src)]):
|
120
|
-
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len): return False
|
121
|
-
new_store = store.copy()
|
122
|
-
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.src, vp)):
|
123
|
-
store.update(new_store)
|
124
|
-
return True
|
125
|
-
return False
|
126
|
-
|
127
|
-
class PatternMatcher:
|
128
|
-
def __init__(self, patterns:List[Tuple[Union[UPat, UOp], Callable]]):
|
129
|
-
self.patterns = patterns
|
130
|
-
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list)
|
131
|
-
# uop is required, arg is optional
|
132
|
-
for p,fxn in self.patterns:
|
133
|
-
if isinstance(p, UOp): p = UPat.compile(p)
|
134
|
-
assert p.op is not None
|
135
|
-
if isinstance(p.op, set):
|
136
|
-
for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
|
137
|
-
else:
|
138
|
-
self.pdict[(p.op, p.arg)].append((p, fxn))
|
139
|
-
|
140
|
-
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
141
|
-
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
|
142
|
-
store: Dict[str, UOp] = {}
|
143
|
-
if _match(uop, p, store): return fxn(**store)
|
144
|
-
return None
|
145
|
-
|
146
|
-
def sum_collapse(phi_input, loop, val1, val2):
|
147
|
-
for v1,v2 in [(val1, val2), (val2, val1)]:
|
148
|
-
if loop not in v1.parents:
|
149
|
-
loop_range = loop.src[1]-loop.src[0]
|
150
|
-
ret = v1*loop_range.cast(v1.dtype)
|
151
|
-
return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+ret
|
152
|
-
return None
|
153
|
-
|
154
|
-
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng):
|
155
|
-
if not rng.arg[2]: return None # must be a reduce
|
156
|
-
if mval.arg >= 0 or loop_start.arg != 0:
|
157
|
-
# TODO: support and test this with other mvals and loop_starts
|
158
|
-
if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
|
159
|
-
return None
|
160
|
-
comprange = UOp.min(loop_end, UOp.max(UOp.alu(BinaryOps.IDIV, idx-compval-mval, mval) + (loop_end-loop_start), loop_start))
|
161
|
-
return UOp(UOps.UNMUL, multconst.dtype, (comprange.cast(multconst.dtype) * multconst, loop_end-loop_start))
|
162
|
-
|
163
|
-
# this is symbolic 2.0
|
164
|
-
constant_folder = PatternMatcher([
|
165
|
-
# arange loop folding (early)
|
166
|
-
(UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
|
167
|
-
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL, src=[UPat(UOps.CONST, name="mval"),
|
168
|
-
UPat(UOps.RANGE, name="rng", src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
|
169
|
-
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse),
|
170
|
-
(UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
|
171
|
-
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, UnaryOps.NEG, src=[
|
172
|
-
UPat(UOps.RANGE, name="rng", src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
|
173
|
-
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))),
|
174
|
-
lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
|
175
|
-
# sum collapse to mul (with possible GEP)
|
176
|
-
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),
|
177
|
-
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
178
|
-
(UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),)),
|
179
|
-
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
180
|
-
# deal with UNMUL
|
181
|
-
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
|
182
|
-
lambda c1,c2,v: v if c1.arg == c2.arg else None),
|
183
|
-
(UOp(UOps.UNMUL, src=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero),
|
184
|
-
(UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.src[0].cast(root.dtype), unmul.src[1]))),
|
185
|
-
# max on special can go away (TODO: special should be variable, same thing applies)
|
186
|
-
(UOp.max(UOp.cvar('c'), UOp(UOps.SPECIAL).name('s')), lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
|
187
|
-
# const rules
|
188
|
-
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
|
189
|
-
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
|
190
|
-
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
|
191
|
-
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
|
192
|
-
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST),)), UPat(name="x"))), lambda x: x),
|
193
|
-
(UPat(UOps.PHI, src=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
|
194
|
-
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
|
195
|
-
(UPat(UOps.DEFINE_ACC, name="root", src=(UPat(UOps.CONST),)), lambda root: UOp.cast(root.src[0], root.dtype)),
|
196
|
-
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
|
197
|
-
# max -2147483648
|
198
|
-
(UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
|
199
|
-
# bool < False is always false, True < bool is always false
|
200
|
-
(UOp.var().lt(UOp.const(dtypes.bool, False)), lambda: UOp.const(dtypes.bool, False)),
|
201
|
-
(UOp.const(dtypes.bool, True).lt(UOp.var()), lambda: UOp.const(dtypes.bool, False)),
|
202
|
-
# a conditional with the same results either way is a noop, also fold const conditionals
|
203
|
-
(UOp.alu(TernaryOps.WHERE, UOp.var(), UOp.var("val"), UOp.var("val")), lambda val: val),
|
204
|
-
(UOp.alu(TernaryOps.WHERE, UOp.cvar('gate'), UOp.var('c0'), UOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
|
205
|
-
# ** constant folding **
|
206
|
-
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
|
207
|
-
# ** self folding **
|
208
|
-
(-(-UOp.var('x')), lambda x: x), # -(-x) -> x
|
209
|
-
(UOp.var('x') + 0, lambda x: x), # x+0 -> x
|
210
|
-
(UOp.var('x') - 0, lambda x: x), # x-0 -> x
|
211
|
-
(UOp.var('x') * 1, lambda x: x), # x*1 -> x
|
212
|
-
(UOp.var('x') * -1, lambda x: -x), # x*-1 -> -x
|
213
|
-
(UOp.var('x') // UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x//x -> 1
|
214
|
-
(UOp.var('x') // 1, lambda x: x), # x//1 -> x
|
215
|
-
(UOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
|
216
|
-
(UOp.var('x') / UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x/x -> 1
|
217
|
-
(UOp.var('x') / UOp.cvar('c'), lambda x,c: x*exec_alu(UnaryOps.RECIP, c.dtype, [c.arg])), # x/c -> x*(1/c)
|
218
|
-
(UOp.var('x', dtype=dtypes.bool).max(UOp.const(dtypes.bool, False)), lambda x: x), # max(x, False) -> x
|
219
|
-
# ** zero folding **
|
220
|
-
#x*0 -> 0 or 0*x -> 0
|
221
|
-
#if x is nan or inf it should render the nan value.
|
222
|
-
# NOTE: this can be wrong for loaded NaN
|
223
|
-
(UOp.var('x') * 0, lambda x: UOp.const(x.dtype, float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
224
|
-
(UOp.var('x') - UOp.var('x'), lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
|
225
|
-
# ** load/store folding **
|
226
|
-
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.load(UOp.var("buf"), UOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
|
227
|
-
# ** two stage add/sub folding **
|
228
|
-
((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
|
229
|
-
((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),
|
230
|
-
# *** rules from symbolic ***
|
231
|
-
# two stage mul, (x*c1)*c2 = x*(c1*c2)
|
232
|
-
((UOp.var("x") * UOp.cvar("c1")) * UOp.cvar("c2"), lambda x,c1,c2: x*UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
|
233
|
-
# x%1 -> 0
|
234
|
-
(UOp.var("x") % UOp.const(None, 1), lambda x: UOp.const(x.dtype, 0)),
|
235
|
-
# (x*c0)+(x*c1) -> x*(c0+c1)
|
236
|
-
(UOp.var("x") * UOp.cvar("c0") + UOp.var("x") * UOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
|
237
|
-
# (x*c0)+(y*c0) -> (x+y)*c0
|
238
|
-
#((UOp.var("x") * UOp.cvar("c0")) + (UOp.var("y") * UOp.cvar("c0")), lambda x,y,c0: c0*(x+y)),
|
239
|
-
# (x*c0)//c0 -> x
|
240
|
-
((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c0"), lambda x,c0: x if c0.arg != 0 else None),
|
241
|
-
# (x*x2)/x2 -> x
|
242
|
-
((UOp.var("x") * UOp.var("x2")) / UOp.var("x2"), lambda x,x2: x),
|
243
|
-
# (x//c0)//c1 -> x//(c0*c1)
|
244
|
-
((UOp.var("x") // UOp.cvar("c0")) // UOp.cvar("c1"), lambda x,c0,c1: x//UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))),
|
245
|
-
# (x/x1)/x2 -> x/(x1*x2)
|
246
|
-
((UOp.var("x") / UOp.var("x2")) / UOp.var("x3"), lambda x,x2,x3: x/(x2*x3)),
|
247
|
-
# c0 + x < c1 -> x < c1 - c0
|
248
|
-
((UOp.cvar("c0") + UOp.var("x")).lt(UOp.cvar("c1")),
|
249
|
-
lambda x,c0,c1: UOp.lt(x, UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
|
250
|
-
# (x+x*c0)-> x*(c0+1)
|
251
|
-
(UOp.var("x") + UOp.var("x") * UOp.cvar("c0"), lambda x,c0: x*UOp.const(x.dtype, c0.arg+1)),
|
252
|
-
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
253
|
-
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
|
254
|
-
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
|
255
|
-
# store float4/float2 directly (remove CAST/GEP)
|
256
|
-
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store),
|
257
|
-
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store),
|
258
|
-
# CAST-PHI-GEP -> PHI-CAST
|
259
|
-
(UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
|
260
|
-
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
|
261
|
-
(UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
|
262
|
-
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
|
263
|
-
# NEG/CMPLT -> CMPLT
|
264
|
-
(UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
|
265
|
-
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
266
|
-
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
267
|
-
# fold gated LOAD/STORE
|
268
|
-
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
|
269
|
-
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var"), UOp.var("barrier")),
|
270
|
-
lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
|
271
|
-
(UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var")), lambda var: var),
|
272
|
-
(UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var"), UOp.var()), lambda var: var),
|
273
|
-
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(None, 1)), UOp.store),
|
274
|
-
(UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(None, 0)), lambda: UOp(UOps.NOOP)),
|
275
|
-
# remove NOOPs from SINK
|
276
|
-
(UPat(UOps.SINK, name="root"),
|
277
|
-
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None)
|
278
|
-
])
|
279
|
-
|
280
|
-
# *** uop graph ***
|
281
|
-
|
282
|
-
def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]):
|
283
|
-
if u in children: return
|
284
|
-
children[u] = []
|
285
|
-
for x in u.src:
|
286
|
-
get_children_dfs(x, children, in_degree)
|
287
|
-
children[x].append(u)
|
288
|
-
in_degree[u] = len(u.src)
|
289
|
-
|
290
|
-
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
|
291
|
-
nodes: Dict[Tuple, UOp] = {}
|
292
|
-
replace: Dict[UOp, UOp] = {}
|
293
|
-
def __inner_rewrite(n:UOp) -> UOp:
|
294
|
-
if n in replace: return replace[n]
|
295
|
-
replace_source = (n.op, n.dtype, tuple(__inner_rewrite(y) for y in n.src), n.arg)
|
296
|
-
if found := nodes.get(replace_source): replace[n] = found
|
297
|
-
else: nodes[replace_source] = replace[n] = __inner_rewrite(new_x) if (new_x := pm.rewrite(x:=UOp(*replace_source))) else x
|
298
|
-
return replace[n]
|
299
|
-
return __inner_rewrite(sink)
|
300
|
-
|
301
|
-
class UOpGraph:
|
302
|
-
def __init__(self, sinks:List[UOp]):
|
303
|
-
self.sinks: List[UOp] = sinks
|
304
|
-
# used by linearizer
|
305
|
-
self._uops: Optional[List[UOp]] = None
|
306
|
-
|
307
|
-
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
|
308
|
-
def __getitem__(self, index) -> UOp: return self.uops[index]
|
309
|
-
|
310
|
-
def vars(self) -> List[Variable]: return sorted([x.arg for x in self.uops if x.op is UOps.DEFINE_VAR], key=lambda v: v.expr)
|
311
|
-
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.op is UOps.DEFINE_GLOBAL]
|
312
|
-
|
313
|
-
@property
|
314
|
-
def uops(self):
|
315
|
-
if self._uops is None: self.linearize()
|
316
|
-
return self._uops
|
317
|
-
|
318
|
-
def graph(self):
|
319
|
-
from tinygrad.engine.graph import graph_uops
|
320
|
-
graph_uops(self.uops)
|
321
|
-
|
322
|
-
def print(self):
|
323
|
-
for i,u in enumerate(self):
|
324
|
-
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}")
|
325
|
-
|
326
|
-
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
|
327
|
-
# NOTE: relinearizering should be okay
|
328
|
-
#assert self._uops is None, "already linearized"
|
329
|
-
sink = UOp(UOps.SINK, None, tuple(self.sinks))
|
330
|
-
|
331
|
-
# dedup all nodes and do graph rewrite
|
332
|
-
sink = graph_rewrite(sink, constant_folder)
|
333
|
-
if extra_pm: sink = graph_rewrite(sink, PatternMatcher(constant_folder.patterns+extra_pm.patterns))
|
334
|
-
|
335
|
-
# filter nodes that don't link to a sink
|
336
|
-
# BFS toposort
|
337
|
-
children: Dict[UOp, List[UOp]] = {}
|
338
|
-
in_degree: Dict[UOp, int] = {}
|
339
|
-
get_children_dfs(sink, children, in_degree)
|
340
|
-
|
341
|
-
@functools.lru_cache(None)
|
342
|
-
def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
|
343
|
-
if x.op is UOps.SINK: return set()
|
344
|
-
return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
|
345
|
-
|
346
|
-
# scope children impact the toposort and END* insertion
|
347
|
-
end_for_uop = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
|
348
|
-
loops, ifs = [x for x in in_degree if x.op is UOps.RANGE], [x for x in in_degree if x.op is UOps.IF]
|
349
|
-
scope_children = {p:get_recursive_children(p, end_for_uop[p.op][0]) for p in (loops+ifs)[::-1]}
|
350
|
-
|
351
|
-
queue:List[Tuple[int, UOp]] = []
|
352
|
-
def push(u:UOp):
|
353
|
-
priority = 0
|
354
|
-
# prefer uops that are loop children
|
355
|
-
for l, ss in scope_children.items():
|
356
|
-
if l.op is UOps.RANGE and u in ss: priority -= l.arg[0]*1000 + l.arg[1]
|
357
|
-
heapq.heappush(queue, (priority, u))
|
358
|
-
|
359
|
-
for u in children:
|
360
|
-
if in_degree[u] == 0: push(u)
|
361
|
-
|
362
|
-
if getenv("FUZZ_UOPS", 0):
|
363
|
-
from test.external.fuzz_uops import fuzz_uops
|
364
|
-
self.fuzz_paths = fuzz_uops(children, in_degree.copy(), scope_children)
|
365
|
-
|
366
|
-
self._uops = []
|
367
|
-
while queue:
|
368
|
-
p,x = heapq.heappop(queue)
|
369
|
-
if DEBUG >= 7: print(p,x)
|
370
|
-
if x.op is UOps.DEFINE_ACC and len(x.src) > 1:
|
371
|
-
idx = min([self._uops.index(l) for l in x.src if l.op is UOps.RANGE])
|
372
|
-
self._uops.insert(idx, x)
|
373
|
-
else:
|
374
|
-
self._uops.append(x)
|
375
|
-
for u, ss in scope_children.items():
|
376
|
-
if x in ss:
|
377
|
-
ss.remove(x)
|
378
|
-
if len(ss) == 0: self._uops.append(UOp(end_for_uop[u.op][1], None, (u,)))
|
379
|
-
for u in children[x]:
|
380
|
-
in_degree[u] -= 1
|
381
|
-
if in_degree[u] == 0: push(u)
|
382
|
-
|
383
|
-
assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
|
384
|
-
self._uops = self._uops[:-1]
|
385
|
-
|
386
|
-
if type_verify: self.type_verify()
|
387
|
-
|
388
|
-
# *** checker functions ***
|
389
|
-
|
390
|
-
def flops_mem(self, ignore_indexing=False) -> Tuple[sint, sint]:
|
391
|
-
flops: sint = 0
|
392
|
-
mem: sint = 0
|
393
|
-
mults: sint = 1
|
394
|
-
mult_stack = []
|
395
|
-
dont_count: Set[UOp] = set()
|
396
|
-
if ignore_indexing:
|
397
|
-
for u in self.uops:
|
398
|
-
if u.op is UOps.LOAD:
|
399
|
-
dont_count = dont_count.union(u.src[1].sparents)
|
400
|
-
if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
|
401
|
-
elif u.op is UOps.STORE:
|
402
|
-
dont_count = dont_count.union(u.src[1].sparents)
|
403
|
-
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
|
404
|
-
for u in self.uops:
|
405
|
-
if u.op is UOps.RANGE:
|
406
|
-
mult_stack.append(mults)
|
407
|
-
mults *= uop_alu_resolve(u.src[1])
|
408
|
-
elif u.op is UOps.ENDRANGE:
|
409
|
-
mults = mult_stack.pop(-1)
|
410
|
-
elif u.op is UOps.LOAD:
|
411
|
-
assert u.dtype is not None
|
412
|
-
mem += u.dtype.itemsize * mults
|
413
|
-
elif u.op is UOps.STORE:
|
414
|
-
assert u.src[2].dtype is not None
|
415
|
-
mem += u.src[2].dtype.itemsize * mults
|
416
|
-
elif u.op is UOps.ALU and u not in dont_count:
|
417
|
-
flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
|
418
|
-
elif u.op is UOps.WMMA and u not in dont_count:
|
419
|
-
assert u.arg[1] is not None
|
420
|
-
flops += 2 * prod(u.arg[1]) // 32 * mults
|
421
|
-
return flops, mem
|
422
|
-
|
423
|
-
def type_verify(self):
|
424
|
-
for u in self.uops:
|
425
|
-
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
|
426
|
-
if uop in (UOps.CONST, UOps.DEFINE_ACC):
|
427
|
-
if uop is UOps.DEFINE_ACC:
|
428
|
-
assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
|
429
|
-
arg = src[0].arg
|
430
|
-
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
431
|
-
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
|
432
|
-
if uop is UOps.LOAD and len(src) > 2 and src[2].op not in {UOps.IF, UOps.BARRIER}: assert src[2].dtype == dtypes.bool
|
433
|
-
if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool
|
434
|
-
if uop is UOps.ALU:
|
435
|
-
if arg in UnaryOps:
|
436
|
-
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
437
|
-
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE):
|
438
|
-
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
439
|
-
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
440
|
-
elif arg is BinaryOps.IDIV:
|
441
|
-
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
|
442
|
-
f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
|
443
|
-
assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
|
444
|
-
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
|
445
|
-
# the distance to shift isn't typechecked
|
446
|
-
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
447
|
-
elif arg in BinaryOps:
|
448
|
-
assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
449
|
-
elif arg == TernaryOps.WHERE:
|
450
|
-
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
|
451
|
-
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
tinygrad/engine/graph.py
DELETED
@@ -1,100 +0,0 @@
|
|
1
|
-
import os, atexit, functools, contextlib
|
2
|
-
from collections import defaultdict
|
3
|
-
from typing import List, Any, DefaultDict, Union
|
4
|
-
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LoadOps, BufferOps, TernaryOps, LazyOp
|
5
|
-
from tinygrad.device import Device
|
6
|
-
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, getenv
|
7
|
-
from tinygrad.codegen.uops import UOps, UOp, UPat
|
8
|
-
from tinygrad.shape.symbolic import NumNode
|
9
|
-
from tinygrad.lazy import LazyBuffer
|
10
|
-
|
11
|
-
with contextlib.suppress(ImportError): import networkx as nx
|
12
|
-
|
13
|
-
# **** debugging and graphing ****
|
14
|
-
|
15
|
-
if DEBUG >= 2:
|
16
|
-
def print_globalcounters():
|
17
|
-
if GlobalCounters.time_sum_s == 0: return
|
18
|
-
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
|
19
|
-
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
|
20
|
-
atexit.register(print_globalcounters)
|
21
|
-
|
22
|
-
def save_graph(G, fn, opt=""):
|
23
|
-
print("saving", G, f"to {fn}.svg")
|
24
|
-
nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot')
|
25
|
-
os.system(f'dot {opt} -Tsvg {fn}.dot -o {fn}.svg')
|
26
|
-
|
27
|
-
G:Any = None
|
28
|
-
def init_graph():
|
29
|
-
global G
|
30
|
-
if G is not None: return
|
31
|
-
G = nx.DiGraph()
|
32
|
-
atexit.register(functools.partial(save_graph, G, GRAPHPATH)) # -Gnslimit=100 can make it finish, but you won't like results
|
33
|
-
|
34
|
-
counts: DefaultDict[type, int] = defaultdict(int)
|
35
|
-
def nm(x):
|
36
|
-
if not hasattr(x, 'node_id'):
|
37
|
-
setattr(x, 'node_id', counts[type(x)])
|
38
|
-
counts[type(x)] += 1
|
39
|
-
return x.node_id
|
40
|
-
|
41
|
-
def realized_lazybuffer(lb:'LazyBuffer', num):
|
42
|
-
init_graph()
|
43
|
-
G.nodes[nm(lb)]['style'] = '"filled,bold"'
|
44
|
-
G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
|
45
|
-
G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num}"'
|
46
|
-
|
47
|
-
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0",
|
48
|
-
TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'}
|
49
|
-
def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
|
50
|
-
init_graph()
|
51
|
-
if lb.base.realized is None and lb.base.op is LoadOps.CONST: return
|
52
|
-
if lb.base != lb:
|
53
|
-
offset = lb.st.expr_idxs([NumNode(0)] * len(lb.st.shape))[0]
|
54
|
-
label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")
|
55
|
-
G.add_node(nm(lb), style='"filled,dashed"', fillcolor="#80ff8080", color="black", label=label)
|
56
|
-
G.add_edge(nm(lb.base), nm(lb), color='#00000060')
|
57
|
-
lb = lb.base
|
58
|
-
if lb.realized is None:
|
59
|
-
label_append = []
|
60
|
-
for idx,x in enumerate(lb.srcs):
|
61
|
-
if nm(x) not in G.nodes: log_lazybuffer(x)
|
62
|
-
if x.base.realized is None and x.base.op is LoadOps.CONST:
|
63
|
-
label_append.append(f"\nCONST{idx} {x.base.arg:g}")
|
64
|
-
else:
|
65
|
-
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
|
66
|
-
label = '"' + \
|
67
|
-
(str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
|
68
|
-
(f"\n{lb.dtype.name}" if lb.dtype.name != "float" else "")+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {LoadOps.CONST, UnaryOps.CAST} else "") + \
|
69
|
-
(f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + ''.join(label_append) + '"'
|
70
|
-
G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
|
71
|
-
if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
|
72
|
-
else:
|
73
|
-
if nm(lb) not in G.nodes:
|
74
|
-
# realized but unseen?
|
75
|
-
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
|
76
|
-
|
77
|
-
def _tree(dag:Union[LazyOp, UOp, UPat], cycles, cnt):
|
78
|
-
cnt[0] += 1
|
79
|
-
src = dag.src if isinstance(dag.src, (list, tuple)) else [] if dag.src is None else [dag.src]
|
80
|
-
if len(src) == 0: return [f"━━ {dag.op} {dag.arg}"]
|
81
|
-
if (lid := id(dag)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
|
82
|
-
return [f"━⬆︎ goto {cycles[id(dag)][0]}: {dag.op}"]
|
83
|
-
cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
|
84
|
-
lines = [f"━┳ {dag.op} {dag.arg}"]
|
85
|
-
childs = [_tree(c, cycles, cnt) for c in src]
|
86
|
-
for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
|
87
|
-
return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
|
88
|
-
|
89
|
-
def print_tree(dag:Union[LazyOp, UOp, UPat]): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(dag, {}, [-1]))]))
|
90
|
-
|
91
|
-
def graph_uops(uops:List[UOp]):
|
92
|
-
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
|
93
|
-
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
|
94
|
-
UOps.RANGE: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
|
95
|
-
G = nx.DiGraph()
|
96
|
-
for u in uops:
|
97
|
-
if u.op in {UOps.ENDRANGE, UOps.ENDIF}: continue
|
98
|
-
G.add_node(uops.index(u), label=f"{str(u.op)[5:]}{(' '+str(u.arg).replace(':', '')) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.op, "#ffffff")) # noqa: E501
|
99
|
-
for v in u.src: G.add_edge(uops.index(v), uops.index(u))
|
100
|
-
save_graph(G, f'{GRAPHPATH}.uops', '-Grankdir=LR')
|