tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -1,88 +1,13 @@
1
- from typing import Optional, Any, Callable
2
- import functools, operator
1
+ from typing import Any, cast
2
+ import functools, operator, itertools
3
3
  from collections import defaultdict
4
- from tinygrad.dtype import dtypes, ImageDType, PtrDType
5
- from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve
6
- from tinygrad.ops import graph_rewrite, GroupOp
7
- from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, mulacc_unrolled
8
- from tinygrad.helpers import getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
9
- from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
4
+ from dataclasses import dataclass
5
+ from tinygrad.dtype import dtypes, ImageDType, PtrDType, DType, AddrSpace
6
+ from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
7
+ from tinygrad.uop.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
8
+ from tinygrad.helpers import getenv, flatten, AMX, prod, partition
10
9
  from tinygrad.renderer import Renderer
11
10
 
12
- # ***** float4/image store handling *****
13
-
14
- def fold_expanded(ex, buf):
15
- new_srcs = dedup(list(ex.src))
16
- old_new_srcs = new_srcs[:]
17
- is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType)
18
-
19
- # TODO: get the device from the buffer somehow
20
- # NOTE: this can't be Device.DEFAULT because it opens devices
21
- if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None
22
- lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
23
-
24
- # first, extract all the relevant offsets
25
- offsets_rootsrc: defaultdict[Any, dict] = defaultdict(dict)
26
- for i,s in enumerate(new_srcs):
27
- idx = s.src[0].src[1]
28
- if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue
29
- if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
30
- elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
31
- else: root_src, arg = idx, 0
32
- # add gates for gated
33
- if len(s.src[0].src) == 3: root_src = (s.src[0].src[2], root_src)
34
- assert arg not in offsets_rootsrc[root_src], f"{offsets_rootsrc[root_src][arg]} != {i} with {len(s.src)} sources"
35
- offsets_rootsrc[root_src][arg] = i
36
-
37
- # then rewrite everything we can
38
- used: set[tuple[UOp, UOp]] = set()
39
- for rootsrc, offsets in offsets_rootsrc.items():
40
- for o in offsets:
41
- for fold_length in lengths:
42
- if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)):
43
- load_1 = new_srcs[offsets[o]]
44
- new_src = list(load_1.src)
45
- oidx = new_src[0].src[1]
46
- if oidx.divides(fold_length) is None: continue
47
- if is_image:
48
- # for images, we rewrite the index. it must evenly divide 4 from the above check
49
- new_src[0] = buf.index(
50
- UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
51
- rootsrc[0] if isinstance(rootsrc, tuple) else None)
52
- else:
53
- # for non image, we upcast the index pointer
54
- new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(size=new_src[0].dtype.size, local=new_src[0].dtype.local))
55
- # generate the folded new_srcs
56
- if is_load:
57
- new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
58
- for i in range(fold_length): new_srcs[offsets[o+i]] = new_load.gep(i)
59
- else: # vectorize the store
60
- new_src[1] = UOp(Ops.VECTORIZE, new_src[1].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[1] for i in range(fold_length)))
61
- for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(Ops.STORE, dtypes.void, tuple(new_src)) if i == 0 else None
62
- used.update((rootsrc,o+i) for i in range(fold_length))
63
-
64
- # dedup expand for LOAD
65
- if is_load and len(old_new_srcs) != len(ex.src): new_srcs = [new_srcs[old_new_srcs.index(s)] for s in ex.src]
66
- # remove Nones for STORE
67
- return UOp(ex.op, ex.dtype, tuple(x for x in new_srcs if x is not None), ex.arg) if len(used) else None
68
-
69
- def fix_unfoldable_image_load(load:UOp, buf:UOp):
70
- if not isinstance(buf.dtype, ImageDType) or (oidx:=load.src[0].src[1]).dtype.count == 2: return None
71
- id4 = oidx % 4
72
- new_src = list(load.src)
73
- # TODO: copied logic from above
74
- new_src[0] = load.src[0].src[0].index(
75
- UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
76
- load.src[0].src[2] if len(load.src[0].src) == 3 else None)
77
- vec_load = UOp(Ops.LOAD, load.dtype.vec(4), tuple(new_src))
78
- return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan')))
79
-
80
- buf_idx_pat = UPat(Ops.INDEX, src=(UPat.var("buf"),), allow_any_len=True)
81
- float4_folding = PatternMatcher([
82
- (UPat(Ops.VECTORIZE, src=UPat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
83
- (UPat((Ops.BARRIER, Ops.SINK), src=UPat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
84
- ])
85
-
86
11
  # ***** image load valid simplification *****
87
12
 
88
13
  def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
@@ -95,7 +20,8 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
95
20
  # can drop valid if idx is out of bound when valid is False
96
21
  drop_stmt = []
97
22
  for stmt in split_uop(valid, Ops.AND):
98
- X, is_upper_bound, c = parse_valid(stmt)
23
+ try: X, is_upper_bound, c = parse_valid(stmt)
24
+ except ValueError: return None
99
25
 
100
26
  # for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
101
27
  if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)):
@@ -119,27 +45,173 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
119
45
  new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None
120
46
  return buf.index(idx, new_valid)
121
47
 
122
- # ***** optional patterns *****
123
-
124
- powers_of_two = {2**i:i for i in range(64)}
125
- @functools.lru_cache(None)
126
- def get_late_rewrite_patterns(ops, force_transcendental=False):
127
- pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
128
- ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
129
- # rewrite SQRT to xpow 0.5
130
- if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
131
- # rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
132
- if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
133
- # rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
134
- if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)]
135
- if Ops.SHR in ops:
136
- pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) and resolve(x>=0,False) else None)]
137
- if Ops.NEG in ops:
138
- pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
139
- if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
140
- if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
141
- return PatternMatcher(pat)
48
+ def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
49
+ if store_gate not in [gate.src[0] for gate in val.toposort() if gate.op is Ops.IF]: return None
50
+ # remove the gate from the index
51
+ return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val, *store.src[2:])
52
+
53
+ load_store_indexing = PatternMatcher([
54
+ # simplify valid
55
+ (UPat(Ops.AND, name="valid"), simplify_valid),
56
+ # image load valid idx simplification
57
+ (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
58
+ # index True is just Index
59
+ (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)),
60
+ # delete_redundant_gates (after expand)
61
+ (UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
62
+ UPat.var("val")), name="store", allow_any_len=True), delete_redundant_gates),
63
+ ])
142
64
 
65
+ # ***** load/store grouping *****
66
+
67
+ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
68
+ if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
69
+ # generate the individual indexes
70
+ midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
71
+ symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}")
72
+ # extract all the relevant offsets
73
+ offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
74
+ for i in range(vec.dtype.count):
75
+ idx: Any = midx.src[i].src[1]
76
+ if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
77
+ elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
78
+ elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
79
+ else: root_src, arg = idx, 0
80
+ if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
81
+ offsets_rootsrc[root_src].setdefault(arg, []).append(i)
82
+
83
+ # the buf.dtype is always a pointer
84
+ ptrdtype = cast(PtrDType, buf.dtype)
85
+
86
+ # then rewrite everything we can into groups
87
+ ret = []
88
+ idxs: list[int|None] = [None]*vec.dtype.count
89
+ global_offset = 0
90
+ for offsets in offsets_rootsrc.values():
91
+ grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
92
+ for grp in grouped_offsets:
93
+ # get the index offset for this element. using [0] is okay, because they are the same
94
+ lidx = midx.src[offsets[grp[0]][0]]
95
+ if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace))
96
+ # set the idxs of the output
97
+ for i,g in enumerate(grp):
98
+ for oo in offsets[g]: idxs[oo] = global_offset+i
99
+ # add this lidx to the CAT
100
+ ret.append(lidx)
101
+ global_offset += len(grp)
102
+ assert None not in idxs, f"some idxs are missing {idxs}"
103
+ # this base thing is for image, we want the CAT to be a normal pointer
104
+ post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret))
105
+ return post_cat.gep(tuple(cast(list[int], idxs)))
106
+
107
+ def cat_after_store(cat:UOp, data:UOp, sto:UOp):
108
+ # TODO: this is written in many places
109
+ offset = 0
110
+ ret: list[UOp] = []
111
+ for s in cat.src:
112
+ ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))), *sto.src[2:]))
113
+ offset += s.dtype.count
114
+ return UOp(Ops.NOOP, src=tuple(ret))
115
+
116
+ def gep_on_store(gep:UOp, st:UOp, sto:UOp):
117
+ # NOTE: we need to invert the gep here, but it may be an expanding gep
118
+ # fake argsort. TODO: handle duplicates
119
+ a = {}
120
+ for i,x in enumerate(gep.arg): a[x] = i
121
+ new_arg = tuple(x[1] for x in sorted(a.items()))
122
+ return gep.src[0].store(st.gep(new_arg), *sto.src[2:])
123
+
124
+ load_store_folding = PatternMatcher([
125
+ (UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"))), expand_index),
126
+ (UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"),
127
+ UPat.var("mask"))), expand_index),
128
+ # GEP after LOAD
129
+ (UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
130
+ lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
131
+ # GEP on data of STORE
132
+ (UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), allow_any_len=True, name="sto"), gep_on_store),
133
+ # put PTRCAT after LOAD
134
+ (UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
135
+ lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
136
+ # put PTRCAT after STORE
137
+ (UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), allow_any_len=True, name="sto"), cat_after_store),
138
+ ])
139
+
140
+ # *** correct load/store ***
141
+
142
+ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
143
+ # this splits loads and stores into multiple chunks
144
+
145
+ # if there's only one element to load/store, no splitting needed
146
+ if (sz:=ls.src[0].dtype.count) == 1: return None
147
+ buf = idx.src[0]
148
+
149
+ # determine fold lengths
150
+ lengths = []
151
+ must_divide = True
152
+ if ctx is not None and ctx.device == "DSP":
153
+ lengths = [128,64,32,16,8,4]
154
+ must_divide = False
155
+ elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
156
+ pass
157
+ elif cast(PtrDType, buf.dtype).addrspace == AddrSpace.REG:
158
+ pass
159
+ elif isinstance(buf.dtype, ImageDType):
160
+ lengths = [4]
161
+ elif ctx is not None and ctx.supports_float4:
162
+ # TODO: a better way to get this than ctx
163
+ lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2])
164
+ lengths.append(1) # worst case, it's not folded
165
+
166
+ # filter fold lengths that don't divide
167
+ if must_divide: lengths = [x for x in lengths if idx.src[1].divides(x) is not None]
168
+
169
+ # split based on the fold lengths
170
+ global_offset = 0
171
+ ret = []
172
+ ptrdtype = cast(PtrDType, buf.dtype)
173
+ while global_offset < sz:
174
+ # with 1 at the end of the lengths list, this will always hit
175
+ for fold_length in lengths:
176
+ if global_offset+fold_length > sz: continue
177
+ lidx = buf.index(idx.src[1] + global_offset, idx.src[2] if len(idx.src) > 2 else None)
178
+ if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace))
179
+ if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
180
+ else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
181
+ global_offset += fold_length
182
+ break
183
+
184
+ # if it wasn't split, we return None. otherwise we CAT them
185
+ if len(ret) <= 1: return None
186
+ return UOp(Ops.CAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp(Ops.NOOP, src=tuple(ret))
187
+
188
+ def image_fixup(ls:UOp):
189
+ # normal image load or store, with the CAST from expand_index
190
+ if ls.src[0].op is Ops.CAST and isinstance(image_dtype:=ls.src[0].src[0].dtype, ImageDType):
191
+ assert ls.src[0].dtype.count == 4, "image must be casted to 4"
192
+ idx = ls.src[0].src[0]
193
+ oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
194
+ idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
195
+ return ls.replace(src=(idx,)+ls.src[1:])
196
+
197
+ # this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores
198
+ if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].dtype != dtypes.int.vec(2):
199
+ assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it"
200
+ idx = ls.src[0]
201
+ id4 = idx.src[1] % 4
202
+ oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
203
+ idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
204
+ vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:])
205
+ return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan')))
206
+
207
+ return None
208
+
209
+ correct_load_store = PatternMatcher([
210
+ # split LOAD/STORE
211
+ (UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.INDEX, name="idx").cast(),), name="ls", allow_any_len=True), split_load_store),
212
+ # image indexing, including unfoldable images
213
+ (UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup),
214
+ ])
143
215
 
144
216
  # *** uop expander ***
145
217
 
@@ -155,93 +227,164 @@ def no_vectorized_wmma(wmma:UOp):
155
227
  wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
156
228
  return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
157
229
 
158
- def no_vectorized_alu(alu):
230
+ def no_vectorized_alu(alu:UOp):
159
231
  if alu.dtype.vcount == 1: return None
160
232
  alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
161
233
  return UOp(Ops.VECTORIZE, alu.dtype, alus)
162
234
 
163
- def no_vectorized_load_store(ls:UOp):
164
- idx = ls.src[0]
165
- assert isinstance(idx.dtype, PtrDType)
166
- if idx.dtype.v == 1: return None
167
- tv = [UOp(ls.op, ls.dtype.scalar(), tuple(j.gep(i) for j in ls.src)) for i in range(idx.dtype.v)]
168
- return UOp(Ops.VECTORIZE, ls.dtype, tuple(tv))
169
-
170
- def no_vectorized_acc(acc:UOp):
235
+ def no_vectorized_acc(acc:UOp, c:UOp):
171
236
  if acc.dtype.count == 1: return None
172
- alus = tuple(UOp(acc.op, acc.dtype.scalar(),
173
- tuple(s.gep(i) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
174
- return UOp(Ops.VECTORIZE, acc.dtype, alus)
237
+ assert c.arg == 0, "this only supports index 0"
238
+ new_acc = acc.replace(dtype=acc.dtype.base.scalar().ptr(acc.dtype.count, cast(PtrDType, acc.dtype).addrspace))
239
+ return UOp(Ops.PTRCAT, acc.dtype, tuple([new_acc.index(UOp.const(dtypes.int, i)) for i in range(acc.dtype.count)]))
175
240
 
176
241
  devectorize = PatternMatcher([
177
242
  # no ALU on vectorized dtypes
178
- (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
243
+ (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
179
244
  (UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
180
- (UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
181
- (UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
245
+ (UPat(Ops.DEFINE_REG, name="acc").index(UPat.cvar("c")), no_vectorized_acc),
182
246
  ])
183
247
 
184
- devectorize_load_store = PatternMatcher([
185
- # TODO: add vectorized support to transcendental
186
- (UPat((Ops.INDEX, Ops.EXP2, Ops.LOG2, Ops.SIN), name="alu"), no_vectorized_alu),
187
- (UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
188
- ])
189
-
190
- def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
191
- if store_gate not in [gate.src[0] for gate in val.toposort if gate.op is Ops.IF]: return None
192
- # remove the gate from the index
193
- return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
194
-
195
- load_store_indexing = PatternMatcher([
196
- # late fixup of unfoldable image loads
197
- (UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
198
- # simplify valid
199
- (UPat(Ops.AND, name="valid"), simplify_valid),
200
- # image load valid idx simplification
201
- (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
202
- # delete_redundant_gates (after expand)
203
- (UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
204
- UPat.var("val"))), delete_redundant_gates),
205
- ])
206
-
207
- def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:UOp|None=None) -> UOp:
208
- # this moves the mask from the indexing to the load/store op for rendering
209
- nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
210
- return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is Ops.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
211
-
212
248
  pm_render = PatternMatcher([
213
249
  # for rendering, we use explicit VECTORIZE
214
250
  (UPat(Ops.CONST, name='c'),
215
251
  lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
216
252
  (UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
217
253
  (UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
254
+ (UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
218
255
  (UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
219
- # move masks of loads/stores
220
- (UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("mask"))),
221
- masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
256
+ # give any loads that are masked an alt value
257
+ (UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
258
+ lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None),
222
259
  # gate any stores that aren't gated with ifs
223
- (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
224
- lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
260
+ (UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True),
261
+ lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \
262
+ len(store.src) <= 2 or store.src[2].op != Ops.IF else None),
225
263
  ])
226
264
 
227
- # *** uop graph ***
228
-
229
- def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
230
- assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
231
- supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
232
- extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
233
-
234
- if DEVECTORIZE:
235
- # devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse
236
- sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+
237
- mulacc_unrolled)
238
- else:
239
- # new devectorize only for load/store
240
- sink = graph_rewrite(sink, sym+devectorize_load_store+mulacc_unrolled)
241
-
242
- # optional pre matcher
243
- if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher)
244
-
245
- # final rules for the renderer (without sym)
246
- sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher)
247
- return sink
265
+ # *** Ops.REDUCE -> Ops.DEFINE_ACC ***
266
+
267
+ @dataclass
268
+ class ReduceContext:
269
+ acc_num: int = 0
270
+
271
+ def horizontal_reduce(inp:UOp, out_dtype:DType) -> list[UOp]:
272
+ # if this has a horizontal reduction component, do that first
273
+ if inp.dtype != out_dtype:
274
+ # NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7]
275
+ horizontal_amount = inp.dtype.count//out_dtype.count
276
+ return [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)]
277
+ return [inp]
278
+
279
+ def reduce_to_acc(ctx:ReduceContext, red:UOp):
280
+ inp, reduce_range = red.src[0], red.src[1:]
281
+ lst = horizontal_reduce(inp, red.dtype)
282
+ assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
283
+ # if we have a range
284
+ if len(reduce_range) != 0:
285
+ topo = inp.toposort()
286
+ stored_ranges = flatten([x.src[2:] for x in topo if x.op is Ops.STORE])
287
+ input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in stored_ranges])
288
+ identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar()))
289
+ acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
290
+ do_store = acc.store(identity, UOp(Ops.NOOP, src=input_ranges)) if len(input_ranges) else acc.store(identity)
291
+ lst = [acc.load(do_store, *reduce_range)] + lst # put acc as the first element
292
+ ctx.acc_num += 1
293
+ ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
294
+ return acc.load(acc.store(ret, *reduce_range)) if len(reduce_range) != 0 else ret
295
+
296
+ def no_vectorized_reduce(inp:UOp, red:UOp):
297
+ if inp.dtype != red.dtype:
298
+ red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), horizontal_reduce(inp, red.dtype)),)+red.src[1:])
299
+ if red.dtype.vcount == 1: return red
300
+ # no_vectorize_alu ignoring ranges
301
+ if red.dtype.vcount == 1: return None
302
+ alus = tuple(UOp(red.op, red.dtype.scalar(), (red.src[0].gep(i),)+red.src[1:], red.arg) for i in range(red.dtype.vcount))
303
+ return UOp(Ops.VECTORIZE, red.dtype, alus)
304
+
305
+ def reduce_rangeless(red:UOp):
306
+ # TODO: share code with reduce_unparented
307
+ if red.arg not in {Ops.ADD, Ops.MAX}: return None
308
+ if red.src[0].dtype != red.dtype: return None
309
+ if any(x.op in {Ops.RANGE} for x in red.src[0].toposort()): return None
310
+ ret = red.src[0]
311
+ if red.arg is Ops.ADD:
312
+ for r in red.src[1:]:
313
+ ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
314
+ return ret
315
+
316
+ def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparents)
317
+
318
+ pm_reduce_collapse = PatternMatcher([
319
+ # lift x+y out of reduce on lt
320
+ ((UPat.var("x")+UPat.var("y")) < UPat.var("c"), lambda x,y,c: (x < (c-y)) if no_range(y) and no_range(c) else None),
321
+ # lift x*y out of reduce
322
+ ((UPat.var("x")*UPat.var("y")) < UPat.var("c"),
323
+ lambda x,y,c: (x < ((c+y-1) // y)) if no_range(y) and no_range(c) and y.vmin > 0 else None),
324
+ # lift x+y out of reduce on ne
325
+ ((UPat.var("x")+UPat.var("y")) != UPat.var("c"), lambda x,y,c: (x != (c-y)) if no_range(y) and no_range(c) else None),
326
+ # fold the range
327
+ ((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(0, UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True),
328
+ lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
329
+ ((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True),
330
+ lambda r,cut,val: cut.maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
331
+ # REDUCE on ADD
332
+ ((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
333
+ lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)),
334
+ # MUL casted bool
335
+ ((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")),
336
+ lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)),
337
+ # WHERE on LOAD (works on max too)
338
+ (UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True),
339
+ lambda buf,idx,gate: buf.index(idx, gate).load()),
340
+ (UPat.var("gate").where(0, UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True),
341
+ lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()),
342
+ # INDEX on RANGE / gated RANGE
343
+ (UPat.var("buf").index(UPat.var("expr"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted())),
344
+ lambda buf,r,idx,expr: buf.index(expr.substitute({r:idx.cast(r.dtype)}), (idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))),
345
+ # AND on WHERE
346
+ ((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \
347
+ .where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
348
+ lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)),
349
+ # remove REDUCEs that no longer have a RANGE in the src
350
+ (UPat(Ops.REDUCE, name="red"), reduce_rangeless),
351
+ # devectorize REDUCE
352
+ (UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce),
353
+ # index/load/where. TODO: this is more aggressive than needed
354
+ (UPat((Ops.INDEX, Ops.LOAD, Ops.WHERE), name="alu"), no_vectorized_alu),
355
+ ])+sym
356
+
357
+ def reduce_collapse(red:UOp):
358
+ included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:]))
359
+ if any(x.op in {Ops.STORE, Ops.REDUCE} for x in included): return None
360
+ replaces: dict[UOp, UOp] = {}
361
+ for u in included:
362
+ for s in u.src:
363
+ if s in not_included and s not in replaces and s.op not in {Ops.CONST, Ops.VCONST, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}:
364
+ replaces[s] = UOp(Ops.DEFINE_VAR, dtype=s.dtype, arg=(f'in{len(replaces)}', s.vmin, s.vmax))
365
+ collapse_fxn = red.substitute(replaces)
366
+ sink = graph_rewrite(collapse_fxn, pm_reduce_collapse, name="reduce_collapse")
367
+ # TODO: why is REDUCE needed here and just RANGE isn't enough?
368
+ if any(x.op in {Ops.REDUCE, Ops.RANGE} for x in sink.toposort()): return None
369
+ return sink.substitute({v:k for k,v in replaces.items()})
370
+
371
+ def reduce_unparented(red:UOp):
372
+ if red.arg not in {Ops.ADD, Ops.MAX}: return None
373
+ reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents)
374
+ if len(reduce_unparented) == 0: return None
375
+ ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) if len(reduce_parented) or red.dtype != red.src[0].dtype else red.src[0]
376
+ if red.arg is Ops.ADD:
377
+ for r in reduce_unparented: ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
378
+ return ret
379
+
380
+ pm_reduce = PatternMatcher([
381
+ # remove any ranges from a REDUCE that aren't referenced in the reduce source
382
+ (UPat(Ops.REDUCE, name="red"), reduce_unparented),
383
+ # remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
384
+ (UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),
385
+ # REDUCE -> DEFINE_ACC+ASSIGN
386
+ (UPat(Ops.REDUCE, name="red"), reduce_to_acc),
387
+ # tensor core built in accumulate
388
+ (UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
389
+ lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
390
+ ])+sym
@@ -2,8 +2,7 @@
2
2
 
3
3
  import functools, itertools, operator
4
4
  from tinygrad.helpers import AMX, dedup, flatten, all_same, prod
5
- from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, graph_rewrite
6
- from tinygrad.codegen.symbolic import sym
5
+ from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp
7
6
 
8
7
  def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int:
9
8
  idx, mul = 0, 1
@@ -15,7 +14,7 @@ def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) ->
15
14
  def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]:
16
15
  return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
17
16
 
18
- @functools.lru_cache(None)
17
+ @functools.cache
19
18
  def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]:
20
19
  return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
21
20
 
@@ -50,6 +49,9 @@ def do_expand(root:UOp):
50
49
  if root.op is Ops.IF:
51
50
  # for the first arg of IF, just pass them through ignoring UNROLLS
52
51
  new_srcs.append(src)
52
+ elif root.op in {Ops.REDUCE, Ops.STORE} and src.op is Ops.RANGE:
53
+ # for any range args of REDUCE, pass them through
54
+ new_srcs.append(src)
53
55
  elif src.dtype.count > 1:
54
56
  # put any input dtype > 1 grouped together
55
57
  new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
@@ -81,12 +83,9 @@ expander = PatternMatcher([
81
83
  (UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
82
84
  lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
83
85
  # do expansion
84
- (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
85
- Ops.VECTORIZE, Ops.IF), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
86
+ (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX,
87
+ Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
86
88
  (UPat(Ops.CONTRACT, name="con"), do_contract),
87
- # vectorize DEFINE_ACC
88
- (UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"),
89
- lambda acc,v: acc.replace(dtype=v.dtype, src=(acc.src[0].broadcast(v.dtype.count),)+acc.src[1:])),
90
89
  # BARRIERs aren't actually expanded
91
90
  (UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
92
91
  lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)),
@@ -98,7 +97,7 @@ expander = PatternMatcher([
98
97
  ])
99
98
 
100
99
  def create_gate(root:UOp) -> UOp|None:
101
- @functools.lru_cache(None)
100
+ @functools.cache
102
101
  def _gate_srcs(u:UOp, gate:UOp) -> UOp:
103
102
  if u.op is Ops.BARRIER: return u
104
103
  if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER:
@@ -112,10 +111,3 @@ migrate_indexing = PatternMatcher([
112
111
  # create gate MUST BE BEFORE expander
113
112
  (UPat(Ops.STORE, name="root"), create_gate),
114
113
  ])
115
-
116
- def expand_rewrite(sink:UOp) -> UOp:
117
- # initial symbolic + migrate indexing (remove this)
118
- sink = graph_rewrite(sink, sym+migrate_indexing)
119
-
120
- # expand
121
- return graph_rewrite(sink, sym+expander)