tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
@@ -1,92 +1,222 @@
1
- from typing import List, Set, Dict, Tuple
2
- import functools, heapq
3
- from tinygrad.ops import type_verify, END_FOR_UOP, UOp, Ops, GroupOp
4
- from tinygrad.dtype import dtypes
5
- from tinygrad.helpers import DEBUG
6
-
7
- def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[UOp, None]], in_degree:Dict[UOp, int]):
8
- if u in children: return srcs[u]
9
- srcs[u] = {}
10
- children[u] = []
11
- for x in u.src:
12
- srcs[u].update(get_children_dfs(x, children, srcs, in_degree))
13
- if x.op is Ops.RANGE and x.arg[1]: srcs[u][x] = None
14
- children[x].append(u)
15
- in_degree[u] = len(u.src)
16
- return srcs[u]
17
-
18
- def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
19
- assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
20
- # filter nodes that don't link to a sink
21
- # BFS toposort
22
- children: Dict[UOp, List[UOp]] = {}
23
- range_srcs: Dict[UOp, Dict[UOp, None]] = {}
24
- in_degree: Dict[UOp, int] = {}
25
- get_children_dfs(sink, children, range_srcs, in_degree)
26
-
27
- @functools.lru_cache(None)
28
- def get_recursive_children(x:UOp, end:Ops, include_self=False) -> Set[UOp]:
29
- if x.op is Ops.SINK: return set()
30
- return set.union({x} if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
31
-
32
- # scope children impact the toposort and END* insertion
33
- scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
34
- range_phi = {r:[p for p in scope_children[r] if p.op is Ops.ASSIGN] for r in scope_children if r.op is Ops.RANGE}
35
-
36
- # assign priorities
37
- def get_priority(u:UOp):
38
- priority = 0
39
- # prefer ranges that depend on the least number of independent ranges
40
- if u.op is Ops.RANGE and u.arg[1]:
41
- priority += u.arg[0]
42
- for p in range_phi[u]:
43
- priority += 10000*len([r for r in range_srcs[p] if not any(i in range_phi[u] for i in range_phi[r])])
44
- elif u.op is Ops.CONST:
45
- # place consts first here, they don't do anything and it can cause issues with DEFINE_ACC
46
- priority -= 100000000000
1
+ from __future__ import annotations
2
+ import collections, heapq
3
+ from dataclasses import dataclass
4
+ from tinygrad.ops import UOp, Ops, PatternMatcher, UPat, graph_rewrite, GroupOp
5
+ from tinygrad.spec import type_verify
6
+ from tinygrad.dtype import dtypes, PtrDType
7
+ from tinygrad.helpers import dedup, flatten, partition
8
+
9
+ DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block}
10
+
11
+ def disp(y:UOp) -> str:
12
+ if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0])
13
+ if y.op is Ops.IF: return f'IF{id(y)}'
14
+ if y.op is Ops.RANGE: return str(y.arg)
15
+ return "<NONE>"
16
+
17
+ @dataclass(frozen=True)
18
+ class BasicBlock:
19
+ ctx: tuple[UOp, ...]
20
+ lst: tuple[UOp, ...]
21
+ end: UOp|None = None
22
+ def __lt__(self, o:BasicBlock): return tuple(x.tuplize for x in self.ctx+self.lst) < tuple(x.tuplize for x in o.ctx+o.lst)
23
+ def __repr__(self):
24
+ return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+\
25
+ f"{[disp(y) for y in self.ctx]} {len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
26
+
27
+ def append_to_block(ctx:tuple[dict[UOp, tuple[UOp, ...]], dict[UOp, list[UOp]]], x:UOp):
28
+ block_ctxs, children = ctx
29
+ in_this_block = set(x.arg.lst)
30
+
31
+ # collections to build
32
+ new_srcs: list[UOp] = []
33
+ to_append: list[UOp] = []
34
+ old_blocks: dict[tuple[UOp, ...], UOp] = {}
35
+ new_blocks: dict[tuple[UOp, ...], list[UOp]] = {}
36
+
37
+ for u in x.src:
38
+ if u.op is Ops.BLOCK:
39
+ # merge sibling blocks. NOTE: blocks must only have one output source
40
+ assert u.arg.ctx not in old_blocks, "sibling should never have been created"
41
+ old_blocks[u.arg.ctx] = u
42
+ elif u.op not in DONT_PLACE_IN_BLOCK and set(children[u]).issubset(in_this_block):
43
+ # if it can go in blocks and all its children are in the block, we add it to the block
44
+ if (block_ctx:=block_ctxs[u]) == x.arg.ctx:
45
+ # if it's the same context, we place the UOp in this block and append the parents to its srcs
46
+ new_srcs.extend(u.src)
47
+ to_append.append(u)
48
+ else:
49
+ # if it's a different context, we create a new block with this UOp
50
+ new_blocks.setdefault(block_ctx, []).append(u)
47
51
  else:
48
- # prefer uops that are loop children
49
- priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is Ops.RANGE and u in ss])
50
- if u.op is Ops.IF and len(u.src) == 1: priority += 10000000 # if penalty
51
- return priority
52
- priorities:Dict[UOp, int] = {u:get_priority(u) for u in children}
53
-
54
- # prevent priority inversion
55
- @functools.lru_cache(None)
56
- def fix_priority(u:UOp, lowest_priority):
57
- if u.op in {Ops.CAST, Ops.BITCAST, *GroupOp.ALU, Ops.VECTORIZE, Ops.GEP, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.LOAD}:
58
- priorities[u] = min(priorities[u], lowest_priority)
59
- if u.op is Ops.LOAD: priorities[u] += 100 # load penalty (here)
60
- for x in u.src: fix_priority(x, priorities[u])
61
- fix_priority(sink, 0)
62
-
63
- # NOTE: the compare should never make it all the way to u
64
- queue:List[Tuple[int, Tuple, UOp]] = []
52
+ # otherwise, we keep it in the srcs
53
+ new_srcs.append(u)
54
+ if len(to_append) == 0 and len(new_blocks) == 0: return None
55
+
56
+ for rng,lst in new_blocks.items():
57
+ srcs = flatten(y.src for y in lst)
58
+ if (old_block:=old_blocks.pop(rng, None)) is not None:
59
+ # NOTE: order shouldn't matter here
60
+ srcs.extend(old_block.src)
61
+ lst.extend(old_block.arg.lst)
62
+ new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(srcs)), BasicBlock(rng, tuple(lst)))
63
+ lrng = list(rng)
64
+ for r in rng[::-1]:
65
+ if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART:
66
+ lrng.remove(r)
67
+ new_block = UOp(Ops.BLOCKEND, src=(new_block,),
68
+ arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r))
69
+ new_srcs.append(new_block)
70
+ return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(list(old_blocks.values())+new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst))
71
+
72
+ make_basic_blocks = PatternMatcher([
73
+ (UPat(Ops.SINK, name="x"), lambda x: UOp(Ops.BLOCK, src=x.src, arg=BasicBlock((), (x,)))),
74
+ (UPat(Ops.BLOCK, name="x"), append_to_block),
75
+ ])
76
+
77
+ def block_merge(ctx, x:UOp):
78
+ # ctx is children here
79
+ if x.op is Ops.BLOCKEND:
80
+ # if it's a BLOCKEND, see if we are done with placement. if all the children of the range are in here
81
+ in_this_block = set(x.arg.lst)
82
+ if len([y for y in ctx[x.arg.end] if y not in in_this_block]) == 0:
83
+ # find the parent block that has the BLOCKSTART in the ctx
84
+ parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and UOp(Ops.BLOCKSTART, src=(x.arg.end,)) in y.arg.ctx]
85
+ assert len(parent_blocks) <= 1, "should never have two parent blocks"
86
+ if len(parent_blocks) == 1:
87
+ parent_block = parent_blocks[0]
88
+ # range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if)
89
+ early_ops, late_ops = partition(x.arg.lst, lambda y: y.op is Ops.DEFINE_ACC and x.arg.end in y.src)
90
+ return UOp(Ops.BLOCK, dtypes.void, tuple(y for y in x.src if y is not parent_block)+parent_block.src,
91
+ BasicBlock(tuple(y for y in x.arg.ctx if y is not x.arg.end), tuple(early_ops)+parent_block.arg.lst+tuple(late_ops)))
92
+
93
+ new_srcs: list[UOp] = []
94
+ to_append: list[UOp] = []
95
+ new_ctx = x.arg.ctx
96
+ placed = set()
97
+ for u in x.src:
98
+ if u.op is Ops.BLOCK and (tuple(u.arg.ctx) == tuple(x.arg.ctx) or (x.arg.end is not None and x.arg.end in u.arg.ctx)):
99
+ # NOTE: this can't appear in srcs twice or it would be a BLOCKFORK
100
+ new_ctx += tuple(y for y in u.arg.ctx if y not in x.arg.ctx)
101
+ new_srcs.extend(u.src)
102
+ to_append.extend(u.arg.lst)
103
+ elif u.op is Ops.BLOCKFORK and x.src.count(u) == u.arg: # block fork appears # of times in srcs
104
+ if u not in placed:
105
+ new_srcs.extend(u.src)
106
+ placed.add(u)
107
+ else:
108
+ # keep it in srcs
109
+ new_srcs.append(u)
110
+ if len(to_append) == 0 and len(placed) == 0: return None
111
+ return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(tuple(sorted(new_ctx, key=lambda x: x.tuplize)), tuple(to_append)+x.arg.lst, x.arg.end))
112
+
113
+ pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),])
114
+
115
+ # NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
116
+ def block_reorder(in_block:UOp):
117
+ in_this_block = set(in_block.arg.lst)
118
+ local_children: collections.defaultdict[UOp, list[UOp]] = collections.defaultdict(list)
119
+ in_degree: collections.defaultdict[UOp, int] = collections.defaultdict(int)
120
+ priorities:dict[UOp, int] = {}
121
+
122
+ # get local children and assign priorities
123
+ for u in reversed(in_block.arg.lst):
124
+ for s in u.src:
125
+ if s in in_this_block:
126
+ local_children[s].append(u)
127
+ in_degree[u] += 1
128
+ # put loads in the beginning of the block and prevent priority inversion
129
+ priorities[u] = min([-1000 if u.op is Ops.LOAD else 0] + [priorities[x] for x in local_children[u]])
130
+
131
+ # placement queue
132
+ queue:list[tuple[int, tuple, UOp]] = []
65
133
  def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u))
66
134
 
67
- for u in children:
68
- if in_degree[u] == 0: push(u)
135
+ # place the first ones that don't have deps
136
+ for u in in_block.arg.lst:
137
+ if u not in in_degree: push(u)
69
138
 
70
- scope_end: Dict[UOp, UOp] = {}
71
- _uops: List[UOp] = []
139
+ newlst = []
72
140
  while queue:
73
- p,_,x = heapq.heappop(queue)
74
- if DEBUG >= 7: print(f"{p:5d}", x.op, x.dtype, x.arg)
75
- if x in scope_children: scope_end[x] = x
76
- if x.op is Ops.DEFINE_ACC:
77
- idx = min([_uops.index(l) for l in x.src if l.op is Ops.RANGE])
78
- _uops.insert(idx, x)
79
- else: _uops.append(x)
80
- for u, ss in scope_children.items():
81
- if x in ss:
82
- ss.remove(x)
83
- if len(ss) == 0: scope_end[u] = x
84
- for u in children[x]:
141
+ _,_,x = heapq.heappop(queue)
142
+ newlst.append(x)
143
+ for u in local_children[x]:
85
144
  in_degree[u] -= 1
86
145
  if in_degree[u] == 0: push(u)
87
146
 
88
- # end scopes in toposort order
89
- for u, x in scope_end.items(): _uops.insert(_uops.index(x)+1, UOp(END_FOR_UOP[u.op][1], dtypes.void, (u,)))
147
+ assert len(newlst) == len(in_block.arg.lst), f"len mismatch {len(newlst)} != {len(in_block.arg.lst)}"
148
+ return in_block.replace(arg=BasicBlock(in_block.arg.ctx, tuple(newlst)))
149
+
150
+ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
151
+ assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
152
+
153
+ # get children and all block contexts
154
+ temp_block_ctxs: dict[UOp, list[UOp]] = {}
155
+ children: dict[UOp, list[UOp]] = {}
156
+ for u in sink.toposort:
157
+ this_block_ctx: list[UOp] = []
158
+ for s in u.src:
159
+ # save children
160
+ children.setdefault(s, []).append(u)
161
+ # compute block ctx
162
+ if s.op in {Ops.RANGE, Ops.IF}: this_block_ctx.append(s)
163
+ # don't flow (fully) through assign and store
164
+ elif s.op is Ops.STORE:
165
+ # ugh, deal with non-reduce locals. probably wrong
166
+ if isinstance(s.src[0].dtype, PtrDType) and s.src[0].dtype.local:
167
+ idx_context, store_context = temp_block_ctxs[s.src[0]], temp_block_ctxs[s]
168
+ this_block_ctx += [x for x in store_context if x not in idx_context and x.op is Ops.RANGE]
169
+ elif s.op is Ops.ASSIGN:
170
+ # flow though assign, but remove the ranges used in the assign
171
+ assert s.src[0].op is Ops.DEFINE_ACC
172
+ this_block_ctx += [x for x in temp_block_ctxs[s.src[1]] if x not in s.src[0].src[1:]]
173
+ else:
174
+ # flow though everything else
175
+ this_block_ctx += temp_block_ctxs[s]
176
+ temp_block_ctxs[u] = sorted(dedup(this_block_ctx), key=lambda x: x.tuplize)
177
+
178
+ # make final block_ctxs, add BLOCKSTART to block_ctxs for IF and RANGE
179
+ block_ctxs: dict[UOp, tuple[UOp, ...]] = {}
180
+ for u in sink.toposort:
181
+ block_ctxs[u] = ((UOp(Ops.BLOCKSTART, src=(u,)),) + tuple(temp_block_ctxs[u])) if u.op in {Ops.IF, Ops.RANGE} else tuple(temp_block_ctxs[u])
182
+
183
+ # TODO: there's probably a clever way to remove this while loop
184
+ while 1:
185
+ sink = graph_rewrite(sink, make_basic_blocks, ctx=(block_ctxs, children))
186
+
187
+ # add BLOCKFORK (slow!)
188
+ block_parent_count = collections.Counter(flatten([x.src for x in sink.toposort if x.op is Ops.BLOCK]))
189
+ non_block_parents = set(flatten([x.src for x in sink.toposort if x.op is not Ops.BLOCK]))
190
+ forks = {u:UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], (u,))),), arg=child_count)
191
+ for u,child_count in block_parent_count.items() if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents}
192
+
193
+ if not len(forks): break
194
+ sink = sink.substitute(forks)
195
+
196
+ # combine matching BLOCKENDS
197
+ blockends_to_arg: dict[UOp, list[UOp]] = {}
198
+ for be in sink.toposort:
199
+ if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
200
+ new_forks = {}
201
+ for k,v in blockends_to_arg.items():
202
+ # NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
203
+ if len(v) > 1:
204
+ out = UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCKEND, src=tuple(flatten(x.src for x in v)),
205
+ arg=BasicBlock(tuple(dedup(flatten([y.arg.ctx for y in v]))), v[0].arg.lst, k)),), arg=len(v))
206
+ for u in v: new_forks[u] = out
207
+ sink = sink.substitute(new_forks)
208
+
209
+ # reorder ops in block for speed
210
+ sink = sink.substitute({u:newu for u in sink.toposort if u.op is Ops.BLOCK and (newu:=block_reorder(u)) is not u})
211
+
212
+ # final rewrite to merge all blocks into one
213
+ sink = graph_rewrite(sink, pm_block_merge, ctx=children)
214
+
215
+ # there should just be one block left, with a few parents with 0 srcs
216
+ assert sink.op is Ops.BLOCK
217
+ _uops = sorted(dedup(sink.src), key=lambda x: x.tuplize)
218
+ assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
219
+ _uops += sink.arg.lst
90
220
 
91
221
  # sanity checks (NOTE: these can cause things to be skipped in BEAM)
92
222
  if not skip_check: type_verify(_uops)
@@ -1,17 +1,14 @@
1
1
  # the job of the lowerer is to do indexing
2
- from __future__ import annotations
3
2
  import functools, itertools, operator
4
3
  from dataclasses import dataclass
5
- from typing import List, Tuple, cast, Optional
6
- from tinygrad.shape.shapetracker import ShapeTracker
7
- from tinygrad.shape.view import variable_to_uop
4
+ from typing import cast
8
5
  from tinygrad.dtype import dtypes, PtrDType
9
- from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
6
+ from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
10
7
  from tinygrad.renderer import Renderer
11
- from tinygrad.helpers import all_int, prod, partition, flatten
8
+ from tinygrad.helpers import all_int, prod, partition, flatten, unwrap
12
9
 
13
10
  # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
14
- def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
11
+ def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
15
12
  acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
16
13
  try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
17
14
  except ValueError: return None
@@ -19,7 +16,7 @@ def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> O
19
16
 
20
17
  # ***** indexing *****
21
18
 
22
- def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
19
+ def _limit_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
23
20
  # TODO: symbolic shape
24
21
  if not all_int(dims): return dims
25
22
  while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
@@ -30,25 +27,24 @@ def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
30
27
  else: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
31
28
  return dims
32
29
 
33
- def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]:
30
+ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|None, reverse=False) -> list[UOp]:
34
31
  if reverse: dims = dims[::-1]
35
32
  limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
36
33
  ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
37
34
  if limited != dims:
38
35
  ret = []
39
- # cast for mypy, get_contraction won't be None
40
- for idx, contraction in zip(raw_idxs, cast(List[List[int]], get_contraction(dims, limited))):
41
- if len(contraction) == 1: ret.append(idx)
42
- else:
43
- for c in contraction:
44
- ret.append(idx % dims[c])
45
- idx //= dims[c]
36
+ if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}")
37
+ for idx, contraction_group in zip(raw_idxs, contraction):
38
+ for c in contraction_group[:-1]:
39
+ ret.append(idx % dims[c])
40
+ idx //= dims[c]
41
+ ret.append(idx)
46
42
  return ret[::-1] if reverse else ret
47
43
 
48
44
  @dataclass
49
45
  class IndexContext:
50
- idxs: List[UOp]
51
- ridxs: List[UOp]
46
+ idxs: list[UOp]
47
+ ridxs: list[UOp]
52
48
  acc_num: int = 0
53
49
 
54
50
  def get_index(ast:UOp, opts:Renderer) -> IndexContext:
@@ -56,14 +52,11 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
56
52
  # NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
57
53
  full_shape = ast.full_shape
58
54
  first_upcasted = len(full_shape)-ki.upcasted
59
- first_output_st: ShapeTracker = ast.src[0].st_arg
60
55
  # if there's no reduce, this is first_upcasted. assumes reduces are at the end
61
- first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.sparents if x.op is Ops.REDUCE_AXIS))
62
- local_loads = [x for x in ast.parents if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
56
+ first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.toposort if x.op is Ops.REDUCE_AXIS))
57
+ local_loads = [x for x in ast.toposort if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
63
58
  # NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
64
- group_for_reduces = sum([any(j!=y for j in x) for x,y in zip(
65
- [[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)],
66
- first_output_st.shape[first_reduce:first_upcasted])]) if local_loads else 0
59
+ group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)])
67
60
  global_dims = first_reduce-ki.local_dims
68
61
 
69
62
  if opts.has_local:
@@ -76,22 +69,21 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
76
69
  get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
77
70
  else:
78
71
  # all loops are RANGES
79
- idxs = [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False))
80
- for i,g in enumerate(full_shape[:first_reduce])]
72
+ idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i) for i,g in enumerate(full_shape[:first_reduce])]
81
73
 
82
74
  # reduce loops
83
- idxs += [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True))
75
+ idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i)
84
76
  for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
85
77
 
86
78
  # upcast loops
87
79
  for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
88
80
  assert isinstance(g, int), "needs to be int to upcast/unroll"
89
- idxs.append(UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
81
+ idxs.append(UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
90
82
 
91
83
  # late indexes (group for reduce)
92
84
  ridxs = idxs[:]
93
85
  for a in range(first_reduce, first_reduce+group_for_reduces):
94
- ridxs[a] = UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True))
86
+ ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(full_shape[a])), 1000+a)
95
87
 
96
88
  return IndexContext(idxs, ridxs)
97
89
 
@@ -100,7 +92,7 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
100
92
  def lower_reduce_axis(ctx: IndexContext, x: UOp):
101
93
  # NOTE: always using ridxs is fine here
102
94
  reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
103
- assert all(x.op is Ops.EXPAND for x in reduce_expand), f"not all EXPANDS in {reduce_expand} for {x.axis_arg}"
95
+ assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
104
96
  alu_op: Ops = x.arg[0]
105
97
  ret = x.src[0]
106
98
  if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
@@ -114,12 +106,10 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
114
106
 
115
107
  def lower_load_store(ctx: IndexContext, x: UOp):
116
108
  idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
117
- # TODO: check has_valid in UPat, not here
118
- has_valid = valid.op is not Ops.CONST or valid.arg is not True
119
109
  buf = x.src[0]
120
110
  if x.op is Ops.LOAD:
121
111
  barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
122
- return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid if has_valid else None),) + barrier)
112
+ return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
123
113
  # NOTE: only store the local reduceop in the threads that are actually doing the reduce
124
114
  if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.ASSIGN:
125
115
  reduce_input = x.src[2].src[1].src[1] if x.src[2].src[1].src[1] is not x.src[2].src[0] else x.src[2].src[1].src[0]
@@ -130,14 +120,19 @@ def lower_load_store(ctx: IndexContext, x: UOp):
130
120
  if (not cast(PtrDType, x.src[0].dtype).local) or store_back:
131
121
  for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
132
122
  if oidx is not ridx: valid = valid * oidx.eq(0)
133
- has_valid = valid.op is not Ops.CONST or valid.arg is not True
134
- return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid if has_valid else None), x.src[2]))
123
+ return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid), x.src[2]))
124
+
125
+ def lower_const(x:UOp):
126
+ assert all(v.mask is None for v in unwrap(x.st).views), f"VIEW in CONST/DEFINE_VAR source must be unmasked, got {x.st}"
127
+ return x.replace(src=())
135
128
 
136
129
  pm_lowerer = PatternMatcher([
137
130
  (UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
131
+ (UPat((Ops.CONST, Ops.DEFINE_VAR), src=(UPat(Ops.VIEW),), name="x"), lower_const),
138
132
  (UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
139
133
  # rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
140
134
  (UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
135
+ (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
141
136
  ])
142
137
 
143
138
  def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))