tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/codegen/uops.py DELETED
@@ -1,451 +0,0 @@
1
- from __future__ import annotations
2
- from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TypeVar
3
- import functools, itertools, heapq, math
4
- from collections import defaultdict
5
- from enum import Enum, auto
6
- from dataclasses import dataclass, field
7
- from tinygrad.dtype import ConstType, dtypes, DType
8
- from tinygrad.shape.symbolic import sint, Variable
9
- from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
10
- from tinygrad.helpers import prod, DEBUG, getenv
11
-
12
- # the order of these UOps controls the order of the toposort
13
- class UOps(Enum):
14
- # ops that aren't rendered
15
- SINK = auto(); VAR = auto() # noqa: E702
16
- DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
17
- CONST = auto(); SPECIAL = auto() # noqa: E702
18
- NOOP = auto(); UNMUL = auto(); GEP = auto() # noqa: E702
19
- # math ops
20
- CAST = auto(); BITCAST = auto() # noqa: E702
21
- ALU = auto(); WMMA = auto() # noqa: E702
22
- # memory/assignment ops
23
- LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
24
- # control flow ops
25
- BARRIER = auto(); IF = auto(); RANGE = auto() # noqa: E702
26
- # these two are not graph nodes
27
- ENDRANGE = auto(); ENDIF = auto() # noqa: E702
28
-
29
- def ufix(dtype: Optional[DType], x): return UOp.const(dtype, x) if not isinstance(x, UOp) else x
30
- @dataclass(frozen=True, eq=False)
31
- class UOp:
32
- op: UOps
33
- dtype: Optional[DType] = None
34
- src: Tuple[UOp, ...] = tuple()
35
- arg: Any = None
36
- def commutative(self) -> bool:
37
- return self.op is UOps.ALU and self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR}
38
- @functools.cached_property
39
- def cmp_tuple(self):
40
- # NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
41
- return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
42
- (type(self.op), self.op.value), self.dtype, self.src)
43
- def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
44
- def __repr__(self):
45
- return f"{str(self.op):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.op for x in self.src]):32s} {self.arg}"
46
- def cast(self, dtype=None): return UOp(UOps.CAST, dtype, (self,))
47
- def name(self, name:Optional[str]): return UOp(UOps.VAR, src=(self,), arg=name)
48
- def __neg__(self): return UOp.alu(UnaryOps.NEG, self)
49
- def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, ufix(self.dtype, x))
50
- def __radd__(self, x): return UOp.alu(BinaryOps.ADD, ufix(self.dtype, x), self)
51
- def __sub__(self, x): return UOp.alu(BinaryOps.ADD, self, -ufix(self.dtype, x))
52
- def __mul__(self, x): return UOp.alu(BinaryOps.MUL, self, ufix(self.dtype, x))
53
- def __rmul__(self, x): return UOp.alu(BinaryOps.MUL, ufix(self.dtype, x), self)
54
- def __floordiv__(self, x): return UOp.alu(BinaryOps.IDIV, self, ufix(self.dtype, x))
55
- def __truediv__(self, x): return UOp.alu(BinaryOps.MUL, self, UOp.alu(UnaryOps.RECIP, ufix(self.dtype, x)))
56
- def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x))
57
- def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x))
58
- def ge(self, x): return -self.lt(x)
59
- def max(self, x): return UOp.alu(BinaryOps.MAX, self, x)
60
- def min(self, x): return -UOp.alu(BinaryOps.MAX, -self, -x)
61
- @staticmethod
62
- @functools.lru_cache(maxsize=None)
63
- def const(dtype:Optional[DType], b:ConstType|Variable):
64
- if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b)
65
- return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
66
- @staticmethod
67
- def alu(arg, *src:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else src[-1].dtype, src, arg)
68
- @staticmethod
69
- def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values()))
70
- @staticmethod
71
- def store(*src:UOp, **kwargs): return UOp(UOps.STORE, None, tuple(src)+tuple(kwargs.values()))
72
- @staticmethod
73
- def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.VAR, dtype=dtype, arg=name)
74
- @staticmethod
75
- def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.CONST, dtype=dtype).name(name)
76
- @functools.cached_property
77
- def parents(self) -> Set[UOp]: return set.union(set(self.src), *[x.parents for x in self.src])
78
- @property # parents with self
79
- def sparents(self) -> Set[UOp]: return set([self]).union(self.parents)
80
- def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.op is UOps.DEFINE_VAR])
81
-
82
- def uop_alu_resolve(u:UOp) -> sint:
83
- if u.op is UOps.CONST: return u.arg
84
- if u.op is UOps.DEFINE_VAR: return u.arg
85
- if u.op is UOps.SPECIAL: return u.arg[2]-1
86
- if u.op is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.src[0]) * uop_alu_resolve(u.src[1])
87
- if u.op is UOps.ALU and u.arg is BinaryOps.SHL: return uop_alu_resolve(u.src[0]) * (2**cast(int, uop_alu_resolve(u.src[1])))
88
- if u.op is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.src[0]) + uop_alu_resolve(u.src[1])
89
- raise RuntimeError(f"ALU resolve fail @ {u.op}")
90
-
91
- # *** simplification logic ***
92
-
93
- @dataclass(frozen=True)
94
- class UPat:
95
- op: Optional[Union[UOps, Set[UOps]]] = None
96
- arg: Any = None
97
- src: Optional[Union[Tuple[UPat, ...], List[UPat], UPat]] = None
98
- name: Optional[str] = None
99
- dtype: Optional[Union[DType, Set[DType]]] = None
100
- allow_len: Set[int] = field(default_factory=set)
101
-
102
- @staticmethod
103
- def compile(u: UOp, name:Optional[str]=None) -> UPat:
104
- if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg)
105
- return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None, name, u.dtype)
106
-
107
- T = TypeVar("T")
108
- def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool: return m2 not in m1 if isinstance(m1, set) else m2 != m1
109
-
110
- def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
111
- if pat.name is not None and store.setdefault(pat.name, uop) is not uop: return False
112
- if pat.arg is not None and __unmatch(pat.arg, uop.arg): return False
113
- if pat.dtype is not None and uop.dtype is not None and __unmatch(pat.dtype, uop.dtype): return False
114
- if pat.op is not None and __unmatch(pat.op, uop.op): return False
115
- if pat.src is None: return True
116
- # only one if it's a tuple
117
- # try all permutations if it's a list
118
- # repeat if it's a UPat
119
- for vp in itertools.permutations(pat.src) if isinstance(pat.src,list) else ([pat.src] if isinstance(pat.src,tuple) else [(pat.src,)*len(uop.src)]):
120
- if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len): return False
121
- new_store = store.copy()
122
- if all(_match(uu, vv, new_store) for uu, vv in zip(uop.src, vp)):
123
- store.update(new_store)
124
- return True
125
- return False
126
-
127
- class PatternMatcher:
128
- def __init__(self, patterns:List[Tuple[Union[UPat, UOp], Callable]]):
129
- self.patterns = patterns
130
- self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list)
131
- # uop is required, arg is optional
132
- for p,fxn in self.patterns:
133
- if isinstance(p, UOp): p = UPat.compile(p)
134
- assert p.op is not None
135
- if isinstance(p.op, set):
136
- for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
137
- else:
138
- self.pdict[(p.op, p.arg)].append((p, fxn))
139
-
140
- def rewrite(self, uop:UOp) -> Optional[UOp]:
141
- for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
142
- store: Dict[str, UOp] = {}
143
- if _match(uop, p, store): return fxn(**store)
144
- return None
145
-
146
- def sum_collapse(phi_input, loop, val1, val2):
147
- for v1,v2 in [(val1, val2), (val2, val1)]:
148
- if loop not in v1.parents:
149
- loop_range = loop.src[1]-loop.src[0]
150
- ret = v1*loop_range.cast(v1.dtype)
151
- return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+ret
152
- return None
153
-
154
- def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng):
155
- if not rng.arg[2]: return None # must be a reduce
156
- if mval.arg >= 0 or loop_start.arg != 0:
157
- # TODO: support and test this with other mvals and loop_starts
158
- if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
159
- return None
160
- comprange = UOp.min(loop_end, UOp.max(UOp.alu(BinaryOps.IDIV, idx-compval-mval, mval) + (loop_end-loop_start), loop_start))
161
- return UOp(UOps.UNMUL, multconst.dtype, (comprange.cast(multconst.dtype) * multconst, loop_end-loop_start))
162
-
163
- # this is symbolic 2.0
164
- constant_folder = PatternMatcher([
165
- # arange loop folding (early)
166
- (UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
167
- UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL, src=[UPat(UOps.CONST, name="mval"),
168
- UPat(UOps.RANGE, name="rng", src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
169
- UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse),
170
- (UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
171
- UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, UnaryOps.NEG, src=[
172
- UPat(UOps.RANGE, name="rng", src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
173
- UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))),
174
- lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
175
- # sum collapse to mul (with possible GEP)
176
- (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),
177
- UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
178
- (UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),)),
179
- UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
180
- # deal with UNMUL
181
- (UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
182
- lambda c1,c2,v: v if c1.arg == c2.arg else None),
183
- (UOp(UOps.UNMUL, src=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero),
184
- (UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.src[0].cast(root.dtype), unmul.src[1]))),
185
- # max on special can go away (TODO: special should be variable, same thing applies)
186
- (UOp.max(UOp.cvar('c'), UOp(UOps.SPECIAL).name('s')), lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
187
- # const rules
188
- (UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
189
- (UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
190
- # a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
191
- (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
192
- (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST),)), UPat(name="x"))), lambda x: x),
193
- (UPat(UOps.PHI, src=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
194
- # a DEFINE_ACC without inputs is a const + GEP on a const is the const
195
- (UPat(UOps.DEFINE_ACC, name="root", src=(UPat(UOps.CONST),)), lambda root: UOp.cast(root.src[0], root.dtype)),
196
- (UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
197
- # max -2147483648
198
- (UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
199
- # bool < False is always false, True < bool is always false
200
- (UOp.var().lt(UOp.const(dtypes.bool, False)), lambda: UOp.const(dtypes.bool, False)),
201
- (UOp.const(dtypes.bool, True).lt(UOp.var()), lambda: UOp.const(dtypes.bool, False)),
202
- # a conditional with the same results either way is a noop, also fold const conditionals
203
- (UOp.alu(TernaryOps.WHERE, UOp.var(), UOp.var("val"), UOp.var("val")), lambda val: val),
204
- (UOp.alu(TernaryOps.WHERE, UOp.cvar('gate'), UOp.var('c0'), UOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
205
- # ** constant folding **
206
- (UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
207
- # ** self folding **
208
- (-(-UOp.var('x')), lambda x: x), # -(-x) -> x
209
- (UOp.var('x') + 0, lambda x: x), # x+0 -> x
210
- (UOp.var('x') - 0, lambda x: x), # x-0 -> x
211
- (UOp.var('x') * 1, lambda x: x), # x*1 -> x
212
- (UOp.var('x') * -1, lambda x: -x), # x*-1 -> -x
213
- (UOp.var('x') // UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x//x -> 1
214
- (UOp.var('x') // 1, lambda x: x), # x//1 -> x
215
- (UOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
216
- (UOp.var('x') / UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x/x -> 1
217
- (UOp.var('x') / UOp.cvar('c'), lambda x,c: x*exec_alu(UnaryOps.RECIP, c.dtype, [c.arg])), # x/c -> x*(1/c)
218
- (UOp.var('x', dtype=dtypes.bool).max(UOp.const(dtypes.bool, False)), lambda x: x), # max(x, False) -> x
219
- # ** zero folding **
220
- #x*0 -> 0 or 0*x -> 0
221
- #if x is nan or inf it should render the nan value.
222
- # NOTE: this can be wrong for loaded NaN
223
- (UOp.var('x') * 0, lambda x: UOp.const(x.dtype, float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
224
- (UOp.var('x') - UOp.var('x'), lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
225
- # ** load/store folding **
226
- (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.load(UOp.var("buf"), UOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
227
- # ** two stage add/sub folding **
228
- ((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
229
- ((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),
230
- # *** rules from symbolic ***
231
- # two stage mul, (x*c1)*c2 = x*(c1*c2)
232
- ((UOp.var("x") * UOp.cvar("c1")) * UOp.cvar("c2"), lambda x,c1,c2: x*UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
233
- # x%1 -> 0
234
- (UOp.var("x") % UOp.const(None, 1), lambda x: UOp.const(x.dtype, 0)),
235
- # (x*c0)+(x*c1) -> x*(c0+c1)
236
- (UOp.var("x") * UOp.cvar("c0") + UOp.var("x") * UOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
237
- # (x*c0)+(y*c0) -> (x+y)*c0
238
- #((UOp.var("x") * UOp.cvar("c0")) + (UOp.var("y") * UOp.cvar("c0")), lambda x,y,c0: c0*(x+y)),
239
- # (x*c0)//c0 -> x
240
- ((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c0"), lambda x,c0: x if c0.arg != 0 else None),
241
- # (x*x2)/x2 -> x
242
- ((UOp.var("x") * UOp.var("x2")) / UOp.var("x2"), lambda x,x2: x),
243
- # (x//c0)//c1 -> x//(c0*c1)
244
- ((UOp.var("x") // UOp.cvar("c0")) // UOp.cvar("c1"), lambda x,c0,c1: x//UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))),
245
- # (x/x1)/x2 -> x/(x1*x2)
246
- ((UOp.var("x") / UOp.var("x2")) / UOp.var("x3"), lambda x,x2,x3: x/(x2*x3)),
247
- # c0 + x < c1 -> x < c1 - c0
248
- ((UOp.cvar("c0") + UOp.var("x")).lt(UOp.cvar("c1")),
249
- lambda x,c0,c1: UOp.lt(x, UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
250
- # (x+x*c0)-> x*(c0+1)
251
- (UOp.var("x") + UOp.var("x") * UOp.cvar("c0"), lambda x,c0: x*UOp.const(x.dtype, c0.arg+1)),
252
- # TODO: can do the invert of this (flip alt/load) when we fix double ops
253
- (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
254
- lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
255
- # store float4/float2 directly (remove CAST/GEP)
256
- (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store),
257
- (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store),
258
- # CAST-PHI-GEP -> PHI-CAST
259
- (UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
260
- lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
261
- (UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
262
- lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
263
- # NEG/CMPLT -> CMPLT
264
- (UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
265
- # cast NOOP (NOTE: it's str to deal with PtrDType)
266
- (UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
267
- # fold gated LOAD/STORE
268
- (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
269
- (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var"), UOp.var("barrier")),
270
- lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
271
- (UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var")), lambda var: var),
272
- (UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var"), UOp.var()), lambda var: var),
273
- (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(None, 1)), UOp.store),
274
- (UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(None, 0)), lambda: UOp(UOps.NOOP)),
275
- # remove NOOPs from SINK
276
- (UPat(UOps.SINK, name="root"),
277
- lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None)
278
- ])
279
-
280
- # *** uop graph ***
281
-
282
- def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]):
283
- if u in children: return
284
- children[u] = []
285
- for x in u.src:
286
- get_children_dfs(x, children, in_degree)
287
- children[x].append(u)
288
- in_degree[u] = len(u.src)
289
-
290
- def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
291
- nodes: Dict[Tuple, UOp] = {}
292
- replace: Dict[UOp, UOp] = {}
293
- def __inner_rewrite(n:UOp) -> UOp:
294
- if n in replace: return replace[n]
295
- replace_source = (n.op, n.dtype, tuple(__inner_rewrite(y) for y in n.src), n.arg)
296
- if found := nodes.get(replace_source): replace[n] = found
297
- else: nodes[replace_source] = replace[n] = __inner_rewrite(new_x) if (new_x := pm.rewrite(x:=UOp(*replace_source))) else x
298
- return replace[n]
299
- return __inner_rewrite(sink)
300
-
301
- class UOpGraph:
302
- def __init__(self, sinks:List[UOp]):
303
- self.sinks: List[UOp] = sinks
304
- # used by linearizer
305
- self._uops: Optional[List[UOp]] = None
306
-
307
- def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
308
- def __getitem__(self, index) -> UOp: return self.uops[index]
309
-
310
- def vars(self) -> List[Variable]: return sorted([x.arg for x in self.uops if x.op is UOps.DEFINE_VAR], key=lambda v: v.expr)
311
- def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.op is UOps.DEFINE_GLOBAL]
312
-
313
- @property
314
- def uops(self):
315
- if self._uops is None: self.linearize()
316
- return self._uops
317
-
318
- def graph(self):
319
- from tinygrad.engine.graph import graph_uops
320
- graph_uops(self.uops)
321
-
322
- def print(self):
323
- for i,u in enumerate(self):
324
- print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}")
325
-
326
- def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
327
- # NOTE: relinearizering should be okay
328
- #assert self._uops is None, "already linearized"
329
- sink = UOp(UOps.SINK, None, tuple(self.sinks))
330
-
331
- # dedup all nodes and do graph rewrite
332
- sink = graph_rewrite(sink, constant_folder)
333
- if extra_pm: sink = graph_rewrite(sink, PatternMatcher(constant_folder.patterns+extra_pm.patterns))
334
-
335
- # filter nodes that don't link to a sink
336
- # BFS toposort
337
- children: Dict[UOp, List[UOp]] = {}
338
- in_degree: Dict[UOp, int] = {}
339
- get_children_dfs(sink, children, in_degree)
340
-
341
- @functools.lru_cache(None)
342
- def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
343
- if x.op is UOps.SINK: return set()
344
- return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
345
-
346
- # scope children impact the toposort and END* insertion
347
- end_for_uop = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
348
- loops, ifs = [x for x in in_degree if x.op is UOps.RANGE], [x for x in in_degree if x.op is UOps.IF]
349
- scope_children = {p:get_recursive_children(p, end_for_uop[p.op][0]) for p in (loops+ifs)[::-1]}
350
-
351
- queue:List[Tuple[int, UOp]] = []
352
- def push(u:UOp):
353
- priority = 0
354
- # prefer uops that are loop children
355
- for l, ss in scope_children.items():
356
- if l.op is UOps.RANGE and u in ss: priority -= l.arg[0]*1000 + l.arg[1]
357
- heapq.heappush(queue, (priority, u))
358
-
359
- for u in children:
360
- if in_degree[u] == 0: push(u)
361
-
362
- if getenv("FUZZ_UOPS", 0):
363
- from test.external.fuzz_uops import fuzz_uops
364
- self.fuzz_paths = fuzz_uops(children, in_degree.copy(), scope_children)
365
-
366
- self._uops = []
367
- while queue:
368
- p,x = heapq.heappop(queue)
369
- if DEBUG >= 7: print(p,x)
370
- if x.op is UOps.DEFINE_ACC and len(x.src) > 1:
371
- idx = min([self._uops.index(l) for l in x.src if l.op is UOps.RANGE])
372
- self._uops.insert(idx, x)
373
- else:
374
- self._uops.append(x)
375
- for u, ss in scope_children.items():
376
- if x in ss:
377
- ss.remove(x)
378
- if len(ss) == 0: self._uops.append(UOp(end_for_uop[u.op][1], None, (u,)))
379
- for u in children[x]:
380
- in_degree[u] -= 1
381
- if in_degree[u] == 0: push(u)
382
-
383
- assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
384
- self._uops = self._uops[:-1]
385
-
386
- if type_verify: self.type_verify()
387
-
388
- # *** checker functions ***
389
-
390
- def flops_mem(self, ignore_indexing=False) -> Tuple[sint, sint]:
391
- flops: sint = 0
392
- mem: sint = 0
393
- mults: sint = 1
394
- mult_stack = []
395
- dont_count: Set[UOp] = set()
396
- if ignore_indexing:
397
- for u in self.uops:
398
- if u.op is UOps.LOAD:
399
- dont_count = dont_count.union(u.src[1].sparents)
400
- if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
401
- elif u.op is UOps.STORE:
402
- dont_count = dont_count.union(u.src[1].sparents)
403
- if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
404
- for u in self.uops:
405
- if u.op is UOps.RANGE:
406
- mult_stack.append(mults)
407
- mults *= uop_alu_resolve(u.src[1])
408
- elif u.op is UOps.ENDRANGE:
409
- mults = mult_stack.pop(-1)
410
- elif u.op is UOps.LOAD:
411
- assert u.dtype is not None
412
- mem += u.dtype.itemsize * mults
413
- elif u.op is UOps.STORE:
414
- assert u.src[2].dtype is not None
415
- mem += u.src[2].dtype.itemsize * mults
416
- elif u.op is UOps.ALU and u not in dont_count:
417
- flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
418
- elif u.op is UOps.WMMA and u not in dont_count:
419
- assert u.arg[1] is not None
420
- flops += 2 * prod(u.arg[1]) // 32 * mults
421
- return flops, mem
422
-
423
- def type_verify(self):
424
- for u in self.uops:
425
- uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
426
- if uop in (UOps.CONST, UOps.DEFINE_ACC):
427
- if uop is UOps.DEFINE_ACC:
428
- assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
429
- arg = src[0].arg
430
- assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
431
- if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
432
- if uop is UOps.LOAD and len(src) > 2 and src[2].op not in {UOps.IF, UOps.BARRIER}: assert src[2].dtype == dtypes.bool
433
- if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool
434
- if uop is UOps.ALU:
435
- if arg in UnaryOps:
436
- assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
437
- elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE):
438
- assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
439
- assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
440
- elif arg is BinaryOps.IDIV:
441
- assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
442
- f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
443
- assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
444
- elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
445
- # the distance to shift isn't typechecked
446
- assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
447
- elif arg in BinaryOps:
448
- assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
449
- elif arg == TernaryOps.WHERE:
450
- assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
451
- assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
tinygrad/engine/graph.py DELETED
@@ -1,100 +0,0 @@
1
- import os, atexit, functools, contextlib
2
- from collections import defaultdict
3
- from typing import List, Any, DefaultDict, Union
4
- from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LoadOps, BufferOps, TernaryOps, LazyOp
5
- from tinygrad.device import Device
6
- from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, getenv
7
- from tinygrad.codegen.uops import UOps, UOp, UPat
8
- from tinygrad.shape.symbolic import NumNode
9
- from tinygrad.lazy import LazyBuffer
10
-
11
- with contextlib.suppress(ImportError): import networkx as nx
12
-
13
- # **** debugging and graphing ****
14
-
15
- if DEBUG >= 2:
16
- def print_globalcounters():
17
- if GlobalCounters.time_sum_s == 0: return
18
- print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
19
- f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
20
- atexit.register(print_globalcounters)
21
-
22
- def save_graph(G, fn, opt=""):
23
- print("saving", G, f"to {fn}.svg")
24
- nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot')
25
- os.system(f'dot {opt} -Tsvg {fn}.dot -o {fn}.svg')
26
-
27
- G:Any = None
28
- def init_graph():
29
- global G
30
- if G is not None: return
31
- G = nx.DiGraph()
32
- atexit.register(functools.partial(save_graph, G, GRAPHPATH)) # -Gnslimit=100 can make it finish, but you won't like results
33
-
34
- counts: DefaultDict[type, int] = defaultdict(int)
35
- def nm(x):
36
- if not hasattr(x, 'node_id'):
37
- setattr(x, 'node_id', counts[type(x)])
38
- counts[type(x)] += 1
39
- return x.node_id
40
-
41
- def realized_lazybuffer(lb:'LazyBuffer', num):
42
- init_graph()
43
- G.nodes[nm(lb)]['style'] = '"filled,bold"'
44
- G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
45
- G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num}"'
46
-
47
- top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0",
48
- TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'}
49
- def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
50
- init_graph()
51
- if lb.base.realized is None and lb.base.op is LoadOps.CONST: return
52
- if lb.base != lb:
53
- offset = lb.st.expr_idxs([NumNode(0)] * len(lb.st.shape))[0]
54
- label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")
55
- G.add_node(nm(lb), style='"filled,dashed"', fillcolor="#80ff8080", color="black", label=label)
56
- G.add_edge(nm(lb.base), nm(lb), color='#00000060')
57
- lb = lb.base
58
- if lb.realized is None:
59
- label_append = []
60
- for idx,x in enumerate(lb.srcs):
61
- if nm(x) not in G.nodes: log_lazybuffer(x)
62
- if x.base.realized is None and x.base.op is LoadOps.CONST:
63
- label_append.append(f"\nCONST{idx} {x.base.arg:g}")
64
- else:
65
- G.add_edge(nm(x), nm(lb), color='#a0a0a0')
66
- label = '"' + \
67
- (str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
68
- (f"\n{lb.dtype.name}" if lb.dtype.name != "float" else "")+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {LoadOps.CONST, UnaryOps.CAST} else "") + \
69
- (f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + ''.join(label_append) + '"'
70
- G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
71
- if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
72
- else:
73
- if nm(lb) not in G.nodes:
74
- # realized but unseen?
75
- G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
76
-
77
- def _tree(dag:Union[LazyOp, UOp, UPat], cycles, cnt):
78
- cnt[0] += 1
79
- src = dag.src if isinstance(dag.src, (list, tuple)) else [] if dag.src is None else [dag.src]
80
- if len(src) == 0: return [f"━━ {dag.op} {dag.arg}"]
81
- if (lid := id(dag)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
82
- return [f"━⬆︎ goto {cycles[id(dag)][0]}: {dag.op}"]
83
- cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
84
- lines = [f"━┳ {dag.op} {dag.arg}"]
85
- childs = [_tree(c, cycles, cnt) for c in src]
86
- for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
87
- return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
88
-
89
- def print_tree(dag:Union[LazyOp, UOp, UPat]): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(dag, {}, [-1]))]))
90
-
91
- def graph_uops(uops:List[UOp]):
92
- colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
93
- UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
94
- UOps.RANGE: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
95
- G = nx.DiGraph()
96
- for u in uops:
97
- if u.op in {UOps.ENDRANGE, UOps.ENDIF}: continue
98
- G.add_node(uops.index(u), label=f"{str(u.op)[5:]}{(' '+str(u.arg).replace(':', '')) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.op, "#ffffff")) # noqa: E501
99
- for v in u.src: G.add_edge(uops.index(v), uops.index(u))
100
- save_graph(G, f'{GRAPHPATH}.uops', '-Grankdir=LR')