tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/kernel.py +230 -190
- tinygrad/codegen/linearizer.py +278 -384
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +132 -275
- tinygrad/dtype.py +53 -37
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +28 -14
- tinygrad/helpers.py +72 -43
- tinygrad/lazy.py +141 -240
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +179 -8
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +86 -17
- tinygrad/ops.py +70 -44
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +299 -206
- tinygrad/renderer/llvmir.py +118 -123
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +130 -38
- tinygrad/runtime/ops_disk.py +45 -42
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +42 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +41 -105
- tinygrad/shape/symbolic.py +98 -95
- tinygrad/shape/view.py +137 -35
- tinygrad/tensor.py +2367 -442
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/helpers.py
CHANGED
@@ -1,22 +1,27 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes
|
3
|
-
|
3
|
+
import itertools, urllib.request, subprocess
|
4
4
|
from tqdm import tqdm
|
5
5
|
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
|
6
6
|
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
7
7
|
from typing_extensions import TypeGuard
|
8
|
+
from tinygrad.shape.shapetracker import sint
|
8
9
|
|
9
10
|
T = TypeVar("T")
|
10
11
|
U = TypeVar("U")
|
11
12
|
# NOTE: it returns int 1 if x is empty regardless of the type of x
|
12
|
-
def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.
|
13
|
+
def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.mul, x, 1)
|
13
14
|
|
14
15
|
# NOTE: helpers is not allowed to import from anything else in tinygrad
|
15
16
|
OSX = platform.system() == "Darwin"
|
16
17
|
CI = os.getenv("CI", "") != ""
|
17
18
|
|
18
19
|
def dedup(x:Iterable[T]): return list(dict.fromkeys(x)) # retains list order
|
19
|
-
def argfix(*x):
|
20
|
+
def argfix(*x):
|
21
|
+
if x and x[0].__class__ in (tuple, list):
|
22
|
+
if len(x) != 1: raise ValueError(f"bad arg {x}")
|
23
|
+
return tuple(x[0])
|
24
|
+
return x
|
20
25
|
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
21
26
|
def all_same(items:List[T]): return all(x == items[0] for x in items)
|
22
27
|
def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
@@ -51,12 +56,21 @@ def get_child(obj, key):
|
|
51
56
|
else: obj = getattr(obj, k)
|
52
57
|
return obj
|
53
58
|
|
59
|
+
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
60
|
+
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
|
61
|
+
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
62
|
+
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
|
63
|
+
except ValueError: return None
|
64
|
+
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
|
65
|
+
|
54
66
|
@functools.lru_cache(maxsize=None)
|
55
67
|
def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
|
56
68
|
@functools.lru_cache(maxsize=None)
|
57
69
|
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
|
58
70
|
def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
|
59
71
|
|
72
|
+
class GraphException(Exception): pass
|
73
|
+
|
60
74
|
class Context(contextlib.ContextDecorator):
|
61
75
|
stack: ClassVar[List[dict[str, int]]] = [{}]
|
62
76
|
def __init__(self, **kwargs): self.kwargs = kwargs
|
@@ -70,18 +84,34 @@ class Context(contextlib.ContextDecorator):
|
|
70
84
|
class ContextVar:
|
71
85
|
_cache: ClassVar[Dict[str, ContextVar]] = {}
|
72
86
|
value: int
|
87
|
+
key: str
|
73
88
|
def __new__(cls, key, default_value):
|
74
89
|
if key in ContextVar._cache: return ContextVar._cache[key]
|
75
90
|
instance = ContextVar._cache[key] = super().__new__(cls)
|
76
|
-
instance.value = getenv(key, default_value)
|
91
|
+
instance.value, instance.key = getenv(key, default_value), key
|
77
92
|
return instance
|
78
93
|
def __bool__(self): return bool(self.value)
|
79
94
|
def __ge__(self, x): return self.value >= x
|
80
95
|
def __gt__(self, x): return self.value > x
|
81
96
|
def __lt__(self, x): return self.value < x
|
82
97
|
|
83
|
-
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
|
84
|
-
|
98
|
+
DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
|
99
|
+
WINO, THREEFRY, CACHECOLLECTING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CACHECOLLECTING", 1)
|
100
|
+
GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
|
101
|
+
MULTIOUTPUT = ContextVar("MULTIOUTPUT", 1)
|
102
|
+
|
103
|
+
# **************** global state Counters ****************
|
104
|
+
|
105
|
+
class GlobalCounters:
|
106
|
+
global_ops: ClassVar[int] = 0
|
107
|
+
global_mem: ClassVar[int] = 0
|
108
|
+
time_sum_s: ClassVar[float] = 0.0
|
109
|
+
kernel_count: ClassVar[int] = 0
|
110
|
+
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
|
111
|
+
@staticmethod
|
112
|
+
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
|
113
|
+
|
114
|
+
# **************** timer and profiler ****************
|
85
115
|
|
86
116
|
class Timing(contextlib.ContextDecorator):
|
87
117
|
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
|
@@ -90,16 +120,24 @@ class Timing(contextlib.ContextDecorator):
|
|
90
120
|
self.et = time.perf_counter_ns() - self.st
|
91
121
|
if self.enabled: print(f"{self.prefix}{self.et*1e-6:6.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
|
92
122
|
|
123
|
+
def _format_fcn(fcn): return f"{fcn[0]}:{fcn[1]}:{fcn[2]}"
|
93
124
|
class Profiling(contextlib.ContextDecorator):
|
94
|
-
def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None
|
125
|
+
def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None, ts=1):
|
126
|
+
self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3/ts
|
95
127
|
def __enter__(self):
|
96
|
-
self.pr = cProfile.Profile(
|
128
|
+
self.pr = cProfile.Profile()
|
97
129
|
if self.enabled: self.pr.enable()
|
98
130
|
def __exit__(self, *exc):
|
99
131
|
if self.enabled:
|
100
132
|
self.pr.disable()
|
101
133
|
if self.fn: self.pr.dump_stats(self.fn)
|
102
|
-
pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort)
|
134
|
+
stats = pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort)
|
135
|
+
for fcn in stats.fcn_list[0:int(len(stats.fcn_list)*self.frac)]: # type: ignore[attr-defined]
|
136
|
+
(_primitive_calls, num_calls, tottime, cumtime, callers) = stats.stats[fcn] # type: ignore[attr-defined]
|
137
|
+
scallers = sorted(callers.items(), key=lambda x: -x[1][2])
|
138
|
+
print(f"n:{num_calls:8d} tm:{tottime*self.time_scale:7.2f}ms tot:{cumtime*self.time_scale:7.2f}ms",
|
139
|
+
colored(_format_fcn(fcn), "yellow") + " "*(50-len(_format_fcn(fcn))),
|
140
|
+
colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if len(scallers) else '')
|
103
141
|
|
104
142
|
# *** universal database cache ***
|
105
143
|
|
@@ -107,7 +145,7 @@ _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches"
|
|
107
145
|
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
|
108
146
|
CACHELEVEL = getenv("CACHELEVEL", 2)
|
109
147
|
|
110
|
-
VERSION =
|
148
|
+
VERSION = 16
|
111
149
|
_db_connection = None
|
112
150
|
def db_connection():
|
113
151
|
global _db_connection
|
@@ -117,13 +155,18 @@ def db_connection():
|
|
117
155
|
if DEBUG >= 7: _db_connection.set_trace_callback(print)
|
118
156
|
return _db_connection
|
119
157
|
|
158
|
+
def diskcache_clear():
|
159
|
+
cur = db_connection().cursor()
|
160
|
+
drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
|
161
|
+
cur.executescript("\n".join([s[0] for s in drop_tables]))
|
162
|
+
|
120
163
|
def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
|
121
164
|
if CACHELEVEL == 0: return None
|
122
165
|
if isinstance(key, (str,int)): key = {"key": key}
|
123
166
|
conn = db_connection()
|
124
167
|
cur = conn.cursor()
|
125
168
|
try:
|
126
|
-
res = cur.execute(f"SELECT val FROM {table}_{VERSION} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
|
169
|
+
res = cur.execute(f"SELECT val FROM '{table}_{VERSION}' WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
|
127
170
|
except sqlite3.OperationalError:
|
128
171
|
return None # table doesn't exist
|
129
172
|
if (val:=res.fetchone()) is not None: return pickle.loads(val[0])
|
@@ -138,20 +181,27 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
|
|
138
181
|
if table not in _db_tables:
|
139
182
|
TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
|
140
183
|
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
|
141
|
-
cur.execute(f"CREATE TABLE IF NOT EXISTS {table}_{VERSION} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
184
|
+
cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
142
185
|
_db_tables.add(table)
|
143
|
-
cur.execute(f"REPLACE INTO {table}_{VERSION} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501
|
186
|
+
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501
|
144
187
|
conn.commit()
|
145
188
|
cur.close()
|
146
189
|
return val
|
147
190
|
|
191
|
+
def diskcache(func):
|
192
|
+
def wrapper(*args, **kwargs) -> bytes:
|
193
|
+
table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
|
194
|
+
if (ret:=diskcache_get(table, key)): return ret
|
195
|
+
return diskcache_put(table, key, func(*args, **kwargs))
|
196
|
+
return wrapper
|
197
|
+
|
148
198
|
# *** http support ***
|
149
199
|
|
150
200
|
def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
|
151
201
|
if url.startswith(("/", ".")): return pathlib.Path(url)
|
152
202
|
fp = pathlib.Path(name) if name is not None and (isinstance(name, pathlib.Path) or '/' in name) else pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (name if name else hashlib.md5(url.encode('utf-8')).hexdigest()) # noqa: E501
|
153
203
|
if not fp.is_file() or not allow_caching:
|
154
|
-
with request.urlopen(url, timeout=10) as r:
|
204
|
+
with urllib.request.urlopen(url, timeout=10) as r:
|
155
205
|
assert r.status == 200
|
156
206
|
total_length = int(r.headers.get('content-length', 0))
|
157
207
|
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=url)
|
@@ -170,10 +220,16 @@ def cpu_time_execution(cb, enable):
|
|
170
220
|
cb()
|
171
221
|
if enable: return time.perf_counter()-st
|
172
222
|
|
223
|
+
def cpu_objdump(lib):
|
224
|
+
with tempfile.NamedTemporaryFile(delete=True) as f:
|
225
|
+
pathlib.Path(f.name).write_bytes(lib)
|
226
|
+
print(subprocess.check_output(['objdump', '-d', f.name]).decode('utf-8'))
|
227
|
+
|
173
228
|
# *** ctypes helpers
|
174
229
|
|
175
230
|
# TODO: make this work with read only memoryviews (if possible)
|
176
|
-
def from_mv(mv:memoryview, to_type=ctypes.c_char):
|
231
|
+
def from_mv(mv:memoryview, to_type=ctypes.c_char):
|
232
|
+
return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
|
177
233
|
def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
|
178
234
|
def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
|
179
235
|
@functools.lru_cache(maxsize=None)
|
@@ -182,31 +238,4 @@ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
|
|
182
238
|
_pack_, _fields_ = 1, fields
|
183
239
|
return CStruct
|
184
240
|
def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
|
185
|
-
def
|
186
|
-
def flat_mv(mv:memoryview):
|
187
|
-
if len(mv) == 0: return mv
|
188
|
-
return mv.cast("B", shape=(mv.nbytes,))
|
189
|
-
|
190
|
-
# *** Helpers for CUDA-like APIs.
|
191
|
-
|
192
|
-
def compile_cuda_style(prg, compile_options, prog_t, create_prog, compile_prog, get_code, get_code_size, get_log, get_log_size, check) -> bytes:
|
193
|
-
check(create_prog(ctypes.byref(prog := prog_t()), prg.encode(), "<null>".encode(), 0, None, None))
|
194
|
-
status = compile_prog(prog, len(compile_options), to_char_p_p([o.encode() for o in compile_options]))
|
195
|
-
|
196
|
-
if status != 0: raise RuntimeError(f"compile failed: {get_bytes(prog, get_log_size, get_log, check).decode()}")
|
197
|
-
return get_bytes(prog, get_code_size, get_code, check)
|
198
|
-
|
199
|
-
def encode_args_cuda_style(bufs, vals, device_ptr_t, marks) -> Tuple[ctypes.Array, ctypes.Structure]:
|
200
|
-
c_args = init_c_struct_t(tuple([(f'f{i}', device_ptr_t) for i in range(len(bufs))] + [(f'f{i}', ctypes.c_int) for i in range(len(bufs), len(bufs)+len(vals))]))(*bufs, *vals) # noqa: E501
|
201
|
-
return (ctypes.c_void_p * 5)(ctypes.c_void_p(marks[0]), ctypes.cast(ctypes.pointer(c_args), ctypes.c_void_p), ctypes.c_void_p(marks[1]), ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(marks[2])), c_args # noqa: E501
|
202
|
-
|
203
|
-
def time_execution_cuda_style(cb, ev_t, evcreate, evrecord, evsync, evdestroy, evtime, enable=False) -> Optional[float]:
|
204
|
-
if not enable: return cb()
|
205
|
-
evs = [init_c_var(ev_t(), lambda x: evcreate(ctypes.byref(x), 0)) for _ in range(2)]
|
206
|
-
evrecord(evs[0], None)
|
207
|
-
cb()
|
208
|
-
evrecord(evs[1], None)
|
209
|
-
evsync(evs[1])
|
210
|
-
evtime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1])
|
211
|
-
for ev in evs: evdestroy(ev)
|
212
|
-
return ret.value * 1e-3
|
241
|
+
def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))
|