tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. tinygrad/__init__.py +6 -0
  2. tinygrad/codegen/kernel.py +572 -83
  3. tinygrad/codegen/linearizer.py +415 -395
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +183 -0
  6. tinygrad/dtype.py +113 -0
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +76 -55
  14. tinygrad/helpers.py +196 -89
  15. tinygrad/lazy.py +210 -371
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +202 -22
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +112 -32
  20. tinygrad/nn/state.py +136 -39
  21. tinygrad/ops.py +119 -202
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +353 -166
  25. tinygrad/renderer/llvmir.py +150 -138
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +81 -0
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +75 -0
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +24 -77
  43. tinygrad/runtime/ops_cuda.py +175 -89
  44. tinygrad/runtime/ops_disk.py +56 -33
  45. tinygrad/runtime/ops_gpu.py +92 -95
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +39 -60
  48. tinygrad/runtime/ops_metal.py +92 -74
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +86 -254
  53. tinygrad/shape/symbolic.py +166 -141
  54. tinygrad/shape/view.py +296 -0
  55. tinygrad/tensor.py +2619 -448
  56. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. tinygrad-0.9.0.dist-info/METADATA +227 -0
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/codegen/assembly.py +0 -190
  61. tinygrad/codegen/optimizer.py +0 -379
  62. tinygrad/codegen/search.py +0 -72
  63. tinygrad/graph.py +0 -83
  64. tinygrad/jit.py +0 -57
  65. tinygrad/nn/image.py +0 -100
  66. tinygrad/renderer/assembly_arm64.py +0 -169
  67. tinygrad/renderer/assembly_ptx.py +0 -98
  68. tinygrad/renderer/wgsl.py +0 -53
  69. tinygrad/runtime/lib.py +0 -113
  70. tinygrad/runtime/ops_cpu.py +0 -51
  71. tinygrad/runtime/ops_hip.py +0 -82
  72. tinygrad/runtime/ops_shm.py +0 -29
  73. tinygrad/runtime/ops_torch.py +0 -30
  74. tinygrad/runtime/ops_webgpu.py +0 -45
  75. tinygrad-0.7.0.dist-info/METADATA +0 -212
  76. tinygrad-0.7.0.dist-info/RECORD +0 -40
  77. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,415 @@
1
+ from __future__ import annotations
2
+ from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable
3
+ import functools, itertools, heapq
4
+ from collections import defaultdict
5
+ from enum import Enum, auto
6
+ from dataclasses import dataclass
7
+ from tinygrad.dtype import dtypes, DType
8
+ from tinygrad.shape.symbolic import sint, Variable
9
+ from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
10
+ from tinygrad.helpers import prod, DEBUG, getenv
11
+
12
+ # the order of these UOps controls the order of the toposort
13
+ class UOps(Enum):
14
+ # ops that aren't rendered
15
+ SINK = auto()
16
+ DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
17
+ CONST = auto(); SPECIAL = auto() # noqa: E702
18
+ NOOP = auto(); UNMUL = auto(); GEP = auto() # noqa: E702
19
+ # math ops
20
+ CAST = auto(); BITCAST = auto() # noqa: E702
21
+ ALU = auto(); WMMA = auto() # noqa: E702
22
+ # memory/assignment ops
23
+ LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
24
+ # control flow ops
25
+ BARRIER = auto(); IF = auto(); RANGE = auto() # noqa: E702
26
+ # these two are not graph nodes
27
+ ENDRANGE = auto(); ENDIF = auto() # noqa: E702
28
+
29
+ @dataclass(eq=False)
30
+ class UOp:
31
+ uop: UOps
32
+ dtype: Optional[DType] = None
33
+ vin: Tuple[UOp, ...] = tuple()
34
+ arg: Any = None
35
+ def tuple(self): return (self.uop, self.dtype, self.vin, self.arg)
36
+ @functools.cached_property
37
+ def cmp_tuple(self):
38
+ # NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
39
+ return (self.uop.value, (self.arg if self.uop is not UOps.DEFINE_VAR else self.arg.expr) if self.uop is not UOps.ALU else \
40
+ (type(self.uop), self.uop.value), self.dtype, self.vin)
41
+ def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
42
+ def __repr__(self):
43
+ return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
44
+ def cast(self, dtype): return UOp(UOps.CAST, dtype, (self,))
45
+ def __neg__(self): return UOp.alu(UnaryOps.NEG, self)
46
+ def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, x)
47
+ def __sub__(self, x): return UOp.alu(BinaryOps.SUB, self, x)
48
+ def __mul__(self, x): return UOp.alu(BinaryOps.MUL, self, x)
49
+ @staticmethod
50
+ def max(x, y): return UOp.alu(BinaryOps.MAX, x, y)
51
+ @staticmethod
52
+ def min(x, y): return -UOp.alu(BinaryOps.MAX, -x, -y)
53
+ @staticmethod
54
+ def const(dtype, val): return UOp(UOps.CONST, dtype, arg=dtypes.as_const(val, dtype))
55
+ @staticmethod
56
+ def alu(arg, *vin:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else vin[-1].dtype, vin, arg)
57
+ @functools.cached_property
58
+ def parents(self) -> Set[UOp]: return set.union(set(self.vin), *[x.parents for x in self.vin])
59
+
60
+ def uop_alu_resolve(u:UOp) -> sint:
61
+ if u.uop is UOps.CONST: return u.arg
62
+ if u.uop is UOps.DEFINE_VAR: return u.arg
63
+ if u.uop is UOps.SPECIAL: return u.arg[2]-1
64
+ if u.uop is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1])
65
+ if u.uop is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1])
66
+ raise RuntimeError(f"ALU resolve fail @ {u.uop}")
67
+
68
+ # *** simplification logic ***
69
+
70
+ def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool:
71
+ for k,v in pattern.items():
72
+ if k == "__name__":
73
+ if v in store and store[v] != uop: return False
74
+ store[v] = uop
75
+ elif k == "arg":
76
+ if uop.arg != v: return False
77
+ elif k == "dtype":
78
+ if isinstance(v, set):
79
+ if uop.dtype not in v: return False
80
+ elif uop.dtype != v: return False
81
+ elif k == "uop":
82
+ if isinstance(v, set):
83
+ if uop.uop not in v: return False
84
+ elif uop.uop != v: return False
85
+ elif k == "vin":
86
+ # only one if it's a tuple
87
+ # try all permutations if it's a list
88
+ # repeat if it's a dict
89
+ for vp in itertools.permutations(v) if isinstance(v, list) else ([v] if isinstance(v, tuple) else [(v,)*len(uop.vin)]):
90
+ if len(uop.vin) != len(vp) and (len(uop.vin) not in pattern.get('__allow_len__', [])): return False
91
+ new_store = store.copy()
92
+ if all(_match(uu, vv, new_store) for uu, vv in zip(uop.vin, vp)):
93
+ for k,v in new_store.items(): store[k] = v
94
+ return True
95
+ return False
96
+ return True
97
+
98
+ class PatternMatcher:
99
+ def __init__(self, patterns:List[Tuple[Dict[str, Any], Callable]]):
100
+ self.patterns = patterns
101
+ self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[Dict[str, Any], Callable]]] = defaultdict(list)
102
+ # uop is required, arg is optional
103
+ for p,fxn in self.patterns:
104
+ uops = p["uop"]
105
+ if isinstance(uops, set):
106
+ for uop in uops: self.pdict[(uop, p.get("arg", None))].append((p, fxn))
107
+ else:
108
+ self.pdict[(uops, p.get("arg", None))].append((p, fxn))
109
+
110
+ def rewrite(self, uop:UOp) -> Optional[UOp]:
111
+ for p,fxn in itertools.chain(self.pdict[(uop.uop, uop.arg)], self.pdict[(uop.uop, None)]):
112
+ store: Dict[str, UOp] = {}
113
+ if _match(uop, p, store): return fxn(**store)
114
+ return None
115
+
116
+ def sum_collapse(phi_input, loop, val1, val2):
117
+ for v1,v2 in [(val1, val2), (val2, val1)]:
118
+ if loop not in v1.parents:
119
+ loop_range = loop.vin[1]-loop.vin[0]
120
+ ret = v1*loop_range.cast(v1.dtype)
121
+ return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+ret
122
+ return None
123
+
124
+ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst):
125
+ if mval.arg >= 0 or loop_start.arg != 0:
126
+ # TODO: support and test this with other mvals and loop_starts
127
+ if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
128
+ return None
129
+ comprange = UOp.min(loop_end, UOp.max(UOp.alu(BinaryOps.DIV, idx-compval-mval, mval) + (loop_end-loop_start), loop_start))
130
+ return UOp(UOps.UNMUL, multconst.dtype, (comprange.cast(multconst.dtype) * multconst, loop_end-loop_start))
131
+
132
+ # this is symbolic 2.0
133
+ constant_folder = PatternMatcher([
134
+ # arange loop folding (early)
135
+ ({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": (
136
+ {"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin":
137
+ [{"__name__": "idx"}, {"uop": UOps.ALU, "arg": BinaryOps.MUL,
138
+ "vin": [{"__name__": "mval", "uop": UOps.CONST}, {"uop": UOps.RANGE, "vin": ({"__name__": "loop_start"}, {"__name__": "loop_end"})}]}]},
139
+ {"__name__": "compval", "uop": UOps.CONST})}, {"__name__": "multconst", "uop": UOps.CONST}, {"uop": UOps.CONST, "arg": 0})}, loop_collapse),
140
+ # sum collapse to mul (with possible GEP)
141
+ ({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.DEFINE_ACC, "vin": ({"uop": UOps.RANGE, "__name__": "loop"},)},
142
+ {"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
143
+ ({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.GEP,
144
+ "vin": ({"uop": UOps.DEFINE_ACC, "vin":({"uop": UOps.RANGE, "__name__": "loop"},)},)},
145
+ {"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
146
+ # deal with UNMUL
147
+ ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"uop": UOps.CONST, "__name__": "c1"},
148
+ {"uop": UOps.UNMUL, "vin": [{"uop": UOps.CONST, "__name__": "c2"}, {"__name__": "v"}]}]},
149
+ lambda c1,c2,v: v if c1.arg == c2.arg else None),
150
+ ({"uop": UOps.UNMUL, "vin": ({"uop": UOps.CONST, "__name__": "zero", "arg": 0}, {})}, lambda zero: zero),
151
+ ({"__name__": "root", "uop": UOps.CAST, "vin": ({"uop": UOps.UNMUL, "__name__": "unmul"},)},
152
+ lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.vin[0].cast(root.dtype), unmul.vin[1]))),
153
+ # max on special can go away (TODO: special should be variable, same thing applies)
154
+ ({"uop": UOps.ALU, "arg": BinaryOps.MAX, "vin": [{"__name__": "c", "uop": UOps.CONST}, {"__name__": "s", "uop": UOps.SPECIAL}]},
155
+ lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
156
+ # const rules
157
+ ({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "c", "uop": UOps.CONST},)}, lambda root, c: UOp.const(root.dtype, c.arg)),
158
+ ({"__name__": "root", "uop": UOps.CAST, "vin": {"__name__": "c", "uop": UOps.CONST}}, lambda root, c: UOp.const(root.dtype, c.arg)),
159
+ # a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
160
+ ({"uop": UOps.PHI, "vin": ({"uop": UOps.DEFINE_ACC, "__name__": "acc"}, {"__name__": "acc"})}, lambda acc: UOp.const(acc.dtype, acc.arg[0])),
161
+ ({"uop": UOps.PHI, "vin": ({"uop": UOps.DEFINE_ACC, "vin": tuple()}, {"__name__": "x"})}, lambda x: x),
162
+ ({"uop": UOps.PHI, "vin": ({"uop": UOps.CONST}, {"__name__": "x"})}, lambda x: x),
163
+ # a DEFINE_ACC without inputs is a const + GEP on a const is the const
164
+ ({"__name__": "root", "uop": UOps.DEFINE_ACC, "vin": tuple()}, lambda root: UOp.const(root.dtype, root.arg[0])),
165
+ ({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "x", "uop": UOps.CONST},)}, lambda root,x: UOp.const(root.dtype, x.arg)),
166
+ # max -2147483648
167
+ ({"uop": UOps.ALU, "arg": BinaryOps.MAX, "dtype": dtypes.int, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -2147483648}]}, lambda x: x),
168
+ # -(-x) -> x
169
+ ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"__name__": "x"},)})}, lambda x: x),
170
+ # x+-y -> x-y
171
+ ({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "x"}, {"__name__": "my", "uop": UOps.ALU, "arg": UnaryOps.NEG})},
172
+ lambda x, my: x-my.vin[0]),
173
+ # -1*x -> -x
174
+ ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -1}]}, lambda x: -x),
175
+ # bool < False is always false, True < bool is always false
176
+ ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({}, {"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": False})}, lambda x: x),
177
+ ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": True}, {})},
178
+ lambda x: UOp.const(dtypes.bool, False)),
179
+ # a conditional with the same results either way is a noop, also fold const conditionals
180
+ ({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({}, {"__name__": "val"}, {"__name__": "val"})}, lambda val: val),
181
+ ({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"__name__": "gate", "uop": UOps.CONST}, {"__name__": "c0"}, {"__name__": "c1"})},
182
+ lambda gate, c0, c1: c0 if gate.arg else c1),
183
+ # ** constant folding **
184
+ ({"__name__": "root", "uop": UOps.ALU, "vin": {"uop": UOps.CONST}},
185
+ lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))),
186
+ # ** self folding **
187
+ ({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 0}]}, lambda x: x), # x+0 -> x or 0+x -> x
188
+ ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 1}]}, lambda x: x), # x*1 -> x or 1*x -> x
189
+ ({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 0})}, lambda x: x), # x-0 -> x
190
+ ({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 1})}, lambda x: x), # x/1 -> x
191
+ ({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": -1})}, lambda x: -x), # x/-1 -> -x
192
+ # ** zero folding **
193
+ ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{}, {"__name__": "c", "uop": UOps.CONST, "arg": 0}]}, lambda c: c), # x*0 -> 0 or 0*x -> 0
194
+ ({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"__name__": "x"})}, lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
195
+ # ** load/store folding **
196
+ ({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"},
197
+ {"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})}, lambda buf, idx: UOp(UOps.NOOP)),
198
+ # ** two stage add/sub folding **
199
+ ({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"uop": UOps.ALU, "arg": BinaryOps.ADD,
200
+ "vin": [{"__name__": "x"}, {"__name__": "c1", "uop": UOps.CONST}]}, {"__name__": "c2", "uop": UOps.CONST}]},
201
+ lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
202
+ ({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"uop": UOps.ALU, "arg": BinaryOps.SUB,
203
+ "vin": ({"__name__": "x"}, {"__name__": "c1", "uop": UOps.CONST})}, {"__name__": "c2", "uop": UOps.CONST}]},
204
+ lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.SUB, x.dtype, [c2.arg, c1.arg]))),
205
+ # TODO: can do the invert of this (flip alt/load) when we fix double ops
206
+ ({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.ALU, "arg": TernaryOps.WHERE,
207
+ "vin": ({"__name__": "gate"}, {"__name__": "alt"}, {"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})})},
208
+ lambda buf, idx, gate, alt: UOp(UOps.STORE, None, (buf, idx, alt, gate))),
209
+ # store float4/float2 directly (remove CAST/GEP)
210
+ ({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.CAST, "vin":
211
+ tuple({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i} for i in range(4))})},
212
+ lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
213
+ ({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.CAST, "vin":
214
+ tuple({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i} for i in range(2))})},
215
+ lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
216
+ # CAST-PHI-GEP -> PHI-CAST
217
+ ({"__name__": "root", "uop": UOps.CAST, "vin":
218
+ tuple({"uop": UOps.PHI, "vin": ({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i}, {"__name__": f"v{i}"})} for i in range(4))},
219
+ lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
220
+ ({"__name__": "root", "uop": UOps.CAST, "vin":
221
+ tuple({"uop": UOps.PHI, "vin": ({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i}, {"__name__": f"v{i}"})} for i in range(2))},
222
+ lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
223
+ # NEG/CMPLT -> CMPLT
224
+ ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"__name__": "x"},)},
225
+ {"__name__": "c", "uop": UOps.CONST, "dtype": dtypes.int})},
226
+ lambda c,x: UOp(UOps.ALU, dtypes.bool, (UOp.const(c.dtype, -c.arg), x), BinaryOps.CMPLT)),
227
+ # cast NOOP (NOTE: it's str to deal with PtrDType)
228
+ ({"__name__": "root", "uop": UOps.CAST}, lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None),
229
+ ])
230
+
231
+ # *** uop graph ***
232
+
233
+ class UOpGraph:
234
+ def __init__(self):
235
+ self.nodes: Dict[Tuple, UOp] = {}
236
+ self._uops: Optional[List[UOp]] = None
237
+
238
+ def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
239
+ def __getitem__(self, index) -> UOp: return self.uops[index]
240
+
241
+ def vars(self) -> List[Variable]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_VAR]
242
+ def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_GLOBAL]
243
+
244
+ @property
245
+ def uops(self):
246
+ if self._uops is None: self.linearize()
247
+ return self._uops
248
+
249
+ def graph(self):
250
+ from tinygrad.engine.graph import graph_uops
251
+ graph_uops(self.uops)
252
+
253
+ def print(self):
254
+ for i,u in enumerate(self):
255
+ print(f"{i:4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.vin]):32s} {u.arg}")
256
+
257
+ def graph_rewrite(self, sink, pm):
258
+ # recursive rewrite
259
+ changed = getenv("UOPS_REWRITE", 1)
260
+ run_cnt = 0
261
+ while changed:
262
+ changed = 0
263
+ @functools.lru_cache
264
+ def rewrite(u:UOp) -> UOp:
265
+ nonlocal changed
266
+ recurse_cnt = 0
267
+ up = u
268
+ # locally recursively rewrite
269
+ while (rewritten := pm.rewrite(up)):
270
+ assert recurse_cnt < 100, f"recursive_rewrite looped {up} <--> {rewritten}"
271
+ up = rewritten
272
+ recurse_cnt += 1
273
+ changed += recurse_cnt
274
+ # NOTE: this changes UOp, so we have to delete caches
275
+ up.vin = tuple(rewrite(x) for x in up.vin)
276
+ if hasattr(up, "parents"): del up.parents
277
+ if hasattr(up, "cmp_tuple"): del up.cmp_tuple
278
+ # replace with cached nodes
279
+ if found:=self.nodes.get(key:=up.tuple()): return found
280
+ else: self.nodes[key] = up
281
+ return up
282
+ sink = rewrite(sink)
283
+ run_cnt += 1
284
+ assert run_cnt < 100, "exceeded 100 rewrite loops!"
285
+ return sink
286
+
287
+ def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
288
+ # NOTE: relinearizering should be okay
289
+ #assert self._uops is None, "already linearized"
290
+
291
+ # get sink
292
+ _sinks: List[UOp] = []
293
+ for u in self.nodes.values():
294
+ if u.uop is UOps.STORE: _sinks.append(u)
295
+ if u.uop is UOps.SINK: _sinks.extend(u.vin)
296
+ sink = UOp(UOps.SINK, None, tuple(_sinks))
297
+ del _sinks
298
+
299
+ sink = self.graph_rewrite(sink, constant_folder)
300
+ if extra_pm: sink = self.graph_rewrite(sink, PatternMatcher(constant_folder.patterns+extra_pm.patterns))
301
+
302
+ # filter nodes that don't link to a sink
303
+ # BFS toposort
304
+ graph: DefaultDict[UOp, List[UOp]] = defaultdict(list)
305
+ in_degree: DefaultDict[UOp, int] = defaultdict(int)
306
+ loops = []
307
+ ifs = []
308
+ nodes: Dict[UOp, None] = {}
309
+ def add_parents(u:UOp):
310
+ if u in nodes: return
311
+ nodes[u] = None
312
+ for x in u.vin:
313
+ add_parents(x)
314
+ in_degree[u] += 1
315
+ graph[x].append(u)
316
+ if u.uop is UOps.RANGE: loops.append(u)
317
+ if u.uop is UOps.IF: ifs.append(u)
318
+ sink = UOp(UOps.SINK, None, tuple(x for x in sink.vin if x.uop is not UOps.NOOP))
319
+ add_parents(sink)
320
+
321
+ @functools.lru_cache(None)
322
+ def get_recursive_children(x:UOp, include_self=False) -> Set[UOp]:
323
+ if x.uop is UOps.SINK: return set()
324
+ return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, True) for u in graph[x]] if x.uop is not UOps.PHI else []))
325
+ loops_children = {l:get_recursive_children(l) for l in loops[::-1]}
326
+
327
+ queue: List = []
328
+ def push(u):
329
+ priority = 0
330
+ # prefer uops that are loop children
331
+ for l, ss in loops_children.items():
332
+ if u in ss: priority -= l.arg[0]*1000 + l.arg[1]
333
+ heapq.heappush(queue, (priority, u))
334
+
335
+ for u in nodes:
336
+ if in_degree[u] == 0: push(u)
337
+
338
+ if getenv("FUZZ_UOPS", 0):
339
+ from test.external.fuzz_uops import fuzz_uops
340
+ self.fuzz_paths = fuzz_uops(graph, in_degree.copy(), loops_children)
341
+
342
+ self._uops = []
343
+ while queue:
344
+ p,x = heapq.heappop(queue)
345
+ if DEBUG >= 7: print(p,x)
346
+ if x.uop is UOps.DEFINE_ACC and len(x.vin):
347
+ idx = min([self._uops.index(l) for l in x.vin])
348
+ self._uops.insert(idx, x)
349
+ else:
350
+ self._uops.append(x)
351
+ for u, ss in loops_children.items():
352
+ if x in ss:
353
+ ss.remove(x)
354
+ if len(ss) == 0: self._uops.append(UOp(UOps.ENDRANGE, None, (u,)))
355
+ for u in graph[x]:
356
+ in_degree[u] -= 1
357
+ if in_degree[u] == 0: push(u)
358
+
359
+ assert self._uops[-1].uop is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
360
+ self._uops = self._uops[:-1]
361
+
362
+ # TODO: ifs should be removed and just the store should be gated
363
+ for u in ifs[::-1]: self._uops.append(UOp(UOps.ENDIF, None, (u,)))
364
+
365
+ if type_verify: self.type_verify()
366
+
367
+ def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None) -> UOp:
368
+ if found:=self.nodes.get(key:=(uop, dtype, vin, arg)): return found
369
+ self.nodes[key] = ret = UOp(*key)
370
+ return ret
371
+
372
+ # *** checker functions ***
373
+
374
+ def flops_mem(self) -> Tuple[sint, sint]:
375
+ flops: sint = 0
376
+ mem: sint = 0
377
+ mults: sint = 1
378
+ mult_stack = []
379
+ for u in self.uops:
380
+ if u.uop is UOps.RANGE:
381
+ mult_stack.append(mults)
382
+ mults *= uop_alu_resolve(u.vin[1])
383
+ elif u.uop is UOps.ENDRANGE:
384
+ mults = mult_stack.pop(-1)
385
+ elif u.uop is UOps.ALU:
386
+ flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
387
+ elif u.uop is UOps.LOAD:
388
+ assert u.dtype is not None
389
+ mem += u.dtype.itemsize * mults
390
+ elif u.uop is UOps.STORE:
391
+ assert u.vin[2].dtype is not None
392
+ mem += u.vin[2].dtype.itemsize * mults
393
+ elif u.uop is UOps.WMMA:
394
+ assert u.arg[1] is not None
395
+ flops += 2 * prod(u.arg[1]) // 32 * mults
396
+ return flops, mem
397
+
398
+ def type_verify(self):
399
+ for u in self.uops:
400
+ uop, arg, vin, dtype = u.uop, u.arg, u.vin, u.dtype
401
+ if uop in {UOps.CONST, UOps.DEFINE_ACC}:
402
+ if uop is UOps.DEFINE_ACC: arg = arg[0]
403
+ assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
404
+ if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
405
+ if uop is UOps.ALU:
406
+ if arg in UnaryOps:
407
+ assert dtype == vin[0].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=}"
408
+ elif arg in (BinaryOps.CMPLT, BinaryOps.CMPEQ):
409
+ assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
410
+ assert vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
411
+ elif arg in BinaryOps:
412
+ assert dtype == vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
413
+ elif arg == TernaryOps.WHERE:
414
+ assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
415
+ assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}"
tinygrad/device.py ADDED
@@ -0,0 +1,183 @@
1
+ from __future__ import annotations
2
+ import multiprocessing
3
+ from dataclasses import dataclass
4
+ from collections import defaultdict
5
+ from typing import List, Optional, Dict, Tuple, Any
6
+ import importlib, inspect, functools, pathlib, os, ctypes
7
+ from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
8
+ from tinygrad.dtype import DType, ImageDType
9
+ from tinygrad.renderer import Renderer
10
+
11
+ # **************** Device ****************
12
+
13
+ class _Device:
14
+ def __init__(self) -> None: self._devices: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] # noqa: E501
15
+ @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
16
+ def _canonicalize(self, device:str) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") # noqa: E501
17
+ # NOTE: you can't cache canonicalize in case Device.DEFAULT changes
18
+ def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
19
+ def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
20
+ @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
21
+ def __get_canonicalized_item(self, ix:str) -> Compiled:
22
+ if DEBUG >= 1: print(f"opening device {ix} from pid:{os.getpid()}")
23
+ assert multiprocessing.current_process().name == "MainProcess" or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent"
24
+ x = ix.split(":")[0].upper()
25
+ return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
26
+ @functools.cached_property
27
+ def DEFAULT(self) -> str:
28
+ device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore
29
+ if device_from_env: return device_from_env
30
+ for device in ["METAL", "HSA", "CUDA", "GPU", "CLANG", "LLVM"]:
31
+ try:
32
+ if self[device]:
33
+ os.environ[device] = "1" # we set this in environment for spawned children
34
+ return device
35
+ except Exception: pass
36
+ raise RuntimeError("no usable devices")
37
+ Device = _Device()
38
+
39
+ # **************** Buffer + Allocators ****************
40
+
41
+ @dataclass(frozen=True, eq=True)
42
+ class BufferOptions:
43
+ image: Optional[ImageDType] = None
44
+ uncached: bool = False
45
+ cpu_access: bool = False
46
+ host: bool = False
47
+ nolru: bool = False
48
+
49
+ class Buffer:
50
+ def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
51
+ initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
52
+ assert isinstance(dtype, DType)
53
+ if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
54
+ self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
55
+ if base is None:
56
+ assert offset == 0, "base buffers can't have offset"
57
+ self._base = None
58
+ self._lb_refcount = lb_refcount
59
+ if opaque is not None: self.allocate(opaque)
60
+ if initial_value is not None:
61
+ self.allocate()
62
+ self.copyin(memoryview(initial_value))
63
+ else:
64
+ assert base._base is None, "base can't have a base"
65
+ assert device == base.device, "base must have the same device"
66
+ self._base = base
67
+ if preallocate: self.allocate()
68
+ @property
69
+ def base(self) -> Buffer: return self._base if self._base is not None else self
70
+ @property
71
+ def lb_refcount(self): return self.base._lb_refcount
72
+ def ref(self, cnt): self.base._lb_refcount += cnt
73
+ def is_allocated(self) -> bool: return hasattr(self, '_buf')
74
+ def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
75
+ def allocate(self, opaque=None) -> Buffer:
76
+ assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
77
+ self.allocator = Device[self.device].allocator
78
+ if self._base is not None:
79
+ self._base.ensure_allocated()
80
+ assert hasattr(self.allocator, "offset"), "offset function required for view"
81
+ self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset)
82
+ else:
83
+ self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
84
+ if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
85
+ return self
86
+ def __reduce__(self):
87
+ buf = None
88
+ if self._base is not None:
89
+ return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
90
+ if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
91
+ if self.is_allocated():
92
+ buf = bytearray(self.nbytes)
93
+ self.copyout(memoryview(buf))
94
+ return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
95
+ @property
96
+ def nbytes(self): return self.size*self.dtype.itemsize
97
+ def __del__(self):
98
+ if not hasattr(self, '_buf'): return
99
+ if self._base is None:
100
+ if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
101
+ self.allocator.free(self._buf, self.nbytes, self.options)
102
+ def __repr__(self):
103
+ return f"<buf real:{hasattr(self, '_buf')} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
104
+ (f" offset:{self.offset}" if hasattr(self, "base") else "") + \
105
+ (">" if self.options is None else f" {self.options=}>")
106
+ def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
107
+ # zero copy with as_buffer (disabled by default due to use after free)
108
+ if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
109
+ assert not force_zero_copy, "force zero copy was passed, but copy is required"
110
+ return self.copyout(memoryview(bytearray(self.nbytes)))
111
+ def copyin(self, mv:memoryview):
112
+ mv = flat_mv(mv)
113
+ assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
114
+ assert self.is_allocated(), "can't copyin to unallocated buffer"
115
+ self.allocator.copyin(self._buf, mv)
116
+ return self
117
+ def copyout(self, mv:memoryview) -> memoryview:
118
+ mv = flat_mv(mv)
119
+ assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
120
+ assert self.is_allocated(), "can't copyout unallocated buffer"
121
+ self.allocator.copyout(mv, self._buf)
122
+ return mv
123
+ def view(self, size:int, dtype:DType, offset:int) -> Buffer:
124
+ assert offset < self.nbytes, "offset must be less than nbytes"
125
+ if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
126
+ return Buffer(self.device, size, dtype, base=self, offset=offset)
127
+
128
+ # TODO: size, dest, src are the same type. can we enforce this?
129
+ class Allocator:
130
+ def alloc(self, size:int, options:Optional[BufferOptions]=None):
131
+ assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
132
+ return self._alloc(size, options if options is not None else BufferOptions())
133
+ def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
134
+ def free(self, opaque, size:int, options:Optional[BufferOptions]=None):
135
+ self._free(opaque, options if options is not None else BufferOptions())
136
+ def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
137
+ def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
138
+ def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
139
+
140
+ class LRUAllocator(Allocator): # pylint: disable=abstract-method
141
+ def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
142
+ def alloc(self, size:int, options:Optional[BufferOptions]=None):
143
+ if len(c := self.cache[(size, options)]): return c.pop()
144
+ try: return super().alloc(size, options)
145
+ except (RuntimeError, MemoryError):
146
+ self.free_cache()
147
+ return super().alloc(size, options)
148
+ def free_cache(self):
149
+ for (sz,options),opaques in self.cache.items():
150
+ for opaque in opaques: super().free(opaque, sz, options)
151
+ opaques.clear()
152
+ def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None):
153
+ if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
154
+ else: super().free(opaque, size, options)
155
+
156
+ class _MallocAllocator(LRUAllocator):
157
+ def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)()
158
+ def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
159
+ def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
160
+ def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
161
+ def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size])
162
+
163
+ MallocAllocator = _MallocAllocator()
164
+
165
+ # **************** for Compiled Devices ****************
166
+
167
+ class CompileError(Exception): pass
168
+
169
+ class Compiler:
170
+ def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
171
+ def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
172
+ def compile_cached(self, src:str) -> bytes:
173
+ if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
174
+ assert not getenv("ASSERT_COMPILE"), "tried to compile with ASSERT_COMPILE set"
175
+ lib = self.compile(src)
176
+ if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
177
+ return lib
178
+
179
+ class Compiled:
180
+ def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
181
+ self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler if compiler else Compiler(), runtime, graph
182
+ self.renderer = renderer if renderer else Renderer()
183
+ def synchronize(self): pass # override this in your device