tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -1,161 +1,114 @@
1
1
  # the job of the lowerer is to do indexing
2
- import functools, itertools, operator, math
3
- from dataclasses import dataclass
2
+ import functools, operator
4
3
  from typing import cast
5
- from tinygrad.dtype import dtypes, PtrDType
6
- from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
7
- from tinygrad.renderer import Renderer
8
- from tinygrad.helpers import all_int, prod, partition, flatten, unwrap
9
- from tinygrad.codegen.expander import expand_rewrite
10
-
11
- # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
12
- def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
13
- acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
14
- try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
15
- except ValueError: return None
16
- return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
4
+ from dataclasses import dataclass
5
+ from tinygrad.dtype import dtypes, AddrSpace, PtrDType
6
+ from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite
7
+ from tinygrad.helpers import prod, partition, flatten
17
8
 
18
9
  # ***** indexing *****
19
- def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
20
- # TODO: symbolic shape
21
- if not all_int(dims): return dims
22
- while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
23
- for i,m in enumerate(max_sizes):
24
- if i < (len(dims)-1) and dims[i] * dims[i+1] <= m:
25
- dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
26
- break
27
- else: return None
28
- return dims
29
-
30
- def _split_dims(dims, max_sizes):
31
- if all(d <= m for d,m in zip(dims, max_sizes)): return dims
32
- _dims = list(dims) + [1]*(3-len(dims))
33
- for i in range(len(_dims)):
34
- while _dims[i] > max_sizes[i]:
35
- div = next((d for d in range(2, math.ceil(math.sqrt(_dims[i])) + 1) if (_dims[i] % d) == 0), 1)
36
- if div == 1: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
37
- _dims[i], _dims[(i+1)%len(_dims)] = _dims[i]//div, _dims[(i+1)%len(_dims)]*div
38
- return tuple(_dims[:2] if _dims[2] == 1 else _dims[0] if _dims[1:3] == [1,1] else _dims)
39
-
40
- def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|None, reverse=False) -> list[UOp]:
41
- if reverse: dims = dims[::-1]
42
- # try to group first: (a, b, c, d) -> (ab, c, d)
43
- limited = (grouped if (grouped := _group_dims(dims, max_sizes)) else dims) if max_sizes is not None else dims
44
- # check if grouping failed
45
- if max_sizes is not None and len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
46
- # try to split up dims: (a,) -> (b, c)
47
- if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims
48
- ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
49
- if len(limited) < len(dims):
50
- ret = []
51
- if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}")
52
- for idx, contraction_group in zip(raw_idxs, contraction):
53
- for c in contraction_group[:-1]:
54
- ret.append(idx % dims[c])
55
- idx //= dims[c]
56
- ret.append(idx)
57
- elif len(limited) > len(dims):
58
- a, b = len(limited), len(dims)
59
- if a == 2 and b == 1: ret = [raw_idxs[0] * limited[1] + raw_idxs[1]]
60
- if a == 3 and b == 1: ret = [raw_idxs[0] * (limited[1] * limited[2]) + raw_idxs[1] * limited[2] + raw_idxs[2]]
61
- if a == 3 and b == 2: ret = [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
62
- return ret[::-1] if reverse else ret
63
10
 
64
11
  @dataclass
65
12
  class IndexContext:
13
+ axis_types: tuple[AxisType, ...]
66
14
  idxs: list[UOp]
67
- ridxs: list[UOp]
68
- acc_num: int = 0
69
-
70
- def get_index(ast:UOp, opts:Renderer) -> IndexContext:
71
- ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo()
72
- # NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
73
- full_shape = ast.full_shape
74
- first_upcasted = len(full_shape)-ki.upcasted
75
- # if there's no reduce, this is first_upcasted. assumes reduces are at the end
76
- first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.toposort if x.op is Ops.REDUCE_AXIS))
77
- local_loads = [x for x in ast.toposort if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
78
- # NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
79
- group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)])
80
- global_dims = first_reduce-ki.local_dims
81
-
82
- if opts.has_local:
83
- if ki.dont_use_locals:
84
- assert ki.local_dims == 0, "can't use locals if there's no local dims"
85
- idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True)
15
+ start: int = 0
16
+
17
+ def shape_to_idx(s, axis_types, start=0):
18
+ # indexes
19
+ idxs = []
20
+ for i, (s, at) in enumerate(zip(s, axis_types)):
21
+ if at in (AxisType.UPCAST, AxisType.UNROLL):
22
+ assert isinstance(s, int), "needs to be int to upcast/unroll"
23
+ idxs.append(UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(s), tuple(range(s))),), ((i,s),), tag=1))
86
24
  else:
87
- # define indexes for GPU-like execution
88
- idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \
89
- get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
90
- else:
91
- # all loops are RANGES
92
- idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i) for i,g in enumerate(full_shape[:first_reduce])]
25
+ # all others are RANGES
26
+ idxs.append(UOp(Ops.RANGE, dtypes.int, (sint_to_uop(s),), start+i))
27
+ return idxs
93
28
 
94
- # reduce loops
95
- idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i)
96
- for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
29
+ def get_index(ast:UOp) -> IndexContext:
30
+ axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
31
+ if len(ast.full_shape) != len(axis_types): axis_types = (AxisType.LOOP,)*len(ast.full_shape)
32
+ return IndexContext(axis_types, [], 0)
97
33
 
98
- # upcast loops
99
- for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
100
- assert isinstance(g, int), "needs to be int to upcast/unroll"
101
- idxs.append(UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
34
+ # ***** lowering (given index) *****
102
35
 
103
- # late indexes (group for reduce)
104
- ridxs = idxs[:]
105
- for a in range(first_reduce, first_reduce+group_for_reduces):
106
- ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(full_shape[a])), 1000+a)
36
+ def subblock(ctx: IndexContext, full_new_idx: list[UOp], src: UOp):
37
+ lc = IndexContext(ctx.axis_types, full_new_idx, ctx.start+1000)
38
+ ctx.start = lc.start
39
+ return graph_rewrite(src, pm_lowerer, lc, name="subblock", bottom_up=True)
107
40
 
108
- return IndexContext(idxs, ridxs)
41
+ def lower_reduce_axis(ctx: IndexContext, x: UOp):
42
+ new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
43
+ full_new_idx = list(ctx.idxs)
44
+ for a in x.axis_arg: full_new_idx[a] = new_idxs[a]
109
45
 
110
- # ***** lowering (given index) *****
46
+ ret = subblock(ctx, full_new_idx, x.src[0])
111
47
 
112
- def lower_reduce_axis(ctx: IndexContext, x: UOp):
113
48
  # NOTE: always using ridxs is fine here
114
- reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
49
+ reduce_range, reduce_expand = partition([full_new_idx[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
115
50
  assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
116
- alu_op: Ops = x.arg[0]
117
- ret = x.src[0]
118
51
  if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
119
- ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
120
- ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)])
121
- if not len(reduce_range): return ret
122
- # create ACC and assign
123
- acc = UOp(Ops.DEFINE_ACC, x.dtype, (x.const_like(identity_element(alu_op, x.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
124
- ctx.acc_num += 1
125
- return acc.assign(acc.alu(alu_op, ret))
126
-
127
- def lower_load_store(ctx: IndexContext, x: UOp):
128
- idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
129
- buf = x.src[0]
130
- if x.op is Ops.LOAD:
131
- barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
132
- return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
133
- # NOTE: only store the local reduceop in the threads that are actually doing the reduce
134
- if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.ASSIGN:
135
- reduce_input = x.src[2].src[1].src[1] if x.src[2].src[1].src[1] is not x.src[2].src[0] else x.src[2].src[1].src[0]
136
- store_back = reduce_input.op is Ops.LOAD and cast(PtrDType, reduce_input.src[0].dtype).local
137
- else: store_back = False
138
- # NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
139
- if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs])
140
- if (not cast(PtrDType, x.src[0].dtype).local) or store_back:
141
- for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
142
- if oidx is not ridx: valid = valid * oidx.eq(0)
143
- return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid), x.src[2]))
144
-
145
- def lower_const(x:UOp):
146
- assert all(v.mask is None for v in unwrap(x.st).views), f"VIEW in CONST/DEFINE_VAR source must be unmasked, got {x.st}"
147
- return x.replace(src=())
52
+ ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis), tag=1)
53
+ # REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group
54
+ return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple(reduce_range), x.arg[0])
55
+
56
+ def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
57
+ # TODO: reenable after REDUCE_AXIS is fixed
58
+ #assert x.src[1].shape == x.src[0].shape, f"shape mismatch on store {x.src[1].shape} != {x.src[0].shape}"
59
+
60
+ new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
61
+ idx, valid = x.st_arg.to_indexed_uops(new_idxs)
62
+ used_idxs = [x for x in UOp.sink(idx, valid).toposort() if x in new_idxs]
63
+ real_new_idxs = []
64
+ for i in range(len(x.src[0].shape)):
65
+ if new_idxs[i] in used_idxs or len(ctx.idxs) <= i: real_new_idxs.append(new_idxs[i])
66
+ else: real_new_idxs.append(ctx.idxs[i])
67
+
68
+ stored = subblock(ctx, real_new_idxs, x.src[1])
69
+ used_ranges = [x for x in used_idxs if x.op is Ops.RANGE]
70
+ ret = buf.index(idx, valid).store(stored, *used_ranges)
71
+
72
+ # insert BARRIER if we are ending a LOCAL, IF if we are ending a GROUP_REDUCE
73
+ if cast(PtrDType, buf.dtype).addrspace == AddrSpace.LOCAL and \
74
+ any(ctx.axis_types[x.arg%1000] in {AxisType.GROUP_REDUCE, AxisType.LOCAL} for x in used_ranges):
75
+ ret = ret.barrier()
76
+ range_gates = [x.eq(0) for x in used_ranges if ctx.axis_types[x.arg%1000] == AxisType.GROUP_REDUCE]
77
+ if len(range_gates): ret = UOp(Ops.IF, src=(functools.reduce(operator.and_, range_gates), ret))
78
+ return ret
79
+
80
+ def fixup_wmma(ctx:IndexContext, x:UOp):
81
+ if x.tag is not None: return None
82
+ new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
83
+ full_new_idx = list(ctx.idxs)
84
+ for a in x.arg[-1]: full_new_idx[a] = new_idxs[a]
85
+
86
+ srcs = subblock(ctx, full_new_idx, UOp.sink(*x.src)).src
87
+
88
+ # NOTE: this assumes these are expanded. which now shouldn't change anything
89
+ new_x_arg_m2 = tuple([tuple([(full_new_idx[a].arg[0][0], sz) for a,sz in v]) for v in x.arg[-2]])
90
+ new_x_arg_m1 = tuple([full_new_idx[a].arg[0][0] for a in x.arg[-1]])
91
+ return x.replace(src=srcs, arg=x.arg[:-2]+(new_x_arg_m2, new_x_arg_m1), tag=1)
148
92
 
149
93
  pm_lowerer = PatternMatcher([
94
+ # TODO: remove these hacks
95
+ # hack for old style CONST(VIEW) (now it's just VIEW(CONST))
96
+ (UPat((Ops.DEFINE_VAR, Ops.CONST), src=(UPat(Ops.VIEW, name="v"),), name="c"), lambda c,v: c.replace(src=()).view(v.arg)),
97
+ # hack for old style VALID (now it's just VIEW(CONST))
98
+ (UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c"), UPat(Ops.CONST, arg=0)), lambda c,v: c.replace(src=()).view(v.arg)),
99
+
100
+ # consts and loads
101
+ (UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),), name="view"),
102
+ lambda ctx,view,c: c if all(x.mask is None for x in view.arg.views) else view.arg.to_indexed_uops(ctx.idxs)[1].where(c, c.const_like(0))),
103
+ (UPat(Ops.LOAD, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"),
104
+ lambda ctx,buf,x: UOp(Ops.LOAD, x.dtype, (buf.index(*x.st_arg.to_indexed_uops(ctx.idxs)),)+x.src[1:])),
105
+
106
+ # reduce/view_const
150
107
  (UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
151
- (UPat((Ops.CONST, Ops.DEFINE_VAR), src=(UPat(Ops.VIEW),), name="x"), lower_const),
152
- (UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
153
- # rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
154
- (UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
155
- (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
156
- ])
108
+ (UPat(Ops.STORE, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"), lower_store),
109
+ (UPat(Ops.WMMA, name="x"), fixup_wmma),
157
110
 
158
- def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
159
- sink = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
160
- # expand_rewrite turns this into a vectorized program
161
- return expand_rewrite(sink)
111
+ # axis fixups for WMMA
112
+ (UPat((Ops.CONTRACT, Ops.UNROLL), name="x"),
113
+ lambda ctx,x: x.replace(tag=1, arg=tuple([(ctx.idxs[a].arg[0][0], sz) for a,sz in x.arg])) if x.tag is None else None),
114
+ ])
@@ -0,0 +1,38 @@
1
+ # opt opinionatedly transforms an ast into an optimized ast using either heuristics or beam search
2
+
3
+ from tinygrad.codegen.opt.kernel import Kernel
4
+ from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
5
+ from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops
6
+ from tinygrad.helpers import NOOPT, BEAM, USE_TC, getenv
7
+ from tinygrad.renderer import Renderer
8
+ from tinygrad.uop.spec import type_verify
9
+
10
+ def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
11
+ """
12
+ Optimize an AST based on heuristics or BEAM search.
13
+
14
+ Args:
15
+ ast: The Ops.SINK rooted AST
16
+ renderer: The renderer used to generate the code
17
+
18
+ Returns:
19
+ The Ops.SINK rooted AST transformed to apply the opts and with a KernelInfo in the arg.
20
+ """
21
+
22
+ k = Kernel(ast, opts=renderer)
23
+ if ast.arg is not None and ast.arg.opts_to_apply is not None: k.apply_opts(ast.arg.opts_to_apply)
24
+ elif not NOOPT:
25
+ if not k.apply_tensor_cores(USE_TC.value): k.apply_opts(hand_coded_optimizations(k))
26
+ if BEAM >= 1:
27
+ from tinygrad.codegen.opt.search import beam_search, bufs_from_lin
28
+ kb = Kernel(ast, opts=renderer)
29
+ rawbufs = bufs_from_lin(kb, allocate=False)
30
+ k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
31
+ ret = k.get_optimized_ast()
32
+ if __debug__: type_verify(list(ret.toposort()))
33
+ return ret
34
+
35
+ pm_optimize = PatternMatcher([
36
+ (UPat(Ops.SINK, name="ast"), lambda ctx,ast:
37
+ get_optimized_ast(ast, ctx) if (ast.arg is None or ast.arg.opts_to_apply is not None) and ast.src[0].st is not None else None),
38
+ ])
@@ -0,0 +1,125 @@
1
+ import itertools
2
+ from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps, KernelOptError, AxisType
3
+ from tinygrad.helpers import getenv, DEBUG, prod, NOLOCALS
4
+ from tinygrad.dtype import ImageDType
5
+ from tinygrad.uop.ops import Ops, resolve
6
+
7
+ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
8
+ # make a copy so it does not mutate the input
9
+ k = k.copy()
10
+
11
+ # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
12
+ MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
13
+ if k.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
14
+ k.reduceop is not None and k.reduceop.arg[0] is Ops.ADD and len(k.full_shape) >= 2 and k.opts.has_shared and \
15
+ (mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
16
+ st0, st1 = k.sts[k.bufs.index(mulop.src[0])], k.sts[k.bufs.index(mulop.src[1])]
17
+ strides0, strides1 = st0.real_strides(), st1.real_strides()
18
+ def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
19
+ if strides0[first_reduce:=(k.axes_of(AxisType.REDUCE)[0])] == 1 and \
20
+ not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
21
+ for global_idx in k.axes_of(AxisType.GLOBAL):
22
+ if k.full_shape[first_reduce]%MV_THREADS_PER_ROW == 0 and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
23
+ if DEBUG >= 3:
24
+ print(f"MATVEC: {k.full_shape=} {first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
25
+ if MV_THREADS_PER_ROW > 1: k.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
26
+ if MV_BLOCKSIZE > 1: k.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
27
+ if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
28
+ return k.applied_opts
29
+
30
+ # are we grouping? (requires local shape support)
31
+ if resolve(prod(k.sts[0].shape[i] for i in k.upcastable_dims) <= 2048, False):
32
+ for sz in [16]:
33
+ try:
34
+ k.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
35
+ break
36
+ except KernelOptError: pass
37
+
38
+ # upcast float4 images
39
+ for buf_index,buf in enumerate(k.bufs):
40
+ if isinstance(buf.src[0].dtype, ImageDType):
41
+ if (unit_stride_axes_mul_4 := [i for i in k.sts[buf_index].unit_stride_axes(ignore_valid=True) if k.sts[buf_index].shape[i]%4 == 0]):
42
+ if (axis:=unit_stride_axes_mul_4[0]) in k.upcastable_dims:
43
+ k.apply_opt(Opt(OptOps.UPCAST, axis, 4))
44
+ elif axis in k.unrollable_dims:
45
+ k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims.index(axis), 4))
46
+
47
+ # no more opt if we are grouping
48
+ if k.group_for_reduces: return k.applied_opts
49
+
50
+ # **** below this line need to be optional and benchmarked ****
51
+
52
+ # if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
53
+ to_upcast: list[int] = []
54
+ # upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
55
+ for axis in k.upcastable_dims:
56
+ if k.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in k.sts) and \
57
+ prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
58
+ if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
59
+ to_upcast.append(axis)
60
+ for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))
61
+
62
+ # potentially do more upcasts of non reduce axes based on a heuristic
63
+ is_dsp = k.opts is not None and k.opts.device == "DSP"
64
+ upcasted_axis: set[int] = set()
65
+ while resolve(prod(k.sts[0].shape[i] for i in k.upcastable_dims) >= 1024):
66
+ xb_choices = []
67
+ # consider all upcastable axes with 3 or 4 upcast (128 on the DSP)
68
+ for axis, upcast_amount in itertools.product(k.upcastable_dims, ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
69
+ # if we haven't upcasted it, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
70
+ if axis in upcasted_axis or k.full_shape[axis]%upcast_amount != 0: continue
71
+ if any(st.views[-1].strides[axis] == 0 and \
72
+ all(x != 0 for t,x in zip(k.axis_types, st.real_strides()) if t in (AxisType.UPCAST, AxisType.UNROLL)) for st in k.sts):
73
+ xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts),
74
+ sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount))
75
+ if xb_choices:
76
+ xb_choices = sorted(xb_choices)
77
+ if DEBUG >= 4: print(f"more upcast axis : {xb_choices}")
78
+ k.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
79
+ upcasted_axis.add(xb_choices[0][2])
80
+ else: break
81
+
82
+ # if last reduce dim is small(ish), loop unroll the reduce
83
+ # NOTE: this can fail on multireduce with mismatching dimensions, this is okay
84
+ try:
85
+ upcast_size = prod(k.full_shape[a] for a in k.axes_of(AxisType.UPCAST, AxisType.UNROLL))
86
+ if k.unrollable_dims and (upcast_size <= 4 or not k.axes_of(AxisType.UNROLL)) and (upcast_size < 64):
87
+ if (s:=k.full_shape[k.unrollable_dims[-1]]) <= 32:
88
+ k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0))
89
+ # if it's small, upcast a second reduce dimension too
90
+ if k.unrollable_dims and s <= 3 and k.full_shape[k.unrollable_dims[-1]] <= 3:
91
+ k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0))
92
+ else:
93
+ for splits in [4]:
94
+ if k.full_shape[axis:=k.unrollable_dims[-1]]%splits == 0:
95
+ k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, splits))
96
+ break
97
+ except KernelOptError: pass
98
+
99
+ # if nothing at all is upcasted and it's easy to, do an upcast
100
+ for splits in [4]:
101
+ # TODO: somehow this never hits a reduce
102
+ if not k.upcasted and k.upcastable_dims and k.full_shape[k.upcastable_dims[-1]] % splits == 0:
103
+ k.apply_opt(Opt(OptOps.UPCAST, k.upcastable_dims[-1], splits))
104
+
105
+ # **** local groups ****
106
+
107
+ if k.opts.has_local:
108
+ if NOLOCALS:
109
+ k.apply_opt(Opt(OptOps.NOLOCALS))
110
+ else:
111
+ # prioritize making expand axes local
112
+ local_axis_ranking = [(any(st.views[-1].strides[axis] == 0 for st in k.sts), axis) for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP)]
113
+ to_local: list[tuple[int, int]] = []
114
+ for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
115
+ local_size = prod(sz for _, sz in to_local)
116
+ local_sz: int|None = next((x for x in ([32] * (axis == 0) + [16,8,4,3,2]) if k.full_shape[axis] % x == 0 and local_size * x <= 128), None)
117
+ if local_sz is not None: to_local.append((axis, local_sz))
118
+ deleted_shape = 0
119
+ for axis, local_sz in sorted(to_local[:3]):
120
+ axis = axis - deleted_shape
121
+ will_delete_shape = local_sz == k.full_shape[axis]
122
+ k.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
123
+ if will_delete_shape: deleted_shape += 1
124
+
125
+ return k.applied_opts