tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
- tinygrad/codegen/kernel.py +572 -83
- tinygrad/codegen/linearizer.py +415 -395
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +183 -0
- tinygrad/dtype.py +113 -0
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +76 -55
- tinygrad/helpers.py +196 -89
- tinygrad/lazy.py +210 -371
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +202 -22
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +112 -32
- tinygrad/nn/state.py +136 -39
- tinygrad/ops.py +119 -202
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +353 -166
- tinygrad/renderer/llvmir.py +150 -138
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +81 -0
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +75 -0
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +24 -77
- tinygrad/runtime/ops_cuda.py +175 -89
- tinygrad/runtime/ops_disk.py +56 -33
- tinygrad/runtime/ops_gpu.py +92 -95
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +39 -60
- tinygrad/runtime/ops_metal.py +92 -74
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +86 -254
- tinygrad/shape/symbolic.py +166 -141
- tinygrad/shape/view.py +296 -0
- tinygrad/tensor.py +2619 -448
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- tinygrad-0.9.0.dist-info/METADATA +227 -0
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/assembly.py +0 -190
- tinygrad/codegen/optimizer.py +0 -379
- tinygrad/codegen/search.py +0 -72
- tinygrad/graph.py +0 -83
- tinygrad/jit.py +0 -57
- tinygrad/nn/image.py +0 -100
- tinygrad/renderer/assembly_arm64.py +0 -169
- tinygrad/renderer/assembly_ptx.py +0 -98
- tinygrad/renderer/wgsl.py +0 -53
- tinygrad/runtime/lib.py +0 -113
- tinygrad/runtime/ops_cpu.py +0 -51
- tinygrad/runtime/ops_hip.py +0 -82
- tinygrad/runtime/ops_shm.py +0 -29
- tinygrad/runtime/ops_torch.py +0 -30
- tinygrad/runtime/ops_webgpu.py +0 -45
- tinygrad-0.7.0.dist-info/METADATA +0 -212
- tinygrad-0.7.0.dist-info/RECORD +0 -40
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/lazy.py
CHANGED
@@ -1,381 +1,220 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import
|
3
|
-
from typing import
|
4
|
-
from
|
5
|
-
|
6
|
-
import
|
7
|
-
from tinygrad.
|
8
|
-
from tinygrad.
|
9
|
-
from tinygrad.
|
10
|
-
from
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
P2P = getenv("P2P", 0)
|
24
|
-
|
25
|
-
# TODO: movement ops that only change shape are really nops. treat them as such
|
26
|
-
REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
|
27
|
-
MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS, SIMPLIFY_SUM_RESHAPE_EXPAND_SUM = OPT>=2, OPT>=2, OPT>=2 # shuffle pad ops is fine now since we only push to merge binops
|
28
|
-
PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3
|
29
|
-
|
30
|
-
def _simplify_sum_reshape_expand_sum(self:LazyBuffer, src: Any, prev_src: Any) -> Optional[LazyOp]:
|
31
|
-
if prev_src.op.op == MovementOps.EXPAND:
|
32
|
-
if src.op.op == ReduceOps.SUM:
|
33
|
-
if src.shape == self.shape:
|
34
|
-
dim_difference = [i for i, (a, b) in enumerate(zip(prev_src.shape, self.shape)) if a != b]
|
35
|
-
# NOTE: we can probably also handle the case where more than one dimension is different with more thought
|
36
|
-
if len(dim_difference) == 1:
|
37
|
-
expansion_index = dim_difference[0]
|
38
|
-
expansion_size = prev_src.shape[expansion_index]
|
39
|
-
return LazyOp(BinaryOps.MUL, (src, LazyBuffer.const(src, expansion_size)))
|
40
|
-
return None
|
41
|
-
|
42
|
-
# **** realize functions ****
|
43
|
-
def _ast_reduceops(self:LazyBuffer) -> LazyOp:
|
44
|
-
# TODO: this can also corealize a binary op after the reduce, not just before
|
45
|
-
# NOTE: mypy doesn't know that if not src.realized, then src.op must be a LazyOp so we have to ignore a bunch of warnings
|
46
|
-
src = self.op.src[0]
|
47
|
-
if not src.realized:
|
48
|
-
# When a tensor is reduced, reshaped/expanded back and then reduced again along the same axis,
|
49
|
-
# it's equivalent to performing the initial reduction and multiplying the result
|
50
|
-
# by the size of the expanded dimension.
|
51
|
-
if SIMPLIFY_SUM_RESHAPE_EXPAND_SUM and src.op.op == MovementOps.EXPAND: # type: ignore
|
52
|
-
expanded = src.op.src[0] # type: ignore
|
53
|
-
if expanded.op.op == MovementOps.RESHAPE: # type: ignore
|
54
|
-
reshaped = expanded.op.src[0] # type: ignore
|
55
|
-
simplified = _simplify_sum_reshape_expand_sum(self, reshaped, src)
|
56
|
-
else:
|
57
|
-
simplified = _simplify_sum_reshape_expand_sum(self, expanded, src)
|
58
|
-
if simplified: return simplified
|
59
|
-
if MERGE_ELEMENTWISE_INTO_REDUCE and src.optype is BinaryOps and len(src.children) <= 1:
|
60
|
-
# If we did remove an expand above, we might stumble back into a case where the reduction is not necessary
|
61
|
-
if src.shape == self.shape:
|
62
|
-
return src.op # type: ignore
|
63
|
-
src = src.op # type: ignore
|
64
|
-
return LazyOp(self.op.op, (src,), self.op.arg)
|
65
|
-
|
66
|
-
# this supports late merging an upstream Reduce op and even an Elementwise op above that
|
67
|
-
def _ast_binaryops(self:LazyBuffer) -> LazyOp:
|
68
|
-
real_srcs: Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in self.op.buffers}
|
69
|
-
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
|
70
|
-
# TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd
|
71
|
-
psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and not x.realized and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
|
72
|
-
intermediate_shape: Tuple[int, ...] = self.shape
|
73
|
-
if MERGE_ONE_REDUCE_INTO_ELEMENTWISE and psrcs:
|
74
|
-
psrc = psrcs[0] # NOTE: right now we can't handle multiple, as we'd have to check for loop
|
75
|
-
if psrc[1].optype == ReduceOps:
|
76
|
-
top = _ast_reduceops(psrc[1])
|
77
|
-
real_srcs[psrc[0]] = top
|
78
|
-
real_srcs.update({x:x for x in top.buffers}) # the reduce op buffers are not modified
|
79
|
-
|
80
|
-
# if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs
|
81
|
-
if psrc[0].shape != psrc[1].shape:
|
82
|
-
intermediate_shape = psrc[1].shape
|
83
|
-
assert psrc[0].shape == self.shape, f"shape mismatch {psrc[0].shape} != {self.shape}"
|
84
|
-
|
85
|
-
# reshape all the late ops into the output shape
|
86
|
-
# NOTE: these RESHAPEs will return self if they don't change the shape
|
87
|
-
for x in real_srcs.keys():
|
88
|
-
if not real_srcs[x]: real_srcs[x] = x.reshape(intermediate_shape)
|
89
|
-
ast = self.op.map_buffers(real_srcs)
|
90
|
-
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast
|
91
|
-
|
92
|
-
# **** lazy operations ****
|
93
|
-
|
94
|
-
def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast(LazyBuffer, root.op.src[0])) if getattr(root, 'op', None) and len(root.op.src) == 1 else root
|
95
|
-
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
|
96
|
-
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
|
97
|
-
|
98
|
-
lazycache: WeakValueDictionary = WeakValueDictionary()
|
99
|
-
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType):
|
100
|
-
# fromcpu aren't cached
|
101
|
-
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype)
|
102
|
-
|
103
|
-
# wop is the deduping key. i feel this used to compare more deeply
|
104
|
-
wop = (device, dtype, optype, ref(op))
|
105
|
-
if wop in lazycache:
|
106
|
-
for x in op.buffers: x.children.add(lazycache[wop])
|
107
|
-
return lazycache[wop]
|
108
|
-
|
109
|
-
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype)
|
2
|
+
import math
|
3
|
+
from typing import Union, Optional, Any, Tuple, List
|
4
|
+
from tinygrad.dtype import dtypes, DType, ConstType
|
5
|
+
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG
|
6
|
+
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
|
7
|
+
from tinygrad.shape.symbolic import sint, Variable
|
8
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
9
|
+
from tinygrad.device import Buffer
|
10
|
+
from weakref import ref, ReferenceType, WeakValueDictionary
|
11
|
+
|
12
|
+
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
|
13
|
+
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
14
|
+
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
|
15
|
+
if st.size == 0: op, arg, srcs, base = LoadOps.CONST, 0, (), None
|
16
|
+
if op is LoadOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, Variable) else arg, True
|
17
|
+
|
18
|
+
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
|
19
|
+
if enable_cache and (rret := lazycache.get(cache_key, None)): return rret
|
20
|
+
|
21
|
+
ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base)
|
22
|
+
if enable_cache: lazycache[cache_key] = ret
|
110
23
|
return ret
|
111
24
|
|
112
|
-
|
113
|
-
|
25
|
+
view_supported_devices = {"LLVM", "CLANG", "CUDA", "DISK"}
|
114
26
|
class LazyBuffer:
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
self.device, self.
|
119
|
-
self.
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
def realize(self:LazyBuffer) -> LazyBuffer:
|
141
|
-
if not self.realized:
|
142
|
-
# get real ops first
|
143
|
-
if self.optype is BinaryOps: self.op = _ast_binaryops(self)
|
144
|
-
elif self.optype is ReduceOps:
|
145
|
-
self.op = _ast_reduceops(self)
|
146
|
-
if self.op.op in BinaryOps: self.op = _ast_binaryops(self)
|
147
|
-
elif self.optype is LoadOps: LOAD_OPS_DISPATCHER[cast(LoadOps, self.op.op)](self)
|
148
|
-
# run the ast if we still have to, and log the op
|
149
|
-
if not self.realized:
|
150
|
-
for x in self.op.buffers: x.realize()
|
151
|
-
|
152
|
-
# HACK: image shape can be wrong, hot cast it back to a normal float
|
153
|
-
if isinstance(self.dtype, ImageDType) and self.optype != MovementOps and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
|
154
|
-
if self.op.op == MovementOps.RESHAPE:
|
155
|
-
# put CAST before the final RESHAPE
|
156
|
-
self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, (dtypes.float32, False)),), self.op.arg)
|
157
|
-
else:
|
158
|
-
self.op = LazyOp(UnaryOps.CAST, (self.op,), (dtypes.float32, False))
|
159
|
-
self.dtype = dtypes.float32
|
160
|
-
self.realized = Device[self.device].exec_ast(self.op, output=self, **self._device_extra_args())
|
161
|
-
|
162
|
-
assert self.realized and isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}"
|
163
|
-
# HACK: allow hot casting of images
|
164
|
-
assert self.realized.dtype == self.dtype or self.dtype.__class__ is ImageDType, f"dtype mismatch on realize got {self.realized.dtype} expected {self.dtype}"
|
165
|
-
self.dtype = self.realized.dtype
|
166
|
-
|
167
|
-
# log to the graph
|
168
|
-
if (DEBUG or GRAPH) and (self.realized.__class__ is not RawConst or GRAPH >= 2):
|
169
|
-
log_op(self, self.op)
|
170
|
-
|
171
|
-
# no need to keep the op after realization
|
172
|
-
del self.op
|
173
|
-
return self
|
27
|
+
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
|
28
|
+
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
29
|
+
base:Optional[LazyBuffer]=None):
|
30
|
+
self.device, self.st, self.dtype, self.shape, self.size = device, st, dtype, st.shape, st.size
|
31
|
+
self._base: Optional[LazyBuffer] = None
|
32
|
+
if base is None:
|
33
|
+
# properties on base
|
34
|
+
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
|
35
|
+
assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
|
36
|
+
|
37
|
+
if (self.op is LoadOps.CONTIGUOUS or self.op is UnaryOps.BITCAST) and srcs[0].st.consecutive and \
|
38
|
+
not srcs[0].is_unrealized_const() and device.split(":")[0] in view_supported_devices:
|
39
|
+
# some LazyBuffers can be processed with only a view, no AST required
|
40
|
+
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
41
|
+
self.op = LoadOps.VIEW
|
42
|
+
else:
|
43
|
+
self.buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
|
44
|
+
self.buffer.ref(1)
|
45
|
+
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
|
46
|
+
self.forced_realize = False
|
47
|
+
else:
|
48
|
+
# properties on view
|
49
|
+
assert base.base == base, "base must be a base itself"
|
50
|
+
self._base = base
|
174
51
|
|
175
|
-
|
176
|
-
|
177
|
-
|
52
|
+
def __del__(self):
|
53
|
+
if hasattr(self, 'buffer'): self.buffer.ref(-1)
|
54
|
+
|
55
|
+
def __repr__(self) -> str:
|
56
|
+
return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base != self else (self.op, self.realized)}>"
|
178
57
|
|
179
|
-
|
180
|
-
def
|
181
|
-
# NOTE:
|
182
|
-
return self.
|
58
|
+
@property
|
59
|
+
def realized(self) -> Optional[Buffer]:
|
60
|
+
# NOTE: we check for a lack of srcs instead of an allocated buffer to make unrealized assigns return None here
|
61
|
+
return self.buffer if self._base is None and not hasattr(self, 'srcs') else None
|
62
|
+
|
63
|
+
# NOTE: this has to be a function to prevent self reference
|
64
|
+
@property
|
65
|
+
def base(self) -> LazyBuffer: return self._base if self._base is not None else self
|
183
66
|
|
184
|
-
|
185
|
-
|
186
|
-
|
67
|
+
# same API as multi
|
68
|
+
@property
|
69
|
+
def lbs(self) -> List[LazyBuffer]: return [self]
|
187
70
|
|
188
71
|
@staticmethod
|
189
|
-
def
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
new_srcs: List[LazyBuffer] = []
|
211
|
-
for x in srcs:
|
212
|
-
if not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1:
|
213
|
-
x.op.src[0].children.discard(x)
|
214
|
-
new_srcs.append(cast(LazyBuffer, x.op.src[0]))
|
215
|
-
else:
|
216
|
-
new_srcs.append(x)
|
217
|
-
return new_srcs[0].e(op, *new_srcs[1:], arg=arg).contiguous()
|
218
|
-
|
219
|
-
if MERGE_ELEMENTWISE_OPS:
|
220
|
-
# remove the buffers from any (childless) BinaryOps that feed into this
|
221
|
-
srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore
|
222
|
-
|
223
|
-
return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype)
|
224
|
-
|
225
|
-
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[Union[Node,int], ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer:
|
226
|
-
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children:
|
227
|
-
return self.op.replace_with_movement_ops([(op, arg)])
|
228
|
-
ret = create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype)
|
229
|
-
if REMOVE_MOVEMENT_NOPS and not self.realized and not ret.realized and ret.st.contiguous:
|
230
|
-
# MovementOps aren't stacked any more, they each have one parent, find the root
|
231
|
-
root = get_movementroot(self)
|
232
|
-
if root.st.contiguous and root != self and prod(ret.st.shape) == prod(root.shape):
|
233
|
-
return root.reshape(ret.st.shape)
|
234
|
-
return ret
|
235
|
-
|
236
|
-
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
237
|
-
if self.shape == tuple(new_shape): return self
|
238
|
-
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
|
239
|
-
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype)
|
240
|
-
|
241
|
-
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
242
|
-
if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
|
243
|
-
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore
|
244
|
-
if divisor < 16 or heuristic < 0.125: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides.
|
245
|
-
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
|
246
|
-
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
|
247
|
-
|
248
|
-
def reshape(self:LazyBuffer, arg:Tuple[Union[Node, int], ...]) -> LazyBuffer:
|
249
|
-
if self.shape == arg: return self
|
250
|
-
if not self.realized and self.op.op == MovementOps.RESHAPE:
|
251
|
-
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
|
252
|
-
return self.op.src[0].reshape(arg)
|
253
|
-
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).reshape(arg), MovementOps.RESHAPE, arg)
|
254
|
-
|
255
|
-
def pad(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
256
|
-
if all(b == 0 and e == 0 for b,e in arg): return self
|
257
|
-
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
|
258
|
-
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg)
|
259
|
-
|
260
|
-
def expand(self: LazyBuffer, arg:Tuple[Union[Node,int], ...]) -> LazyBuffer:
|
261
|
-
if self.shape == arg: return self
|
262
|
-
if not self.realized and self.op.op == MovementOps.EXPAND:
|
263
|
-
return self.op.src[0].expand(arg)
|
264
|
-
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).expand(arg), MovementOps.EXPAND, arg)
|
265
|
-
|
266
|
-
def permute(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
267
|
-
if arg == tuple(range(len(self.shape))): return self
|
268
|
-
if not self.realized and self.op.op == MovementOps.PERMUTE: return self.op.src[0].permute(tuple([self.op.arg[i] for i in arg]))
|
269
|
-
if not self.realized:
|
270
|
-
if PUSH_PERMUTES and self.optype == ReduceOps:
|
271
|
-
# reduceops have one buffer input, permute it
|
272
|
-
narg = tuple([self.op.arg[arg[i]] for i in range(len(arg))])
|
273
|
-
src, rop = self.op.src[0], self.op.op
|
274
|
-
src.children.discard(self)
|
275
|
-
del self # TODO: why doesn't this delete remove it from the children
|
276
|
-
return src.permute(arg).reduce_op(cast(ReduceOps, rop), narg)
|
277
|
-
|
278
|
-
# move permutes before expands (always, this is safe)
|
279
|
-
if self.op.op == MovementOps.EXPAND:
|
280
|
-
return self.op.src[0].permute(arg).expand(tuple([self.op.arg[a] for a in arg]))
|
281
|
-
|
282
|
-
# move permutes before reshapes if we can
|
283
|
-
if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and self.op.src[0].__class__ is LazyBuffer:
|
284
|
-
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
|
285
|
-
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
|
286
|
-
return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(ShapeTracker(self.st).permute(arg).shape)
|
287
|
-
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).permute(arg), MovementOps.PERMUTE, arg)
|
288
|
-
|
289
|
-
def shrink(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
290
|
-
if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self
|
291
|
-
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
|
292
|
-
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).shrink(arg), MovementOps.SHRINK, arg)
|
293
|
-
|
294
|
-
def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
295
|
-
local_st = ShapeTracker(self.shape).stride(arg)
|
296
|
-
if self.shape == local_st.shape and local_st.contiguous: return self
|
297
|
-
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg)))
|
298
|
-
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).stride(arg), MovementOps.STRIDE, arg)
|
72
|
+
def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
|
73
|
+
assert isinstance(src, tuple)
|
74
|
+
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
|
75
|
+
|
76
|
+
def const(self, val:ConstType, shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer:
|
77
|
+
shape = self.shape if shape is None else shape
|
78
|
+
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
|
79
|
+
|
80
|
+
def is_realized(self) -> bool: return self.base.realized is not None
|
81
|
+
|
82
|
+
def assign(self, x:LazyBuffer) -> LazyBuffer:
|
83
|
+
assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
|
84
|
+
return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base))
|
85
|
+
|
86
|
+
def contiguous(self):
|
87
|
+
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
88
|
+
ret = self.e(LoadOps.CONTIGUOUS)
|
89
|
+
if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
|
90
|
+
return ret
|
91
|
+
self.base.forced_realize = True
|
92
|
+
return self
|
299
93
|
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
94
|
+
def cast(self, dtype:DType, bitcast:bool=False):
|
95
|
+
if self.dtype == dtype: return self
|
96
|
+
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
|
97
|
+
if self.is_unrealized_unmasked_const() and not bitcast:
|
98
|
+
return create_lazybuffer(self.device, self.st, dtype, LoadOps.CONST, dtypes.as_const(self.base.arg, dtype))
|
99
|
+
# TODO: applying this makes gpt2 slower
|
100
|
+
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
|
101
|
+
return self.base.cast(dtype, bitcast)._view(self.st)
|
102
|
+
new_shape = self.shape
|
103
|
+
if bitcast and self.dtype.itemsize != dtype.itemsize:
|
104
|
+
if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
|
105
|
+
if not all_int(new_shape): raise RuntimeError("shape changing bitcast with symbolic shape isn't supported yet")
|
106
|
+
# https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
|
107
|
+
if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
|
108
|
+
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
|
109
|
+
cast_op = UnaryOps.BITCAST if bitcast else UnaryOps.CAST
|
110
|
+
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
|
111
|
+
|
112
|
+
def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST and not isinstance(self.base.arg, Variable)
|
113
|
+
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
114
|
+
|
115
|
+
def _copy(self, device:str) -> LazyBuffer:
|
116
|
+
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, self.buffer.nbytes, (self,), enable_cache=False)
|
117
|
+
|
118
|
+
def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer:
|
119
|
+
# no COPY
|
120
|
+
if self.device == device: return self
|
121
|
+
|
122
|
+
# double COPY = one COPY
|
123
|
+
if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is LoadOps.COPY:
|
124
|
+
return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
|
125
|
+
|
126
|
+
# const doesn't have to be copied (issues with disk tensor)
|
127
|
+
if self.is_unrealized_const():
|
128
|
+
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
|
129
|
+
|
130
|
+
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
131
|
+
if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
|
132
|
+
|
133
|
+
# copy the base and apply the shapetracker on the new device
|
134
|
+
return self.base._copy(device)._view(self.st)
|
135
|
+
|
136
|
+
def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
|
137
|
+
srcs: List[LazyBuffer] = []
|
138
|
+
for s in (self,)+in_srcs:
|
139
|
+
if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
|
140
|
+
srcs.append(root._view(s.base.contiguous_child[1]))
|
141
|
+
else:
|
142
|
+
srcs.append(s)
|
143
|
+
assert all_same(dts:=[x.dtype.scalar() for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]), f"all dtypes must match {dts} on {op}"
|
144
|
+
assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
|
145
|
+
if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
|
146
|
+
if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
|
147
|
+
|
148
|
+
out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else srcs[-1].dtype
|
149
|
+
|
150
|
+
# const folding
|
151
|
+
if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
|
152
|
+
return self.cast(out_dtype).const(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
|
153
|
+
if op is UnaryOps.NEG and self.base.op is UnaryOps.NEG: return self.base.srcs[0]
|
154
|
+
if op in BinaryOps: x, y = self, in_srcs[0]
|
155
|
+
if op is BinaryOps.ADD:
|
156
|
+
if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x # pylint: disable=possibly-used-before-assignment
|
157
|
+
if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y # pylint: disable=possibly-used-before-assignment
|
158
|
+
if op is BinaryOps.SUB and y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
|
159
|
+
if op is BinaryOps.MUL:
|
160
|
+
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0, -1):
|
161
|
+
return y if val == 1 else y.const(0) if val == 0 else y.e(UnaryOps.NEG)
|
162
|
+
if y.is_unrealized_unmasked_const() and (val := float(y.base.arg)) in (1, 0, -1):
|
163
|
+
return x if val == 1 else x.const(0) if val == 0 else x.e(UnaryOps.NEG)
|
164
|
+
if op is BinaryOps.DIV and dtypes.is_float(x.dtype) and y.is_unrealized_unmasked_const() and y.base.arg != 0:
|
165
|
+
return x.e(BinaryOps.MUL, x.const(1 / y.base.arg))
|
166
|
+
|
167
|
+
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
|
168
|
+
|
169
|
+
# *** reduce ops ***
|
170
|
+
|
171
|
+
def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
|
172
|
+
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
|
173
|
+
axis = tuple(x for x in axis if self.shape[x] != 1)
|
174
|
+
if len(axis) == 0: return self
|
175
|
+
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
176
|
+
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, axis, (self,))
|
177
|
+
|
178
|
+
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
|
179
|
+
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
180
|
+
# TODO: this logic should move to the scheduler
|
181
|
+
if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape)
|
182
|
+
|
183
|
+
# const folding
|
184
|
+
if self.is_unrealized_unmasked_const():
|
185
|
+
return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
|
186
|
+
|
187
|
+
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
188
|
+
if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \
|
189
|
+
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
|
190
|
+
return self._reduce_op(op, axis)
|
191
|
+
|
192
|
+
# if there are few globals, make some reduces into globals by splitting into two kernels
|
193
|
+
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
|
194
|
+
# ~2**10 should be enough if GROUP is used
|
195
|
+
# 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
|
196
|
+
# split is moved to the end to provide maximum locality for the second phase reduce.
|
197
|
+
self_real_strides = self.st.real_strides(ignore_valid=True)
|
198
|
+
split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
|
199
|
+
if self.shape[i] % x == 0 and self_real_strides[i] != 0]
|
200
|
+
if not split_candidates: return self._reduce_op(op, axis)
|
201
|
+
dim_to_split, divisor = split_candidates[0]
|
202
|
+
splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
|
203
|
+
splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
|
204
|
+
if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
|
205
|
+
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
|
206
|
+
|
207
|
+
# *** movement ops ***
|
208
|
+
|
209
|
+
def _view(self, new_st:ShapeTracker) -> LazyBuffer:
|
210
|
+
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
|
211
|
+
return self.const(0, new_st.shape)
|
212
|
+
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
|
213
|
+
return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
|
214
|
+
|
215
|
+
def reshape(self, arg:Tuple[sint, ...]): return self._view(self.st.reshape(arg))
|
216
|
+
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.pad(arg))
|
217
|
+
def expand(self, arg:Tuple[sint, ...]): return self._view(self.st.expand(arg))
|
218
|
+
def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
|
219
|
+
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
|
220
|
+
def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))
|