tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +35 -37
  4. tinygrad/codegen/linearize.py +19 -10
  5. tinygrad/codegen/lowerer.py +31 -8
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +10 -0
  8. tinygrad/device.py +28 -11
  9. tinygrad/dtype.py +12 -3
  10. tinygrad/engine/jit.py +3 -2
  11. tinygrad/engine/multi.py +0 -1
  12. tinygrad/engine/realize.py +7 -4
  13. tinygrad/engine/schedule.py +227 -255
  14. tinygrad/engine/search.py +20 -27
  15. tinygrad/gradient.py +3 -0
  16. tinygrad/helpers.py +7 -4
  17. tinygrad/nn/state.py +2 -2
  18. tinygrad/ops.py +64 -329
  19. tinygrad/renderer/__init__.py +19 -3
  20. tinygrad/renderer/cstyle.py +39 -18
  21. tinygrad/renderer/llvmir.py +55 -18
  22. tinygrad/renderer/ptx.py +6 -2
  23. tinygrad/renderer/wgsl.py +20 -12
  24. tinygrad/runtime/autogen/libc.py +404 -71
  25. tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
  26. tinygrad/runtime/autogen/webgpu.py +6985 -0
  27. tinygrad/runtime/graph/metal.py +28 -29
  28. tinygrad/runtime/ops_amd.py +37 -34
  29. tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
  30. tinygrad/runtime/ops_disk.py +1 -1
  31. tinygrad/runtime/ops_dsp.py +59 -33
  32. tinygrad/runtime/ops_llvm.py +14 -12
  33. tinygrad/runtime/ops_metal.py +78 -62
  34. tinygrad/runtime/ops_nv.py +9 -6
  35. tinygrad/runtime/ops_python.py +5 -5
  36. tinygrad/runtime/ops_webgpu.py +200 -38
  37. tinygrad/runtime/support/am/amdev.py +23 -11
  38. tinygrad/runtime/support/am/ip.py +10 -10
  39. tinygrad/runtime/support/elf.py +2 -0
  40. tinygrad/runtime/support/hcq.py +7 -5
  41. tinygrad/runtime/support/llvm.py +8 -14
  42. tinygrad/shape/shapetracker.py +3 -2
  43. tinygrad/shape/view.py +2 -3
  44. tinygrad/spec.py +21 -20
  45. tinygrad/tensor.py +150 -90
  46. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  47. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  48. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  49. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  50. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  51. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  52. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  53. tinygrad/viz/index.html +544 -0
  54. tinygrad/viz/perfetto.html +178 -0
  55. tinygrad/viz/serve.py +205 -0
  56. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
  57. tinygrad-0.10.2.dist-info/RECORD +99 -0
  58. tinygrad/codegen/rewriter.py +0 -516
  59. tinygrad-0.10.1.dist-info/RECORD +0 -86
  60. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  61. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
  62. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,12 @@
1
1
  # the job of the lowerer is to do indexing
2
- import functools, itertools, operator
2
+ import functools, itertools, operator, math
3
3
  from dataclasses import dataclass
4
4
  from typing import cast
5
5
  from tinygrad.dtype import dtypes, PtrDType
6
6
  from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
7
7
  from tinygrad.renderer import Renderer
8
8
  from tinygrad.helpers import all_int, prod, partition, flatten, unwrap
9
+ from tinygrad.codegen.expander import expand_rewrite
9
10
 
10
11
  # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
11
12
  def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
@@ -15,23 +16,37 @@ def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> l
15
16
  return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
16
17
 
17
18
  # ***** indexing *****
18
-
19
- def _limit_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
19
+ def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
20
20
  # TODO: symbolic shape
21
21
  if not all_int(dims): return dims
22
22
  while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
23
23
  for i,m in enumerate(max_sizes):
24
- if dims[i] * dims[i+1] <= m:
24
+ if i < (len(dims)-1) and dims[i] * dims[i+1] <= m:
25
25
  dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
26
26
  break
27
- else: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
27
+ else: return None
28
28
  return dims
29
29
 
30
+ def _split_dims(dims, max_sizes):
31
+ if all(d <= m for d,m in zip(dims, max_sizes)): return dims
32
+ _dims = list(dims) + [1]*(3-len(dims))
33
+ for i in range(len(_dims)):
34
+ while _dims[i] > max_sizes[i]:
35
+ div = next((d for d in range(2, math.ceil(math.sqrt(_dims[i])) + 1) if (_dims[i] % d) == 0), 1)
36
+ if div == 1: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
37
+ _dims[i], _dims[(i+1)%len(_dims)] = _dims[i]//div, _dims[(i+1)%len(_dims)]*div
38
+ return tuple(_dims[:2] if _dims[2] == 1 else _dims[0] if _dims[1:3] == [1,1] else _dims)
39
+
30
40
  def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|None, reverse=False) -> list[UOp]:
31
41
  if reverse: dims = dims[::-1]
32
- limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
42
+ # try to group first: (a, b, c, d) -> (ab, c, d)
43
+ limited = (grouped if (grouped := _group_dims(dims, max_sizes)) else dims) if max_sizes is not None else dims
44
+ # check if grouping failed
45
+ if max_sizes is not None and len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
46
+ # try to split up dims: (a,) -> (b, c)
47
+ if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims
33
48
  ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
34
- if limited != dims:
49
+ if len(limited) < len(dims):
35
50
  ret = []
36
51
  if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}")
37
52
  for idx, contraction_group in zip(raw_idxs, contraction):
@@ -39,6 +54,11 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
39
54
  ret.append(idx % dims[c])
40
55
  idx //= dims[c]
41
56
  ret.append(idx)
57
+ elif len(limited) > len(dims):
58
+ a, b = len(limited), len(dims)
59
+ if a == 2 and b == 1: ret = [raw_idxs[0] * limited[1] + raw_idxs[1]]
60
+ if a == 3 and b == 1: ret = [raw_idxs[0] * (limited[1] * limited[2]) + raw_idxs[1] * limited[2] + raw_idxs[2]]
61
+ if a == 3 and b == 2: ret = [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
42
62
  return ret[::-1] if reverse else ret
43
63
 
44
64
  @dataclass
@@ -135,4 +155,7 @@ pm_lowerer = PatternMatcher([
135
155
  (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
136
156
  ])
137
157
 
138
- def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
158
+ def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
159
+ sink = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
160
+ # expand_rewrite turns this into a vectorized program
161
+ return expand_rewrite(sink)
@@ -0,0 +1,476 @@
1
+ # all of symbolic lives here now
2
+ from typing import Any, Literal, cast
3
+ import math, operator, struct, functools
4
+ from collections import defaultdict
5
+ from tinygrad.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
6
+ from tinygrad.dtype import ConstType, dtypes, PtrDType
7
+ from tinygrad.helpers import partition, all_same, prod, getenv, DEBUG, flatten
8
+ from tinygrad.codegen.transcendental import xpow
9
+
10
+ # ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
11
+
12
+ def simplify_pow(x:UOp, c:UOp) -> UOp|None:
13
+ if c.arg < 0: return x.reciprocal().pow(-c)
14
+ if c.arg == 0: return x.const_like(1)
15
+ if int(c.arg-0.5)+0.5 == c.arg: return x.pow(c.const_like(c.arg-0.5)) * x.sqrt()
16
+ if int(c.arg) == c.arg: return (y := x.pow(c.const_like(c.arg//2))) * y * (x if c.arg%2 == 1 else 1)
17
+ return None
18
+
19
+ def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
20
+ if (from_fmt:=c.dtype.scalar().fmt) is None or (to_fmt:=root.dtype.scalar().fmt) is None: return None
21
+ def convert(v:Any): return struct.unpack(to_fmt, struct.pack(from_fmt, v))[0]
22
+ return root.const_like(convert(c.arg) if root.dtype.count == 1 else tuple(map(convert, c.arg)))
23
+
24
+ symbolic_simple = PatternMatcher([
25
+ # ** self folding **
26
+ (UPat.var("x") + 0, lambda x: x), # x+0 -> x
27
+ (UPat.var("x") * 1, lambda x: x), # x*1 -> x
28
+ (UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
29
+ (UPat.var("x") // 1, lambda x: x), # x//1 -> x
30
+ (UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
31
+ (UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
32
+ ((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
33
+ ((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
34
+ (UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
35
+ ((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
36
+ lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
37
+ (UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
38
+ (UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
39
+ (UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
40
+ (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
41
+ (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
42
+ # ** zero folding **
43
+ (UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
44
+ (UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
45
+ lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
46
+ # x*0 -> 0 or 0*x -> 0
47
+ # if x is nan or inf it should render the nan value.
48
+ # NOTE: this can be wrong for loaded NaN
49
+ (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
50
+ # ** constant folding **
51
+ # TODO: add const folding for Ops.THREEFRY
52
+ (UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))),
53
+ lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False)) if a.op is not Ops.THREEFRY else None),
54
+ # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
55
+ (UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
56
+ (UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
57
+ (UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
58
+ # *** cast/bitcast ***
59
+ (UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
60
+ (UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
61
+ (UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast),
62
+ # ** pow **
63
+ (UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow),
64
+ # positive const ** x
65
+ (UPat.cvar("c", vec=False).alu(Ops.POW, UPat.var("x")), lambda c,x: c if c.arg == 1 else (x*math.log2(c.arg)).exp2() if c.arg > 0 else None),
66
+ ])
67
+
68
+ # ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********
69
+
70
+ def split_uop(x:UOp, sep:Ops):
71
+ if x.op is sep:
72
+ for s in x.src: yield from split_uop(s, sep)
73
+ else: yield x
74
+
75
+ def fold_unrolled_divs(divs:UOp):
76
+ # div pattern in unrolled arange
77
+ # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
78
+ add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None
79
+ for u in add_chain:
80
+ if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
81
+ if denominator is None: denominator = u.src[1].arg
82
+ if denominator != u.src[1].arg: return None
83
+ # assumed CONST is the last of an ADD
84
+ if (s0:=u.src[0]).op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
85
+ seen_const.append(s0.src[1].arg)
86
+ s0 = s0.src[0]
87
+ else: seen_const.append(0)
88
+ if ans is None: ans = s0
89
+ if ans is not s0: return None
90
+ if denominator is None: return None
91
+ # the first (denominator-len(seen_const)) terms may have been folded to 0 already
92
+ for i in range(denominator-len(seen_const)):
93
+ if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
94
+ return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
95
+
96
+ def lt_folding(x:UOp, c:int) -> UOp|None:
97
+ p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
98
+ if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
99
+ return cast(UOp, functools.reduce(operator.add, np).divides(d))<(c//d)
100
+ return None
101
+
102
+ def canonicalize_simplex(X:UOp) -> UOp|None:
103
+ # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
104
+ # returns x0 + x1 + ... in such case, or None if not
105
+ changed, ret = False, []
106
+ for u in split_uop(X, Ops.ADD):
107
+ # assumed the const is the last src of MUL
108
+ if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
109
+ changed = True
110
+ u = u.src[0]
111
+ if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None
112
+ ret.append(u)
113
+ return functools.reduce(operator.add, ret) if changed else None
114
+
115
+ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None:
116
+ # simplify x // y or x % y, None means no change
117
+ # simple cancel div/mod case
118
+ if y.vmin != 0 != y.vmax and (q:=x.vmin//y.vmin) == x.vmin//y.vmax == x.vmax//y.vmin == x.vmax//y.vmax:
119
+ return x - q*y if which is Ops.MOD else x.const_like(q)
120
+
121
+ if (y.op is not Ops.CONST) or ((c := y.arg) <= 0) or (x.dtype.count > 1): return None
122
+
123
+ svars, factors, quotients, remainders, gcd, div, const, offset, something_changed = [], [], [], [], c, 1, 0, 0, False
124
+ for u in split_uop(x, Ops.ADD):
125
+ if u.op is Ops.MOD and which is Ops.MOD and u.src[1].op is Ops.CONST and u.src[1].arg%c == 0:
126
+ u = u.src[0]
127
+ something_changed = True
128
+ v: UOp = u.divides(f:=u.const_factor())
129
+ q, r = divmod(f, c)
130
+ if r==0 or ((which is Ops.MOD or split_rem or u.op is Ops.CONST) and r!=f): something_changed = True
131
+ offset += r*v.vmin
132
+ if u.op is Ops.CONST: const += f
133
+ else: # div is the smallest common divisor of all terms
134
+ if f > 1 and c % f == 0 and (div == 1 or div > f): div = f
135
+ gcd = math.gcd(r, gcd)
136
+ factors.append(f); svars.append(v); quotients.append(q); remainders.append(r) # noqa: E702
137
+
138
+ lbound = ubound = offset = offset % c
139
+ # we can fold if the expression has only one non-constant term and this term can only take on two values
140
+ if len(svars)==1 and (v:=svars[0]).vmax-v.vmin == 1:
141
+ r = (offset+remainders[0])%c - offset%c
142
+ offset -= r * v.vmin
143
+ if which is Ops.MOD: return r*v + offset
144
+ return (factors[0]-r)//c * v + (const-offset)//c
145
+
146
+ # a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
147
+ # within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
148
+ for (r, v) in zip(remainders, svars):
149
+ if r > c//2:
150
+ if (lbound := lbound + (r:=r-c) * (v.vmax-v.vmin)) < 0: break
151
+ elif (ubound := ubound + r * (v.vmax-v.vmin)) >= c: break
152
+ offset -= r * v.vmin # determine what the new offset would be
153
+ else: # vmin/vmax of the remainder is between 0 and c, we can remove the mod/div
154
+ remainders = [min(r, r-c, key=abs) for r in remainders]
155
+ if which is Ops.MOD: return functools.reduce(operator.add, [r*v for r,v in zip(remainders,svars)], x.const_like(offset))
156
+ return functools.reduce(operator.add, [(f-r)//c * v for f,r,v in zip(factors, remainders,svars)], x.const_like((const-offset)//c))
157
+
158
+ if gcd != 1: something_changed = True
159
+ if not something_changed:
160
+ if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, UOp.const(dtypes.int, div), Ops.IDIV)) is not None: return newx//(c//div)
161
+ return None
162
+ quo, rem = x.const_like(const//c), x.const_like((const%c)//gcd)
163
+ for q,r,f,v in zip(quotients, remainders, factors, svars):
164
+ if which is Ops.IDIV and (not split_rem) and r!=0:
165
+ rem += f//gcd * v
166
+ else:
167
+ rem += r//gcd * v
168
+ quo += q * v
169
+
170
+ if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
171
+ return rem//(c//gcd)+quo
172
+
173
+ symbolic = symbolic_simple+PatternMatcher([
174
+ # ** COMMUTATIVE flipping (only for ints) **
175
+ (UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
176
+ # ** boolean algebra **
177
+ (UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
178
+ # ** combine terms **
179
+ (UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
180
+ ((UPat.var("y") + UPat.var("x") * UPat.cvar("c0")) + UPat.var("x") * UPat.cvar("c1"), lambda x,y,c0,c1: y+x*(c0+c1)),
181
+ (UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
182
+ ((UPat.var("y") + UPat.var("x")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: y+x*(c+1)),
183
+ (UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
184
+ ((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2),
185
+ ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3)
186
+ (-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
187
+ # a conditional with the same results either way is a noop, also fold const conditionals
188
+ (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
189
+ (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
190
+ # alu of two where with same conds can combine, only do if true branch or false branch is const
191
+ (UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
192
+ lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
193
+ # ALU min==max -> CONST (slow!)
194
+ (UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
195
+ # max folding
196
+ (UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
197
+ # TODO: why does this rule break beautiful_mnist?
198
+ #((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z),
199
+ #((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
200
+ # ** two stage ALU folding **
201
+ *((UPat.var("x").alu(op, UPat.cvar("c1")).alu(op, UPat.cvar("c2")).named("f"),
202
+ lambda f,x,c1,c2: x.alu(f.op,c1.alu(f.op,c2))) for op in GroupOp.Associative),
203
+ ((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
204
+ ((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
205
+ # ** lt **
206
+ # c0*x<c1 for positive int c0,c1
207
+ ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
208
+ lambda x,c0,c1: x<math.ceil(c1.arg/c0.arg) if c0.arg > 0 and c1.arg > 0 else None),
209
+ # c0*x<c1 for negative int c0 and non-positive c1
210
+ ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
211
+ lambda x,c0,c1: (-x)<(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
212
+ # x//c0<c1 for positive int c0
213
+ ((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False))<UPat.cvar("c1", vec=False),
214
+ lambda x,c0,c1: x<(c1.arg*c0.arg) if c0.arg > 0 else None),
215
+ # ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
216
+ (UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
217
+ (UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
218
+ # *** rules from symbolic ***
219
+ # unrolled arange div folding
220
+ (UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),
221
+ # generic lt folding
222
+ (UPat.var("x", dtypes.sints)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
223
+ # canonicalize a simplex with positive coefficients > 0
224
+ # not x < 1 -> X > 0
225
+ ((UPat.var("x", dtypes.ints)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
226
+ # ** div **
227
+ # div folding
228
+ ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)), # (x//c+a)//d -> (x+a*c)//(c*d)
229
+ (UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
230
+ # ** mod **
231
+ # mod folding
232
+ (UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
233
+ ])
234
+
235
+ symbolic_flat = symbolic+PatternMatcher([
236
+ # ** combine terms (opinionated) **
237
+ (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
238
+ # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
239
+ ((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
240
+ ])
241
+
242
+ # ******** we take a small aside to "simplify_valid" to rewrite valids ********
243
+
244
+ def parse_valid(valid:UOp) -> tuple[UOp, bool, int]:
245
+ # if it's X <= c, returns X, True, c
246
+ # if it's X >= c, returns X, False, c
247
+
248
+ # (X < c).ne(True) -> X >= c
249
+ if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
250
+ (s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
251
+ # X < c -> X <= c-1
252
+ if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, valid.src[1].arg-1
253
+ raise ValueError(f"not able to parse {valid=}")
254
+
255
+ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
256
+ # return None if valid is always False, otherwise the simplified uop (might be the same as input)
257
+
258
+ # first, parse valid into {expr: (lower_bound, upper_bound)}
259
+ bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
260
+ for stmt in split_uop(valid, Ops.AND):
261
+ try: expr, is_upper, c = parse_valid(stmt)
262
+ except ValueError: return uop # give up if we cannot parse the valid
263
+ bounds[expr][int(is_upper)] = c
264
+
265
+ # simplify uop given that valid is True
266
+ for expr,v in bounds.items():
267
+ # some expr has lower bound > upper bound -> valid is an empty set and we return None
268
+ if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
269
+
270
+ # every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
271
+ candidates = []
272
+ if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
273
+ # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
274
+ candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
275
+ # try checking the whole clause
276
+ if expr in uop.toposort:
277
+ candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
278
+
279
+ for candidate in candidates:
280
+ # if every branch in candidate gives the same simplified uop, we can rewrite the uop
281
+ newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
282
+ if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
283
+ if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
284
+ if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
285
+ elif all_same(newuops): uop = newuops[0]
286
+
287
+ return uop
288
+
289
+ def _valid_priority(v: UOp, valids:list[UOp]):
290
+ # we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
291
+ try: return sum(-1 if parse_valid(v)[0] in other.toposort else 0 for other in valids)
292
+ except ValueError: return 0
293
+
294
+ def simplify_valid(valid:UOp) -> UOp|None:
295
+ ret:list[UOp] = []
296
+ something_changed = False
297
+ valids = list(split_uop(valid, Ops.AND))
298
+ for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
299
+ ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
300
+ if ret[-1] is not stmt: something_changed = True
301
+ return functools.reduce(operator.and_, ret) if something_changed else None
302
+
303
+ # ***** threefry *****
304
+
305
+ def threefry2x32(x: UOp, key: UOp):
306
+ # split x into two uint32, since x in a uint64
307
+ x0, x1 = (x & 0xffffffff).cast(dtypes.uint32), ((x // 2**32) & 0xffffffff).cast(dtypes.uint32)
308
+
309
+ rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
310
+ key0, key1 = (key & 0xffffffff).cast(dtypes.uint32), ((key // 2**32) & 0xffffffff).cast(dtypes.uint32)
311
+ ks = [key1, key0 ^ key1 ^ 0x1BD11BDA, key0]
312
+ xr = [x0 + ks[-1], x1 + ks[0]]
313
+ for i in range(5):
314
+ for r in rotations[i % 2]: xr[0], xr[1] = (x0 := xr[0] + xr[1]), x0 ^ ((xr[1] * 2**r) + (xr[1] // 2**(32 - r)))
315
+ xr = [(xr[0] + ks[i % 3]), (xr[1] + ks[(i + 1) % 3] + i + 1)]
316
+
317
+ return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64)
318
+
319
+ # ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ********
320
+
321
+ def loop_collapse(compval, multconst, rng:UOp, acc:UOp, idx2=None,idx3=None,extra=None,vec=None,ne=None,
322
+ add=UOp.const(dtypes.int, 0), mul:UOp=UOp.const(dtypes.int, 1)):
323
+ if getenv("DISABLE_LOOP_COLLAPSE") or rng not in acc.src: return None # must be the right REDUCE
324
+ loop_start, loop_end = rng.src
325
+ if loop_start.arg != 0:
326
+ # TODO: support and test this with other mul and loop_starts
327
+ if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mul:{mul.arg} loop_start:{loop_start.arg}")
328
+ return None
329
+ if idx2 is not None: add = add + idx2
330
+ if idx3 is not None: add = add + idx3
331
+ if vec is not None:
332
+ # add, mul, loop_start, loop_end
333
+ def dvec(x:UOp):
334
+ if x.op is Ops.CONST: return UOp.const(x.dtype.vec(vec.dtype.count), x.arg)
335
+ return UOp(Ops.VECTORIZE, x.dtype.vec(vec.dtype.count), src=(x,)*vec.dtype.count)
336
+ add, mul, loop_start, loop_end = dvec(add), dvec(mul), dvec(loop_start), dvec(loop_end)
337
+ if mul.vmin > 0 and ne is not None:
338
+ comprange = UOp.minimum(loop_end, UOp.maximum((add-compval)//mul + (loop_end-loop_start), loop_start))
339
+ elif mul.vmax < 0 and ne is None:
340
+ comprange = UOp.minimum(loop_end, UOp.maximum((add-compval-mul)//mul + (loop_end-loop_start), loop_start))
341
+ else:
342
+ return None
343
+ new_reduce_op = comprange.cast(multconst.dtype) * multconst
344
+ # TODO: what does it mean to have the same numbered DEFINE_ACC with different ranges?
345
+ new_acc = acc.replace(src=acc.src[0:1]+tuple(x for x in acc.src[1:] if x is not rng))
346
+ ret = new_acc.assign(new_acc+new_reduce_op)
347
+ if extra is not None: ret = ret + acc.assign(acc+extra)
348
+ return ret
349
+
350
+ def index_collapse(idx:UOp,rng:UOp,buf:UOp,ld:UOp,acc:UOp,add=UOp.const(dtypes.int, 0),mul=UOp.const(dtypes.int, 1)):
351
+ if rng not in acc.src: return None
352
+ new_load = UOp.load(buf.index(add+mul*idx, (idx >= rng.src[0]) & (idx < rng.src[1])), dtype=ld.dtype)
353
+ new_acc = acc.replace(src=acc.src[0:1]+tuple(x for x in acc.src[1:] if x is not rng))
354
+ return new_acc.assign(new_acc+new_load)
355
+
356
+ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
357
+ reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.toposort)
358
+ if len(reduce_unparented) == 0: return None
359
+ new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented))
360
+ ret = new_acc.assign(new_acc.alu(alu.op, ret))
361
+ if alu.op is Ops.ADD:
362
+ for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
363
+ return ret
364
+
365
+ def gep_through_wmma(gep:UOp, wmma:UOp):
366
+ out_sz = prod(x[1] for x in wmma.arg[6][-1])
367
+ wmma_idxs = gep.arg[::out_sz]
368
+ for i in range(out_sz):
369
+ if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None
370
+ tsrcs = []
371
+ for s,sz in zip(wmma.src, wmma.arg[6]):
372
+ src_args = []
373
+ ssz = prod(x[1] for x in sz)
374
+ for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz))
375
+ tsrcs.append(s.gep(tuple(src_args)))
376
+ return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
377
+
378
+ acc_pat, rng_pat = UPat(Ops.DEFINE_ACC, name="acc"), UPat(Ops.RANGE, name="rng")
379
+ rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat)
380
+
381
+ index_load = UPat.var("buf").index(rng_aug).load(name="ld")
382
+
383
+ arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
384
+ arange_m = ((arange_augrng<UPat.cvar("compval"))!=UPat(Ops.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0))
385
+
386
+ # this moves the accumulation variable down an unrolled add chain which allows for more efficient accumulation using mulacc
387
+ mulacc_unrolled = PatternMatcher([(UPat.var("x")+UPat.var("y")+acc_pat, lambda x,y,acc: (acc+x)+y if y.op is not Ops.DEFINE_ACC else None)])
388
+
389
+ # this is symbolic 2.0
390
+ sym = symbolic_flat+PatternMatcher([
391
+ # self ASSIGN is just self
392
+ (UPat(Ops.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
393
+ # VECTORIZE/CONST, VECTORIZE/GEP
394
+ (UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
395
+ (UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat.var("x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
396
+ # reorder ALU/VECTORIZE
397
+ (UPat(GroupOp.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
398
+ lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(alu.op, alu.dtype.scalar(), (x,y)),)*alu.dtype.count)),
399
+ # VECTORIZE of a single element is just that element
400
+ (UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
401
+ # VECTORIZE void is SINK
402
+ (UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b),
403
+ (UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
404
+ # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
405
+ (UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
406
+ lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
407
+ (UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
408
+ lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
409
+ (UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
410
+ (UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
411
+ # push all GEPs through ALUs (fix arange stuff)
412
+ (UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
413
+ lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
414
+ if not isinstance(gep.dtype, PtrDType) else None),
415
+ # push some GEPs through WMMAs
416
+ (UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
417
+ # CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
418
+ (UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
419
+ if not isinstance(x.dtype, PtrDType) else None),
420
+ # tensor core with a 0 input is acc
421
+ (UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
422
+ (UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
423
+ # tensor core cleanups
424
+ (UPat.var("add") + UPat(Ops.WMMA, name="wmma"),
425
+ lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
426
+ # threefry + remove longs
427
+ (UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32),
428
+ (UPat.var('x', dtypes.uint32).cast(dtypes.uint64).cast(dtypes.uint32), lambda x: x), # cast there and back is noop (TODO: genericize)
429
+ ((UPat.var('x', dtypes.uint64)&0xFFFFFFFF).cast(dtypes.uint32), lambda x: x.cast(dtypes.uint32)), # cast does truncation
430
+ (((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
431
+ (((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
432
+ # hacks for threefry long removal when padded (TODO: genericize)
433
+ (UPat.var('x', dtypes.uint32).cast(dtypes.uint64) * UPat.var('y').where(UPat.const(dtypes.uint64, 1<<32), UPat.const(dtypes.uint64, 0)),
434
+ lambda x,y: y.where(x, UOp.const(dtypes.uint32, 0)).cast(dtypes.uint64) * (1<<32)),
435
+ ((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
436
+ lambda x,y: y.where(x.cast(dtypes.uint32), UOp.const(dtypes.uint32, 0))),
437
+ # arange loop folding
438
+ (acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse),
439
+ # indexing, with cast or where
440
+ (acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
441
+ (acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
442
+ # parentless reduce # TODO: add MUL
443
+ (acc_pat.assign(UPat((Ops.ADD, Ops.MAX), src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
444
+ # ** self folding **
445
+ (UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
446
+ (UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
447
+ # x!=0 -> (bool)x
448
+ (UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
449
+ # ** where **
450
+ # push cast to branches
451
+ (UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))),
452
+ # ** pow **
453
+ ((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))),
454
+ # ** load/store folding **
455
+ (UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
456
+ (UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))),
457
+ lambda index, gate, alt: UOp.store(index.src[0].index(index.src[1], gate), alt)),
458
+ # fold gated LOAD/STORE
459
+ (UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True
460
+ (UPat().index(UPat(), UPat.const(dtypes.bool, False)).named("idx"), lambda idx: idx.const_like(0)), # False -> NULL pointer
461
+ (UPat(Ops.LOAD, src=(UPat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0
462
+ (UPat(Ops.STORE, src=(UPat.const(None, 0),), allow_any_len=True), lambda: UOp(Ops.NOOP)), # NULL pointer store does nothing
463
+ # remove NOOPs from SINK
464
+ (UPat(Ops.SINK, name="root"),
465
+ lambda root: UOp(Ops.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not Ops.NOOP)) != len(root.src) else None),
466
+ # remove VECTORIZE from SINK/BARRIER
467
+ (UPat(Ops.BARRIER, src=(UPat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)),
468
+ (UPat(Ops.SINK, name="root"),
469
+ lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.UNROLL} else (x,) for x in root.src)), root.arg)
470
+ if any(x.op in {Ops.SINK, Ops.UNROLL} for x in root.src) else None),
471
+ ((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
472
+ ((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
473
+ (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)
474
+ (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)),
475
+ (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y),
476
+ ])
@@ -254,3 +254,13 @@ def xlog2(d:UOp) -> UOp:
254
254
  r = d.ne(d).where(r.const_like(math.nan), r)
255
255
  # log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
256
256
  return d.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf))
257
+
258
+ def xpow(base:UOp, exponent:UOp) -> UOp:
259
+ # start with b ** e = exp2(e * log2(b))
260
+ ret = (base < 0).where(-base, base).log2().mul(exponent).exp2()
261
+ # negative base adjustment: nan for non-integer exponent and -1 for odd exponent
262
+ non_int = exponent != exponent.cast(dtypes.int32).cast(exponent.dtype)
263
+ adj = non_int.where(ret.const_like(math.nan),
264
+ (exponent < 0).where(-exponent, exponent).cast(dtypes.int32).mod(2).cast(dtypes.bool).where(ret.const_like(-1), ret.const_like(1)))
265
+ # fix 0 ** 0 = 1
266
+ return (base.eq(0) & exponent.eq(0)).where(ret.const_like(1), ret * (base < 0).where(adj, ret.const_like(1)))