tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,231 @@
|
|
1
|
+
from typing import cast
|
2
|
+
import functools, itertools, operator
|
3
|
+
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv, unwrap
|
4
|
+
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, resolve
|
5
|
+
from tinygrad.device import Device
|
6
|
+
|
7
|
+
# *** allreduce implementation ***
|
8
|
+
def handle_allreduce_multirank(buf:UOp, red:UOp) -> UOp|None:
|
9
|
+
if not isinstance(buf.device, tuple): return None
|
10
|
+
|
11
|
+
# Group buffers
|
12
|
+
groups: dict[int|None, list[UOp]] = {}
|
13
|
+
for i,dev in enumerate(buf.device):
|
14
|
+
groups.setdefault(Device[dev].group_id, []).append(buf.mselect(i))
|
15
|
+
|
16
|
+
# Put reduce leader of each group first
|
17
|
+
reduce_leaders = set(getenv("REDUCE_LEADERS", "").split(","))
|
18
|
+
groups = {gid: sorted(bufs, key=lambda x: (x.device not in reduce_leaders, x.device)) for gid,bufs in groups.items()}
|
19
|
+
|
20
|
+
# Skip if only one group or if every group has only one buffer
|
21
|
+
if len(groups) <= 1 or not any(len(g) > 1 for g in groups.values()): return None
|
22
|
+
|
23
|
+
# Reduce inside each group
|
24
|
+
inner = [UOp(Ops.MSTACK, buf.dtype, tuple(bufs)).allreduce(red.arg, (cast(str, bufs[0].device),)).mselect(0) for bufs in groups.values()]
|
25
|
+
|
26
|
+
# Allreduce across groups
|
27
|
+
outer = UOp(Ops.MSTACK, buf.dtype, tuple(inner)).allreduce(red.arg, tuple(buf.device for buf in inner))
|
28
|
+
|
29
|
+
# Broadcast back to all devices in the group
|
30
|
+
gid2bid = {Device[device].group_id: i for i,device in enumerate(outer.device)}
|
31
|
+
return outer.mselect(gid2bid[Device[red.device].group_id]).copy_to_device(red.device) if not isinstance(red.device, tuple) else \
|
32
|
+
UOp(Ops.MSTACK, buf.dtype, tuple(outer.mselect(gid2bid[Device[device].group_id]).copy_to_device(device) for device in red.device))
|
33
|
+
|
34
|
+
def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
35
|
+
if not isinstance(buf.device, tuple): return None
|
36
|
+
assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}"
|
37
|
+
n_lbs, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
|
38
|
+
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
39
|
+
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
40
|
+
use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
41
|
+
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {buf.dtype}")
|
42
|
+
|
43
|
+
# contiguous before we copy it
|
44
|
+
buf = buf.contiguous()
|
45
|
+
|
46
|
+
# copy to all devices. if you shrink later, that'll be handled
|
47
|
+
if not use_ring: return functools.reduce(lambda x,y: x.alu(red.arg, y),
|
48
|
+
[UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(len(buf.device))])
|
49
|
+
|
50
|
+
# new ring reduce
|
51
|
+
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
|
52
|
+
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
|
53
|
+
chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
|
54
|
+
chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0)))
|
55
|
+
|
56
|
+
# extract chunks and scatter-reduce
|
57
|
+
reduced_chunks = []
|
58
|
+
for i,(s,e) in enumerate(chunks):
|
59
|
+
chunk = buf.reshape((numel,)).shrink(((s,e),))
|
60
|
+
reduced_chunk = chunk
|
61
|
+
for step in range(n_lbs-1):
|
62
|
+
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
|
63
|
+
# copy the chunk from the src device to the dest (operating device), and select the chunk on the dest device
|
64
|
+
reduced_chunk = reduced_chunk.copy_to_device(buf.device[dest], src if isinstance(reduced_chunk.device, tuple) else None) \
|
65
|
+
.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
|
66
|
+
reduced_chunks.append(reduced_chunk)
|
67
|
+
|
68
|
+
# allgather
|
69
|
+
copied_chunks = []
|
70
|
+
for i,c in enumerate(reduced_chunks):
|
71
|
+
this_chunk = [None] * len(buf.device)
|
72
|
+
this_chunk[(i+len(buf.device)-1)%n_lbs] = c
|
73
|
+
for step in range(n_lbs-1):
|
74
|
+
dest = (i+step)%n_lbs
|
75
|
+
this_chunk[dest] = c = c.copy_to_device(buf.device[dest])
|
76
|
+
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
|
77
|
+
|
78
|
+
# reassemble
|
79
|
+
pads = [((s,numel-e),) for s,e in chunks]
|
80
|
+
return functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads, copied_chunks)]).reshape(shape)
|
81
|
+
|
82
|
+
# ***** multi rewrite MSELECT/MSTACK *****
|
83
|
+
|
84
|
+
def _replace_dnum(st, val):
|
85
|
+
# replace dnum in ShapeTracker with literal const for this mselect
|
86
|
+
if (dnums:=[x for x in st.vars() if x.op is Ops.DEFINE_VAR and x.arg[0] == '_device_num']):
|
87
|
+
assert len(dnums) == 1, f"view must have exactly 0 or 1 dnum, got {dnums}"
|
88
|
+
st = st.substitute({dnums[0]:dnums[0].const_like(val)})
|
89
|
+
return st
|
90
|
+
|
91
|
+
def mstack_reorder_view(ms:UOp):
|
92
|
+
args = [x.arg for x in ms.src]
|
93
|
+
if not all_same(args) or len([x for x in args[0].vars() if x.arg[0] == '_device_num']) != 0: return None
|
94
|
+
return UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).view(args[0])
|
95
|
+
|
96
|
+
def mstack_early_shrink(view:UOp, ms:UOp):
|
97
|
+
if resolve(prod(view.shape) >= prod(ms.shape)) or _replace_dnum(view.st, 0) == view.st: return None
|
98
|
+
ret = []
|
99
|
+
for i, x in enumerate(ms.src):
|
100
|
+
new_view = _replace_dnum(view.st, i)
|
101
|
+
if x.op is Ops.COPY:
|
102
|
+
# if src device doesn't have a renderer, we have to view after the copy
|
103
|
+
# TODO: a way to understand this
|
104
|
+
if x.src[0].device in {"DISK", "NPY"}:
|
105
|
+
ret.append(x.view(new_view))
|
106
|
+
else:
|
107
|
+
ret.append(x.src[0].view(new_view).copy_to_device(x.device))
|
108
|
+
else:
|
109
|
+
ret.append(x.view(new_view).contiguous())
|
110
|
+
return ms.replace(src=tuple(ret))
|
111
|
+
|
112
|
+
replace_allreduce = PatternMatcher([
|
113
|
+
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce_multirank),
|
114
|
+
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
|
115
|
+
# BROADCAST: explicitly expand broadcast copies and combine with MSTACK
|
116
|
+
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
|
117
|
+
UOp(Ops.MSTACK, c.dtype, tuple(x.copy_to_device(d) for d in c.device)) if isinstance(c.device, tuple) and isinstance(x.device, str) else None),
|
118
|
+
# COPY_TO_ONE: if copying from multidevice to one, MSELECT the first (TODO: a little from each?)
|
119
|
+
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
|
120
|
+
x.mselect(0).copy_to_device(c.device) if isinstance(c.device, str) and isinstance(x.device, tuple) else None),
|
121
|
+
# MSELECT on MSTACK is replaced with nothing
|
122
|
+
(UPat(Ops.MSELECT, src=(UPat(Ops.MSTACK, name="mstack"),), name="ms"), lambda mstack, ms: mstack.src[ms.arg]),
|
123
|
+
# MSELECT must select a base, if there are views apply them after selecting the base
|
124
|
+
(UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"),), name="ms"), lambda ms, view, base:
|
125
|
+
base.mselect(ms.arg).view(_replace_dnum(unwrap(view.st), ms.arg))),
|
126
|
+
# move view through MSTACK
|
127
|
+
(UPat(Ops.MSTACK, src=UPat(Ops.VIEW), name="ms"), mstack_reorder_view),
|
128
|
+
# move shrink before MSTACK
|
129
|
+
(UPat(Ops.VIEW, src=(UPat(Ops.MSTACK, name="ms"),), name="view"), mstack_early_shrink),
|
130
|
+
])
|
131
|
+
|
132
|
+
# ***** multi functions *****
|
133
|
+
|
134
|
+
def alu_multi(root:UOp):
|
135
|
+
msrcs = root.src
|
136
|
+
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
|
137
|
+
axis = root.axis
|
138
|
+
assert axis is not None
|
139
|
+
|
140
|
+
srcs = []
|
141
|
+
for mlb in msrcs:
|
142
|
+
if mlb.axis == axis:
|
143
|
+
# same axis, just copy through
|
144
|
+
assert mlb.op is Ops.MULTI
|
145
|
+
srcs.append(mlb.src[0])
|
146
|
+
elif mlb.axis is None:
|
147
|
+
# no axis, shard it
|
148
|
+
assert mlb.op is not Ops.MULTI
|
149
|
+
srcs.append(mlb._shard(axis))
|
150
|
+
else:
|
151
|
+
# axis mismatch, unshard it, send it to all devices, and shard it correctly
|
152
|
+
assert mlb.op is Ops.MULTI
|
153
|
+
srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis))
|
154
|
+
return srcs[0].alu(root.op, *srcs[1:]).multi(axis)
|
155
|
+
|
156
|
+
def reduce_multi(root:UOp, multi:UOp):
|
157
|
+
op, axis = root.arg
|
158
|
+
if multi.axis is not None and multi.axis in axis:
|
159
|
+
# all-reduce on sharded axes
|
160
|
+
return multi.src[0].r(op, axis).allreduce(op, multi.device)
|
161
|
+
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
162
|
+
return multi.src[0].r(op, axis).multi(axis=multi.axis)
|
163
|
+
|
164
|
+
def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
|
165
|
+
return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape))
|
166
|
+
|
167
|
+
def reshape_multi(root:UOp, multi:UOp):
|
168
|
+
arg = root.arg
|
169
|
+
if (new_axis:=root.axis) is None: return multi.src[0].reshape(arg).multi(new_axis)
|
170
|
+
assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)"
|
171
|
+
assert prod(multi.src[0].shape[multi.axis:])%prod(arg[new_axis+1:]) == 0, f"reshape cannot move items between shards {multi.shape} -> {root.arg=}"
|
172
|
+
new_shape_axis = prod(multi.src[0].shape[multi.axis:]) // prod(arg[new_axis+1:])
|
173
|
+
return multi.src[0].reshape(tuple(s if a!=new_axis else new_shape_axis for a,s in enumerate(arg))).multi(new_axis)
|
174
|
+
|
175
|
+
def expand_multi(root:UOp, multi:UOp):
|
176
|
+
# NOTE: this assert isn't needed, sharded axis can have dim 1
|
177
|
+
assert multi.axis is None or root.arg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.arg=}"
|
178
|
+
return multi.src[0].expand(_shape_to_single_shard(multi.axis, root.arg, multi.src[0])).multi(multi.axis)
|
179
|
+
|
180
|
+
def pad_multi(root:UOp, multi:UOp):
|
181
|
+
assert multi.axis is None or root.arg[multi.axis] == (0,0), f"padding not supported for {root.arg=}"
|
182
|
+
return multi.src[0].pad(root.arg).multi(multi.axis)
|
183
|
+
|
184
|
+
def permute_multi(root:UOp, multi:UOp):
|
185
|
+
# all permutes supported!
|
186
|
+
return multi.src[0].permute(root.arg).multi(root.axis)
|
187
|
+
|
188
|
+
def shrink_multi(root:UOp, multi:UOp):
|
189
|
+
assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \
|
190
|
+
f"shrinking not supported for {root.arg=}"
|
191
|
+
if multi.axis is not None and root.arg[multi.axis] in multi.bounds and root.arg[multi.axis] != (0, multi.shape[multi.axis]):
|
192
|
+
assert all(root.arg[i] == (0, s) or i == multi.axis for i,s in enumerate(multi.shape)), \
|
193
|
+
"cannot shrink sharded and non-sharded axis at the same time"
|
194
|
+
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
|
195
|
+
# we just copy it to all the devices, no real. this will be optimized out later
|
196
|
+
return multi.src[0].copy_to_device(multi.device, arg=multi.bounds.index(root.arg[multi.axis]))
|
197
|
+
return multi.src[0].shrink(tuple((0, multi.src[0].shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))).multi(multi.axis)
|
198
|
+
|
199
|
+
def flip_multi(root:UOp, multi:UOp):
|
200
|
+
assert multi.axis is None or not root.arg[multi.axis], "flipping not supported on sharded axis"
|
201
|
+
return multi.src[0].flip(root.arg).multi(multi.axis)
|
202
|
+
|
203
|
+
# from multiple devices -> one
|
204
|
+
def copy_multi(multi:UOp, device:UOp):
|
205
|
+
assert multi.axis is not None, "all multi ops have axis"
|
206
|
+
return multi.src[0]._unshard(multi.axis).allreduce(Ops.ADD, device)
|
207
|
+
|
208
|
+
def assign_multi(dest:UOp, src:UOp):
|
209
|
+
if dest.axis != src.axis: raise RuntimeError(f"axis must match in assign {dest.axis} != {src.axis}")
|
210
|
+
return dest.src[0].assign(src.src[0]).multi(src.axis)
|
211
|
+
|
212
|
+
def passthrough_multi(root:UOp, multi:UOp):
|
213
|
+
return root.replace(src=(multi.src[0],)).multi(multi.axis)
|
214
|
+
|
215
|
+
# NOTE: this is the same pattern as Ops.UNROLL
|
216
|
+
multi_pm = PatternMatcher([
|
217
|
+
(UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
|
218
|
+
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
|
219
|
+
(UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi),
|
220
|
+
(UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi),
|
221
|
+
(UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi),
|
222
|
+
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
|
223
|
+
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi),
|
224
|
+
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
|
225
|
+
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
|
226
|
+
(UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),
|
227
|
+
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
|
228
|
+
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
|
229
|
+
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE),
|
230
|
+
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
231
|
+
])+replace_allreduce
|
tinygrad/shape/shapetracker.py
CHANGED
@@ -2,58 +2,55 @@
|
|
2
2
|
from __future__ import annotations
|
3
3
|
from dataclasses import dataclass
|
4
4
|
import functools
|
5
|
-
from typing import
|
5
|
+
from typing import Callable
|
6
6
|
from tinygrad.helpers import merge_dicts, getenv
|
7
|
-
from tinygrad.shape.view import View,
|
7
|
+
from tinygrad.shape.view import View, unravel
|
8
8
|
from tinygrad.dtype import dtypes
|
9
|
-
from tinygrad.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context
|
10
|
-
from tinygrad.
|
11
|
-
|
12
|
-
def overflow(u: UOp): return u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int)
|
9
|
+
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context, PatternMatcher, UPat, GroupOp
|
10
|
+
from tinygrad.uop.symbolic import split_uop, symbolic_flat, uop_given_valid, simplify_valid
|
13
11
|
|
14
12
|
# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
|
15
13
|
# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
|
16
|
-
def
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
def
|
29
|
-
with Context(TRACK_MATCH_STATS=0):
|
30
|
-
return upcast(graph_rewrite(u, sym, {}))
|
31
|
-
|
32
|
-
@functools.lru_cache(None)
|
33
|
-
def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
|
14
|
+
def handle_upcast(u: UOp) -> UOp|None:
|
15
|
+
dtype = dtypes.int64.vec(u.dtype.count) if u.dtype.count > 1 else dtypes.int64
|
16
|
+
# check for overflow, upcast this to int64
|
17
|
+
if u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int):
|
18
|
+
return u.replace(dtype=dtype, src=tuple([x.cast(dtype) for x in u.src]))
|
19
|
+
# if any inputs are int64 and this *doesn't* overflow, cast back to int
|
20
|
+
if any(x.dtype == dtypes.int64 for x in u.src):
|
21
|
+
return u.replace(dtype=dtype, src=tuple([x.cast(dtype) for x in u.src])).cast(u.dtype)
|
22
|
+
return None
|
23
|
+
pm_upcast = PatternMatcher([(UPat(GroupOp.ALU, dtype=dtypes.int, name="u"), handle_upcast),])
|
24
|
+
|
25
|
+
@functools.cache
|
26
|
+
def views_to_indexed_uops(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None) -> tuple[UOp, UOp]:
|
34
27
|
idx, valid = views[-1].to_indexed_uops(_idxs)
|
35
28
|
for view in reversed(views[0:-1]):
|
36
29
|
view = view.minify()
|
37
30
|
idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
31
|
+
with Context(TRACK_MATCH_STATS=0):
|
32
|
+
# symbolic
|
33
|
+
idx, valid = graph_rewrite(UOp.sink(idx, valid), symbolic_flat, name="indexing sym @ 1").src
|
34
|
+
# simplify
|
35
|
+
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
|
36
|
+
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = newidx
|
37
|
+
# symbolic again, upcast if needed
|
38
|
+
return graph_rewrite(UOp.sink(idx, valid), symbolic_flat+pm_upcast, name="indexing sym @ 2").src
|
39
|
+
|
40
|
+
@functools.cache
|
41
|
+
def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[sint|None, ...]:
|
42
42
|
# NOTE: if a stride is not always valid, it will be None
|
43
43
|
if len(views) == 1 and views[-1].mask is None: return views[-1].strides
|
44
|
-
ret: list[
|
45
|
-
idx, valid =
|
46
|
-
# TODO: always apply these in to_indexed_uops?
|
47
|
-
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
|
48
|
-
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
|
44
|
+
ret: list[sint|None] = [None] * len(views[-1].shape)
|
45
|
+
idx, valid = views_to_indexed_uops(views)
|
49
46
|
for c in split_uop(idx, Ops.ADD):
|
50
47
|
if c.op is Ops.RANGE: ret[c.arg] = 1
|
51
48
|
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
|
52
49
|
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
|
53
|
-
used_ranges = [x.arg for x in idx.toposort if x.op is Ops.RANGE]
|
50
|
+
used_ranges = [x.arg for x in idx.toposort() if x.op is Ops.RANGE]
|
54
51
|
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
|
55
52
|
if not ignore_valid:
|
56
|
-
for masked_axis in [x.arg for x in valid.toposort if x.op is Ops.RANGE]: ret[masked_axis] = None
|
53
|
+
for masked_axis in [x.arg for x in valid.toposort() if x.op is Ops.RANGE]: ret[masked_axis] = None
|
57
54
|
return tuple(ret)
|
58
55
|
|
59
56
|
@dataclass(frozen=True, order=True)
|
@@ -65,7 +62,7 @@ class ShapeTracker:
|
|
65
62
|
for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
|
66
63
|
return ret
|
67
64
|
|
68
|
-
def invert(self, out_shape:tuple[sint, ...]) ->
|
65
|
+
def invert(self, out_shape:tuple[sint, ...]) -> ShapeTracker|None:
|
69
66
|
inverted_views:list[View] = []
|
70
67
|
for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]):
|
71
68
|
if (inverted:= v.invert(s)) is None: return None
|
@@ -73,14 +70,11 @@ class ShapeTracker:
|
|
73
70
|
return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
|
74
71
|
|
75
72
|
@staticmethod
|
76
|
-
def from_shape(shape:tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),))
|
73
|
+
def from_shape(shape:tuple[sint, ...], strides:tuple[sint, ...]|None=None) -> ShapeTracker: return ShapeTracker((View.create(shape, strides),))
|
77
74
|
|
78
75
|
@property
|
79
76
|
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
|
80
77
|
|
81
|
-
@property
|
82
|
-
def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
|
83
|
-
|
84
78
|
@property
|
85
79
|
def shape(self) -> tuple[sint, ...]: return self.views[-1].shape
|
86
80
|
|
@@ -89,10 +83,8 @@ class ShapeTracker:
|
|
89
83
|
|
90
84
|
def reduce(self, axis:tuple[int, ...]) -> tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
91
85
|
|
92
|
-
def
|
93
|
-
|
94
|
-
idx, valid = views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
|
95
|
-
return folded_upcast(idx), folded_upcast(valid)
|
86
|
+
def to_indexed_uops(self, _idxs:list[UOp]|tuple[UOp, ...]|None=None) -> tuple[UOp, UOp]:
|
87
|
+
return views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
|
96
88
|
|
97
89
|
# upper bound on buffer size required to fit this shapetracker
|
98
90
|
def real_size(self) -> int:
|
@@ -111,14 +103,16 @@ class ShapeTracker:
|
|
111
103
|
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
|
112
104
|
if all(len(x) == 0 for x in var_vals): return self, {}
|
113
105
|
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
|
106
|
+
def substitute(self, dvars:dict[UOp, UOp]): return ShapeTracker(tuple(x.substitute(dvars) for x in self.views))
|
114
107
|
|
115
|
-
def real_strides(self, ignore_valid=False) -> tuple[
|
108
|
+
def real_strides(self, ignore_valid=False) -> tuple[sint|None, ...]:
|
109
|
+
with Context(TRACK_MATCH_STATS=0): return views_to_real_strides(self.views, ignore_valid)
|
116
110
|
def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
117
111
|
|
118
112
|
def axis_is_masked(self, axis:int) -> bool:
|
119
113
|
with Context(TRACK_MATCH_STATS=0):
|
120
114
|
_, valid = self.to_indexed_uops()
|
121
|
-
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
|
115
|
+
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort() if x.op is Ops.RANGE]
|
122
116
|
|
123
117
|
def simplify(self) -> ShapeTracker:
|
124
118
|
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
|