tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/uop/spec.py
ADDED
@@ -0,0 +1,228 @@
|
|
1
|
+
from typing import cast, Callable
|
2
|
+
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite
|
3
|
+
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace
|
4
|
+
from tinygrad.helpers import all_same, prod, DEBUG, ContextVar, Context
|
5
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
6
|
+
try:
|
7
|
+
import z3
|
8
|
+
|
9
|
+
# IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND
|
10
|
+
def z3_cdiv(a, b):return z3.If((a<0), z3.If(0<b, (a+(b-1))/b, (a-(b+1))/b), a/b)
|
11
|
+
z3_alu: dict[Ops, Callable] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()),
|
12
|
+
Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If,
|
13
|
+
Ops.MAX: lambda a,b: z3.If(a<b, b, a)}
|
14
|
+
def create_bounded(name:str, vmin, vmax, solver:z3.Solver) -> z3.ArithRef:
|
15
|
+
s = z3.Int(name, ctx=solver.ctx)
|
16
|
+
solver.add(vmin <= s, s <= vmax)
|
17
|
+
return s
|
18
|
+
|
19
|
+
# ctx is (solver, load_number_dict)
|
20
|
+
z3_renderer = PatternMatcher([
|
21
|
+
# Ops.SPECIAL can have symbolic arg but it wont be in the toposort beacuse its not a src, we need to add it manually
|
22
|
+
(UPat(Ops.SPECIAL, src=(), name="x"), lambda x: UOp(Ops.SPECIAL, arg=x.arg[0], src=(x.ufix(x.arg[1]),))),
|
23
|
+
(UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg, 0, x.src[0].arg-1, ctx[0]))),
|
24
|
+
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0]))),
|
25
|
+
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"ridx{x.arg}", 0, x.src[0].arg-1, ctx[0]))),
|
26
|
+
(UPat(Ops.LOAD, dtypes.ints, name="x"),
|
27
|
+
lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.vmin, x.vmax, ctx[0]))),
|
28
|
+
(UPat(Ops.CONST, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx))),
|
29
|
+
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.bool,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
30
|
+
(UPat(Ops.CAST, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.vmin, x.vmax, ctx[0]))),
|
31
|
+
(UPat(Ops.XOR, src=UPat(Ops.NOOP), name="x"),
|
32
|
+
lambda x: UOp(Ops.NOOP, arg=z3.BV2Int(z3_alu[x.op](*(z3.Int2BV(s.arg, x.dtype.itemsize*8) for s in x.src))))),
|
33
|
+
(UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=z3_alu[x.op](*(s.arg for s in x.src)))),
|
34
|
+
])
|
35
|
+
|
36
|
+
z3_imported = True
|
37
|
+
except (ImportError, AttributeError): z3_imported = False
|
38
|
+
|
39
|
+
# if you have z3 installed, by default we check the bounds
|
40
|
+
IGNORE_OOB = ContextVar("IGNORE_OOB", int(not z3_imported))
|
41
|
+
|
42
|
+
buffer_spec = PatternMatcher([
|
43
|
+
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
|
44
|
+
(UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d:
|
45
|
+
isinstance(d.arg, str) or (isinstance(d.arg, tuple) and all(isinstance(s, str) for s in d.arg))),
|
46
|
+
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), allow_any_len=True, name="buf"),
|
47
|
+
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
|
48
|
+
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
|
49
|
+
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
|
50
|
+
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.MSTACK, src=UPat(Ops.BUFFER)),)), lambda: True),
|
51
|
+
# allow VIEW here. TODO: what views specifically are allowed? does this mess with gradient?
|
52
|
+
(UPat(Ops.VIEW), lambda: True),
|
53
|
+
])
|
54
|
+
|
55
|
+
assign_spec = PatternMatcher([
|
56
|
+
# KERNEL can attach to an ASSIGN to describe the compute required to realize a BUFFER
|
57
|
+
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
|
58
|
+
|
59
|
+
# ASSIGN has a target and a value. It can also optionally depend on other assigns
|
60
|
+
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
|
61
|
+
|
62
|
+
# MSELECT chooses one of the multi buffers
|
63
|
+
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
|
64
|
+
|
65
|
+
# MSTACK combines buffers into multi
|
66
|
+
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)),
|
67
|
+
])
|
68
|
+
|
69
|
+
# *** this is the spec of a Tensor in UOp ***
|
70
|
+
|
71
|
+
tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
|
72
|
+
(UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)),
|
73
|
+
# naturally correct
|
74
|
+
lambda mv,x: (isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
|
75
|
+
# "make things that can't be images not images" can change the buffer dtype
|
76
|
+
# this is fine as long as it's a realized buffer and base dtypes match.
|
77
|
+
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.base.op is Ops.BUFFER)),
|
78
|
+
(UPat(Ops.VIEW, src=(UPat.var("x"),)), lambda x: x.base.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.CONST, Ops.DEVICE}),
|
79
|
+
|
80
|
+
# Tensor variable bindings
|
81
|
+
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
|
82
|
+
|
83
|
+
# Tensor const has a device and an unmasked ShapeTracker of stride 0
|
84
|
+
# NOTE: variables in shape can cause multiple views in this ShapeTracker and other issues, see TestSymbolicJit.test_ones_sum
|
85
|
+
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)),
|
86
|
+
lambda st: len(st.st.views) == 1 and all(v.mask is None for v in st.st.views)),
|
87
|
+
|
88
|
+
# DETACH and CONTIGUOUS change how we interpret the source UOp
|
89
|
+
# CONTIGUOUS ensures the source UOp realizes
|
90
|
+
(UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="root", src=(UPat.var("x"),), arg=None),
|
91
|
+
lambda root,x: root.dtype == x.dtype),
|
92
|
+
|
93
|
+
# COPY/ALLREDUCE/MULTI
|
94
|
+
(UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE)), arg=None), lambda copy,x: copy.dtype == x.dtype),
|
95
|
+
(UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)),
|
96
|
+
(UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)),
|
97
|
+
])
|
98
|
+
|
99
|
+
# ***** uop type spec *****
|
100
|
+
|
101
|
+
def validate_index(idx:UOp, gate:UOp=UOp.const(dtypes.bool, True)):
|
102
|
+
if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := cast(PtrDType, idx.src[0].dtype).size) == -1: return True
|
103
|
+
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
|
104
|
+
if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True
|
105
|
+
mask = idx.src[2]&gate if len(idx.src)==3 else gate
|
106
|
+
|
107
|
+
# WEBGPU has a BITCAST in the index. TODO: fix
|
108
|
+
if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
|
109
|
+
|
110
|
+
if not z3_imported: raise ImportError("z3 is required for bounds checking, try IGNORE_OOB=0 or \"pip install z3-solver\"")
|
111
|
+
solver = z3.Solver(ctx=z3.Context())
|
112
|
+
z3_sink = graph_rewrite(idx.src[1].sink(mask), z3_renderer, ctx=(solver, {}))
|
113
|
+
z3_idx = z3_sink.src[0].arg
|
114
|
+
solver.add(z3_sink.src[1].arg)
|
115
|
+
if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat:
|
116
|
+
print(f"idx={idx.src[1].render(simplify=False)}")
|
117
|
+
print(f"mask & gate={mask.render(simplify=False)}")
|
118
|
+
print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
|
119
|
+
return False
|
120
|
+
return True
|
121
|
+
|
122
|
+
def validate_store(idx:UOp, val:UOp, gate:UOp=UOp.const(dtypes.bool, True)):
|
123
|
+
if gate.op is Ops.IF: gate = gate.src[0]
|
124
|
+
# we need to find the implicit gates, inverse of delete_redundant_gates
|
125
|
+
for u in val.toposort():
|
126
|
+
if u.op is Ops.IF: gate &= u.src[0]
|
127
|
+
return validate_index(idx, gate)
|
128
|
+
|
129
|
+
index_pat = UPat(Ops.INDEX, name="idx").or_casted()
|
130
|
+
|
131
|
+
# this is the matcher for the final rendered UOps
|
132
|
+
# matcher functions returns True or False (or None to not match)
|
133
|
+
spec = PatternMatcher([
|
134
|
+
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL),
|
135
|
+
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL),
|
136
|
+
(UPat(Ops.DEFINE_REG, src=()), lambda: True),
|
137
|
+
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
138
|
+
|
139
|
+
(UPat(Ops.RANGE, src=(UPat.var("x"),), name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg, int)),
|
140
|
+
(UPat(Ops.SPECIAL, src=()), lambda: True),
|
141
|
+
|
142
|
+
(UPat(Ops.VIEW, dtypes.void, src=(), name="x"), lambda x: isinstance(x.arg, ShapeTracker)),
|
143
|
+
(UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"),
|
144
|
+
lambda x,src: isinstance(x.arg, ShapeTracker) and src.op is not Ops.STORE and x.dtype.base == src.dtype.base),
|
145
|
+
|
146
|
+
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
|
147
|
+
(UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
148
|
+
|
149
|
+
# early LOAD has a <bufview, store?>
|
150
|
+
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)),)), lambda: True),
|
151
|
+
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)), UPat(Ops.STORE))), lambda: True),
|
152
|
+
|
153
|
+
# early STORE has a <bufview, val>
|
154
|
+
(UPat(Ops.STORE, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)), UPat())), lambda: True),
|
155
|
+
|
156
|
+
# **** new style load/store ****
|
157
|
+
|
158
|
+
# INDEX is used in new style load/store
|
159
|
+
# INDEX takes a <buf, alu, gate?>
|
160
|
+
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat())), lambda: True),
|
161
|
+
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
162
|
+
|
163
|
+
# LOAD on STORE
|
164
|
+
(UPat(Ops.LOAD, src=(UPat(Ops.STORE),), allow_any_len=True), lambda: True),
|
165
|
+
|
166
|
+
# LOAD takes a <bufidx, alt?, barrier?>
|
167
|
+
(UPat(Ops.LOAD, src=(index_pat, UPat(Ops.IF, name="cond")), allow_any_len=True), lambda idx,cond: validate_index(idx,cond.src[0])),
|
168
|
+
(UPat(Ops.LOAD, src=(index_pat,), allow_any_len=True), validate_index),
|
169
|
+
|
170
|
+
# STORE takes a <bufidx, val, gate?>
|
171
|
+
(UPat(Ops.STORE, dtype=dtypes.void, src=(index_pat, UPat(name="val"), UPat(Ops.IF, name="gate")), allow_any_len=True), validate_store),
|
172
|
+
(UPat(Ops.STORE, dtype=dtypes.void, src=(index_pat, UPat(name="val")), allow_any_len=True), validate_store),
|
173
|
+
|
174
|
+
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
175
|
+
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
176
|
+
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base),
|
177
|
+
# and SHL/SHR, the shift distance can be an int
|
178
|
+
(UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
|
179
|
+
(UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
180
|
+
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
|
181
|
+
|
182
|
+
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
|
183
|
+
|
184
|
+
# WMMA has a <a, b, acc>
|
185
|
+
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
|
186
|
+
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
187
|
+
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
188
|
+
|
189
|
+
# if has a <gate, barrier?>
|
190
|
+
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
|
191
|
+
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
|
192
|
+
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
193
|
+
|
194
|
+
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
195
|
+
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
196
|
+
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
197
|
+
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
|
198
|
+
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
|
199
|
+
(UPat(Ops.BARRIER, dtypes.void), lambda: True), # BARRIERs can also happen at the end of loops
|
200
|
+
|
201
|
+
# NOTE: for testing, we let sinks be anything
|
202
|
+
#(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
|
203
|
+
(UPat(Ops.SINK, dtypes.void), lambda: True),
|
204
|
+
(UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
|
205
|
+
|
206
|
+
# PTX LOAD/STORE
|
207
|
+
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
208
|
+
])
|
209
|
+
|
210
|
+
# *** this is the UOp AST spec ***
|
211
|
+
|
212
|
+
ast_spec = PatternMatcher([
|
213
|
+
# VIEW can only exist in the edges
|
214
|
+
(UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL),))), lambda: True),
|
215
|
+
(UPat(Ops.VIEW, name="view"), lambda view: len(view.src) == 0),
|
216
|
+
# all parent UOps must have the same shape
|
217
|
+
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])),
|
218
|
+
])
|
219
|
+
|
220
|
+
# ***** uop helpers *****
|
221
|
+
|
222
|
+
def type_verify(uops:list[UOp], extra_spec:PatternMatcher|None=None):
|
223
|
+
check_spec = (extra_spec+spec) if extra_spec is not None else spec
|
224
|
+
for i,u in enumerate(uops):
|
225
|
+
with Context(TRACK_MATCH_STATS=0): ret = check_spec.rewrite(u)
|
226
|
+
if cast(bool|None, ret) is not True:
|
227
|
+
if DEBUG >= 3: print_uops(uops)
|
228
|
+
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[(x.op, x.dtype, x.arg) for x in u.src]} {u.arg}")
|