tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/helpers.py CHANGED
@@ -1,10 +1,8 @@
1
1
  from __future__ import annotations
2
- import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes, sys
3
- import itertools, urllib.request, subprocess, shutil, math, json
4
- from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
5
- if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
6
- from typing_extensions import TypeGuard
7
- from tinygrad.shape.shapetracker import sint
2
+ import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
3
+ import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib
4
+ from dataclasses import dataclass
5
+ from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard
8
6
 
9
7
  T = TypeVar("T")
10
8
  U = TypeVar("U")
@@ -15,6 +13,9 @@ def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.mul, x
15
13
  OSX = platform.system() == "Darwin"
16
14
  CI = os.getenv("CI", "") != ""
17
15
 
16
+ # fix colors on Windows, https://stackoverflow.com/questions/12492810/python-how-can-i-make-the-ansi-escape-codes-to-work-also-in-windows
17
+ if sys.platform == "win32": os.system("")
18
+
18
19
  def dedup(x:Iterable[T]): return list(dict.fromkeys(x)) # retains list order
19
20
  def argfix(*x):
20
21
  if x and x[0].__class__ in (tuple, list):
@@ -22,51 +23,53 @@ def argfix(*x):
22
23
  return tuple(x[0])
23
24
  return x
24
25
  def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
25
- def all_same(items:List[T]): return all(x == items[0] for x in items)
26
+ def all_same(items:Union[Tuple[T, ...], List[T]]): return all(x == items[0] for x in items)
26
27
  def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
27
28
  def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
29
+ def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
30
+ def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
28
31
  def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
29
32
  def ansilen(s:str): return len(ansistrip(s))
30
- def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
33
+ def make_tuple(x:Union[int, Sequence[int]], cnt:int) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
31
34
  def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
32
- def fully_flatten(l): return [item for sublist in l for item in (fully_flatten(sublist) if isinstance(sublist, (tuple, list)) else [sublist])]
35
+ def fully_flatten(l):
36
+ if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):
37
+ flattened = []
38
+ if hasattr(l, "shape") and l.shape == (): flattened.append(l[()])
39
+ else:
40
+ for i in range(len(l)): flattened.extend(fully_flatten(l[i]))
41
+ return flattened
42
+ return [l]
33
43
  def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
34
44
  def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
35
- def round_up(num, amt:int): return (num+amt-1)//amt * amt
45
+ def ceildiv(num, amt):
46
+ ret = -(num//-amt)
47
+ return ret if not isinstance(ret, float) else int(ret)
48
+ def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
49
+ def data64(data: int) -> Tuple[int, int]: return (data >> 32, data & 0xFFFFFFFF)
50
+ def data64_le(data: int) -> Tuple[int, int]: return (data & 0xFFFFFFFF, data >> 32)
36
51
  def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
37
- assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" # noqa: E501
52
+ kvs = set([(k,v) for d in ds for k,v in d.items()])
53
+ assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
38
54
  return {k:v for d in ds for k,v in d.items()}
39
- def partition(lst:List[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
55
+ def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
40
56
  a:List[T] = []
41
57
  b:List[T] = []
42
- for s in lst: (a if fxn(s) else b).append(s)
58
+ for s in itr: (a if fxn(s) else b).append(s)
43
59
  return a,b
44
60
  def unwrap(x:Optional[T]) -> T:
45
61
  assert x is not None
46
62
  return x
47
- def unwrap2(x:Tuple[T,Any]) -> T:
48
- ret, err = x
49
- assert err is None, str(err)
50
- return ret
51
63
  def get_child(obj, key):
52
64
  for k in key.split('.'):
53
65
  if k.isnumeric(): obj = obj[int(k)]
54
66
  elif isinstance(obj, dict): obj = obj[k]
55
67
  else: obj = getattr(obj, k)
56
68
  return obj
69
+ def word_wrap(x, wrap=80): return x if len(x) <= wrap else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
57
70
 
58
- def get_shape(x) -> Tuple[int, ...]:
59
- if not isinstance(x, (list, tuple)): return ()
60
- subs = [get_shape(xi) for xi in x]
61
- if not all_same(subs): raise ValueError(f"inhomogeneous shape from {x}")
62
- return (len(subs),) + (subs[0] if subs else ())
63
-
64
- # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
65
- def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
66
- acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
67
- try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
68
- except ValueError: return None
69
- return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
71
+ # for length N coefficients `p`, returns p[0] * x**(N-1) + p[1] * x**(N-2) + ... + p[-2] * x + p[-1]
72
+ def polyN(x:T, p:List[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore
70
73
 
71
74
  @functools.lru_cache(maxsize=None)
72
75
  def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
@@ -74,8 +77,6 @@ def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+str
74
77
  def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
75
78
  def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
76
79
 
77
- class GraphException(Exception): pass
78
-
79
80
  class Context(contextlib.ContextDecorator):
80
81
  stack: ClassVar[List[dict[str, int]]] = [{}]
81
82
  def __init__(self, **kwargs): self.kwargs = kwargs
@@ -90,20 +91,31 @@ class ContextVar:
90
91
  _cache: ClassVar[Dict[str, ContextVar]] = {}
91
92
  value: int
92
93
  key: str
93
- def __new__(cls, key, default_value):
94
- if key in ContextVar._cache: return ContextVar._cache[key]
95
- instance = ContextVar._cache[key] = super().__new__(cls)
96
- instance.value, instance.key = getenv(key, default_value), key
97
- return instance
94
+ def __init__(self, key, default_value):
95
+ assert key not in ContextVar._cache, f"attempt to recreate ContextVar {key}"
96
+ ContextVar._cache[key] = self
97
+ self.value, self.key = getenv(key, default_value), key
98
98
  def __bool__(self): return bool(self.value)
99
99
  def __ge__(self, x): return self.value >= x
100
100
  def __gt__(self, x): return self.value > x
101
101
  def __lt__(self, x): return self.value < x
102
102
 
103
103
  DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
104
- WINO, THREEFRY, CAPTURING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1)
105
- GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
106
- MULTIOUTPUT, PROFILE = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0)
104
+ WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
105
+ PROFILE, PROFILEPATH = ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json"))
106
+ USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1)
107
+ FUSE_ARANGE, FUSE_CONV_BW, LAZYCACHE = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0), ContextVar("LAZYCACHE", 1)
108
+ SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
109
+
110
+ @dataclass(frozen=True)
111
+ class Metadata:
112
+ name: str
113
+ caller: str
114
+ backward: bool = False
115
+ def __hash__(self): return hash(self.name)
116
+ def __repr__(self): return str(self) + (f" - {self.caller}" if self.caller else "")
117
+ def __str__(self): return self.name + (" bw" if self.backward else "")
118
+ _METADATA: contextvars.ContextVar[Optional[Metadata]] = contextvars.ContextVar("_METADATA", default=None)
107
119
 
108
120
  # **************** global state Counters ****************
109
121
 
@@ -130,47 +142,21 @@ class Profiling(contextlib.ContextDecorator):
130
142
  def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None, ts=1):
131
143
  self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3/ts
132
144
  def __enter__(self):
145
+ import cProfile
133
146
  self.pr = cProfile.Profile()
134
147
  if self.enabled: self.pr.enable()
135
148
  def __exit__(self, *exc):
136
149
  if self.enabled:
137
150
  self.pr.disable()
138
151
  if self.fn: self.pr.dump_stats(self.fn)
152
+ import pstats
139
153
  stats = pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort)
140
154
  for fcn in stats.fcn_list[0:int(len(stats.fcn_list)*self.frac)]: # type: ignore[attr-defined]
141
155
  (_primitive_calls, num_calls, tottime, cumtime, callers) = stats.stats[fcn] # type: ignore[attr-defined]
142
156
  scallers = sorted(callers.items(), key=lambda x: -x[1][2])
143
157
  print(f"n:{num_calls:8d} tm:{tottime*self.time_scale:7.2f}ms tot:{cumtime*self.time_scale:7.2f}ms",
144
- colored(_format_fcn(fcn), "yellow") + " "*(50-len(_format_fcn(fcn))),
145
- colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if len(scallers) else '')
146
-
147
- class ProfileLogger:
148
- writers: int = 0
149
- mjson: List[Dict] = []
150
- actors: Dict[str, int] = {}
151
- subactors: Dict[Tuple[str, str], int] = {}
152
- path = getenv("PROFILE_OUTPUT_FILE", temp("tinygrad_profile.json"))
153
-
154
- def __init__(self): self.events, ProfileLogger.writers = [], ProfileLogger.writers + 1
155
-
156
- def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor)]
157
-
158
- def __del__(self):
159
- for name,st,et,actor_name,subactor_name in self.events:
160
- if actor_name not in self.actors:
161
- self.actors[actor_name] = (pid:=len(self.actors))
162
- self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
163
-
164
- if (subactor_key:=(actor_name,subactor_name)) not in self.subactors:
165
- self.subactors[subactor_key] = (tid:=len(self.subactors))
166
- self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
167
-
168
- self.mjson.append({"name": name, "ph": "X", "pid": self.actors[actor_name], "tid": self.subactors.get(subactor_key, -1), "ts":st, "dur":et-st})
169
-
170
- ProfileLogger.writers -= 1
171
- if ProfileLogger.writers == 0:
172
- with open(self.path, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
173
- print(f"Saved profile to {self.path}. Use https://ui.perfetto.dev/ to open it.")
158
+ colored(_format_fcn(fcn).ljust(50), "yellow"),
159
+ colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if scallers else '')
174
160
 
175
161
  # *** universal database cache ***
176
162
 
@@ -184,14 +170,17 @@ def db_connection():
184
170
  global _db_connection
185
171
  if _db_connection is None:
186
172
  os.makedirs(CACHEDB.rsplit(os.sep, 1)[0], exist_ok=True)
187
- _db_connection = sqlite3.connect(CACHEDB)
173
+ _db_connection = sqlite3.connect(CACHEDB, timeout=60, isolation_level="IMMEDIATE")
174
+ # another connection has set it already or is in the process of setting it
175
+ # that connection will lock the database
176
+ with contextlib.suppress(sqlite3.OperationalError): _db_connection.execute("PRAGMA journal_mode=WAL").fetchone()
188
177
  if DEBUG >= 7: _db_connection.set_trace_callback(print)
189
178
  return _db_connection
190
179
 
191
180
  def diskcache_clear():
192
181
  cur = db_connection().cursor()
193
182
  drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
194
- cur.executescript("\n".join([s[0] for s in drop_tables]))
183
+ cur.executescript("\n".join([s[0] for s in drop_tables] + ["VACUUM;"]))
195
184
 
196
185
  def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
197
186
  if CACHELEVEL == 0: return None
@@ -230,22 +219,36 @@ def diskcache(func):
230
219
 
231
220
  # *** http support ***
232
221
 
233
- def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None,
222
+ def _ensure_downloads_dir() -> pathlib.Path:
223
+ # if we are on a tinybox, use the raid array
224
+ if pathlib.Path("/etc/tinybox-release").is_file():
225
+ # try creating dir with sudo
226
+ if not (downloads_dir := pathlib.Path("/raid/downloads")).exists():
227
+ subprocess.run(["sudo", "mkdir", "-p", downloads_dir], check=True)
228
+ subprocess.run(["sudo", "chown", "tiny:root", downloads_dir], check=True)
229
+ subprocess.run(["sudo", "chmod", "775", downloads_dir], check=True)
230
+ return downloads_dir
231
+ return pathlib.Path(_cache_dir) / "tinygrad" / "downloads"
232
+
233
+ def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None, gunzip:bool=False,
234
234
  allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
235
235
  if url.startswith(("/", ".")): return pathlib.Path(url)
236
236
  if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
237
- else: fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (subdir or "") / (name or hashlib.md5(url.encode('utf-8')).hexdigest())
237
+ else:
238
+ fp = _ensure_downloads_dir() / (subdir or "") / \
239
+ ((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
238
240
  if not fp.is_file() or not allow_caching:
239
241
  with urllib.request.urlopen(url, timeout=10) as r:
240
242
  assert r.status == 200
241
- total_length = int(r.headers.get('content-length', 0))
242
- progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}: ", disable=CI)
243
+ length = int(r.headers.get('content-length', 0)) if not gunzip else None
244
+ progress_bar = tqdm(total=length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
243
245
  (path := fp.parent).mkdir(parents=True, exist_ok=True)
246
+ readfile = gzip.GzipFile(fileobj=r) if gunzip else r
244
247
  with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
245
- while chunk := r.read(16384): progress_bar.update(f.write(chunk))
248
+ while chunk := readfile.read(16384): progress_bar.update(f.write(chunk))
246
249
  f.close()
247
250
  progress_bar.update(close=True)
248
- if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")
251
+ if length and (file_size:=os.stat(f.name).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
249
252
  pathlib.Path(f.name).rename(fp)
250
253
  return fp
251
254
 
@@ -256,10 +259,10 @@ def cpu_time_execution(cb, enable):
256
259
  cb()
257
260
  if enable: return time.perf_counter()-st
258
261
 
259
- def cpu_objdump(lib):
262
+ def cpu_objdump(lib, objdump_tool='objdump'):
260
263
  with tempfile.NamedTemporaryFile(delete=True) as f:
261
264
  pathlib.Path(f.name).write_bytes(lib)
262
- print(subprocess.check_output(['objdump', '-d', f.name]).decode('utf-8'))
265
+ print(subprocess.check_output([objdump_tool, '-d', f.name]).decode('utf-8'))
263
266
 
264
267
  # *** ctypes helpers
265
268
 
@@ -277,34 +280,49 @@ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
277
280
  def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
278
281
  def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))
279
282
 
283
+ # *** tqdm
284
+
280
285
  class tqdm:
281
- def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:int=-1, rate:int=100):
282
- self.iter, self.desc, self.dis, self.unit, self.unit_scale, self.rate = iterable, f"{desc}: " if desc else "", disable, unit, unit_scale, rate
283
- self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, len(iterable) if total==-1 else total
286
+ def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:Optional[int]=None, rate:int=100):
287
+ self.iterable, self.disable, self.unit, self.unit_scale, self.rate = iterable, disable, unit, unit_scale, rate
288
+ self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, getattr(iterable, "__len__", lambda:0)() if total is None else total
289
+ self.set_description(desc)
284
290
  self.update(0)
285
291
  def __iter__(self):
286
- try:
287
- for item in self.iter:
288
- yield item
289
- self.update(1)
290
- finally: self.update(close=True)
292
+ for item in self.iterable:
293
+ yield item
294
+ self.update(1)
295
+ self.update(close=True)
296
+ def __enter__(self): return self
297
+ def __exit__(self, *_): self.update(close=True)
291
298
  def set_description(self, desc:str): self.desc = f"{desc}: " if desc else ""
292
299
  def update(self, n:int=0, close:bool=False):
293
300
  self.n, self.i = self.n+n, self.i+1
294
- if (self.i % self.skip != 0 and not close) or self.dis: return
295
- prog, dur, term = self.n/self.t if self.t else -1, time.perf_counter()-self.st, shutil.get_terminal_size().columns
296
- if self.i/dur > self.rate and self.i: self.skip = max(int(self.i/dur)//self.rate,1) if self.i else 1
297
- def fmt(t): return ':'.join([f'{x:02d}' for x in divmod(int(t), 60)]) if t!=-1 else '?'
298
- def scl(x): return x/1000**int(math.log(x,1000))
299
- def fn(x): return (f"{scl(x):.{3-math.ceil(math.log10(scl(x)))}f}"[:4]+(f"{[' ','k','M','G','T','P'][int(math.log(x,1000))]}") if x else '0.00')
300
- if self.t: unit_text = f"{fn(self.n)}/{fn(self.t)}" if self.unit_scale else f"{self.n}/{self.t}"
301
- else: unit_text = f"{fn(self.n)}{self.unit}" if self.unit_scale else f"{self.n}{self.unit}"
302
- it_text = f"{fn(self.n/dur)}" if self.n and self.unit_scale else f"{self.n/dur:5.2f}" if self.n else "?"
303
- if self.t: suf = f'| {unit_text} [{fmt(dur)}<{fmt(dur/self.n*self.t-dur if self.n else -1)}, {it_text}{self.unit}/s]'
304
- else: suf = f'{unit_text} [{fmt(dur)}, {it_text}{self.unit}/s]'
305
- sz = max(term-5-len(suf)-len(self.desc), 1)
306
- bar = f'\r{self.desc}{round(100*prog):3}%|{"█"*round(sz*prog)}{" "*(sz-round(sz*prog))}{suf}' if self.t else f'\r{self.desc}{suf}{" "*term}'
307
- print(bar[:term+1],flush=True,end='\n'*close,file=sys.stderr)
301
+ if self.disable or (not close and self.i % self.skip != 0): return
302
+ prog, elapsed, ncols = self.n/self.t if self.t else 0, time.perf_counter()-self.st, shutil.get_terminal_size().columns
303
+ if self.i/elapsed > self.rate and self.i: self.skip = max(int(self.i/elapsed)//self.rate,1)
304
+ def HMS(t): return ':'.join(f'{x:02d}' if i else str(x) for i,x in enumerate([int(t)//3600,int(t)%3600//60,int(t)%60]) if i or x)
305
+ def SI(x): return (f"{x/1000**int(g:=math.log(x,1000)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)].strip()) if x else '0.00'
306
+ prog_text = f'{SI(self.n)}{f"/{SI(self.t)}" if self.t else self.unit}' if self.unit_scale else f'{self.n}{f"/{self.t}" if self.t else self.unit}'
307
+ est_text = f'<{HMS(elapsed/prog-elapsed) if self.n else "?"}' if self.t else ''
308
+ it_text = (SI(self.n/elapsed) if self.unit_scale else f"{self.n/elapsed:5.2f}") if self.n else "?"
309
+ suf = f'{prog_text} [{HMS(elapsed)}{est_text}, {it_text}{self.unit}/s]'
310
+ sz = max(ncols-len(self.desc)-3-2-2-len(suf), 1)
311
+ bar = '\r' + self.desc + (f'{100*prog:3.0f}%|{("█"*int(num:=sz*prog)+" ▏▎▍▌▋▊▉"[int(8*num)%8].strip()).ljust(sz," ")}| ' if self.t else '') + suf
312
+ print(bar[:ncols+1], flush=True, end='\n'*close, file=sys.stderr)
313
+ @classmethod
314
+ def write(cls, s:str): print(f"\r\033[K{s}", flush=True, file=sys.stderr)
308
315
 
309
316
  class trange(tqdm):
310
- def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
317
+ def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
318
+
319
+ # *** universal support for code object pickling
320
+
321
+ def _reconstruct_code(*args): return types.CodeType(*args)
322
+ def _serialize_code(code:types.CodeType):
323
+ args = inspect.signature(types.CodeType).parameters # NOTE: this works in Python 3.10 and up
324
+ return _reconstruct_code, tuple(code.__getattribute__('co_'+x.replace('codestring', 'code').replace('constants', 'consts')) for x in args)
325
+ copyreg.pickle(types.CodeType, _serialize_code)
326
+
327
+ def _serialize_module(module:types.ModuleType): return importlib.import_module, (module.__name__,)
328
+ copyreg.pickle(types.ModuleType, _serialize_module)
tinygrad/multi.py CHANGED
@@ -1,64 +1,60 @@
1
1
  from __future__ import annotations
2
- from typing import Optional, Union, Any, Tuple, List
2
+ from typing import Optional, Tuple, List, Dict
3
3
  import functools, itertools, operator
4
- from tinygrad.helpers import all_same, all_int, dedup, round_up, prod, DEBUG, RING
5
- from tinygrad.dtype import DType, ConstType
6
- from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
7
- from tinygrad.lazy import LazyBuffer
4
+ from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
5
+ from tinygrad.dtype import DType
6
+ from tinygrad.ops import Ops, MathTrait
7
+ from tinygrad.engine.lazy import LazyBuffer
8
8
  from tinygrad.shape.shapetracker import sint
9
9
 
10
- def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
10
+ def all_reduce(bop: Ops, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
11
11
  assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
12
12
  assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
13
- bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[op]
14
-
15
- n_lbs, dim = len(lbs), prod(lbs[0].shape)
16
- # Ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
17
- # so just fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
18
- use_ring = (RING >= 2 or (n_lbs > 2 and dim > 256_000 and RING >= 1))
19
- if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{dim} | {lbs[0].dtype}")
20
- if not use_ring:
21
- return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
22
- factor = max(f for f in [32, 16, 8, 4, 2, 1] if dim % f == 0)
23
- base, left = (dim // factor) // n_lbs, (dim // factor) % n_lbs
24
- c_lens = [(base + 1) * factor if i < left else base * factor for i in range(n_lbs)]
13
+ n_lbs, shape, numel = len(lbs), lbs[0].shape, prod(lbs[0].shape)
14
+ # ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
15
+ # fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
16
+ use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
17
+ if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {lbs[0].dtype}")
18
+ if not use_ring: return [functools.reduce(lambda x,y: x.alu(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
19
+
20
+ factor = next(f for f in [32, 16, 8, 4, 2, 1] if numel % f == 0)
21
+ base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
22
+ chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
25
23
  acc = 0
26
- chunks = [(acc, (acc := acc + i)) for i in c_lens if i > 0]
27
- chunked = [[lb.reshape((dim,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
24
+ chunks = [(acc, (acc := acc + i)) for i in chunk_sizes if i > 0]
25
+ chunked = [[lb.reshape((numel,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
28
26
 
29
- # Scatter-reduce step
30
- for step in range(n_lbs - 1):
27
+ # scatter-reduce
28
+ for step in range(n_lbs-1):
31
29
  for i in range(len(chunks)):
32
- s, r = (i+step)%n_lbs, (i+step+1)%n_lbs
33
- chunked[r][i] = chunked[r][i].e(bop, chunked[s][i].copy_to_device(chunked[r][i].device, force=True))
30
+ src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
31
+ chunked[dest][i] = chunked[dest][i].alu(bop, chunked[src][i].copy_to_device(chunked[dest][i].device, force=True))
34
32
 
35
- # Allgather step
36
- for step in range(n_lbs - 1):
33
+ # allgather
34
+ for step in range(n_lbs-1):
37
35
  for i in range(len(chunks)):
38
- s, r = (i+step-1)%n_lbs, (i+step)%n_lbs
39
- chunked[r][i] = chunked[s][i].copy_to_device(chunked[r][i].device, force=True)
36
+ src, dest = (i+step-1)%n_lbs, (i+step)%n_lbs
37
+ chunked[dest][i] = chunked[src][i].copy_to_device(chunked[dest][i].device, force=True)
40
38
 
41
- # Assemble chunks back
42
- pads = [((s,dim-e),) for s,e in chunks]
43
- return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [c.pad(pads[i]) for i,c in enumerate(lb_c)]).reshape(lbs[0].shape) for lb_c in chunked]
39
+ # assemble chunks back
40
+ pads = [((s,numel-e),) for s,e in chunks]
41
+ return [functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads,lb_c)]).reshape(shape) for lb_c in chunked]
44
42
 
45
- def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
46
- if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
47
- sz = round_up(lbs[0].shape[axis], len(lbs)) // len(lbs)
48
- return [lb.shrink(tuple((0,s) if a != axis else (min(s,sz*i),min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
43
+ def to_sharded(lbs:List[LazyBuffer], axis:int, bounds: Tuple[Tuple[int, int], ...]) -> List[LazyBuffer]:
44
+ if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
45
+ return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
49
46
 
50
- class MultiLazyBuffer:
47
+ class MultiLazyBuffer(MathTrait):
51
48
  def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None):
52
49
  assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them"
53
50
  assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}"
54
51
  self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
55
52
  if axis is not None:
56
53
  splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0))
57
- self.bounds = list(zip(splits, splits[1:]))
54
+ self.bounds = tuple(zip(splits, splits[1:]))
58
55
 
59
56
  @property
60
- def shape(self):
61
- return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
57
+ def shape(self): return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
62
58
 
63
59
  @property
64
60
  def size(self): return sum(x.size for x in self.real_lbs)
@@ -66,108 +62,116 @@ class MultiLazyBuffer:
66
62
  @property
67
63
  def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r]
68
64
 
69
- def __repr__(self):
70
- return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
65
+ def __repr__(self): return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
71
66
 
72
67
  @staticmethod
73
- def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
74
- lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unmasked_const() else lb] * len(devices)
75
- sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)]
68
+ def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int], bounds:Optional[Tuple[Tuple[int, int], ...]]):
69
+ assert (axis is None) == (bounds is None), "must specify bounds iff axis is specified"
70
+ lbs = [lb] * len(devices)
71
+ sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis, bounds) if axis is not None and bounds is not None else lbs, devices)]
76
72
  return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis)
77
73
 
78
74
  def copy_to_device(self, device:str) -> LazyBuffer:
79
75
  if self.axis is None:
80
76
  # if we already have a copy on the device, return that
81
- for lb in self.real_lbs:
82
- if lb.device == device: return lb
83
- return self.lbs[self.real.index(True)].copy_to_device(device)
77
+ return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device))
78
+ # copy lbs to device, pad to final shape, and sum
84
79
  llbs:List[LazyBuffer] = []
85
80
  for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
86
81
  if not real: continue
87
82
  pad_arg = tuple((0,0) if a != self.axis else (start, self.bounds[-1][1]-end) for a in range(len(lb.shape)))
88
83
  llbs.append(lb.copy_to_device(device).pad(pad_arg))
89
- return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
84
+ return functools.reduce(operator.add, llbs)
90
85
 
91
86
  # passthroughs
92
- def is_realized(self) -> bool: return all(lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True)
93
- def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
94
- def const(self, val:ConstType) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real)
87
+ @property
88
+ def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs)
89
+ def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True):
90
+ return MultiLazyBuffer([x.cast(dtype, bitcast, allow_buffer_view) for x in self.lbs], self.axis, self.real)
91
+ def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real)
95
92
  def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
96
93
  def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
94
+ def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real)
97
95
 
98
96
  # elementwise is simple
99
- def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer:
97
+ def alu(self, op:Ops, *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer:
100
98
  msrcs = (self,)+in_srcs
101
99
  assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
102
100
  assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
103
101
 
104
102
  # NOTE: they all have to share an axis, we always choose [-1]
105
- axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None
106
- srcs = []
107
- not_all_real = any(not all(mlb.real) for mlb in msrcs)
103
+ axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None)
104
+ srcs:List[List[LazyBuffer]] = []
105
+ not_all_real = not all(all(mlb.real) for mlb in msrcs)
108
106
  new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real
109
107
  assert any(new_real), "output contains no real lb"
110
108
  for mlb in msrcs:
111
- if mlb.axis == axis or not_all_real: srcs.append(mlb.lbs)
112
- elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis))
113
- else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
114
- # NOTE: lsrcs[-1].const(0) is correct for where
115
- return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) if r else lsrcs[-1].const(0) for lsrcs,r in zip(zip(*srcs),new_real)], axis, new_real)
116
-
117
- def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
118
- return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
119
-
120
- def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
109
+ if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs)
110
+ elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis, bounds))
111
+ else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
112
+ new_real_lbs:Dict[int,LazyBuffer] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
113
+ # NOTE: const dtype should match real
114
+ new_dtype = next(iter(new_real_lbs.values())).dtype
115
+ return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
116
+
117
+ def r(self, op:Ops, axis:Tuple[int, ...]) -> MultiLazyBuffer:
121
118
  if self.axis is not None and self.axis in axis:
122
119
  # all-reduce on sharded axes
123
- reduced_parts = [(x if r else x.const(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]
120
+ reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]
121
+ # if all partitions are real, do all_reduce
124
122
  if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
123
+ # only one partition is real, keep it
125
124
  return MultiLazyBuffer(reduced_parts, None, self.real)
126
125
  # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
127
126
  return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)
128
127
 
129
128
  # *** movement ops ***
130
129
 
130
+ def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
131
+ return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
132
+
131
133
  def reshape(self, arg:Tuple[sint, ...]):
132
134
  if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
135
+ assert prod(self.shape) == prod(arg), "reshape must maintain prod(shape)"
133
136
  arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
134
137
  # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
135
138
  # todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
136
139
  new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
137
- if arg[new_axis] != self.shape[self.axis]:
138
- assert self.shape[self.axis] % len(self.real_lbs) == 0, f"cannot reshape on-axis for uneven shard {self.axis} {self.shape} {len(self.real_lbs)}"
139
- assert arg[new_axis] % len(self.real_lbs) == 0, f"new on-axis shape must divide evenly between devices {new_axis} {arg} {len(self.real_lbs)}"
140
- return MultiLazyBuffer([x.reshape(tuple(s if a != new_axis else
141
- x.shape[self.axis] if s == self.shape[self.axis] else
142
- s // len(self.real_lbs) for a,s in enumerate(arg))) for x in self.lbs],
143
- new_axis, self.real)
140
+ assert all(prod(lb.shape[self.axis:])%prod(arg[new_axis+1:])==0 for lb in self.lbs), f"reshape cannot move items between shards {self=} {arg=}"
141
+ lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[self.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in self.lbs]
142
+ return MultiLazyBuffer(lbs, new_axis, self.real)
144
143
 
145
144
  def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
146
145
  assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}"
147
146
  # pad on shard axis -> fill others with zeros and set real to all True
148
147
  if self.axis is not None and arg[self.axis] != (0,0):
149
148
  # pad back to whole axis, remove real mask
150
- assert all(arg[i] == (0, 0) or i == self.axis for i in range(len(self.shape))), "cannot pad sharded and non-sharded axis at the same time"
151
- assert arg[self.axis] == (sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i < self.real.index(True)), \
152
- sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i > self.real.index(True))), "can only pad to whole axis"
153
- return MultiLazyBuffer([x if r else x.const(0) for x,r in zip(self.lbs, self.real)], self.axis)
149
+ assert all(arg[i] == (0, 0) for i in range(len(self.shape)) if i != self.axis), "cannot pad sharded and non-sharded axis at the same time"
150
+ dim, bound = sum(lb.shape[self.axis] for lb in self.lbs), self.bounds[self.real.index(True)]
151
+ assert arg[self.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis"
152
+ return MultiLazyBuffer([x if r else x.const_like(0) for x,r in zip(self.lbs, self.real)], self.axis)
154
153
  return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)
154
+
155
155
  def expand(self, arg:Tuple[sint, ...]):
156
156
  # NOTE: this assert isn't needed, sharded axis can have dim 1
157
157
  assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}"
158
158
  return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real)
159
+
159
160
  def permute(self, arg:Tuple[int, ...]):
160
161
  # all permutes supported!
161
162
  return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real)
163
+
162
164
  def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
163
165
  assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
164
166
  if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
165
167
  assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time"
168
+ # NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
166
169
  idx = self.bounds.index(arg[self.axis])
167
170
  # zero out other lbs to not create lb reference
168
- return MultiLazyBuffer([lb if i==idx else lb.const(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
171
+ return MultiLazyBuffer([lb if i==idx else lb.const_like(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
169
172
  return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs],
170
173
  self.axis, self.real)
174
+
171
175
  def stride(self, arg:Tuple[int, ...]):
172
176
  assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
173
177
  return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)