tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +35 -37
  4. tinygrad/codegen/linearize.py +19 -10
  5. tinygrad/codegen/lowerer.py +31 -8
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +10 -0
  8. tinygrad/device.py +28 -11
  9. tinygrad/dtype.py +12 -3
  10. tinygrad/engine/jit.py +3 -2
  11. tinygrad/engine/multi.py +0 -1
  12. tinygrad/engine/realize.py +7 -4
  13. tinygrad/engine/schedule.py +227 -255
  14. tinygrad/engine/search.py +20 -27
  15. tinygrad/gradient.py +3 -0
  16. tinygrad/helpers.py +7 -4
  17. tinygrad/nn/state.py +2 -2
  18. tinygrad/ops.py +64 -329
  19. tinygrad/renderer/__init__.py +19 -3
  20. tinygrad/renderer/cstyle.py +39 -18
  21. tinygrad/renderer/llvmir.py +55 -18
  22. tinygrad/renderer/ptx.py +6 -2
  23. tinygrad/renderer/wgsl.py +20 -12
  24. tinygrad/runtime/autogen/libc.py +404 -71
  25. tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
  26. tinygrad/runtime/autogen/webgpu.py +6985 -0
  27. tinygrad/runtime/graph/metal.py +28 -29
  28. tinygrad/runtime/ops_amd.py +37 -34
  29. tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
  30. tinygrad/runtime/ops_disk.py +1 -1
  31. tinygrad/runtime/ops_dsp.py +59 -33
  32. tinygrad/runtime/ops_llvm.py +14 -12
  33. tinygrad/runtime/ops_metal.py +78 -62
  34. tinygrad/runtime/ops_nv.py +9 -6
  35. tinygrad/runtime/ops_python.py +5 -5
  36. tinygrad/runtime/ops_webgpu.py +200 -38
  37. tinygrad/runtime/support/am/amdev.py +23 -11
  38. tinygrad/runtime/support/am/ip.py +10 -10
  39. tinygrad/runtime/support/elf.py +2 -0
  40. tinygrad/runtime/support/hcq.py +7 -5
  41. tinygrad/runtime/support/llvm.py +8 -14
  42. tinygrad/shape/shapetracker.py +3 -2
  43. tinygrad/shape/view.py +2 -3
  44. tinygrad/spec.py +21 -20
  45. tinygrad/tensor.py +150 -90
  46. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  47. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  48. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  49. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  50. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  51. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  52. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  53. tinygrad/viz/index.html +544 -0
  54. tinygrad/viz/perfetto.html +178 -0
  55. tinygrad/viz/serve.py +205 -0
  56. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
  57. tinygrad-0.10.2.dist-info/RECORD +99 -0
  58. tinygrad/codegen/rewriter.py +0 -516
  59. tinygrad-0.10.1.dist-info/RECORD +0 -86
  60. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  61. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
  62. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,247 @@
1
+ from typing import Optional, Any, Callable
2
+ import functools, operator
3
+ from collections import defaultdict
4
+ from tinygrad.dtype import dtypes, ImageDType, PtrDType
5
+ from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve
6
+ from tinygrad.ops import graph_rewrite, GroupOp
7
+ from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, mulacc_unrolled
8
+ from tinygrad.helpers import getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
9
+ from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
10
+ from tinygrad.renderer import Renderer
11
+
12
+ # ***** float4/image store handling *****
13
+
14
+ def fold_expanded(ex, buf):
15
+ new_srcs = dedup(list(ex.src))
16
+ old_new_srcs = new_srcs[:]
17
+ is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType)
18
+
19
+ # TODO: get the device from the buffer somehow
20
+ # NOTE: this can't be Device.DEFAULT because it opens devices
21
+ if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None
22
+ lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
23
+
24
+ # first, extract all the relevant offsets
25
+ offsets_rootsrc: defaultdict[Any, dict] = defaultdict(dict)
26
+ for i,s in enumerate(new_srcs):
27
+ idx = s.src[0].src[1]
28
+ if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue
29
+ if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
30
+ elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
31
+ else: root_src, arg = idx, 0
32
+ # add gates for gated
33
+ if len(s.src[0].src) == 3: root_src = (s.src[0].src[2], root_src)
34
+ assert arg not in offsets_rootsrc[root_src], f"{offsets_rootsrc[root_src][arg]} != {i} with {len(s.src)} sources"
35
+ offsets_rootsrc[root_src][arg] = i
36
+
37
+ # then rewrite everything we can
38
+ used: set[tuple[UOp, UOp]] = set()
39
+ for rootsrc, offsets in offsets_rootsrc.items():
40
+ for o in offsets:
41
+ for fold_length in lengths:
42
+ if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)):
43
+ load_1 = new_srcs[offsets[o]]
44
+ new_src = list(load_1.src)
45
+ oidx = new_src[0].src[1]
46
+ if oidx.divides(fold_length) is None: continue
47
+ if is_image:
48
+ # for images, we rewrite the index. it must evenly divide 4 from the above check
49
+ new_src[0] = buf.index(
50
+ UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
51
+ rootsrc[0] if isinstance(rootsrc, tuple) else None)
52
+ else:
53
+ # for non image, we upcast the index pointer
54
+ new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(size=new_src[0].dtype.size, local=new_src[0].dtype.local))
55
+ # generate the folded new_srcs
56
+ if is_load:
57
+ new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
58
+ for i in range(fold_length): new_srcs[offsets[o+i]] = new_load.gep(i)
59
+ else: # vectorize the store
60
+ new_src[1] = UOp(Ops.VECTORIZE, new_src[1].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[1] for i in range(fold_length)))
61
+ for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(Ops.STORE, dtypes.void, tuple(new_src)) if i == 0 else None
62
+ used.update((rootsrc,o+i) for i in range(fold_length))
63
+
64
+ # dedup expand for LOAD
65
+ if is_load and len(old_new_srcs) != len(ex.src): new_srcs = [new_srcs[old_new_srcs.index(s)] for s in ex.src]
66
+ # remove Nones for STORE
67
+ return UOp(ex.op, ex.dtype, tuple(x for x in new_srcs if x is not None), ex.arg) if len(used) else None
68
+
69
+ def fix_unfoldable_image_load(load:UOp, buf:UOp):
70
+ if not isinstance(buf.dtype, ImageDType) or (oidx:=load.src[0].src[1]).dtype.count == 2: return None
71
+ id4 = oidx % 4
72
+ new_src = list(load.src)
73
+ # TODO: copied logic from above
74
+ new_src[0] = load.src[0].src[0].index(
75
+ UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
76
+ load.src[0].src[2] if len(load.src[0].src) == 3 else None)
77
+ vec_load = UOp(Ops.LOAD, load.dtype.vec(4), tuple(new_src))
78
+ return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan')))
79
+
80
+ buf_idx_pat = UPat(Ops.INDEX, src=(UPat.var("buf"),), allow_any_len=True)
81
+ float4_folding = PatternMatcher([
82
+ (UPat(Ops.VECTORIZE, src=UPat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
83
+ (UPat((Ops.BARRIER, Ops.SINK), src=UPat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
84
+ ])
85
+
86
+ # ***** image load valid simplification *****
87
+
88
+ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
89
+ if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0)
90
+ if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
91
+
92
+ # wait for it to be image indexed before running simplification
93
+ if start_idx.dtype.count != 2: return None
94
+
95
+ # can drop valid if idx is out of bound when valid is False
96
+ drop_stmt = []
97
+ for stmt in split_uop(valid, Ops.AND):
98
+ X, is_upper_bound, c = parse_valid(stmt)
99
+
100
+ # for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
101
+ if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)):
102
+ testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, Ops.ADD), idx)
103
+ testidx = testidx.simplify()
104
+ if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
105
+ drop_stmt.append(stmt)
106
+ continue
107
+
108
+ # if X <= c, check if it's out of bound when X = c+1
109
+ # if X >= c, check if it's out of bound when X = c-1
110
+ test_value = c + 1 if is_upper_bound else c - 1
111
+ for i,b in zip(idx.src, (buf.dtype.shape[1], buf.dtype.shape[0])):
112
+ if i.is_increasing():
113
+ rw = i.substitute({X:X.const_like(test_value)}).simplify()
114
+ if rw.vmin >= b or rw.vmax < 0:
115
+ drop_stmt.append(stmt)
116
+ break
117
+
118
+ if not drop_stmt and idx is start_idx: return None
119
+ new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None
120
+ return buf.index(idx, new_valid)
121
+
122
+ # ***** optional patterns *****
123
+
124
+ powers_of_two = {2**i:i for i in range(64)}
125
+ @functools.lru_cache(None)
126
+ def get_late_rewrite_patterns(ops, force_transcendental=False):
127
+ pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
128
+ ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
129
+ # rewrite SQRT to xpow 0.5
130
+ if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
131
+ # rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
132
+ if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
133
+ # rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
134
+ if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)]
135
+ if Ops.SHR in ops:
136
+ pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) and resolve(x>=0,False) else None)]
137
+ if Ops.NEG in ops:
138
+ pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
139
+ if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
140
+ if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
141
+ return PatternMatcher(pat)
142
+
143
+
144
+ # *** uop expander ***
145
+
146
+ # TODO: there's a lot shared with gep_through_wmma here
147
+ def no_vectorized_wmma(wmma:UOp):
148
+ out_sz = prod(x[1] for x in wmma.arg[6][-1])
149
+ if wmma.dtype.count == out_sz: return None
150
+ tsrcs = []
151
+ for s,sz in zip(wmma.src, wmma.arg[6]):
152
+ ssz = prod(x[1] for x in sz)
153
+ tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
154
+ wmmas = [UOp(Ops.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
155
+ wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
156
+ return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
157
+
158
+ def no_vectorized_alu(alu):
159
+ if alu.dtype.vcount == 1: return None
160
+ alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
161
+ return UOp(Ops.VECTORIZE, alu.dtype, alus)
162
+
163
+ def no_vectorized_load_store(ls:UOp):
164
+ idx = ls.src[0]
165
+ assert isinstance(idx.dtype, PtrDType)
166
+ if idx.dtype.v == 1: return None
167
+ tv = [UOp(ls.op, ls.dtype.scalar(), tuple(j.gep(i) for j in ls.src)) for i in range(idx.dtype.v)]
168
+ return UOp(Ops.VECTORIZE, ls.dtype, tuple(tv))
169
+
170
+ def no_vectorized_acc(acc:UOp):
171
+ if acc.dtype.count == 1: return None
172
+ alus = tuple(UOp(acc.op, acc.dtype.scalar(),
173
+ tuple(s.gep(i) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
174
+ return UOp(Ops.VECTORIZE, acc.dtype, alus)
175
+
176
+ devectorize = PatternMatcher([
177
+ # no ALU on vectorized dtypes
178
+ (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
179
+ (UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
180
+ (UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
181
+ (UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
182
+ ])
183
+
184
+ devectorize_load_store = PatternMatcher([
185
+ # TODO: add vectorized support to transcendental
186
+ (UPat((Ops.INDEX, Ops.EXP2, Ops.LOG2, Ops.SIN), name="alu"), no_vectorized_alu),
187
+ (UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
188
+ ])
189
+
190
+ def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
191
+ if store_gate not in [gate.src[0] for gate in val.toposort if gate.op is Ops.IF]: return None
192
+ # remove the gate from the index
193
+ return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
194
+
195
+ load_store_indexing = PatternMatcher([
196
+ # late fixup of unfoldable image loads
197
+ (UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
198
+ # simplify valid
199
+ (UPat(Ops.AND, name="valid"), simplify_valid),
200
+ # image load valid idx simplification
201
+ (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
202
+ # delete_redundant_gates (after expand)
203
+ (UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
204
+ UPat.var("val"))), delete_redundant_gates),
205
+ ])
206
+
207
+ def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:UOp|None=None) -> UOp:
208
+ # this moves the mask from the indexing to the load/store op for rendering
209
+ nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
210
+ return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is Ops.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
211
+
212
+ pm_render = PatternMatcher([
213
+ # for rendering, we use explicit VECTORIZE
214
+ (UPat(Ops.CONST, name='c'),
215
+ lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
216
+ (UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
217
+ (UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
218
+ (UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
219
+ # move masks of loads/stores
220
+ (UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("mask"))),
221
+ masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
222
+ # gate any stores that aren't gated with ifs
223
+ (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
224
+ lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
225
+ ])
226
+
227
+ # *** uop graph ***
228
+
229
+ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
230
+ assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
231
+ supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
232
+ extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
233
+
234
+ if DEVECTORIZE:
235
+ # devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse
236
+ sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+
237
+ mulacc_unrolled)
238
+ else:
239
+ # new devectorize only for load/store
240
+ sink = graph_rewrite(sink, sym+devectorize_load_store+mulacc_unrolled)
241
+
242
+ # optional pre matcher
243
+ if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher)
244
+
245
+ # final rules for the renderer (without sym)
246
+ sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher)
247
+ return sink
@@ -0,0 +1,121 @@
1
+ # this converts a lowerer program into a vectorized program
2
+
3
+ import functools, itertools, operator
4
+ from tinygrad.helpers import AMX, dedup, flatten, all_same, prod
5
+ from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, graph_rewrite
6
+ from tinygrad.codegen.symbolic import sym
7
+
8
+ def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int:
9
+ idx, mul = 0, 1
10
+ for axis,m in args[::-1]:
11
+ idx += rpk[axis] * mul
12
+ mul *= m
13
+ return idx
14
+
15
+ def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]:
16
+ return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
17
+
18
+ @functools.lru_cache(None)
19
+ def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]:
20
+ return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
21
+
22
+ def do_expand(root:UOp):
23
+ expands = [x for x in root.src if x.op is Ops.UNROLL]
24
+ if len(expands) == 0: return None
25
+ # NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
26
+ exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is Ops.WMMA else ()
27
+ if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0:
28
+ # if there's only one expand arg, it's okay to use it (optimization)
29
+ expand_args = expands[0].arg
30
+ else:
31
+ # otherwise, we sort them and GEP
32
+ expand_args = tuple(x for x in sorted(dedup(flatten(expands_args))) if x[0] not in exclude_args)
33
+ expand_sz = prod([x[1] for x in expand_args])
34
+ new_srcs = []
35
+ for i,src in enumerate(root.src):
36
+ if src.op is Ops.UNROLL:
37
+ if root.op is Ops.IF and i == 0:
38
+ # IF means OR on first arg to IF
39
+ new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)]))
40
+ elif expand_args == src.arg:
41
+ # just remove the expand
42
+ new_srcs.append(src.src[0])
43
+ else:
44
+ lst = _swizzle_args(expand_args, src.arg, exclude_args)
45
+ # if the base dtype is > 1, put those at the end
46
+ if src.dtype.count > 1: lst = flatten([[i*src.dtype.count+j for j in range(src.dtype.count)] for i in lst])
47
+ new_srcs.append(src.src[0].gep(tuple(lst)))
48
+ else:
49
+ # non-UNROLL input
50
+ if root.op is Ops.IF:
51
+ # for the first arg of IF, just pass them through ignoring UNROLLS
52
+ new_srcs.append(src)
53
+ elif src.dtype.count > 1:
54
+ # put any input dtype > 1 grouped together
55
+ new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
56
+ else:
57
+ # repeat the arg
58
+ new_srcs.append(src.broadcast(expand_sz))
59
+
60
+ new_arg = root.arg
61
+ if root.op is Ops.GEP:
62
+ assert root.dtype.count == 1
63
+ # is this right?
64
+ new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
65
+ nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
66
+ return UOp(Ops.UNROLL, root.dtype, (nsrc,), expand_args)
67
+
68
+ def do_contract(con:UOp):
69
+ ex = con.src[0]
70
+ # CONTRACT without UNROLL repeats the element VECTORIZED
71
+ if ex.op is not Ops.UNROLL: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count)
72
+ # CONTRACT may remove several axes from UNROLL
73
+ assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
74
+ idxs = []
75
+ for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
76
+ idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)]
77
+ return UOp(Ops.UNROLL, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
78
+
79
+ expander = PatternMatcher([
80
+ # double expand
81
+ (UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
82
+ lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
83
+ # do expansion
84
+ (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
85
+ Ops.VECTORIZE, Ops.IF), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
86
+ (UPat(Ops.CONTRACT, name="con"), do_contract),
87
+ # vectorize DEFINE_ACC
88
+ (UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"),
89
+ lambda acc,v: acc.replace(dtype=v.dtype, src=(acc.src[0].broadcast(v.dtype.count),)+acc.src[1:])),
90
+ # BARRIERs aren't actually expanded
91
+ (UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
92
+ lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)),
93
+ # empty UNROLL is NOOP
94
+ (UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x),
95
+ # UNROLL GEP (needed for WMMA, generalize this) -> vectorized ALU
96
+ (UPat(Ops.UNROLL, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
97
+ lambda ex,x,y: UOp(Ops.UNROLL, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
98
+ ])
99
+
100
+ def create_gate(root:UOp) -> UOp|None:
101
+ @functools.lru_cache(None)
102
+ def _gate_srcs(u:UOp, gate:UOp) -> UOp:
103
+ if u.op is Ops.BARRIER: return u
104
+ if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER:
105
+ return UOp(u.op, u.dtype, u.src[:-1]+(UOp(Ops.IF, src=(gate, u.src[-1])),), arg=u.arg)
106
+ return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg)
107
+ idx = root.src[0]
108
+ if idx.op is Ops.CAST: idx = idx.src[0]
109
+ return None if idx.op is not Ops.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret
110
+
111
+ migrate_indexing = PatternMatcher([
112
+ # create gate MUST BE BEFORE expander
113
+ (UPat(Ops.STORE, name="root"), create_gate),
114
+ ])
115
+
116
+ def expand_rewrite(sink:UOp) -> UOp:
117
+ # initial symbolic + migrate indexing (remove this)
118
+ sink = graph_rewrite(sink, sym+migrate_indexing)
119
+
120
+ # expand
121
+ return graph_rewrite(sink, sym+expander)
@@ -3,43 +3,26 @@ import itertools, functools, math
3
3
  from dataclasses import dataclass
4
4
  from collections import defaultdict
5
5
  from typing import Optional, cast, Final, Callable, Sequence
6
- from enum import Enum, auto
7
6
 
8
7
  from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, view_left, print_uops
8
+ from tinygrad.ops import PatternMatcher
9
9
  from tinygrad.spec import type_verify, shape_spec
10
10
  from tinygrad.device import Device
11
- from tinygrad.renderer import Renderer, TensorCore, ProgramSpec
11
+ from tinygrad.renderer import Renderer, TensorCore, ProgramSpec, Opt, OptOps
12
12
  from tinygrad.dtype import ImageDType
13
13
  from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap, ContextVar
14
14
  from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, USE_TC, AMX, CAPTURE_PROCESS_REPLAY
15
15
  from tinygrad.shape.shapetracker import ShapeTracker
16
16
  from tinygrad.shape.view import strides_for_shape
17
17
  from tinygrad.codegen.linearize import linearize_uop
18
- from tinygrad.codegen.rewriter import full_graph_rewrite
18
+ from tinygrad.codegen.devectorizer import full_graph_rewrite
19
19
  from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction
20
20
 
21
- class OptOps(Enum):
22
- TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
23
- GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
24
- def __lt__(self, x:OptOps): return self.value < x.value
25
-
26
21
  class KernelOptError(Exception): pass
27
22
 
28
23
  def check(cond:bool, msg:str=""):
29
24
  if not cond: raise KernelOptError(msg)
30
25
 
31
- @dataclass(frozen=True, order=True)
32
- class Opt:
33
- op: OptOps
34
- axis: Optional[int] = None
35
- arg: Optional[int | tuple] = None
36
- def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
37
- def real_axis(self, k:Kernel):
38
- if self.axis is None: return -1
39
- if self.op is OptOps.UNROLL: return k.first_reduce+self.axis
40
- if self.op in {OptOps.GROUP, OptOps.GROUPTOP}: return k.first_reduce+k.group_for_reduces+self.axis
41
- return self.axis
42
-
43
26
  @dataclass
44
27
  class TensorCoreOptions:
45
28
  axes: tuple[int, ...] # the location of the original N and M axes if still in the shape
@@ -325,8 +308,8 @@ class Kernel:
325
308
  -1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
326
309
  [0-N]: uses only the n'th tensor core available; useful for search
327
310
  tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
328
- 0: applies to only kernels with a single reduce axis and direct UOps.LOAD into Ops.MUL
329
- 1: allows kernels with multiple reduce axes and also multiplication of UOps.CAST'd buffers
311
+ 0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL
312
+ 1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers
330
313
  2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
331
314
  """
332
315
  if tc_select is None: tc_select = TC_SELECT.value
@@ -339,7 +322,7 @@ class Kernel:
339
322
  if extra_opts is not None:
340
323
  for opt in extra_opts: self.apply_opt(opt)
341
324
  else:
342
- if (self.opts.device == "CLANG" and AMX): return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
325
+ if AMX: return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
343
326
  # hand-coded TC opts
344
327
  for tc_dim in [tc_dim for tc_dim in [1,0] if tc_opts.axes_exist[tc_dim]]: # attempt to upcast M and N
345
328
  szs = [sz for sz in [5,4,3,2] if self.full_shape[tc_opts.axes[tc_dim]] % sz == 0]
@@ -351,6 +334,12 @@ class Kernel:
351
334
  except KernelOptError:
352
335
  return False
353
336
 
337
+ def real_axis(self, opt:Opt):
338
+ if opt.axis is None: return -1
339
+ if opt.op is OptOps.UNROLL: return self.first_reduce+opt.axis
340
+ if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.first_reduce+self.group_for_reduces+opt.axis
341
+ return opt.axis
342
+
354
343
  def apply_opt(self, opt:Opt, append_opt:bool=True):
355
344
  if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals")
356
345
 
@@ -365,7 +354,7 @@ class Kernel:
365
354
  self.applied_opts.append(opt)
366
355
  return
367
356
 
368
- axis = opt.real_axis(self)
357
+ axis = self.real_axis(opt)
369
358
  check(axis < len(self.full_shape), "invalid axis")
370
359
 
371
360
  if opt.op is OptOps.SWAP: amt = cast(int, opt.arg) # arg is an axis in the SWAPs
@@ -385,6 +374,8 @@ class Kernel:
385
374
  check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
386
375
 
387
376
  if opt.op is OptOps.LOCAL: # cyan
377
+ # NOTE: LLVM/CPU can use locals too, but they are treated the same as globals (still helpful for L1 cache)
378
+ # it's disabled for now since it makes BEAM slow for little gain
388
379
  check(self.opts.has_local, "target does not support local")
389
380
  check(axis < self.global_dims, "local is for globals")
390
381
  self.shift_to(axis, amt, insert_before=self.first_reduce)
@@ -409,7 +400,7 @@ class Kernel:
409
400
  elif opt.op is OptOps.UPCAST: # yellow
410
401
  check(axis < self.first_reduce, "upcast is for non-reduce")
411
402
  check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.get_local_axes())), "can't upcast TC locals")
412
- check(amt <= 16, "don't upcast more than 16")
403
+ check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
413
404
  self.shift_to(axis, amt, insert_before=None)
414
405
  self.upcast()
415
406
  elif opt.op is OptOps.NOLOCALS:
@@ -425,7 +416,7 @@ class Kernel:
425
416
  check(not self.vars, "does not work with symbolic shape")
426
417
  check(axis < self.first_upcast, "cannot pad upcasted")
427
418
  # ok to pad SUM if all parent ALU ops have f(0) = 0
428
- if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}, {}), f"cannot pad {r}")
419
+ if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}, cache={}), f"cannot pad {r}")
429
420
  padded = False
430
421
  for i,st in enumerate(self.sts):
431
422
  if (s:=st.shape[axis]) == 1: continue # reduced
@@ -512,7 +503,7 @@ class Kernel:
512
503
  for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
513
504
 
514
505
  # potentially do more upcasts of non reduce axes based on a heuristic
515
- upcasted_axis = set()
506
+ upcasted_axis: set[int] = set()
516
507
  while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024):
517
508
  xb_choices = []
518
509
  for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
@@ -582,7 +573,7 @@ class Kernel:
582
573
  num = f"n{Kernel.kernel_cnt[function_name]-1}" if Kernel.kernel_cnt[function_name] > 1 else ""
583
574
  return name + colored(num, 'BLACK')
584
575
 
585
- def get_optimized_ast(self) -> UOp:
576
+ def get_optimized_ast(self, name_override:Optional[str]=None) -> UOp:
586
577
  @functools.lru_cache(None)
587
578
  def fixup_ast(op:UOp) -> UOp:
588
579
  ret = op.replace(src=tuple(fixup_ast(x) for x in op.src))
@@ -592,7 +583,9 @@ class Kernel:
592
583
  if op.op is Ops.CONST and any(v.mask is not None for v in unwrap(st_uop.st).views): return op.valid(unwrap(st_uop.st))
593
584
  # otherwise we just replace the VIEW source
594
585
  return ret.replace(src=(st_uop,)) if len(op.src) == 1 else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
595
- if op.op is Ops.SINK: return ret.replace(arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals))
586
+ if op.op is Ops.SINK:
587
+ return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override,
588
+ self.local_dims, self.upcasted, self.dont_use_locals))
596
589
  if op.op is Ops.REDUCE_AXIS:
597
590
  reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
598
591
 
@@ -662,13 +655,17 @@ class Kernel:
662
655
  # **** this is the lowerer ****
663
656
 
664
657
  @track_rewrites()
665
- def linearize(self) -> Kernel:
666
- modified_ast = self.get_optimized_ast()
658
+ def linearize(self, name_override:Optional[str]=None) -> Kernel:
659
+ # display the AST
660
+ if getenv("VIZ"): graph_rewrite(self.ast, PatternMatcher([]), name="View Base AST")
661
+
662
+ modified_ast = self.get_optimized_ast(name_override)
667
663
 
668
664
  if DEBUG >= 3:
669
665
  print(self.name)
670
666
  if getenv("RAWAST"): print(self.ast)
671
- print(modified_ast)
667
+ for i,(buf,st) in enumerate([(buf,st) for buf,st in zip(self.bufs, self.sts) if buf.op not in {Ops.CONST, Ops.VALID}]):
668
+ print(f"{i:2d}: {str(st.shape):25s} {str(buf.src[0].dtype).replace('dtypes.',''):20s}", st.real_strides())
672
669
  print(self.applied_opts)
673
670
  # verify AST matches the spec after applying opts
674
671
  if __debug__: type_verify(list(modified_ast.toposort))
@@ -680,16 +677,17 @@ class Kernel:
680
677
  return self
681
678
 
682
679
  def to_program(self, name_override:Optional[str]=None) -> ProgramSpec:
683
- self.linearize()
684
- src = self.opts.render(name:=to_function_name(ansiname:=(name_override if name_override is not None else self.name)), self.uops)
680
+ self.linearize(name_override)
681
+ assert self.uops[0].op is Ops.NAME, "first uop must be name"
682
+ src = self.opts.render(self.uops)
685
683
 
686
684
  if CAPTURE_PROCESS_REPLAY:
687
- diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, ContextVar._cache, src))
685
+ diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, self.uops[0].arg, ContextVar._cache, src))
688
686
 
689
687
  # group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
690
688
  # TODO: these max and min don't work on symbolic, and results are very wrong.
691
689
  mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
692
690
  for _, group in itertools.groupby([x for x in self.ast.toposort if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
693
691
  key=lambda x: (x.op, x.src[0].arg)))
694
- return ProgramSpec(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
695
- global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
692
+ return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, self.applied_opts, mem_bytes,
693
+ global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
@@ -6,7 +6,7 @@ from tinygrad.spec import type_verify
6
6
  from tinygrad.dtype import dtypes, PtrDType
7
7
  from tinygrad.helpers import dedup, flatten, partition
8
8
 
9
- DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block}
9
+ DONT_PLACE_IN_BLOCK = {Ops.NAME, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block}
10
10
 
11
11
  def disp(y:UOp) -> str:
12
12
  if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0])
@@ -70,7 +70,8 @@ def append_to_block(ctx:tuple[dict[UOp, tuple[UOp, ...]], dict[UOp, list[UOp]]],
70
70
  return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(list(old_blocks.values())+new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst))
71
71
 
72
72
  make_basic_blocks = PatternMatcher([
73
- (UPat(Ops.SINK, name="x"), lambda x: UOp(Ops.BLOCK, src=x.src, arg=BasicBlock((), (x,)))),
73
+ (UPat(Ops.SINK, name="x"),
74
+ lambda x: UOp(Ops.BLOCK, src=x.src+((UOp(Ops.NAME, arg=x.arg.name),) if x.arg is not None else ()), arg=BasicBlock((), (x,)))),
74
75
  (UPat(Ops.BLOCK, name="x"), append_to_block),
75
76
  ])
76
77
 
@@ -112,6 +113,17 @@ def block_merge(ctx, x:UOp):
112
113
 
113
114
  pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),])
114
115
 
116
+ def block_finalize(block:UOp):
117
+ if len(block.src) == 0: return None
118
+ _uops = sorted(dedup(block.src), key=lambda x: x.tuplize)
119
+ assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
120
+ _uops += block.arg.lst
121
+ # strip the SINK
122
+ assert _uops[-1].op is Ops.SINK, "doesn't end with SINK"
123
+ return UOp(Ops.BLOCK, arg=BasicBlock((), tuple(_uops[:-1])))
124
+
125
+ pm_block_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="block"), block_finalize)])
126
+
115
127
  # NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
116
128
  def block_reorder(in_block:UOp):
117
129
  in_this_block = set(in_block.arg.lst)
@@ -212,14 +224,11 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
212
224
  # final rewrite to merge all blocks into one
213
225
  sink = graph_rewrite(sink, pm_block_merge, ctx=children)
214
226
 
215
- # there should just be one block left, with a few parents with 0 srcs
216
- assert sink.op is Ops.BLOCK
217
- _uops = sorted(dedup(sink.src), key=lambda x: x.tuplize)
218
- assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
219
- _uops += sink.arg.lst
227
+ # there should just be one block left, with a few parents with 0 srcs (now done in a rewriter)
228
+ sink = graph_rewrite(sink, pm_block_finalize)
220
229
 
221
230
  # sanity checks (NOTE: these can cause things to be skipped in BEAM)
222
- if not skip_check: type_verify(_uops)
231
+ if not skip_check: type_verify(sink.arg.lst)
223
232
 
224
- # strip the SINK
225
- return _uops[:-1]
233
+ # return the list. TODO: refactor to return the UOp
234
+ return list(sink.arg.lst)