tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/kernel.py +230 -190
- tinygrad/codegen/linearizer.py +278 -384
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +132 -275
- tinygrad/dtype.py +53 -37
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +28 -14
- tinygrad/helpers.py +72 -43
- tinygrad/lazy.py +141 -240
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +179 -8
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +86 -17
- tinygrad/ops.py +70 -44
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +299 -206
- tinygrad/renderer/llvmir.py +118 -123
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +130 -38
- tinygrad/runtime/ops_disk.py +45 -42
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +42 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +41 -105
- tinygrad/shape/symbolic.py +98 -95
- tinygrad/shape/view.py +137 -35
- tinygrad/tensor.py +2367 -442
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/lazy.py
CHANGED
@@ -1,96 +1,126 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import
|
3
|
-
import
|
4
|
-
from
|
5
|
-
from
|
6
|
-
from tinygrad.
|
7
|
-
from tinygrad.helpers import prod, merge_dicts, flatten, getenv, dedup, DEBUG, all_int, all_same
|
8
|
-
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
|
2
|
+
import math
|
3
|
+
from typing import Union, Optional, Any, Tuple, List
|
4
|
+
from tinygrad.dtype import dtypes, DType, ConstType
|
5
|
+
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG
|
6
|
+
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
|
9
7
|
from tinygrad.shape.symbolic import sint, Variable
|
10
8
|
from tinygrad.shape.shapetracker import ShapeTracker
|
11
|
-
from tinygrad.device import Buffer
|
12
|
-
from
|
13
|
-
from weakref import ref, WeakValueDictionary, ReferenceType
|
9
|
+
from tinygrad.device import Buffer
|
10
|
+
from weakref import ref, ReferenceType, WeakValueDictionary
|
14
11
|
|
15
|
-
|
16
|
-
|
12
|
+
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
|
13
|
+
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
14
|
+
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
|
15
|
+
if st.size == 0: op, arg, srcs, base = LoadOps.CONST, 0, (), None
|
16
|
+
if op is LoadOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, Variable) else arg, True
|
17
17
|
|
18
|
-
|
19
|
-
|
20
|
-
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
21
|
-
base:Optional[LazyBuffer]=None):
|
22
|
-
if 0 in st.shape: st, op, arg, srcs = ShapeTracker.from_shape(st.shape), LoadOps.CONST, 0, ()
|
23
|
-
|
24
|
-
wop = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs), ref(base) if base else None)
|
25
|
-
if wop in lazycache: return lazycache[wop]
|
18
|
+
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
|
19
|
+
if enable_cache and (rret := lazycache.get(cache_key, None)): return rret
|
26
20
|
|
27
21
|
ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base)
|
28
|
-
|
29
|
-
# TODO: might be possible to remove LoadOps.COPY
|
30
|
-
if op not in {LoadOps.EMPTY, LoadOps.CUSTOM, LoadOps.CONST, LoadOps.COPY} and getenv("LAZYCACHE", 1): lazycache[wop] = ret
|
22
|
+
if enable_cache: lazycache[cache_key] = ret
|
31
23
|
return ret
|
32
24
|
|
25
|
+
view_supported_devices = {"LLVM", "CLANG", "CUDA", "DISK"}
|
33
26
|
class LazyBuffer:
|
34
27
|
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
|
35
28
|
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
36
29
|
base:Optional[LazyBuffer]=None):
|
37
|
-
assert isinstance(device, str) and device == Device.canonicalize(device)
|
38
30
|
self.device, self.st, self.dtype, self.shape, self.size = device, st, dtype, st.shape, st.size
|
31
|
+
self._base: Optional[LazyBuffer] = None
|
39
32
|
if base is None:
|
40
33
|
# properties on base
|
41
34
|
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
|
42
|
-
self.
|
43
|
-
|
44
|
-
self.
|
35
|
+
assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
|
36
|
+
|
37
|
+
if (self.op is LoadOps.CONTIGUOUS or self.op is UnaryOps.BITCAST) and srcs[0].st.consecutive and \
|
38
|
+
not srcs[0].is_unrealized_const() and device.split(":")[0] in view_supported_devices:
|
39
|
+
# some LazyBuffers can be processed with only a view, no AST required
|
40
|
+
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
41
|
+
self.op = LoadOps.VIEW
|
42
|
+
else:
|
43
|
+
self.buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
|
44
|
+
self.buffer.ref(1)
|
45
45
|
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
|
46
|
+
self.forced_realize = False
|
46
47
|
else:
|
47
48
|
# properties on view
|
48
49
|
assert base.base == base, "base must be a base itself"
|
49
50
|
self._base = base
|
50
51
|
|
52
|
+
def __del__(self):
|
53
|
+
if hasattr(self, 'buffer'): self.buffer.ref(-1)
|
54
|
+
|
51
55
|
def __repr__(self) -> str:
|
52
|
-
return f"<LB {self.device} {self.shape}
|
56
|
+
return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base != self else (self.op, self.realized)}>"
|
53
57
|
|
54
58
|
@property
|
55
|
-
def
|
59
|
+
def realized(self) -> Optional[Buffer]:
|
60
|
+
# NOTE: we check for a lack of srcs instead of an allocated buffer to make unrealized assigns return None here
|
61
|
+
return self.buffer if self._base is None and not hasattr(self, 'srcs') else None
|
62
|
+
|
63
|
+
# NOTE: this has to be a function to prevent self reference
|
64
|
+
@property
|
65
|
+
def base(self) -> LazyBuffer: return self._base if self._base is not None else self
|
66
|
+
|
67
|
+
# same API as multi
|
68
|
+
@property
|
69
|
+
def lbs(self) -> List[LazyBuffer]: return [self]
|
56
70
|
|
57
71
|
@staticmethod
|
58
|
-
def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:
|
59
|
-
|
72
|
+
def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
|
73
|
+
assert isinstance(src, tuple)
|
74
|
+
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
|
60
75
|
|
61
|
-
def const(self, val:
|
62
|
-
|
76
|
+
def const(self, val:ConstType, shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer:
|
77
|
+
shape = self.shape if shape is None else shape
|
78
|
+
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
|
79
|
+
|
80
|
+
def is_realized(self) -> bool: return self.base.realized is not None
|
81
|
+
|
82
|
+
def assign(self, x:LazyBuffer) -> LazyBuffer:
|
83
|
+
assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
|
84
|
+
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base))
|
63
85
|
|
64
86
|
def contiguous(self):
|
65
87
|
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
66
88
|
ret = self.e(LoadOps.CONTIGUOUS)
|
67
|
-
sti
|
68
|
-
if sti is not None: self.base.contiguous_child = ref(ret), sti
|
89
|
+
if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
|
69
90
|
return ret
|
70
91
|
self.base.forced_realize = True
|
71
92
|
return self
|
72
93
|
|
73
94
|
def cast(self, dtype:DType, bitcast:bool=False):
|
74
95
|
if self.dtype == dtype: return self
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
96
|
+
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
|
97
|
+
if self.is_unrealized_unmasked_const() and not bitcast:
|
98
|
+
return create_lazybuffer(self.device, self.st, dtype, LoadOps.CONST, dtypes.as_const(self.base.arg, dtype))
|
99
|
+
# TODO: applying this makes gpt2 slower
|
100
|
+
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
|
101
|
+
return self.base.cast(dtype, bitcast)._view(self.st)
|
102
|
+
new_shape = self.shape
|
103
|
+
if bitcast and self.dtype.itemsize != dtype.itemsize:
|
104
|
+
if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
|
105
|
+
if not all_int(new_shape): raise RuntimeError("shape changing bitcast with symbolic shape isn't supported yet")
|
106
|
+
# https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
|
107
|
+
if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
|
108
|
+
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
|
109
|
+
cast_op = UnaryOps.BITCAST if bitcast else UnaryOps.CAST
|
110
|
+
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
|
111
|
+
|
112
|
+
def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST and not isinstance(self.base.arg, Variable)
|
113
|
+
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
114
|
+
|
115
|
+
def _copy(self, device:str) -> LazyBuffer:
|
116
|
+
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, self.buffer.nbytes, (self,), enable_cache=False)
|
117
|
+
|
118
|
+
def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer:
|
89
119
|
# no COPY
|
90
120
|
if self.device == device: return self
|
91
121
|
|
92
122
|
# double COPY = one COPY
|
93
|
-
if self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op
|
123
|
+
if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is LoadOps.COPY:
|
94
124
|
return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
|
95
125
|
|
96
126
|
# const doesn't have to be copied (issues with disk tensor)
|
@@ -98,11 +128,10 @@ class LazyBuffer:
|
|
98
128
|
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
|
99
129
|
|
100
130
|
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
101
|
-
if prod(self.st.shape) < prod(self.base.st.shape):
|
102
|
-
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, srcs=(self.contiguous(),))
|
131
|
+
if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
|
103
132
|
|
104
133
|
# copy the base and apply the shapetracker on the new device
|
105
|
-
return
|
134
|
+
return self.base._copy(device)._view(self.st)
|
106
135
|
|
107
136
|
def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
|
108
137
|
srcs: List[LazyBuffer] = []
|
@@ -111,36 +140,75 @@ class LazyBuffer:
|
|
111
140
|
srcs.append(root._view(s.base.contiguous_child[1]))
|
112
141
|
else:
|
113
142
|
srcs.append(s)
|
114
|
-
assert all_same(dts:=[x.dtype.scalar() for x in (srcs if op
|
143
|
+
assert all_same(dts:=[x.dtype.scalar() for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]), f"all dtypes must match {dts} on {op}"
|
115
144
|
assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
|
116
|
-
if op
|
117
|
-
|
118
|
-
|
119
|
-
|
145
|
+
if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
|
146
|
+
if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
|
147
|
+
|
148
|
+
out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else srcs[-1].dtype
|
149
|
+
|
150
|
+
# const folding
|
151
|
+
if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
|
152
|
+
return self.cast(out_dtype).const(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
|
153
|
+
if op is UnaryOps.NEG and self.base.op is UnaryOps.NEG: return self.base.srcs[0]
|
154
|
+
if op in BinaryOps: x, y = self, in_srcs[0]
|
155
|
+
if op is BinaryOps.ADD:
|
156
|
+
if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x # pylint: disable=possibly-used-before-assignment
|
157
|
+
if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y # pylint: disable=possibly-used-before-assignment
|
158
|
+
if op is BinaryOps.SUB and y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
|
159
|
+
if op is BinaryOps.MUL:
|
160
|
+
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0, -1):
|
161
|
+
return y if val == 1 else y.const(0) if val == 0 else y.e(UnaryOps.NEG)
|
162
|
+
if y.is_unrealized_unmasked_const() and (val := float(y.base.arg)) in (1, 0, -1):
|
163
|
+
return x if val == 1 else x.const(0) if val == 0 else x.e(UnaryOps.NEG)
|
164
|
+
if op is BinaryOps.DIV and dtypes.is_float(x.dtype) and y.is_unrealized_unmasked_const() and y.base.arg != 0:
|
165
|
+
return x.e(BinaryOps.MUL, x.const(1 / y.base.arg))
|
166
|
+
|
167
|
+
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
|
120
168
|
|
121
169
|
# *** reduce ops ***
|
122
170
|
|
123
|
-
def _reduce_op(self, op:ReduceOps,
|
124
|
-
|
125
|
-
|
126
|
-
|
171
|
+
def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
|
172
|
+
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
|
173
|
+
axis = tuple(x for x in axis if self.shape[x] != 1)
|
174
|
+
if len(axis) == 0: return self
|
175
|
+
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
176
|
+
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, axis, (self,))
|
177
|
+
|
178
|
+
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
|
179
|
+
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
180
|
+
# TODO: this logic should move to the scheduler
|
181
|
+
if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape)
|
182
|
+
|
183
|
+
# const folding
|
184
|
+
if self.is_unrealized_unmasked_const():
|
185
|
+
return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
|
127
186
|
|
128
|
-
def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
129
|
-
assert len(self.shape) == len(new_shape) and all(s == ns or ns == 1 for s,ns in zip(self.shape, new_shape)), \
|
130
|
-
f"reduce shape lens must match {self.shape} {new_shape}"
|
131
187
|
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
132
|
-
if not all_int(self.shape) or (0 in self.shape) or
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
#
|
137
|
-
|
138
|
-
|
139
|
-
|
188
|
+
if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \
|
189
|
+
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
|
190
|
+
return self._reduce_op(op, axis)
|
191
|
+
|
192
|
+
# if there are few globals, make some reduces into globals by splitting into two kernels
|
193
|
+
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
|
194
|
+
# ~2**10 should be enough if GROUP is used
|
195
|
+
# 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
|
196
|
+
# split is moved to the end to provide maximum locality for the second phase reduce.
|
197
|
+
self_real_strides = self.st.real_strides(ignore_valid=True)
|
198
|
+
split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
|
199
|
+
if self.shape[i] % x == 0 and self_real_strides[i] != 0]
|
200
|
+
if not split_candidates: return self._reduce_op(op, axis)
|
201
|
+
dim_to_split, divisor = split_candidates[0]
|
202
|
+
splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
|
203
|
+
splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
|
204
|
+
if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
|
205
|
+
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
|
140
206
|
|
141
207
|
# *** movement ops ***
|
142
208
|
|
143
209
|
def _view(self, new_st:ShapeTracker) -> LazyBuffer:
|
210
|
+
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
|
211
|
+
return self.const(0, new_st.shape)
|
144
212
|
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
|
145
213
|
return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
|
146
214
|
|
@@ -150,170 +218,3 @@ class LazyBuffer:
|
|
150
218
|
def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
|
151
219
|
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
|
152
220
|
def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))
|
153
|
-
|
154
|
-
# *** schedule creation ***
|
155
|
-
|
156
|
-
# recursively create a lazyop
|
157
|
-
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
|
158
|
-
realizes:Set[LazyBuffer], cache, first=True) -> LazyOp:
|
159
|
-
if (buf, st) in cache: return cache[(buf, st)]
|
160
|
-
if buf != buf.base:
|
161
|
-
st = buf.st + st
|
162
|
-
buf = buf.base
|
163
|
-
# all buffers here are base now
|
164
|
-
assert buf.op is not None
|
165
|
-
|
166
|
-
# consts are always fused and generated
|
167
|
-
if buf.op == LoadOps.CONST:
|
168
|
-
# TODO: make shapetracker unbind also return var_vals
|
169
|
-
var_vals.update(merge_dicts([var_vals, st.var_vals]))
|
170
|
-
return LazyOp(BufferOps.CONST, (), ConstBuffer(float(buf.arg), buf.dtype, st.simplify().unbind()))
|
171
|
-
|
172
|
-
# if we aren't fusing it, it's a load and we add it to the inputs
|
173
|
-
if buf.realized or (buf in realizes and not first):
|
174
|
-
if buf not in inputs: inputs.append(buf)
|
175
|
-
var_vals.update(merge_dicts([var_vals, st.var_vals]))
|
176
|
-
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, st.simplify().unbind()))
|
177
|
-
|
178
|
-
# if a CONTIGUOUS made it all the way here, just skip it
|
179
|
-
if buf.op == LoadOps.CONTIGUOUS:
|
180
|
-
assert first
|
181
|
-
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False)
|
182
|
-
|
183
|
-
# if it's a reduce, we have to change the shapetracker
|
184
|
-
if buf.op in ReduceOps:
|
185
|
-
assert st.contiguous, "ReduceOps late fusion must be contiguous"
|
186
|
-
st = ShapeTracker.from_shape(buf.srcs[0].shape)
|
187
|
-
|
188
|
-
# otherwise we fuse it like normal
|
189
|
-
cache[(buf, st)] = ret = LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False) for x in buf.srcs), buf.arg)
|
190
|
-
return ret
|
191
|
-
|
192
|
-
# recursively walk back in the graph to create the schedule
|
193
|
-
def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyBuffer],
|
194
|
-
reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> List[ScheduleItem]:
|
195
|
-
if out in seen or out.realized or out.op == LoadOps.CONST: return []
|
196
|
-
assert out.base == out
|
197
|
-
seen.add(out)
|
198
|
-
|
199
|
-
inputs: List[LazyBuffer] = []
|
200
|
-
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
|
201
|
-
if out.op == LoadOps.COPY:
|
202
|
-
op, inputs = LazyOp(LoadOps.COPY, (), out.srcs[0].base), [out.srcs[0].base]
|
203
|
-
elif out.op == LoadOps.CUSTOM:
|
204
|
-
op, inputs = LazyOp(LoadOps.CUSTOM, (), out.arg), list(out.srcs)
|
205
|
-
elif out.op == LoadOps.EMPTY:
|
206
|
-
op = LazyOp(LoadOps.EMPTY)
|
207
|
-
else:
|
208
|
-
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
|
209
|
-
op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={})
|
210
|
-
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()))
|
211
|
-
|
212
|
-
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem(op, out, tuple(inputs), var_vals)]
|
213
|
-
|
214
|
-
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
215
|
-
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
|
216
|
-
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]]):
|
217
|
-
if buf in allbufs or buf.base.realized: return
|
218
|
-
log_lazybuffer(buf)
|
219
|
-
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
|
220
|
-
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
|
221
|
-
if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
|
222
|
-
buf.dtype = dtypes.float32 # NOTE; this is what makes the dtype above not match
|
223
|
-
if buf.base != buf:
|
224
|
-
# realize all places where the buffer is expanded
|
225
|
-
if prod(buf.base.st.shape) < prod(buf.st.shape):
|
226
|
-
if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
|
227
|
-
prod(buf.base.st.shape) == prod([y-x for x,y in buf.st.views[-1].mask]):
|
228
|
-
simple_pads.add(buf.base)
|
229
|
-
else:
|
230
|
-
realizes.add(buf.base)
|
231
|
-
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
|
232
|
-
if buf.forced_realize: realizes.add(buf)
|
233
|
-
allbufs[buf] = None
|
234
|
-
if buf.op in LoadOps: realizes.add(buf.base)
|
235
|
-
if buf.op == LoadOps.COPY:
|
236
|
-
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
|
237
|
-
realizes.add(buf.srcs[0].base)
|
238
|
-
for x in buf.srcs:
|
239
|
-
children[x.base][buf] = None
|
240
|
-
_recurse_lb(x, realizes, allbufs, simple_pads, children)
|
241
|
-
|
242
|
-
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2}
|
243
|
-
def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
|
244
|
-
if buf in realizes or buf.realized: return True
|
245
|
-
# NOTE: this broke to_image_idx and coder with JIT
|
246
|
-
if buf.op in UNSAFE_PAD_OPS: return False
|
247
|
-
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
|
248
|
-
|
249
|
-
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
|
250
|
-
if seen is None: seen = set()
|
251
|
-
for out in outs: log_lazybuffer(out, scheduled=True)
|
252
|
-
|
253
|
-
# start by just realizing the buffers passed in
|
254
|
-
realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized])
|
255
|
-
allbufs: Dict[LazyBuffer, None] = {}
|
256
|
-
simple_pads: Set[LazyBuffer] = set()
|
257
|
-
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
|
258
|
-
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children)
|
259
|
-
|
260
|
-
# check if we have to realize pads
|
261
|
-
for p in simple_pads:
|
262
|
-
if not _is_padding_okay(p, realizes):
|
263
|
-
realizes.add(p)
|
264
|
-
|
265
|
-
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
266
|
-
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
|
267
|
-
for r in allbufs.keys():
|
268
|
-
if r != r.base or r.op not in ReduceOps or r in realizes: continue
|
269
|
-
|
270
|
-
# follow the reduce down
|
271
|
-
child_set: Dict[LazyBuffer, ShapeTracker] = {r: r.st}
|
272
|
-
realized_children: Dict[LazyBuffer, ShapeTracker] = {}
|
273
|
-
forced_realize = False
|
274
|
-
can_chase = True
|
275
|
-
while not forced_realize and len(child_set):
|
276
|
-
next_child_set = {}
|
277
|
-
for tr,st in child_set.items():
|
278
|
-
if tr in realizes:
|
279
|
-
realized_children[tr] = st
|
280
|
-
# can only have one output buffer
|
281
|
-
# can only reduce contiguous
|
282
|
-
# max one reduceop per kernel
|
283
|
-
if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
|
284
|
-
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
|
285
|
-
forced_realize = True
|
286
|
-
break
|
287
|
-
continue
|
288
|
-
for tr_next in children[tr].keys():
|
289
|
-
if not tr_next.realized:
|
290
|
-
# max one reduceop per kernel
|
291
|
-
if tr_next.op in ReduceOps:
|
292
|
-
forced_realize = True
|
293
|
-
break
|
294
|
-
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
|
295
|
-
if len(st_childs) > 1:
|
296
|
-
forced_realize = True
|
297
|
-
break
|
298
|
-
next_child_set[tr_next] = st + st_childs[0].st
|
299
|
-
child_set = next_child_set
|
300
|
-
if forced_realize:
|
301
|
-
tr = r
|
302
|
-
if can_chase:
|
303
|
-
# can chase this down to contiguous children
|
304
|
-
st = tr.st
|
305
|
-
while len(children[tr]) == 1:
|
306
|
-
tr_next = next(iter(children[tr].keys()))
|
307
|
-
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
|
308
|
-
if len(st_childs) > 1: break
|
309
|
-
if st.size != st_childs[0].st.size: break
|
310
|
-
st = st + st_childs[0].st
|
311
|
-
if not st.contiguous or tr_next.op in ReduceOps: break
|
312
|
-
tr = tr_next
|
313
|
-
reduce_for_op[tr] = r
|
314
|
-
realizes.add(tr)
|
315
|
-
else:
|
316
|
-
assert len(realized_children) == 1
|
317
|
-
reduce_for_op[next(iter(realized_children.keys()))] = r
|
318
|
-
|
319
|
-
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in outs)
|
tinygrad/multi.py
ADDED
@@ -0,0 +1,169 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Optional, Union, Any, Tuple, List
|
3
|
+
import functools, itertools, operator
|
4
|
+
from tinygrad.helpers import all_same, all_int, dedup, round_up, prod, DEBUG, RING
|
5
|
+
from tinygrad.dtype import DType, ConstType
|
6
|
+
from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
|
7
|
+
from tinygrad.lazy import LazyBuffer
|
8
|
+
from tinygrad.shape.shapetracker import sint
|
9
|
+
|
10
|
+
def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
|
11
|
+
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
|
12
|
+
assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
|
13
|
+
bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[op]
|
14
|
+
|
15
|
+
n_lbs, dim = len(lbs), prod(lbs[0].shape)
|
16
|
+
# Ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
17
|
+
# so just fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
18
|
+
use_ring = (RING >= 2 or (n_lbs > 2 and dim > 256_000 and RING >= 1))
|
19
|
+
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{dim} | {lbs[0].dtype}")
|
20
|
+
if not use_ring:
|
21
|
+
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
22
|
+
factor = max(f for f in [32, 16, 8, 4, 2, 1] if dim % f == 0)
|
23
|
+
base, left = (dim // factor) // n_lbs, (dim // factor) % n_lbs
|
24
|
+
c_lens = [(base + 1) * factor if i < left else base * factor for i in range(n_lbs)]
|
25
|
+
acc = 0
|
26
|
+
chunks = [(acc, (acc := acc + i)) for i in c_lens if i > 0]
|
27
|
+
chunked = [[lb.reshape((dim,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
|
28
|
+
|
29
|
+
# Scatter-reduce step
|
30
|
+
for step in range(n_lbs - 1):
|
31
|
+
for i in range(len(chunks)):
|
32
|
+
s, r = (i+step)%n_lbs, (i+step+1)%n_lbs
|
33
|
+
chunked[r][i] = chunked[r][i].e(bop, chunked[s][i].copy_to_device(chunked[r][i].device, force=True))
|
34
|
+
|
35
|
+
# Allgather step
|
36
|
+
for step in range(n_lbs - 1):
|
37
|
+
for i in range(len(chunks)):
|
38
|
+
s, r = (i+step-1)%n_lbs, (i+step)%n_lbs
|
39
|
+
chunked[r][i] = chunked[s][i].copy_to_device(chunked[r][i].device, force=True)
|
40
|
+
|
41
|
+
# Assemble chunks back
|
42
|
+
pads = [((s,dim-e),) for s,e in chunks]
|
43
|
+
return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [c.pad(pads[i]) for i,c in enumerate(lb_c)]).reshape(lbs[0].shape) for lb_c in chunked]
|
44
|
+
|
45
|
+
def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
|
46
|
+
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
|
47
|
+
sz = round_up(lbs[0].shape[axis], len(lbs)) // len(lbs)
|
48
|
+
return [lb.shrink(tuple((0,s) if a != axis else (sz*i,min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
|
49
|
+
|
50
|
+
class MultiLazyBuffer:
|
51
|
+
def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None):
|
52
|
+
assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them"
|
53
|
+
assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}"
|
54
|
+
self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
|
55
|
+
if axis is not None:
|
56
|
+
splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0))
|
57
|
+
self.bounds = [(st,ed) for st,ed in zip(splits, splits[1:])]
|
58
|
+
|
59
|
+
@property
|
60
|
+
def shape(self):
|
61
|
+
return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
|
62
|
+
|
63
|
+
@property
|
64
|
+
def size(self): return sum(x.size for x in self.real_lbs)
|
65
|
+
|
66
|
+
@property
|
67
|
+
def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r]
|
68
|
+
|
69
|
+
def __repr__(self):
|
70
|
+
return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
|
71
|
+
|
72
|
+
@staticmethod
|
73
|
+
def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
|
74
|
+
lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unmasked_const() else lb] * len(devices)
|
75
|
+
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)]
|
76
|
+
return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous() for lb in sharded_lbs], axis)
|
77
|
+
|
78
|
+
def copy_to_device(self, device:str) -> LazyBuffer:
|
79
|
+
if self.axis is None: return self.lbs[self.real.index(True)].copy_to_device(device)
|
80
|
+
sz = self.lbs[0].shape[self.axis]
|
81
|
+
llbs = []
|
82
|
+
for i,lb in enumerate([lb.copy_to_device(device) for lb in self.real_lbs]):
|
83
|
+
pad_arg = tuple((0,0) if a != self.axis else (sz*i, max(0, self.shape[self.axis]-sz*(i+1))) for a in range(len(lb.shape)))
|
84
|
+
llbs.append(lb.pad(pad_arg))
|
85
|
+
return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
|
86
|
+
|
87
|
+
# passthroughs
|
88
|
+
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True)
|
89
|
+
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
|
90
|
+
def const(self, val:ConstType) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real)
|
91
|
+
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
|
92
|
+
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
|
93
|
+
|
94
|
+
# elementwise is simple
|
95
|
+
def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer:
|
96
|
+
msrcs = (self,)+in_srcs
|
97
|
+
assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
|
98
|
+
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
|
99
|
+
|
100
|
+
# NOTE: they all have to share an axis, we always choose [-1]
|
101
|
+
axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None
|
102
|
+
srcs = []
|
103
|
+
not_all_real = any(not all(mlb.real) for mlb in msrcs)
|
104
|
+
new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real
|
105
|
+
assert any(new_real), "output contains no real lb"
|
106
|
+
for mlb in msrcs:
|
107
|
+
if mlb.axis == axis or not_all_real: srcs.append(mlb.lbs)
|
108
|
+
elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis))
|
109
|
+
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
|
110
|
+
# NOTE: lsrcs[-1].const(0) is correct for where
|
111
|
+
return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) if r else lsrcs[-1].const(0) for lsrcs,r in zip(zip(*srcs),new_real)], axis, new_real)
|
112
|
+
|
113
|
+
def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
|
114
|
+
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
|
115
|
+
|
116
|
+
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
|
117
|
+
if self.axis is not None and self.axis in axis:
|
118
|
+
# all-reduce on sharded axes
|
119
|
+
reduced_parts = [(x if r else x.const(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]
|
120
|
+
if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
|
121
|
+
return MultiLazyBuffer(reduced_parts, None, self.real)
|
122
|
+
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
123
|
+
return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)
|
124
|
+
|
125
|
+
# *** movement ops ***
|
126
|
+
|
127
|
+
def reshape(self, arg:Tuple[sint, ...]):
|
128
|
+
if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
|
129
|
+
arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
|
130
|
+
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
|
131
|
+
# todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
|
132
|
+
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
|
133
|
+
if arg[new_axis] != self.shape[self.axis]:
|
134
|
+
assert self.shape[self.axis] % len(self.real_lbs) == 0, f"cannot reshape on-axis for uneven shard {self.axis} {self.shape} {len(self.real_lbs)}"
|
135
|
+
assert arg[new_axis] % len(self.real_lbs) == 0, f"new on-axis shape must divide evenly between devices {new_axis} {arg} {len(self.real_lbs)}"
|
136
|
+
return MultiLazyBuffer([x.reshape(tuple(s if a != new_axis else
|
137
|
+
x.shape[self.axis] if s == self.shape[self.axis] else
|
138
|
+
s // len(self.real_lbs) for a,s in enumerate(arg))) for x in self.lbs],
|
139
|
+
new_axis, self.real)
|
140
|
+
|
141
|
+
def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
|
142
|
+
assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}"
|
143
|
+
# pad on shard axis -> fill others with zeros and set real to all True
|
144
|
+
if self.axis is not None and arg[self.axis] != (0,0):
|
145
|
+
# pad back to whole axis, remove real mask
|
146
|
+
assert all(arg[i] == (0, 0) or i == self.axis for i in range(len(self.shape))), "cannot pad sharded and non-sharded axis at the same time"
|
147
|
+
assert arg[self.axis] == (sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i < self.real.index(True)), \
|
148
|
+
sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i > self.real.index(True))), "can only pad to whole axis"
|
149
|
+
return MultiLazyBuffer([x if r else x.const(0) for x,r in zip(self.lbs, self.real)], self.axis)
|
150
|
+
return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)
|
151
|
+
def expand(self, arg:Tuple[sint, ...]):
|
152
|
+
# NOTE: this assert isn't needed, sharded axis can have dim 1
|
153
|
+
assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}"
|
154
|
+
return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real)
|
155
|
+
def permute(self, arg:Tuple[int, ...]):
|
156
|
+
# all permutes supported!
|
157
|
+
return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real)
|
158
|
+
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
|
159
|
+
assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
|
160
|
+
if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
|
161
|
+
assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time"
|
162
|
+
idx = self.bounds.index(arg[self.axis])
|
163
|
+
# zero out other lbs to not create lb reference
|
164
|
+
return MultiLazyBuffer([lb if i==idx else lb.const(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
|
165
|
+
return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs],
|
166
|
+
self.axis, self.real)
|
167
|
+
def stride(self, arg:Tuple[int, ...]):
|
168
|
+
assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
|
169
|
+
return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)
|