tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/helpers.py
CHANGED
@@ -1,10 +1,8 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3,
|
3
|
-
import
|
4
|
-
from
|
5
|
-
|
6
|
-
from typing_extensions import TypeGuard
|
7
|
-
from tinygrad.shape.shapetracker import sint
|
2
|
+
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
|
3
|
+
import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard
|
8
6
|
|
9
7
|
T = TypeVar("T")
|
10
8
|
U = TypeVar("U")
|
@@ -15,6 +13,9 @@ def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.mul, x
|
|
15
13
|
OSX = platform.system() == "Darwin"
|
16
14
|
CI = os.getenv("CI", "") != ""
|
17
15
|
|
16
|
+
# fix colors on Windows, https://stackoverflow.com/questions/12492810/python-how-can-i-make-the-ansi-escape-codes-to-work-also-in-windows
|
17
|
+
if sys.platform == "win32": os.system("")
|
18
|
+
|
18
19
|
def dedup(x:Iterable[T]): return list(dict.fromkeys(x)) # retains list order
|
19
20
|
def argfix(*x):
|
20
21
|
if x and x[0].__class__ in (tuple, list):
|
@@ -22,51 +23,53 @@ def argfix(*x):
|
|
22
23
|
return tuple(x[0])
|
23
24
|
return x
|
24
25
|
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
25
|
-
def all_same(items:List[T]): return all(x == items[0] for x in items)
|
26
|
+
def all_same(items:Union[Tuple[T, ...], List[T]]): return all(x == items[0] for x in items)
|
26
27
|
def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
27
28
|
def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
|
29
|
+
def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
|
30
|
+
def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
|
28
31
|
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
29
32
|
def ansilen(s:str): return len(ansistrip(s))
|
30
|
-
def
|
33
|
+
def make_tuple(x:Union[int, Sequence[int]], cnt:int) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
|
31
34
|
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
|
32
|
-
def fully_flatten(l):
|
35
|
+
def fully_flatten(l):
|
36
|
+
if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):
|
37
|
+
flattened = []
|
38
|
+
if hasattr(l, "shape") and l.shape == (): flattened.append(l[()])
|
39
|
+
else:
|
40
|
+
for i in range(len(l)): flattened.extend(fully_flatten(l[i]))
|
41
|
+
return flattened
|
42
|
+
return [l]
|
33
43
|
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
34
44
|
def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
35
|
-
def
|
45
|
+
def ceildiv(num, amt):
|
46
|
+
ret = -(num//-amt)
|
47
|
+
return ret if not isinstance(ret, float) else int(ret)
|
48
|
+
def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
|
49
|
+
def data64(data: int) -> Tuple[int, int]: return (data >> 32, data & 0xFFFFFFFF)
|
50
|
+
def data64_le(data: int) -> Tuple[int, int]: return (data & 0xFFFFFFFF, data >> 32)
|
36
51
|
def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
|
37
|
-
|
52
|
+
kvs = set([(k,v) for d in ds for k,v in d.items()])
|
53
|
+
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
38
54
|
return {k:v for d in ds for k,v in d.items()}
|
39
|
-
def partition(
|
55
|
+
def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
|
40
56
|
a:List[T] = []
|
41
57
|
b:List[T] = []
|
42
|
-
for s in
|
58
|
+
for s in itr: (a if fxn(s) else b).append(s)
|
43
59
|
return a,b
|
44
60
|
def unwrap(x:Optional[T]) -> T:
|
45
61
|
assert x is not None
|
46
62
|
return x
|
47
|
-
def unwrap2(x:Tuple[T,Any]) -> T:
|
48
|
-
ret, err = x
|
49
|
-
assert err is None, str(err)
|
50
|
-
return ret
|
51
63
|
def get_child(obj, key):
|
52
64
|
for k in key.split('.'):
|
53
65
|
if k.isnumeric(): obj = obj[int(k)]
|
54
66
|
elif isinstance(obj, dict): obj = obj[k]
|
55
67
|
else: obj = getattr(obj, k)
|
56
68
|
return obj
|
69
|
+
def word_wrap(x, wrap=80): return x if len(x) <= wrap else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
|
57
70
|
|
58
|
-
|
59
|
-
|
60
|
-
subs = [get_shape(xi) for xi in x]
|
61
|
-
if not all_same(subs): raise ValueError(f"inhomogeneous shape from {x}")
|
62
|
-
return (len(subs),) + (subs[0] if subs else ())
|
63
|
-
|
64
|
-
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
65
|
-
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
|
66
|
-
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
67
|
-
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
|
68
|
-
except ValueError: return None
|
69
|
-
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
|
71
|
+
# for length N coefficients `p`, returns p[0] * x**(N-1) + p[1] * x**(N-2) + ... + p[-2] * x + p[-1]
|
72
|
+
def polyN(x:T, p:List[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore
|
70
73
|
|
71
74
|
@functools.lru_cache(maxsize=None)
|
72
75
|
def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
|
@@ -74,8 +77,6 @@ def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+str
|
|
74
77
|
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
|
75
78
|
def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
|
76
79
|
|
77
|
-
class GraphException(Exception): pass
|
78
|
-
|
79
80
|
class Context(contextlib.ContextDecorator):
|
80
81
|
stack: ClassVar[List[dict[str, int]]] = [{}]
|
81
82
|
def __init__(self, **kwargs): self.kwargs = kwargs
|
@@ -90,20 +91,31 @@ class ContextVar:
|
|
90
91
|
_cache: ClassVar[Dict[str, ContextVar]] = {}
|
91
92
|
value: int
|
92
93
|
key: str
|
93
|
-
def
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
return instance
|
94
|
+
def __init__(self, key, default_value):
|
95
|
+
assert key not in ContextVar._cache, f"attempt to recreate ContextVar {key}"
|
96
|
+
ContextVar._cache[key] = self
|
97
|
+
self.value, self.key = getenv(key, default_value), key
|
98
98
|
def __bool__(self): return bool(self.value)
|
99
99
|
def __ge__(self, x): return self.value >= x
|
100
100
|
def __gt__(self, x): return self.value > x
|
101
101
|
def __lt__(self, x): return self.value < x
|
102
102
|
|
103
103
|
DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
|
104
|
-
WINO,
|
105
|
-
|
106
|
-
|
104
|
+
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
|
105
|
+
PROFILE, PROFILEPATH = ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json"))
|
106
|
+
USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1)
|
107
|
+
FUSE_ARANGE, FUSE_CONV_BW, LAZYCACHE = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0), ContextVar("LAZYCACHE", 1)
|
108
|
+
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
109
|
+
|
110
|
+
@dataclass(frozen=True)
|
111
|
+
class Metadata:
|
112
|
+
name: str
|
113
|
+
caller: str
|
114
|
+
backward: bool = False
|
115
|
+
def __hash__(self): return hash(self.name)
|
116
|
+
def __repr__(self): return str(self) + (f" - {self.caller}" if self.caller else "")
|
117
|
+
def __str__(self): return self.name + (" bw" if self.backward else "")
|
118
|
+
_METADATA: contextvars.ContextVar[Optional[Metadata]] = contextvars.ContextVar("_METADATA", default=None)
|
107
119
|
|
108
120
|
# **************** global state Counters ****************
|
109
121
|
|
@@ -130,47 +142,21 @@ class Profiling(contextlib.ContextDecorator):
|
|
130
142
|
def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None, ts=1):
|
131
143
|
self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3/ts
|
132
144
|
def __enter__(self):
|
145
|
+
import cProfile
|
133
146
|
self.pr = cProfile.Profile()
|
134
147
|
if self.enabled: self.pr.enable()
|
135
148
|
def __exit__(self, *exc):
|
136
149
|
if self.enabled:
|
137
150
|
self.pr.disable()
|
138
151
|
if self.fn: self.pr.dump_stats(self.fn)
|
152
|
+
import pstats
|
139
153
|
stats = pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort)
|
140
154
|
for fcn in stats.fcn_list[0:int(len(stats.fcn_list)*self.frac)]: # type: ignore[attr-defined]
|
141
155
|
(_primitive_calls, num_calls, tottime, cumtime, callers) = stats.stats[fcn] # type: ignore[attr-defined]
|
142
156
|
scallers = sorted(callers.items(), key=lambda x: -x[1][2])
|
143
157
|
print(f"n:{num_calls:8d} tm:{tottime*self.time_scale:7.2f}ms tot:{cumtime*self.time_scale:7.2f}ms",
|
144
|
-
colored(_format_fcn(fcn), "yellow")
|
145
|
-
colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if
|
146
|
-
|
147
|
-
class ProfileLogger:
|
148
|
-
writers: int = 0
|
149
|
-
mjson: List[Dict] = []
|
150
|
-
actors: Dict[str, int] = {}
|
151
|
-
subactors: Dict[Tuple[str, str], int] = {}
|
152
|
-
path = getenv("PROFILE_OUTPUT_FILE", temp("tinygrad_profile.json"))
|
153
|
-
|
154
|
-
def __init__(self): self.events, ProfileLogger.writers = [], ProfileLogger.writers + 1
|
155
|
-
|
156
|
-
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor)]
|
157
|
-
|
158
|
-
def __del__(self):
|
159
|
-
for name,st,et,actor_name,subactor_name in self.events:
|
160
|
-
if actor_name not in self.actors:
|
161
|
-
self.actors[actor_name] = (pid:=len(self.actors))
|
162
|
-
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
|
163
|
-
|
164
|
-
if (subactor_key:=(actor_name,subactor_name)) not in self.subactors:
|
165
|
-
self.subactors[subactor_key] = (tid:=len(self.subactors))
|
166
|
-
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
|
167
|
-
|
168
|
-
self.mjson.append({"name": name, "ph": "X", "pid": self.actors[actor_name], "tid": self.subactors.get(subactor_key, -1), "ts":st, "dur":et-st})
|
169
|
-
|
170
|
-
ProfileLogger.writers -= 1
|
171
|
-
if ProfileLogger.writers == 0:
|
172
|
-
with open(self.path, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
|
173
|
-
print(f"Saved profile to {self.path}. Use https://ui.perfetto.dev/ to open it.")
|
158
|
+
colored(_format_fcn(fcn).ljust(50), "yellow"),
|
159
|
+
colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if scallers else '')
|
174
160
|
|
175
161
|
# *** universal database cache ***
|
176
162
|
|
@@ -184,14 +170,17 @@ def db_connection():
|
|
184
170
|
global _db_connection
|
185
171
|
if _db_connection is None:
|
186
172
|
os.makedirs(CACHEDB.rsplit(os.sep, 1)[0], exist_ok=True)
|
187
|
-
_db_connection = sqlite3.connect(CACHEDB)
|
173
|
+
_db_connection = sqlite3.connect(CACHEDB, timeout=60, isolation_level="IMMEDIATE")
|
174
|
+
# another connection has set it already or is in the process of setting it
|
175
|
+
# that connection will lock the database
|
176
|
+
with contextlib.suppress(sqlite3.OperationalError): _db_connection.execute("PRAGMA journal_mode=WAL").fetchone()
|
188
177
|
if DEBUG >= 7: _db_connection.set_trace_callback(print)
|
189
178
|
return _db_connection
|
190
179
|
|
191
180
|
def diskcache_clear():
|
192
181
|
cur = db_connection().cursor()
|
193
182
|
drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
|
194
|
-
cur.executescript("\n".join([s[0] for s in drop_tables]))
|
183
|
+
cur.executescript("\n".join([s[0] for s in drop_tables] + ["VACUUM;"]))
|
195
184
|
|
196
185
|
def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
|
197
186
|
if CACHELEVEL == 0: return None
|
@@ -230,22 +219,36 @@ def diskcache(func):
|
|
230
219
|
|
231
220
|
# *** http support ***
|
232
221
|
|
233
|
-
def
|
222
|
+
def _ensure_downloads_dir() -> pathlib.Path:
|
223
|
+
# if we are on a tinybox, use the raid array
|
224
|
+
if pathlib.Path("/etc/tinybox-release").is_file():
|
225
|
+
# try creating dir with sudo
|
226
|
+
if not (downloads_dir := pathlib.Path("/raid/downloads")).exists():
|
227
|
+
subprocess.run(["sudo", "mkdir", "-p", downloads_dir], check=True)
|
228
|
+
subprocess.run(["sudo", "chown", "tiny:root", downloads_dir], check=True)
|
229
|
+
subprocess.run(["sudo", "chmod", "775", downloads_dir], check=True)
|
230
|
+
return downloads_dir
|
231
|
+
return pathlib.Path(_cache_dir) / "tinygrad" / "downloads"
|
232
|
+
|
233
|
+
def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None, gunzip:bool=False,
|
234
234
|
allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
|
235
235
|
if url.startswith(("/", ".")): return pathlib.Path(url)
|
236
236
|
if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
|
237
|
-
else:
|
237
|
+
else:
|
238
|
+
fp = _ensure_downloads_dir() / (subdir or "") / \
|
239
|
+
((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
|
238
240
|
if not fp.is_file() or not allow_caching:
|
239
241
|
with urllib.request.urlopen(url, timeout=10) as r:
|
240
242
|
assert r.status == 200
|
241
|
-
|
242
|
-
progress_bar = tqdm(total=
|
243
|
+
length = int(r.headers.get('content-length', 0)) if not gunzip else None
|
244
|
+
progress_bar = tqdm(total=length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
|
243
245
|
(path := fp.parent).mkdir(parents=True, exist_ok=True)
|
246
|
+
readfile = gzip.GzipFile(fileobj=r) if gunzip else r
|
244
247
|
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
|
245
|
-
while chunk :=
|
248
|
+
while chunk := readfile.read(16384): progress_bar.update(f.write(chunk))
|
246
249
|
f.close()
|
247
250
|
progress_bar.update(close=True)
|
248
|
-
if (file_size:=os.stat(f.name).st_size) <
|
251
|
+
if length and (file_size:=os.stat(f.name).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
|
249
252
|
pathlib.Path(f.name).rename(fp)
|
250
253
|
return fp
|
251
254
|
|
@@ -256,10 +259,10 @@ def cpu_time_execution(cb, enable):
|
|
256
259
|
cb()
|
257
260
|
if enable: return time.perf_counter()-st
|
258
261
|
|
259
|
-
def cpu_objdump(lib):
|
262
|
+
def cpu_objdump(lib, objdump_tool='objdump'):
|
260
263
|
with tempfile.NamedTemporaryFile(delete=True) as f:
|
261
264
|
pathlib.Path(f.name).write_bytes(lib)
|
262
|
-
print(subprocess.check_output([
|
265
|
+
print(subprocess.check_output([objdump_tool, '-d', f.name]).decode('utf-8'))
|
263
266
|
|
264
267
|
# *** ctypes helpers
|
265
268
|
|
@@ -277,34 +280,49 @@ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
|
|
277
280
|
def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
|
278
281
|
def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))
|
279
282
|
|
283
|
+
# *** tqdm
|
284
|
+
|
280
285
|
class tqdm:
|
281
|
-
def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:int
|
282
|
-
self.
|
283
|
-
self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1,
|
286
|
+
def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:Optional[int]=None, rate:int=100):
|
287
|
+
self.iterable, self.disable, self.unit, self.unit_scale, self.rate = iterable, disable, unit, unit_scale, rate
|
288
|
+
self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, getattr(iterable, "__len__", lambda:0)() if total is None else total
|
289
|
+
self.set_description(desc)
|
284
290
|
self.update(0)
|
285
291
|
def __iter__(self):
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
292
|
+
for item in self.iterable:
|
293
|
+
yield item
|
294
|
+
self.update(1)
|
295
|
+
self.update(close=True)
|
296
|
+
def __enter__(self): return self
|
297
|
+
def __exit__(self, *_): self.update(close=True)
|
291
298
|
def set_description(self, desc:str): self.desc = f"{desc}: " if desc else ""
|
292
299
|
def update(self, n:int=0, close:bool=False):
|
293
300
|
self.n, self.i = self.n+n, self.i+1
|
294
|
-
if (self.i % self.skip != 0
|
295
|
-
prog,
|
296
|
-
if self.i/
|
297
|
-
def
|
298
|
-
def
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
301
|
+
if self.disable or (not close and self.i % self.skip != 0): return
|
302
|
+
prog, elapsed, ncols = self.n/self.t if self.t else 0, time.perf_counter()-self.st, shutil.get_terminal_size().columns
|
303
|
+
if self.i/elapsed > self.rate and self.i: self.skip = max(int(self.i/elapsed)//self.rate,1)
|
304
|
+
def HMS(t): return ':'.join(f'{x:02d}' if i else str(x) for i,x in enumerate([int(t)//3600,int(t)%3600//60,int(t)%60]) if i or x)
|
305
|
+
def SI(x): return (f"{x/1000**int(g:=math.log(x,1000)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)].strip()) if x else '0.00'
|
306
|
+
prog_text = f'{SI(self.n)}{f"/{SI(self.t)}" if self.t else self.unit}' if self.unit_scale else f'{self.n}{f"/{self.t}" if self.t else self.unit}'
|
307
|
+
est_text = f'<{HMS(elapsed/prog-elapsed) if self.n else "?"}' if self.t else ''
|
308
|
+
it_text = (SI(self.n/elapsed) if self.unit_scale else f"{self.n/elapsed:5.2f}") if self.n else "?"
|
309
|
+
suf = f'{prog_text} [{HMS(elapsed)}{est_text}, {it_text}{self.unit}/s]'
|
310
|
+
sz = max(ncols-len(self.desc)-3-2-2-len(suf), 1)
|
311
|
+
bar = '\r' + self.desc + (f'{100*prog:3.0f}%|{("█"*int(num:=sz*prog)+" ▏▎▍▌▋▊▉"[int(8*num)%8].strip()).ljust(sz," ")}| ' if self.t else '') + suf
|
312
|
+
print(bar[:ncols+1], flush=True, end='\n'*close, file=sys.stderr)
|
313
|
+
@classmethod
|
314
|
+
def write(cls, s:str): print(f"\r\033[K{s}", flush=True, file=sys.stderr)
|
308
315
|
|
309
316
|
class trange(tqdm):
|
310
|
-
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
|
317
|
+
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
|
318
|
+
|
319
|
+
# *** universal support for code object pickling
|
320
|
+
|
321
|
+
def _reconstruct_code(*args): return types.CodeType(*args)
|
322
|
+
def _serialize_code(code:types.CodeType):
|
323
|
+
args = inspect.signature(types.CodeType).parameters # NOTE: this works in Python 3.10 and up
|
324
|
+
return _reconstruct_code, tuple(code.__getattribute__('co_'+x.replace('codestring', 'code').replace('constants', 'consts')) for x in args)
|
325
|
+
copyreg.pickle(types.CodeType, _serialize_code)
|
326
|
+
|
327
|
+
def _serialize_module(module:types.ModuleType): return importlib.import_module, (module.__name__,)
|
328
|
+
copyreg.pickle(types.ModuleType, _serialize_module)
|
tinygrad/multi.py
CHANGED
@@ -1,64 +1,60 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import Optional,
|
2
|
+
from typing import Optional, Tuple, List, Dict
|
3
3
|
import functools, itertools, operator
|
4
|
-
from tinygrad.helpers import all_same, all_int, dedup,
|
5
|
-
from tinygrad.dtype import DType
|
6
|
-
from tinygrad.ops import
|
7
|
-
from tinygrad.lazy import LazyBuffer
|
4
|
+
from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
|
5
|
+
from tinygrad.dtype import DType
|
6
|
+
from tinygrad.ops import Ops, MathTrait
|
7
|
+
from tinygrad.engine.lazy import LazyBuffer
|
8
8
|
from tinygrad.shape.shapetracker import sint
|
9
9
|
|
10
|
-
def all_reduce(
|
10
|
+
def all_reduce(bop: Ops, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
|
11
11
|
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
|
12
12
|
assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
use_ring
|
19
|
-
|
20
|
-
if
|
21
|
-
|
22
|
-
|
23
|
-
base, left = (dim // factor) // n_lbs, (dim // factor) % n_lbs
|
24
|
-
c_lens = [(base + 1) * factor if i < left else base * factor for i in range(n_lbs)]
|
13
|
+
n_lbs, shape, numel = len(lbs), lbs[0].shape, prod(lbs[0].shape)
|
14
|
+
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
15
|
+
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
16
|
+
use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
17
|
+
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {lbs[0].dtype}")
|
18
|
+
if not use_ring: return [functools.reduce(lambda x,y: x.alu(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
19
|
+
|
20
|
+
factor = next(f for f in [32, 16, 8, 4, 2, 1] if numel % f == 0)
|
21
|
+
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
|
22
|
+
chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
|
25
23
|
acc = 0
|
26
|
-
chunks = [(acc, (acc := acc + i)) for i in
|
27
|
-
chunked = [[lb.reshape((
|
24
|
+
chunks = [(acc, (acc := acc + i)) for i in chunk_sizes if i > 0]
|
25
|
+
chunked = [[lb.reshape((numel,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
|
28
26
|
|
29
|
-
#
|
30
|
-
for step in range(n_lbs
|
27
|
+
# scatter-reduce
|
28
|
+
for step in range(n_lbs-1):
|
31
29
|
for i in range(len(chunks)):
|
32
|
-
|
33
|
-
chunked[
|
30
|
+
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
|
31
|
+
chunked[dest][i] = chunked[dest][i].alu(bop, chunked[src][i].copy_to_device(chunked[dest][i].device, force=True))
|
34
32
|
|
35
|
-
#
|
36
|
-
for step in range(n_lbs
|
33
|
+
# allgather
|
34
|
+
for step in range(n_lbs-1):
|
37
35
|
for i in range(len(chunks)):
|
38
|
-
|
39
|
-
chunked[
|
36
|
+
src, dest = (i+step-1)%n_lbs, (i+step)%n_lbs
|
37
|
+
chunked[dest][i] = chunked[src][i].copy_to_device(chunked[dest][i].device, force=True)
|
40
38
|
|
41
|
-
#
|
42
|
-
pads = [((s,
|
43
|
-
return [functools.reduce(
|
39
|
+
# assemble chunks back
|
40
|
+
pads = [((s,numel-e),) for s,e in chunks]
|
41
|
+
return [functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads,lb_c)]).reshape(shape) for lb_c in chunked]
|
44
42
|
|
45
|
-
def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
|
46
|
-
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
|
47
|
-
|
48
|
-
return [lb.shrink(tuple((0,s) if a != axis else (min(s,sz*i),min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
|
43
|
+
def to_sharded(lbs:List[LazyBuffer], axis:int, bounds: Tuple[Tuple[int, int], ...]) -> List[LazyBuffer]:
|
44
|
+
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
|
45
|
+
return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
|
49
46
|
|
50
|
-
class MultiLazyBuffer:
|
47
|
+
class MultiLazyBuffer(MathTrait):
|
51
48
|
def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None):
|
52
49
|
assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them"
|
53
50
|
assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}"
|
54
51
|
self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
|
55
52
|
if axis is not None:
|
56
53
|
splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0))
|
57
|
-
self.bounds =
|
54
|
+
self.bounds = tuple(zip(splits, splits[1:]))
|
58
55
|
|
59
56
|
@property
|
60
|
-
def shape(self):
|
61
|
-
return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
|
57
|
+
def shape(self): return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
|
62
58
|
|
63
59
|
@property
|
64
60
|
def size(self): return sum(x.size for x in self.real_lbs)
|
@@ -66,108 +62,116 @@ class MultiLazyBuffer:
|
|
66
62
|
@property
|
67
63
|
def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r]
|
68
64
|
|
69
|
-
def __repr__(self):
|
70
|
-
return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
|
65
|
+
def __repr__(self): return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
|
71
66
|
|
72
67
|
@staticmethod
|
73
|
-
def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]
|
74
|
-
|
75
|
-
|
68
|
+
def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int], bounds:Optional[Tuple[Tuple[int, int], ...]]):
|
69
|
+
assert (axis is None) == (bounds is None), "must specify bounds iff axis is specified"
|
70
|
+
lbs = [lb] * len(devices)
|
71
|
+
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis, bounds) if axis is not None and bounds is not None else lbs, devices)]
|
76
72
|
return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis)
|
77
73
|
|
78
74
|
def copy_to_device(self, device:str) -> LazyBuffer:
|
79
75
|
if self.axis is None:
|
80
76
|
# if we already have a copy on the device, return that
|
81
|
-
for lb in self.real_lbs
|
82
|
-
|
83
|
-
return self.lbs[self.real.index(True)].copy_to_device(device)
|
77
|
+
return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device))
|
78
|
+
# copy lbs to device, pad to final shape, and sum
|
84
79
|
llbs:List[LazyBuffer] = []
|
85
80
|
for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
|
86
81
|
if not real: continue
|
87
82
|
pad_arg = tuple((0,0) if a != self.axis else (start, self.bounds[-1][1]-end) for a in range(len(lb.shape)))
|
88
83
|
llbs.append(lb.copy_to_device(device).pad(pad_arg))
|
89
|
-
return functools.reduce(
|
84
|
+
return functools.reduce(operator.add, llbs)
|
90
85
|
|
91
86
|
# passthroughs
|
92
|
-
|
93
|
-
def
|
94
|
-
def
|
87
|
+
@property
|
88
|
+
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs)
|
89
|
+
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True):
|
90
|
+
return MultiLazyBuffer([x.cast(dtype, bitcast, allow_buffer_view) for x in self.lbs], self.axis, self.real)
|
91
|
+
def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real)
|
95
92
|
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
|
96
93
|
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
|
94
|
+
def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real)
|
97
95
|
|
98
96
|
# elementwise is simple
|
99
|
-
def
|
97
|
+
def alu(self, op:Ops, *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer:
|
100
98
|
msrcs = (self,)+in_srcs
|
101
99
|
assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
|
102
100
|
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
|
103
101
|
|
104
102
|
# NOTE: they all have to share an axis, we always choose [-1]
|
105
|
-
axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None
|
106
|
-
srcs = []
|
107
|
-
not_all_real =
|
103
|
+
axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None)
|
104
|
+
srcs:List[List[LazyBuffer]] = []
|
105
|
+
not_all_real = not all(all(mlb.real) for mlb in msrcs)
|
108
106
|
new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real
|
109
107
|
assert any(new_real), "output contains no real lb"
|
110
108
|
for mlb in msrcs:
|
111
|
-
if mlb.axis == axis or not_all_real: srcs.append(mlb.lbs)
|
112
|
-
elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis))
|
113
|
-
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
|
109
|
+
if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs)
|
110
|
+
elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis, bounds))
|
111
|
+
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
|
112
|
+
new_real_lbs:Dict[int,LazyBuffer] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
|
113
|
+
# NOTE: const dtype should match real
|
114
|
+
new_dtype = next(iter(new_real_lbs.values())).dtype
|
115
|
+
return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
|
116
|
+
|
117
|
+
def r(self, op:Ops, axis:Tuple[int, ...]) -> MultiLazyBuffer:
|
121
118
|
if self.axis is not None and self.axis in axis:
|
122
119
|
# all-reduce on sharded axes
|
123
|
-
reduced_parts = [(x if r else x.
|
120
|
+
reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]
|
121
|
+
# if all partitions are real, do all_reduce
|
124
122
|
if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
|
123
|
+
# only one partition is real, keep it
|
125
124
|
return MultiLazyBuffer(reduced_parts, None, self.real)
|
126
125
|
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
127
126
|
return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)
|
128
127
|
|
129
128
|
# *** movement ops ***
|
130
129
|
|
130
|
+
def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
|
131
|
+
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
|
132
|
+
|
131
133
|
def reshape(self, arg:Tuple[sint, ...]):
|
132
134
|
if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
|
135
|
+
assert prod(self.shape) == prod(arg), "reshape must maintain prod(shape)"
|
133
136
|
arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
|
134
137
|
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
|
135
138
|
# todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
|
136
139
|
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
return MultiLazyBuffer([x.reshape(tuple(s if a != new_axis else
|
141
|
-
x.shape[self.axis] if s == self.shape[self.axis] else
|
142
|
-
s // len(self.real_lbs) for a,s in enumerate(arg))) for x in self.lbs],
|
143
|
-
new_axis, self.real)
|
140
|
+
assert all(prod(lb.shape[self.axis:])%prod(arg[new_axis+1:])==0 for lb in self.lbs), f"reshape cannot move items between shards {self=} {arg=}"
|
141
|
+
lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[self.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in self.lbs]
|
142
|
+
return MultiLazyBuffer(lbs, new_axis, self.real)
|
144
143
|
|
145
144
|
def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
|
146
145
|
assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}"
|
147
146
|
# pad on shard axis -> fill others with zeros and set real to all True
|
148
147
|
if self.axis is not None and arg[self.axis] != (0,0):
|
149
148
|
# pad back to whole axis, remove real mask
|
150
|
-
assert all(arg[i] == (0, 0)
|
151
|
-
|
152
|
-
|
153
|
-
return MultiLazyBuffer([x if r else x.
|
149
|
+
assert all(arg[i] == (0, 0) for i in range(len(self.shape)) if i != self.axis), "cannot pad sharded and non-sharded axis at the same time"
|
150
|
+
dim, bound = sum(lb.shape[self.axis] for lb in self.lbs), self.bounds[self.real.index(True)]
|
151
|
+
assert arg[self.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis"
|
152
|
+
return MultiLazyBuffer([x if r else x.const_like(0) for x,r in zip(self.lbs, self.real)], self.axis)
|
154
153
|
return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)
|
154
|
+
|
155
155
|
def expand(self, arg:Tuple[sint, ...]):
|
156
156
|
# NOTE: this assert isn't needed, sharded axis can have dim 1
|
157
157
|
assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}"
|
158
158
|
return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real)
|
159
|
+
|
159
160
|
def permute(self, arg:Tuple[int, ...]):
|
160
161
|
# all permutes supported!
|
161
162
|
return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real)
|
163
|
+
|
162
164
|
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
|
163
165
|
assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
|
164
166
|
if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
|
165
167
|
assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time"
|
168
|
+
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
|
166
169
|
idx = self.bounds.index(arg[self.axis])
|
167
170
|
# zero out other lbs to not create lb reference
|
168
|
-
return MultiLazyBuffer([lb if i==idx else lb.
|
171
|
+
return MultiLazyBuffer([lb if i==idx else lb.const_like(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
|
169
172
|
return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs],
|
170
173
|
self.axis, self.real)
|
174
|
+
|
171
175
|
def stride(self, arg:Tuple[int, ...]):
|
172
176
|
assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
|
173
177
|
return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)
|