tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,11 @@
1
1
  # all of symbolic lives here now
2
- from typing import Any, Literal, cast
2
+ from typing import Any, cast
3
3
  import math, operator, struct, functools
4
4
  from collections import defaultdict
5
- from tinygrad.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
6
- from tinygrad.dtype import ConstType, dtypes, PtrDType
7
- from tinygrad.helpers import partition, all_same, prod, getenv, DEBUG, flatten
8
- from tinygrad.codegen.transcendental import xpow
5
+ from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
6
+ from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace, can_safe_cast
7
+ from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod, CORRECT_DIVMOD_FOLDING
8
+ from tinygrad.uop.decompositions import xpow
9
9
 
10
10
  # ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
11
11
 
@@ -18,6 +18,7 @@ def simplify_pow(x:UOp, c:UOp) -> UOp|None:
18
18
 
19
19
  def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
20
20
  if (from_fmt:=c.dtype.scalar().fmt) is None or (to_fmt:=root.dtype.scalar().fmt) is None: return None
21
+ if c.dtype.itemsize != root.dtype.itemsize: return None
21
22
  def convert(v:Any): return struct.unpack(to_fmt, struct.pack(from_fmt, v))[0]
22
23
  return root.const_like(convert(c.arg) if root.dtype.count == 1 else tuple(map(convert, c.arg)))
23
24
 
@@ -25,6 +26,7 @@ symbolic_simple = PatternMatcher([
25
26
  # ** self folding **
26
27
  (UPat.var("x") + 0, lambda x: x), # x+0 -> x
27
28
  (UPat.var("x") * 1, lambda x: x), # x*1 -> x
29
+ (UPat.var("x", dtype=dtypes.ints) ^ 0, lambda x: x), # x^0 -> x
28
30
  (UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
29
31
  (UPat.var("x") // 1, lambda x: x), # x//1 -> x
30
32
  (UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
@@ -39,8 +41,10 @@ symbolic_simple = PatternMatcher([
39
41
  (UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
40
42
  (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
41
43
  (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
44
+ (UPat.var("x", dtype=dtypes.ints+(dtypes.bool,)).trunc(), lambda x: x),
42
45
  # ** zero folding **
43
46
  (UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
47
+ (UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0
44
48
  (UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
45
49
  lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
46
50
  # x*0 -> 0 or 0*x -> 0
@@ -49,20 +53,38 @@ symbolic_simple = PatternMatcher([
49
53
  (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
50
54
  # ** constant folding **
51
55
  # TODO: add const folding for Ops.THREEFRY
52
- (UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))),
53
- lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False)) if a.op is not Ops.THREEFRY else None),
56
+ (UPat(GroupOp.Unary, src=(UPat((Ops.VCONST, Ops.CONST)),), name="a"), lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg], False))),
57
+ (UPat(GroupOp.Binary-{Ops.THREEFRY}, src=(UPat((Ops.VCONST, Ops.CONST)),)*2, name="a"),
58
+ lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg], False))),
59
+ (UPat(GroupOp.Ternary, src=(UPat((Ops.VCONST, Ops.CONST)),)*3, name="a"),
60
+ lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg, a.src[2].arg], False))),
54
61
  # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
55
62
  (UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
56
63
  (UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
57
64
  (UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
58
65
  # *** cast/bitcast ***
59
- (UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
66
+ (UPat(Ops.CAST, name="root", src=(UPat.cvar("c"),)), lambda root, c: root.const_like(c.arg)),
60
67
  (UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
61
68
  (UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast),
69
+ # b.cast(a).cast(b) -> b if a preserves all values in b
70
+ (UPat.var('x').cast().named('a').cast().named('b'), lambda x,a,b: x if x.dtype == b.dtype and can_safe_cast(b.dtype, a.dtype) else None),
62
71
  # ** pow **
63
72
  (UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow),
64
73
  # positive const ** x
65
74
  (UPat.cvar("c", vec=False).alu(Ops.POW, UPat.var("x")), lambda c,x: c if c.arg == 1 else (x*math.log2(c.arg)).exp2() if c.arg > 0 else None),
75
+ # rules for threefry
76
+ ((UPat.var('x', dtypes.uint64)&0xFFFFFFFF).cast(dtypes.uint32), lambda x: x.cast(dtypes.uint32)&0xFFFFFFFF), # TODO: why is the and needed?
77
+ (((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
78
+ (((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
79
+ # hacks for threefry long removal when padded (TODO: genericize)
80
+ (UPat.var('x', dtypes.uint32).cast(dtypes.uint64) * UPat.var('y').where(UPat.const(dtypes.uint64, 1<<32), UPat.const(dtypes.uint64, 0)),
81
+ lambda x,y: y.where(x, 0).cast(dtypes.uint64) * (1<<32)),
82
+ ((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
83
+ lambda x,y: y.where(x.cast(dtypes.uint32), 0)),
84
+ # new decomp rules for threefry
85
+ (((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
86
+ (((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x),
87
+ (UPat.var('b').where(UPat.var('x', dtypes.uint32).cast(dtypes.uint64), UPat.const(dtypes.uint64, 0)).cast(dtypes.uint32), lambda b,x: b.where(x,0))
66
88
  ])
67
89
 
68
90
  # ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********
@@ -72,26 +94,31 @@ def split_uop(x:UOp, sep:Ops):
72
94
  for s in x.src: yield from split_uop(s, sep)
73
95
  else: yield x
74
96
 
75
- def fold_unrolled_divs(divs:UOp):
97
+ def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
76
98
  # div pattern in unrolled arange
77
99
  # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
78
- add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None
79
- for u in add_chain:
100
+ seen_const, ans = [], None
101
+ for u in split_uop(divs, Ops.ADD):
102
+ if fac!=1:
103
+ if u.op is not Ops.MUL or u.src[1].op is not Ops.CONST or u.src[1].arg != fac: return None
104
+ u = u.src[0]
80
105
  if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
81
- if denominator is None: denominator = u.src[1].arg
82
106
  if denominator != u.src[1].arg: return None
107
+ if (s0:=u.src[0]).vmin < 0: return None
83
108
  # assumed CONST is the last of an ADD
84
- if (s0:=u.src[0]).op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
109
+ if s0.op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
85
110
  seen_const.append(s0.src[1].arg)
86
111
  s0 = s0.src[0]
87
112
  else: seen_const.append(0)
88
113
  if ans is None: ans = s0
89
114
  if ans is not s0: return None
90
- if denominator is None: return None
115
+ if ans is None: return None
91
116
  # the first (denominator-len(seen_const)) terms may have been folded to 0 already
92
117
  for i in range(denominator-len(seen_const)):
93
118
  if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
94
- return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
119
+ if sorted(seen_const)==list(range(denominator)):
120
+ return fac*ans
121
+ return None
95
122
 
96
123
  def lt_folding(x:UOp, c:int) -> UOp|None:
97
124
  p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
@@ -112,67 +139,139 @@ def canonicalize_simplex(X:UOp) -> UOp|None:
112
139
  ret.append(u)
113
140
  return functools.reduce(operator.add, ret) if changed else None
114
141
 
115
- def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None:
116
- # simplify x // y or x % y, None means no change
117
- # simple cancel div/mod case
118
- if y.vmin != 0 != y.vmax and (q:=x.vmin//y.vmin) == x.vmin//y.vmax == x.vmax//y.vmin == x.vmax//y.vmax:
119
- return x - q*y if which is Ops.MOD else x.const_like(q)
120
-
121
- if (y.op is not Ops.CONST) or ((c := y.arg) <= 0) or (x.dtype.count > 1): return None
142
+ def cancel_divmod(d: UOp, x: UOp, y: UOp) -> UOp|None:
143
+ # simple cancel div/mod case when the range of the numerator lies within a single denominator interval
144
+ x_min, x_max, y_min, y_max = x.vmin, x.vmax, y.vmin, y.vmax
145
+ assert isinstance(x_min, int) and isinstance(x_max, int) and isinstance(y_min, int) and isinstance(y_max, int)
146
+ if y_min==y_max==0: raise ZeroDivisionError(f"{'Division' if d.op is Ops.IDIV else 'Mod'} by zero trying to rewrite {x.alu(d.op, y)}")
147
+ if y_min*y_max > 0 and (q:=cdiv(x_min,y_min)) == cdiv(x_min,y_max) == cdiv(x_max,y_min) == cdiv(x_max,y_max):
148
+ return x - q*y if d.op is Ops.MOD else d.const_like(q)
149
+ return None
122
150
 
123
- svars, factors, quotients, remainders, gcd, div, const, offset, something_changed = [], [], [], [], c, 1, 0, 0, False
151
+ def remove_nested_mod(m: UOp, x: UOp, y: UOp) -> UOp|None:
152
+ # remove nested mod in case the inner mod is a multiple of the outer mod
153
+ # example: (a%4 + b)%2 -> (a+b)%2
154
+ if ((c := y.arg) < 0) or x.vmin<0: return None
155
+ new_xs = []
156
+ something_changed = False
124
157
  for u in split_uop(x, Ops.ADD):
125
- if u.op is Ops.MOD and which is Ops.MOD and u.src[1].op is Ops.CONST and u.src[1].arg%c == 0:
126
- u = u.src[0]
127
- something_changed = True
128
- v: UOp = u.divides(f:=u.const_factor())
129
- q, r = divmod(f, c)
130
- if r==0 or ((which is Ops.MOD or split_rem or u.op is Ops.CONST) and r!=f): something_changed = True
131
- offset += r*v.vmin
132
- if u.op is Ops.CONST: const += f
133
- else: # div is the smallest common divisor of all terms
134
- if f > 1 and c % f == 0 and (div == 1 or div > f): div = f
135
- gcd = math.gcd(r, gcd)
136
- factors.append(f); svars.append(v); quotients.append(q); remainders.append(r) # noqa: E702
137
-
138
- lbound = ubound = offset = offset % c
158
+ if u.op is Ops.MOD:
159
+ if u.src[1].divides(c) is not None:
160
+ something_changed = True
161
+ u = u.src[0]
162
+ new_xs.append(u)
163
+ new_x: UOp = functools.reduce(operator.add, new_xs)
164
+ if something_changed and new_x.vmin>=0: return new_x % y
165
+ return None
166
+
167
+ def fold_binary_numerator(d: UOp, x: UOp, y: UOp) -> UOp|None:
139
168
  # we can fold if the expression has only one non-constant term and this term can only take on two values
140
- if len(svars)==1 and (v:=svars[0]).vmax-v.vmin == 1:
141
- r = (offset+remainders[0])%c - offset%c
142
- offset -= r * v.vmin
143
- if which is Ops.MOD: return r*v + offset
144
- return (factors[0]-r)//c * v + (const-offset)//c
169
+ if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
170
+ x,const = x.pop_const()
171
+ terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
172
+ if len(terms)==1 and (v:=terms[0]).vmax-v.vmin == 1:
173
+ y1 = cmod(factors[0]*v.vmin+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmin+const, c) # type: ignore
174
+ y2 = cmod(factors[0]*v.vmax+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmax+const, c) # type: ignore
175
+ return (y2-y1)*(v-v.vmin) + y1
176
+ return None
145
177
 
146
- # a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
178
+ def fold_divmod_congruence(d: UOp, x: UOp, y: UOp) -> UOp|None:
147
179
  # within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
148
- for (r, v) in zip(remainders, svars):
149
- if r > c//2:
150
- if (lbound := lbound + (r:=r-c) * (v.vmax-v.vmin)) < 0: break
151
- elif (ubound := ubound + r * (v.vmax-v.vmin)) >= c: break
152
- offset -= r * v.vmin # determine what the new offset would be
153
- else: # vmin/vmax of the remainder is between 0 and c, we can remove the mod/div
154
- remainders = [min(r, r-c, key=abs) for r in remainders]
155
- if which is Ops.MOD: return functools.reduce(operator.add, [r*v for r,v in zip(remainders,svars)], x.const_like(offset))
156
- return functools.reduce(operator.add, [(f-r)//c * v for f,r,v in zip(factors, remainders,svars)], x.const_like((const-offset)//c))
157
-
158
- if gcd != 1: something_changed = True
159
- if not something_changed:
160
- if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, UOp.const(dtypes.int, div), Ops.IDIV)) is not None: return newx//(c//div)
161
- return None
180
+ if (x.vmin<0 and CORRECT_DIVMOD_FOLDING) or ((c := y.arg) < 0) or (x.dtype.count > 1): return None
181
+ x,const = x.pop_const()
182
+ terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
183
+ # a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
184
+ rems = [min((r:=f%c), r-c, key=abs) for f in factors]
185
+ if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c==rem.vmax//c and all(f > 0 for f in factors):
186
+ if d.op is Ops.MOD: return rem - rem.vmin//c*c
187
+ return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + (const-const%c+rem.vmin//c*c)//c
188
+ return None
189
+
190
+ def divide_by_gcd(d: UOp, x: UOp, y: UOp) -> UOp|None:
191
+ # x//y -> (x//gcd)//(y//gcd) or x%y -> gcd*(x//gcd)%(y//gcd)
192
+ terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
193
+ if (gcd := math.gcd(y.arg, *factors)) == 1: return None
194
+ ret = sum(f//gcd * v for f,v in zip(factors, terms)).alu(d.op, y.const_like(y.arg//gcd))
195
+ return ret*gcd if d.op is Ops.MOD else ret
196
+
197
+ def nest_div_by_smallest_factor(d: UOp, x: UOp, y: UOp) -> UOp|None:
198
+ # we try and nest the div and see if it allows the numerator to be simplified
199
+ if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
200
+ factors = [u.const_factor() for u in split_uop(x.pop_const()[0], Ops.ADD)]
201
+ # div is the smallest factor of the denominator (greater than 1) out of all "factors"
202
+ # TODO: there are better ways to pick `div`, this sometimes adds extra divisions
203
+ # TODO: add same optimization for mod
204
+ div = min([y.arg]+[abs(f) for f in factors if abs(f) > 1 and (c%f)==0])
205
+ if (1 < div < c) and (newxs:=(newx:=(x//div)).simplify()) is not newx and x.vmin>=0 and newx.vmin>=0: return newxs//(c//div)
206
+ return None
207
+
208
+ def simplify_remainder(d: UOp, x: UOp, y: UOp) -> UOp|None:
209
+ # we try and take out the quotient and see if it allows the numerator to be simplified
210
+ if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
211
+ x_no_const,const = x.pop_const()
212
+ terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x_no_const, Ops.ADD)])
213
+ quotients, remainders = zip(*[divmod(f, c) for f in factors])
214
+ gcd = math.gcd(c, *remainders) # gcd without const!
215
+ if const%c==const and gcd==1 and not any(r==0 or (r!=f and d.op is Ops.MOD) for r,f in zip(remainders, factors)): return None
216
+
162
217
  quo, rem = x.const_like(const//c), x.const_like((const%c)//gcd)
163
- for q,r,f,v in zip(quotients, remainders, factors, svars):
164
- if which is Ops.IDIV and (not split_rem) and r!=0:
218
+ for q,r,f,v in zip(quotients, remainders, factors, terms):
219
+ if d.op is Ops.IDIV and r!=0:
165
220
  rem += f//gcd * v
166
221
  else:
167
222
  rem += r//gcd * v
168
223
  quo += q * v
169
224
 
170
- if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
225
+ # if numerator before/after is negative, and it has remainder, don't simplify because C divmod is different from python divmod.
226
+ if (x.vmin < 0 or rem.vmin < 0) and remainders: return None
227
+ if d.op is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
171
228
  return rem//(c//gcd)+quo
172
229
 
173
- symbolic = symbolic_simple+PatternMatcher([
230
+ def gep_through_wmma(gep:UOp, wmma:UOp):
231
+ out_sz = prod(x[1] for x in wmma.arg[6][-1])
232
+ wmma_idxs = gep.arg[::out_sz]
233
+ for i in range(out_sz):
234
+ if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None
235
+ tsrcs = []
236
+ for s,sz in zip(wmma.src, wmma.arg[6]):
237
+ src_args = []
238
+ ssz = prod(x[1] for x in sz)
239
+ for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz))
240
+ tsrcs.append(s.gep(tuple(src_args)))
241
+ return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
242
+
243
+ gep_pushing = PatternMatcher([
244
+ # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
245
+ (UPat(Ops.GEP, name='g2').f(Ops.GEP, name='g1'),
246
+ lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(len(g1.arg))))),
247
+ (UPat(Ops.VECTORIZE, name='vec').f(Ops.GEP, name='gep'),
248
+ lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
249
+ (UPat.cvar("c", vec=False).f(Ops.GEP, name="gep"), lambda gep, c: gep.const_like(c.arg)),
250
+ (UPat(Ops.VCONST, name="c").f(Ops.GEP, name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
251
+ # GEP on void is skipped
252
+ (UPat(Ops.GEP, src=(UPat(dtype=dtypes.void, name="x"),)), lambda x: x),
253
+ # GEP in order is removed
254
+ (UPat(Ops.GEP, name="g"), lambda g: g.src[0] if not isinstance(g.dtype, PtrDType) and g.arg == tuple(range(g.src[0].dtype.count)) else None),
255
+ # push all GEPs through ALUs (fix arange stuff)
256
+ (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, name='gep'),
257
+ lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
258
+ if not isinstance(gep.dtype, PtrDType) else None),
259
+ # CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
260
+ (UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
261
+ if not isinstance(x.dtype, PtrDType) else None),
262
+ # VECTORIZE on same GEP
263
+ (UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),
264
+ # push some GEPs through WMMAs
265
+ (UPat(Ops.WMMA, name="wmma").f(Ops.GEP, name="gep"), gep_through_wmma),
266
+ ])
267
+
268
+ commutative = PatternMatcher([
174
269
  # ** COMMUTATIVE flipping (only for ints) **
270
+ # NOTE: this can break merging vector math by only flipping some of them
175
271
  (UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
272
+ ])
273
+
274
+ symbolic = symbolic_simple+commutative+PatternMatcher([
176
275
  # ** boolean algebra **
177
276
  (UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
178
277
  # ** combine terms **
@@ -187,11 +286,12 @@ symbolic = symbolic_simple+PatternMatcher([
187
286
  # a conditional with the same results either way is a noop, also fold const conditionals
188
287
  (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
189
288
  (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
289
+ (UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("t"), UPat.var("f")), lambda cond, t, f: cond.where(f,t)),
190
290
  # alu of two where with same conds can combine, only do if true branch or false branch is const
191
291
  (UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
192
292
  lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
193
- # ALU min==max -> CONST (slow!)
194
- (UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
293
+ # ALU/variable min==max -> CONST (slow!)
294
+ (UPat(GroupOp.ALU|{Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
195
295
  # max folding
196
296
  (UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
197
297
  # TODO: why does this rule break beautiful_mnist?
@@ -209,28 +309,42 @@ symbolic = symbolic_simple+PatternMatcher([
209
309
  # c0*x<c1 for negative int c0 and non-positive c1
210
310
  ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
211
311
  lambda x,c0,c1: (-x)<(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
212
- # x//c0<c1 for positive int c0
213
- ((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False))<UPat.cvar("c1", vec=False),
214
- lambda x,c0,c1: x<(c1.arg*c0.arg) if c0.arg > 0 else None),
312
+ # x//d<c
313
+ ((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("d", vec=False))<UPat.cvar("c", vec=False),
314
+ lambda x,d,c: (x<(c.arg*d.arg) if c.arg > 0 else x<(c.arg*d.arg-(d.arg-1))) if d.arg > 0 else None),
215
315
  # ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
216
- (UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
217
- (UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
316
+ ((UPat.var("x") + UPat.cvar("c1")) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
317
+ ((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
218
318
  # *** rules from symbolic ***
219
319
  # unrolled arange div folding
220
- (UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),
320
+ ((UPat() + UPat()//UPat.cvar("d", vec=False)).named("divs"), lambda divs,d: fold_unrolled_divs(divs, d.arg)),
321
+ ((UPat() + (UPat()//UPat.cvar("d", vec=False))*UPat.cvar("c")).named("divs"), lambda divs,d,c: fold_unrolled_divs(divs, d.arg, c.arg)),
221
322
  # generic lt folding
222
323
  (UPat.var("x", dtypes.sints)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
324
+ (UPat.var("x", dtypes.sints)*-1 < UPat.var("y", dtypes.sints)*-1, lambda x,y: y<x),
223
325
  # canonicalize a simplex with positive coefficients > 0
224
326
  # not x < 1 -> X > 0
225
327
  ((UPat.var("x", dtypes.ints)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
226
328
  # ** div **
227
329
  # div folding
228
- ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)), # (x//c+a)//d -> (x+a*c)//(c*d)
229
- (UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
330
+ ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)
331
+ if c.vmin>0 and d.vmin>0 and ((x.vmin>=0 and a.vmin>=0) or (x.vmax<=0 and a.vmax<=0)) else None), # (x//c+a)//d -> (x+a*c)//(c*d)
332
+ (UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.var("y"))), cancel_divmod),
333
+ (UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_binary_numerator),
334
+ (UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_divmod_congruence),
335
+ (UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), divide_by_gcd),
336
+ (UPat(Ops.MOD, dtypes.sints, name="m", src=(UPat.var("x"), UPat.cvar("y", vec=False))), remove_nested_mod),
337
+ (UPat((Ops.IDIV), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), nest_div_by_smallest_factor),
338
+ (UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), simplify_remainder),
339
+ (UPat.var("x") // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None),
340
+ (UPat.var("x") // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <=0 else None),
341
+ ((UPat.var("x", dtypes.sints)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
342
+ lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),
230
343
  # ** mod **
231
344
  # mod folding
232
- (UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
233
- ])
345
+ (UPat.var("x") % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),
346
+ (UPat.var("x") % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None),
347
+ ])+gep_pushing
234
348
 
235
349
  symbolic_flat = symbolic+PatternMatcher([
236
350
  # ** combine terms (opinionated) **
@@ -262,19 +376,25 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
262
376
  except ValueError: return uop # give up if we cannot parse the valid
263
377
  bounds[expr][int(is_upper)] = c
264
378
 
379
+ # don't simplify any other gates, can lead to OOB, we substitute them back later
380
+ uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, arg=u) for u in uop.toposort() if u.op is Ops.INDEX}))
381
+
265
382
  # simplify uop given that valid is True
266
383
  for expr,v in bounds.items():
384
+ v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
267
385
  # some expr has lower bound > upper bound -> valid is an empty set and we return None
268
- if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
269
-
270
- # every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
386
+ if v0 > v1: return None
387
+ # whole node became a const
388
+ if v0 == v1:
389
+ uop = uop.substitute({expr:expr.const_like(v0)}).simplify()
390
+ continue
391
+ # every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
271
392
  candidates = []
272
- if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
393
+ if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
273
394
  # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
274
395
  candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
275
396
  # try checking the whole clause
276
- if expr in uop.toposort:
277
- candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
397
+ if expr in uop.toposort(): candidates.append([(expr, UOp.variable("fake", v0, v1, expr.dtype))])
278
398
 
279
399
  for candidate in candidates:
280
400
  # if every branch in candidate gives the same simplified uop, we can rewrite the uop
@@ -284,11 +404,13 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
284
404
  if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
285
405
  elif all_same(newuops): uop = newuops[0]
286
406
 
407
+ # put the loads back in
408
+ uop = uop.substitute({v:k for k,v in load_subs.items()})
287
409
  return uop
288
410
 
289
411
  def _valid_priority(v: UOp, valids:list[UOp]):
290
412
  # we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
291
- try: return sum(-1 if parse_valid(v)[0] in other.toposort else 0 for other in valids)
413
+ try: return sum(-1 if parse_valid(v)[0] in other.toposort() else 0 for other in valids)
292
414
  except ValueError: return 0
293
415
 
294
416
  def simplify_valid(valid:UOp) -> UOp|None:
@@ -296,100 +418,32 @@ def simplify_valid(valid:UOp) -> UOp|None:
296
418
  something_changed = False
297
419
  valids = list(split_uop(valid, Ops.AND))
298
420
  for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
421
+ # TODO: root cause this and test_simplify_valid_from_div
422
+ if stmt.op is Ops.CAST: return None
299
423
  ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
300
424
  if ret[-1] is not stmt: something_changed = True
301
425
  return functools.reduce(operator.and_, ret) if something_changed else None
302
426
 
303
- # ***** threefry *****
304
-
305
- def threefry2x32(x: UOp, key: UOp):
306
- # split x into two uint32, since x in a uint64
307
- x0, x1 = (x & 0xffffffff).cast(dtypes.uint32), ((x // 2**32) & 0xffffffff).cast(dtypes.uint32)
308
-
309
- rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
310
- key0, key1 = (key & 0xffffffff).cast(dtypes.uint32), ((key // 2**32) & 0xffffffff).cast(dtypes.uint32)
311
- ks = [key1, key0 ^ key1 ^ 0x1BD11BDA, key0]
312
- xr = [x0 + ks[-1], x1 + ks[0]]
313
- for i in range(5):
314
- for r in rotations[i % 2]: xr[0], xr[1] = (x0 := xr[0] + xr[1]), x0 ^ ((xr[1] * 2**r) + (xr[1] // 2**(32 - r)))
315
- xr = [(xr[0] + ks[i % 3]), (xr[1] + ks[(i + 1) % 3] + i + 1)]
316
-
317
- return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64)
318
-
319
427
  # ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ********
320
428
 
321
- def loop_collapse(compval, multconst, rng:UOp, acc:UOp, idx2=None,idx3=None,extra=None,vec=None,ne=None,
322
- add=UOp.const(dtypes.int, 0), mul:UOp=UOp.const(dtypes.int, 1)):
323
- if getenv("DISABLE_LOOP_COLLAPSE") or rng not in acc.src: return None # must be the right REDUCE
324
- loop_start, loop_end = rng.src
325
- if loop_start.arg != 0:
326
- # TODO: support and test this with other mul and loop_starts
327
- if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mul:{mul.arg} loop_start:{loop_start.arg}")
328
- return None
329
- if idx2 is not None: add = add + idx2
330
- if idx3 is not None: add = add + idx3
331
- if vec is not None:
332
- # add, mul, loop_start, loop_end
333
- def dvec(x:UOp):
334
- if x.op is Ops.CONST: return UOp.const(x.dtype.vec(vec.dtype.count), x.arg)
335
- return UOp(Ops.VECTORIZE, x.dtype.vec(vec.dtype.count), src=(x,)*vec.dtype.count)
336
- add, mul, loop_start, loop_end = dvec(add), dvec(mul), dvec(loop_start), dvec(loop_end)
337
- if mul.vmin > 0 and ne is not None:
338
- comprange = UOp.minimum(loop_end, UOp.maximum((add-compval)//mul + (loop_end-loop_start), loop_start))
339
- elif mul.vmax < 0 and ne is None:
340
- comprange = UOp.minimum(loop_end, UOp.maximum((add-compval-mul)//mul + (loop_end-loop_start), loop_start))
341
- else:
342
- return None
343
- new_reduce_op = comprange.cast(multconst.dtype) * multconst
344
- # TODO: what does it mean to have the same numbered DEFINE_ACC with different ranges?
345
- new_acc = acc.replace(src=acc.src[0:1]+tuple(x for x in acc.src[1:] if x is not rng))
346
- ret = new_acc.assign(new_acc+new_reduce_op)
347
- if extra is not None: ret = ret + acc.assign(acc+extra)
348
- return ret
349
-
350
- def index_collapse(idx:UOp,rng:UOp,buf:UOp,ld:UOp,acc:UOp,add=UOp.const(dtypes.int, 0),mul=UOp.const(dtypes.int, 1)):
351
- if rng not in acc.src: return None
352
- new_load = UOp.load(buf.index(add+mul*idx, (idx >= rng.src[0]) & (idx < rng.src[1])), dtype=ld.dtype)
353
- new_acc = acc.replace(src=acc.src[0:1]+tuple(x for x in acc.src[1:] if x is not rng))
354
- return new_acc.assign(new_acc+new_load)
355
-
356
- def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
357
- reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.toposort)
358
- if len(reduce_unparented) == 0: return None
359
- new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented))
360
- ret = new_acc.assign(new_acc.alu(alu.op, ret))
361
- if alu.op is Ops.ADD:
362
- for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
363
- return ret
364
-
365
- def gep_through_wmma(gep:UOp, wmma:UOp):
366
- out_sz = prod(x[1] for x in wmma.arg[6][-1])
367
- wmma_idxs = gep.arg[::out_sz]
368
- for i in range(out_sz):
369
- if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None
370
- tsrcs = []
371
- for s,sz in zip(wmma.src, wmma.arg[6]):
372
- src_args = []
373
- ssz = prod(x[1] for x in sz)
374
- for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz))
375
- tsrcs.append(s.gep(tuple(src_args)))
376
- return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
377
-
378
- acc_pat, rng_pat = UPat(Ops.DEFINE_ACC, name="acc"), UPat(Ops.RANGE, name="rng")
379
- rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat)
380
-
381
- index_load = UPat.var("buf").index(rng_aug).load(name="ld")
382
-
383
- arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
384
- arange_m = ((arange_augrng<UPat.cvar("compval"))!=UPat(Ops.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0))
385
-
386
- # this moves the accumulation variable down an unrolled add chain which allows for more efficient accumulation using mulacc
387
- mulacc_unrolled = PatternMatcher([(UPat.var("x")+UPat.var("y")+acc_pat, lambda x,y,acc: (acc+x)+y if y.op is not Ops.DEFINE_ACC else None)])
429
+ def reduce_mul_chain(r:UOp):
430
+ if r.arg not in {Ops.ADD, Ops.MAX}: return None
431
+ if r.dtype != r.src[0].dtype: return None
432
+ inside, outside = [], []
433
+ for m in split_uop(r.src[0], Ops.MUL):
434
+ m_parents = m.toposort()
435
+ if all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m)
436
+ else: inside.append(m)
437
+ if len(outside) == 0: return None
438
+ return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside)
388
439
 
389
440
  # this is symbolic 2.0
441
+ REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP}
442
+ REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
390
443
  sym = symbolic_flat+PatternMatcher([
391
- # self ASSIGN is just self
392
- (UPat(Ops.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
444
+ # LOAD/STORE -> NOOP
445
+ (UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
446
+ (UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
393
447
  # VECTORIZE/CONST, VECTORIZE/GEP
394
448
  (UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
395
449
  (UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat.var("x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
@@ -401,76 +455,45 @@ sym = symbolic_flat+PatternMatcher([
401
455
  # VECTORIZE void is SINK
402
456
  (UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b),
403
457
  (UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
404
- # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
405
- (UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
406
- lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
407
- (UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
408
- lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
409
- (UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
410
- (UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
411
- # push all GEPs through ALUs (fix arange stuff)
412
- (UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
413
- lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
414
- if not isinstance(gep.dtype, PtrDType) else None),
415
- # push some GEPs through WMMAs
416
- (UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
417
- # CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
418
- (UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
419
- if not isinstance(x.dtype, PtrDType) else None),
420
458
  # tensor core with a 0 input is acc
421
459
  (UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
422
460
  (UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
423
- # tensor core cleanups
424
- (UPat.var("add") + UPat(Ops.WMMA, name="wmma"),
425
- lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
426
- # threefry + remove longs
427
- (UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32),
428
- (UPat.var('x', dtypes.uint32).cast(dtypes.uint64).cast(dtypes.uint32), lambda x: x), # cast there and back is noop (TODO: genericize)
429
- ((UPat.var('x', dtypes.uint64)&0xFFFFFFFF).cast(dtypes.uint32), lambda x: x.cast(dtypes.uint32)), # cast does truncation
430
- (((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
431
- (((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
432
- # hacks for threefry long removal when padded (TODO: genericize)
433
- (UPat.var('x', dtypes.uint32).cast(dtypes.uint64) * UPat.var('y').where(UPat.const(dtypes.uint64, 1<<32), UPat.const(dtypes.uint64, 0)),
434
- lambda x,y: y.where(x, UOp.const(dtypes.uint32, 0)).cast(dtypes.uint64) * (1<<32)),
435
- ((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
436
- lambda x,y: y.where(x.cast(dtypes.uint32), UOp.const(dtypes.uint32, 0))),
437
- # arange loop folding
438
- (acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse),
439
- # indexing, with cast or where
440
- (acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
441
- (acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
442
- # parentless reduce # TODO: add MUL
443
- (acc_pat.assign(UPat((Ops.ADD, Ops.MAX), src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
444
461
  # ** self folding **
445
- (UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
446
- (UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
447
462
  # x!=0 -> (bool)x
448
463
  (UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
449
464
  # ** where **
450
465
  # push cast to branches
451
466
  (UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))),
467
+ # a.where(b.where(c, d), d) -> (a & b).where(c, d)
468
+ (UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
452
469
  # ** pow **
453
470
  ((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))),
471
+ # index true is index without op
472
+ (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
454
473
  # ** load/store folding **
455
474
  (UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
456
- (UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))),
457
- lambda index, gate, alt: UOp.store(index.src[0].index(index.src[1], gate), alt)),
475
+ (UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"),
476
+ UPat.load(UPat(Ops.INDEX, name="index"))), allow_any_len=True, name="store"),
477
+ lambda index, gate, alt, store: UOp.store(index.src[0].index(index.src[1], gate), alt, *store.src[2:])),
458
478
  # fold gated LOAD/STORE
459
479
  (UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True
460
- (UPat().index(UPat(), UPat.const(dtypes.bool, False)).named("idx"), lambda idx: idx.const_like(0)), # False -> NULL pointer
461
- (UPat(Ops.LOAD, src=(UPat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0
462
- (UPat(Ops.STORE, src=(UPat.const(None, 0),), allow_any_len=True), lambda: UOp(Ops.NOOP)), # NULL pointer store does nothing
463
- # remove NOOPs from SINK
464
- (UPat(Ops.SINK, name="root"),
465
- lambda root: UOp(Ops.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not Ops.NOOP)) != len(root.src) else None),
466
- # remove VECTORIZE from SINK/BARRIER
467
- (UPat(Ops.BARRIER, src=(UPat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)),
480
+ (UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat(), UPat.const(dtypes.bool, False)).or_casted(),), allow_any_len=True, name="x"),
481
+ lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # NULL pointer store does nothing. NULL pointer load produces 0
482
+ # remove VECTORIZE from SINK/BARRIER. TODO: SINK/BARRIER are really the same thing at GLOBAL/LOCAL levels
483
+ (UPat(Ops.BARRIER, name="root"),
484
+ lambda root: UOp(Ops.BARRIER, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_BARRIER else (x,) for x in root.src)), root.arg)
485
+ if any(x.op in REMOVE_FROM_BARRIER for x in root.src) else None),
468
486
  (UPat(Ops.SINK, name="root"),
469
- lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.UNROLL} else (x,) for x in root.src)), root.arg)
470
- if any(x.op in {Ops.SINK, Ops.UNROLL} for x in root.src) else None),
487
+ lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_SINK else (x,) for x in root.src)), root.arg)
488
+ if any(x.op in REMOVE_FROM_SINK for x in root.src) else None),
471
489
  ((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
472
490
  ((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
491
+ ((UPat.var("x") * UPat.cvar("c")).reciprocal(), lambda x,c: x.reciprocal()*c.reciprocal()), # 1/(x*c) -> (1/c)*(1/x)
473
492
  (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)
474
493
  (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)),
475
494
  (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y),
495
+ # move const multiply after REDUCE (NOTE: the mul chain can do this, but only if it's a same dtype reduce)
496
+ ((UPat.var("x")*UPat.cvar("c", vec=False)).reduce(arg=Ops.ADD, name="r", allow_any_len=True), lambda x,c,r: r.replace(src=(x,)+r.src[1:])*c.arg),
497
+ # reduce mul chain, move muls after the reduce
498
+ (UPat(Ops.MUL).reduce(name="r", allow_any_len=True), reduce_mul_chain),
476
499
  ])