tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/kernel.py +230 -190
  3. tinygrad/codegen/linearizer.py +278 -384
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +132 -275
  6. tinygrad/dtype.py +53 -37
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +28 -14
  14. tinygrad/helpers.py +72 -43
  15. tinygrad/lazy.py +141 -240
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +179 -8
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +106 -28
  20. tinygrad/nn/state.py +86 -17
  21. tinygrad/ops.py +70 -44
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +299 -206
  25. tinygrad/renderer/llvmir.py +118 -123
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +59 -54
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +37 -41
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +16 -14
  43. tinygrad/runtime/ops_cuda.py +130 -38
  44. tinygrad/runtime/ops_disk.py +45 -42
  45. tinygrad/runtime/ops_gpu.py +52 -50
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +36 -56
  48. tinygrad/runtime/ops_metal.py +42 -24
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +41 -105
  53. tinygrad/shape/symbolic.py +98 -95
  54. tinygrad/shape/view.py +137 -35
  55. tinygrad/tensor.py +2367 -442
  56. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/features/image.py +0 -93
  61. tinygrad/features/multi.py +0 -103
  62. tinygrad/features/search.py +0 -160
  63. tinygrad/graph.py +0 -106
  64. tinygrad/jit.py +0 -152
  65. tinygrad/realize.py +0 -50
  66. tinygrad/runtime/graph/hip.py +0 -24
  67. tinygrad/runtime/ops_cpu.py +0 -45
  68. tinygrad/runtime/ops_hip.py +0 -97
  69. tinygrad/runtime/ops_torch.py +0 -49
  70. tinygrad-0.8.0.dist-info/RECORD +0 -41
  71. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,415 @@
1
+ from __future__ import annotations
2
+ from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable
3
+ import functools, itertools, heapq
4
+ from collections import defaultdict
5
+ from enum import Enum, auto
6
+ from dataclasses import dataclass
7
+ from tinygrad.dtype import dtypes, DType
8
+ from tinygrad.shape.symbolic import sint, Variable
9
+ from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
10
+ from tinygrad.helpers import prod, DEBUG, getenv
11
+
12
+ # the order of these UOps controls the order of the toposort
13
+ class UOps(Enum):
14
+ # ops that aren't rendered
15
+ SINK = auto()
16
+ DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
17
+ CONST = auto(); SPECIAL = auto() # noqa: E702
18
+ NOOP = auto(); UNMUL = auto(); GEP = auto() # noqa: E702
19
+ # math ops
20
+ CAST = auto(); BITCAST = auto() # noqa: E702
21
+ ALU = auto(); WMMA = auto() # noqa: E702
22
+ # memory/assignment ops
23
+ LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
24
+ # control flow ops
25
+ BARRIER = auto(); IF = auto(); RANGE = auto() # noqa: E702
26
+ # these two are not graph nodes
27
+ ENDRANGE = auto(); ENDIF = auto() # noqa: E702
28
+
29
+ @dataclass(eq=False)
30
+ class UOp:
31
+ uop: UOps
32
+ dtype: Optional[DType] = None
33
+ vin: Tuple[UOp, ...] = tuple()
34
+ arg: Any = None
35
+ def tuple(self): return (self.uop, self.dtype, self.vin, self.arg)
36
+ @functools.cached_property
37
+ def cmp_tuple(self):
38
+ # NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
39
+ return (self.uop.value, (self.arg if self.uop is not UOps.DEFINE_VAR else self.arg.expr) if self.uop is not UOps.ALU else \
40
+ (type(self.uop), self.uop.value), self.dtype, self.vin)
41
+ def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
42
+ def __repr__(self):
43
+ return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
44
+ def cast(self, dtype): return UOp(UOps.CAST, dtype, (self,))
45
+ def __neg__(self): return UOp.alu(UnaryOps.NEG, self)
46
+ def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, x)
47
+ def __sub__(self, x): return UOp.alu(BinaryOps.SUB, self, x)
48
+ def __mul__(self, x): return UOp.alu(BinaryOps.MUL, self, x)
49
+ @staticmethod
50
+ def max(x, y): return UOp.alu(BinaryOps.MAX, x, y)
51
+ @staticmethod
52
+ def min(x, y): return -UOp.alu(BinaryOps.MAX, -x, -y)
53
+ @staticmethod
54
+ def const(dtype, val): return UOp(UOps.CONST, dtype, arg=dtypes.as_const(val, dtype))
55
+ @staticmethod
56
+ def alu(arg, *vin:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else vin[-1].dtype, vin, arg)
57
+ @functools.cached_property
58
+ def parents(self) -> Set[UOp]: return set.union(set(self.vin), *[x.parents for x in self.vin])
59
+
60
+ def uop_alu_resolve(u:UOp) -> sint:
61
+ if u.uop is UOps.CONST: return u.arg
62
+ if u.uop is UOps.DEFINE_VAR: return u.arg
63
+ if u.uop is UOps.SPECIAL: return u.arg[2]-1
64
+ if u.uop is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1])
65
+ if u.uop is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1])
66
+ raise RuntimeError(f"ALU resolve fail @ {u.uop}")
67
+
68
+ # *** simplification logic ***
69
+
70
+ def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool:
71
+ for k,v in pattern.items():
72
+ if k == "__name__":
73
+ if v in store and store[v] != uop: return False
74
+ store[v] = uop
75
+ elif k == "arg":
76
+ if uop.arg != v: return False
77
+ elif k == "dtype":
78
+ if isinstance(v, set):
79
+ if uop.dtype not in v: return False
80
+ elif uop.dtype != v: return False
81
+ elif k == "uop":
82
+ if isinstance(v, set):
83
+ if uop.uop not in v: return False
84
+ elif uop.uop != v: return False
85
+ elif k == "vin":
86
+ # only one if it's a tuple
87
+ # try all permutations if it's a list
88
+ # repeat if it's a dict
89
+ for vp in itertools.permutations(v) if isinstance(v, list) else ([v] if isinstance(v, tuple) else [(v,)*len(uop.vin)]):
90
+ if len(uop.vin) != len(vp) and (len(uop.vin) not in pattern.get('__allow_len__', [])): return False
91
+ new_store = store.copy()
92
+ if all(_match(uu, vv, new_store) for uu, vv in zip(uop.vin, vp)):
93
+ for k,v in new_store.items(): store[k] = v
94
+ return True
95
+ return False
96
+ return True
97
+
98
+ class PatternMatcher:
99
+ def __init__(self, patterns:List[Tuple[Dict[str, Any], Callable]]):
100
+ self.patterns = patterns
101
+ self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[Dict[str, Any], Callable]]] = defaultdict(list)
102
+ # uop is required, arg is optional
103
+ for p,fxn in self.patterns:
104
+ uops = p["uop"]
105
+ if isinstance(uops, set):
106
+ for uop in uops: self.pdict[(uop, p.get("arg", None))].append((p, fxn))
107
+ else:
108
+ self.pdict[(uops, p.get("arg", None))].append((p, fxn))
109
+
110
+ def rewrite(self, uop:UOp) -> Optional[UOp]:
111
+ for p,fxn in itertools.chain(self.pdict[(uop.uop, uop.arg)], self.pdict[(uop.uop, None)]):
112
+ store: Dict[str, UOp] = {}
113
+ if _match(uop, p, store): return fxn(**store)
114
+ return None
115
+
116
+ def sum_collapse(phi_input, loop, val1, val2):
117
+ for v1,v2 in [(val1, val2), (val2, val1)]:
118
+ if loop not in v1.parents:
119
+ loop_range = loop.vin[1]-loop.vin[0]
120
+ ret = v1*loop_range.cast(v1.dtype)
121
+ return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+ret
122
+ return None
123
+
124
+ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst):
125
+ if mval.arg >= 0 or loop_start.arg != 0:
126
+ # TODO: support and test this with other mvals and loop_starts
127
+ if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
128
+ return None
129
+ comprange = UOp.min(loop_end, UOp.max(UOp.alu(BinaryOps.DIV, idx-compval-mval, mval) + (loop_end-loop_start), loop_start))
130
+ return UOp(UOps.UNMUL, multconst.dtype, (comprange.cast(multconst.dtype) * multconst, loop_end-loop_start))
131
+
132
+ # this is symbolic 2.0
133
+ constant_folder = PatternMatcher([
134
+ # arange loop folding (early)
135
+ ({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": (
136
+ {"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin":
137
+ [{"__name__": "idx"}, {"uop": UOps.ALU, "arg": BinaryOps.MUL,
138
+ "vin": [{"__name__": "mval", "uop": UOps.CONST}, {"uop": UOps.RANGE, "vin": ({"__name__": "loop_start"}, {"__name__": "loop_end"})}]}]},
139
+ {"__name__": "compval", "uop": UOps.CONST})}, {"__name__": "multconst", "uop": UOps.CONST}, {"uop": UOps.CONST, "arg": 0})}, loop_collapse),
140
+ # sum collapse to mul (with possible GEP)
141
+ ({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.DEFINE_ACC, "vin": ({"uop": UOps.RANGE, "__name__": "loop"},)},
142
+ {"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
143
+ ({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.GEP,
144
+ "vin": ({"uop": UOps.DEFINE_ACC, "vin":({"uop": UOps.RANGE, "__name__": "loop"},)},)},
145
+ {"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
146
+ # deal with UNMUL
147
+ ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"uop": UOps.CONST, "__name__": "c1"},
148
+ {"uop": UOps.UNMUL, "vin": [{"uop": UOps.CONST, "__name__": "c2"}, {"__name__": "v"}]}]},
149
+ lambda c1,c2,v: v if c1.arg == c2.arg else None),
150
+ ({"uop": UOps.UNMUL, "vin": ({"uop": UOps.CONST, "__name__": "zero", "arg": 0}, {})}, lambda zero: zero),
151
+ ({"__name__": "root", "uop": UOps.CAST, "vin": ({"uop": UOps.UNMUL, "__name__": "unmul"},)},
152
+ lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.vin[0].cast(root.dtype), unmul.vin[1]))),
153
+ # max on special can go away (TODO: special should be variable, same thing applies)
154
+ ({"uop": UOps.ALU, "arg": BinaryOps.MAX, "vin": [{"__name__": "c", "uop": UOps.CONST}, {"__name__": "s", "uop": UOps.SPECIAL}]},
155
+ lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
156
+ # const rules
157
+ ({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "c", "uop": UOps.CONST},)}, lambda root, c: UOp.const(root.dtype, c.arg)),
158
+ ({"__name__": "root", "uop": UOps.CAST, "vin": {"__name__": "c", "uop": UOps.CONST}}, lambda root, c: UOp.const(root.dtype, c.arg)),
159
+ # a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
160
+ ({"uop": UOps.PHI, "vin": ({"uop": UOps.DEFINE_ACC, "__name__": "acc"}, {"__name__": "acc"})}, lambda acc: UOp.const(acc.dtype, acc.arg[0])),
161
+ ({"uop": UOps.PHI, "vin": ({"uop": UOps.DEFINE_ACC, "vin": tuple()}, {"__name__": "x"})}, lambda x: x),
162
+ ({"uop": UOps.PHI, "vin": ({"uop": UOps.CONST}, {"__name__": "x"})}, lambda x: x),
163
+ # a DEFINE_ACC without inputs is a const + GEP on a const is the const
164
+ ({"__name__": "root", "uop": UOps.DEFINE_ACC, "vin": tuple()}, lambda root: UOp.const(root.dtype, root.arg[0])),
165
+ ({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "x", "uop": UOps.CONST},)}, lambda root,x: UOp.const(root.dtype, x.arg)),
166
+ # max -2147483648
167
+ ({"uop": UOps.ALU, "arg": BinaryOps.MAX, "dtype": dtypes.int, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -2147483648}]}, lambda x: x),
168
+ # -(-x) -> x
169
+ ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"__name__": "x"},)})}, lambda x: x),
170
+ # x+-y -> x-y
171
+ ({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "x"}, {"__name__": "my", "uop": UOps.ALU, "arg": UnaryOps.NEG})},
172
+ lambda x, my: x-my.vin[0]),
173
+ # -1*x -> -x
174
+ ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -1}]}, lambda x: -x),
175
+ # bool < False is always false, True < bool is always false
176
+ ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({}, {"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": False})}, lambda x: x),
177
+ ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": True}, {})},
178
+ lambda x: UOp.const(dtypes.bool, False)),
179
+ # a conditional with the same results either way is a noop, also fold const conditionals
180
+ ({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({}, {"__name__": "val"}, {"__name__": "val"})}, lambda val: val),
181
+ ({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"__name__": "gate", "uop": UOps.CONST}, {"__name__": "c0"}, {"__name__": "c1"})},
182
+ lambda gate, c0, c1: c0 if gate.arg else c1),
183
+ # ** constant folding **
184
+ ({"__name__": "root", "uop": UOps.ALU, "vin": {"uop": UOps.CONST}},
185
+ lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))),
186
+ # ** self folding **
187
+ ({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 0}]}, lambda x: x), # x+0 -> x or 0+x -> x
188
+ ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 1}]}, lambda x: x), # x*1 -> x or 1*x -> x
189
+ ({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 0})}, lambda x: x), # x-0 -> x
190
+ ({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 1})}, lambda x: x), # x/1 -> x
191
+ ({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": -1})}, lambda x: -x), # x/-1 -> -x
192
+ # ** zero folding **
193
+ ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{}, {"__name__": "c", "uop": UOps.CONST, "arg": 0}]}, lambda c: c), # x*0 -> 0 or 0*x -> 0
194
+ ({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"__name__": "x"})}, lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
195
+ # ** load/store folding **
196
+ ({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"},
197
+ {"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})}, lambda buf, idx: UOp(UOps.NOOP)),
198
+ # ** two stage add/sub folding **
199
+ ({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"uop": UOps.ALU, "arg": BinaryOps.ADD,
200
+ "vin": [{"__name__": "x"}, {"__name__": "c1", "uop": UOps.CONST}]}, {"__name__": "c2", "uop": UOps.CONST}]},
201
+ lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
202
+ ({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"uop": UOps.ALU, "arg": BinaryOps.SUB,
203
+ "vin": ({"__name__": "x"}, {"__name__": "c1", "uop": UOps.CONST})}, {"__name__": "c2", "uop": UOps.CONST}]},
204
+ lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.SUB, x.dtype, [c2.arg, c1.arg]))),
205
+ # TODO: can do the invert of this (flip alt/load) when we fix double ops
206
+ ({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.ALU, "arg": TernaryOps.WHERE,
207
+ "vin": ({"__name__": "gate"}, {"__name__": "alt"}, {"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})})},
208
+ lambda buf, idx, gate, alt: UOp(UOps.STORE, None, (buf, idx, alt, gate))),
209
+ # store float4/float2 directly (remove CAST/GEP)
210
+ ({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.CAST, "vin":
211
+ tuple({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i} for i in range(4))})},
212
+ lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
213
+ ({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.CAST, "vin":
214
+ tuple({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i} for i in range(2))})},
215
+ lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
216
+ # CAST-PHI-GEP -> PHI-CAST
217
+ ({"__name__": "root", "uop": UOps.CAST, "vin":
218
+ tuple({"uop": UOps.PHI, "vin": ({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i}, {"__name__": f"v{i}"})} for i in range(4))},
219
+ lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
220
+ ({"__name__": "root", "uop": UOps.CAST, "vin":
221
+ tuple({"uop": UOps.PHI, "vin": ({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i}, {"__name__": f"v{i}"})} for i in range(2))},
222
+ lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
223
+ # NEG/CMPLT -> CMPLT
224
+ ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"__name__": "x"},)},
225
+ {"__name__": "c", "uop": UOps.CONST, "dtype": dtypes.int})},
226
+ lambda c,x: UOp(UOps.ALU, dtypes.bool, (UOp.const(c.dtype, -c.arg), x), BinaryOps.CMPLT)),
227
+ # cast NOOP (NOTE: it's str to deal with PtrDType)
228
+ ({"__name__": "root", "uop": UOps.CAST}, lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None),
229
+ ])
230
+
231
+ # *** uop graph ***
232
+
233
+ class UOpGraph:
234
+ def __init__(self):
235
+ self.nodes: Dict[Tuple, UOp] = {}
236
+ self._uops: Optional[List[UOp]] = None
237
+
238
+ def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
239
+ def __getitem__(self, index) -> UOp: return self.uops[index]
240
+
241
+ def vars(self) -> List[Variable]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_VAR]
242
+ def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_GLOBAL]
243
+
244
+ @property
245
+ def uops(self):
246
+ if self._uops is None: self.linearize()
247
+ return self._uops
248
+
249
+ def graph(self):
250
+ from tinygrad.engine.graph import graph_uops
251
+ graph_uops(self.uops)
252
+
253
+ def print(self):
254
+ for i,u in enumerate(self):
255
+ print(f"{i:4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.vin]):32s} {u.arg}")
256
+
257
+ def graph_rewrite(self, sink, pm):
258
+ # recursive rewrite
259
+ changed = getenv("UOPS_REWRITE", 1)
260
+ run_cnt = 0
261
+ while changed:
262
+ changed = 0
263
+ @functools.lru_cache
264
+ def rewrite(u:UOp) -> UOp:
265
+ nonlocal changed
266
+ recurse_cnt = 0
267
+ up = u
268
+ # locally recursively rewrite
269
+ while (rewritten := pm.rewrite(up)):
270
+ assert recurse_cnt < 100, f"recursive_rewrite looped {up} <--> {rewritten}"
271
+ up = rewritten
272
+ recurse_cnt += 1
273
+ changed += recurse_cnt
274
+ # NOTE: this changes UOp, so we have to delete caches
275
+ up.vin = tuple(rewrite(x) for x in up.vin)
276
+ if hasattr(up, "parents"): del up.parents
277
+ if hasattr(up, "cmp_tuple"): del up.cmp_tuple
278
+ # replace with cached nodes
279
+ if found:=self.nodes.get(key:=up.tuple()): return found
280
+ else: self.nodes[key] = up
281
+ return up
282
+ sink = rewrite(sink)
283
+ run_cnt += 1
284
+ assert run_cnt < 100, "exceeded 100 rewrite loops!"
285
+ return sink
286
+
287
+ def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
288
+ # NOTE: relinearizering should be okay
289
+ #assert self._uops is None, "already linearized"
290
+
291
+ # get sink
292
+ _sinks: List[UOp] = []
293
+ for u in self.nodes.values():
294
+ if u.uop is UOps.STORE: _sinks.append(u)
295
+ if u.uop is UOps.SINK: _sinks.extend(u.vin)
296
+ sink = UOp(UOps.SINK, None, tuple(_sinks))
297
+ del _sinks
298
+
299
+ sink = self.graph_rewrite(sink, constant_folder)
300
+ if extra_pm: sink = self.graph_rewrite(sink, PatternMatcher(constant_folder.patterns+extra_pm.patterns))
301
+
302
+ # filter nodes that don't link to a sink
303
+ # BFS toposort
304
+ graph: DefaultDict[UOp, List[UOp]] = defaultdict(list)
305
+ in_degree: DefaultDict[UOp, int] = defaultdict(int)
306
+ loops = []
307
+ ifs = []
308
+ nodes: Dict[UOp, None] = {}
309
+ def add_parents(u:UOp):
310
+ if u in nodes: return
311
+ nodes[u] = None
312
+ for x in u.vin:
313
+ add_parents(x)
314
+ in_degree[u] += 1
315
+ graph[x].append(u)
316
+ if u.uop is UOps.RANGE: loops.append(u)
317
+ if u.uop is UOps.IF: ifs.append(u)
318
+ sink = UOp(UOps.SINK, None, tuple(x for x in sink.vin if x.uop is not UOps.NOOP))
319
+ add_parents(sink)
320
+
321
+ @functools.lru_cache(None)
322
+ def get_recursive_children(x:UOp, include_self=False) -> Set[UOp]:
323
+ if x.uop is UOps.SINK: return set()
324
+ return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, True) for u in graph[x]] if x.uop is not UOps.PHI else []))
325
+ loops_children = {l:get_recursive_children(l) for l in loops[::-1]}
326
+
327
+ queue: List = []
328
+ def push(u):
329
+ priority = 0
330
+ # prefer uops that are loop children
331
+ for l, ss in loops_children.items():
332
+ if u in ss: priority -= l.arg[0]*1000 + l.arg[1]
333
+ heapq.heappush(queue, (priority, u))
334
+
335
+ for u in nodes:
336
+ if in_degree[u] == 0: push(u)
337
+
338
+ if getenv("FUZZ_UOPS", 0):
339
+ from test.external.fuzz_uops import fuzz_uops
340
+ self.fuzz_paths = fuzz_uops(graph, in_degree.copy(), loops_children)
341
+
342
+ self._uops = []
343
+ while queue:
344
+ p,x = heapq.heappop(queue)
345
+ if DEBUG >= 7: print(p,x)
346
+ if x.uop is UOps.DEFINE_ACC and len(x.vin):
347
+ idx = min([self._uops.index(l) for l in x.vin])
348
+ self._uops.insert(idx, x)
349
+ else:
350
+ self._uops.append(x)
351
+ for u, ss in loops_children.items():
352
+ if x in ss:
353
+ ss.remove(x)
354
+ if len(ss) == 0: self._uops.append(UOp(UOps.ENDRANGE, None, (u,)))
355
+ for u in graph[x]:
356
+ in_degree[u] -= 1
357
+ if in_degree[u] == 0: push(u)
358
+
359
+ assert self._uops[-1].uop is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
360
+ self._uops = self._uops[:-1]
361
+
362
+ # TODO: ifs should be removed and just the store should be gated
363
+ for u in ifs[::-1]: self._uops.append(UOp(UOps.ENDIF, None, (u,)))
364
+
365
+ if type_verify: self.type_verify()
366
+
367
+ def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None) -> UOp:
368
+ if found:=self.nodes.get(key:=(uop, dtype, vin, arg)): return found
369
+ self.nodes[key] = ret = UOp(*key)
370
+ return ret
371
+
372
+ # *** checker functions ***
373
+
374
+ def flops_mem(self) -> Tuple[sint, sint]:
375
+ flops: sint = 0
376
+ mem: sint = 0
377
+ mults: sint = 1
378
+ mult_stack = []
379
+ for u in self.uops:
380
+ if u.uop is UOps.RANGE:
381
+ mult_stack.append(mults)
382
+ mults *= uop_alu_resolve(u.vin[1])
383
+ elif u.uop is UOps.ENDRANGE:
384
+ mults = mult_stack.pop(-1)
385
+ elif u.uop is UOps.ALU:
386
+ flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
387
+ elif u.uop is UOps.LOAD:
388
+ assert u.dtype is not None
389
+ mem += u.dtype.itemsize * mults
390
+ elif u.uop is UOps.STORE:
391
+ assert u.vin[2].dtype is not None
392
+ mem += u.vin[2].dtype.itemsize * mults
393
+ elif u.uop is UOps.WMMA:
394
+ assert u.arg[1] is not None
395
+ flops += 2 * prod(u.arg[1]) // 32 * mults
396
+ return flops, mem
397
+
398
+ def type_verify(self):
399
+ for u in self.uops:
400
+ uop, arg, vin, dtype = u.uop, u.arg, u.vin, u.dtype
401
+ if uop in {UOps.CONST, UOps.DEFINE_ACC}:
402
+ if uop is UOps.DEFINE_ACC: arg = arg[0]
403
+ assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
404
+ if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
405
+ if uop is UOps.ALU:
406
+ if arg in UnaryOps:
407
+ assert dtype == vin[0].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=}"
408
+ elif arg in (BinaryOps.CMPLT, BinaryOps.CMPEQ):
409
+ assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
410
+ assert vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
411
+ elif arg in BinaryOps:
412
+ assert dtype == vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
413
+ elif arg == TernaryOps.WHERE:
414
+ assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
415
+ assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}"