tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
- tinygrad/codegen/kernel.py +572 -83
- tinygrad/codegen/linearizer.py +415 -395
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +183 -0
- tinygrad/dtype.py +113 -0
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +76 -55
- tinygrad/helpers.py +196 -89
- tinygrad/lazy.py +210 -371
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +202 -22
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +112 -32
- tinygrad/nn/state.py +136 -39
- tinygrad/ops.py +119 -202
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +353 -166
- tinygrad/renderer/llvmir.py +150 -138
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +81 -0
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +75 -0
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +24 -77
- tinygrad/runtime/ops_cuda.py +175 -89
- tinygrad/runtime/ops_disk.py +56 -33
- tinygrad/runtime/ops_gpu.py +92 -95
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +39 -60
- tinygrad/runtime/ops_metal.py +92 -74
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +86 -254
- tinygrad/shape/symbolic.py +166 -141
- tinygrad/shape/view.py +296 -0
- tinygrad/tensor.py +2619 -448
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- tinygrad-0.9.0.dist-info/METADATA +227 -0
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/assembly.py +0 -190
- tinygrad/codegen/optimizer.py +0 -379
- tinygrad/codegen/search.py +0 -72
- tinygrad/graph.py +0 -83
- tinygrad/jit.py +0 -57
- tinygrad/nn/image.py +0 -100
- tinygrad/renderer/assembly_arm64.py +0 -169
- tinygrad/renderer/assembly_ptx.py +0 -98
- tinygrad/renderer/wgsl.py +0 -53
- tinygrad/runtime/lib.py +0 -113
- tinygrad/runtime/ops_cpu.py +0 -51
- tinygrad/runtime/ops_hip.py +0 -82
- tinygrad/runtime/ops_shm.py +0 -29
- tinygrad/runtime/ops_torch.py +0 -30
- tinygrad/runtime/ops_webgpu.py +0 -45
- tinygrad-0.7.0.dist-info/METADATA +0 -212
- tinygrad-0.7.0.dist-info/RECORD +0 -40
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/codegen/uops.py
ADDED
@@ -0,0 +1,415 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable
|
3
|
+
import functools, itertools, heapq
|
4
|
+
from collections import defaultdict
|
5
|
+
from enum import Enum, auto
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from tinygrad.dtype import dtypes, DType
|
8
|
+
from tinygrad.shape.symbolic import sint, Variable
|
9
|
+
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
|
10
|
+
from tinygrad.helpers import prod, DEBUG, getenv
|
11
|
+
|
12
|
+
# the order of these UOps controls the order of the toposort
|
13
|
+
class UOps(Enum):
|
14
|
+
# ops that aren't rendered
|
15
|
+
SINK = auto()
|
16
|
+
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
|
17
|
+
CONST = auto(); SPECIAL = auto() # noqa: E702
|
18
|
+
NOOP = auto(); UNMUL = auto(); GEP = auto() # noqa: E702
|
19
|
+
# math ops
|
20
|
+
CAST = auto(); BITCAST = auto() # noqa: E702
|
21
|
+
ALU = auto(); WMMA = auto() # noqa: E702
|
22
|
+
# memory/assignment ops
|
23
|
+
LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
|
24
|
+
# control flow ops
|
25
|
+
BARRIER = auto(); IF = auto(); RANGE = auto() # noqa: E702
|
26
|
+
# these two are not graph nodes
|
27
|
+
ENDRANGE = auto(); ENDIF = auto() # noqa: E702
|
28
|
+
|
29
|
+
@dataclass(eq=False)
|
30
|
+
class UOp:
|
31
|
+
uop: UOps
|
32
|
+
dtype: Optional[DType] = None
|
33
|
+
vin: Tuple[UOp, ...] = tuple()
|
34
|
+
arg: Any = None
|
35
|
+
def tuple(self): return (self.uop, self.dtype, self.vin, self.arg)
|
36
|
+
@functools.cached_property
|
37
|
+
def cmp_tuple(self):
|
38
|
+
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
39
|
+
return (self.uop.value, (self.arg if self.uop is not UOps.DEFINE_VAR else self.arg.expr) if self.uop is not UOps.ALU else \
|
40
|
+
(type(self.uop), self.uop.value), self.dtype, self.vin)
|
41
|
+
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
|
42
|
+
def __repr__(self):
|
43
|
+
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
|
44
|
+
def cast(self, dtype): return UOp(UOps.CAST, dtype, (self,))
|
45
|
+
def __neg__(self): return UOp.alu(UnaryOps.NEG, self)
|
46
|
+
def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, x)
|
47
|
+
def __sub__(self, x): return UOp.alu(BinaryOps.SUB, self, x)
|
48
|
+
def __mul__(self, x): return UOp.alu(BinaryOps.MUL, self, x)
|
49
|
+
@staticmethod
|
50
|
+
def max(x, y): return UOp.alu(BinaryOps.MAX, x, y)
|
51
|
+
@staticmethod
|
52
|
+
def min(x, y): return -UOp.alu(BinaryOps.MAX, -x, -y)
|
53
|
+
@staticmethod
|
54
|
+
def const(dtype, val): return UOp(UOps.CONST, dtype, arg=dtypes.as_const(val, dtype))
|
55
|
+
@staticmethod
|
56
|
+
def alu(arg, *vin:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else vin[-1].dtype, vin, arg)
|
57
|
+
@functools.cached_property
|
58
|
+
def parents(self) -> Set[UOp]: return set.union(set(self.vin), *[x.parents for x in self.vin])
|
59
|
+
|
60
|
+
def uop_alu_resolve(u:UOp) -> sint:
|
61
|
+
if u.uop is UOps.CONST: return u.arg
|
62
|
+
if u.uop is UOps.DEFINE_VAR: return u.arg
|
63
|
+
if u.uop is UOps.SPECIAL: return u.arg[2]-1
|
64
|
+
if u.uop is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1])
|
65
|
+
if u.uop is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1])
|
66
|
+
raise RuntimeError(f"ALU resolve fail @ {u.uop}")
|
67
|
+
|
68
|
+
# *** simplification logic ***
|
69
|
+
|
70
|
+
def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool:
|
71
|
+
for k,v in pattern.items():
|
72
|
+
if k == "__name__":
|
73
|
+
if v in store and store[v] != uop: return False
|
74
|
+
store[v] = uop
|
75
|
+
elif k == "arg":
|
76
|
+
if uop.arg != v: return False
|
77
|
+
elif k == "dtype":
|
78
|
+
if isinstance(v, set):
|
79
|
+
if uop.dtype not in v: return False
|
80
|
+
elif uop.dtype != v: return False
|
81
|
+
elif k == "uop":
|
82
|
+
if isinstance(v, set):
|
83
|
+
if uop.uop not in v: return False
|
84
|
+
elif uop.uop != v: return False
|
85
|
+
elif k == "vin":
|
86
|
+
# only one if it's a tuple
|
87
|
+
# try all permutations if it's a list
|
88
|
+
# repeat if it's a dict
|
89
|
+
for vp in itertools.permutations(v) if isinstance(v, list) else ([v] if isinstance(v, tuple) else [(v,)*len(uop.vin)]):
|
90
|
+
if len(uop.vin) != len(vp) and (len(uop.vin) not in pattern.get('__allow_len__', [])): return False
|
91
|
+
new_store = store.copy()
|
92
|
+
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.vin, vp)):
|
93
|
+
for k,v in new_store.items(): store[k] = v
|
94
|
+
return True
|
95
|
+
return False
|
96
|
+
return True
|
97
|
+
|
98
|
+
class PatternMatcher:
|
99
|
+
def __init__(self, patterns:List[Tuple[Dict[str, Any], Callable]]):
|
100
|
+
self.patterns = patterns
|
101
|
+
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[Dict[str, Any], Callable]]] = defaultdict(list)
|
102
|
+
# uop is required, arg is optional
|
103
|
+
for p,fxn in self.patterns:
|
104
|
+
uops = p["uop"]
|
105
|
+
if isinstance(uops, set):
|
106
|
+
for uop in uops: self.pdict[(uop, p.get("arg", None))].append((p, fxn))
|
107
|
+
else:
|
108
|
+
self.pdict[(uops, p.get("arg", None))].append((p, fxn))
|
109
|
+
|
110
|
+
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
111
|
+
for p,fxn in itertools.chain(self.pdict[(uop.uop, uop.arg)], self.pdict[(uop.uop, None)]):
|
112
|
+
store: Dict[str, UOp] = {}
|
113
|
+
if _match(uop, p, store): return fxn(**store)
|
114
|
+
return None
|
115
|
+
|
116
|
+
def sum_collapse(phi_input, loop, val1, val2):
|
117
|
+
for v1,v2 in [(val1, val2), (val2, val1)]:
|
118
|
+
if loop not in v1.parents:
|
119
|
+
loop_range = loop.vin[1]-loop.vin[0]
|
120
|
+
ret = v1*loop_range.cast(v1.dtype)
|
121
|
+
return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+ret
|
122
|
+
return None
|
123
|
+
|
124
|
+
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst):
|
125
|
+
if mval.arg >= 0 or loop_start.arg != 0:
|
126
|
+
# TODO: support and test this with other mvals and loop_starts
|
127
|
+
if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
|
128
|
+
return None
|
129
|
+
comprange = UOp.min(loop_end, UOp.max(UOp.alu(BinaryOps.DIV, idx-compval-mval, mval) + (loop_end-loop_start), loop_start))
|
130
|
+
return UOp(UOps.UNMUL, multconst.dtype, (comprange.cast(multconst.dtype) * multconst, loop_end-loop_start))
|
131
|
+
|
132
|
+
# this is symbolic 2.0
|
133
|
+
constant_folder = PatternMatcher([
|
134
|
+
# arange loop folding (early)
|
135
|
+
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": (
|
136
|
+
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin":
|
137
|
+
[{"__name__": "idx"}, {"uop": UOps.ALU, "arg": BinaryOps.MUL,
|
138
|
+
"vin": [{"__name__": "mval", "uop": UOps.CONST}, {"uop": UOps.RANGE, "vin": ({"__name__": "loop_start"}, {"__name__": "loop_end"})}]}]},
|
139
|
+
{"__name__": "compval", "uop": UOps.CONST})}, {"__name__": "multconst", "uop": UOps.CONST}, {"uop": UOps.CONST, "arg": 0})}, loop_collapse),
|
140
|
+
# sum collapse to mul (with possible GEP)
|
141
|
+
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.DEFINE_ACC, "vin": ({"uop": UOps.RANGE, "__name__": "loop"},)},
|
142
|
+
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
|
143
|
+
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.GEP,
|
144
|
+
"vin": ({"uop": UOps.DEFINE_ACC, "vin":({"uop": UOps.RANGE, "__name__": "loop"},)},)},
|
145
|
+
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
|
146
|
+
# deal with UNMUL
|
147
|
+
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"uop": UOps.CONST, "__name__": "c1"},
|
148
|
+
{"uop": UOps.UNMUL, "vin": [{"uop": UOps.CONST, "__name__": "c2"}, {"__name__": "v"}]}]},
|
149
|
+
lambda c1,c2,v: v if c1.arg == c2.arg else None),
|
150
|
+
({"uop": UOps.UNMUL, "vin": ({"uop": UOps.CONST, "__name__": "zero", "arg": 0}, {})}, lambda zero: zero),
|
151
|
+
({"__name__": "root", "uop": UOps.CAST, "vin": ({"uop": UOps.UNMUL, "__name__": "unmul"},)},
|
152
|
+
lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.vin[0].cast(root.dtype), unmul.vin[1]))),
|
153
|
+
# max on special can go away (TODO: special should be variable, same thing applies)
|
154
|
+
({"uop": UOps.ALU, "arg": BinaryOps.MAX, "vin": [{"__name__": "c", "uop": UOps.CONST}, {"__name__": "s", "uop": UOps.SPECIAL}]},
|
155
|
+
lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
|
156
|
+
# const rules
|
157
|
+
({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "c", "uop": UOps.CONST},)}, lambda root, c: UOp.const(root.dtype, c.arg)),
|
158
|
+
({"__name__": "root", "uop": UOps.CAST, "vin": {"__name__": "c", "uop": UOps.CONST}}, lambda root, c: UOp.const(root.dtype, c.arg)),
|
159
|
+
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
|
160
|
+
({"uop": UOps.PHI, "vin": ({"uop": UOps.DEFINE_ACC, "__name__": "acc"}, {"__name__": "acc"})}, lambda acc: UOp.const(acc.dtype, acc.arg[0])),
|
161
|
+
({"uop": UOps.PHI, "vin": ({"uop": UOps.DEFINE_ACC, "vin": tuple()}, {"__name__": "x"})}, lambda x: x),
|
162
|
+
({"uop": UOps.PHI, "vin": ({"uop": UOps.CONST}, {"__name__": "x"})}, lambda x: x),
|
163
|
+
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
|
164
|
+
({"__name__": "root", "uop": UOps.DEFINE_ACC, "vin": tuple()}, lambda root: UOp.const(root.dtype, root.arg[0])),
|
165
|
+
({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "x", "uop": UOps.CONST},)}, lambda root,x: UOp.const(root.dtype, x.arg)),
|
166
|
+
# max -2147483648
|
167
|
+
({"uop": UOps.ALU, "arg": BinaryOps.MAX, "dtype": dtypes.int, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -2147483648}]}, lambda x: x),
|
168
|
+
# -(-x) -> x
|
169
|
+
({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"__name__": "x"},)})}, lambda x: x),
|
170
|
+
# x+-y -> x-y
|
171
|
+
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "x"}, {"__name__": "my", "uop": UOps.ALU, "arg": UnaryOps.NEG})},
|
172
|
+
lambda x, my: x-my.vin[0]),
|
173
|
+
# -1*x -> -x
|
174
|
+
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -1}]}, lambda x: -x),
|
175
|
+
# bool < False is always false, True < bool is always false
|
176
|
+
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({}, {"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": False})}, lambda x: x),
|
177
|
+
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": True}, {})},
|
178
|
+
lambda x: UOp.const(dtypes.bool, False)),
|
179
|
+
# a conditional with the same results either way is a noop, also fold const conditionals
|
180
|
+
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({}, {"__name__": "val"}, {"__name__": "val"})}, lambda val: val),
|
181
|
+
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"__name__": "gate", "uop": UOps.CONST}, {"__name__": "c0"}, {"__name__": "c1"})},
|
182
|
+
lambda gate, c0, c1: c0 if gate.arg else c1),
|
183
|
+
# ** constant folding **
|
184
|
+
({"__name__": "root", "uop": UOps.ALU, "vin": {"uop": UOps.CONST}},
|
185
|
+
lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))),
|
186
|
+
# ** self folding **
|
187
|
+
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 0}]}, lambda x: x), # x+0 -> x or 0+x -> x
|
188
|
+
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 1}]}, lambda x: x), # x*1 -> x or 1*x -> x
|
189
|
+
({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 0})}, lambda x: x), # x-0 -> x
|
190
|
+
({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 1})}, lambda x: x), # x/1 -> x
|
191
|
+
({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": -1})}, lambda x: -x), # x/-1 -> -x
|
192
|
+
# ** zero folding **
|
193
|
+
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{}, {"__name__": "c", "uop": UOps.CONST, "arg": 0}]}, lambda c: c), # x*0 -> 0 or 0*x -> 0
|
194
|
+
({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"__name__": "x"})}, lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
|
195
|
+
# ** load/store folding **
|
196
|
+
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"},
|
197
|
+
{"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})}, lambda buf, idx: UOp(UOps.NOOP)),
|
198
|
+
# ** two stage add/sub folding **
|
199
|
+
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"uop": UOps.ALU, "arg": BinaryOps.ADD,
|
200
|
+
"vin": [{"__name__": "x"}, {"__name__": "c1", "uop": UOps.CONST}]}, {"__name__": "c2", "uop": UOps.CONST}]},
|
201
|
+
lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
|
202
|
+
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"uop": UOps.ALU, "arg": BinaryOps.SUB,
|
203
|
+
"vin": ({"__name__": "x"}, {"__name__": "c1", "uop": UOps.CONST})}, {"__name__": "c2", "uop": UOps.CONST}]},
|
204
|
+
lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.SUB, x.dtype, [c2.arg, c1.arg]))),
|
205
|
+
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
206
|
+
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.ALU, "arg": TernaryOps.WHERE,
|
207
|
+
"vin": ({"__name__": "gate"}, {"__name__": "alt"}, {"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})})},
|
208
|
+
lambda buf, idx, gate, alt: UOp(UOps.STORE, None, (buf, idx, alt, gate))),
|
209
|
+
# store float4/float2 directly (remove CAST/GEP)
|
210
|
+
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.CAST, "vin":
|
211
|
+
tuple({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i} for i in range(4))})},
|
212
|
+
lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
|
213
|
+
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.CAST, "vin":
|
214
|
+
tuple({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i} for i in range(2))})},
|
215
|
+
lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
|
216
|
+
# CAST-PHI-GEP -> PHI-CAST
|
217
|
+
({"__name__": "root", "uop": UOps.CAST, "vin":
|
218
|
+
tuple({"uop": UOps.PHI, "vin": ({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i}, {"__name__": f"v{i}"})} for i in range(4))},
|
219
|
+
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
|
220
|
+
({"__name__": "root", "uop": UOps.CAST, "vin":
|
221
|
+
tuple({"uop": UOps.PHI, "vin": ({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i}, {"__name__": f"v{i}"})} for i in range(2))},
|
222
|
+
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
|
223
|
+
# NEG/CMPLT -> CMPLT
|
224
|
+
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"__name__": "x"},)},
|
225
|
+
{"__name__": "c", "uop": UOps.CONST, "dtype": dtypes.int})},
|
226
|
+
lambda c,x: UOp(UOps.ALU, dtypes.bool, (UOp.const(c.dtype, -c.arg), x), BinaryOps.CMPLT)),
|
227
|
+
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
228
|
+
({"__name__": "root", "uop": UOps.CAST}, lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None),
|
229
|
+
])
|
230
|
+
|
231
|
+
# *** uop graph ***
|
232
|
+
|
233
|
+
class UOpGraph:
|
234
|
+
def __init__(self):
|
235
|
+
self.nodes: Dict[Tuple, UOp] = {}
|
236
|
+
self._uops: Optional[List[UOp]] = None
|
237
|
+
|
238
|
+
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
|
239
|
+
def __getitem__(self, index) -> UOp: return self.uops[index]
|
240
|
+
|
241
|
+
def vars(self) -> List[Variable]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_VAR]
|
242
|
+
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_GLOBAL]
|
243
|
+
|
244
|
+
@property
|
245
|
+
def uops(self):
|
246
|
+
if self._uops is None: self.linearize()
|
247
|
+
return self._uops
|
248
|
+
|
249
|
+
def graph(self):
|
250
|
+
from tinygrad.engine.graph import graph_uops
|
251
|
+
graph_uops(self.uops)
|
252
|
+
|
253
|
+
def print(self):
|
254
|
+
for i,u in enumerate(self):
|
255
|
+
print(f"{i:4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.vin]):32s} {u.arg}")
|
256
|
+
|
257
|
+
def graph_rewrite(self, sink, pm):
|
258
|
+
# recursive rewrite
|
259
|
+
changed = getenv("UOPS_REWRITE", 1)
|
260
|
+
run_cnt = 0
|
261
|
+
while changed:
|
262
|
+
changed = 0
|
263
|
+
@functools.lru_cache
|
264
|
+
def rewrite(u:UOp) -> UOp:
|
265
|
+
nonlocal changed
|
266
|
+
recurse_cnt = 0
|
267
|
+
up = u
|
268
|
+
# locally recursively rewrite
|
269
|
+
while (rewritten := pm.rewrite(up)):
|
270
|
+
assert recurse_cnt < 100, f"recursive_rewrite looped {up} <--> {rewritten}"
|
271
|
+
up = rewritten
|
272
|
+
recurse_cnt += 1
|
273
|
+
changed += recurse_cnt
|
274
|
+
# NOTE: this changes UOp, so we have to delete caches
|
275
|
+
up.vin = tuple(rewrite(x) for x in up.vin)
|
276
|
+
if hasattr(up, "parents"): del up.parents
|
277
|
+
if hasattr(up, "cmp_tuple"): del up.cmp_tuple
|
278
|
+
# replace with cached nodes
|
279
|
+
if found:=self.nodes.get(key:=up.tuple()): return found
|
280
|
+
else: self.nodes[key] = up
|
281
|
+
return up
|
282
|
+
sink = rewrite(sink)
|
283
|
+
run_cnt += 1
|
284
|
+
assert run_cnt < 100, "exceeded 100 rewrite loops!"
|
285
|
+
return sink
|
286
|
+
|
287
|
+
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
|
288
|
+
# NOTE: relinearizering should be okay
|
289
|
+
#assert self._uops is None, "already linearized"
|
290
|
+
|
291
|
+
# get sink
|
292
|
+
_sinks: List[UOp] = []
|
293
|
+
for u in self.nodes.values():
|
294
|
+
if u.uop is UOps.STORE: _sinks.append(u)
|
295
|
+
if u.uop is UOps.SINK: _sinks.extend(u.vin)
|
296
|
+
sink = UOp(UOps.SINK, None, tuple(_sinks))
|
297
|
+
del _sinks
|
298
|
+
|
299
|
+
sink = self.graph_rewrite(sink, constant_folder)
|
300
|
+
if extra_pm: sink = self.graph_rewrite(sink, PatternMatcher(constant_folder.patterns+extra_pm.patterns))
|
301
|
+
|
302
|
+
# filter nodes that don't link to a sink
|
303
|
+
# BFS toposort
|
304
|
+
graph: DefaultDict[UOp, List[UOp]] = defaultdict(list)
|
305
|
+
in_degree: DefaultDict[UOp, int] = defaultdict(int)
|
306
|
+
loops = []
|
307
|
+
ifs = []
|
308
|
+
nodes: Dict[UOp, None] = {}
|
309
|
+
def add_parents(u:UOp):
|
310
|
+
if u in nodes: return
|
311
|
+
nodes[u] = None
|
312
|
+
for x in u.vin:
|
313
|
+
add_parents(x)
|
314
|
+
in_degree[u] += 1
|
315
|
+
graph[x].append(u)
|
316
|
+
if u.uop is UOps.RANGE: loops.append(u)
|
317
|
+
if u.uop is UOps.IF: ifs.append(u)
|
318
|
+
sink = UOp(UOps.SINK, None, tuple(x for x in sink.vin if x.uop is not UOps.NOOP))
|
319
|
+
add_parents(sink)
|
320
|
+
|
321
|
+
@functools.lru_cache(None)
|
322
|
+
def get_recursive_children(x:UOp, include_self=False) -> Set[UOp]:
|
323
|
+
if x.uop is UOps.SINK: return set()
|
324
|
+
return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, True) for u in graph[x]] if x.uop is not UOps.PHI else []))
|
325
|
+
loops_children = {l:get_recursive_children(l) for l in loops[::-1]}
|
326
|
+
|
327
|
+
queue: List = []
|
328
|
+
def push(u):
|
329
|
+
priority = 0
|
330
|
+
# prefer uops that are loop children
|
331
|
+
for l, ss in loops_children.items():
|
332
|
+
if u in ss: priority -= l.arg[0]*1000 + l.arg[1]
|
333
|
+
heapq.heappush(queue, (priority, u))
|
334
|
+
|
335
|
+
for u in nodes:
|
336
|
+
if in_degree[u] == 0: push(u)
|
337
|
+
|
338
|
+
if getenv("FUZZ_UOPS", 0):
|
339
|
+
from test.external.fuzz_uops import fuzz_uops
|
340
|
+
self.fuzz_paths = fuzz_uops(graph, in_degree.copy(), loops_children)
|
341
|
+
|
342
|
+
self._uops = []
|
343
|
+
while queue:
|
344
|
+
p,x = heapq.heappop(queue)
|
345
|
+
if DEBUG >= 7: print(p,x)
|
346
|
+
if x.uop is UOps.DEFINE_ACC and len(x.vin):
|
347
|
+
idx = min([self._uops.index(l) for l in x.vin])
|
348
|
+
self._uops.insert(idx, x)
|
349
|
+
else:
|
350
|
+
self._uops.append(x)
|
351
|
+
for u, ss in loops_children.items():
|
352
|
+
if x in ss:
|
353
|
+
ss.remove(x)
|
354
|
+
if len(ss) == 0: self._uops.append(UOp(UOps.ENDRANGE, None, (u,)))
|
355
|
+
for u in graph[x]:
|
356
|
+
in_degree[u] -= 1
|
357
|
+
if in_degree[u] == 0: push(u)
|
358
|
+
|
359
|
+
assert self._uops[-1].uop is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
|
360
|
+
self._uops = self._uops[:-1]
|
361
|
+
|
362
|
+
# TODO: ifs should be removed and just the store should be gated
|
363
|
+
for u in ifs[::-1]: self._uops.append(UOp(UOps.ENDIF, None, (u,)))
|
364
|
+
|
365
|
+
if type_verify: self.type_verify()
|
366
|
+
|
367
|
+
def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None) -> UOp:
|
368
|
+
if found:=self.nodes.get(key:=(uop, dtype, vin, arg)): return found
|
369
|
+
self.nodes[key] = ret = UOp(*key)
|
370
|
+
return ret
|
371
|
+
|
372
|
+
# *** checker functions ***
|
373
|
+
|
374
|
+
def flops_mem(self) -> Tuple[sint, sint]:
|
375
|
+
flops: sint = 0
|
376
|
+
mem: sint = 0
|
377
|
+
mults: sint = 1
|
378
|
+
mult_stack = []
|
379
|
+
for u in self.uops:
|
380
|
+
if u.uop is UOps.RANGE:
|
381
|
+
mult_stack.append(mults)
|
382
|
+
mults *= uop_alu_resolve(u.vin[1])
|
383
|
+
elif u.uop is UOps.ENDRANGE:
|
384
|
+
mults = mult_stack.pop(-1)
|
385
|
+
elif u.uop is UOps.ALU:
|
386
|
+
flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
|
387
|
+
elif u.uop is UOps.LOAD:
|
388
|
+
assert u.dtype is not None
|
389
|
+
mem += u.dtype.itemsize * mults
|
390
|
+
elif u.uop is UOps.STORE:
|
391
|
+
assert u.vin[2].dtype is not None
|
392
|
+
mem += u.vin[2].dtype.itemsize * mults
|
393
|
+
elif u.uop is UOps.WMMA:
|
394
|
+
assert u.arg[1] is not None
|
395
|
+
flops += 2 * prod(u.arg[1]) // 32 * mults
|
396
|
+
return flops, mem
|
397
|
+
|
398
|
+
def type_verify(self):
|
399
|
+
for u in self.uops:
|
400
|
+
uop, arg, vin, dtype = u.uop, u.arg, u.vin, u.dtype
|
401
|
+
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
|
402
|
+
if uop is UOps.DEFINE_ACC: arg = arg[0]
|
403
|
+
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
404
|
+
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
|
405
|
+
if uop is UOps.ALU:
|
406
|
+
if arg in UnaryOps:
|
407
|
+
assert dtype == vin[0].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=}"
|
408
|
+
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPEQ):
|
409
|
+
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
410
|
+
assert vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
|
411
|
+
elif arg in BinaryOps:
|
412
|
+
assert dtype == vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
|
413
|
+
elif arg == TernaryOps.WHERE:
|
414
|
+
assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
|
415
|
+
assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}"
|
tinygrad/device.py
ADDED
@@ -0,0 +1,183 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import multiprocessing
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from collections import defaultdict
|
5
|
+
from typing import List, Optional, Dict, Tuple, Any
|
6
|
+
import importlib, inspect, functools, pathlib, os, ctypes
|
7
|
+
from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
|
8
|
+
from tinygrad.dtype import DType, ImageDType
|
9
|
+
from tinygrad.renderer import Renderer
|
10
|
+
|
11
|
+
# **************** Device ****************
|
12
|
+
|
13
|
+
class _Device:
|
14
|
+
def __init__(self) -> None: self._devices: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] # noqa: E501
|
15
|
+
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
16
|
+
def _canonicalize(self, device:str) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") # noqa: E501
|
17
|
+
# NOTE: you can't cache canonicalize in case Device.DEFAULT changes
|
18
|
+
def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
|
19
|
+
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
|
20
|
+
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
21
|
+
def __get_canonicalized_item(self, ix:str) -> Compiled:
|
22
|
+
if DEBUG >= 1: print(f"opening device {ix} from pid:{os.getpid()}")
|
23
|
+
assert multiprocessing.current_process().name == "MainProcess" or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent"
|
24
|
+
x = ix.split(":")[0].upper()
|
25
|
+
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
|
26
|
+
@functools.cached_property
|
27
|
+
def DEFAULT(self) -> str:
|
28
|
+
device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore
|
29
|
+
if device_from_env: return device_from_env
|
30
|
+
for device in ["METAL", "HSA", "CUDA", "GPU", "CLANG", "LLVM"]:
|
31
|
+
try:
|
32
|
+
if self[device]:
|
33
|
+
os.environ[device] = "1" # we set this in environment for spawned children
|
34
|
+
return device
|
35
|
+
except Exception: pass
|
36
|
+
raise RuntimeError("no usable devices")
|
37
|
+
Device = _Device()
|
38
|
+
|
39
|
+
# **************** Buffer + Allocators ****************
|
40
|
+
|
41
|
+
@dataclass(frozen=True, eq=True)
|
42
|
+
class BufferOptions:
|
43
|
+
image: Optional[ImageDType] = None
|
44
|
+
uncached: bool = False
|
45
|
+
cpu_access: bool = False
|
46
|
+
host: bool = False
|
47
|
+
nolru: bool = False
|
48
|
+
|
49
|
+
class Buffer:
|
50
|
+
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
|
51
|
+
initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
|
52
|
+
assert isinstance(dtype, DType)
|
53
|
+
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
|
54
|
+
self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
|
55
|
+
if base is None:
|
56
|
+
assert offset == 0, "base buffers can't have offset"
|
57
|
+
self._base = None
|
58
|
+
self._lb_refcount = lb_refcount
|
59
|
+
if opaque is not None: self.allocate(opaque)
|
60
|
+
if initial_value is not None:
|
61
|
+
self.allocate()
|
62
|
+
self.copyin(memoryview(initial_value))
|
63
|
+
else:
|
64
|
+
assert base._base is None, "base can't have a base"
|
65
|
+
assert device == base.device, "base must have the same device"
|
66
|
+
self._base = base
|
67
|
+
if preallocate: self.allocate()
|
68
|
+
@property
|
69
|
+
def base(self) -> Buffer: return self._base if self._base is not None else self
|
70
|
+
@property
|
71
|
+
def lb_refcount(self): return self.base._lb_refcount
|
72
|
+
def ref(self, cnt): self.base._lb_refcount += cnt
|
73
|
+
def is_allocated(self) -> bool: return hasattr(self, '_buf')
|
74
|
+
def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
|
75
|
+
def allocate(self, opaque=None) -> Buffer:
|
76
|
+
assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
|
77
|
+
self.allocator = Device[self.device].allocator
|
78
|
+
if self._base is not None:
|
79
|
+
self._base.ensure_allocated()
|
80
|
+
assert hasattr(self.allocator, "offset"), "offset function required for view"
|
81
|
+
self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset)
|
82
|
+
else:
|
83
|
+
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
|
84
|
+
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
|
85
|
+
return self
|
86
|
+
def __reduce__(self):
|
87
|
+
buf = None
|
88
|
+
if self._base is not None:
|
89
|
+
return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
|
90
|
+
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
|
91
|
+
if self.is_allocated():
|
92
|
+
buf = bytearray(self.nbytes)
|
93
|
+
self.copyout(memoryview(buf))
|
94
|
+
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
|
95
|
+
@property
|
96
|
+
def nbytes(self): return self.size*self.dtype.itemsize
|
97
|
+
def __del__(self):
|
98
|
+
if not hasattr(self, '_buf'): return
|
99
|
+
if self._base is None:
|
100
|
+
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
|
101
|
+
self.allocator.free(self._buf, self.nbytes, self.options)
|
102
|
+
def __repr__(self):
|
103
|
+
return f"<buf real:{hasattr(self, '_buf')} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
|
104
|
+
(f" offset:{self.offset}" if hasattr(self, "base") else "") + \
|
105
|
+
(">" if self.options is None else f" {self.options=}>")
|
106
|
+
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
|
107
|
+
# zero copy with as_buffer (disabled by default due to use after free)
|
108
|
+
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
|
109
|
+
assert not force_zero_copy, "force zero copy was passed, but copy is required"
|
110
|
+
return self.copyout(memoryview(bytearray(self.nbytes)))
|
111
|
+
def copyin(self, mv:memoryview):
|
112
|
+
mv = flat_mv(mv)
|
113
|
+
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
114
|
+
assert self.is_allocated(), "can't copyin to unallocated buffer"
|
115
|
+
self.allocator.copyin(self._buf, mv)
|
116
|
+
return self
|
117
|
+
def copyout(self, mv:memoryview) -> memoryview:
|
118
|
+
mv = flat_mv(mv)
|
119
|
+
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
120
|
+
assert self.is_allocated(), "can't copyout unallocated buffer"
|
121
|
+
self.allocator.copyout(mv, self._buf)
|
122
|
+
return mv
|
123
|
+
def view(self, size:int, dtype:DType, offset:int) -> Buffer:
|
124
|
+
assert offset < self.nbytes, "offset must be less than nbytes"
|
125
|
+
if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
|
126
|
+
return Buffer(self.device, size, dtype, base=self, offset=offset)
|
127
|
+
|
128
|
+
# TODO: size, dest, src are the same type. can we enforce this?
|
129
|
+
class Allocator:
|
130
|
+
def alloc(self, size:int, options:Optional[BufferOptions]=None):
|
131
|
+
assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
|
132
|
+
return self._alloc(size, options if options is not None else BufferOptions())
|
133
|
+
def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
|
134
|
+
def free(self, opaque, size:int, options:Optional[BufferOptions]=None):
|
135
|
+
self._free(opaque, options if options is not None else BufferOptions())
|
136
|
+
def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
|
137
|
+
def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
|
138
|
+
def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
|
139
|
+
|
140
|
+
class LRUAllocator(Allocator): # pylint: disable=abstract-method
|
141
|
+
def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
|
142
|
+
def alloc(self, size:int, options:Optional[BufferOptions]=None):
|
143
|
+
if len(c := self.cache[(size, options)]): return c.pop()
|
144
|
+
try: return super().alloc(size, options)
|
145
|
+
except (RuntimeError, MemoryError):
|
146
|
+
self.free_cache()
|
147
|
+
return super().alloc(size, options)
|
148
|
+
def free_cache(self):
|
149
|
+
for (sz,options),opaques in self.cache.items():
|
150
|
+
for opaque in opaques: super().free(opaque, sz, options)
|
151
|
+
opaques.clear()
|
152
|
+
def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None):
|
153
|
+
if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
|
154
|
+
else: super().free(opaque, size, options)
|
155
|
+
|
156
|
+
class _MallocAllocator(LRUAllocator):
|
157
|
+
def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)()
|
158
|
+
def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
|
159
|
+
def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
|
160
|
+
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
|
161
|
+
def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size])
|
162
|
+
|
163
|
+
MallocAllocator = _MallocAllocator()
|
164
|
+
|
165
|
+
# **************** for Compiled Devices ****************
|
166
|
+
|
167
|
+
class CompileError(Exception): pass
|
168
|
+
|
169
|
+
class Compiler:
|
170
|
+
def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
|
171
|
+
def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
|
172
|
+
def compile_cached(self, src:str) -> bytes:
|
173
|
+
if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
|
174
|
+
assert not getenv("ASSERT_COMPILE"), "tried to compile with ASSERT_COMPILE set"
|
175
|
+
lib = self.compile(src)
|
176
|
+
if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
|
177
|
+
return lib
|
178
|
+
|
179
|
+
class Compiled:
|
180
|
+
def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
|
181
|
+
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler if compiler else Compiler(), runtime, graph
|
182
|
+
self.renderer = renderer if renderer else Renderer()
|
183
|
+
def synchronize(self): pass # override this in your device
|