tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/kernel.py +230 -190
  3. tinygrad/codegen/linearizer.py +278 -384
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +132 -275
  6. tinygrad/dtype.py +53 -37
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +28 -14
  14. tinygrad/helpers.py +72 -43
  15. tinygrad/lazy.py +141 -240
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +179 -8
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +106 -28
  20. tinygrad/nn/state.py +86 -17
  21. tinygrad/ops.py +70 -44
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +299 -206
  25. tinygrad/renderer/llvmir.py +118 -123
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +59 -54
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +37 -41
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +16 -14
  43. tinygrad/runtime/ops_cuda.py +130 -38
  44. tinygrad/runtime/ops_disk.py +45 -42
  45. tinygrad/runtime/ops_gpu.py +52 -50
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +36 -56
  48. tinygrad/runtime/ops_metal.py +42 -24
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +41 -105
  53. tinygrad/shape/symbolic.py +98 -95
  54. tinygrad/shape/view.py +137 -35
  55. tinygrad/tensor.py +2367 -442
  56. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/features/image.py +0 -93
  61. tinygrad/features/multi.py +0 -103
  62. tinygrad/features/search.py +0 -160
  63. tinygrad/graph.py +0 -106
  64. tinygrad/jit.py +0 -152
  65. tinygrad/realize.py +0 -50
  66. tinygrad/runtime/graph/hip.py +0 -24
  67. tinygrad/runtime/ops_cpu.py +0 -45
  68. tinygrad/runtime/ops_hip.py +0 -97
  69. tinygrad/runtime/ops_torch.py +0 -49
  70. tinygrad-0.8.0.dist-info/RECORD +0 -41
  71. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/helpers.py CHANGED
@@ -1,22 +1,27 @@
1
1
  from __future__ import annotations
2
2
  import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes
3
- from urllib import request # NOTE: this has to be imported specifically
3
+ import itertools, urllib.request, subprocess
4
4
  from tqdm import tqdm
5
5
  from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
6
6
  if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
7
7
  from typing_extensions import TypeGuard
8
+ from tinygrad.shape.shapetracker import sint
8
9
 
9
10
  T = TypeVar("T")
10
11
  U = TypeVar("U")
11
12
  # NOTE: it returns int 1 if x is empty regardless of the type of x
12
- def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.__mul__, x, 1)
13
+ def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.mul, x, 1)
13
14
 
14
15
  # NOTE: helpers is not allowed to import from anything else in tinygrad
15
16
  OSX = platform.system() == "Darwin"
16
17
  CI = os.getenv("CI", "") != ""
17
18
 
18
19
  def dedup(x:Iterable[T]): return list(dict.fromkeys(x)) # retains list order
19
- def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x
20
+ def argfix(*x):
21
+ if x and x[0].__class__ in (tuple, list):
22
+ if len(x) != 1: raise ValueError(f"bad arg {x}")
23
+ return tuple(x[0])
24
+ return x
20
25
  def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
21
26
  def all_same(items:List[T]): return all(x == items[0] for x in items)
22
27
  def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
@@ -51,12 +56,21 @@ def get_child(obj, key):
51
56
  else: obj = getattr(obj, k)
52
57
  return obj
53
58
 
59
+ # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
60
+ def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
61
+ acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
62
+ try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
63
+ except ValueError: return None
64
+ return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
65
+
54
66
  @functools.lru_cache(maxsize=None)
55
67
  def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
56
68
  @functools.lru_cache(maxsize=None)
57
69
  def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
58
70
  def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
59
71
 
72
+ class GraphException(Exception): pass
73
+
60
74
  class Context(contextlib.ContextDecorator):
61
75
  stack: ClassVar[List[dict[str, int]]] = [{}]
62
76
  def __init__(self, **kwargs): self.kwargs = kwargs
@@ -70,18 +84,34 @@ class Context(contextlib.ContextDecorator):
70
84
  class ContextVar:
71
85
  _cache: ClassVar[Dict[str, ContextVar]] = {}
72
86
  value: int
87
+ key: str
73
88
  def __new__(cls, key, default_value):
74
89
  if key in ContextVar._cache: return ContextVar._cache[key]
75
90
  instance = ContextVar._cache[key] = super().__new__(cls)
76
- instance.value = getenv(key, default_value)
91
+ instance.value, instance.key = getenv(key, default_value), key
77
92
  return instance
78
93
  def __bool__(self): return bool(self.value)
79
94
  def __ge__(self, x): return self.value >= x
80
95
  def __gt__(self, x): return self.value > x
81
96
  def __lt__(self, x): return self.value < x
82
97
 
83
- DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
84
- GRAPH, GRAPHPATH = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
98
+ DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
99
+ WINO, THREEFRY, CACHECOLLECTING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CACHECOLLECTING", 1)
100
+ GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
101
+ MULTIOUTPUT = ContextVar("MULTIOUTPUT", 1)
102
+
103
+ # **************** global state Counters ****************
104
+
105
+ class GlobalCounters:
106
+ global_ops: ClassVar[int] = 0
107
+ global_mem: ClassVar[int] = 0
108
+ time_sum_s: ClassVar[float] = 0.0
109
+ kernel_count: ClassVar[int] = 0
110
+ mem_used: ClassVar[int] = 0 # NOTE: this is not reset
111
+ @staticmethod
112
+ def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
113
+
114
+ # **************** timer and profiler ****************
85
115
 
86
116
  class Timing(contextlib.ContextDecorator):
87
117
  def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
@@ -90,16 +120,24 @@ class Timing(contextlib.ContextDecorator):
90
120
  self.et = time.perf_counter_ns() - self.st
91
121
  if self.enabled: print(f"{self.prefix}{self.et*1e-6:6.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
92
122
 
123
+ def _format_fcn(fcn): return f"{fcn[0]}:{fcn[1]}:{fcn[2]}"
93
124
  class Profiling(contextlib.ContextDecorator):
94
- def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None): self.enabled, self.sort, self.frac, self.fn = enabled, sort, frac, fn
125
+ def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None, ts=1):
126
+ self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3/ts
95
127
  def __enter__(self):
96
- self.pr = cProfile.Profile(timer=lambda: int(time.time()*1e9), timeunit=1e-6)
128
+ self.pr = cProfile.Profile()
97
129
  if self.enabled: self.pr.enable()
98
130
  def __exit__(self, *exc):
99
131
  if self.enabled:
100
132
  self.pr.disable()
101
133
  if self.fn: self.pr.dump_stats(self.fn)
102
- pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort).print_stats(self.frac)
134
+ stats = pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort)
135
+ for fcn in stats.fcn_list[0:int(len(stats.fcn_list)*self.frac)]: # type: ignore[attr-defined]
136
+ (_primitive_calls, num_calls, tottime, cumtime, callers) = stats.stats[fcn] # type: ignore[attr-defined]
137
+ scallers = sorted(callers.items(), key=lambda x: -x[1][2])
138
+ print(f"n:{num_calls:8d} tm:{tottime*self.time_scale:7.2f}ms tot:{cumtime*self.time_scale:7.2f}ms",
139
+ colored(_format_fcn(fcn), "yellow") + " "*(50-len(_format_fcn(fcn))),
140
+ colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if len(scallers) else '')
103
141
 
104
142
  # *** universal database cache ***
105
143
 
@@ -107,7 +145,7 @@ _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches"
107
145
  CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
108
146
  CACHELEVEL = getenv("CACHELEVEL", 2)
109
147
 
110
- VERSION = 10
148
+ VERSION = 16
111
149
  _db_connection = None
112
150
  def db_connection():
113
151
  global _db_connection
@@ -117,13 +155,18 @@ def db_connection():
117
155
  if DEBUG >= 7: _db_connection.set_trace_callback(print)
118
156
  return _db_connection
119
157
 
158
+ def diskcache_clear():
159
+ cur = db_connection().cursor()
160
+ drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
161
+ cur.executescript("\n".join([s[0] for s in drop_tables]))
162
+
120
163
  def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
121
164
  if CACHELEVEL == 0: return None
122
165
  if isinstance(key, (str,int)): key = {"key": key}
123
166
  conn = db_connection()
124
167
  cur = conn.cursor()
125
168
  try:
126
- res = cur.execute(f"SELECT val FROM {table}_{VERSION} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
169
+ res = cur.execute(f"SELECT val FROM '{table}_{VERSION}' WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
127
170
  except sqlite3.OperationalError:
128
171
  return None # table doesn't exist
129
172
  if (val:=res.fetchone()) is not None: return pickle.loads(val[0])
@@ -138,20 +181,27 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
138
181
  if table not in _db_tables:
139
182
  TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
140
183
  ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
141
- cur.execute(f"CREATE TABLE IF NOT EXISTS {table}_{VERSION} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
184
+ cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
142
185
  _db_tables.add(table)
143
- cur.execute(f"REPLACE INTO {table}_{VERSION} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501
186
+ cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501
144
187
  conn.commit()
145
188
  cur.close()
146
189
  return val
147
190
 
191
+ def diskcache(func):
192
+ def wrapper(*args, **kwargs) -> bytes:
193
+ table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
194
+ if (ret:=diskcache_get(table, key)): return ret
195
+ return diskcache_put(table, key, func(*args, **kwargs))
196
+ return wrapper
197
+
148
198
  # *** http support ***
149
199
 
150
200
  def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
151
201
  if url.startswith(("/", ".")): return pathlib.Path(url)
152
202
  fp = pathlib.Path(name) if name is not None and (isinstance(name, pathlib.Path) or '/' in name) else pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (name if name else hashlib.md5(url.encode('utf-8')).hexdigest()) # noqa: E501
153
203
  if not fp.is_file() or not allow_caching:
154
- with request.urlopen(url, timeout=10) as r:
204
+ with urllib.request.urlopen(url, timeout=10) as r:
155
205
  assert r.status == 200
156
206
  total_length = int(r.headers.get('content-length', 0))
157
207
  progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=url)
@@ -170,10 +220,16 @@ def cpu_time_execution(cb, enable):
170
220
  cb()
171
221
  if enable: return time.perf_counter()-st
172
222
 
223
+ def cpu_objdump(lib):
224
+ with tempfile.NamedTemporaryFile(delete=True) as f:
225
+ pathlib.Path(f.name).write_bytes(lib)
226
+ print(subprocess.check_output(['objdump', '-d', f.name]).decode('utf-8'))
227
+
173
228
  # *** ctypes helpers
174
229
 
175
230
  # TODO: make this work with read only memoryviews (if possible)
176
- def from_mv(mv:memoryview, to_type=ctypes.c_char): return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type))
231
+ def from_mv(mv:memoryview, to_type=ctypes.c_char):
232
+ return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
177
233
  def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
178
234
  def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
179
235
  @functools.lru_cache(maxsize=None)
@@ -182,31 +238,4 @@ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
182
238
  _pack_, _fields_ = 1, fields
183
239
  return CStruct
184
240
  def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
185
- def get_bytes(arg, get_sz, get_str, check) -> bytes: return (sz := init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))), ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value))[1] # noqa: E501
186
- def flat_mv(mv:memoryview):
187
- if len(mv) == 0: return mv
188
- return mv.cast("B", shape=(mv.nbytes,))
189
-
190
- # *** Helpers for CUDA-like APIs.
191
-
192
- def compile_cuda_style(prg, compile_options, prog_t, create_prog, compile_prog, get_code, get_code_size, get_log, get_log_size, check) -> bytes:
193
- check(create_prog(ctypes.byref(prog := prog_t()), prg.encode(), "<null>".encode(), 0, None, None))
194
- status = compile_prog(prog, len(compile_options), to_char_p_p([o.encode() for o in compile_options]))
195
-
196
- if status != 0: raise RuntimeError(f"compile failed: {get_bytes(prog, get_log_size, get_log, check).decode()}")
197
- return get_bytes(prog, get_code_size, get_code, check)
198
-
199
- def encode_args_cuda_style(bufs, vals, device_ptr_t, marks) -> Tuple[ctypes.Array, ctypes.Structure]:
200
- c_args = init_c_struct_t(tuple([(f'f{i}', device_ptr_t) for i in range(len(bufs))] + [(f'f{i}', ctypes.c_int) for i in range(len(bufs), len(bufs)+len(vals))]))(*bufs, *vals) # noqa: E501
201
- return (ctypes.c_void_p * 5)(ctypes.c_void_p(marks[0]), ctypes.cast(ctypes.pointer(c_args), ctypes.c_void_p), ctypes.c_void_p(marks[1]), ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(marks[2])), c_args # noqa: E501
202
-
203
- def time_execution_cuda_style(cb, ev_t, evcreate, evrecord, evsync, evdestroy, evtime, enable=False) -> Optional[float]:
204
- if not enable: return cb()
205
- evs = [init_c_var(ev_t(), lambda x: evcreate(ctypes.byref(x), 0)) for _ in range(2)]
206
- evrecord(evs[0], None)
207
- cb()
208
- evrecord(evs[1], None)
209
- evsync(evs[1])
210
- evtime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1])
211
- for ev in evs: evdestroy(ev)
212
- return ret.value * 1e-3
241
+ def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))