tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,476 @@
|
|
1
|
+
# all of symbolic lives here now
|
2
|
+
from typing import Any, Literal, cast
|
3
|
+
import math, operator, struct, functools
|
4
|
+
from collections import defaultdict
|
5
|
+
from tinygrad.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
6
|
+
from tinygrad.dtype import ConstType, dtypes, PtrDType
|
7
|
+
from tinygrad.helpers import partition, all_same, prod, getenv, DEBUG, flatten
|
8
|
+
from tinygrad.codegen.transcendental import xpow
|
9
|
+
|
10
|
+
# ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
|
11
|
+
|
12
|
+
def simplify_pow(x:UOp, c:UOp) -> UOp|None:
|
13
|
+
if c.arg < 0: return x.reciprocal().pow(-c)
|
14
|
+
if c.arg == 0: return x.const_like(1)
|
15
|
+
if int(c.arg-0.5)+0.5 == c.arg: return x.pow(c.const_like(c.arg-0.5)) * x.sqrt()
|
16
|
+
if int(c.arg) == c.arg: return (y := x.pow(c.const_like(c.arg//2))) * y * (x if c.arg%2 == 1 else 1)
|
17
|
+
return None
|
18
|
+
|
19
|
+
def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
|
20
|
+
if (from_fmt:=c.dtype.scalar().fmt) is None or (to_fmt:=root.dtype.scalar().fmt) is None: return None
|
21
|
+
def convert(v:Any): return struct.unpack(to_fmt, struct.pack(from_fmt, v))[0]
|
22
|
+
return root.const_like(convert(c.arg) if root.dtype.count == 1 else tuple(map(convert, c.arg)))
|
23
|
+
|
24
|
+
symbolic_simple = PatternMatcher([
|
25
|
+
# ** self folding **
|
26
|
+
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
|
27
|
+
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
|
28
|
+
(UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
|
29
|
+
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
|
30
|
+
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
|
31
|
+
(UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
|
32
|
+
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
33
|
+
((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
|
34
|
+
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
|
35
|
+
((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
|
36
|
+
lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
|
37
|
+
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
|
38
|
+
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
|
39
|
+
(UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
|
40
|
+
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
|
41
|
+
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
|
42
|
+
# ** zero folding **
|
43
|
+
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
|
44
|
+
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
|
45
|
+
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
|
46
|
+
# x*0 -> 0 or 0*x -> 0
|
47
|
+
# if x is nan or inf it should render the nan value.
|
48
|
+
# NOTE: this can be wrong for loaded NaN
|
49
|
+
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
50
|
+
# ** constant folding **
|
51
|
+
# TODO: add const folding for Ops.THREEFRY
|
52
|
+
(UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))),
|
53
|
+
lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False)) if a.op is not Ops.THREEFRY else None),
|
54
|
+
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
55
|
+
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
|
56
|
+
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
|
57
|
+
(UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
|
58
|
+
# *** cast/bitcast ***
|
59
|
+
(UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
|
60
|
+
(UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
|
61
|
+
(UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast),
|
62
|
+
# ** pow **
|
63
|
+
(UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow),
|
64
|
+
# positive const ** x
|
65
|
+
(UPat.cvar("c", vec=False).alu(Ops.POW, UPat.var("x")), lambda c,x: c if c.arg == 1 else (x*math.log2(c.arg)).exp2() if c.arg > 0 else None),
|
66
|
+
])
|
67
|
+
|
68
|
+
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********
|
69
|
+
|
70
|
+
def split_uop(x:UOp, sep:Ops):
|
71
|
+
if x.op is sep:
|
72
|
+
for s in x.src: yield from split_uop(s, sep)
|
73
|
+
else: yield x
|
74
|
+
|
75
|
+
def fold_unrolled_divs(divs:UOp):
|
76
|
+
# div pattern in unrolled arange
|
77
|
+
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
|
78
|
+
add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None
|
79
|
+
for u in add_chain:
|
80
|
+
if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
|
81
|
+
if denominator is None: denominator = u.src[1].arg
|
82
|
+
if denominator != u.src[1].arg: return None
|
83
|
+
# assumed CONST is the last of an ADD
|
84
|
+
if (s0:=u.src[0]).op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
|
85
|
+
seen_const.append(s0.src[1].arg)
|
86
|
+
s0 = s0.src[0]
|
87
|
+
else: seen_const.append(0)
|
88
|
+
if ans is None: ans = s0
|
89
|
+
if ans is not s0: return None
|
90
|
+
if denominator is None: return None
|
91
|
+
# the first (denominator-len(seen_const)) terms may have been folded to 0 already
|
92
|
+
for i in range(denominator-len(seen_const)):
|
93
|
+
if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
|
94
|
+
return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
|
95
|
+
|
96
|
+
def lt_folding(x:UOp, c:int) -> UOp|None:
|
97
|
+
p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
|
98
|
+
if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
|
99
|
+
return cast(UOp, functools.reduce(operator.add, np).divides(d))<(c//d)
|
100
|
+
return None
|
101
|
+
|
102
|
+
def canonicalize_simplex(X:UOp) -> UOp|None:
|
103
|
+
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
|
104
|
+
# returns x0 + x1 + ... in such case, or None if not
|
105
|
+
changed, ret = False, []
|
106
|
+
for u in split_uop(X, Ops.ADD):
|
107
|
+
# assumed the const is the last src of MUL
|
108
|
+
if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
|
109
|
+
changed = True
|
110
|
+
u = u.src[0]
|
111
|
+
if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None
|
112
|
+
ret.append(u)
|
113
|
+
return functools.reduce(operator.add, ret) if changed else None
|
114
|
+
|
115
|
+
def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None:
|
116
|
+
# simplify x // y or x % y, None means no change
|
117
|
+
# simple cancel div/mod case
|
118
|
+
if y.vmin != 0 != y.vmax and (q:=x.vmin//y.vmin) == x.vmin//y.vmax == x.vmax//y.vmin == x.vmax//y.vmax:
|
119
|
+
return x - q*y if which is Ops.MOD else x.const_like(q)
|
120
|
+
|
121
|
+
if (y.op is not Ops.CONST) or ((c := y.arg) <= 0) or (x.dtype.count > 1): return None
|
122
|
+
|
123
|
+
svars, factors, quotients, remainders, gcd, div, const, offset, something_changed = [], [], [], [], c, 1, 0, 0, False
|
124
|
+
for u in split_uop(x, Ops.ADD):
|
125
|
+
if u.op is Ops.MOD and which is Ops.MOD and u.src[1].op is Ops.CONST and u.src[1].arg%c == 0:
|
126
|
+
u = u.src[0]
|
127
|
+
something_changed = True
|
128
|
+
v: UOp = u.divides(f:=u.const_factor())
|
129
|
+
q, r = divmod(f, c)
|
130
|
+
if r==0 or ((which is Ops.MOD or split_rem or u.op is Ops.CONST) and r!=f): something_changed = True
|
131
|
+
offset += r*v.vmin
|
132
|
+
if u.op is Ops.CONST: const += f
|
133
|
+
else: # div is the smallest common divisor of all terms
|
134
|
+
if f > 1 and c % f == 0 and (div == 1 or div > f): div = f
|
135
|
+
gcd = math.gcd(r, gcd)
|
136
|
+
factors.append(f); svars.append(v); quotients.append(q); remainders.append(r) # noqa: E702
|
137
|
+
|
138
|
+
lbound = ubound = offset = offset % c
|
139
|
+
# we can fold if the expression has only one non-constant term and this term can only take on two values
|
140
|
+
if len(svars)==1 and (v:=svars[0]).vmax-v.vmin == 1:
|
141
|
+
r = (offset+remainders[0])%c - offset%c
|
142
|
+
offset -= r * v.vmin
|
143
|
+
if which is Ops.MOD: return r*v + offset
|
144
|
+
return (factors[0]-r)//c * v + (const-offset)//c
|
145
|
+
|
146
|
+
# a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
|
147
|
+
# within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
|
148
|
+
for (r, v) in zip(remainders, svars):
|
149
|
+
if r > c//2:
|
150
|
+
if (lbound := lbound + (r:=r-c) * (v.vmax-v.vmin)) < 0: break
|
151
|
+
elif (ubound := ubound + r * (v.vmax-v.vmin)) >= c: break
|
152
|
+
offset -= r * v.vmin # determine what the new offset would be
|
153
|
+
else: # vmin/vmax of the remainder is between 0 and c, we can remove the mod/div
|
154
|
+
remainders = [min(r, r-c, key=abs) for r in remainders]
|
155
|
+
if which is Ops.MOD: return functools.reduce(operator.add, [r*v for r,v in zip(remainders,svars)], x.const_like(offset))
|
156
|
+
return functools.reduce(operator.add, [(f-r)//c * v for f,r,v in zip(factors, remainders,svars)], x.const_like((const-offset)//c))
|
157
|
+
|
158
|
+
if gcd != 1: something_changed = True
|
159
|
+
if not something_changed:
|
160
|
+
if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, UOp.const(dtypes.int, div), Ops.IDIV)) is not None: return newx//(c//div)
|
161
|
+
return None
|
162
|
+
quo, rem = x.const_like(const//c), x.const_like((const%c)//gcd)
|
163
|
+
for q,r,f,v in zip(quotients, remainders, factors, svars):
|
164
|
+
if which is Ops.IDIV and (not split_rem) and r!=0:
|
165
|
+
rem += f//gcd * v
|
166
|
+
else:
|
167
|
+
rem += r//gcd * v
|
168
|
+
quo += q * v
|
169
|
+
|
170
|
+
if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
|
171
|
+
return rem//(c//gcd)+quo
|
172
|
+
|
173
|
+
symbolic = symbolic_simple+PatternMatcher([
|
174
|
+
# ** COMMUTATIVE flipping (only for ints) **
|
175
|
+
(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
176
|
+
# ** boolean algebra **
|
177
|
+
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
|
178
|
+
# ** combine terms **
|
179
|
+
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
|
180
|
+
((UPat.var("y") + UPat.var("x") * UPat.cvar("c0")) + UPat.var("x") * UPat.cvar("c1"), lambda x,y,c0,c1: y+x*(c0+c1)),
|
181
|
+
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
|
182
|
+
((UPat.var("y") + UPat.var("x")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: y+x*(c+1)),
|
183
|
+
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
|
184
|
+
((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2),
|
185
|
+
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3)
|
186
|
+
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
|
187
|
+
# a conditional with the same results either way is a noop, also fold const conditionals
|
188
|
+
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
189
|
+
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
190
|
+
# alu of two where with same conds can combine, only do if true branch or false branch is const
|
191
|
+
(UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
|
192
|
+
lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
|
193
|
+
# ALU min==max -> CONST (slow!)
|
194
|
+
(UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
195
|
+
# max folding
|
196
|
+
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
197
|
+
# TODO: why does this rule break beautiful_mnist?
|
198
|
+
#((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z),
|
199
|
+
#((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
|
200
|
+
# ** two stage ALU folding **
|
201
|
+
*((UPat.var("x").alu(op, UPat.cvar("c1")).alu(op, UPat.cvar("c2")).named("f"),
|
202
|
+
lambda f,x,c1,c2: x.alu(f.op,c1.alu(f.op,c2))) for op in GroupOp.Associative),
|
203
|
+
((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
|
204
|
+
((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
|
205
|
+
# ** lt **
|
206
|
+
# c0*x<c1 for positive int c0,c1
|
207
|
+
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
|
208
|
+
lambda x,c0,c1: x<math.ceil(c1.arg/c0.arg) if c0.arg > 0 and c1.arg > 0 else None),
|
209
|
+
# c0*x<c1 for negative int c0 and non-positive c1
|
210
|
+
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
|
211
|
+
lambda x,c0,c1: (-x)<(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
|
212
|
+
# x//c0<c1 for positive int c0
|
213
|
+
((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False))<UPat.cvar("c1", vec=False),
|
214
|
+
lambda x,c0,c1: x<(c1.arg*c0.arg) if c0.arg > 0 else None),
|
215
|
+
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
|
216
|
+
(UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
|
217
|
+
(UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
|
218
|
+
# *** rules from symbolic ***
|
219
|
+
# unrolled arange div folding
|
220
|
+
(UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),
|
221
|
+
# generic lt folding
|
222
|
+
(UPat.var("x", dtypes.sints)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
223
|
+
# canonicalize a simplex with positive coefficients > 0
|
224
|
+
# not x < 1 -> X > 0
|
225
|
+
((UPat.var("x", dtypes.ints)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
|
226
|
+
# ** div **
|
227
|
+
# div folding
|
228
|
+
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)), # (x//c+a)//d -> (x+a*c)//(c*d)
|
229
|
+
(UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
|
230
|
+
# ** mod **
|
231
|
+
# mod folding
|
232
|
+
(UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
|
233
|
+
])
|
234
|
+
|
235
|
+
symbolic_flat = symbolic+PatternMatcher([
|
236
|
+
# ** combine terms (opinionated) **
|
237
|
+
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
238
|
+
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
|
239
|
+
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
240
|
+
])
|
241
|
+
|
242
|
+
# ******** we take a small aside to "simplify_valid" to rewrite valids ********
|
243
|
+
|
244
|
+
def parse_valid(valid:UOp) -> tuple[UOp, bool, int]:
|
245
|
+
# if it's X <= c, returns X, True, c
|
246
|
+
# if it's X >= c, returns X, False, c
|
247
|
+
|
248
|
+
# (X < c).ne(True) -> X >= c
|
249
|
+
if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
|
250
|
+
(s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
|
251
|
+
# X < c -> X <= c-1
|
252
|
+
if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, valid.src[1].arg-1
|
253
|
+
raise ValueError(f"not able to parse {valid=}")
|
254
|
+
|
255
|
+
def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
|
256
|
+
# return None if valid is always False, otherwise the simplified uop (might be the same as input)
|
257
|
+
|
258
|
+
# first, parse valid into {expr: (lower_bound, upper_bound)}
|
259
|
+
bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
|
260
|
+
for stmt in split_uop(valid, Ops.AND):
|
261
|
+
try: expr, is_upper, c = parse_valid(stmt)
|
262
|
+
except ValueError: return uop # give up if we cannot parse the valid
|
263
|
+
bounds[expr][int(is_upper)] = c
|
264
|
+
|
265
|
+
# simplify uop given that valid is True
|
266
|
+
for expr,v in bounds.items():
|
267
|
+
# some expr has lower bound > upper bound -> valid is an empty set and we return None
|
268
|
+
if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
|
269
|
+
|
270
|
+
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
271
|
+
candidates = []
|
272
|
+
if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
|
273
|
+
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
274
|
+
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
|
275
|
+
# try checking the whole clause
|
276
|
+
if expr in uop.toposort:
|
277
|
+
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
|
278
|
+
|
279
|
+
for candidate in candidates:
|
280
|
+
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
|
281
|
+
newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
|
282
|
+
if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
|
283
|
+
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
|
284
|
+
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
|
285
|
+
elif all_same(newuops): uop = newuops[0]
|
286
|
+
|
287
|
+
return uop
|
288
|
+
|
289
|
+
def _valid_priority(v: UOp, valids:list[UOp]):
|
290
|
+
# we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
|
291
|
+
try: return sum(-1 if parse_valid(v)[0] in other.toposort else 0 for other in valids)
|
292
|
+
except ValueError: return 0
|
293
|
+
|
294
|
+
def simplify_valid(valid:UOp) -> UOp|None:
|
295
|
+
ret:list[UOp] = []
|
296
|
+
something_changed = False
|
297
|
+
valids = list(split_uop(valid, Ops.AND))
|
298
|
+
for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
|
299
|
+
ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
|
300
|
+
if ret[-1] is not stmt: something_changed = True
|
301
|
+
return functools.reduce(operator.and_, ret) if something_changed else None
|
302
|
+
|
303
|
+
# ***** threefry *****
|
304
|
+
|
305
|
+
def threefry2x32(x: UOp, key: UOp):
|
306
|
+
# split x into two uint32, since x in a uint64
|
307
|
+
x0, x1 = (x & 0xffffffff).cast(dtypes.uint32), ((x // 2**32) & 0xffffffff).cast(dtypes.uint32)
|
308
|
+
|
309
|
+
rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
|
310
|
+
key0, key1 = (key & 0xffffffff).cast(dtypes.uint32), ((key // 2**32) & 0xffffffff).cast(dtypes.uint32)
|
311
|
+
ks = [key1, key0 ^ key1 ^ 0x1BD11BDA, key0]
|
312
|
+
xr = [x0 + ks[-1], x1 + ks[0]]
|
313
|
+
for i in range(5):
|
314
|
+
for r in rotations[i % 2]: xr[0], xr[1] = (x0 := xr[0] + xr[1]), x0 ^ ((xr[1] * 2**r) + (xr[1] // 2**(32 - r)))
|
315
|
+
xr = [(xr[0] + ks[i % 3]), (xr[1] + ks[(i + 1) % 3] + i + 1)]
|
316
|
+
|
317
|
+
return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64)
|
318
|
+
|
319
|
+
# ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ********
|
320
|
+
|
321
|
+
def loop_collapse(compval, multconst, rng:UOp, acc:UOp, idx2=None,idx3=None,extra=None,vec=None,ne=None,
|
322
|
+
add=UOp.const(dtypes.int, 0), mul:UOp=UOp.const(dtypes.int, 1)):
|
323
|
+
if getenv("DISABLE_LOOP_COLLAPSE") or rng not in acc.src: return None # must be the right REDUCE
|
324
|
+
loop_start, loop_end = rng.src
|
325
|
+
if loop_start.arg != 0:
|
326
|
+
# TODO: support and test this with other mul and loop_starts
|
327
|
+
if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mul:{mul.arg} loop_start:{loop_start.arg}")
|
328
|
+
return None
|
329
|
+
if idx2 is not None: add = add + idx2
|
330
|
+
if idx3 is not None: add = add + idx3
|
331
|
+
if vec is not None:
|
332
|
+
# add, mul, loop_start, loop_end
|
333
|
+
def dvec(x:UOp):
|
334
|
+
if x.op is Ops.CONST: return UOp.const(x.dtype.vec(vec.dtype.count), x.arg)
|
335
|
+
return UOp(Ops.VECTORIZE, x.dtype.vec(vec.dtype.count), src=(x,)*vec.dtype.count)
|
336
|
+
add, mul, loop_start, loop_end = dvec(add), dvec(mul), dvec(loop_start), dvec(loop_end)
|
337
|
+
if mul.vmin > 0 and ne is not None:
|
338
|
+
comprange = UOp.minimum(loop_end, UOp.maximum((add-compval)//mul + (loop_end-loop_start), loop_start))
|
339
|
+
elif mul.vmax < 0 and ne is None:
|
340
|
+
comprange = UOp.minimum(loop_end, UOp.maximum((add-compval-mul)//mul + (loop_end-loop_start), loop_start))
|
341
|
+
else:
|
342
|
+
return None
|
343
|
+
new_reduce_op = comprange.cast(multconst.dtype) * multconst
|
344
|
+
# TODO: what does it mean to have the same numbered DEFINE_ACC with different ranges?
|
345
|
+
new_acc = acc.replace(src=acc.src[0:1]+tuple(x for x in acc.src[1:] if x is not rng))
|
346
|
+
ret = new_acc.assign(new_acc+new_reduce_op)
|
347
|
+
if extra is not None: ret = ret + acc.assign(acc+extra)
|
348
|
+
return ret
|
349
|
+
|
350
|
+
def index_collapse(idx:UOp,rng:UOp,buf:UOp,ld:UOp,acc:UOp,add=UOp.const(dtypes.int, 0),mul=UOp.const(dtypes.int, 1)):
|
351
|
+
if rng not in acc.src: return None
|
352
|
+
new_load = UOp.load(buf.index(add+mul*idx, (idx >= rng.src[0]) & (idx < rng.src[1])), dtype=ld.dtype)
|
353
|
+
new_acc = acc.replace(src=acc.src[0:1]+tuple(x for x in acc.src[1:] if x is not rng))
|
354
|
+
return new_acc.assign(new_acc+new_load)
|
355
|
+
|
356
|
+
def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
|
357
|
+
reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.toposort)
|
358
|
+
if len(reduce_unparented) == 0: return None
|
359
|
+
new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented))
|
360
|
+
ret = new_acc.assign(new_acc.alu(alu.op, ret))
|
361
|
+
if alu.op is Ops.ADD:
|
362
|
+
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
363
|
+
return ret
|
364
|
+
|
365
|
+
def gep_through_wmma(gep:UOp, wmma:UOp):
|
366
|
+
out_sz = prod(x[1] for x in wmma.arg[6][-1])
|
367
|
+
wmma_idxs = gep.arg[::out_sz]
|
368
|
+
for i in range(out_sz):
|
369
|
+
if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None
|
370
|
+
tsrcs = []
|
371
|
+
for s,sz in zip(wmma.src, wmma.arg[6]):
|
372
|
+
src_args = []
|
373
|
+
ssz = prod(x[1] for x in sz)
|
374
|
+
for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz))
|
375
|
+
tsrcs.append(s.gep(tuple(src_args)))
|
376
|
+
return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
|
377
|
+
|
378
|
+
acc_pat, rng_pat = UPat(Ops.DEFINE_ACC, name="acc"), UPat(Ops.RANGE, name="rng")
|
379
|
+
rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat)
|
380
|
+
|
381
|
+
index_load = UPat.var("buf").index(rng_aug).load(name="ld")
|
382
|
+
|
383
|
+
arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
|
384
|
+
arange_m = ((arange_augrng<UPat.cvar("compval"))!=UPat(Ops.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0))
|
385
|
+
|
386
|
+
# this moves the accumulation variable down an unrolled add chain which allows for more efficient accumulation using mulacc
|
387
|
+
mulacc_unrolled = PatternMatcher([(UPat.var("x")+UPat.var("y")+acc_pat, lambda x,y,acc: (acc+x)+y if y.op is not Ops.DEFINE_ACC else None)])
|
388
|
+
|
389
|
+
# this is symbolic 2.0
|
390
|
+
sym = symbolic_flat+PatternMatcher([
|
391
|
+
# self ASSIGN is just self
|
392
|
+
(UPat(Ops.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
|
393
|
+
# VECTORIZE/CONST, VECTORIZE/GEP
|
394
|
+
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
|
395
|
+
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat.var("x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
|
396
|
+
# reorder ALU/VECTORIZE
|
397
|
+
(UPat(GroupOp.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
|
398
|
+
lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(alu.op, alu.dtype.scalar(), (x,y)),)*alu.dtype.count)),
|
399
|
+
# VECTORIZE of a single element is just that element
|
400
|
+
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
401
|
+
# VECTORIZE void is SINK
|
402
|
+
(UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b),
|
403
|
+
(UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
|
404
|
+
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
405
|
+
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
|
406
|
+
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
|
407
|
+
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
|
408
|
+
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
409
|
+
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
410
|
+
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
|
411
|
+
# push all GEPs through ALUs (fix arange stuff)
|
412
|
+
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
|
413
|
+
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
|
414
|
+
if not isinstance(gep.dtype, PtrDType) else None),
|
415
|
+
# push some GEPs through WMMAs
|
416
|
+
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
|
417
|
+
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
|
418
|
+
(UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
|
419
|
+
if not isinstance(x.dtype, PtrDType) else None),
|
420
|
+
# tensor core with a 0 input is acc
|
421
|
+
(UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
|
422
|
+
(UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
|
423
|
+
# tensor core cleanups
|
424
|
+
(UPat.var("add") + UPat(Ops.WMMA, name="wmma"),
|
425
|
+
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
426
|
+
# threefry + remove longs
|
427
|
+
(UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32),
|
428
|
+
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64).cast(dtypes.uint32), lambda x: x), # cast there and back is noop (TODO: genericize)
|
429
|
+
((UPat.var('x', dtypes.uint64)&0xFFFFFFFF).cast(dtypes.uint32), lambda x: x.cast(dtypes.uint32)), # cast does truncation
|
430
|
+
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
431
|
+
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
|
432
|
+
# hacks for threefry long removal when padded (TODO: genericize)
|
433
|
+
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64) * UPat.var('y').where(UPat.const(dtypes.uint64, 1<<32), UPat.const(dtypes.uint64, 0)),
|
434
|
+
lambda x,y: y.where(x, UOp.const(dtypes.uint32, 0)).cast(dtypes.uint64) * (1<<32)),
|
435
|
+
((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
|
436
|
+
lambda x,y: y.where(x.cast(dtypes.uint32), UOp.const(dtypes.uint32, 0))),
|
437
|
+
# arange loop folding
|
438
|
+
(acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse),
|
439
|
+
# indexing, with cast or where
|
440
|
+
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
|
441
|
+
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
|
442
|
+
# parentless reduce # TODO: add MUL
|
443
|
+
(acc_pat.assign(UPat((Ops.ADD, Ops.MAX), src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
|
444
|
+
# ** self folding **
|
445
|
+
(UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
|
446
|
+
(UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
|
447
|
+
# x!=0 -> (bool)x
|
448
|
+
(UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
|
449
|
+
# ** where **
|
450
|
+
# push cast to branches
|
451
|
+
(UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))),
|
452
|
+
# ** pow **
|
453
|
+
((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))),
|
454
|
+
# ** load/store folding **
|
455
|
+
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
|
456
|
+
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))),
|
457
|
+
lambda index, gate, alt: UOp.store(index.src[0].index(index.src[1], gate), alt)),
|
458
|
+
# fold gated LOAD/STORE
|
459
|
+
(UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True
|
460
|
+
(UPat().index(UPat(), UPat.const(dtypes.bool, False)).named("idx"), lambda idx: idx.const_like(0)), # False -> NULL pointer
|
461
|
+
(UPat(Ops.LOAD, src=(UPat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0
|
462
|
+
(UPat(Ops.STORE, src=(UPat.const(None, 0),), allow_any_len=True), lambda: UOp(Ops.NOOP)), # NULL pointer store does nothing
|
463
|
+
# remove NOOPs from SINK
|
464
|
+
(UPat(Ops.SINK, name="root"),
|
465
|
+
lambda root: UOp(Ops.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not Ops.NOOP)) != len(root.src) else None),
|
466
|
+
# remove VECTORIZE from SINK/BARRIER
|
467
|
+
(UPat(Ops.BARRIER, src=(UPat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)),
|
468
|
+
(UPat(Ops.SINK, name="root"),
|
469
|
+
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.UNROLL} else (x,) for x in root.src)), root.arg)
|
470
|
+
if any(x.op in {Ops.SINK, Ops.UNROLL} for x in root.src) else None),
|
471
|
+
((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
|
472
|
+
((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
|
473
|
+
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)
|
474
|
+
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)),
|
475
|
+
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y),
|
476
|
+
])
|
@@ -1,5 +1,4 @@
|
|
1
1
|
import math
|
2
|
-
from typing import Tuple
|
3
2
|
from tinygrad.dtype import dtypes, DType
|
4
3
|
from tinygrad.helpers import polyN
|
5
4
|
from tinygrad.ops import UOp
|
@@ -22,7 +21,7 @@ def shl(x:UOp, y:int) -> UOp: return x * (2**y)
|
|
22
21
|
def rintk(d:UOp) -> UOp:
|
23
22
|
"""round d:float to int away from 0"""
|
24
23
|
out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype]
|
25
|
-
return (d + d
|
24
|
+
return (d + (d<0.0).where(d.const_like(-0.5), d.const_like(0.5))).cast(out_dtype)
|
26
25
|
|
27
26
|
def pow2if(q:UOp, float_dtype:DType):
|
28
27
|
"""cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
|
@@ -49,7 +48,7 @@ def ldexp2k(d:UOp, e:UOp) -> UOp:
|
|
49
48
|
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in (dtypes.int16, dtypes.int32, dtypes.int64)
|
50
49
|
return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype)
|
51
50
|
|
52
|
-
def frexp(v:UOp) ->
|
51
|
+
def frexp(v:UOp) -> tuple[UOp, UOp]:
|
53
52
|
"""frexp(v) -> (mantissa, exponent) assuming v != 0"""
|
54
53
|
assert v.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
55
54
|
# m1 = masks for mantissa, m2 = masks to normalize the mantissa.
|
@@ -63,7 +62,7 @@ def frexp(v:UOp) -> Tuple[UOp, UOp]:
|
|
63
62
|
return mantissa, exp
|
64
63
|
|
65
64
|
# *** reduction algorithms for sine ***
|
66
|
-
def payne_hanek_reduction(d:UOp) ->
|
65
|
+
def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
|
67
66
|
"""
|
68
67
|
Performs Payne-Hanek Reduction: computes the remainder of `d` modulo pi/2 for the values `d` where
|
69
68
|
39800.0 <= d <= +Inf
|
@@ -110,9 +109,9 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
|
110
109
|
r = (p.cast(intermediate_dtype) * (3.4061215800865545e-19)).cast(d.dtype)
|
111
110
|
|
112
111
|
# if fraction >= 0.5, r -= pi/2, q += 1
|
113
|
-
return f
|
112
|
+
return (f<0.5).where(r, r - math.pi/2), (f<0.5).where(q, q + 1)
|
114
113
|
|
115
|
-
def cody_waite_reduction(d:UOp) ->
|
114
|
+
def cody_waite_reduction(d:UOp) -> tuple[UOp, UOp]:
|
116
115
|
"""
|
117
116
|
Performs Cody-Waite Reduction: computes the reminder of `d` modulo pi/2 for the values `d` where
|
118
117
|
0 <= abs(d) <= 39800.0
|
@@ -177,14 +176,14 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
|
|
177
176
|
# mask +-inf/nan as zero
|
178
177
|
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
|
179
178
|
# x_sign = sign(x)
|
180
|
-
x_sign = x.ne(0).where(x
|
179
|
+
x_sign = x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
|
181
180
|
x_abs = x * x_sign
|
182
181
|
r, q = (cody_waite_reduction if fast else payne_hanek_reduction)(x_abs)
|
183
182
|
if fast: result = sin_poly_small(r, q)
|
184
183
|
else:
|
185
184
|
# Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction.
|
186
185
|
r_small, q_small = cody_waite_reduction(x_abs)
|
187
|
-
result = x_abs
|
186
|
+
result = (x_abs<switch_over).where(sin_poly_small(r_small, q_small), sin_poly_large(r, q))
|
188
187
|
# adjusts the sign for abs(x)
|
189
188
|
result = result * x_sign
|
190
189
|
# sin(Inf) = NaN, sin(-Inf) = NaN, sin(NaN) = NaN
|
@@ -210,9 +209,9 @@ def xexp2(d:UOp) -> UOp:
|
|
210
209
|
u = ldexp2k(u, q) # u*2^q
|
211
210
|
upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype]
|
212
211
|
# Replace x >= upper with +inf
|
213
|
-
u = d
|
214
|
-
# Replace x
|
215
|
-
u = d
|
212
|
+
u = (d >= upper).where(d.const_like(math.inf), u)
|
213
|
+
# Replace x < lower with zero.
|
214
|
+
u = (d<lower).where(d.const_like(0.0), u)
|
216
215
|
# exp2(NaN) = NaN
|
217
216
|
return d.ne(d).where(d.const_like(math.nan), u)
|
218
217
|
|
@@ -225,7 +224,7 @@ def xlog2(d:UOp) -> UOp:
|
|
225
224
|
# TODO: float16 denormal need float32 to achieve precision
|
226
225
|
if d.dtype == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
|
227
226
|
FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4)
|
228
|
-
is_denormal = d
|
227
|
+
is_denormal = d<FLT_MIN
|
229
228
|
a = is_denormal.where(d * (2 ** 64), d)
|
230
229
|
|
231
230
|
e = ilogb2k(a * (1.0 / 0.75)).cast(a.dtype)
|
@@ -246,7 +245,7 @@ def xlog2(d:UOp) -> UOp:
|
|
246
245
|
# log2(Inf) = Inf
|
247
246
|
r = d.ne(math.inf).where(r, r.const_like(math.inf))
|
248
247
|
# log2(x) = NaN for x < 0
|
249
|
-
r = d
|
248
|
+
r = (d<-0.0).where(r.const_like(math.nan), r)
|
250
249
|
# log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
|
251
250
|
# log2_zero = the value of unmasked xlog2(0.0).
|
252
251
|
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype]
|
@@ -255,3 +254,13 @@ def xlog2(d:UOp) -> UOp:
|
|
255
254
|
r = d.ne(d).where(r.const_like(math.nan), r)
|
256
255
|
# log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
|
257
256
|
return d.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf))
|
257
|
+
|
258
|
+
def xpow(base:UOp, exponent:UOp) -> UOp:
|
259
|
+
# start with b ** e = exp2(e * log2(b))
|
260
|
+
ret = (base < 0).where(-base, base).log2().mul(exponent).exp2()
|
261
|
+
# negative base adjustment: nan for non-integer exponent and -1 for odd exponent
|
262
|
+
non_int = exponent != exponent.cast(dtypes.int32).cast(exponent.dtype)
|
263
|
+
adj = non_int.where(ret.const_like(math.nan),
|
264
|
+
(exponent < 0).where(-exponent, exponent).cast(dtypes.int32).mod(2).cast(dtypes.bool).where(ret.const_like(-1), ret.const_like(1)))
|
265
|
+
# fix 0 ** 0 = 1
|
266
|
+
return (base.eq(0) & exponent.eq(0)).where(ret.const_like(1), ret * (base < 0).where(adj, ret.const_like(1)))
|