tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +78 -90
- tinygrad/codegen/linearizer.py +237 -169
- tinygrad/codegen/uops.py +278 -242
- tinygrad/device.py +147 -10
- tinygrad/dtype.py +7 -7
- tinygrad/engine/graph.py +16 -16
- tinygrad/engine/jit.py +39 -36
- tinygrad/engine/realize.py +6 -5
- tinygrad/engine/schedule.py +15 -7
- tinygrad/engine/search.py +6 -3
- tinygrad/function.py +17 -23
- tinygrad/helpers.py +77 -8
- tinygrad/lazy.py +26 -26
- tinygrad/multi.py +13 -9
- tinygrad/nn/__init__.py +1 -1
- tinygrad/nn/datasets.py +2 -1
- tinygrad/nn/state.py +3 -4
- tinygrad/ops.py +49 -16
- tinygrad/renderer/__init__.py +8 -4
- tinygrad/renderer/assembly.py +93 -100
- tinygrad/renderer/cstyle.py +47 -42
- tinygrad/renderer/llvmir.py +30 -30
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +11504 -1
- tinygrad/runtime/autogen/comgr.py +36 -10
- tinygrad/runtime/autogen/hsa.py +146 -14
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/nv_gpu.py +269 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +20 -11
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +3 -2
- tinygrad/runtime/graph/cuda.py +2 -2
- tinygrad/runtime/graph/hcq.py +122 -78
- tinygrad/runtime/ops_amd.py +302 -316
- tinygrad/runtime/ops_cuda.py +3 -3
- tinygrad/runtime/ops_disk.py +70 -5
- tinygrad/runtime/ops_gpu.py +2 -2
- tinygrad/runtime/ops_metal.py +5 -6
- tinygrad/runtime/ops_npy.py +1 -1
- tinygrad/runtime/ops_nv.py +161 -166
- tinygrad/runtime/ops_python.py +20 -16
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +5 -2
- tinygrad/shape/symbolic.py +1 -3
- tinygrad/shape/view.py +34 -19
- tinygrad/tensor.py +219 -135
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/runtime/driver/hsa.py +0 -143
- tinygrad/runtime/graph/hsa.py +0 -171
- tinygrad/runtime/ops_hsa.py +0 -278
- tinygrad-0.9.0.dist-info/RECORD +0 -60
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/function.py
CHANGED
@@ -31,7 +31,7 @@ class Neg(Function):
|
|
31
31
|
|
32
32
|
class Reciprocal(Function):
|
33
33
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
34
|
-
self.ret = x.
|
34
|
+
self.ret = x.e(UnaryOps.RECIP)
|
35
35
|
return self.ret
|
36
36
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
37
37
|
return grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.ret).e(BinaryOps.MUL, self.ret)
|
@@ -42,7 +42,7 @@ class Sin(Function):
|
|
42
42
|
return x.e(UnaryOps.SIN)
|
43
43
|
|
44
44
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
45
|
-
return self.x.const(math.pi / 2).e(BinaryOps.
|
45
|
+
return self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG)).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
|
46
46
|
|
47
47
|
# NOTE: maximum(x, 0) behaves differently where x=0
|
48
48
|
class Relu(Function):
|
@@ -58,7 +58,7 @@ class Log(Function):
|
|
58
58
|
self.x = x
|
59
59
|
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
|
60
60
|
|
61
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.
|
61
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.MUL, self.x.e(UnaryOps.RECIP))
|
62
62
|
|
63
63
|
class Exp(Function):
|
64
64
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
@@ -73,23 +73,23 @@ class Sqrt(Function):
|
|
73
73
|
return self.ret
|
74
74
|
|
75
75
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
76
|
-
return grad_output.e(BinaryOps.
|
76
|
+
return grad_output.e(BinaryOps.MUL, self.ret.e(BinaryOps.MUL, self.ret.const(2)).e(UnaryOps.RECIP))
|
77
77
|
|
78
78
|
# NOTE: the implicit derivative of sigmoid is not stable
|
79
79
|
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
|
80
80
|
# TODO: have the backend automatically find this
|
81
81
|
class Sigmoid(Function):
|
82
82
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
83
|
-
self.ret = x.const(1).e(BinaryOps.
|
83
|
+
self.ret = x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)).e(UnaryOps.RECIP)
|
84
84
|
return self.ret
|
85
85
|
|
86
86
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
87
|
-
return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.
|
87
|
+
return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.ADD, self.ret.e(UnaryOps.NEG))).e(BinaryOps.MUL, grad_output)
|
88
88
|
|
89
89
|
class Sign(Function):
|
90
90
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
91
|
-
return x.e(BinaryOps.
|
92
|
-
|
91
|
+
return x.e(BinaryOps.CMPNE, x.const(0)).e(
|
92
|
+
TernaryOps.WHERE, x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)), x.const(0))
|
93
93
|
# backward always return 0 to match torch
|
94
94
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
|
95
95
|
|
@@ -99,8 +99,8 @@ class Less(Function):
|
|
99
99
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y)
|
100
100
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
|
101
101
|
|
102
|
-
class
|
103
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.
|
102
|
+
class Neq(Function):
|
103
|
+
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPNE, y)
|
104
104
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
|
105
105
|
|
106
106
|
class Xor(Function):
|
@@ -113,13 +113,6 @@ class Add(Function):
|
|
113
113
|
return grad_output if self.needs_input_grad[0] else None, \
|
114
114
|
grad_output if self.needs_input_grad[1] else None
|
115
115
|
|
116
|
-
class Sub(Function):
|
117
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.SUB, y)
|
118
|
-
|
119
|
-
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
120
|
-
return grad_output if self.needs_input_grad[0] else None, \
|
121
|
-
grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None
|
122
|
-
|
123
116
|
class Mul(Function):
|
124
117
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
125
118
|
self.x, self.y = x, y
|
@@ -132,11 +125,11 @@ class Mul(Function):
|
|
132
125
|
class Div(Function):
|
133
126
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
134
127
|
self.x, self.y = x, y
|
135
|
-
return x.e(BinaryOps.
|
128
|
+
return x.e(BinaryOps.MUL, y.e(UnaryOps.RECIP)) if not dtypes.is_int(x.dtype) else x.e(BinaryOps.IDIV, y)
|
136
129
|
|
137
130
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
138
|
-
return grad_output.e(BinaryOps.
|
139
|
-
grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.
|
131
|
+
return grad_output.e(BinaryOps.MUL, self.y.e(UnaryOps.RECIP)) if self.needs_input_grad[0] else None, \
|
132
|
+
grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.MUL, self.y.e(BinaryOps.MUL, self.y).e(UnaryOps.RECIP)) if self.needs_input_grad[1] else None # noqa: E501
|
140
133
|
|
141
134
|
# ************* ternary ops *************
|
142
135
|
|
@@ -166,9 +159,10 @@ class Max(Function):
|
|
166
159
|
|
167
160
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
168
161
|
# 1s in locations where the max was chosen (can be two locations)
|
169
|
-
max_is_1s = self.x.e(BinaryOps.
|
162
|
+
max_is_1s = self.x.const(1.0).cast(dtypes.float).e(BinaryOps.ADD, self.x.e(BinaryOps.CMPNE, \
|
163
|
+
self.ret.expand(self.x.shape)).cast(dtypes.float).e(UnaryOps.NEG))
|
170
164
|
div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
|
171
|
-
return max_is_1s.e(BinaryOps.
|
165
|
+
return max_is_1s.e(BinaryOps.MUL, div.e(UnaryOps.RECIP)).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
|
172
166
|
|
173
167
|
# ************* movement ops *************
|
174
168
|
|
@@ -211,7 +205,7 @@ class Shrink(Function):
|
|
211
205
|
|
212
206
|
class Flip(Function):
|
213
207
|
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
214
|
-
self.arg = tuple([-1 if i in
|
208
|
+
self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))])
|
215
209
|
return x.stride(self.arg)
|
216
210
|
|
217
211
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)
|
tinygrad/helpers.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes
|
3
|
-
import itertools, urllib.request, subprocess
|
4
|
-
from tqdm import tqdm
|
2
|
+
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes, sys
|
3
|
+
import itertools, urllib.request, subprocess, shutil, math, json
|
5
4
|
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
|
6
5
|
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
7
6
|
from typing_extensions import TypeGuard
|
@@ -56,6 +55,12 @@ def get_child(obj, key):
|
|
56
55
|
else: obj = getattr(obj, k)
|
57
56
|
return obj
|
58
57
|
|
58
|
+
def get_shape(x) -> Tuple[int, ...]:
|
59
|
+
if not isinstance(x, (list, tuple)): return ()
|
60
|
+
subs = [get_shape(xi) for xi in x]
|
61
|
+
if not all_same(subs): raise ValueError(f"inhomogeneous shape from {x}")
|
62
|
+
return (len(subs),) + (subs[0] if subs else ())
|
63
|
+
|
59
64
|
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
60
65
|
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
|
61
66
|
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
@@ -96,9 +101,9 @@ class ContextVar:
|
|
96
101
|
def __lt__(self, x): return self.value < x
|
97
102
|
|
98
103
|
DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
|
99
|
-
WINO, THREEFRY,
|
104
|
+
WINO, THREEFRY, CAPTURING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1)
|
100
105
|
GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
|
101
|
-
MULTIOUTPUT = ContextVar("MULTIOUTPUT", 1)
|
106
|
+
MULTIOUTPUT, PROFILE = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0)
|
102
107
|
|
103
108
|
# **************** global state Counters ****************
|
104
109
|
|
@@ -139,6 +144,34 @@ class Profiling(contextlib.ContextDecorator):
|
|
139
144
|
colored(_format_fcn(fcn), "yellow") + " "*(50-len(_format_fcn(fcn))),
|
140
145
|
colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if len(scallers) else '')
|
141
146
|
|
147
|
+
class ProfileLogger:
|
148
|
+
writers: int = 0
|
149
|
+
mjson: List[Dict] = []
|
150
|
+
actors: Dict[str, int] = {}
|
151
|
+
subactors: Dict[Tuple[str, str], int] = {}
|
152
|
+
path = getenv("PROFILE_OUTPUT_FILE", temp("tinygrad_profile.json"))
|
153
|
+
|
154
|
+
def __init__(self): self.events, ProfileLogger.writers = [], ProfileLogger.writers + 1
|
155
|
+
|
156
|
+
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor)]
|
157
|
+
|
158
|
+
def __del__(self):
|
159
|
+
for name,st,et,actor_name,subactor_name in self.events:
|
160
|
+
if actor_name not in self.actors:
|
161
|
+
self.actors[actor_name] = (pid:=len(self.actors))
|
162
|
+
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
|
163
|
+
|
164
|
+
if (subactor_key:=(actor_name,subactor_name)) not in self.subactors:
|
165
|
+
self.subactors[subactor_key] = (tid:=len(self.subactors))
|
166
|
+
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
|
167
|
+
|
168
|
+
self.mjson.append({"name": name, "ph": "X", "pid": self.actors[actor_name], "tid": self.subactors.get(subactor_key, -1), "ts":st, "dur":et-st})
|
169
|
+
|
170
|
+
ProfileLogger.writers -= 1
|
171
|
+
if ProfileLogger.writers == 0:
|
172
|
+
with open(self.path, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
|
173
|
+
print(f"Saved profile to {self.path}. Use https://ui.perfetto.dev/ to open it.")
|
174
|
+
|
142
175
|
# *** universal database cache ***
|
143
176
|
|
144
177
|
_cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache"))
|
@@ -197,18 +230,21 @@ def diskcache(func):
|
|
197
230
|
|
198
231
|
# *** http support ***
|
199
232
|
|
200
|
-
def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None,
|
233
|
+
def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None,
|
234
|
+
allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
|
201
235
|
if url.startswith(("/", ".")): return pathlib.Path(url)
|
202
|
-
|
236
|
+
if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
|
237
|
+
else: fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (subdir or "") / (name or hashlib.md5(url.encode('utf-8')).hexdigest())
|
203
238
|
if not fp.is_file() or not allow_caching:
|
204
239
|
with urllib.request.urlopen(url, timeout=10) as r:
|
205
240
|
assert r.status == 200
|
206
241
|
total_length = int(r.headers.get('content-length', 0))
|
207
|
-
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=url)
|
242
|
+
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}: ", disable=CI)
|
208
243
|
(path := fp.parent).mkdir(parents=True, exist_ok=True)
|
209
244
|
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
|
210
245
|
while chunk := r.read(16384): progress_bar.update(f.write(chunk))
|
211
246
|
f.close()
|
247
|
+
progress_bar.update(close=True)
|
212
248
|
if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")
|
213
249
|
pathlib.Path(f.name).rename(fp)
|
214
250
|
return fp
|
@@ -231,6 +267,7 @@ def cpu_objdump(lib):
|
|
231
267
|
def from_mv(mv:memoryview, to_type=ctypes.c_char):
|
232
268
|
return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
|
233
269
|
def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
|
270
|
+
def mv_address(mv:memoryview): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
|
234
271
|
def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
|
235
272
|
@functools.lru_cache(maxsize=None)
|
236
273
|
def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
|
@@ -239,3 +276,35 @@ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
|
|
239
276
|
return CStruct
|
240
277
|
def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
|
241
278
|
def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))
|
279
|
+
|
280
|
+
class tqdm:
|
281
|
+
def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:int=-1, rate:int=100):
|
282
|
+
self.iter, self.desc, self.dis, self.unit, self.unit_scale, self.rate = iterable, f"{desc}: " if desc else "", disable, unit, unit_scale, rate
|
283
|
+
self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, len(iterable) if total==-1 else total
|
284
|
+
self.update(0)
|
285
|
+
def __iter__(self):
|
286
|
+
try:
|
287
|
+
for item in self.iter:
|
288
|
+
yield item
|
289
|
+
self.update(1)
|
290
|
+
finally: self.update(close=True)
|
291
|
+
def set_description(self, desc:str): self.desc = f"{desc}: " if desc else ""
|
292
|
+
def update(self, n:int=0, close:bool=False):
|
293
|
+
self.n, self.i = self.n+n, self.i+1
|
294
|
+
if (self.i % self.skip != 0 and not close) or self.dis: return
|
295
|
+
prog, dur, term = self.n/self.t if self.t else -1, time.perf_counter()-self.st, shutil.get_terminal_size().columns
|
296
|
+
if self.i/dur > self.rate and self.i: self.skip = max(int(self.i/dur)//self.rate,1) if self.i else 1
|
297
|
+
def fmt(t): return ':'.join([f'{x:02d}' for x in divmod(int(t), 60)]) if t!=-1 else '?'
|
298
|
+
def scl(x): return x/1000**int(math.log(x,1000))
|
299
|
+
def fn(x): return (f"{scl(x):.{3-math.ceil(math.log10(scl(x)))}f}"[:4]+(f"{[' ','k','M','G','T','P'][int(math.log(x,1000))]}") if x else '0.00')
|
300
|
+
if self.t: unit_text = f"{fn(self.n)}/{fn(self.t)}" if self.unit_scale else f"{self.n}/{self.t}"
|
301
|
+
else: unit_text = f"{fn(self.n)}{self.unit}" if self.unit_scale else f"{self.n}{self.unit}"
|
302
|
+
it_text = f"{fn(self.n/dur)}" if self.n and self.unit_scale else f"{self.n/dur:5.2f}" if self.n else "?"
|
303
|
+
if self.t: suf = f'| {unit_text} [{fmt(dur)}<{fmt(dur/self.n*self.t-dur if self.n else -1)}, {it_text}{self.unit}/s]'
|
304
|
+
else: suf = f'{unit_text} [{fmt(dur)}, {it_text}{self.unit}/s]'
|
305
|
+
sz = max(term-5-len(suf)-len(self.desc), 1)
|
306
|
+
bar = f'\r{self.desc}{round(100*prog):3}%|{"█"*round(sz*prog)}{" "*(sz-round(sz*prog))}{suf}' if self.t else f'\r{self.desc}{suf}{" "*term}'
|
307
|
+
print(bar[:term+1],flush=True,end='\n'*close,file=sys.stderr)
|
308
|
+
|
309
|
+
class trange(tqdm):
|
310
|
+
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
|
tinygrad/lazy.py
CHANGED
@@ -22,7 +22,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=
|
|
22
22
|
if enable_cache: lazycache[cache_key] = ret
|
23
23
|
return ret
|
24
24
|
|
25
|
-
view_supported_devices = {"LLVM", "CLANG", "CUDA", "DISK"}
|
25
|
+
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "DISK"}
|
26
26
|
class LazyBuffer:
|
27
27
|
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
|
28
28
|
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
@@ -34,11 +34,9 @@ class LazyBuffer:
|
|
34
34
|
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
|
35
35
|
assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
|
36
36
|
|
37
|
-
if
|
38
|
-
not srcs[0].is_unrealized_const() and device.split(":")[0] in view_supported_devices:
|
37
|
+
if self.op is LoadOps.VIEW:
|
39
38
|
# some LazyBuffers can be processed with only a view, no AST required
|
40
39
|
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
41
|
-
self.op = LoadOps.VIEW
|
42
40
|
else:
|
43
41
|
self.buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
|
44
42
|
self.buffer.ref(1)
|
@@ -74,6 +72,7 @@ class LazyBuffer:
|
|
74
72
|
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
|
75
73
|
|
76
74
|
def const(self, val:ConstType, shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer:
|
75
|
+
assert isinstance(val, (int,float,bool)), f"{val=} has {type(val)=}, not a ConstType"
|
77
76
|
shape = self.shape if shape is None else shape
|
78
77
|
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
|
79
78
|
|
@@ -83,9 +82,11 @@ class LazyBuffer:
|
|
83
82
|
assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
|
84
83
|
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base))
|
85
84
|
|
86
|
-
def
|
85
|
+
def can_view(self): return self.st.consecutive and not self.is_unrealized_const() and self.device.split(":")[0] in view_supported_devices
|
86
|
+
|
87
|
+
def contiguous(self, allow_buffer_view=True):
|
87
88
|
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
88
|
-
ret = self.e(LoadOps.CONTIGUOUS)
|
89
|
+
ret = self.e(LoadOps.VIEW) if allow_buffer_view and self.can_view() else self.e(LoadOps.CONTIGUOUS)
|
89
90
|
if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
|
90
91
|
return ret
|
91
92
|
self.base.forced_realize = True
|
@@ -96,9 +97,6 @@ class LazyBuffer:
|
|
96
97
|
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
|
97
98
|
if self.is_unrealized_unmasked_const() and not bitcast:
|
98
99
|
return create_lazybuffer(self.device, self.st, dtype, LoadOps.CONST, dtypes.as_const(self.base.arg, dtype))
|
99
|
-
# TODO: applying this makes gpt2 slower
|
100
|
-
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
|
101
|
-
return self.base.cast(dtype, bitcast)._view(self.st)
|
102
100
|
new_shape = self.shape
|
103
101
|
if bitcast and self.dtype.itemsize != dtype.itemsize:
|
104
102
|
if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
|
@@ -106,7 +104,10 @@ class LazyBuffer:
|
|
106
104
|
# https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
|
107
105
|
if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
|
108
106
|
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
|
109
|
-
|
107
|
+
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
|
108
|
+
# TODO: applying this makes gpt2 slower
|
109
|
+
return self.base.cast(dtype, bitcast)._view(self.st)
|
110
|
+
cast_op: Union[LoadOps, UnaryOps] = (LoadOps.VIEW if self.can_view() else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
|
110
111
|
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
|
111
112
|
|
112
113
|
def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST and not isinstance(self.base.arg, Variable)
|
@@ -145,24 +146,22 @@ class LazyBuffer:
|
|
145
146
|
if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
|
146
147
|
if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
|
147
148
|
|
148
|
-
out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.
|
149
|
+
out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else srcs[-1].dtype
|
149
150
|
|
150
151
|
# const folding
|
151
152
|
if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
|
152
153
|
return self.cast(out_dtype).const(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
|
153
|
-
if op is UnaryOps.NEG and self.base.op is UnaryOps.NEG: return self.base.srcs[0]
|
154
|
-
if op in BinaryOps:
|
155
|
-
|
156
|
-
if
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
if op is BinaryOps.DIV and dtypes.is_float(x.dtype) and y.is_unrealized_unmasked_const() and y.base.arg != 0:
|
165
|
-
return x.e(BinaryOps.MUL, x.const(1 / y.base.arg))
|
154
|
+
if op is UnaryOps.NEG and self.base.op is UnaryOps.NEG and self.base.realized is None: return self.base.srcs[0]
|
155
|
+
if op in BinaryOps:
|
156
|
+
x, y = self, in_srcs[0]
|
157
|
+
if op is BinaryOps.ADD:
|
158
|
+
if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
|
159
|
+
if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
|
160
|
+
if op is BinaryOps.MUL:
|
161
|
+
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0, -1):
|
162
|
+
return y if val == 1 else y.const(0) if val == 0 else y.e(UnaryOps.NEG)
|
163
|
+
if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0, -1):
|
164
|
+
return x if val == 1 else x.const(0) if val == 0 else x.e(UnaryOps.NEG)
|
166
165
|
|
167
166
|
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
|
168
167
|
|
@@ -170,7 +169,7 @@ class LazyBuffer:
|
|
170
169
|
|
171
170
|
def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
|
172
171
|
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
|
173
|
-
axis = tuple(x for x in axis if self.shape[x] != 1)
|
172
|
+
axis = tuple(sorted([x for x in axis if self.shape[x] != 1]))
|
174
173
|
if len(axis) == 0: return self
|
175
174
|
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
176
175
|
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, axis, (self,))
|
@@ -181,7 +180,8 @@ class LazyBuffer:
|
|
181
180
|
if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape)
|
182
181
|
|
183
182
|
# const folding
|
184
|
-
|
183
|
+
# TODO: fold this for symbolic?
|
184
|
+
if self.is_unrealized_unmasked_const() and all_int(self.shape):
|
185
185
|
return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
|
186
186
|
|
187
187
|
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
tinygrad/multi.py
CHANGED
@@ -45,7 +45,7 @@ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
|
|
45
45
|
def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
|
46
46
|
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
|
47
47
|
sz = round_up(lbs[0].shape[axis], len(lbs)) // len(lbs)
|
48
|
-
return [lb.shrink(tuple((0,s) if a != axis else (sz*i,min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
|
48
|
+
return [lb.shrink(tuple((0,s) if a != axis else (min(s,sz*i),min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
|
49
49
|
|
50
50
|
class MultiLazyBuffer:
|
51
51
|
def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None):
|
@@ -54,7 +54,7 @@ class MultiLazyBuffer:
|
|
54
54
|
self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
|
55
55
|
if axis is not None:
|
56
56
|
splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0))
|
57
|
-
self.bounds =
|
57
|
+
self.bounds = list(zip(splits, splits[1:]))
|
58
58
|
|
59
59
|
@property
|
60
60
|
def shape(self):
|
@@ -73,15 +73,19 @@ class MultiLazyBuffer:
|
|
73
73
|
def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
|
74
74
|
lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unmasked_const() else lb] * len(devices)
|
75
75
|
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)]
|
76
|
-
return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous() for lb in sharded_lbs], axis)
|
76
|
+
return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis)
|
77
77
|
|
78
78
|
def copy_to_device(self, device:str) -> LazyBuffer:
|
79
|
-
if self.axis is None:
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
79
|
+
if self.axis is None:
|
80
|
+
# if we already have a copy on the device, return that
|
81
|
+
for lb in self.real_lbs:
|
82
|
+
if lb.device == device: return lb
|
83
|
+
return self.lbs[self.real.index(True)].copy_to_device(device)
|
84
|
+
llbs:List[LazyBuffer] = []
|
85
|
+
for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
|
86
|
+
if not real: continue
|
87
|
+
pad_arg = tuple((0,0) if a != self.axis else (start, self.bounds[-1][1]-end) for a in range(len(lb.shape)))
|
88
|
+
llbs.append(lb.copy_to_device(device).pad(pad_arg))
|
85
89
|
return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
|
86
90
|
|
87
91
|
# passthroughs
|
tinygrad/nn/__init__.py
CHANGED
@@ -2,7 +2,7 @@ import math
|
|
2
2
|
from typing import Optional, Union, Tuple, cast
|
3
3
|
from tinygrad.tensor import Tensor
|
4
4
|
from tinygrad.helpers import prod
|
5
|
-
from tinygrad.nn import optim, state # noqa: F401
|
5
|
+
from tinygrad.nn import optim, state, datasets # noqa: F401
|
6
6
|
|
7
7
|
class BatchNorm2d:
|
8
8
|
"""
|
tinygrad/nn/datasets.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import gzip
|
2
|
-
from tinygrad import Tensor
|
2
|
+
from tinygrad.tensor import Tensor
|
3
|
+
from tinygrad.helpers import fetch
|
3
4
|
|
4
5
|
def _fetch_mnist(file, offset): return Tensor(gzip.open(fetch("https://storage.googleapis.com/cvdf-datasets/mnist/"+file)).read()[offset:])
|
5
6
|
def mnist():
|
tinygrad/nn/state.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1
1
|
import os, json, pathlib, zipfile, pickle, tarfile, struct
|
2
|
-
from tqdm import tqdm
|
3
2
|
from typing import Dict, Union, List, Optional, Any, Tuple
|
4
3
|
from tinygrad.tensor import Tensor
|
5
4
|
from tinygrad.dtype import dtypes
|
6
|
-
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters
|
5
|
+
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
|
7
6
|
from tinygrad.shape.view import strides_for_shape
|
8
7
|
from tinygrad.multi import MultiLazyBuffer
|
9
8
|
|
@@ -41,7 +40,7 @@ def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any
|
|
41
40
|
Saves a state_dict to disk in a .safetensor file with optional metadata.
|
42
41
|
|
43
42
|
```python
|
44
|
-
t =
|
43
|
+
t = Tensor([1, 2, 3])
|
45
44
|
nn.state.safe_save({'t':t}, "test.safetensor")
|
46
45
|
```
|
47
46
|
"""
|
@@ -120,7 +119,7 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
|
|
120
119
|
if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
|
121
120
|
print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
|
122
121
|
for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
|
123
|
-
t.
|
122
|
+
t.desc = f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}: "
|
124
123
|
if k not in state_dict and not strict:
|
125
124
|
if DEBUG >= 1: print(f"WARNING: not loading {k}")
|
126
125
|
continue
|
tinygrad/ops.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
from typing import Union, Tuple, Any, List, Dict, Callable
|
3
|
-
import functools, hashlib, math, operator, ctypes
|
3
|
+
import functools, hashlib, math, operator, ctypes, struct
|
4
4
|
from enum import Enum, auto
|
5
5
|
from dataclasses import dataclass
|
6
6
|
from tinygrad.helpers import prod, dedup
|
@@ -14,10 +14,11 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
|
14
14
|
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
|
15
15
|
class UnaryOps(Enum):
|
16
16
|
"""A -> A (elementwise)"""
|
17
|
-
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
|
17
|
+
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto(); RECIP = auto() # noqa: E702
|
18
18
|
class BinaryOps(Enum):
|
19
19
|
"""A + A -> A (elementwise)"""
|
20
|
-
ADD = auto();
|
20
|
+
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
21
|
+
SHR = auto(); SHL = auto() # noqa: E702
|
21
22
|
class TernaryOps(Enum):
|
22
23
|
"""A + A + A -> A (elementwise)"""
|
23
24
|
WHERE = auto(); MULACC = auto() # noqa: E702
|
@@ -30,7 +31,7 @@ class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS =
|
|
30
31
|
Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
|
31
32
|
|
32
33
|
# do not preserve f(0) = 0
|
33
|
-
UNSAFE_PAD_OPS = {
|
34
|
+
UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
|
34
35
|
|
35
36
|
@dataclass(frozen=True)
|
36
37
|
class MemBuffer:
|
@@ -40,7 +41,7 @@ class MemBuffer:
|
|
40
41
|
|
41
42
|
@dataclass(frozen=True)
|
42
43
|
class ConstBuffer:
|
43
|
-
val: ConstType
|
44
|
+
val: ConstType | Variable
|
44
45
|
dtype: DType
|
45
46
|
st: ShapeTracker
|
46
47
|
|
@@ -61,7 +62,7 @@ class LazyOp:
|
|
61
62
|
def dtype(self) -> DType:
|
62
63
|
if self.op in BufferOps: return self.arg.dtype
|
63
64
|
if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
|
64
|
-
return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.
|
65
|
+
return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype
|
65
66
|
|
66
67
|
@functools.cached_property
|
67
68
|
def key(self) -> bytes:
|
@@ -73,8 +74,8 @@ class LazyOp:
|
|
73
74
|
def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops])
|
74
75
|
def vars(self) -> List[Variable]:
|
75
76
|
extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps]
|
76
|
-
const_vars = [x.arg.val
|
77
|
-
return sorted(set.union(*extract_vars, set(const_vars)), key=lambda
|
77
|
+
const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
|
78
|
+
return sorted(set.union(*extract_vars, set(const_vars)), key=lambda v: v.expr)
|
78
79
|
|
79
80
|
# **************** independent FlopCounter ****************
|
80
81
|
|
@@ -116,21 +117,53 @@ def hook_overflow(dv, fxn):
|
|
116
117
|
|
117
118
|
python_alu = {
|
118
119
|
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
|
119
|
-
UnaryOps.EXP2: hook_overflow(math.inf, lambda x:
|
120
|
-
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan,
|
120
|
+
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
|
121
|
+
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan,
|
122
|
+
UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
|
123
|
+
UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
121
124
|
UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
|
122
|
-
BinaryOps.
|
123
|
-
BinaryOps.
|
124
|
-
BinaryOps.
|
125
|
-
BinaryOps.
|
125
|
+
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
|
126
|
+
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
|
127
|
+
BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
|
128
|
+
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x, y: int(x/y) if y != 0 else x*math.inf,
|
126
129
|
TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
127
130
|
|
131
|
+
def truncate_fp16(x):
|
132
|
+
try:
|
133
|
+
x = float(x)
|
134
|
+
struct.pack("@e", x)
|
135
|
+
return x
|
136
|
+
except OverflowError: return math.copysign(math.inf, x)
|
137
|
+
|
128
138
|
truncate: Dict[DType, Callable] = {dtypes.bool: bool,
|
129
|
-
# TODO:
|
130
|
-
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
139
|
+
# TODO: bfloat16
|
140
|
+
dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
131
141
|
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
132
142
|
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
133
143
|
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
|
134
144
|
dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,}
|
135
145
|
|
136
146
|
def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
|
147
|
+
|
148
|
+
# the living definition of LazyOps
|
149
|
+
def verify_lazyop(*ast:LazyOp):
|
150
|
+
sts: Dict[LazyOp, ShapeTracker] = {}
|
151
|
+
def dfs(op:LazyOp, st:ShapeTracker):
|
152
|
+
if op in sts: return
|
153
|
+
for x in op.src: dfs(x, st)
|
154
|
+
# only reduceop is allowed to change shape, limited to turning n to 1
|
155
|
+
if op.op in ReduceOps:
|
156
|
+
expected_shape = tuple(1 if i in op.arg else s for i,s in enumerate(sts[op.src[0]].shape))
|
157
|
+
assert st.shape == expected_shape, f"unexpected reduceop shape {st.shape} != {expected_shape}"
|
158
|
+
st = ShapeTracker.from_shape(expected_shape)
|
159
|
+
else:
|
160
|
+
# movementops are pushed to the edges with LOAD
|
161
|
+
if op.op in BufferOps: st = op.arg.st
|
162
|
+
else: st = sts[op.src[0]]
|
163
|
+
for x in op.src: assert sts[x].shape == st.shape, f"found implicit movement op {x.op} {sts[x].shape} != {op.op} {st.shape}"
|
164
|
+
sts[op] = st
|
165
|
+
for i, out in enumerate(ast):
|
166
|
+
assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}"
|
167
|
+
assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}"
|
168
|
+
assert out.arg.st.size == ast[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
|
169
|
+
dfs(out, out.arg.st)
|
tinygrad/renderer/__init__.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from typing import Optional, List, Tuple, Dict
|
2
2
|
import functools
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from tinygrad.helpers import to_function_name
|
4
|
+
from tinygrad.helpers import getenv, to_function_name
|
5
5
|
from tinygrad.codegen.uops import UOpGraph
|
6
6
|
from tinygrad.shape.symbolic import sym_infer, sint, Variable
|
7
7
|
from tinygrad.dtype import DType
|
@@ -52,10 +52,14 @@ class Renderer:
|
|
52
52
|
supports_float4: bool = True
|
53
53
|
has_local: bool = True
|
54
54
|
has_shared: bool = True
|
55
|
-
# NOTE: these two should be in
|
56
|
-
global_max: Optional[
|
57
|
-
local_max: Optional[
|
55
|
+
# NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
|
56
|
+
global_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
|
57
|
+
local_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
|
58
58
|
shared_max: int = 32768
|
59
59
|
tensor_cores: List[TensorCore] = []
|
60
|
+
@functools.cached_property
|
61
|
+
def tc_opt(self): return getenv("TC_OPT")
|
62
|
+
@functools.cached_property
|
63
|
+
def tc(self): return getenv("TC", 1)
|
60
64
|
|
61
65
|
def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("needs a renderer")
|