tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/engine/lazy.py
DELETED
@@ -1,228 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
from typing import Optional, Any, Tuple, List, get_args
|
3
|
-
from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType
|
4
|
-
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE
|
5
|
-
from tinygrad.ops import exec_alu, python_alu
|
6
|
-
from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops
|
7
|
-
from tinygrad.shape.shapetracker import ShapeTracker
|
8
|
-
from tinygrad.device import Buffer
|
9
|
-
from weakref import ref, ReferenceType, WeakValueDictionary
|
10
|
-
|
11
|
-
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
|
12
|
-
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
13
|
-
base:Optional[LazyBuffer]=None, enable_cache=bool(LAZYCACHE)):
|
14
|
-
if st.size == 0: op, arg, srcs, base = Ops.CONST, 0, (), None
|
15
|
-
dtype = to_dtype(dtype)
|
16
|
-
if op is Ops.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True
|
17
|
-
|
18
|
-
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
|
19
|
-
if enable_cache and (rret := lazycache.get(cache_key, None)) is not None: return rret
|
20
|
-
|
21
|
-
ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base, metadata=_METADATA.get())
|
22
|
-
if enable_cache: lazycache[cache_key] = ret
|
23
|
-
return ret
|
24
|
-
|
25
|
-
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
|
26
|
-
class LazyBuffer(MathTrait):
|
27
|
-
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
|
28
|
-
op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
29
|
-
base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None):
|
30
|
-
self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, to_dtype(dtype), st.shape, st.size, metadata
|
31
|
-
self._base: Optional[LazyBuffer] = None
|
32
|
-
if base is None:
|
33
|
-
# properties on base
|
34
|
-
self.op, self.arg, self.srcs = op, arg, srcs # this is a UOp, except the src is LazyBuffers and not UOps
|
35
|
-
assert self.op is not Ops.ASSIGN or srcs[0].base.realized is not None, "assign target must be realized"
|
36
|
-
|
37
|
-
if self.op is Ops.BUFFER_VIEW:
|
38
|
-
# some LazyBuffers can be processed with only a view, no AST required
|
39
|
-
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
40
|
-
else:
|
41
|
-
self.buffer = srcs[0].base.buffer if self.op is Ops.ASSIGN else Buffer(device, self.size, self.dtype)
|
42
|
-
self.buffer.ref(1)
|
43
|
-
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
|
44
|
-
self.forced_realize = False
|
45
|
-
else:
|
46
|
-
# properties on view
|
47
|
-
assert base.base == base, "base must be a base itself"
|
48
|
-
self._base = base
|
49
|
-
|
50
|
-
def __del__(self):
|
51
|
-
if hasattr(self, 'buffer'): self.buffer.ref(-1)
|
52
|
-
|
53
|
-
def __repr__(self) -> str:
|
54
|
-
return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base is not self else (self.op, self.realized)}>"
|
55
|
-
|
56
|
-
@property
|
57
|
-
def realized(self) -> Optional[Buffer]:
|
58
|
-
# NOTE: we check for a lack of srcs instead of an allocated buffer to make unrealized assigns return None here
|
59
|
-
return self.buffer if self._base is None and not hasattr(self, 'srcs') else None
|
60
|
-
|
61
|
-
# NOTE: this has to be a function to prevent self reference
|
62
|
-
@property
|
63
|
-
def base(self) -> LazyBuffer: return self._base if self._base is not None else self
|
64
|
-
|
65
|
-
# same API as multi
|
66
|
-
@property
|
67
|
-
def lbs(self) -> List[LazyBuffer]: return [self]
|
68
|
-
|
69
|
-
@staticmethod
|
70
|
-
def metaop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
|
71
|
-
assert isinstance(src, tuple)
|
72
|
-
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
|
73
|
-
|
74
|
-
def const_like(self, b): return self.const_with_shape(b, self.shape)
|
75
|
-
def const_with_shape(self, val:ConstType, shape:Tuple[sint,...]) -> LazyBuffer:
|
76
|
-
assert isinstance(val, get_args(ConstType)), f"{val=} has {type(val)=}, not a ConstType"
|
77
|
-
return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
|
78
|
-
|
79
|
-
@property
|
80
|
-
def is_realized(self) -> bool: return self.base.realized is not None
|
81
|
-
|
82
|
-
def assign(self, x:LazyBuffer) -> LazyBuffer:
|
83
|
-
assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
|
84
|
-
assert self.is_realized, f"assign target must be realized {self}"
|
85
|
-
return LazyBuffer.metaop(Ops.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,),
|
86
|
-
src=(self.base, x), enable_cache=True)
|
87
|
-
|
88
|
-
def can_view(self):
|
89
|
-
return (self.st.consecutive and not self.is_unrealized_const() and not isinstance(self.dtype, ImageDType) and
|
90
|
-
self.device.split(":")[0] in view_supported_devices)
|
91
|
-
|
92
|
-
def contiguous(self, allow_buffer_view=True):
|
93
|
-
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
94
|
-
ret = self.alu(Ops.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(Ops.CONTIGUOUS)
|
95
|
-
if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
|
96
|
-
return ret
|
97
|
-
self.base.forced_realize = True
|
98
|
-
return self
|
99
|
-
|
100
|
-
def bitcast(self, dtype:DType) -> LazyBuffer: return self.cast(dtype, bitcast=True)
|
101
|
-
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True) -> LazyBuffer:
|
102
|
-
if self.dtype == dtype: return self
|
103
|
-
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
|
104
|
-
if self.is_unrealized_unmasked_const() and not bitcast:
|
105
|
-
return create_lazybuffer(self.device, self.st, dtype, Ops.CONST, dtypes.as_const(self.base.arg, dtype))
|
106
|
-
new_shape = self.shape
|
107
|
-
if bitcast and self.dtype.itemsize != dtype.itemsize:
|
108
|
-
if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
|
109
|
-
if not all_int(new_shape): raise RuntimeError("shape changing bitcast with symbolic shape isn't supported yet")
|
110
|
-
# https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
|
111
|
-
if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
|
112
|
-
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
|
113
|
-
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
|
114
|
-
# TODO: applying this makes gpt2 slower
|
115
|
-
return self.base.cast(dtype, bitcast)._view(self.st)
|
116
|
-
cast_op: Ops = (Ops.BUFFER_VIEW if self.can_view() and allow_buffer_view else Ops.BITCAST) if bitcast else Ops.CAST
|
117
|
-
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
|
118
|
-
|
119
|
-
def is_unrealized_const(self): return self.base.realized is None and self.base.op is Ops.CONST and not isinstance(self.base.arg, UOp)
|
120
|
-
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
121
|
-
|
122
|
-
def _copy(self, device:str) -> LazyBuffer:
|
123
|
-
assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}"
|
124
|
-
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, self.buffer.nbytes, (self,), enable_cache=False)
|
125
|
-
|
126
|
-
def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer:
|
127
|
-
# no COPY
|
128
|
-
if self.device == device and not clone: return self
|
129
|
-
|
130
|
-
# double COPY = one COPY
|
131
|
-
if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is Ops.COPY:
|
132
|
-
return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
|
133
|
-
|
134
|
-
# const doesn't have to be copied (issues with disk tensor)
|
135
|
-
if self.is_unrealized_const():
|
136
|
-
return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
|
137
|
-
|
138
|
-
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
139
|
-
if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
|
140
|
-
|
141
|
-
# copy the base and apply the shapetracker on the new device
|
142
|
-
return self.base._copy(device)._view(self.st)
|
143
|
-
|
144
|
-
def clone(self) -> LazyBuffer: return self.copy_to_device(self.device, clone=True)
|
145
|
-
|
146
|
-
def alu(self, op:Ops, *in_srcs:LazyBuffer) -> LazyBuffer:
|
147
|
-
srcs: List[LazyBuffer] = []
|
148
|
-
for s in (self,)+in_srcs:
|
149
|
-
if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
|
150
|
-
srcs.append(root._view(s.base.contiguous_child[1]))
|
151
|
-
else:
|
152
|
-
srcs.append(s)
|
153
|
-
if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is Ops.WHERE else srcs)]):
|
154
|
-
raise AssertionError(f"all dtypes must match {dts} on {op}")
|
155
|
-
assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
|
156
|
-
if op is Ops.WHERE: assert srcs[0].dtype == dtypes.bool, "Ops.WHERE must have the first arg be bool"
|
157
|
-
|
158
|
-
out_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else srcs[-1].dtype
|
159
|
-
|
160
|
-
# const folding
|
161
|
-
if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
|
162
|
-
return self.cast(out_dtype).const_like(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
|
163
|
-
if op in GroupOp.Binary:
|
164
|
-
x, y = self, in_srcs[0]
|
165
|
-
if op is Ops.ADD:
|
166
|
-
if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
|
167
|
-
if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
|
168
|
-
if op is Ops.MUL:
|
169
|
-
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0): return y if val == 1 else y.const_like(0)
|
170
|
-
if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0): return x if val == 1 else x.const_like(0)
|
171
|
-
if op is Ops.IDIV and y.is_unrealized_unmasked_const() and y.base.arg == 1: return x
|
172
|
-
|
173
|
-
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, None, tuple(srcs))
|
174
|
-
|
175
|
-
# *** reduce ops ***
|
176
|
-
|
177
|
-
def _reduce_op(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
|
178
|
-
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
|
179
|
-
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
180
|
-
if len(axis) == 0: return self
|
181
|
-
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, Ops.REDUCE_AXIS, (op, axis), (self,))
|
182
|
-
|
183
|
-
def r(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
|
184
|
-
new_shape = self.st.reduce(axis)
|
185
|
-
# TODO: this logic should move to the scheduler
|
186
|
-
if 0 in self.shape and 0 not in new_shape: return self.const_with_shape(identity_element(op, self.dtype), new_shape)
|
187
|
-
|
188
|
-
# const folding
|
189
|
-
# TODO: fold this for symbolic?
|
190
|
-
if self.is_unrealized_unmasked_const() and all_int(self.shape):
|
191
|
-
if op is Ops.ADD: return self.const_with_shape(self.base.arg * prod(self.shape[i] for i in axis), new_shape)
|
192
|
-
if op is Ops.MUL: return self.const_with_shape(self.base.arg ** prod(self.shape[i] for i in axis), new_shape)
|
193
|
-
if op is Ops.MAX: return self.const_with_shape(self.base.arg, new_shape)
|
194
|
-
|
195
|
-
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
196
|
-
if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \
|
197
|
-
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
|
198
|
-
return self._reduce_op(op, axis)
|
199
|
-
|
200
|
-
# if there are few globals, make some reduces into globals by splitting into two kernels
|
201
|
-
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
|
202
|
-
# ~2**10 should be enough if GROUP is used
|
203
|
-
# 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
|
204
|
-
# split is moved to the end to provide maximum locality for the second phase reduce.
|
205
|
-
self_real_strides = self.st.real_strides(ignore_valid=True)
|
206
|
-
split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
|
207
|
-
if self.shape[i] % x == 0 and self_real_strides[i] != 0]
|
208
|
-
if not split_candidates: return self._reduce_op(op, axis)
|
209
|
-
dim_to_split, divisor = split_candidates[0]
|
210
|
-
splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
|
211
|
-
splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
|
212
|
-
if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
|
213
|
-
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
|
214
|
-
|
215
|
-
# *** movement ops ***
|
216
|
-
|
217
|
-
def _view(self, new_st:ShapeTracker) -> LazyBuffer:
|
218
|
-
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
|
219
|
-
return self.const_with_shape(0, new_st.shape)
|
220
|
-
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
|
221
|
-
return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
|
222
|
-
|
223
|
-
def reshape(self, arg:Tuple[sint, ...]): return self._view(self.st.reshape(arg))
|
224
|
-
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.pad(arg))
|
225
|
-
def expand(self, arg:Tuple[sint, ...]): return self._view(self.st.expand(arg))
|
226
|
-
def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
|
227
|
-
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
|
228
|
-
def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))
|
tinygrad/function.py
DELETED
@@ -1,212 +0,0 @@
|
|
1
|
-
"""This is where the forwards and backwards passes live."""
|
2
|
-
import math
|
3
|
-
from typing import Tuple, Optional
|
4
|
-
from tinygrad.helpers import argsort
|
5
|
-
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
|
6
|
-
from tinygrad.ops import Ops, resolve, sint
|
7
|
-
from tinygrad.tensor import Function
|
8
|
-
from tinygrad.engine.lazy import LazyBuffer
|
9
|
-
|
10
|
-
class Contiguous(Function):
|
11
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
|
12
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output
|
13
|
-
|
14
|
-
class ContiguousBackward(Function):
|
15
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer: return x
|
16
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.contiguous()
|
17
|
-
|
18
|
-
class Cast(Function):
|
19
|
-
def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
|
20
|
-
self.input_dtype, self.bitcast = x.dtype, bitcast
|
21
|
-
return x.bitcast(dtype) if self.bitcast else x.cast(dtype)
|
22
|
-
|
23
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
24
|
-
if self.bitcast: raise RuntimeError("bitcast cannot backward")
|
25
|
-
return grad_output.cast(self.input_dtype)
|
26
|
-
|
27
|
-
# ************* unary ops *************
|
28
|
-
|
29
|
-
class Reciprocal(Function):
|
30
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
31
|
-
self.ret = x.reciprocal()
|
32
|
-
return self.ret
|
33
|
-
|
34
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return -grad_output * self.ret * self.ret
|
35
|
-
|
36
|
-
class Sin(Function):
|
37
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
38
|
-
self.x = x
|
39
|
-
return x.sin()
|
40
|
-
|
41
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return (math.pi/2 - self.x).sin() * grad_output
|
42
|
-
|
43
|
-
class Relu(Function):
|
44
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
45
|
-
self.ret = x.maximum(0)
|
46
|
-
return self.ret
|
47
|
-
|
48
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.gt(0).cast(grad_output.dtype) * grad_output
|
49
|
-
|
50
|
-
class Log(Function):
|
51
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
52
|
-
self.x = x
|
53
|
-
return x.log2() * math.log(2)
|
54
|
-
|
55
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output / self.x
|
56
|
-
|
57
|
-
class Exp(Function):
|
58
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
59
|
-
self.ret = (x * (1/math.log(2))).exp2()
|
60
|
-
return self.ret
|
61
|
-
|
62
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret * grad_output
|
63
|
-
|
64
|
-
class Sqrt(Function):
|
65
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
66
|
-
self.ret = x.sqrt()
|
67
|
-
return self.ret
|
68
|
-
|
69
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output / (self.ret*2)
|
70
|
-
|
71
|
-
# NOTE: the implicit derivative of sigmoid is not stable
|
72
|
-
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
|
73
|
-
# TODO: have the backend automatically find this
|
74
|
-
class Sigmoid(Function):
|
75
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
76
|
-
self.ret = (1 + (x * (-1/math.log(2))).exp2()).reciprocal()
|
77
|
-
return self.ret
|
78
|
-
|
79
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
80
|
-
return (self.ret * (1 - self.ret)) * grad_output
|
81
|
-
|
82
|
-
class Sign(Function):
|
83
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.ne(0).where(x.lt(0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
|
84
|
-
# backward always return 0 to match torch
|
85
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const_like(0)
|
86
|
-
|
87
|
-
# ************* binary ops *************
|
88
|
-
|
89
|
-
class Less(Function):
|
90
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.lt(y)
|
91
|
-
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
|
92
|
-
|
93
|
-
class Neq(Function):
|
94
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.ne(y)
|
95
|
-
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
|
96
|
-
|
97
|
-
class Xor(Function):
|
98
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x^y
|
99
|
-
|
100
|
-
class BitwiseAnd(Function):
|
101
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x&y
|
102
|
-
|
103
|
-
class BitwiseOr(Function):
|
104
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x|y
|
105
|
-
|
106
|
-
class Threefry(Function):
|
107
|
-
def forward(self, x:LazyBuffer, seed:LazyBuffer) -> LazyBuffer: return x.threefry(seed)
|
108
|
-
|
109
|
-
class Add(Function):
|
110
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x+y
|
111
|
-
|
112
|
-
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
113
|
-
return grad_output if self.needs_input_grad[0] else None, \
|
114
|
-
grad_output if self.needs_input_grad[1] else None
|
115
|
-
|
116
|
-
class Mul(Function):
|
117
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
118
|
-
self.x, self.y = x, y
|
119
|
-
return x * y
|
120
|
-
|
121
|
-
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
122
|
-
return (self.y * grad_output) if self.needs_input_grad[0] else None, \
|
123
|
-
(self.x * grad_output) if self.needs_input_grad[1] else None
|
124
|
-
|
125
|
-
class IDiv(Function):
|
126
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x // y
|
127
|
-
|
128
|
-
# ************* ternary ops *************
|
129
|
-
|
130
|
-
class Where(Function):
|
131
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
|
132
|
-
self.x = x
|
133
|
-
return self.x.where(y, z)
|
134
|
-
|
135
|
-
def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
|
136
|
-
return None, \
|
137
|
-
self.x.where(grad_output, grad_output.const_like(0)) if self.needs_input_grad[1] else None, \
|
138
|
-
self.x.where(grad_output.const_like(0), grad_output) if self.needs_input_grad[2] else None
|
139
|
-
|
140
|
-
# ************* reduce ops *************
|
141
|
-
|
142
|
-
class Sum(Function):
|
143
|
-
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
144
|
-
self.input_shape = x.shape
|
145
|
-
return x.r(Ops.ADD, axis)
|
146
|
-
|
147
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
|
148
|
-
|
149
|
-
class Prod(Function):
|
150
|
-
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
151
|
-
self.x, self.ret = x, x.r(Ops.MUL, axis)
|
152
|
-
return self.ret
|
153
|
-
|
154
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
155
|
-
return (grad_output * self.ret).expand(self.x.shape) / self.x
|
156
|
-
|
157
|
-
class Max(Function):
|
158
|
-
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
159
|
-
self.x, self.ret, self.axis = x, x.r(Ops.MAX, axis), axis
|
160
|
-
return self.ret
|
161
|
-
|
162
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
163
|
-
# 1s in locations where the max was chosen (can be two locations)
|
164
|
-
max_is_1s = self.x.ne(self.ret.expand(self.x.shape)).ne(self.x.const_like(1).cast(dtypes.bool)).cast(grad_output.dtype)
|
165
|
-
div = max_is_1s.r(Ops.ADD, self.axis).expand(self.x.shape)
|
166
|
-
return (max_is_1s/div) * grad_output.expand(self.x.shape)
|
167
|
-
|
168
|
-
# ************* movement ops *************
|
169
|
-
|
170
|
-
# NOTE: this is sum in reverse
|
171
|
-
class Expand(Function):
|
172
|
-
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
173
|
-
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if resolve(si != so))
|
174
|
-
return x.expand(shape)
|
175
|
-
|
176
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
177
|
-
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(Ops.ADD, self.expanded_axis).cast(grad_output.dtype)
|
178
|
-
|
179
|
-
class Reshape(Function):
|
180
|
-
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
181
|
-
self.input_shape = x.shape
|
182
|
-
return x.reshape(shape)
|
183
|
-
|
184
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.reshape(self.input_shape)
|
185
|
-
|
186
|
-
class Permute(Function):
|
187
|
-
def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
|
188
|
-
self.input_order = order
|
189
|
-
return x.permute(order)
|
190
|
-
|
191
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.permute(argsort(self.input_order))
|
192
|
-
|
193
|
-
class Pad(Function):
|
194
|
-
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
195
|
-
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
|
196
|
-
return x.pad(arg)
|
197
|
-
|
198
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.shrink(self.narg)
|
199
|
-
|
200
|
-
class Shrink(Function):
|
201
|
-
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
|
202
|
-
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
|
203
|
-
return x.shrink(arg)
|
204
|
-
|
205
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.pad(self.narg)
|
206
|
-
|
207
|
-
class Flip(Function):
|
208
|
-
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
209
|
-
self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))])
|
210
|
-
return x.stride(self.arg)
|
211
|
-
|
212
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)
|
tinygrad/multi.py
DELETED
@@ -1,177 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
from typing import Optional, Tuple, List, Dict
|
3
|
-
import functools, itertools, operator
|
4
|
-
from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
|
5
|
-
from tinygrad.dtype import DType
|
6
|
-
from tinygrad.ops import Ops, MathTrait
|
7
|
-
from tinygrad.engine.lazy import LazyBuffer
|
8
|
-
from tinygrad.shape.shapetracker import sint
|
9
|
-
|
10
|
-
def all_reduce(bop: Ops, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
|
11
|
-
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
|
12
|
-
assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
|
13
|
-
n_lbs, shape, numel = len(lbs), lbs[0].shape, prod(lbs[0].shape)
|
14
|
-
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
15
|
-
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
16
|
-
use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
17
|
-
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {lbs[0].dtype}")
|
18
|
-
if not use_ring: return [functools.reduce(lambda x,y: x.alu(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
19
|
-
|
20
|
-
factor = next(f for f in [32, 16, 8, 4, 2, 1] if numel % f == 0)
|
21
|
-
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
|
22
|
-
chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
|
23
|
-
acc = 0
|
24
|
-
chunks = [(acc, (acc := acc + i)) for i in chunk_sizes if i > 0]
|
25
|
-
chunked = [[lb.reshape((numel,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
|
26
|
-
|
27
|
-
# scatter-reduce
|
28
|
-
for step in range(n_lbs-1):
|
29
|
-
for i in range(len(chunks)):
|
30
|
-
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
|
31
|
-
chunked[dest][i] = chunked[dest][i].alu(bop, chunked[src][i].copy_to_device(chunked[dest][i].device, force=True))
|
32
|
-
|
33
|
-
# allgather
|
34
|
-
for step in range(n_lbs-1):
|
35
|
-
for i in range(len(chunks)):
|
36
|
-
src, dest = (i+step-1)%n_lbs, (i+step)%n_lbs
|
37
|
-
chunked[dest][i] = chunked[src][i].copy_to_device(chunked[dest][i].device, force=True)
|
38
|
-
|
39
|
-
# assemble chunks back
|
40
|
-
pads = [((s,numel-e),) for s,e in chunks]
|
41
|
-
return [functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads,lb_c)]).reshape(shape) for lb_c in chunked]
|
42
|
-
|
43
|
-
def to_sharded(lbs:List[LazyBuffer], axis:int, bounds: Tuple[Tuple[int, int], ...]) -> List[LazyBuffer]:
|
44
|
-
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
|
45
|
-
return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
|
46
|
-
|
47
|
-
class MultiLazyBuffer(MathTrait):
|
48
|
-
def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None):
|
49
|
-
assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them"
|
50
|
-
assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}"
|
51
|
-
self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
|
52
|
-
if axis is not None:
|
53
|
-
splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0))
|
54
|
-
self.bounds = tuple(zip(splits, splits[1:]))
|
55
|
-
|
56
|
-
@property
|
57
|
-
def shape(self): return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
|
58
|
-
|
59
|
-
@property
|
60
|
-
def size(self): return sum(x.size for x in self.real_lbs)
|
61
|
-
|
62
|
-
@property
|
63
|
-
def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r]
|
64
|
-
|
65
|
-
def __repr__(self): return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
|
66
|
-
|
67
|
-
@staticmethod
|
68
|
-
def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int], bounds:Optional[Tuple[Tuple[int, int], ...]]):
|
69
|
-
assert (axis is None) == (bounds is None), "must specify bounds iff axis is specified"
|
70
|
-
lbs = [lb] * len(devices)
|
71
|
-
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis, bounds) if axis is not None and bounds is not None else lbs, devices)]
|
72
|
-
return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis)
|
73
|
-
|
74
|
-
def copy_to_device(self, device:str) -> LazyBuffer:
|
75
|
-
if self.axis is None:
|
76
|
-
# if we already have a copy on the device, return that
|
77
|
-
return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device))
|
78
|
-
# copy lbs to device, pad to final shape, and sum
|
79
|
-
llbs:List[LazyBuffer] = []
|
80
|
-
for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
|
81
|
-
if not real: continue
|
82
|
-
pad_arg = tuple((0,0) if a != self.axis else (start, self.bounds[-1][1]-end) for a in range(len(lb.shape)))
|
83
|
-
llbs.append(lb.copy_to_device(device).pad(pad_arg))
|
84
|
-
return functools.reduce(operator.add, llbs)
|
85
|
-
|
86
|
-
# passthroughs
|
87
|
-
@property
|
88
|
-
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs)
|
89
|
-
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True):
|
90
|
-
return MultiLazyBuffer([x.cast(dtype, bitcast, allow_buffer_view) for x in self.lbs], self.axis, self.real)
|
91
|
-
def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real)
|
92
|
-
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
|
93
|
-
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
|
94
|
-
def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real)
|
95
|
-
|
96
|
-
# elementwise is simple
|
97
|
-
def alu(self, op:Ops, *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer:
|
98
|
-
msrcs = (self,)+in_srcs
|
99
|
-
assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
|
100
|
-
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
|
101
|
-
|
102
|
-
# NOTE: they all have to share an axis, we always choose [-1]
|
103
|
-
axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None)
|
104
|
-
srcs:List[List[LazyBuffer]] = []
|
105
|
-
not_all_real = not all(all(mlb.real) for mlb in msrcs)
|
106
|
-
new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real
|
107
|
-
assert any(new_real), "output contains no real lb"
|
108
|
-
for mlb in msrcs:
|
109
|
-
if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs)
|
110
|
-
elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis, bounds))
|
111
|
-
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
|
112
|
-
new_real_lbs:Dict[int,LazyBuffer] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
|
113
|
-
# NOTE: const dtype should match real
|
114
|
-
new_dtype = next(iter(new_real_lbs.values())).dtype
|
115
|
-
return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
|
116
|
-
|
117
|
-
def r(self, op:Ops, axis:Tuple[int, ...]) -> MultiLazyBuffer:
|
118
|
-
if self.axis is not None and self.axis in axis:
|
119
|
-
# all-reduce on sharded axes
|
120
|
-
reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]
|
121
|
-
# if all partitions are real, do all_reduce
|
122
|
-
if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
|
123
|
-
# only one partition is real, keep it
|
124
|
-
return MultiLazyBuffer(reduced_parts, None, self.real)
|
125
|
-
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
126
|
-
return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)
|
127
|
-
|
128
|
-
# *** movement ops ***
|
129
|
-
|
130
|
-
def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
|
131
|
-
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
|
132
|
-
|
133
|
-
def reshape(self, arg:Tuple[sint, ...]):
|
134
|
-
if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
|
135
|
-
assert prod(self.shape) == prod(arg), "reshape must maintain prod(shape)"
|
136
|
-
arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
|
137
|
-
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
|
138
|
-
# todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
|
139
|
-
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
|
140
|
-
assert all(prod(lb.shape[self.axis:])%prod(arg[new_axis+1:])==0 for lb in self.lbs), f"reshape cannot move items between shards {self=} {arg=}"
|
141
|
-
lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[self.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in self.lbs]
|
142
|
-
return MultiLazyBuffer(lbs, new_axis, self.real)
|
143
|
-
|
144
|
-
def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
|
145
|
-
assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}"
|
146
|
-
# pad on shard axis -> fill others with zeros and set real to all True
|
147
|
-
if self.axis is not None and arg[self.axis] != (0,0):
|
148
|
-
# pad back to whole axis, remove real mask
|
149
|
-
assert all(arg[i] == (0, 0) for i in range(len(self.shape)) if i != self.axis), "cannot pad sharded and non-sharded axis at the same time"
|
150
|
-
dim, bound = sum(lb.shape[self.axis] for lb in self.lbs), self.bounds[self.real.index(True)]
|
151
|
-
assert arg[self.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis"
|
152
|
-
return MultiLazyBuffer([x if r else x.const_like(0) for x,r in zip(self.lbs, self.real)], self.axis)
|
153
|
-
return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)
|
154
|
-
|
155
|
-
def expand(self, arg:Tuple[sint, ...]):
|
156
|
-
# NOTE: this assert isn't needed, sharded axis can have dim 1
|
157
|
-
assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}"
|
158
|
-
return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real)
|
159
|
-
|
160
|
-
def permute(self, arg:Tuple[int, ...]):
|
161
|
-
# all permutes supported!
|
162
|
-
return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real)
|
163
|
-
|
164
|
-
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
|
165
|
-
assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
|
166
|
-
if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
|
167
|
-
assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time"
|
168
|
-
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
|
169
|
-
idx = self.bounds.index(arg[self.axis])
|
170
|
-
# zero out other lbs to not create lb reference
|
171
|
-
return MultiLazyBuffer([lb if i==idx else lb.const_like(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
|
172
|
-
return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs],
|
173
|
-
self.axis, self.real)
|
174
|
-
|
175
|
-
def stride(self, arg:Tuple[int, ...]):
|
176
|
-
assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
|
177
|
-
return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)
|
tinygrad/runtime/graph/clang.py
DELETED
@@ -1,39 +0,0 @@
|
|
1
|
-
from typing import List, Dict, cast
|
2
|
-
import ctypes
|
3
|
-
from tinygrad.helpers import dedup, cpu_time_execution, DEBUG
|
4
|
-
from tinygrad.engine.jit import GraphRunner, GraphException
|
5
|
-
from tinygrad.device import Buffer, Device
|
6
|
-
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
7
|
-
from tinygrad.ops import Variable
|
8
|
-
from tinygrad.runtime.ops_clang import ClangProgram
|
9
|
-
from tinygrad.renderer.cstyle import ClangRenderer
|
10
|
-
render_dtype = ClangRenderer().render_dtype
|
11
|
-
|
12
|
-
class ClangGraph(GraphRunner):
|
13
|
-
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
14
|
-
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
15
|
-
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
16
|
-
|
17
|
-
prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
|
18
|
-
args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
|
19
|
-
args += sorted([f"int {v.expr}" for v in var_vals])
|
20
|
-
code = ["void batched("+','.join(args)+") {"]
|
21
|
-
for ji in jit_cache:
|
22
|
-
args = []
|
23
|
-
for buf in ji.bufs:
|
24
|
-
assert buf is not None
|
25
|
-
if buf in input_rawbuffers:
|
26
|
-
args.append(f"arg{input_rawbuffers.index(buf)}")
|
27
|
-
else:
|
28
|
-
args.append(f"({render_dtype(buf.dtype)}*)0x{ctypes.addressof(buf._buf):X}")
|
29
|
-
args += [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
|
30
|
-
code.append(f" {cast(CompiledRunner, ji.prg).p.function_name}({','.join(args)});")
|
31
|
-
code.append("}")
|
32
|
-
if DEBUG >= 4: print("\n".join(code))
|
33
|
-
compiler = Device["CLANG"].compiler
|
34
|
-
assert compiler is not None
|
35
|
-
self.clprg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers
|
36
|
-
|
37
|
-
def __call__(self, rawbufs: List[Buffer], var_vals: Dict[Variable, int], wait=False):
|
38
|
-
return cpu_time_execution(
|
39
|
-
lambda: self.clprg(*[x._buf for x in rawbufs], *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0].expr)]), enable=wait)
|