tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,89 @@
1
+ import math
2
+ from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType
3
+ from tinygrad.helpers import all_int
4
+ from tinygrad.dtype import dtypes
5
+ from tinygrad.shape.view import get_contraction
6
+ from tinygrad.renderer import Renderer
7
+
8
+ def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
9
+ # TODO: symbolic shape
10
+ if not all_int(dims): return dims
11
+ while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
12
+ for i,m in enumerate(max_sizes):
13
+ if i < (len(dims)-1) and dims[i] * dims[i+1] <= m:
14
+ dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
15
+ break
16
+ else: return None
17
+ return dims
18
+
19
+ def _split_dims(dims, max_sizes):
20
+ if all(d <= m for d,m in zip(dims, max_sizes)): return dims
21
+ _dims = list(dims) + [1]*(3-len(dims))
22
+ for i in range(len(_dims)):
23
+ while _dims[i] > max_sizes[i]:
24
+ div = next((d for d in range(2, math.ceil(math.sqrt(_dims[i])) + 1) if (_dims[i] % d) == 0), 1)
25
+ if div == 1: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
26
+ _dims[i], _dims[(i+1)%len(_dims)] = _dims[i]//div, _dims[(i+1)%len(_dims)]*div
27
+ return tuple(_dims[:2] if _dims[2] == 1 else _dims[0] if _dims[1:3] == [1,1] else _dims)
28
+
29
+ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|None, reverse=False) -> list[UOp]:
30
+ if reverse: dims = dims[::-1]
31
+ # try to group first: (a, b, c, d) -> (ab, c, d)
32
+ limited = (grouped if (grouped := _group_dims(dims, max_sizes)) else dims) if max_sizes is not None else dims
33
+ # check if grouping failed
34
+ if max_sizes is not None and len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
35
+ # try to split up dims: (a,) -> (b, c)
36
+ if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims
37
+ ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
38
+ if len(limited) < len(dims):
39
+ ret = []
40
+ if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}")
41
+ for idx, contraction_group in zip(raw_idxs, contraction):
42
+ for c in contraction_group[:-1]:
43
+ ret.append(idx % dims[c])
44
+ idx //= dims[c]
45
+ ret.append(idx)
46
+ elif len(limited) > len(dims):
47
+ a, b = len(limited), len(dims)
48
+ if a == 2 and b == 1: ret = [raw_idxs[0] * limited[1] + raw_idxs[1]]
49
+ if a == 3 and b == 1: ret = [raw_idxs[0] * (limited[1] * limited[2]) + raw_idxs[1] * limited[2] + raw_idxs[2]]
50
+ if a == 3 and b == 2: ret = [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
51
+ return ret[::-1] if reverse else ret
52
+
53
+ def add_gpudims(ctx:Renderer, s:UOp):
54
+ if s.arg is None: return None
55
+ ki: KernelInfo = s.arg
56
+ global_dims = [i for i,x in enumerate(ki.axis_types) if x is AxisType.GLOBAL]
57
+ local_dims = [i for i,x in enumerate(ki.axis_types) if x in (AxisType.LOCAL, AxisType.GROUP_REDUCE)]
58
+ if not global_dims and not local_dims: return None
59
+ s_topo = list(s.toposort())
60
+ if any(x.op is Ops.SPECIAL for x in s_topo): return None
61
+
62
+ # get global and local shape
63
+ all_ranges = {x.arg%1000:x for x in s_topo if x.op is Ops.RANGE}
64
+ ranges = [all_ranges[r] for r in global_dims+local_dims if r in all_ranges]
65
+ global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg%1000 in global_dims])
66
+ local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg%1000 in local_dims])
67
+
68
+ # get the idxs
69
+ if ki.dont_use_locals:
70
+ assert not local_dims, "can't use locals if there's no local dims"
71
+ idxs = get_grouped_dims("idx", global_shape, ctx.global_max, reverse=True)
72
+ else:
73
+ # define indexes for GPU-like execution
74
+ idxs = get_grouped_dims("gidx", global_shape, ctx.global_max, reverse=True) + get_grouped_dims("lidx", local_shape, ctx.local_max)
75
+
76
+ # apply to multiple ranges
77
+ subs = {}
78
+ for r in s_topo:
79
+ if r.op is not Ops.RANGE: continue
80
+ try:
81
+ ii = (global_dims+local_dims).index(r.arg%1000)
82
+ if r.arg < 2000 and ki.axis_types[r.arg%1000] == AxisType.GROUP_REDUCE: continue
83
+ subs[r] = idxs[ii]
84
+ except ValueError: continue
85
+ return s.substitute(subs)
86
+
87
+ pm_add_gpudims = PatternMatcher([
88
+ (UPat(Ops.SINK, name="s"), add_gpudims),
89
+ ])
@@ -1,234 +1,236 @@
1
1
  from __future__ import annotations
2
- import collections, heapq
3
- from dataclasses import dataclass
4
- from tinygrad.ops import UOp, Ops, PatternMatcher, UPat, graph_rewrite, GroupOp
5
- from tinygrad.spec import type_verify
6
- from tinygrad.dtype import dtypes, PtrDType
7
- from tinygrad.helpers import dedup, flatten, partition
2
+ import heapq
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass, replace
5
+ from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat, GroupOp
6
+ from tinygrad.helpers import dedup, all_same, flatten, getenv
8
7
 
9
- DONT_PLACE_IN_BLOCK = {Ops.NAME, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block}
8
+ # NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
9
+ def block_reorder(lst:list[UOp]) -> list[UOp]:
10
+ in_this_block = set(lst)
11
+ local_children: defaultdict[UOp, list[UOp]] = defaultdict(list)
12
+ in_degree:dict[UOp, int] = {}
13
+ priorities:dict[UOp, int] = {}
14
+
15
+ # get local children and assign priorities
16
+ # NOTE: this requires the lst be locally toposorted
17
+ for u in reversed(lst):
18
+ in_degree[u] = 0
19
+ for s in u.src:
20
+ if s in in_this_block:
21
+ local_children[s].append(u)
22
+ in_degree[u] += 1
23
+ # put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too
24
+ priority = [0] + [priorities[x] for x in local_children[u]]
25
+ if u.op is Ops.LOAD: priority.append(-1000)
26
+ if u.op is Ops.BARRIER: priority.append(-1500)
27
+ priorities[u] = min(priority)
28
+
29
+ # number the uops in "ideal" order
30
+ nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))}
31
+
32
+ # then force then to be toposorted in as close to the ideal order as possible
33
+ heapq.heapify(heap:=[(nkey[u],u) for u in lst if in_degree[u] == 0])
34
+ newlst = []
35
+ while heap:
36
+ newlst.append(u:=heapq.heappop(heap)[1])
37
+ for v in local_children[u]:
38
+ in_degree[v] -= 1
39
+ if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v))
40
+
41
+ assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}"
42
+ return newlst
43
+
44
+ # ***** basic block *****
10
45
 
11
46
  def disp(y:UOp) -> str:
12
- if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0])
13
47
  if y.op is Ops.IF: return f'IF{id(y)}'
14
48
  if y.op is Ops.RANGE: return str(y.arg)
15
49
  return "<NONE>"
16
50
 
17
- @dataclass(frozen=True)
51
+ @dataclass(frozen=True, eq=False)
18
52
  class BasicBlock:
19
- ctx: tuple[UOp, ...]
20
53
  lst: tuple[UOp, ...]
54
+ ctx: tuple[UOp, ...] = ()
21
55
  end: UOp|None = None
22
- def __lt__(self, o:BasicBlock): return tuple(x.tuplize for x in self.ctx+self.lst) < tuple(x.tuplize for x in o.ctx+o.lst)
56
+ cnt: int = 0
57
+ child_ctx: tuple[UOp, ...]|None = None
58
+ def __lt__(self, _:BasicBlock): raise RuntimeError("no comparing basic blocks")
23
59
  def __repr__(self):
24
- return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+\
25
- f"{[disp(y) for y in self.ctx]} {len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
26
-
27
- def append_to_block(ctx:tuple[dict[UOp, tuple[UOp, ...]], dict[UOp, list[UOp]]], x:UOp):
28
- block_ctxs, children = ctx
29
- in_this_block = set(x.arg.lst)
30
-
31
- # collections to build
32
- new_srcs: list[UOp] = []
33
- to_append: list[UOp] = []
34
- old_blocks: dict[tuple[UOp, ...], UOp] = {}
35
- new_blocks: dict[tuple[UOp, ...], list[UOp]] = {}
36
-
37
- for u in x.src:
38
- if u.op is Ops.BLOCK:
39
- # merge sibling blocks. NOTE: blocks must only have one output source
40
- assert u.arg.ctx not in old_blocks, "sibling should never have been created"
41
- old_blocks[u.arg.ctx] = u
42
- elif u.op not in DONT_PLACE_IN_BLOCK and set(children[u]).issubset(in_this_block):
43
- # if it can go in blocks and all its children are in the block, we add it to the block
44
- if (block_ctx:=block_ctxs[u]) == x.arg.ctx:
45
- # if it's the same context, we place the UOp in this block and append the parents to its srcs
46
- new_srcs.extend(u.src)
47
- to_append.append(u)
60
+ return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+f'f{self.cnt} '+\
61
+ f"{[disp(y) for y in self.ctx]} {[disp(y) for y in self.child_ctx] if self.child_ctx is not None else '-'} "+\
62
+ f"{len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
63
+ def last_ctx(self): return self.child_ctx if self.child_ctx is not None else self.ctx
64
+
65
+ def _sort_ctx(inp): return tuple(sorted(dedup(inp), key=lambda x: x.tuplize))
66
+
67
+ # ***** block context *****
68
+
69
+ @dataclass
70
+ class BlockContext:
71
+ child_count: dict[UOp, int]
72
+ block_ctxs: dict[UOp, tuple[UOp, ...]]
73
+ child_ctxs: dict[UOp, tuple[UOp, ...]]
74
+ def last_ctx(self, u): return self.child_ctxs.get(u, self.block_ctxs[u])
75
+ @staticmethod
76
+ def from_sink(sink:UOp) -> BlockContext:
77
+ # get children and all block contexts
78
+ ctx = BlockContext({}, {}, {})
79
+ for u in sink.toposort():
80
+ this_block_ctx: list[UOp] = []
81
+ ctx.child_count[u] = 0
82
+
83
+ # get children and accumulate the last_ctx
84
+ for s in u.src:
85
+ # NOTE: if a parent appears multiple times in the src, it counts multiple times as a child
86
+ ctx.child_count[s] += 1
87
+ this_block_ctx += ctx.last_ctx(s)
88
+
89
+ # save the block ctx. SINK never has anything
90
+ ctx.block_ctxs[u] = _sort_ctx(this_block_ctx) if u.op is not Ops.SINK else ()
91
+
92
+ # RANGE/IF add to the next ctx
93
+ # STORE/ASSIGN subtract from the next ctx
94
+ if u.op in {Ops.RANGE, Ops.IF}: ctx.child_ctxs[u] = _sort_ctx(ctx.block_ctxs[u] + (u,))
95
+ elif u.op is Ops.STORE: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src])
96
+ return ctx
97
+
98
+ # ***** make blocks *****
99
+
100
+ DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}
101
+
102
+ def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp, ...], cnt:int=1) -> UOp:
103
+ ends_to_add = [z for z in new_ctx if z not in current_ctx]
104
+ while len(ends_to_add):
105
+ r:UOp = ends_to_add.pop(-1)
106
+ new_ctx = tuple([z for z in new_ctx if z is not r])
107
+ end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,))
108
+ base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock((end_uop,), tuple(new_ctx), end=r, cnt=cnt))
109
+ return base_block
110
+
111
+ def make_block_bottom_up(ctx:BlockContext, x:UOp):
112
+ if x.op is Ops.BLOCKSTART:
113
+ current_ctx, child_ctx = x.arg
114
+ lst = list(x.src)
115
+ child_count = 1
116
+ else:
117
+ current_ctx, child_count, child_ctx = ctx.block_ctxs[x], ctx.child_count[x], ctx.child_ctxs.get(x, None)
118
+ lst = [x]
119
+
120
+ # count of times we've seen this block, or a seed for a new block if we can't merge it
121
+ unmergable: defaultdict[UOp, int] = defaultdict(int)
122
+ blockseeds = defaultdict(list)
123
+
124
+ # add the srcs of this to the frontier
125
+ # NOTE: things may be in here multiple times, that's okay
126
+ frontier_nodes = list(flatten(y.src[::-1] for y in lst))
127
+ while len(frontier_nodes):
128
+ u = frontier_nodes.pop(0)
129
+ if u.op not in DONT_PLACE_IN_BLOCK and ctx.child_count[u] == unmergable[u]+1:
130
+ # count is correct
131
+ if (newctx:=ctx.block_ctxs[u]) == current_ctx:
132
+ # block has same context, merge it, and put the srcs on the frontier
133
+ lst.append(u)
134
+ frontier_nodes.extend(u.src[::-1])
48
135
  else:
49
- # if it's a different context, we create a new block with this UOp
50
- new_blocks.setdefault(block_ctx, []).append(u)
136
+ # block has different context, add it to blockseeds
137
+ blockseeds[(newctx, ctx.child_ctxs.get(u, None))].append(u)
138
+ del unmergable[u]
51
139
  else:
52
- # otherwise, we keep it in the srcs
53
- new_srcs.append(u)
54
- if len(to_append) == 0 and len(new_blocks) == 0: return None
55
-
56
- for rng,lst in new_blocks.items():
57
- srcs = flatten(y.src for y in lst)
58
- if (old_block:=old_blocks.pop(rng, None)) is not None:
59
- # NOTE: order shouldn't matter here
60
- srcs.extend(old_block.src)
61
- lst.extend(old_block.arg.lst)
62
- new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(srcs)), BasicBlock(rng, tuple(lst)))
63
- lrng = list(rng)
64
- for r in rng[::-1]:
65
- if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART:
66
- lrng.remove(r)
67
- new_block = UOp(Ops.BLOCKEND, src=(new_block,),
68
- arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r))
69
- new_srcs.append(new_block)
70
- return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(list(old_blocks.values())+new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst))
71
-
72
- make_basic_blocks = PatternMatcher([
73
- (UPat(Ops.SINK, name="x"),
74
- lambda x: UOp(Ops.BLOCK, src=x.src+((UOp(Ops.NAME, arg=x.arg.name),) if x.arg is not None else ()), arg=BasicBlock((), (x,)))),
75
- (UPat(Ops.BLOCK, name="x"), append_to_block),
76
- ])
77
-
78
- def block_merge(ctx, x:UOp):
79
- # ctx is children here
80
- if x.op is Ops.BLOCKEND:
81
- # if it's a BLOCKEND, see if we are done with placement. if all the children of the range are in here
82
- in_this_block = set(x.arg.lst)
83
- if len([y for y in ctx[x.arg.end] if y not in in_this_block]) == 0:
84
- # find the parent block that has the BLOCKSTART in the ctx
85
- parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and UOp(Ops.BLOCKSTART, src=(x.arg.end,)) in y.arg.ctx]
86
- assert len(parent_blocks) <= 1, "should never have two parent blocks"
87
- if len(parent_blocks) == 1:
88
- parent_block = parent_blocks[0]
89
- # range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if)
90
- early_ops, late_ops = partition(x.arg.lst, lambda y: y.op is Ops.DEFINE_ACC and x.arg.end in y.src)
91
- return UOp(Ops.BLOCK, dtypes.void, tuple(y for y in x.src if y is not parent_block)+parent_block.src,
92
- BasicBlock(tuple(y for y in x.arg.ctx if y is not x.arg.end), tuple(early_ops)+parent_block.arg.lst+tuple(late_ops)))
93
-
94
- new_srcs: list[UOp] = []
95
- to_append: list[UOp] = []
96
- new_ctx = x.arg.ctx
97
- placed = set()
98
- for u in x.src:
99
- if u.op is Ops.BLOCK and (tuple(u.arg.ctx) == tuple(x.arg.ctx) or (x.arg.end is not None and x.arg.end in u.arg.ctx)):
100
- # NOTE: this can't appear in srcs twice or it would be a BLOCKFORK
101
- new_ctx += tuple(y for y in u.arg.ctx if y not in x.arg.ctx)
102
- new_srcs.extend(u.src)
103
- to_append.extend(u.arg.lst)
104
- elif u.op is Ops.BLOCKFORK and x.src.count(u) == u.arg: # block fork appears # of times in srcs
105
- if u not in placed:
106
- new_srcs.extend(u.src)
107
- placed.add(u)
108
- else:
109
- # keep it in srcs
110
- new_srcs.append(u)
111
- if len(to_append) == 0 and len(placed) == 0: return None
112
- return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(tuple(sorted(new_ctx, key=lambda x: x.tuplize)), tuple(to_append)+x.arg.lst, x.arg.end))
113
-
114
- pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),])
115
-
116
- def block_finalize(block:UOp):
117
- if len(block.src) == 0: return None
118
- _uops = sorted(dedup(block.src), key=lambda x: x.tuplize)
119
- assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
120
- _uops += block.arg.lst
121
- # strip the SINK
122
- assert _uops[-1].op is Ops.SINK, "doesn't end with SINK"
123
- return UOp(Ops.BLOCK, arg=BasicBlock((), tuple(_uops[:-1])))
140
+ # count is incorrect (or it's DONT_PLACE_IN_BLOCK), add it to unmergable
141
+ unmergable[u] += 1
124
142
 
125
- pm_block_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="block"), block_finalize)])
143
+ # add unmergables to sources
144
+ srcs = []
145
+ for u,cnt in unmergable.items(): srcs += [add_blockends(u, ctx.block_ctxs[u], current_ctx, cnt=cnt)]*cnt
126
146
 
127
- # NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
128
- def block_reorder(in_block:UOp):
129
- in_this_block = set(in_block.arg.lst)
130
- local_children: collections.defaultdict[UOp, list[UOp]] = collections.defaultdict(list)
131
- in_degree: collections.defaultdict[UOp, int] = collections.defaultdict(int)
132
- priorities:dict[UOp, int] = {}
133
-
134
- # get local children and assign priorities
135
- for u in reversed(in_block.arg.lst):
136
- for s in u.src:
137
- if s in in_this_block:
138
- local_children[s].append(u)
139
- in_degree[u] += 1
140
- # put loads in the beginning of the block and prevent priority inversion
141
- priorities[u] = min([-1000 if u.op is Ops.LOAD else 0] + [priorities[x] for x in local_children[u]])
142
-
143
- # placement queue
144
- queue:list[tuple[int, tuple, UOp]] = []
145
- def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u))
147
+ # add blockseeds, with blockends as needed
148
+ for (new_ctx, new_child_ctx), v in blockseeds.items():
149
+ base_block = UOp(Ops.BLOCKSTART, src=tuple(v), arg=(new_ctx, new_child_ctx))
150
+ srcs.append(add_blockends(base_block, new_ctx, current_ctx))
146
151
 
147
- # place the first ones that don't have deps
148
- for u in in_block.arg.lst:
149
- if u not in in_degree: push(u)
152
+ lst = lst[::-1]
153
+ if getenv("BLOCK_REORDER", 1): lst = block_reorder(lst)
154
+ bb = BasicBlock(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx)
155
+ return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb)
150
156
 
151
- newlst = []
152
- while queue:
153
- _,_,x = heapq.heappop(queue)
154
- newlst.append(x)
155
- for u in local_children[x]:
156
- in_degree[u] -= 1
157
- if in_degree[u] == 0: push(u)
158
-
159
- assert len(newlst) == len(in_block.arg.lst), f"len mismatch {len(newlst)} != {len(in_block.arg.lst)}"
160
- return in_block.replace(arg=BasicBlock(in_block.arg.ctx, tuple(newlst)))
161
-
162
- def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
163
- assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
164
-
165
- # get children and all block contexts
166
- temp_block_ctxs: dict[UOp, list[UOp]] = {}
167
- children: dict[UOp, list[UOp]] = {}
168
- for u in sink.toposort:
169
- this_block_ctx: list[UOp] = []
170
- for s in u.src:
171
- # save children
172
- children.setdefault(s, []).append(u)
173
- # compute block ctx
174
- if s.op in {Ops.RANGE, Ops.IF}: this_block_ctx.append(s)
175
- # don't flow (fully) through assign and store
176
- elif s.op is Ops.STORE:
177
- # ugh, deal with non-reduce locals. probably wrong
178
- if isinstance(s.src[0].dtype, PtrDType) and s.src[0].dtype.local:
179
- idx_context, store_context = temp_block_ctxs[s.src[0]], temp_block_ctxs[s]
180
- this_block_ctx += [x for x in store_context if x not in idx_context and x.op is Ops.RANGE]
181
- elif s.op is Ops.ASSIGN:
182
- # flow though assign, but remove the ranges used in the assign
183
- assert s.src[0].op is Ops.DEFINE_ACC
184
- this_block_ctx += [x for x in temp_block_ctxs[s.src[1]] if x not in s.src[0].src[1:]]
185
- else:
186
- # flow though everything else
187
- this_block_ctx += temp_block_ctxs[s]
188
- temp_block_ctxs[u] = sorted(dedup(this_block_ctx), key=lambda x: x.tuplize)
189
-
190
- # make final block_ctxs, add BLOCKSTART to block_ctxs for IF and RANGE
191
- block_ctxs: dict[UOp, tuple[UOp, ...]] = {}
192
- for u in sink.toposort:
193
- block_ctxs[u] = ((UOp(Ops.BLOCKSTART, src=(u,)),) + tuple(temp_block_ctxs[u])) if u.op in {Ops.IF, Ops.RANGE} else tuple(temp_block_ctxs[u])
194
-
195
- # TODO: there's probably a clever way to remove this while loop
196
- while 1:
197
- sink = graph_rewrite(sink, make_basic_blocks, ctx=(block_ctxs, children))
198
-
199
- # add BLOCKFORK (slow!)
200
- block_parent_count = collections.Counter(flatten([x.src for x in sink.toposort if x.op is Ops.BLOCK]))
201
- non_block_parents = set(flatten([x.src for x in sink.toposort if x.op is not Ops.BLOCK]))
202
- forks = {u:UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], (u,))),), arg=child_count)
203
- for u,child_count in block_parent_count.items() if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents}
157
+ block_create = PatternMatcher([
158
+ (UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up),
159
+ ])
204
160
 
205
- if not len(forks): break
206
- sink = sink.substitute(forks)
161
+ # ***** blockend merging ****
207
162
 
208
- # combine matching BLOCKENDS
163
+ def merge_blockends(sink:UOp) -> UOp|None:
164
+ # only run on the final BLOCK with the SINK in it
165
+ if sink.arg.lst[-1].op is not Ops.SINK: return None
166
+ # combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs
209
167
  blockends_to_arg: dict[UOp, list[UOp]] = {}
210
- for be in sink.toposort:
168
+ for be in sink.toposort():
211
169
  if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
212
170
  new_forks = {}
213
171
  for k,v in blockends_to_arg.items():
214
172
  # NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
215
173
  if len(v) > 1:
216
- out = UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCKEND, src=tuple(flatten(x.src for x in v)),
217
- arg=BasicBlock(tuple(dedup(flatten([y.arg.ctx for y in v]))), v[0].arg.lst, k)),), arg=len(v))
174
+ bb = BasicBlock(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v))
175
+ out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb)
176
+ # NOTE: bb.ctx != u.arg.ctx can cause problems here
218
177
  for u in v: new_forks[u] = out
219
- sink = sink.substitute(new_forks)
220
-
221
- # reorder ops in block for speed
222
- sink = sink.substitute({u:newu for u in sink.toposort if u.op is Ops.BLOCK and (newu:=block_reorder(u)) is not u})
178
+ if len(new_forks) == 0: return None
179
+ return sink.substitute(new_forks)
180
+
181
+ pm_blockend_merge = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), merge_blockends)])
182
+
183
+ # ***** block merging ****
184
+
185
+ def merge_block(x:UOp):
186
+ unmergable_blocks, mergable_blocks = [], []
187
+ mergable_dict: defaultdict[UOp, int] = defaultdict(int)
188
+ for y in x.src:
189
+ if y.op is Ops.BLOCK and x.op is Ops.BLOCK and x.arg.ctx == y.arg.ctx: mergable_dict[y] += 1
190
+ elif y.op is Ops.BLOCK and x.op is Ops.BLOCKEND and x.arg.end in y.arg.ctx: mergable_dict[y] += 1
191
+ else: unmergable_blocks.append(y)
192
+ for k,v in mergable_dict.items():
193
+ if v == k.arg.cnt: mergable_blocks.append(k)
194
+ else: unmergable_blocks.extend([k]*v)
195
+ if len(mergable_blocks) == 0: return None
196
+ del mergable_dict
197
+
198
+ # create the block
199
+ arg = replace(x.arg, lst=tuple(flatten([y.arg.lst for y in mergable_blocks]))+x.arg.lst)
200
+ return UOp(x.op, src=tuple(flatten([y.src for y in mergable_blocks])+unmergable_blocks), arg=arg)
201
+
202
+ def remove_blockend(x:UOp):
203
+ # if there's any remaining blocks that need to go in this BLOCKEND, we don't remove it
204
+ if any(x.arg.end in y.arg.ctx for y in x.src if y.op in {Ops.BLOCK, Ops.BLOCKEND}): return None
205
+
206
+ if (parent_blocks := [y for y in x.src if y.op is Ops.BLOCK and y.arg.child_ctx is not None and x.arg.end in y.arg.child_ctx]):
207
+ assert all_same(parent_blocks), f"should never have two parent blocks (has {len(parent_blocks)})"
208
+ parent_block = parent_blocks[0]
209
+ assert len(parent_blocks) == parent_block.arg.cnt
210
+ # NOTE: DEFINE_ACC doesn't have to be handled in any special way
211
+ late_ops = list(x.arg.lst)
212
+ # NOTE: we have to add a barrier at the start if barrier is used in the range
213
+ if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE:
214
+ late_ops = [UOp(Ops.BARRIER)] + late_ops
215
+ # peephole opt, remove any BARRIERs next to each other
216
+ for i in range(len(late_ops)-1):
217
+ if late_ops[i].op is Ops.BARRIER and late_ops[i+1].op is Ops.BARRIER: late_ops[i+1] = UOp(Ops.NOOP)
218
+ arg = BasicBlock(parent_block.arg.lst+tuple(late_ops), tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt)
219
+ return UOp(Ops.BLOCK, src=tuple(y for y in x.src if y is not parent_block)+parent_block.src, arg=arg)
220
+
221
+ block_merge = PatternMatcher([
222
+ (UPat((Ops.BLOCK, Ops.BLOCKEND), name="x"), merge_block),
223
+ (UPat(Ops.BLOCKEND, name="x"), remove_blockend),
224
+ ])
223
225
 
224
- # final rewrite to merge all blocks into one
225
- sink = graph_rewrite(sink, pm_block_merge, ctx=children)
226
+ # ****** finalize ******
226
227
 
227
- # there should just be one block left, with a few parents with 0 srcs (now done in a rewriter)
228
- sink = graph_rewrite(sink, pm_block_finalize)
228
+ def finalize(sink:UOp) -> UOp:
229
+ if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src):
230
+ raise RuntimeError(f"linearize failure {sink.op} {[x.op for x in sink.src if x.op not in DONT_PLACE_IN_BLOCK]}")
229
231
 
230
- # sanity checks (NOTE: these can cause things to be skipped in BEAM)
231
- if not skip_check: type_verify(sink.arg.lst)
232
+ # place the early things
233
+ lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)
234
+ return UOp(Ops.BLOCKFINAL, arg=BasicBlock(tuple(lst)))
232
235
 
233
- # return the list. TODO: refactor to return the UOp
234
- return list(sink.arg.lst)
236
+ pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)])