tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/kernel.py +230 -190
  3. tinygrad/codegen/linearizer.py +278 -384
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +132 -275
  6. tinygrad/dtype.py +53 -37
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +28 -14
  14. tinygrad/helpers.py +72 -43
  15. tinygrad/lazy.py +141 -240
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +179 -8
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +106 -28
  20. tinygrad/nn/state.py +86 -17
  21. tinygrad/ops.py +70 -44
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +299 -206
  25. tinygrad/renderer/llvmir.py +118 -123
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +59 -54
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +37 -41
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +16 -14
  43. tinygrad/runtime/ops_cuda.py +130 -38
  44. tinygrad/runtime/ops_disk.py +45 -42
  45. tinygrad/runtime/ops_gpu.py +52 -50
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +36 -56
  48. tinygrad/runtime/ops_metal.py +42 -24
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +41 -105
  53. tinygrad/shape/symbolic.py +98 -95
  54. tinygrad/shape/view.py +137 -35
  55. tinygrad/tensor.py +2367 -442
  56. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/features/image.py +0 -93
  61. tinygrad/features/multi.py +0 -103
  62. tinygrad/features/search.py +0 -160
  63. tinygrad/graph.py +0 -106
  64. tinygrad/jit.py +0 -152
  65. tinygrad/realize.py +0 -50
  66. tinygrad/runtime/graph/hip.py +0 -24
  67. tinygrad/runtime/ops_cpu.py +0 -45
  68. tinygrad/runtime/ops_hip.py +0 -97
  69. tinygrad/runtime/ops_torch.py +0 -49
  70. tinygrad-0.8.0.dist-info/RECORD +0 -41
  71. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/lazy.py CHANGED
@@ -1,96 +1,126 @@
1
1
  from __future__ import annotations
2
- import sys, math
3
- import numpy as np
4
- from collections import defaultdict
5
- from typing import Union, Optional, Any, Tuple, List, Set, Dict, DefaultDict
6
- from tinygrad.dtype import dtypes, DType, ImageDType
7
- from tinygrad.helpers import prod, merge_dicts, flatten, getenv, dedup, DEBUG, all_int, all_same
8
- from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
2
+ import math
3
+ from typing import Union, Optional, Any, Tuple, List
4
+ from tinygrad.dtype import dtypes, DType, ConstType
5
+ from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG
6
+ from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
9
7
  from tinygrad.shape.symbolic import sint, Variable
10
8
  from tinygrad.shape.shapetracker import ShapeTracker
11
- from tinygrad.device import Buffer, Device
12
- from tinygrad.graph import log_lazybuffer
13
- from weakref import ref, WeakValueDictionary, ReferenceType
9
+ from tinygrad.device import Buffer
10
+ from weakref import ref, ReferenceType, WeakValueDictionary
14
11
 
15
- # lazy can recurse a lot
16
- sys.setrecursionlimit(10000)
12
+ lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
13
+ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
14
+ base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
15
+ if st.size == 0: op, arg, srcs, base = LoadOps.CONST, 0, (), None
16
+ if op is LoadOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, Variable) else arg, True
17
17
 
18
- lazycache: WeakValueDictionary = WeakValueDictionary()
19
- def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType,
20
- op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
21
- base:Optional[LazyBuffer]=None):
22
- if 0 in st.shape: st, op, arg, srcs = ShapeTracker.from_shape(st.shape), LoadOps.CONST, 0, ()
23
-
24
- wop = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs), ref(base) if base else None)
25
- if wop in lazycache: return lazycache[wop]
18
+ cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
19
+ if enable_cache and (rret := lazycache.get(cache_key, None)): return rret
26
20
 
27
21
  ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base)
28
- # TODO: remove LoadOps.CONST here while keeping a pretty graph and working fusions
29
- # TODO: might be possible to remove LoadOps.COPY
30
- if op not in {LoadOps.EMPTY, LoadOps.CUSTOM, LoadOps.CONST, LoadOps.COPY} and getenv("LAZYCACHE", 1): lazycache[wop] = ret
22
+ if enable_cache: lazycache[cache_key] = ret
31
23
  return ret
32
24
 
25
+ view_supported_devices = {"LLVM", "CLANG", "CUDA", "DISK"}
33
26
  class LazyBuffer:
34
27
  def __init__(self, device:str, st:ShapeTracker, dtype:DType,
35
28
  op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
36
29
  base:Optional[LazyBuffer]=None):
37
- assert isinstance(device, str) and device == Device.canonicalize(device)
38
30
  self.device, self.st, self.dtype, self.shape, self.size = device, st, dtype, st.shape, st.size
31
+ self._base: Optional[LazyBuffer] = None
39
32
  if base is None:
40
33
  # properties on base
41
34
  self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
42
- self.realized: Optional[Buffer] = None
43
- self.output_buffer: Optional[Buffer] = None
44
- self.forced_realize = False
35
+ assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
36
+
37
+ if (self.op is LoadOps.CONTIGUOUS or self.op is UnaryOps.BITCAST) and srcs[0].st.consecutive and \
38
+ not srcs[0].is_unrealized_const() and device.split(":")[0] in view_supported_devices:
39
+ # some LazyBuffers can be processed with only a view, no AST required
40
+ self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
41
+ self.op = LoadOps.VIEW
42
+ else:
43
+ self.buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
44
+ self.buffer.ref(1)
45
45
  self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
46
+ self.forced_realize = False
46
47
  else:
47
48
  # properties on view
48
49
  assert base.base == base, "base must be a base itself"
49
50
  self._base = base
50
51
 
52
+ def __del__(self):
53
+ if hasattr(self, 'buffer'): self.buffer.ref(-1)
54
+
51
55
  def __repr__(self) -> str:
52
- return f"<LB {self.device} {self.shape} contig:{self.st.contiguous} {self.st if hasattr(self, '_base') else (self.op, self.realized)}>"
56
+ return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base != self else (self.op, self.realized)}>"
53
57
 
54
58
  @property
55
- def base(self) -> LazyBuffer: return self._base if hasattr(self, '_base') else self
59
+ def realized(self) -> Optional[Buffer]:
60
+ # NOTE: we check for a lack of srcs instead of an allocated buffer to make unrealized assigns return None here
61
+ return self.buffer if self._base is None and not hasattr(self, 'srcs') else None
62
+
63
+ # NOTE: this has to be a function to prevent self reference
64
+ @property
65
+ def base(self) -> LazyBuffer: return self._base if self._base is not None else self
66
+
67
+ # same API as multi
68
+ @property
69
+ def lbs(self) -> List[LazyBuffer]: return [self]
56
70
 
57
71
  @staticmethod
58
- def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Optional[LazyBuffer]=None) -> LazyBuffer:
59
- return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, (src,) if src is not None else ())
72
+ def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
73
+ assert isinstance(src, tuple)
74
+ return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
60
75
 
61
- def const(self, val:Union[float, int]) -> LazyBuffer:
62
- return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
76
+ def const(self, val:ConstType, shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer:
77
+ shape = self.shape if shape is None else shape
78
+ return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
79
+
80
+ def is_realized(self) -> bool: return self.base.realized is not None
81
+
82
+ def assign(self, x:LazyBuffer) -> LazyBuffer:
83
+ assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
84
+ return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base))
63
85
 
64
86
  def contiguous(self):
65
87
  if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
66
88
  ret = self.e(LoadOps.CONTIGUOUS)
67
- sti = self.st.invert(self.base.shape)
68
- if sti is not None: self.base.contiguous_child = ref(ret), sti
89
+ if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
69
90
  return ret
70
91
  self.base.forced_realize = True
71
92
  return self
72
93
 
73
94
  def cast(self, dtype:DType, bitcast:bool=False):
74
95
  if self.dtype == dtype: return self
75
- return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
76
-
77
- def is_unrealized_const(self): return not self.base.realized and self.base.op == LoadOps.CONST
78
- def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op == LoadOps.CONST
79
-
80
- def schedule(self, seen=None): return create_schedule([self], seen)
81
-
82
- @staticmethod
83
- def fromCPU(x: np.ndarray) -> LazyBuffer:
84
- ret = LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), dtypes.from_np(x.dtype), op=LoadOps.EMPTY)
85
- ret.realized = Buffer("CPU", prod(x.shape), dtypes.from_np(x.dtype), x.flatten())
86
- return ret
87
-
88
- def copy_to_device(self, device:str) -> LazyBuffer:
96
+ if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
97
+ if self.is_unrealized_unmasked_const() and not bitcast:
98
+ return create_lazybuffer(self.device, self.st, dtype, LoadOps.CONST, dtypes.as_const(self.base.arg, dtype))
99
+ # TODO: applying this makes gpt2 slower
100
+ if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
101
+ return self.base.cast(dtype, bitcast)._view(self.st)
102
+ new_shape = self.shape
103
+ if bitcast and self.dtype.itemsize != dtype.itemsize:
104
+ if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
105
+ if not all_int(new_shape): raise RuntimeError("shape changing bitcast with symbolic shape isn't supported yet")
106
+ # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
107
+ if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
108
+ new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
109
+ cast_op = UnaryOps.BITCAST if bitcast else UnaryOps.CAST
110
+ return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
111
+
112
+ def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST and not isinstance(self.base.arg, Variable)
113
+ def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
114
+
115
+ def _copy(self, device:str) -> LazyBuffer:
116
+ return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, self.buffer.nbytes, (self,), enable_cache=False)
117
+
118
+ def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer:
89
119
  # no COPY
90
120
  if self.device == device: return self
91
121
 
92
122
  # double COPY = one COPY
93
- if self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op == LoadOps.COPY:
123
+ if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is LoadOps.COPY:
94
124
  return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
95
125
 
96
126
  # const doesn't have to be copied (issues with disk tensor)
@@ -98,11 +128,10 @@ class LazyBuffer:
98
128
  return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
99
129
 
100
130
  # if it's a shrink, do the shrink before the copy with CONTIGUOUS
101
- if prod(self.st.shape) < prod(self.base.st.shape):
102
- return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, srcs=(self.contiguous(),))
131
+ if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
103
132
 
104
133
  # copy the base and apply the shapetracker on the new device
105
- return create_lazybuffer(device, self.base.st, self.dtype, LoadOps.COPY, srcs=(self.base,))._view(self.st)
134
+ return self.base._copy(device)._view(self.st)
106
135
 
107
136
  def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
108
137
  srcs: List[LazyBuffer] = []
@@ -111,36 +140,75 @@ class LazyBuffer:
111
140
  srcs.append(root._view(s.base.contiguous_child[1]))
112
141
  else:
113
142
  srcs.append(s)
114
- assert all_same(dts:=[x.dtype.scalar() for x in (srcs if op != TernaryOps.WHERE else srcs[1:])]), f"all dtypes must match {dts} on {op}"
143
+ assert all_same(dts:=[x.dtype.scalar() for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]), f"all dtypes must match {dts} on {op}"
115
144
  assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
116
- if op == TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
117
- out_dtype = srcs[-1].dtype if op not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else dtypes.bool
118
- ret = create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
119
- return ret
145
+ if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
146
+ if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
147
+
148
+ out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else srcs[-1].dtype
149
+
150
+ # const folding
151
+ if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
152
+ return self.cast(out_dtype).const(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
153
+ if op is UnaryOps.NEG and self.base.op is UnaryOps.NEG: return self.base.srcs[0]
154
+ if op in BinaryOps: x, y = self, in_srcs[0]
155
+ if op is BinaryOps.ADD:
156
+ if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x # pylint: disable=possibly-used-before-assignment
157
+ if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y # pylint: disable=possibly-used-before-assignment
158
+ if op is BinaryOps.SUB and y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
159
+ if op is BinaryOps.MUL:
160
+ if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0, -1):
161
+ return y if val == 1 else y.const(0) if val == 0 else y.e(UnaryOps.NEG)
162
+ if y.is_unrealized_unmasked_const() and (val := float(y.base.arg)) in (1, 0, -1):
163
+ return x if val == 1 else x.const(0) if val == 0 else x.e(UnaryOps.NEG)
164
+ if op is BinaryOps.DIV and dtypes.is_float(x.dtype) and y.is_unrealized_unmasked_const() and y.base.arg != 0:
165
+ return x.e(BinaryOps.MUL, x.const(1 / y.base.arg))
166
+
167
+ return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
120
168
 
121
169
  # *** reduce ops ***
122
170
 
123
- def _reduce_op(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
124
- if self.shape == tuple(new_shape): return self
125
- unbound_new_shape = tuple(s.unbind()[0] if not isinstance(s, int) else s for s in new_shape)
126
- return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, unbound_new_shape, (self,))
171
+ def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
172
+ assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
173
+ axis = tuple(x for x in axis if self.shape[x] != 1)
174
+ if len(axis) == 0: return self
175
+ new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
176
+ return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, axis, (self,))
177
+
178
+ def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
179
+ new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
180
+ # TODO: this logic should move to the scheduler
181
+ if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape)
182
+
183
+ # const folding
184
+ if self.is_unrealized_unmasked_const():
185
+ return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
127
186
 
128
- def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
129
- assert len(self.shape) == len(new_shape) and all(s == ns or ns == 1 for s,ns in zip(self.shape, new_shape)), \
130
- f"reduce shape lens must match {self.shape} {new_shape}"
131
187
  # TODO: can we split symbolic shape if the reduce axis is not symbolic?
132
- if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
133
- return self._reduce_op(op, new_shape)
134
- heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore # noqa: E501
135
- if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape)
136
- # choose largest divisor (>=16) to split on, penalize large strides
137
- def splitted_shape(dim_aft_div):
138
- return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
139
- return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
188
+ if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \
189
+ prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
190
+ return self._reduce_op(op, axis)
191
+
192
+ # if there are few globals, make some reduces into globals by splitting into two kernels
193
+ # cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
194
+ # ~2**10 should be enough if GROUP is used
195
+ # 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
196
+ # split is moved to the end to provide maximum locality for the second phase reduce.
197
+ self_real_strides = self.st.real_strides(ignore_valid=True)
198
+ split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
199
+ if self.shape[i] % x == 0 and self_real_strides[i] != 0]
200
+ if not split_candidates: return self._reduce_op(op, axis)
201
+ dim_to_split, divisor = split_candidates[0]
202
+ splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
203
+ splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
204
+ if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
205
+ return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
140
206
 
141
207
  # *** movement ops ***
142
208
 
143
209
  def _view(self, new_st:ShapeTracker) -> LazyBuffer:
210
+ if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
211
+ return self.const(0, new_st.shape)
144
212
  if new_st.contiguous and self.base.shape == new_st.shape: return self.base
145
213
  return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
146
214
 
@@ -150,170 +218,3 @@ class LazyBuffer:
150
218
  def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
151
219
  def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
152
220
  def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))
153
-
154
- # *** schedule creation ***
155
-
156
- # recursively create a lazyop
157
- def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
158
- realizes:Set[LazyBuffer], cache, first=True) -> LazyOp:
159
- if (buf, st) in cache: return cache[(buf, st)]
160
- if buf != buf.base:
161
- st = buf.st + st
162
- buf = buf.base
163
- # all buffers here are base now
164
- assert buf.op is not None
165
-
166
- # consts are always fused and generated
167
- if buf.op == LoadOps.CONST:
168
- # TODO: make shapetracker unbind also return var_vals
169
- var_vals.update(merge_dicts([var_vals, st.var_vals]))
170
- return LazyOp(BufferOps.CONST, (), ConstBuffer(float(buf.arg), buf.dtype, st.simplify().unbind()))
171
-
172
- # if we aren't fusing it, it's a load and we add it to the inputs
173
- if buf.realized or (buf in realizes and not first):
174
- if buf not in inputs: inputs.append(buf)
175
- var_vals.update(merge_dicts([var_vals, st.var_vals]))
176
- return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, st.simplify().unbind()))
177
-
178
- # if a CONTIGUOUS made it all the way here, just skip it
179
- if buf.op == LoadOps.CONTIGUOUS:
180
- assert first
181
- return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False)
182
-
183
- # if it's a reduce, we have to change the shapetracker
184
- if buf.op in ReduceOps:
185
- assert st.contiguous, "ReduceOps late fusion must be contiguous"
186
- st = ShapeTracker.from_shape(buf.srcs[0].shape)
187
-
188
- # otherwise we fuse it like normal
189
- cache[(buf, st)] = ret = LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False) for x in buf.srcs), buf.arg)
190
- return ret
191
-
192
- # recursively walk back in the graph to create the schedule
193
- def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyBuffer],
194
- reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> List[ScheduleItem]:
195
- if out in seen or out.realized or out.op == LoadOps.CONST: return []
196
- assert out.base == out
197
- seen.add(out)
198
-
199
- inputs: List[LazyBuffer] = []
200
- var_vals: Dict[Variable, int] = out.st.var_vals.copy()
201
- if out.op == LoadOps.COPY:
202
- op, inputs = LazyOp(LoadOps.COPY, (), out.srcs[0].base), [out.srcs[0].base]
203
- elif out.op == LoadOps.CUSTOM:
204
- op, inputs = LazyOp(LoadOps.CUSTOM, (), out.arg), list(out.srcs)
205
- elif out.op == LoadOps.EMPTY:
206
- op = LazyOp(LoadOps.EMPTY)
207
- else:
208
- output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
209
- op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={})
210
- op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()))
211
-
212
- return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem(op, out, tuple(inputs), var_vals)]
213
-
214
- # recursively search the entire graph for all LazyBuffers, insert realizes after expands
215
- def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
216
- simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]]):
217
- if buf in allbufs or buf.base.realized: return
218
- log_lazybuffer(buf)
219
- if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
220
- not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
221
- if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
222
- buf.dtype = dtypes.float32 # NOTE; this is what makes the dtype above not match
223
- if buf.base != buf:
224
- # realize all places where the buffer is expanded
225
- if prod(buf.base.st.shape) < prod(buf.st.shape):
226
- if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
227
- prod(buf.base.st.shape) == prod([y-x for x,y in buf.st.views[-1].mask]):
228
- simple_pads.add(buf.base)
229
- else:
230
- realizes.add(buf.base)
231
- return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
232
- if buf.forced_realize: realizes.add(buf)
233
- allbufs[buf] = None
234
- if buf.op in LoadOps: realizes.add(buf.base)
235
- if buf.op == LoadOps.COPY:
236
- assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
237
- realizes.add(buf.srcs[0].base)
238
- for x in buf.srcs:
239
- children[x.base][buf] = None
240
- _recurse_lb(x, realizes, allbufs, simple_pads, children)
241
-
242
- UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2}
243
- def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
244
- if buf in realizes or buf.realized: return True
245
- # NOTE: this broke to_image_idx and coder with JIT
246
- if buf.op in UNSAFE_PAD_OPS: return False
247
- return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
248
-
249
- def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
250
- if seen is None: seen = set()
251
- for out in outs: log_lazybuffer(out, scheduled=True)
252
-
253
- # start by just realizing the buffers passed in
254
- realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized])
255
- allbufs: Dict[LazyBuffer, None] = {}
256
- simple_pads: Set[LazyBuffer] = set()
257
- children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
258
- for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children)
259
-
260
- # check if we have to realize pads
261
- for p in simple_pads:
262
- if not _is_padding_okay(p, realizes):
263
- realizes.add(p)
264
-
265
- # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
266
- reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
267
- for r in allbufs.keys():
268
- if r != r.base or r.op not in ReduceOps or r in realizes: continue
269
-
270
- # follow the reduce down
271
- child_set: Dict[LazyBuffer, ShapeTracker] = {r: r.st}
272
- realized_children: Dict[LazyBuffer, ShapeTracker] = {}
273
- forced_realize = False
274
- can_chase = True
275
- while not forced_realize and len(child_set):
276
- next_child_set = {}
277
- for tr,st in child_set.items():
278
- if tr in realizes:
279
- realized_children[tr] = st
280
- # can only have one output buffer
281
- # can only reduce contiguous
282
- # max one reduceop per kernel
283
- if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
284
- can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
285
- forced_realize = True
286
- break
287
- continue
288
- for tr_next in children[tr].keys():
289
- if not tr_next.realized:
290
- # max one reduceop per kernel
291
- if tr_next.op in ReduceOps:
292
- forced_realize = True
293
- break
294
- st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
295
- if len(st_childs) > 1:
296
- forced_realize = True
297
- break
298
- next_child_set[tr_next] = st + st_childs[0].st
299
- child_set = next_child_set
300
- if forced_realize:
301
- tr = r
302
- if can_chase:
303
- # can chase this down to contiguous children
304
- st = tr.st
305
- while len(children[tr]) == 1:
306
- tr_next = next(iter(children[tr].keys()))
307
- st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
308
- if len(st_childs) > 1: break
309
- if st.size != st_childs[0].st.size: break
310
- st = st + st_childs[0].st
311
- if not st.contiguous or tr_next.op in ReduceOps: break
312
- tr = tr_next
313
- reduce_for_op[tr] = r
314
- realizes.add(tr)
315
- else:
316
- assert len(realized_children) == 1
317
- reduce_for_op[next(iter(realized_children.keys()))] = r
318
-
319
- return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in outs)
tinygrad/multi.py ADDED
@@ -0,0 +1,169 @@
1
+ from __future__ import annotations
2
+ from typing import Optional, Union, Any, Tuple, List
3
+ import functools, itertools, operator
4
+ from tinygrad.helpers import all_same, all_int, dedup, round_up, prod, DEBUG, RING
5
+ from tinygrad.dtype import DType, ConstType
6
+ from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
7
+ from tinygrad.lazy import LazyBuffer
8
+ from tinygrad.shape.shapetracker import sint
9
+
10
+ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
11
+ assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
12
+ assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
13
+ bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[op]
14
+
15
+ n_lbs, dim = len(lbs), prod(lbs[0].shape)
16
+ # Ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
17
+ # so just fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
18
+ use_ring = (RING >= 2 or (n_lbs > 2 and dim > 256_000 and RING >= 1))
19
+ if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{dim} | {lbs[0].dtype}")
20
+ if not use_ring:
21
+ return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
22
+ factor = max(f for f in [32, 16, 8, 4, 2, 1] if dim % f == 0)
23
+ base, left = (dim // factor) // n_lbs, (dim // factor) % n_lbs
24
+ c_lens = [(base + 1) * factor if i < left else base * factor for i in range(n_lbs)]
25
+ acc = 0
26
+ chunks = [(acc, (acc := acc + i)) for i in c_lens if i > 0]
27
+ chunked = [[lb.reshape((dim,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
28
+
29
+ # Scatter-reduce step
30
+ for step in range(n_lbs - 1):
31
+ for i in range(len(chunks)):
32
+ s, r = (i+step)%n_lbs, (i+step+1)%n_lbs
33
+ chunked[r][i] = chunked[r][i].e(bop, chunked[s][i].copy_to_device(chunked[r][i].device, force=True))
34
+
35
+ # Allgather step
36
+ for step in range(n_lbs - 1):
37
+ for i in range(len(chunks)):
38
+ s, r = (i+step-1)%n_lbs, (i+step)%n_lbs
39
+ chunked[r][i] = chunked[s][i].copy_to_device(chunked[r][i].device, force=True)
40
+
41
+ # Assemble chunks back
42
+ pads = [((s,dim-e),) for s,e in chunks]
43
+ return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [c.pad(pads[i]) for i,c in enumerate(lb_c)]).reshape(lbs[0].shape) for lb_c in chunked]
44
+
45
+ def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
46
+ if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
47
+ sz = round_up(lbs[0].shape[axis], len(lbs)) // len(lbs)
48
+ return [lb.shrink(tuple((0,s) if a != axis else (sz*i,min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
49
+
50
+ class MultiLazyBuffer:
51
+ def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None):
52
+ assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them"
53
+ assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}"
54
+ self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
55
+ if axis is not None:
56
+ splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0))
57
+ self.bounds = [(st,ed) for st,ed in zip(splits, splits[1:])]
58
+
59
+ @property
60
+ def shape(self):
61
+ return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
62
+
63
+ @property
64
+ def size(self): return sum(x.size for x in self.real_lbs)
65
+
66
+ @property
67
+ def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r]
68
+
69
+ def __repr__(self):
70
+ return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
71
+
72
+ @staticmethod
73
+ def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
74
+ lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unmasked_const() else lb] * len(devices)
75
+ sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)]
76
+ return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous() for lb in sharded_lbs], axis)
77
+
78
+ def copy_to_device(self, device:str) -> LazyBuffer:
79
+ if self.axis is None: return self.lbs[self.real.index(True)].copy_to_device(device)
80
+ sz = self.lbs[0].shape[self.axis]
81
+ llbs = []
82
+ for i,lb in enumerate([lb.copy_to_device(device) for lb in self.real_lbs]):
83
+ pad_arg = tuple((0,0) if a != self.axis else (sz*i, max(0, self.shape[self.axis]-sz*(i+1))) for a in range(len(lb.shape)))
84
+ llbs.append(lb.pad(pad_arg))
85
+ return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
86
+
87
+ # passthroughs
88
+ def is_realized(self) -> bool: return all(lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True)
89
+ def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
90
+ def const(self, val:ConstType) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real)
91
+ def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
92
+ def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
93
+
94
+ # elementwise is simple
95
+ def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer:
96
+ msrcs = (self,)+in_srcs
97
+ assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
98
+ assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
99
+
100
+ # NOTE: they all have to share an axis, we always choose [-1]
101
+ axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None
102
+ srcs = []
103
+ not_all_real = any(not all(mlb.real) for mlb in msrcs)
104
+ new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real
105
+ assert any(new_real), "output contains no real lb"
106
+ for mlb in msrcs:
107
+ if mlb.axis == axis or not_all_real: srcs.append(mlb.lbs)
108
+ elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis))
109
+ else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
110
+ # NOTE: lsrcs[-1].const(0) is correct for where
111
+ return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) if r else lsrcs[-1].const(0) for lsrcs,r in zip(zip(*srcs),new_real)], axis, new_real)
112
+
113
+ def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
114
+ return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
115
+
116
+ def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
117
+ if self.axis is not None and self.axis in axis:
118
+ # all-reduce on sharded axes
119
+ reduced_parts = [(x if r else x.const(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]
120
+ if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
121
+ return MultiLazyBuffer(reduced_parts, None, self.real)
122
+ # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
123
+ return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)
124
+
125
+ # *** movement ops ***
126
+
127
+ def reshape(self, arg:Tuple[sint, ...]):
128
+ if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
129
+ arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
130
+ # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
131
+ # todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
132
+ new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
133
+ if arg[new_axis] != self.shape[self.axis]:
134
+ assert self.shape[self.axis] % len(self.real_lbs) == 0, f"cannot reshape on-axis for uneven shard {self.axis} {self.shape} {len(self.real_lbs)}"
135
+ assert arg[new_axis] % len(self.real_lbs) == 0, f"new on-axis shape must divide evenly between devices {new_axis} {arg} {len(self.real_lbs)}"
136
+ return MultiLazyBuffer([x.reshape(tuple(s if a != new_axis else
137
+ x.shape[self.axis] if s == self.shape[self.axis] else
138
+ s // len(self.real_lbs) for a,s in enumerate(arg))) for x in self.lbs],
139
+ new_axis, self.real)
140
+
141
+ def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
142
+ assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}"
143
+ # pad on shard axis -> fill others with zeros and set real to all True
144
+ if self.axis is not None and arg[self.axis] != (0,0):
145
+ # pad back to whole axis, remove real mask
146
+ assert all(arg[i] == (0, 0) or i == self.axis for i in range(len(self.shape))), "cannot pad sharded and non-sharded axis at the same time"
147
+ assert arg[self.axis] == (sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i < self.real.index(True)), \
148
+ sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i > self.real.index(True))), "can only pad to whole axis"
149
+ return MultiLazyBuffer([x if r else x.const(0) for x,r in zip(self.lbs, self.real)], self.axis)
150
+ return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)
151
+ def expand(self, arg:Tuple[sint, ...]):
152
+ # NOTE: this assert isn't needed, sharded axis can have dim 1
153
+ assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}"
154
+ return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real)
155
+ def permute(self, arg:Tuple[int, ...]):
156
+ # all permutes supported!
157
+ return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real)
158
+ def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
159
+ assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
160
+ if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
161
+ assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time"
162
+ idx = self.bounds.index(arg[self.axis])
163
+ # zero out other lbs to not create lb reference
164
+ return MultiLazyBuffer([lb if i==idx else lb.const(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
165
+ return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs],
166
+ self.axis, self.real)
167
+ def stride(self, arg:Tuple[int, ...]):
168
+ assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
169
+ return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)