tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -1,95 +1,234 @@
1
- from typing import List, Set, Dict, Tuple
2
- import functools, heapq
3
- from tinygrad.ops import type_verify, END_FOR_UOP, UOp, Ops, GroupOp
4
- from tinygrad.dtype import dtypes
5
- from tinygrad.helpers import DEBUG
6
-
7
- def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[UOp, None]], in_degree:Dict[UOp, int]):
8
- if u in children: return srcs[u]
9
- srcs[u] = {}
10
- children[u] = []
11
- for x in u.src:
12
- srcs[u].update(get_children_dfs(x, children, srcs, in_degree))
13
- if x.op is Ops.RANGE and x.arg[1]: srcs[u][x] = None
14
- children[x].append(u)
15
- in_degree[u] = len(u.src)
16
- return srcs[u]
17
-
18
- def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
19
- assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
20
- # filter nodes that don't link to a sink
21
- # BFS toposort
22
- children: Dict[UOp, List[UOp]] = {}
23
- range_srcs: Dict[UOp, Dict[UOp, None]] = {}
24
- in_degree: Dict[UOp, int] = {}
25
- get_children_dfs(sink, children, range_srcs, in_degree)
26
-
27
- @functools.lru_cache(None)
28
- def get_recursive_children(x:UOp, end:Ops, include_self=False) -> Set[UOp]:
29
- if x.op is Ops.SINK: return set()
30
- return set.union({x} if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
31
-
32
- # scope children impact the toposort and END* insertion
33
- scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
34
- range_phi = {r:[p for p in scope_children[r] if p.op is Ops.ASSIGN] for r in scope_children if r.op is Ops.RANGE}
35
-
36
- # assign priorities
37
- def get_priority(u:UOp):
38
- priority = 0
39
- # prefer ranges that depend on the least number of independent ranges
40
- if u.op is Ops.RANGE and u.arg[1]:
41
- priority += u.arg[0]
42
- for p in range_phi[u]:
43
- priority += 10000*len([r for r in range_srcs[p] if not any(i in range_phi[u] for i in range_phi[r])])
44
- elif u.op is Ops.CONST:
45
- # place consts first here, they don't do anything and it can cause issues with DEFINE_ACC
46
- priority -= 100000000000
1
+ from __future__ import annotations
2
+ import collections, heapq
3
+ from dataclasses import dataclass
4
+ from tinygrad.ops import UOp, Ops, PatternMatcher, UPat, graph_rewrite, GroupOp
5
+ from tinygrad.spec import type_verify
6
+ from tinygrad.dtype import dtypes, PtrDType
7
+ from tinygrad.helpers import dedup, flatten, partition
8
+
9
+ DONT_PLACE_IN_BLOCK = {Ops.NAME, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block}
10
+
11
+ def disp(y:UOp) -> str:
12
+ if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0])
13
+ if y.op is Ops.IF: return f'IF{id(y)}'
14
+ if y.op is Ops.RANGE: return str(y.arg)
15
+ return "<NONE>"
16
+
17
+ @dataclass(frozen=True)
18
+ class BasicBlock:
19
+ ctx: tuple[UOp, ...]
20
+ lst: tuple[UOp, ...]
21
+ end: UOp|None = None
22
+ def __lt__(self, o:BasicBlock): return tuple(x.tuplize for x in self.ctx+self.lst) < tuple(x.tuplize for x in o.ctx+o.lst)
23
+ def __repr__(self):
24
+ return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+\
25
+ f"{[disp(y) for y in self.ctx]} {len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
26
+
27
+ def append_to_block(ctx:tuple[dict[UOp, tuple[UOp, ...]], dict[UOp, list[UOp]]], x:UOp):
28
+ block_ctxs, children = ctx
29
+ in_this_block = set(x.arg.lst)
30
+
31
+ # collections to build
32
+ new_srcs: list[UOp] = []
33
+ to_append: list[UOp] = []
34
+ old_blocks: dict[tuple[UOp, ...], UOp] = {}
35
+ new_blocks: dict[tuple[UOp, ...], list[UOp]] = {}
36
+
37
+ for u in x.src:
38
+ if u.op is Ops.BLOCK:
39
+ # merge sibling blocks. NOTE: blocks must only have one output source
40
+ assert u.arg.ctx not in old_blocks, "sibling should never have been created"
41
+ old_blocks[u.arg.ctx] = u
42
+ elif u.op not in DONT_PLACE_IN_BLOCK and set(children[u]).issubset(in_this_block):
43
+ # if it can go in blocks and all its children are in the block, we add it to the block
44
+ if (block_ctx:=block_ctxs[u]) == x.arg.ctx:
45
+ # if it's the same context, we place the UOp in this block and append the parents to its srcs
46
+ new_srcs.extend(u.src)
47
+ to_append.append(u)
48
+ else:
49
+ # if it's a different context, we create a new block with this UOp
50
+ new_blocks.setdefault(block_ctx, []).append(u)
51
+ else:
52
+ # otherwise, we keep it in the srcs
53
+ new_srcs.append(u)
54
+ if len(to_append) == 0 and len(new_blocks) == 0: return None
55
+
56
+ for rng,lst in new_blocks.items():
57
+ srcs = flatten(y.src for y in lst)
58
+ if (old_block:=old_blocks.pop(rng, None)) is not None:
59
+ # NOTE: order shouldn't matter here
60
+ srcs.extend(old_block.src)
61
+ lst.extend(old_block.arg.lst)
62
+ new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(srcs)), BasicBlock(rng, tuple(lst)))
63
+ lrng = list(rng)
64
+ for r in rng[::-1]:
65
+ if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART:
66
+ lrng.remove(r)
67
+ new_block = UOp(Ops.BLOCKEND, src=(new_block,),
68
+ arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r))
69
+ new_srcs.append(new_block)
70
+ return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(list(old_blocks.values())+new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst))
71
+
72
+ make_basic_blocks = PatternMatcher([
73
+ (UPat(Ops.SINK, name="x"),
74
+ lambda x: UOp(Ops.BLOCK, src=x.src+((UOp(Ops.NAME, arg=x.arg.name),) if x.arg is not None else ()), arg=BasicBlock((), (x,)))),
75
+ (UPat(Ops.BLOCK, name="x"), append_to_block),
76
+ ])
77
+
78
+ def block_merge(ctx, x:UOp):
79
+ # ctx is children here
80
+ if x.op is Ops.BLOCKEND:
81
+ # if it's a BLOCKEND, see if we are done with placement. if all the children of the range are in here
82
+ in_this_block = set(x.arg.lst)
83
+ if len([y for y in ctx[x.arg.end] if y not in in_this_block]) == 0:
84
+ # find the parent block that has the BLOCKSTART in the ctx
85
+ parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and UOp(Ops.BLOCKSTART, src=(x.arg.end,)) in y.arg.ctx]
86
+ assert len(parent_blocks) <= 1, "should never have two parent blocks"
87
+ if len(parent_blocks) == 1:
88
+ parent_block = parent_blocks[0]
89
+ # range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if)
90
+ early_ops, late_ops = partition(x.arg.lst, lambda y: y.op is Ops.DEFINE_ACC and x.arg.end in y.src)
91
+ return UOp(Ops.BLOCK, dtypes.void, tuple(y for y in x.src if y is not parent_block)+parent_block.src,
92
+ BasicBlock(tuple(y for y in x.arg.ctx if y is not x.arg.end), tuple(early_ops)+parent_block.arg.lst+tuple(late_ops)))
93
+
94
+ new_srcs: list[UOp] = []
95
+ to_append: list[UOp] = []
96
+ new_ctx = x.arg.ctx
97
+ placed = set()
98
+ for u in x.src:
99
+ if u.op is Ops.BLOCK and (tuple(u.arg.ctx) == tuple(x.arg.ctx) or (x.arg.end is not None and x.arg.end in u.arg.ctx)):
100
+ # NOTE: this can't appear in srcs twice or it would be a BLOCKFORK
101
+ new_ctx += tuple(y for y in u.arg.ctx if y not in x.arg.ctx)
102
+ new_srcs.extend(u.src)
103
+ to_append.extend(u.arg.lst)
104
+ elif u.op is Ops.BLOCKFORK and x.src.count(u) == u.arg: # block fork appears # of times in srcs
105
+ if u not in placed:
106
+ new_srcs.extend(u.src)
107
+ placed.add(u)
47
108
  else:
48
- # prefer uops that are loop children
49
- priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is Ops.RANGE and u in ss])
50
- if u.op is Ops.IF and len(u.src) == 1: priority += 10000000 # if penalty
51
- return priority
52
- priorities:Dict[UOp, int] = {u:get_priority(u) for u in children}
53
-
54
- # prevent priority inversion
55
- @functools.lru_cache(None)
56
- def fix_priority(u:UOp, lowest_priority):
57
- if u.op in {Ops.CAST, Ops.BITCAST, *GroupOp.ALU, Ops.VECTORIZE, Ops.GEP, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.LOAD}:
58
- priorities[u] = min(priorities[u], lowest_priority)
59
- if u.op is Ops.LOAD: priorities[u] += 100 # load penalty (here)
60
- for x in u.src: fix_priority(x, priorities[u])
61
- fix_priority(sink, 0)
62
-
63
- # NOTE: the compare should never make it all the way to u
64
- queue:List[Tuple[int, Tuple, UOp]] = []
109
+ # keep it in srcs
110
+ new_srcs.append(u)
111
+ if len(to_append) == 0 and len(placed) == 0: return None
112
+ return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(tuple(sorted(new_ctx, key=lambda x: x.tuplize)), tuple(to_append)+x.arg.lst, x.arg.end))
113
+
114
+ pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),])
115
+
116
+ def block_finalize(block:UOp):
117
+ if len(block.src) == 0: return None
118
+ _uops = sorted(dedup(block.src), key=lambda x: x.tuplize)
119
+ assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
120
+ _uops += block.arg.lst
121
+ # strip the SINK
122
+ assert _uops[-1].op is Ops.SINK, "doesn't end with SINK"
123
+ return UOp(Ops.BLOCK, arg=BasicBlock((), tuple(_uops[:-1])))
124
+
125
+ pm_block_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="block"), block_finalize)])
126
+
127
+ # NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
128
+ def block_reorder(in_block:UOp):
129
+ in_this_block = set(in_block.arg.lst)
130
+ local_children: collections.defaultdict[UOp, list[UOp]] = collections.defaultdict(list)
131
+ in_degree: collections.defaultdict[UOp, int] = collections.defaultdict(int)
132
+ priorities:dict[UOp, int] = {}
133
+
134
+ # get local children and assign priorities
135
+ for u in reversed(in_block.arg.lst):
136
+ for s in u.src:
137
+ if s in in_this_block:
138
+ local_children[s].append(u)
139
+ in_degree[u] += 1
140
+ # put loads in the beginning of the block and prevent priority inversion
141
+ priorities[u] = min([-1000 if u.op is Ops.LOAD else 0] + [priorities[x] for x in local_children[u]])
142
+
143
+ # placement queue
144
+ queue:list[tuple[int, tuple, UOp]] = []
65
145
  def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u))
66
146
 
67
- for u in children:
68
- if in_degree[u] == 0: push(u)
147
+ # place the first ones that don't have deps
148
+ for u in in_block.arg.lst:
149
+ if u not in in_degree: push(u)
69
150
 
70
- scope_end: Dict[UOp, UOp] = {}
71
- _uops: List[UOp] = []
151
+ newlst = []
72
152
  while queue:
73
- p,_,x = heapq.heappop(queue)
74
- if DEBUG >= 7: print(f"{p:5d}", x.op, x.dtype, x.arg)
75
- if x in scope_children: scope_end[x] = x
76
- if x.op is Ops.DEFINE_ACC:
77
- idx = min([_uops.index(l) for l in x.src if l.op is Ops.RANGE])
78
- _uops.insert(idx, x)
79
- else: _uops.append(x)
80
- for u, ss in scope_children.items():
81
- if x in ss:
82
- ss.remove(x)
83
- if len(ss) == 0: scope_end[u] = x
84
- for u in children[x]:
153
+ _,_,x = heapq.heappop(queue)
154
+ newlst.append(x)
155
+ for u in local_children[x]:
85
156
  in_degree[u] -= 1
86
157
  if in_degree[u] == 0: push(u)
87
158
 
88
- # end scopes in toposort order
89
- for u, x in scope_end.items(): _uops.insert(_uops.index(x)+1, UOp(END_FOR_UOP[u.op][1], dtypes.void, (u,)))
159
+ assert len(newlst) == len(in_block.arg.lst), f"len mismatch {len(newlst)} != {len(in_block.arg.lst)}"
160
+ return in_block.replace(arg=BasicBlock(in_block.arg.ctx, tuple(newlst)))
161
+
162
+ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
163
+ assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
164
+
165
+ # get children and all block contexts
166
+ temp_block_ctxs: dict[UOp, list[UOp]] = {}
167
+ children: dict[UOp, list[UOp]] = {}
168
+ for u in sink.toposort:
169
+ this_block_ctx: list[UOp] = []
170
+ for s in u.src:
171
+ # save children
172
+ children.setdefault(s, []).append(u)
173
+ # compute block ctx
174
+ if s.op in {Ops.RANGE, Ops.IF}: this_block_ctx.append(s)
175
+ # don't flow (fully) through assign and store
176
+ elif s.op is Ops.STORE:
177
+ # ugh, deal with non-reduce locals. probably wrong
178
+ if isinstance(s.src[0].dtype, PtrDType) and s.src[0].dtype.local:
179
+ idx_context, store_context = temp_block_ctxs[s.src[0]], temp_block_ctxs[s]
180
+ this_block_ctx += [x for x in store_context if x not in idx_context and x.op is Ops.RANGE]
181
+ elif s.op is Ops.ASSIGN:
182
+ # flow though assign, but remove the ranges used in the assign
183
+ assert s.src[0].op is Ops.DEFINE_ACC
184
+ this_block_ctx += [x for x in temp_block_ctxs[s.src[1]] if x not in s.src[0].src[1:]]
185
+ else:
186
+ # flow though everything else
187
+ this_block_ctx += temp_block_ctxs[s]
188
+ temp_block_ctxs[u] = sorted(dedup(this_block_ctx), key=lambda x: x.tuplize)
189
+
190
+ # make final block_ctxs, add BLOCKSTART to block_ctxs for IF and RANGE
191
+ block_ctxs: dict[UOp, tuple[UOp, ...]] = {}
192
+ for u in sink.toposort:
193
+ block_ctxs[u] = ((UOp(Ops.BLOCKSTART, src=(u,)),) + tuple(temp_block_ctxs[u])) if u.op in {Ops.IF, Ops.RANGE} else tuple(temp_block_ctxs[u])
194
+
195
+ # TODO: there's probably a clever way to remove this while loop
196
+ while 1:
197
+ sink = graph_rewrite(sink, make_basic_blocks, ctx=(block_ctxs, children))
198
+
199
+ # add BLOCKFORK (slow!)
200
+ block_parent_count = collections.Counter(flatten([x.src for x in sink.toposort if x.op is Ops.BLOCK]))
201
+ non_block_parents = set(flatten([x.src for x in sink.toposort if x.op is not Ops.BLOCK]))
202
+ forks = {u:UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], (u,))),), arg=child_count)
203
+ for u,child_count in block_parent_count.items() if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents}
204
+
205
+ if not len(forks): break
206
+ sink = sink.substitute(forks)
207
+
208
+ # combine matching BLOCKENDS
209
+ blockends_to_arg: dict[UOp, list[UOp]] = {}
210
+ for be in sink.toposort:
211
+ if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
212
+ new_forks = {}
213
+ for k,v in blockends_to_arg.items():
214
+ # NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
215
+ if len(v) > 1:
216
+ out = UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCKEND, src=tuple(flatten(x.src for x in v)),
217
+ arg=BasicBlock(tuple(dedup(flatten([y.arg.ctx for y in v]))), v[0].arg.lst, k)),), arg=len(v))
218
+ for u in v: new_forks[u] = out
219
+ sink = sink.substitute(new_forks)
220
+
221
+ # reorder ops in block for speed
222
+ sink = sink.substitute({u:newu for u in sink.toposort if u.op is Ops.BLOCK and (newu:=block_reorder(u)) is not u})
223
+
224
+ # final rewrite to merge all blocks into one
225
+ sink = graph_rewrite(sink, pm_block_merge, ctx=children)
226
+
227
+ # there should just be one block left, with a few parents with 0 srcs (now done in a rewriter)
228
+ sink = graph_rewrite(sink, pm_block_finalize)
90
229
 
91
230
  # sanity checks (NOTE: these can cause things to be skipped in BEAM)
92
- if not skip_check: type_verify(_uops)
231
+ if not skip_check: type_verify(sink.arg.lst)
93
232
 
94
- # strip the SINK
95
- return _uops[:-1]
233
+ # return the list. TODO: refactor to return the UOp
234
+ return list(sink.arg.lst)
@@ -1,54 +1,70 @@
1
1
  # the job of the lowerer is to do indexing
2
- from __future__ import annotations
3
- import functools, itertools, operator
2
+ import functools, itertools, operator, math
4
3
  from dataclasses import dataclass
5
- from typing import List, Tuple, cast, Optional
6
- from tinygrad.shape.shapetracker import ShapeTracker
7
- from tinygrad.shape.view import variable_to_uop
4
+ from typing import cast
8
5
  from tinygrad.dtype import dtypes, PtrDType
9
- from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
6
+ from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
10
7
  from tinygrad.renderer import Renderer
11
- from tinygrad.helpers import all_int, prod, partition, flatten
8
+ from tinygrad.helpers import all_int, prod, partition, flatten, unwrap
9
+ from tinygrad.codegen.expander import expand_rewrite
12
10
 
13
11
  # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
14
- def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
12
+ def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
15
13
  acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
16
14
  try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
17
15
  except ValueError: return None
18
16
  return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
19
17
 
20
18
  # ***** indexing *****
21
-
22
- def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
19
+ def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
23
20
  # TODO: symbolic shape
24
21
  if not all_int(dims): return dims
25
22
  while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
26
23
  for i,m in enumerate(max_sizes):
27
- if dims[i] * dims[i+1] <= m:
24
+ if i < (len(dims)-1) and dims[i] * dims[i+1] <= m:
28
25
  dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
29
26
  break
30
- else: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
27
+ else: return None
31
28
  return dims
32
29
 
33
- def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]:
30
+ def _split_dims(dims, max_sizes):
31
+ if all(d <= m for d,m in zip(dims, max_sizes)): return dims
32
+ _dims = list(dims) + [1]*(3-len(dims))
33
+ for i in range(len(_dims)):
34
+ while _dims[i] > max_sizes[i]:
35
+ div = next((d for d in range(2, math.ceil(math.sqrt(_dims[i])) + 1) if (_dims[i] % d) == 0), 1)
36
+ if div == 1: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
37
+ _dims[i], _dims[(i+1)%len(_dims)] = _dims[i]//div, _dims[(i+1)%len(_dims)]*div
38
+ return tuple(_dims[:2] if _dims[2] == 1 else _dims[0] if _dims[1:3] == [1,1] else _dims)
39
+
40
+ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|None, reverse=False) -> list[UOp]:
34
41
  if reverse: dims = dims[::-1]
35
- limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
42
+ # try to group first: (a, b, c, d) -> (ab, c, d)
43
+ limited = (grouped if (grouped := _group_dims(dims, max_sizes)) else dims) if max_sizes is not None else dims
44
+ # check if grouping failed
45
+ if max_sizes is not None and len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
46
+ # try to split up dims: (a,) -> (b, c)
47
+ if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims
36
48
  ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
37
- if limited != dims:
49
+ if len(limited) < len(dims):
38
50
  ret = []
39
- # cast for mypy, get_contraction won't be None
40
- for idx, contraction in zip(raw_idxs, cast(List[List[int]], get_contraction(dims, limited))):
41
- if len(contraction) == 1: ret.append(idx)
42
- else:
43
- for c in contraction:
44
- ret.append(idx % dims[c])
45
- idx //= dims[c]
51
+ if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}")
52
+ for idx, contraction_group in zip(raw_idxs, contraction):
53
+ for c in contraction_group[:-1]:
54
+ ret.append(idx % dims[c])
55
+ idx //= dims[c]
56
+ ret.append(idx)
57
+ elif len(limited) > len(dims):
58
+ a, b = len(limited), len(dims)
59
+ if a == 2 and b == 1: ret = [raw_idxs[0] * limited[1] + raw_idxs[1]]
60
+ if a == 3 and b == 1: ret = [raw_idxs[0] * (limited[1] * limited[2]) + raw_idxs[1] * limited[2] + raw_idxs[2]]
61
+ if a == 3 and b == 2: ret = [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
46
62
  return ret[::-1] if reverse else ret
47
63
 
48
64
  @dataclass
49
65
  class IndexContext:
50
- idxs: List[UOp]
51
- ridxs: List[UOp]
66
+ idxs: list[UOp]
67
+ ridxs: list[UOp]
52
68
  acc_num: int = 0
53
69
 
54
70
  def get_index(ast:UOp, opts:Renderer) -> IndexContext:
@@ -56,14 +72,11 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
56
72
  # NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
57
73
  full_shape = ast.full_shape
58
74
  first_upcasted = len(full_shape)-ki.upcasted
59
- first_output_st: ShapeTracker = ast.src[0].st_arg
60
75
  # if there's no reduce, this is first_upcasted. assumes reduces are at the end
61
- first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.sparents if x.op is Ops.REDUCE_AXIS))
62
- local_loads = [x for x in ast.parents if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
76
+ first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.toposort if x.op is Ops.REDUCE_AXIS))
77
+ local_loads = [x for x in ast.toposort if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
63
78
  # NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
64
- group_for_reduces = sum([any(j!=y for j in x) for x,y in zip(
65
- [[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)],
66
- first_output_st.shape[first_reduce:first_upcasted])]) if local_loads else 0
79
+ group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)])
67
80
  global_dims = first_reduce-ki.local_dims
68
81
 
69
82
  if opts.has_local:
@@ -76,22 +89,21 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
76
89
  get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
77
90
  else:
78
91
  # all loops are RANGES
79
- idxs = [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False))
80
- for i,g in enumerate(full_shape[:first_reduce])]
92
+ idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i) for i,g in enumerate(full_shape[:first_reduce])]
81
93
 
82
94
  # reduce loops
83
- idxs += [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True))
95
+ idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i)
84
96
  for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
85
97
 
86
98
  # upcast loops
87
99
  for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
88
100
  assert isinstance(g, int), "needs to be int to upcast/unroll"
89
- idxs.append(UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
101
+ idxs.append(UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
90
102
 
91
103
  # late indexes (group for reduce)
92
104
  ridxs = idxs[:]
93
105
  for a in range(first_reduce, first_reduce+group_for_reduces):
94
- ridxs[a] = UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True))
106
+ ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(full_shape[a])), 1000+a)
95
107
 
96
108
  return IndexContext(idxs, ridxs)
97
109
 
@@ -100,7 +112,7 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
100
112
  def lower_reduce_axis(ctx: IndexContext, x: UOp):
101
113
  # NOTE: always using ridxs is fine here
102
114
  reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
103
- assert all(x.op is Ops.EXPAND for x in reduce_expand), f"not all EXPANDS in {reduce_expand} for {x.axis_arg}"
115
+ assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
104
116
  alu_op: Ops = x.arg[0]
105
117
  ret = x.src[0]
106
118
  if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
@@ -114,12 +126,10 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
114
126
 
115
127
  def lower_load_store(ctx: IndexContext, x: UOp):
116
128
  idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
117
- # TODO: check has_valid in UPat, not here
118
- has_valid = valid.op is not Ops.CONST or valid.arg is not True
119
129
  buf = x.src[0]
120
130
  if x.op is Ops.LOAD:
121
131
  barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
122
- return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid if has_valid else None),) + barrier)
132
+ return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
123
133
  # NOTE: only store the local reduceop in the threads that are actually doing the reduce
124
134
  if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.ASSIGN:
125
135
  reduce_input = x.src[2].src[1].src[1] if x.src[2].src[1].src[1] is not x.src[2].src[0] else x.src[2].src[1].src[0]
@@ -130,14 +140,22 @@ def lower_load_store(ctx: IndexContext, x: UOp):
130
140
  if (not cast(PtrDType, x.src[0].dtype).local) or store_back:
131
141
  for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
132
142
  if oidx is not ridx: valid = valid * oidx.eq(0)
133
- has_valid = valid.op is not Ops.CONST or valid.arg is not True
134
- return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid if has_valid else None), x.src[2]))
143
+ return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid), x.src[2]))
144
+
145
+ def lower_const(x:UOp):
146
+ assert all(v.mask is None for v in unwrap(x.st).views), f"VIEW in CONST/DEFINE_VAR source must be unmasked, got {x.st}"
147
+ return x.replace(src=())
135
148
 
136
149
  pm_lowerer = PatternMatcher([
137
150
  (UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
151
+ (UPat((Ops.CONST, Ops.DEFINE_VAR), src=(UPat(Ops.VIEW),), name="x"), lower_const),
138
152
  (UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
139
153
  # rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
140
154
  (UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
155
+ (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
141
156
  ])
142
157
 
143
- def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
158
+ def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
159
+ sink = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
160
+ # expand_rewrite turns this into a vectorized program
161
+ return expand_rewrite(sink)