tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/engine/multi.py
DELETED
@@ -1,161 +0,0 @@
|
|
1
|
-
import functools, itertools, operator
|
2
|
-
from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
|
3
|
-
from tinygrad.ops import Ops, UOp, sint
|
4
|
-
|
5
|
-
def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]:
|
6
|
-
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
|
7
|
-
assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
|
8
|
-
n_lbs, shape, numel = len(lbs), lbs[0].shape, prod(lbs[0].shape)
|
9
|
-
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
10
|
-
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
11
|
-
use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
12
|
-
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {lbs[0].dtype}")
|
13
|
-
if not use_ring: return [functools.reduce(lambda x,y: x.alu(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
14
|
-
|
15
|
-
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
|
16
|
-
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
|
17
|
-
chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
|
18
|
-
chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0)))
|
19
|
-
chunked = [[lb.reshape((numel,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
|
20
|
-
|
21
|
-
# scatter-reduce
|
22
|
-
for step in range(n_lbs-1):
|
23
|
-
for i in range(len(chunks)):
|
24
|
-
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
|
25
|
-
chunked[dest][i] = chunked[dest][i].alu(bop, chunked[src][i].copy_to_device(chunked[dest][i].device))
|
26
|
-
|
27
|
-
# allgather
|
28
|
-
for step in range(n_lbs-1):
|
29
|
-
for i in range(len(chunks)):
|
30
|
-
src, dest = (i+step-1)%n_lbs, (i+step)%n_lbs
|
31
|
-
chunked[dest][i] = chunked[src][i].copy_to_device(chunked[dest][i].device)
|
32
|
-
|
33
|
-
# assemble chunks back
|
34
|
-
pads = [((s,numel-e),) for s,e in chunks]
|
35
|
-
return [functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads,lb_c)]).reshape(shape) for lb_c in chunked]
|
36
|
-
|
37
|
-
def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) -> list[UOp]:
|
38
|
-
if lbs[0].shape[axis] % len(lbs) != 0: raise RuntimeError(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
|
39
|
-
return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
|
40
|
-
|
41
|
-
# ***** multi functions *****
|
42
|
-
|
43
|
-
from tinygrad.ops import PatternMatcher, UPat, GroupOp, graph_rewrite_map, track_rewrites
|
44
|
-
|
45
|
-
def alu_multi(root:UOp):
|
46
|
-
msrcs = root.src
|
47
|
-
assert all(x.op is Ops.MULTI for x in msrcs), f"all buffers must be MultiLazyBuffer {[x.op for x in msrcs]}"
|
48
|
-
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
|
49
|
-
|
50
|
-
axis = root.axis
|
51
|
-
bounds = dedup([x.bounds for x in root.src if x.axis == axis])[-1] if axis is not None else None
|
52
|
-
srcs:list[list[UOp]] = []
|
53
|
-
not_all_real = not all(all(mlb.real) for mlb in msrcs)
|
54
|
-
new_real = tuple(all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])) if not_all_real else msrcs[0].real
|
55
|
-
for mlb in msrcs:
|
56
|
-
if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(list(mlb.src))
|
57
|
-
else:
|
58
|
-
assert axis is not None and bounds is not None
|
59
|
-
if mlb.axis is None: srcs.append(to_sharded(list(mlb.src), axis, bounds))
|
60
|
-
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.src], axis, bounds))
|
61
|
-
new_lbs = [lsrcs[0].alu(root.op, *lsrcs[1:]) for lsrcs in zip(*srcs)]
|
62
|
-
new_lbs = [x if r else x.const_like(0) for r,x in zip(new_real, new_lbs)] # TODO: is this needed?
|
63
|
-
return UOp.multi(*new_lbs, axis=axis, real=new_real)
|
64
|
-
|
65
|
-
def reduce_multi(root:UOp, multi:UOp):
|
66
|
-
op, axis = root.arg
|
67
|
-
if multi.axis is not None and multi.axis in axis:
|
68
|
-
# all-reduce on sharded axes
|
69
|
-
reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(multi.src, multi.real)]
|
70
|
-
# if all partitions are real, do all_reduce
|
71
|
-
if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=root.axis)
|
72
|
-
# only one partition is real, keep it
|
73
|
-
return UOp.multi(*reduced_parts, axis=root.axis, real=multi.real)
|
74
|
-
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
75
|
-
return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=root.axis, real=multi.real)
|
76
|
-
|
77
|
-
def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
|
78
|
-
return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape))
|
79
|
-
|
80
|
-
def reshape_multi(root:UOp, multi:UOp):
|
81
|
-
arg = root.arg
|
82
|
-
if (new_axis:=root.axis) is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=new_axis, real=multi.real)
|
83
|
-
assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)"
|
84
|
-
assert all(prod(lb.shape[multi.axis:])%prod(arg[new_axis+1:])==0 for lb in multi.src), \
|
85
|
-
f"reshape cannot move items between shards {multi.shape} -> {root.arg=}"
|
86
|
-
lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[multi.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in multi.src]
|
87
|
-
return UOp.multi(*lbs, axis=new_axis, real=multi.real)
|
88
|
-
|
89
|
-
def expand_multi(root:UOp, multi:UOp):
|
90
|
-
# NOTE: this assert isn't needed, sharded axis can have dim 1
|
91
|
-
assert multi.axis is None or root.arg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.arg=}"
|
92
|
-
return UOp.multi(*[x.expand(_shape_to_single_shard(multi.axis, root.arg, x)) for x in multi.src], axis=multi.axis, real=multi.real)
|
93
|
-
|
94
|
-
def pad_multi(root:UOp, multi:UOp):
|
95
|
-
assert multi.axis is None or root.arg[multi.axis] == (0,0) or not all(multi.real), f"padding not supported for {root.arg=}"
|
96
|
-
# pad on shard axis -> fill others with zeros and set real to all True
|
97
|
-
if multi.axis is not None and root.arg[multi.axis] != (0,0):
|
98
|
-
# pad back to whole axis, remove real mask
|
99
|
-
assert all(root.arg[i] == (0, 0) for i in range(len(multi.shape)) if i != multi.axis), "cannot pad sharded and non-sharded axis at the same time"
|
100
|
-
dim, bound = sum(lb.shape[multi.axis] for lb in multi.src), multi.bounds[multi.real.index(True)]
|
101
|
-
assert root.arg[multi.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis"
|
102
|
-
return UOp.multi(*[x if r else x.const_like(0) for x,r in zip(multi.src, multi.real)], axis=multi.axis)
|
103
|
-
return UOp.multi(*[x.pad(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
|
104
|
-
|
105
|
-
def permute_multi(root:UOp, multi:UOp):
|
106
|
-
# all permutes supported!
|
107
|
-
return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.axis, real=multi.real)
|
108
|
-
|
109
|
-
def shrink_multi(root:UOp, multi:UOp):
|
110
|
-
assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \
|
111
|
-
f"shrinking not supported for {root.arg=}"
|
112
|
-
if multi.axis is not None and root.arg[multi.axis] in multi.bounds and root.arg[multi.axis] != (0, multi.shape[multi.axis]):
|
113
|
-
assert all(root.arg[i] == (0, s) or i == multi.axis for i,s in enumerate(multi.shape)), \
|
114
|
-
"cannot shrink sharded and non-sharded axis at the same time"
|
115
|
-
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
|
116
|
-
idx = multi.bounds.index(root.arg[multi.axis])
|
117
|
-
# zero out other lbs to not create lb reference
|
118
|
-
return UOp.multi(*[lb if i==idx else lb.const_like(0) for i,lb in enumerate(multi.src)],
|
119
|
-
axis=multi.axis, real=tuple(i==idx for i in range(len(multi.src))))
|
120
|
-
return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src],
|
121
|
-
axis=multi.axis, real=multi.real)
|
122
|
-
|
123
|
-
def flip_multi(root:UOp, multi:UOp):
|
124
|
-
assert multi.axis is None or not root.arg[multi.axis], "flipping not supported on sharded axis"
|
125
|
-
return UOp.multi(*[x.flip(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
|
126
|
-
|
127
|
-
def copy_multi(multi:UOp, device:UOp):
|
128
|
-
# if we already have a copy on the device, return that
|
129
|
-
if multi.axis is None: return next((lb for lb in multi.real_lbs if lb.device == device.arg), multi.real_lbs[0].copy_to_device(device.arg))
|
130
|
-
# copy lbs to device, pad to final shape, and sum
|
131
|
-
llbs:list[UOp] = []
|
132
|
-
for lb,real,(start,end) in zip(multi.src, multi.real, multi.bounds):
|
133
|
-
if not real: continue
|
134
|
-
pad_arg = tuple((0,0) if a != multi.axis else (start, multi.bounds[-1][1]-end) for a in range(len(lb.shape)))
|
135
|
-
llbs.append(lb.copy_to_device(device.arg).pad(pad_arg))
|
136
|
-
return functools.reduce(operator.add, llbs)
|
137
|
-
|
138
|
-
def assign_multi(dest:UOp, src:UOp):
|
139
|
-
assert dest.axis == src.axis and dest.real == src.real, f"axis/real must match in assign {dest.axis} != {src.axis} or {dest.real} != {src.real}"
|
140
|
-
return UOp.multi(*[x.assign(y) for x,y in zip(dest.src, src.src)], axis=src.axis, real=src.real)
|
141
|
-
|
142
|
-
def passthrough_multi(root:UOp, multi:UOp): return UOp.multi(*[root.replace(src=(m,)) for m in multi.src], axis=multi.axis, real=multi.real)
|
143
|
-
|
144
|
-
# NOTE: this is the same pattern as Ops.UNROLL
|
145
|
-
multi_pm = PatternMatcher([
|
146
|
-
(UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
|
147
|
-
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
|
148
|
-
(UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi),
|
149
|
-
(UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi),
|
150
|
-
(UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi),
|
151
|
-
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
|
152
|
-
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi),
|
153
|
-
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
|
154
|
-
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
|
155
|
-
(UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi),
|
156
|
-
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
|
157
|
-
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
158
|
-
])
|
159
|
-
|
160
|
-
@track_rewrites(named=True)
|
161
|
-
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: return {k:v for k,v in graph_rewrite_map(big_sink, multi_pm).items() if k is not v}
|