tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,247 @@
1
+ from typing import Optional, Any, Callable
2
+ import functools, operator
3
+ from collections import defaultdict
4
+ from tinygrad.dtype import dtypes, ImageDType, PtrDType
5
+ from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve
6
+ from tinygrad.ops import graph_rewrite, GroupOp
7
+ from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, mulacc_unrolled
8
+ from tinygrad.helpers import getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
9
+ from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
10
+ from tinygrad.renderer import Renderer
11
+
12
+ # ***** float4/image store handling *****
13
+
14
+ def fold_expanded(ex, buf):
15
+ new_srcs = dedup(list(ex.src))
16
+ old_new_srcs = new_srcs[:]
17
+ is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType)
18
+
19
+ # TODO: get the device from the buffer somehow
20
+ # NOTE: this can't be Device.DEFAULT because it opens devices
21
+ if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None
22
+ lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
23
+
24
+ # first, extract all the relevant offsets
25
+ offsets_rootsrc: defaultdict[Any, dict] = defaultdict(dict)
26
+ for i,s in enumerate(new_srcs):
27
+ idx = s.src[0].src[1]
28
+ if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue
29
+ if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
30
+ elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
31
+ else: root_src, arg = idx, 0
32
+ # add gates for gated
33
+ if len(s.src[0].src) == 3: root_src = (s.src[0].src[2], root_src)
34
+ assert arg not in offsets_rootsrc[root_src], f"{offsets_rootsrc[root_src][arg]} != {i} with {len(s.src)} sources"
35
+ offsets_rootsrc[root_src][arg] = i
36
+
37
+ # then rewrite everything we can
38
+ used: set[tuple[UOp, UOp]] = set()
39
+ for rootsrc, offsets in offsets_rootsrc.items():
40
+ for o in offsets:
41
+ for fold_length in lengths:
42
+ if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)):
43
+ load_1 = new_srcs[offsets[o]]
44
+ new_src = list(load_1.src)
45
+ oidx = new_src[0].src[1]
46
+ if oidx.divides(fold_length) is None: continue
47
+ if is_image:
48
+ # for images, we rewrite the index. it must evenly divide 4 from the above check
49
+ new_src[0] = buf.index(
50
+ UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
51
+ rootsrc[0] if isinstance(rootsrc, tuple) else None)
52
+ else:
53
+ # for non image, we upcast the index pointer
54
+ new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(size=new_src[0].dtype.size, local=new_src[0].dtype.local))
55
+ # generate the folded new_srcs
56
+ if is_load:
57
+ new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
58
+ for i in range(fold_length): new_srcs[offsets[o+i]] = new_load.gep(i)
59
+ else: # vectorize the store
60
+ new_src[1] = UOp(Ops.VECTORIZE, new_src[1].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[1] for i in range(fold_length)))
61
+ for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(Ops.STORE, dtypes.void, tuple(new_src)) if i == 0 else None
62
+ used.update((rootsrc,o+i) for i in range(fold_length))
63
+
64
+ # dedup expand for LOAD
65
+ if is_load and len(old_new_srcs) != len(ex.src): new_srcs = [new_srcs[old_new_srcs.index(s)] for s in ex.src]
66
+ # remove Nones for STORE
67
+ return UOp(ex.op, ex.dtype, tuple(x for x in new_srcs if x is not None), ex.arg) if len(used) else None
68
+
69
+ def fix_unfoldable_image_load(load:UOp, buf:UOp):
70
+ if not isinstance(buf.dtype, ImageDType) or (oidx:=load.src[0].src[1]).dtype.count == 2: return None
71
+ id4 = oidx % 4
72
+ new_src = list(load.src)
73
+ # TODO: copied logic from above
74
+ new_src[0] = load.src[0].src[0].index(
75
+ UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
76
+ load.src[0].src[2] if len(load.src[0].src) == 3 else None)
77
+ vec_load = UOp(Ops.LOAD, load.dtype.vec(4), tuple(new_src))
78
+ return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan')))
79
+
80
+ buf_idx_pat = UPat(Ops.INDEX, src=(UPat.var("buf"),), allow_any_len=True)
81
+ float4_folding = PatternMatcher([
82
+ (UPat(Ops.VECTORIZE, src=UPat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
83
+ (UPat((Ops.BARRIER, Ops.SINK), src=UPat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
84
+ ])
85
+
86
+ # ***** image load valid simplification *****
87
+
88
+ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
89
+ if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0)
90
+ if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
91
+
92
+ # wait for it to be image indexed before running simplification
93
+ if start_idx.dtype.count != 2: return None
94
+
95
+ # can drop valid if idx is out of bound when valid is False
96
+ drop_stmt = []
97
+ for stmt in split_uop(valid, Ops.AND):
98
+ X, is_upper_bound, c = parse_valid(stmt)
99
+
100
+ # for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
101
+ if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)):
102
+ testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, Ops.ADD), idx)
103
+ testidx = testidx.simplify()
104
+ if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
105
+ drop_stmt.append(stmt)
106
+ continue
107
+
108
+ # if X <= c, check if it's out of bound when X = c+1
109
+ # if X >= c, check if it's out of bound when X = c-1
110
+ test_value = c + 1 if is_upper_bound else c - 1
111
+ for i,b in zip(idx.src, (buf.dtype.shape[1], buf.dtype.shape[0])):
112
+ if i.is_increasing():
113
+ rw = i.substitute({X:X.const_like(test_value)}).simplify()
114
+ if rw.vmin >= b or rw.vmax < 0:
115
+ drop_stmt.append(stmt)
116
+ break
117
+
118
+ if not drop_stmt and idx is start_idx: return None
119
+ new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None
120
+ return buf.index(idx, new_valid)
121
+
122
+ # ***** optional patterns *****
123
+
124
+ powers_of_two = {2**i:i for i in range(64)}
125
+ @functools.lru_cache(None)
126
+ def get_late_rewrite_patterns(ops, force_transcendental=False):
127
+ pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
128
+ ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
129
+ # rewrite SQRT to xpow 0.5
130
+ if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
131
+ # rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
132
+ if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
133
+ # rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
134
+ if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)]
135
+ if Ops.SHR in ops:
136
+ pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) and resolve(x>=0,False) else None)]
137
+ if Ops.NEG in ops:
138
+ pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
139
+ if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
140
+ if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
141
+ return PatternMatcher(pat)
142
+
143
+
144
+ # *** uop expander ***
145
+
146
+ # TODO: there's a lot shared with gep_through_wmma here
147
+ def no_vectorized_wmma(wmma:UOp):
148
+ out_sz = prod(x[1] for x in wmma.arg[6][-1])
149
+ if wmma.dtype.count == out_sz: return None
150
+ tsrcs = []
151
+ for s,sz in zip(wmma.src, wmma.arg[6]):
152
+ ssz = prod(x[1] for x in sz)
153
+ tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
154
+ wmmas = [UOp(Ops.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
155
+ wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
156
+ return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
157
+
158
+ def no_vectorized_alu(alu):
159
+ if alu.dtype.vcount == 1: return None
160
+ alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
161
+ return UOp(Ops.VECTORIZE, alu.dtype, alus)
162
+
163
+ def no_vectorized_load_store(ls:UOp):
164
+ idx = ls.src[0]
165
+ assert isinstance(idx.dtype, PtrDType)
166
+ if idx.dtype.v == 1: return None
167
+ tv = [UOp(ls.op, ls.dtype.scalar(), tuple(j.gep(i) for j in ls.src)) for i in range(idx.dtype.v)]
168
+ return UOp(Ops.VECTORIZE, ls.dtype, tuple(tv))
169
+
170
+ def no_vectorized_acc(acc:UOp):
171
+ if acc.dtype.count == 1: return None
172
+ alus = tuple(UOp(acc.op, acc.dtype.scalar(),
173
+ tuple(s.gep(i) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
174
+ return UOp(Ops.VECTORIZE, acc.dtype, alus)
175
+
176
+ devectorize = PatternMatcher([
177
+ # no ALU on vectorized dtypes
178
+ (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
179
+ (UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
180
+ (UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
181
+ (UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
182
+ ])
183
+
184
+ devectorize_load_store = PatternMatcher([
185
+ # TODO: add vectorized support to transcendental
186
+ (UPat((Ops.INDEX, Ops.EXP2, Ops.LOG2, Ops.SIN), name="alu"), no_vectorized_alu),
187
+ (UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
188
+ ])
189
+
190
+ def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
191
+ if store_gate not in [gate.src[0] for gate in val.toposort if gate.op is Ops.IF]: return None
192
+ # remove the gate from the index
193
+ return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
194
+
195
+ load_store_indexing = PatternMatcher([
196
+ # late fixup of unfoldable image loads
197
+ (UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
198
+ # simplify valid
199
+ (UPat(Ops.AND, name="valid"), simplify_valid),
200
+ # image load valid idx simplification
201
+ (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
202
+ # delete_redundant_gates (after expand)
203
+ (UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
204
+ UPat.var("val"))), delete_redundant_gates),
205
+ ])
206
+
207
+ def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:UOp|None=None) -> UOp:
208
+ # this moves the mask from the indexing to the load/store op for rendering
209
+ nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
210
+ return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is Ops.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
211
+
212
+ pm_render = PatternMatcher([
213
+ # for rendering, we use explicit VECTORIZE
214
+ (UPat(Ops.CONST, name='c'),
215
+ lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
216
+ (UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
217
+ (UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
218
+ (UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
219
+ # move masks of loads/stores
220
+ (UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("mask"))),
221
+ masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
222
+ # gate any stores that aren't gated with ifs
223
+ (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
224
+ lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
225
+ ])
226
+
227
+ # *** uop graph ***
228
+
229
+ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
230
+ assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
231
+ supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
232
+ extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
233
+
234
+ if DEVECTORIZE:
235
+ # devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse
236
+ sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+
237
+ mulacc_unrolled)
238
+ else:
239
+ # new devectorize only for load/store
240
+ sink = graph_rewrite(sink, sym+devectorize_load_store+mulacc_unrolled)
241
+
242
+ # optional pre matcher
243
+ if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher)
244
+
245
+ # final rules for the renderer (without sym)
246
+ sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher)
247
+ return sink
@@ -0,0 +1,121 @@
1
+ # this converts a lowerer program into a vectorized program
2
+
3
+ import functools, itertools, operator
4
+ from tinygrad.helpers import AMX, dedup, flatten, all_same, prod
5
+ from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, graph_rewrite
6
+ from tinygrad.codegen.symbolic import sym
7
+
8
+ def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int:
9
+ idx, mul = 0, 1
10
+ for axis,m in args[::-1]:
11
+ idx += rpk[axis] * mul
12
+ mul *= m
13
+ return idx
14
+
15
+ def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]:
16
+ return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
17
+
18
+ @functools.lru_cache(None)
19
+ def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]:
20
+ return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
21
+
22
+ def do_expand(root:UOp):
23
+ expands = [x for x in root.src if x.op is Ops.UNROLL]
24
+ if len(expands) == 0: return None
25
+ # NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
26
+ exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is Ops.WMMA else ()
27
+ if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0:
28
+ # if there's only one expand arg, it's okay to use it (optimization)
29
+ expand_args = expands[0].arg
30
+ else:
31
+ # otherwise, we sort them and GEP
32
+ expand_args = tuple(x for x in sorted(dedup(flatten(expands_args))) if x[0] not in exclude_args)
33
+ expand_sz = prod([x[1] for x in expand_args])
34
+ new_srcs = []
35
+ for i,src in enumerate(root.src):
36
+ if src.op is Ops.UNROLL:
37
+ if root.op is Ops.IF and i == 0:
38
+ # IF means OR on first arg to IF
39
+ new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)]))
40
+ elif expand_args == src.arg:
41
+ # just remove the expand
42
+ new_srcs.append(src.src[0])
43
+ else:
44
+ lst = _swizzle_args(expand_args, src.arg, exclude_args)
45
+ # if the base dtype is > 1, put those at the end
46
+ if src.dtype.count > 1: lst = flatten([[i*src.dtype.count+j for j in range(src.dtype.count)] for i in lst])
47
+ new_srcs.append(src.src[0].gep(tuple(lst)))
48
+ else:
49
+ # non-UNROLL input
50
+ if root.op is Ops.IF:
51
+ # for the first arg of IF, just pass them through ignoring UNROLLS
52
+ new_srcs.append(src)
53
+ elif src.dtype.count > 1:
54
+ # put any input dtype > 1 grouped together
55
+ new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
56
+ else:
57
+ # repeat the arg
58
+ new_srcs.append(src.broadcast(expand_sz))
59
+
60
+ new_arg = root.arg
61
+ if root.op is Ops.GEP:
62
+ assert root.dtype.count == 1
63
+ # is this right?
64
+ new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
65
+ nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
66
+ return UOp(Ops.UNROLL, root.dtype, (nsrc,), expand_args)
67
+
68
+ def do_contract(con:UOp):
69
+ ex = con.src[0]
70
+ # CONTRACT without UNROLL repeats the element VECTORIZED
71
+ if ex.op is not Ops.UNROLL: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count)
72
+ # CONTRACT may remove several axes from UNROLL
73
+ assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
74
+ idxs = []
75
+ for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
76
+ idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)]
77
+ return UOp(Ops.UNROLL, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
78
+
79
+ expander = PatternMatcher([
80
+ # double expand
81
+ (UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
82
+ lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
83
+ # do expansion
84
+ (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
85
+ Ops.VECTORIZE, Ops.IF), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
86
+ (UPat(Ops.CONTRACT, name="con"), do_contract),
87
+ # vectorize DEFINE_ACC
88
+ (UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"),
89
+ lambda acc,v: acc.replace(dtype=v.dtype, src=(acc.src[0].broadcast(v.dtype.count),)+acc.src[1:])),
90
+ # BARRIERs aren't actually expanded
91
+ (UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
92
+ lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)),
93
+ # empty UNROLL is NOOP
94
+ (UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x),
95
+ # UNROLL GEP (needed for WMMA, generalize this) -> vectorized ALU
96
+ (UPat(Ops.UNROLL, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
97
+ lambda ex,x,y: UOp(Ops.UNROLL, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
98
+ ])
99
+
100
+ def create_gate(root:UOp) -> UOp|None:
101
+ @functools.lru_cache(None)
102
+ def _gate_srcs(u:UOp, gate:UOp) -> UOp:
103
+ if u.op is Ops.BARRIER: return u
104
+ if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER:
105
+ return UOp(u.op, u.dtype, u.src[:-1]+(UOp(Ops.IF, src=(gate, u.src[-1])),), arg=u.arg)
106
+ return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg)
107
+ idx = root.src[0]
108
+ if idx.op is Ops.CAST: idx = idx.src[0]
109
+ return None if idx.op is not Ops.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret
110
+
111
+ migrate_indexing = PatternMatcher([
112
+ # create gate MUST BE BEFORE expander
113
+ (UPat(Ops.STORE, name="root"), create_gate),
114
+ ])
115
+
116
+ def expand_rewrite(sink:UOp) -> UOp:
117
+ # initial symbolic + migrate indexing (remove this)
118
+ sink = graph_rewrite(sink, sym+migrate_indexing)
119
+
120
+ # expand
121
+ return graph_rewrite(sink, sym+expander)