tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,11 @@
|
|
1
1
|
# all of symbolic lives here now
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, cast
|
3
3
|
import math, operator, struct, functools
|
4
4
|
from collections import defaultdict
|
5
|
-
from tinygrad.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
6
|
-
from tinygrad.dtype import ConstType, dtypes, PtrDType
|
7
|
-
from tinygrad.helpers import partition, all_same, prod,
|
8
|
-
from tinygrad.
|
5
|
+
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
6
|
+
from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace, can_safe_cast
|
7
|
+
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod, CORRECT_DIVMOD_FOLDING
|
8
|
+
from tinygrad.uop.decompositions import xpow
|
9
9
|
|
10
10
|
# ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
|
11
11
|
|
@@ -18,6 +18,7 @@ def simplify_pow(x:UOp, c:UOp) -> UOp|None:
|
|
18
18
|
|
19
19
|
def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
|
20
20
|
if (from_fmt:=c.dtype.scalar().fmt) is None or (to_fmt:=root.dtype.scalar().fmt) is None: return None
|
21
|
+
if c.dtype.itemsize != root.dtype.itemsize: return None
|
21
22
|
def convert(v:Any): return struct.unpack(to_fmt, struct.pack(from_fmt, v))[0]
|
22
23
|
return root.const_like(convert(c.arg) if root.dtype.count == 1 else tuple(map(convert, c.arg)))
|
23
24
|
|
@@ -25,6 +26,7 @@ symbolic_simple = PatternMatcher([
|
|
25
26
|
# ** self folding **
|
26
27
|
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
|
27
28
|
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
|
29
|
+
(UPat.var("x", dtype=dtypes.ints) ^ 0, lambda x: x), # x^0 -> x
|
28
30
|
(UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
|
29
31
|
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
|
30
32
|
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
|
@@ -39,8 +41,10 @@ symbolic_simple = PatternMatcher([
|
|
39
41
|
(UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
|
40
42
|
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
|
41
43
|
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
|
44
|
+
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool,)).trunc(), lambda x: x),
|
42
45
|
# ** zero folding **
|
43
46
|
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
|
47
|
+
(UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0
|
44
48
|
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
|
45
49
|
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
|
46
50
|
# x*0 -> 0 or 0*x -> 0
|
@@ -49,20 +53,38 @@ symbolic_simple = PatternMatcher([
|
|
49
53
|
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
50
54
|
# ** constant folding **
|
51
55
|
# TODO: add const folding for Ops.THREEFRY
|
52
|
-
(UPat(GroupOp.
|
53
|
-
|
56
|
+
(UPat(GroupOp.Unary, src=(UPat((Ops.VCONST, Ops.CONST)),), name="a"), lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg], False))),
|
57
|
+
(UPat(GroupOp.Binary-{Ops.THREEFRY}, src=(UPat((Ops.VCONST, Ops.CONST)),)*2, name="a"),
|
58
|
+
lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg], False))),
|
59
|
+
(UPat(GroupOp.Ternary, src=(UPat((Ops.VCONST, Ops.CONST)),)*3, name="a"),
|
60
|
+
lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg, a.src[2].arg], False))),
|
54
61
|
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
55
62
|
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
|
56
63
|
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
|
57
64
|
(UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
|
58
65
|
# *** cast/bitcast ***
|
59
|
-
(UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
|
66
|
+
(UPat(Ops.CAST, name="root", src=(UPat.cvar("c"),)), lambda root, c: root.const_like(c.arg)),
|
60
67
|
(UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
|
61
68
|
(UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast),
|
69
|
+
# b.cast(a).cast(b) -> b if a preserves all values in b
|
70
|
+
(UPat.var('x').cast().named('a').cast().named('b'), lambda x,a,b: x if x.dtype == b.dtype and can_safe_cast(b.dtype, a.dtype) else None),
|
62
71
|
# ** pow **
|
63
72
|
(UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow),
|
64
73
|
# positive const ** x
|
65
74
|
(UPat.cvar("c", vec=False).alu(Ops.POW, UPat.var("x")), lambda c,x: c if c.arg == 1 else (x*math.log2(c.arg)).exp2() if c.arg > 0 else None),
|
75
|
+
# rules for threefry
|
76
|
+
((UPat.var('x', dtypes.uint64)&0xFFFFFFFF).cast(dtypes.uint32), lambda x: x.cast(dtypes.uint32)&0xFFFFFFFF), # TODO: why is the and needed?
|
77
|
+
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
78
|
+
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
|
79
|
+
# hacks for threefry long removal when padded (TODO: genericize)
|
80
|
+
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64) * UPat.var('y').where(UPat.const(dtypes.uint64, 1<<32), UPat.const(dtypes.uint64, 0)),
|
81
|
+
lambda x,y: y.where(x, 0).cast(dtypes.uint64) * (1<<32)),
|
82
|
+
((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
|
83
|
+
lambda x,y: y.where(x.cast(dtypes.uint32), 0)),
|
84
|
+
# new decomp rules for threefry
|
85
|
+
(((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
86
|
+
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x),
|
87
|
+
(UPat.var('b').where(UPat.var('x', dtypes.uint32).cast(dtypes.uint64), UPat.const(dtypes.uint64, 0)).cast(dtypes.uint32), lambda b,x: b.where(x,0))
|
66
88
|
])
|
67
89
|
|
68
90
|
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********
|
@@ -72,26 +94,31 @@ def split_uop(x:UOp, sep:Ops):
|
|
72
94
|
for s in x.src: yield from split_uop(s, sep)
|
73
95
|
else: yield x
|
74
96
|
|
75
|
-
def fold_unrolled_divs(divs:UOp):
|
97
|
+
def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
|
76
98
|
# div pattern in unrolled arange
|
77
99
|
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
|
78
|
-
|
79
|
-
for u in
|
100
|
+
seen_const, ans = [], None
|
101
|
+
for u in split_uop(divs, Ops.ADD):
|
102
|
+
if fac!=1:
|
103
|
+
if u.op is not Ops.MUL or u.src[1].op is not Ops.CONST or u.src[1].arg != fac: return None
|
104
|
+
u = u.src[0]
|
80
105
|
if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
|
81
|
-
if denominator is None: denominator = u.src[1].arg
|
82
106
|
if denominator != u.src[1].arg: return None
|
107
|
+
if (s0:=u.src[0]).vmin < 0: return None
|
83
108
|
# assumed CONST is the last of an ADD
|
84
|
-
if
|
109
|
+
if s0.op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
|
85
110
|
seen_const.append(s0.src[1].arg)
|
86
111
|
s0 = s0.src[0]
|
87
112
|
else: seen_const.append(0)
|
88
113
|
if ans is None: ans = s0
|
89
114
|
if ans is not s0: return None
|
90
|
-
if
|
115
|
+
if ans is None: return None
|
91
116
|
# the first (denominator-len(seen_const)) terms may have been folded to 0 already
|
92
117
|
for i in range(denominator-len(seen_const)):
|
93
118
|
if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
|
94
|
-
|
119
|
+
if sorted(seen_const)==list(range(denominator)):
|
120
|
+
return fac*ans
|
121
|
+
return None
|
95
122
|
|
96
123
|
def lt_folding(x:UOp, c:int) -> UOp|None:
|
97
124
|
p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
|
@@ -112,67 +139,139 @@ def canonicalize_simplex(X:UOp) -> UOp|None:
|
|
112
139
|
ret.append(u)
|
113
140
|
return functools.reduce(operator.add, ret) if changed else None
|
114
141
|
|
115
|
-
def
|
116
|
-
#
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
142
|
+
def cancel_divmod(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
143
|
+
# simple cancel div/mod case when the range of the numerator lies within a single denominator interval
|
144
|
+
x_min, x_max, y_min, y_max = x.vmin, x.vmax, y.vmin, y.vmax
|
145
|
+
assert isinstance(x_min, int) and isinstance(x_max, int) and isinstance(y_min, int) and isinstance(y_max, int)
|
146
|
+
if y_min==y_max==0: raise ZeroDivisionError(f"{'Division' if d.op is Ops.IDIV else 'Mod'} by zero trying to rewrite {x.alu(d.op, y)}")
|
147
|
+
if y_min*y_max > 0 and (q:=cdiv(x_min,y_min)) == cdiv(x_min,y_max) == cdiv(x_max,y_min) == cdiv(x_max,y_max):
|
148
|
+
return x - q*y if d.op is Ops.MOD else d.const_like(q)
|
149
|
+
return None
|
122
150
|
|
123
|
-
|
151
|
+
def remove_nested_mod(m: UOp, x: UOp, y: UOp) -> UOp|None:
|
152
|
+
# remove nested mod in case the inner mod is a multiple of the outer mod
|
153
|
+
# example: (a%4 + b)%2 -> (a+b)%2
|
154
|
+
if ((c := y.arg) < 0) or x.vmin<0: return None
|
155
|
+
new_xs = []
|
156
|
+
something_changed = False
|
124
157
|
for u in split_uop(x, Ops.ADD):
|
125
|
-
if u.op is Ops.MOD
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
gcd = math.gcd(r, gcd)
|
136
|
-
factors.append(f); svars.append(v); quotients.append(q); remainders.append(r) # noqa: E702
|
137
|
-
|
138
|
-
lbound = ubound = offset = offset % c
|
158
|
+
if u.op is Ops.MOD:
|
159
|
+
if u.src[1].divides(c) is not None:
|
160
|
+
something_changed = True
|
161
|
+
u = u.src[0]
|
162
|
+
new_xs.append(u)
|
163
|
+
new_x: UOp = functools.reduce(operator.add, new_xs)
|
164
|
+
if something_changed and new_x.vmin>=0: return new_x % y
|
165
|
+
return None
|
166
|
+
|
167
|
+
def fold_binary_numerator(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
139
168
|
# we can fold if the expression has only one non-constant term and this term can only take on two values
|
140
|
-
if
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
169
|
+
if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
|
170
|
+
x,const = x.pop_const()
|
171
|
+
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
|
172
|
+
if len(terms)==1 and (v:=terms[0]).vmax-v.vmin == 1:
|
173
|
+
y1 = cmod(factors[0]*v.vmin+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmin+const, c) # type: ignore
|
174
|
+
y2 = cmod(factors[0]*v.vmax+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmax+const, c) # type: ignore
|
175
|
+
return (y2-y1)*(v-v.vmin) + y1
|
176
|
+
return None
|
145
177
|
|
146
|
-
|
178
|
+
def fold_divmod_congruence(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
147
179
|
# within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
180
|
+
if (x.vmin<0 and CORRECT_DIVMOD_FOLDING) or ((c := y.arg) < 0) or (x.dtype.count > 1): return None
|
181
|
+
x,const = x.pop_const()
|
182
|
+
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
|
183
|
+
# a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
|
184
|
+
rems = [min((r:=f%c), r-c, key=abs) for f in factors]
|
185
|
+
if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c==rem.vmax//c and all(f > 0 for f in factors):
|
186
|
+
if d.op is Ops.MOD: return rem - rem.vmin//c*c
|
187
|
+
return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + (const-const%c+rem.vmin//c*c)//c
|
188
|
+
return None
|
189
|
+
|
190
|
+
def divide_by_gcd(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
191
|
+
# x//y -> (x//gcd)//(y//gcd) or x%y -> gcd*(x//gcd)%(y//gcd)
|
192
|
+
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
|
193
|
+
if (gcd := math.gcd(y.arg, *factors)) == 1: return None
|
194
|
+
ret = sum(f//gcd * v for f,v in zip(factors, terms)).alu(d.op, y.const_like(y.arg//gcd))
|
195
|
+
return ret*gcd if d.op is Ops.MOD else ret
|
196
|
+
|
197
|
+
def nest_div_by_smallest_factor(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
198
|
+
# we try and nest the div and see if it allows the numerator to be simplified
|
199
|
+
if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
|
200
|
+
factors = [u.const_factor() for u in split_uop(x.pop_const()[0], Ops.ADD)]
|
201
|
+
# div is the smallest factor of the denominator (greater than 1) out of all "factors"
|
202
|
+
# TODO: there are better ways to pick `div`, this sometimes adds extra divisions
|
203
|
+
# TODO: add same optimization for mod
|
204
|
+
div = min([y.arg]+[abs(f) for f in factors if abs(f) > 1 and (c%f)==0])
|
205
|
+
if (1 < div < c) and (newxs:=(newx:=(x//div)).simplify()) is not newx and x.vmin>=0 and newx.vmin>=0: return newxs//(c//div)
|
206
|
+
return None
|
207
|
+
|
208
|
+
def simplify_remainder(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
209
|
+
# we try and take out the quotient and see if it allows the numerator to be simplified
|
210
|
+
if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
|
211
|
+
x_no_const,const = x.pop_const()
|
212
|
+
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x_no_const, Ops.ADD)])
|
213
|
+
quotients, remainders = zip(*[divmod(f, c) for f in factors])
|
214
|
+
gcd = math.gcd(c, *remainders) # gcd without const!
|
215
|
+
if const%c==const and gcd==1 and not any(r==0 or (r!=f and d.op is Ops.MOD) for r,f in zip(remainders, factors)): return None
|
216
|
+
|
162
217
|
quo, rem = x.const_like(const//c), x.const_like((const%c)//gcd)
|
163
|
-
for q,r,f,v in zip(quotients, remainders, factors,
|
164
|
-
if
|
218
|
+
for q,r,f,v in zip(quotients, remainders, factors, terms):
|
219
|
+
if d.op is Ops.IDIV and r!=0:
|
165
220
|
rem += f//gcd * v
|
166
221
|
else:
|
167
222
|
rem += r//gcd * v
|
168
223
|
quo += q * v
|
169
224
|
|
170
|
-
if
|
225
|
+
# if numerator before/after is negative, and it has remainder, don't simplify because C divmod is different from python divmod.
|
226
|
+
if (x.vmin < 0 or rem.vmin < 0) and remainders: return None
|
227
|
+
if d.op is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
|
171
228
|
return rem//(c//gcd)+quo
|
172
229
|
|
173
|
-
|
230
|
+
def gep_through_wmma(gep:UOp, wmma:UOp):
|
231
|
+
out_sz = prod(x[1] for x in wmma.arg[6][-1])
|
232
|
+
wmma_idxs = gep.arg[::out_sz]
|
233
|
+
for i in range(out_sz):
|
234
|
+
if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None
|
235
|
+
tsrcs = []
|
236
|
+
for s,sz in zip(wmma.src, wmma.arg[6]):
|
237
|
+
src_args = []
|
238
|
+
ssz = prod(x[1] for x in sz)
|
239
|
+
for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz))
|
240
|
+
tsrcs.append(s.gep(tuple(src_args)))
|
241
|
+
return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
|
242
|
+
|
243
|
+
gep_pushing = PatternMatcher([
|
244
|
+
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
245
|
+
(UPat(Ops.GEP, name='g2').f(Ops.GEP, name='g1'),
|
246
|
+
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(len(g1.arg))))),
|
247
|
+
(UPat(Ops.VECTORIZE, name='vec').f(Ops.GEP, name='gep'),
|
248
|
+
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
249
|
+
(UPat.cvar("c", vec=False).f(Ops.GEP, name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
250
|
+
(UPat(Ops.VCONST, name="c").f(Ops.GEP, name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
|
251
|
+
# GEP on void is skipped
|
252
|
+
(UPat(Ops.GEP, src=(UPat(dtype=dtypes.void, name="x"),)), lambda x: x),
|
253
|
+
# GEP in order is removed
|
254
|
+
(UPat(Ops.GEP, name="g"), lambda g: g.src[0] if not isinstance(g.dtype, PtrDType) and g.arg == tuple(range(g.src[0].dtype.count)) else None),
|
255
|
+
# push all GEPs through ALUs (fix arange stuff)
|
256
|
+
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, name='gep'),
|
257
|
+
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
|
258
|
+
if not isinstance(gep.dtype, PtrDType) else None),
|
259
|
+
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
|
260
|
+
(UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
|
261
|
+
if not isinstance(x.dtype, PtrDType) else None),
|
262
|
+
# VECTORIZE on same GEP
|
263
|
+
(UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),
|
264
|
+
# push some GEPs through WMMAs
|
265
|
+
(UPat(Ops.WMMA, name="wmma").f(Ops.GEP, name="gep"), gep_through_wmma),
|
266
|
+
])
|
267
|
+
|
268
|
+
commutative = PatternMatcher([
|
174
269
|
# ** COMMUTATIVE flipping (only for ints) **
|
270
|
+
# NOTE: this can break merging vector math by only flipping some of them
|
175
271
|
(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
272
|
+
])
|
273
|
+
|
274
|
+
symbolic = symbolic_simple+commutative+PatternMatcher([
|
176
275
|
# ** boolean algebra **
|
177
276
|
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
|
178
277
|
# ** combine terms **
|
@@ -187,11 +286,12 @@ symbolic = symbolic_simple+PatternMatcher([
|
|
187
286
|
# a conditional with the same results either way is a noop, also fold const conditionals
|
188
287
|
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
189
288
|
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
289
|
+
(UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("t"), UPat.var("f")), lambda cond, t, f: cond.where(f,t)),
|
190
290
|
# alu of two where with same conds can combine, only do if true branch or false branch is const
|
191
291
|
(UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
|
192
292
|
lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
|
193
|
-
# ALU min==max -> CONST (slow!)
|
194
|
-
(UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
293
|
+
# ALU/variable min==max -> CONST (slow!)
|
294
|
+
(UPat(GroupOp.ALU|{Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
195
295
|
# max folding
|
196
296
|
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
197
297
|
# TODO: why does this rule break beautiful_mnist?
|
@@ -209,28 +309,42 @@ symbolic = symbolic_simple+PatternMatcher([
|
|
209
309
|
# c0*x<c1 for negative int c0 and non-positive c1
|
210
310
|
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
|
211
311
|
lambda x,c0,c1: (-x)<(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
|
212
|
-
# x//
|
213
|
-
((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("
|
214
|
-
lambda x,
|
312
|
+
# x//d<c
|
313
|
+
((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("d", vec=False))<UPat.cvar("c", vec=False),
|
314
|
+
lambda x,d,c: (x<(c.arg*d.arg) if c.arg > 0 else x<(c.arg*d.arg-(d.arg-1))) if d.arg > 0 else None),
|
215
315
|
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
|
216
|
-
(
|
217
|
-
(
|
316
|
+
((UPat.var("x") + UPat.cvar("c1")) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
|
317
|
+
((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
|
218
318
|
# *** rules from symbolic ***
|
219
319
|
# unrolled arange div folding
|
220
|
-
(UPat(
|
320
|
+
((UPat() + UPat()//UPat.cvar("d", vec=False)).named("divs"), lambda divs,d: fold_unrolled_divs(divs, d.arg)),
|
321
|
+
((UPat() + (UPat()//UPat.cvar("d", vec=False))*UPat.cvar("c")).named("divs"), lambda divs,d,c: fold_unrolled_divs(divs, d.arg, c.arg)),
|
221
322
|
# generic lt folding
|
222
323
|
(UPat.var("x", dtypes.sints)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
324
|
+
(UPat.var("x", dtypes.sints)*-1 < UPat.var("y", dtypes.sints)*-1, lambda x,y: y<x),
|
223
325
|
# canonicalize a simplex with positive coefficients > 0
|
224
326
|
# not x < 1 -> X > 0
|
225
327
|
((UPat.var("x", dtypes.ints)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
|
226
328
|
# ** div **
|
227
329
|
# div folding
|
228
|
-
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)
|
229
|
-
|
330
|
+
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)
|
331
|
+
if c.vmin>0 and d.vmin>0 and ((x.vmin>=0 and a.vmin>=0) or (x.vmax<=0 and a.vmax<=0)) else None), # (x//c+a)//d -> (x+a*c)//(c*d)
|
332
|
+
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.var("y"))), cancel_divmod),
|
333
|
+
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_binary_numerator),
|
334
|
+
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_divmod_congruence),
|
335
|
+
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), divide_by_gcd),
|
336
|
+
(UPat(Ops.MOD, dtypes.sints, name="m", src=(UPat.var("x"), UPat.cvar("y", vec=False))), remove_nested_mod),
|
337
|
+
(UPat((Ops.IDIV), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), nest_div_by_smallest_factor),
|
338
|
+
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), simplify_remainder),
|
339
|
+
(UPat.var("x") // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None),
|
340
|
+
(UPat.var("x") // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <=0 else None),
|
341
|
+
((UPat.var("x", dtypes.sints)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
|
342
|
+
lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),
|
230
343
|
# ** mod **
|
231
344
|
# mod folding
|
232
|
-
(UPat.var("x") % UPat.var("
|
233
|
-
|
345
|
+
(UPat.var("x") % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),
|
346
|
+
(UPat.var("x") % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None),
|
347
|
+
])+gep_pushing
|
234
348
|
|
235
349
|
symbolic_flat = symbolic+PatternMatcher([
|
236
350
|
# ** combine terms (opinionated) **
|
@@ -262,19 +376,25 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
|
|
262
376
|
except ValueError: return uop # give up if we cannot parse the valid
|
263
377
|
bounds[expr][int(is_upper)] = c
|
264
378
|
|
379
|
+
# don't simplify any other gates, can lead to OOB, we substitute them back later
|
380
|
+
uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, arg=u) for u in uop.toposort() if u.op is Ops.INDEX}))
|
381
|
+
|
265
382
|
# simplify uop given that valid is True
|
266
383
|
for expr,v in bounds.items():
|
384
|
+
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
|
267
385
|
# some expr has lower bound > upper bound -> valid is an empty set and we return None
|
268
|
-
if
|
269
|
-
|
270
|
-
|
386
|
+
if v0 > v1: return None
|
387
|
+
# whole node became a const
|
388
|
+
if v0 == v1:
|
389
|
+
uop = uop.substitute({expr:expr.const_like(v0)}).simplify()
|
390
|
+
continue
|
391
|
+
# every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
271
392
|
candidates = []
|
272
|
-
if expr.op is Ops.ADD and
|
393
|
+
if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
|
273
394
|
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
274
395
|
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
|
275
396
|
# try checking the whole clause
|
276
|
-
if expr in uop.toposort:
|
277
|
-
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
|
397
|
+
if expr in uop.toposort(): candidates.append([(expr, UOp.variable("fake", v0, v1, expr.dtype))])
|
278
398
|
|
279
399
|
for candidate in candidates:
|
280
400
|
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
|
@@ -284,11 +404,13 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
|
|
284
404
|
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
|
285
405
|
elif all_same(newuops): uop = newuops[0]
|
286
406
|
|
407
|
+
# put the loads back in
|
408
|
+
uop = uop.substitute({v:k for k,v in load_subs.items()})
|
287
409
|
return uop
|
288
410
|
|
289
411
|
def _valid_priority(v: UOp, valids:list[UOp]):
|
290
412
|
# we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
|
291
|
-
try: return sum(-1 if parse_valid(v)[0] in other.toposort else 0 for other in valids)
|
413
|
+
try: return sum(-1 if parse_valid(v)[0] in other.toposort() else 0 for other in valids)
|
292
414
|
except ValueError: return 0
|
293
415
|
|
294
416
|
def simplify_valid(valid:UOp) -> UOp|None:
|
@@ -296,100 +418,32 @@ def simplify_valid(valid:UOp) -> UOp|None:
|
|
296
418
|
something_changed = False
|
297
419
|
valids = list(split_uop(valid, Ops.AND))
|
298
420
|
for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
|
421
|
+
# TODO: root cause this and test_simplify_valid_from_div
|
422
|
+
if stmt.op is Ops.CAST: return None
|
299
423
|
ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
|
300
424
|
if ret[-1] is not stmt: something_changed = True
|
301
425
|
return functools.reduce(operator.and_, ret) if something_changed else None
|
302
426
|
|
303
|
-
# ***** threefry *****
|
304
|
-
|
305
|
-
def threefry2x32(x: UOp, key: UOp):
|
306
|
-
# split x into two uint32, since x in a uint64
|
307
|
-
x0, x1 = (x & 0xffffffff).cast(dtypes.uint32), ((x // 2**32) & 0xffffffff).cast(dtypes.uint32)
|
308
|
-
|
309
|
-
rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
|
310
|
-
key0, key1 = (key & 0xffffffff).cast(dtypes.uint32), ((key // 2**32) & 0xffffffff).cast(dtypes.uint32)
|
311
|
-
ks = [key1, key0 ^ key1 ^ 0x1BD11BDA, key0]
|
312
|
-
xr = [x0 + ks[-1], x1 + ks[0]]
|
313
|
-
for i in range(5):
|
314
|
-
for r in rotations[i % 2]: xr[0], xr[1] = (x0 := xr[0] + xr[1]), x0 ^ ((xr[1] * 2**r) + (xr[1] // 2**(32 - r)))
|
315
|
-
xr = [(xr[0] + ks[i % 3]), (xr[1] + ks[(i + 1) % 3] + i + 1)]
|
316
|
-
|
317
|
-
return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64)
|
318
|
-
|
319
427
|
# ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ********
|
320
428
|
|
321
|
-
def
|
322
|
-
|
323
|
-
if
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
if
|
328
|
-
|
329
|
-
if
|
330
|
-
|
331
|
-
if vec is not None:
|
332
|
-
# add, mul, loop_start, loop_end
|
333
|
-
def dvec(x:UOp):
|
334
|
-
if x.op is Ops.CONST: return UOp.const(x.dtype.vec(vec.dtype.count), x.arg)
|
335
|
-
return UOp(Ops.VECTORIZE, x.dtype.vec(vec.dtype.count), src=(x,)*vec.dtype.count)
|
336
|
-
add, mul, loop_start, loop_end = dvec(add), dvec(mul), dvec(loop_start), dvec(loop_end)
|
337
|
-
if mul.vmin > 0 and ne is not None:
|
338
|
-
comprange = UOp.minimum(loop_end, UOp.maximum((add-compval)//mul + (loop_end-loop_start), loop_start))
|
339
|
-
elif mul.vmax < 0 and ne is None:
|
340
|
-
comprange = UOp.minimum(loop_end, UOp.maximum((add-compval-mul)//mul + (loop_end-loop_start), loop_start))
|
341
|
-
else:
|
342
|
-
return None
|
343
|
-
new_reduce_op = comprange.cast(multconst.dtype) * multconst
|
344
|
-
# TODO: what does it mean to have the same numbered DEFINE_ACC with different ranges?
|
345
|
-
new_acc = acc.replace(src=acc.src[0:1]+tuple(x for x in acc.src[1:] if x is not rng))
|
346
|
-
ret = new_acc.assign(new_acc+new_reduce_op)
|
347
|
-
if extra is not None: ret = ret + acc.assign(acc+extra)
|
348
|
-
return ret
|
349
|
-
|
350
|
-
def index_collapse(idx:UOp,rng:UOp,buf:UOp,ld:UOp,acc:UOp,add=UOp.const(dtypes.int, 0),mul=UOp.const(dtypes.int, 1)):
|
351
|
-
if rng not in acc.src: return None
|
352
|
-
new_load = UOp.load(buf.index(add+mul*idx, (idx >= rng.src[0]) & (idx < rng.src[1])), dtype=ld.dtype)
|
353
|
-
new_acc = acc.replace(src=acc.src[0:1]+tuple(x for x in acc.src[1:] if x is not rng))
|
354
|
-
return new_acc.assign(new_acc+new_load)
|
355
|
-
|
356
|
-
def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
|
357
|
-
reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.toposort)
|
358
|
-
if len(reduce_unparented) == 0: return None
|
359
|
-
new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented))
|
360
|
-
ret = new_acc.assign(new_acc.alu(alu.op, ret))
|
361
|
-
if alu.op is Ops.ADD:
|
362
|
-
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
363
|
-
return ret
|
364
|
-
|
365
|
-
def gep_through_wmma(gep:UOp, wmma:UOp):
|
366
|
-
out_sz = prod(x[1] for x in wmma.arg[6][-1])
|
367
|
-
wmma_idxs = gep.arg[::out_sz]
|
368
|
-
for i in range(out_sz):
|
369
|
-
if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None
|
370
|
-
tsrcs = []
|
371
|
-
for s,sz in zip(wmma.src, wmma.arg[6]):
|
372
|
-
src_args = []
|
373
|
-
ssz = prod(x[1] for x in sz)
|
374
|
-
for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz))
|
375
|
-
tsrcs.append(s.gep(tuple(src_args)))
|
376
|
-
return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
|
377
|
-
|
378
|
-
acc_pat, rng_pat = UPat(Ops.DEFINE_ACC, name="acc"), UPat(Ops.RANGE, name="rng")
|
379
|
-
rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat)
|
380
|
-
|
381
|
-
index_load = UPat.var("buf").index(rng_aug).load(name="ld")
|
382
|
-
|
383
|
-
arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
|
384
|
-
arange_m = ((arange_augrng<UPat.cvar("compval"))!=UPat(Ops.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0))
|
385
|
-
|
386
|
-
# this moves the accumulation variable down an unrolled add chain which allows for more efficient accumulation using mulacc
|
387
|
-
mulacc_unrolled = PatternMatcher([(UPat.var("x")+UPat.var("y")+acc_pat, lambda x,y,acc: (acc+x)+y if y.op is not Ops.DEFINE_ACC else None)])
|
429
|
+
def reduce_mul_chain(r:UOp):
|
430
|
+
if r.arg not in {Ops.ADD, Ops.MAX}: return None
|
431
|
+
if r.dtype != r.src[0].dtype: return None
|
432
|
+
inside, outside = [], []
|
433
|
+
for m in split_uop(r.src[0], Ops.MUL):
|
434
|
+
m_parents = m.toposort()
|
435
|
+
if all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m)
|
436
|
+
else: inside.append(m)
|
437
|
+
if len(outside) == 0: return None
|
438
|
+
return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside)
|
388
439
|
|
389
440
|
# this is symbolic 2.0
|
441
|
+
REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP}
|
442
|
+
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
|
390
443
|
sym = symbolic_flat+PatternMatcher([
|
391
|
-
#
|
392
|
-
(UPat
|
444
|
+
# LOAD/STORE -> NOOP
|
445
|
+
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
|
446
|
+
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
|
393
447
|
# VECTORIZE/CONST, VECTORIZE/GEP
|
394
448
|
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
|
395
449
|
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat.var("x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
|
@@ -401,76 +455,45 @@ sym = symbolic_flat+PatternMatcher([
|
|
401
455
|
# VECTORIZE void is SINK
|
402
456
|
(UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b),
|
403
457
|
(UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
|
404
|
-
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
405
|
-
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
|
406
|
-
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
|
407
|
-
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
|
408
|
-
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
409
|
-
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
410
|
-
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
|
411
|
-
# push all GEPs through ALUs (fix arange stuff)
|
412
|
-
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
|
413
|
-
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
|
414
|
-
if not isinstance(gep.dtype, PtrDType) else None),
|
415
|
-
# push some GEPs through WMMAs
|
416
|
-
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
|
417
|
-
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
|
418
|
-
(UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
|
419
|
-
if not isinstance(x.dtype, PtrDType) else None),
|
420
458
|
# tensor core with a 0 input is acc
|
421
459
|
(UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
|
422
460
|
(UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
|
423
|
-
# tensor core cleanups
|
424
|
-
(UPat.var("add") + UPat(Ops.WMMA, name="wmma"),
|
425
|
-
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
426
|
-
# threefry + remove longs
|
427
|
-
(UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32),
|
428
|
-
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64).cast(dtypes.uint32), lambda x: x), # cast there and back is noop (TODO: genericize)
|
429
|
-
((UPat.var('x', dtypes.uint64)&0xFFFFFFFF).cast(dtypes.uint32), lambda x: x.cast(dtypes.uint32)), # cast does truncation
|
430
|
-
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
431
|
-
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
|
432
|
-
# hacks for threefry long removal when padded (TODO: genericize)
|
433
|
-
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64) * UPat.var('y').where(UPat.const(dtypes.uint64, 1<<32), UPat.const(dtypes.uint64, 0)),
|
434
|
-
lambda x,y: y.where(x, UOp.const(dtypes.uint32, 0)).cast(dtypes.uint64) * (1<<32)),
|
435
|
-
((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
|
436
|
-
lambda x,y: y.where(x.cast(dtypes.uint32), UOp.const(dtypes.uint32, 0))),
|
437
|
-
# arange loop folding
|
438
|
-
(acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse),
|
439
|
-
# indexing, with cast or where
|
440
|
-
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
|
441
|
-
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
|
442
|
-
# parentless reduce # TODO: add MUL
|
443
|
-
(acc_pat.assign(UPat((Ops.ADD, Ops.MAX), src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
|
444
461
|
# ** self folding **
|
445
|
-
(UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
|
446
|
-
(UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
|
447
462
|
# x!=0 -> (bool)x
|
448
463
|
(UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
|
449
464
|
# ** where **
|
450
465
|
# push cast to branches
|
451
466
|
(UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))),
|
467
|
+
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
|
468
|
+
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
|
452
469
|
# ** pow **
|
453
470
|
((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))),
|
471
|
+
# index true is index without op
|
472
|
+
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
|
454
473
|
# ** load/store folding **
|
455
474
|
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
|
456
|
-
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"),
|
457
|
-
|
475
|
+
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"),
|
476
|
+
UPat.load(UPat(Ops.INDEX, name="index"))), allow_any_len=True, name="store"),
|
477
|
+
lambda index, gate, alt, store: UOp.store(index.src[0].index(index.src[1], gate), alt, *store.src[2:])),
|
458
478
|
# fold gated LOAD/STORE
|
459
479
|
(UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True
|
460
|
-
(UPat().index(UPat(), UPat.const(dtypes.bool, False)).
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
# remove VECTORIZE from SINK/BARRIER
|
467
|
-
(UPat(Ops.BARRIER, src=(UPat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)),
|
480
|
+
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat(), UPat.const(dtypes.bool, False)).or_casted(),), allow_any_len=True, name="x"),
|
481
|
+
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # NULL pointer store does nothing. NULL pointer load produces 0
|
482
|
+
# remove VECTORIZE from SINK/BARRIER. TODO: SINK/BARRIER are really the same thing at GLOBAL/LOCAL levels
|
483
|
+
(UPat(Ops.BARRIER, name="root"),
|
484
|
+
lambda root: UOp(Ops.BARRIER, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_BARRIER else (x,) for x in root.src)), root.arg)
|
485
|
+
if any(x.op in REMOVE_FROM_BARRIER for x in root.src) else None),
|
468
486
|
(UPat(Ops.SINK, name="root"),
|
469
|
-
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in
|
470
|
-
if any(x.op in
|
487
|
+
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_SINK else (x,) for x in root.src)), root.arg)
|
488
|
+
if any(x.op in REMOVE_FROM_SINK for x in root.src) else None),
|
471
489
|
((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
|
472
490
|
((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
|
491
|
+
((UPat.var("x") * UPat.cvar("c")).reciprocal(), lambda x,c: x.reciprocal()*c.reciprocal()), # 1/(x*c) -> (1/c)*(1/x)
|
473
492
|
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)
|
474
493
|
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)),
|
475
494
|
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y),
|
495
|
+
# move const multiply after REDUCE (NOTE: the mul chain can do this, but only if it's a same dtype reduce)
|
496
|
+
((UPat.var("x")*UPat.cvar("c", vec=False)).reduce(arg=Ops.ADD, name="r", allow_any_len=True), lambda x,c,r: r.replace(src=(x,)+r.src[1:])*c.arg),
|
497
|
+
# reduce mul chain, move muls after the reduce
|
498
|
+
(UPat(Ops.MUL).reduce(name="r", allow_any_len=True), reduce_mul_chain),
|
476
499
|
])
|