tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/__init__.py +0 -0
  3. tinygrad/codegen/kernel.py +253 -225
  4. tinygrad/codegen/linearizer.py +398 -436
  5. tinygrad/codegen/uops.py +451 -0
  6. tinygrad/device.py +268 -274
  7. tinygrad/dtype.py +56 -40
  8. tinygrad/engine/__init__.py +0 -0
  9. tinygrad/engine/graph.py +100 -0
  10. tinygrad/engine/jit.py +198 -0
  11. tinygrad/engine/realize.py +192 -0
  12. tinygrad/engine/schedule.py +370 -0
  13. tinygrad/engine/search.py +199 -0
  14. tinygrad/{mlops.py → function.py} +40 -32
  15. tinygrad/helpers.py +144 -46
  16. tinygrad/lazy.py +143 -242
  17. tinygrad/multi.py +173 -0
  18. tinygrad/nn/__init__.py +180 -9
  19. tinygrad/nn/datasets.py +8 -0
  20. tinygrad/nn/optim.py +106 -28
  21. tinygrad/nn/state.py +87 -19
  22. tinygrad/ops.py +104 -45
  23. tinygrad/renderer/__init__.py +65 -0
  24. tinygrad/renderer/assembly.py +269 -0
  25. tinygrad/renderer/cstyle.py +308 -210
  26. tinygrad/renderer/llvmir.py +119 -124
  27. tinygrad/runtime/__init__.py +0 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +13403 -0
  29. tinygrad/runtime/autogen/comgr.py +891 -0
  30. tinygrad/runtime/autogen/cuda.py +5923 -0
  31. tinygrad/runtime/autogen/hip.py +5909 -0
  32. tinygrad/runtime/autogen/hsa.py +5893 -0
  33. tinygrad/runtime/autogen/io_uring.py +1486 -0
  34. tinygrad/runtime/autogen/kfd.py +812 -0
  35. tinygrad/runtime/autogen/nv_gpu.py +33597 -0
  36. tinygrad/runtime/autogen/opencl.py +1795 -0
  37. tinygrad/runtime/driver/__init__.py +0 -0
  38. tinygrad/runtime/driver/hip_comgr.py +56 -0
  39. tinygrad/runtime/graph/__init__.py +0 -0
  40. tinygrad/runtime/graph/clang.py +39 -0
  41. tinygrad/runtime/graph/cuda.py +59 -54
  42. tinygrad/runtime/graph/hcq.py +187 -0
  43. tinygrad/runtime/graph/metal.py +37 -41
  44. tinygrad/runtime/ops_amd.py +550 -0
  45. tinygrad/runtime/ops_clang.py +16 -14
  46. tinygrad/runtime/ops_cuda.py +129 -37
  47. tinygrad/runtime/ops_disk.py +111 -43
  48. tinygrad/runtime/ops_gpu.py +52 -50
  49. tinygrad/runtime/ops_llvm.py +36 -56
  50. tinygrad/runtime/ops_metal.py +41 -24
  51. tinygrad/runtime/ops_npy.py +9 -0
  52. tinygrad/runtime/ops_nv.py +625 -0
  53. tinygrad/runtime/ops_python.py +208 -0
  54. tinygrad/shape/__init__.py +0 -0
  55. tinygrad/shape/shapetracker.py +46 -107
  56. tinygrad/shape/symbolic.py +99 -98
  57. tinygrad/shape/view.py +162 -45
  58. tinygrad/tensor.py +2492 -483
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
  60. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
  61. tinygrad-0.9.1.dist-info/RECORD +63 -0
  62. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  63. tinygrad/features/image.py +0 -93
  64. tinygrad/features/multi.py +0 -103
  65. tinygrad/features/search.py +0 -160
  66. tinygrad/graph.py +0 -106
  67. tinygrad/jit.py +0 -152
  68. tinygrad/realize.py +0 -50
  69. tinygrad/runtime/graph/hip.py +0 -24
  70. tinygrad/runtime/ops_cpu.py +0 -45
  71. tinygrad/runtime/ops_hip.py +0 -97
  72. tinygrad/runtime/ops_torch.py +0 -49
  73. tinygrad-0.8.0.dist-info/RECORD +0 -41
  74. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/lazy.py CHANGED
@@ -1,96 +1,127 @@
1
1
  from __future__ import annotations
2
- import sys, math
3
- import numpy as np
4
- from collections import defaultdict
5
- from typing import Union, Optional, Any, Tuple, List, Set, Dict, DefaultDict
6
- from tinygrad.dtype import dtypes, DType, ImageDType
7
- from tinygrad.helpers import prod, merge_dicts, flatten, getenv, dedup, DEBUG, all_int, all_same
8
- from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
2
+ import math
3
+ from typing import Union, Optional, Any, Tuple, List
4
+ from tinygrad.dtype import dtypes, DType, ConstType
5
+ from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG
6
+ from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
9
7
  from tinygrad.shape.symbolic import sint, Variable
10
8
  from tinygrad.shape.shapetracker import ShapeTracker
11
- from tinygrad.device import Buffer, Device
12
- from tinygrad.graph import log_lazybuffer
13
- from weakref import ref, WeakValueDictionary, ReferenceType
9
+ from tinygrad.device import Buffer
10
+ from weakref import ref, ReferenceType, WeakValueDictionary
14
11
 
15
- # lazy can recurse a lot
16
- sys.setrecursionlimit(10000)
12
+ lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
13
+ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
14
+ base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
15
+ if st.size == 0: op, arg, srcs, base = LoadOps.CONST, 0, (), None
16
+ if op is LoadOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, Variable) else arg, True
17
17
 
18
- lazycache: WeakValueDictionary = WeakValueDictionary()
19
- def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType,
20
- op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
21
- base:Optional[LazyBuffer]=None):
22
- if 0 in st.shape: st, op, arg, srcs = ShapeTracker.from_shape(st.shape), LoadOps.CONST, 0, ()
23
-
24
- wop = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs), ref(base) if base else None)
25
- if wop in lazycache: return lazycache[wop]
18
+ cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
19
+ if enable_cache and (rret := lazycache.get(cache_key, None)): return rret
26
20
 
27
21
  ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base)
28
- # TODO: remove LoadOps.CONST here while keeping a pretty graph and working fusions
29
- # TODO: might be possible to remove LoadOps.COPY
30
- if op not in {LoadOps.EMPTY, LoadOps.CUSTOM, LoadOps.CONST, LoadOps.COPY} and getenv("LAZYCACHE", 1): lazycache[wop] = ret
22
+ if enable_cache: lazycache[cache_key] = ret
31
23
  return ret
32
24
 
25
+ view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "DISK"}
33
26
  class LazyBuffer:
34
27
  def __init__(self, device:str, st:ShapeTracker, dtype:DType,
35
28
  op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
36
29
  base:Optional[LazyBuffer]=None):
37
- assert isinstance(device, str) and device == Device.canonicalize(device)
38
30
  self.device, self.st, self.dtype, self.shape, self.size = device, st, dtype, st.shape, st.size
31
+ self._base: Optional[LazyBuffer] = None
39
32
  if base is None:
40
33
  # properties on base
41
34
  self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
42
- self.realized: Optional[Buffer] = None
43
- self.output_buffer: Optional[Buffer] = None
44
- self.forced_realize = False
35
+ assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
36
+
37
+ if self.op is LoadOps.VIEW:
38
+ # some LazyBuffers can be processed with only a view, no AST required
39
+ self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
40
+ else:
41
+ self.buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
42
+ self.buffer.ref(1)
45
43
  self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
44
+ self.forced_realize = False
46
45
  else:
47
46
  # properties on view
48
47
  assert base.base == base, "base must be a base itself"
49
48
  self._base = base
50
49
 
50
+ def __del__(self):
51
+ if hasattr(self, 'buffer'): self.buffer.ref(-1)
52
+
51
53
  def __repr__(self) -> str:
52
- return f"<LB {self.device} {self.shape} contig:{self.st.contiguous} {self.st if hasattr(self, '_base') else (self.op, self.realized)}>"
54
+ return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base != self else (self.op, self.realized)}>"
53
55
 
54
56
  @property
55
- def base(self) -> LazyBuffer: return self._base if hasattr(self, '_base') else self
57
+ def realized(self) -> Optional[Buffer]:
58
+ # NOTE: we check for a lack of srcs instead of an allocated buffer to make unrealized assigns return None here
59
+ return self.buffer if self._base is None and not hasattr(self, 'srcs') else None
60
+
61
+ # NOTE: this has to be a function to prevent self reference
62
+ @property
63
+ def base(self) -> LazyBuffer: return self._base if self._base is not None else self
64
+
65
+ # same API as multi
66
+ @property
67
+ def lbs(self) -> List[LazyBuffer]: return [self]
56
68
 
57
69
  @staticmethod
58
- def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Optional[LazyBuffer]=None) -> LazyBuffer:
59
- return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, (src,) if src is not None else ())
70
+ def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
71
+ assert isinstance(src, tuple)
72
+ return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
73
+
74
+ def const(self, val:ConstType, shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer:
75
+ assert isinstance(val, (int,float,bool)), f"{val=} has {type(val)=}, not a ConstType"
76
+ shape = self.shape if shape is None else shape
77
+ return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
78
+
79
+ def is_realized(self) -> bool: return self.base.realized is not None
80
+
81
+ def assign(self, x:LazyBuffer) -> LazyBuffer:
82
+ assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
83
+ return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base))
60
84
 
61
- def const(self, val:Union[float, int]) -> LazyBuffer:
62
- return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
85
+ def can_view(self): return self.st.consecutive and not self.is_unrealized_const() and self.device.split(":")[0] in view_supported_devices
63
86
 
64
- def contiguous(self):
87
+ def contiguous(self, allow_buffer_view=True):
65
88
  if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
66
- ret = self.e(LoadOps.CONTIGUOUS)
67
- sti = self.st.invert(self.base.shape)
68
- if sti is not None: self.base.contiguous_child = ref(ret), sti
89
+ ret = self.e(LoadOps.VIEW) if allow_buffer_view and self.can_view() else self.e(LoadOps.CONTIGUOUS)
90
+ if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
69
91
  return ret
70
92
  self.base.forced_realize = True
71
93
  return self
72
94
 
73
95
  def cast(self, dtype:DType, bitcast:bool=False):
74
96
  if self.dtype == dtype: return self
75
- return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
76
-
77
- def is_unrealized_const(self): return not self.base.realized and self.base.op == LoadOps.CONST
78
- def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op == LoadOps.CONST
79
-
80
- def schedule(self, seen=None): return create_schedule([self], seen)
81
-
82
- @staticmethod
83
- def fromCPU(x: np.ndarray) -> LazyBuffer:
84
- ret = LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), dtypes.from_np(x.dtype), op=LoadOps.EMPTY)
85
- ret.realized = Buffer("CPU", prod(x.shape), dtypes.from_np(x.dtype), x.flatten())
86
- return ret
87
-
88
- def copy_to_device(self, device:str) -> LazyBuffer:
97
+ if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
98
+ if self.is_unrealized_unmasked_const() and not bitcast:
99
+ return create_lazybuffer(self.device, self.st, dtype, LoadOps.CONST, dtypes.as_const(self.base.arg, dtype))
100
+ new_shape = self.shape
101
+ if bitcast and self.dtype.itemsize != dtype.itemsize:
102
+ if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
103
+ if not all_int(new_shape): raise RuntimeError("shape changing bitcast with symbolic shape isn't supported yet")
104
+ # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
105
+ if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
106
+ new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
107
+ elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
108
+ # TODO: applying this makes gpt2 slower
109
+ return self.base.cast(dtype, bitcast)._view(self.st)
110
+ cast_op: Union[LoadOps, UnaryOps] = (LoadOps.VIEW if self.can_view() else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
111
+ return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
112
+
113
+ def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST and not isinstance(self.base.arg, Variable)
114
+ def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
115
+
116
+ def _copy(self, device:str) -> LazyBuffer:
117
+ return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, self.buffer.nbytes, (self,), enable_cache=False)
118
+
119
+ def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer:
89
120
  # no COPY
90
121
  if self.device == device: return self
91
122
 
92
123
  # double COPY = one COPY
93
- if self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op == LoadOps.COPY:
124
+ if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is LoadOps.COPY:
94
125
  return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
95
126
 
96
127
  # const doesn't have to be copied (issues with disk tensor)
@@ -98,11 +129,10 @@ class LazyBuffer:
98
129
  return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
99
130
 
100
131
  # if it's a shrink, do the shrink before the copy with CONTIGUOUS
101
- if prod(self.st.shape) < prod(self.base.st.shape):
102
- return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, srcs=(self.contiguous(),))
132
+ if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
103
133
 
104
134
  # copy the base and apply the shapetracker on the new device
105
- return create_lazybuffer(device, self.base.st, self.dtype, LoadOps.COPY, srcs=(self.base,))._view(self.st)
135
+ return self.base._copy(device)._view(self.st)
106
136
 
107
137
  def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
108
138
  srcs: List[LazyBuffer] = []
@@ -111,36 +141,74 @@ class LazyBuffer:
111
141
  srcs.append(root._view(s.base.contiguous_child[1]))
112
142
  else:
113
143
  srcs.append(s)
114
- assert all_same(dts:=[x.dtype.scalar() for x in (srcs if op != TernaryOps.WHERE else srcs[1:])]), f"all dtypes must match {dts} on {op}"
144
+ assert all_same(dts:=[x.dtype.scalar() for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]), f"all dtypes must match {dts} on {op}"
115
145
  assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
116
- if op == TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
117
- out_dtype = srcs[-1].dtype if op not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else dtypes.bool
118
- ret = create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
119
- return ret
146
+ if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
147
+ if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
148
+
149
+ out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else srcs[-1].dtype
150
+
151
+ # const folding
152
+ if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
153
+ return self.cast(out_dtype).const(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
154
+ if op is UnaryOps.NEG and self.base.op is UnaryOps.NEG and self.base.realized is None: return self.base.srcs[0]
155
+ if op in BinaryOps:
156
+ x, y = self, in_srcs[0]
157
+ if op is BinaryOps.ADD:
158
+ if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
159
+ if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
160
+ if op is BinaryOps.MUL:
161
+ if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0, -1):
162
+ return y if val == 1 else y.const(0) if val == 0 else y.e(UnaryOps.NEG)
163
+ if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0, -1):
164
+ return x if val == 1 else x.const(0) if val == 0 else x.e(UnaryOps.NEG)
165
+
166
+ return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
120
167
 
121
168
  # *** reduce ops ***
122
169
 
123
- def _reduce_op(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
124
- if self.shape == tuple(new_shape): return self
125
- unbound_new_shape = tuple(s.unbind()[0] if not isinstance(s, int) else s for s in new_shape)
126
- return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, unbound_new_shape, (self,))
170
+ def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
171
+ assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
172
+ axis = tuple(sorted([x for x in axis if self.shape[x] != 1]))
173
+ if len(axis) == 0: return self
174
+ new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
175
+ return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, axis, (self,))
176
+
177
+ def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
178
+ new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
179
+ # TODO: this logic should move to the scheduler
180
+ if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape)
181
+
182
+ # const folding
183
+ # TODO: fold this for symbolic?
184
+ if self.is_unrealized_unmasked_const() and all_int(self.shape):
185
+ return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
127
186
 
128
- def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
129
- assert len(self.shape) == len(new_shape) and all(s == ns or ns == 1 for s,ns in zip(self.shape, new_shape)), \
130
- f"reduce shape lens must match {self.shape} {new_shape}"
131
187
  # TODO: can we split symbolic shape if the reduce axis is not symbolic?
132
- if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
133
- return self._reduce_op(op, new_shape)
134
- heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore # noqa: E501
135
- if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape)
136
- # choose largest divisor (>=16) to split on, penalize large strides
137
- def splitted_shape(dim_aft_div):
138
- return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
139
- return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
188
+ if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \
189
+ prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
190
+ return self._reduce_op(op, axis)
191
+
192
+ # if there are few globals, make some reduces into globals by splitting into two kernels
193
+ # cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
194
+ # ~2**10 should be enough if GROUP is used
195
+ # 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
196
+ # split is moved to the end to provide maximum locality for the second phase reduce.
197
+ self_real_strides = self.st.real_strides(ignore_valid=True)
198
+ split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
199
+ if self.shape[i] % x == 0 and self_real_strides[i] != 0]
200
+ if not split_candidates: return self._reduce_op(op, axis)
201
+ dim_to_split, divisor = split_candidates[0]
202
+ splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
203
+ splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
204
+ if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
205
+ return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
140
206
 
141
207
  # *** movement ops ***
142
208
 
143
209
  def _view(self, new_st:ShapeTracker) -> LazyBuffer:
210
+ if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
211
+ return self.const(0, new_st.shape)
144
212
  if new_st.contiguous and self.base.shape == new_st.shape: return self.base
145
213
  return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
146
214
 
@@ -150,170 +218,3 @@ class LazyBuffer:
150
218
  def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
151
219
  def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
152
220
  def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))
153
-
154
- # *** schedule creation ***
155
-
156
- # recursively create a lazyop
157
- def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
158
- realizes:Set[LazyBuffer], cache, first=True) -> LazyOp:
159
- if (buf, st) in cache: return cache[(buf, st)]
160
- if buf != buf.base:
161
- st = buf.st + st
162
- buf = buf.base
163
- # all buffers here are base now
164
- assert buf.op is not None
165
-
166
- # consts are always fused and generated
167
- if buf.op == LoadOps.CONST:
168
- # TODO: make shapetracker unbind also return var_vals
169
- var_vals.update(merge_dicts([var_vals, st.var_vals]))
170
- return LazyOp(BufferOps.CONST, (), ConstBuffer(float(buf.arg), buf.dtype, st.simplify().unbind()))
171
-
172
- # if we aren't fusing it, it's a load and we add it to the inputs
173
- if buf.realized or (buf in realizes and not first):
174
- if buf not in inputs: inputs.append(buf)
175
- var_vals.update(merge_dicts([var_vals, st.var_vals]))
176
- return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, st.simplify().unbind()))
177
-
178
- # if a CONTIGUOUS made it all the way here, just skip it
179
- if buf.op == LoadOps.CONTIGUOUS:
180
- assert first
181
- return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False)
182
-
183
- # if it's a reduce, we have to change the shapetracker
184
- if buf.op in ReduceOps:
185
- assert st.contiguous, "ReduceOps late fusion must be contiguous"
186
- st = ShapeTracker.from_shape(buf.srcs[0].shape)
187
-
188
- # otherwise we fuse it like normal
189
- cache[(buf, st)] = ret = LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False) for x in buf.srcs), buf.arg)
190
- return ret
191
-
192
- # recursively walk back in the graph to create the schedule
193
- def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyBuffer],
194
- reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> List[ScheduleItem]:
195
- if out in seen or out.realized or out.op == LoadOps.CONST: return []
196
- assert out.base == out
197
- seen.add(out)
198
-
199
- inputs: List[LazyBuffer] = []
200
- var_vals: Dict[Variable, int] = out.st.var_vals.copy()
201
- if out.op == LoadOps.COPY:
202
- op, inputs = LazyOp(LoadOps.COPY, (), out.srcs[0].base), [out.srcs[0].base]
203
- elif out.op == LoadOps.CUSTOM:
204
- op, inputs = LazyOp(LoadOps.CUSTOM, (), out.arg), list(out.srcs)
205
- elif out.op == LoadOps.EMPTY:
206
- op = LazyOp(LoadOps.EMPTY)
207
- else:
208
- output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
209
- op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={})
210
- op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()))
211
-
212
- return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem(op, out, tuple(inputs), var_vals)]
213
-
214
- # recursively search the entire graph for all LazyBuffers, insert realizes after expands
215
- def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
216
- simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]]):
217
- if buf in allbufs or buf.base.realized: return
218
- log_lazybuffer(buf)
219
- if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
220
- not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
221
- if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
222
- buf.dtype = dtypes.float32 # NOTE; this is what makes the dtype above not match
223
- if buf.base != buf:
224
- # realize all places where the buffer is expanded
225
- if prod(buf.base.st.shape) < prod(buf.st.shape):
226
- if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
227
- prod(buf.base.st.shape) == prod([y-x for x,y in buf.st.views[-1].mask]):
228
- simple_pads.add(buf.base)
229
- else:
230
- realizes.add(buf.base)
231
- return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
232
- if buf.forced_realize: realizes.add(buf)
233
- allbufs[buf] = None
234
- if buf.op in LoadOps: realizes.add(buf.base)
235
- if buf.op == LoadOps.COPY:
236
- assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
237
- realizes.add(buf.srcs[0].base)
238
- for x in buf.srcs:
239
- children[x.base][buf] = None
240
- _recurse_lb(x, realizes, allbufs, simple_pads, children)
241
-
242
- UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2}
243
- def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
244
- if buf in realizes or buf.realized: return True
245
- # NOTE: this broke to_image_idx and coder with JIT
246
- if buf.op in UNSAFE_PAD_OPS: return False
247
- return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
248
-
249
- def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
250
- if seen is None: seen = set()
251
- for out in outs: log_lazybuffer(out, scheduled=True)
252
-
253
- # start by just realizing the buffers passed in
254
- realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized])
255
- allbufs: Dict[LazyBuffer, None] = {}
256
- simple_pads: Set[LazyBuffer] = set()
257
- children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
258
- for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children)
259
-
260
- # check if we have to realize pads
261
- for p in simple_pads:
262
- if not _is_padding_okay(p, realizes):
263
- realizes.add(p)
264
-
265
- # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
266
- reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
267
- for r in allbufs.keys():
268
- if r != r.base or r.op not in ReduceOps or r in realizes: continue
269
-
270
- # follow the reduce down
271
- child_set: Dict[LazyBuffer, ShapeTracker] = {r: r.st}
272
- realized_children: Dict[LazyBuffer, ShapeTracker] = {}
273
- forced_realize = False
274
- can_chase = True
275
- while not forced_realize and len(child_set):
276
- next_child_set = {}
277
- for tr,st in child_set.items():
278
- if tr in realizes:
279
- realized_children[tr] = st
280
- # can only have one output buffer
281
- # can only reduce contiguous
282
- # max one reduceop per kernel
283
- if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
284
- can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
285
- forced_realize = True
286
- break
287
- continue
288
- for tr_next in children[tr].keys():
289
- if not tr_next.realized:
290
- # max one reduceop per kernel
291
- if tr_next.op in ReduceOps:
292
- forced_realize = True
293
- break
294
- st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
295
- if len(st_childs) > 1:
296
- forced_realize = True
297
- break
298
- next_child_set[tr_next] = st + st_childs[0].st
299
- child_set = next_child_set
300
- if forced_realize:
301
- tr = r
302
- if can_chase:
303
- # can chase this down to contiguous children
304
- st = tr.st
305
- while len(children[tr]) == 1:
306
- tr_next = next(iter(children[tr].keys()))
307
- st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
308
- if len(st_childs) > 1: break
309
- if st.size != st_childs[0].st.size: break
310
- st = st + st_childs[0].st
311
- if not st.contiguous or tr_next.op in ReduceOps: break
312
- tr = tr_next
313
- reduce_for_op[tr] = r
314
- realizes.add(tr)
315
- else:
316
- assert len(realized_children) == 1
317
- reduce_for_op[next(iter(realized_children.keys()))] = r
318
-
319
- return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in outs)
tinygrad/multi.py ADDED
@@ -0,0 +1,173 @@
1
+ from __future__ import annotations
2
+ from typing import Optional, Union, Any, Tuple, List
3
+ import functools, itertools, operator
4
+ from tinygrad.helpers import all_same, all_int, dedup, round_up, prod, DEBUG, RING
5
+ from tinygrad.dtype import DType, ConstType
6
+ from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
7
+ from tinygrad.lazy import LazyBuffer
8
+ from tinygrad.shape.shapetracker import sint
9
+
10
+ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
11
+ assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
12
+ assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
13
+ bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[op]
14
+
15
+ n_lbs, dim = len(lbs), prod(lbs[0].shape)
16
+ # Ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
17
+ # so just fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
18
+ use_ring = (RING >= 2 or (n_lbs > 2 and dim > 256_000 and RING >= 1))
19
+ if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{dim} | {lbs[0].dtype}")
20
+ if not use_ring:
21
+ return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
22
+ factor = max(f for f in [32, 16, 8, 4, 2, 1] if dim % f == 0)
23
+ base, left = (dim // factor) // n_lbs, (dim // factor) % n_lbs
24
+ c_lens = [(base + 1) * factor if i < left else base * factor for i in range(n_lbs)]
25
+ acc = 0
26
+ chunks = [(acc, (acc := acc + i)) for i in c_lens if i > 0]
27
+ chunked = [[lb.reshape((dim,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
28
+
29
+ # Scatter-reduce step
30
+ for step in range(n_lbs - 1):
31
+ for i in range(len(chunks)):
32
+ s, r = (i+step)%n_lbs, (i+step+1)%n_lbs
33
+ chunked[r][i] = chunked[r][i].e(bop, chunked[s][i].copy_to_device(chunked[r][i].device, force=True))
34
+
35
+ # Allgather step
36
+ for step in range(n_lbs - 1):
37
+ for i in range(len(chunks)):
38
+ s, r = (i+step-1)%n_lbs, (i+step)%n_lbs
39
+ chunked[r][i] = chunked[s][i].copy_to_device(chunked[r][i].device, force=True)
40
+
41
+ # Assemble chunks back
42
+ pads = [((s,dim-e),) for s,e in chunks]
43
+ return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [c.pad(pads[i]) for i,c in enumerate(lb_c)]).reshape(lbs[0].shape) for lb_c in chunked]
44
+
45
+ def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
46
+ if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
47
+ sz = round_up(lbs[0].shape[axis], len(lbs)) // len(lbs)
48
+ return [lb.shrink(tuple((0,s) if a != axis else (min(s,sz*i),min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
49
+
50
+ class MultiLazyBuffer:
51
+ def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None):
52
+ assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them"
53
+ assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}"
54
+ self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
55
+ if axis is not None:
56
+ splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0))
57
+ self.bounds = list(zip(splits, splits[1:]))
58
+
59
+ @property
60
+ def shape(self):
61
+ return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
62
+
63
+ @property
64
+ def size(self): return sum(x.size for x in self.real_lbs)
65
+
66
+ @property
67
+ def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r]
68
+
69
+ def __repr__(self):
70
+ return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
71
+
72
+ @staticmethod
73
+ def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
74
+ lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unmasked_const() else lb] * len(devices)
75
+ sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)]
76
+ return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis)
77
+
78
+ def copy_to_device(self, device:str) -> LazyBuffer:
79
+ if self.axis is None:
80
+ # if we already have a copy on the device, return that
81
+ for lb in self.real_lbs:
82
+ if lb.device == device: return lb
83
+ return self.lbs[self.real.index(True)].copy_to_device(device)
84
+ llbs:List[LazyBuffer] = []
85
+ for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
86
+ if not real: continue
87
+ pad_arg = tuple((0,0) if a != self.axis else (start, self.bounds[-1][1]-end) for a in range(len(lb.shape)))
88
+ llbs.append(lb.copy_to_device(device).pad(pad_arg))
89
+ return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
90
+
91
+ # passthroughs
92
+ def is_realized(self) -> bool: return all(lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True)
93
+ def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
94
+ def const(self, val:ConstType) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real)
95
+ def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
96
+ def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
97
+
98
+ # elementwise is simple
99
+ def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer:
100
+ msrcs = (self,)+in_srcs
101
+ assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
102
+ assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
103
+
104
+ # NOTE: they all have to share an axis, we always choose [-1]
105
+ axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None
106
+ srcs = []
107
+ not_all_real = any(not all(mlb.real) for mlb in msrcs)
108
+ new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real
109
+ assert any(new_real), "output contains no real lb"
110
+ for mlb in msrcs:
111
+ if mlb.axis == axis or not_all_real: srcs.append(mlb.lbs)
112
+ elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis))
113
+ else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
114
+ # NOTE: lsrcs[-1].const(0) is correct for where
115
+ return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) if r else lsrcs[-1].const(0) for lsrcs,r in zip(zip(*srcs),new_real)], axis, new_real)
116
+
117
+ def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
118
+ return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
119
+
120
+ def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
121
+ if self.axis is not None and self.axis in axis:
122
+ # all-reduce on sharded axes
123
+ reduced_parts = [(x if r else x.const(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]
124
+ if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
125
+ return MultiLazyBuffer(reduced_parts, None, self.real)
126
+ # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
127
+ return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)
128
+
129
+ # *** movement ops ***
130
+
131
+ def reshape(self, arg:Tuple[sint, ...]):
132
+ if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
133
+ arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
134
+ # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
135
+ # todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
136
+ new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
137
+ if arg[new_axis] != self.shape[self.axis]:
138
+ assert self.shape[self.axis] % len(self.real_lbs) == 0, f"cannot reshape on-axis for uneven shard {self.axis} {self.shape} {len(self.real_lbs)}"
139
+ assert arg[new_axis] % len(self.real_lbs) == 0, f"new on-axis shape must divide evenly between devices {new_axis} {arg} {len(self.real_lbs)}"
140
+ return MultiLazyBuffer([x.reshape(tuple(s if a != new_axis else
141
+ x.shape[self.axis] if s == self.shape[self.axis] else
142
+ s // len(self.real_lbs) for a,s in enumerate(arg))) for x in self.lbs],
143
+ new_axis, self.real)
144
+
145
+ def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
146
+ assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}"
147
+ # pad on shard axis -> fill others with zeros and set real to all True
148
+ if self.axis is not None and arg[self.axis] != (0,0):
149
+ # pad back to whole axis, remove real mask
150
+ assert all(arg[i] == (0, 0) or i == self.axis for i in range(len(self.shape))), "cannot pad sharded and non-sharded axis at the same time"
151
+ assert arg[self.axis] == (sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i < self.real.index(True)), \
152
+ sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i > self.real.index(True))), "can only pad to whole axis"
153
+ return MultiLazyBuffer([x if r else x.const(0) for x,r in zip(self.lbs, self.real)], self.axis)
154
+ return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)
155
+ def expand(self, arg:Tuple[sint, ...]):
156
+ # NOTE: this assert isn't needed, sharded axis can have dim 1
157
+ assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}"
158
+ return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real)
159
+ def permute(self, arg:Tuple[int, ...]):
160
+ # all permutes supported!
161
+ return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real)
162
+ def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
163
+ assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
164
+ if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
165
+ assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time"
166
+ idx = self.bounds.index(arg[self.axis])
167
+ # zero out other lbs to not create lb reference
168
+ return MultiLazyBuffer([lb if i==idx else lb.const(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
169
+ return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs],
170
+ self.axis, self.real)
171
+ def stride(self, arg:Tuple[int, ...]):
172
+ assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
173
+ return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)