tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/uop/spec.py ADDED
@@ -0,0 +1,228 @@
1
+ from typing import cast, Callable
2
+ from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite
3
+ from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace
4
+ from tinygrad.helpers import all_same, prod, DEBUG, ContextVar, Context
5
+ from tinygrad.shape.shapetracker import ShapeTracker
6
+ try:
7
+ import z3
8
+
9
+ # IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND
10
+ def z3_cdiv(a, b):return z3.If((a<0), z3.If(0<b, (a+(b-1))/b, (a-(b+1))/b), a/b)
11
+ z3_alu: dict[Ops, Callable] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()),
12
+ Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If,
13
+ Ops.MAX: lambda a,b: z3.If(a<b, b, a)}
14
+ def create_bounded(name:str, vmin, vmax, solver:z3.Solver) -> z3.ArithRef:
15
+ s = z3.Int(name, ctx=solver.ctx)
16
+ solver.add(vmin <= s, s <= vmax)
17
+ return s
18
+
19
+ # ctx is (solver, load_number_dict)
20
+ z3_renderer = PatternMatcher([
21
+ # Ops.SPECIAL can have symbolic arg but it wont be in the toposort beacuse its not a src, we need to add it manually
22
+ (UPat(Ops.SPECIAL, src=(), name="x"), lambda x: UOp(Ops.SPECIAL, arg=x.arg[0], src=(x.ufix(x.arg[1]),))),
23
+ (UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg, 0, x.src[0].arg-1, ctx[0]))),
24
+ (UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0]))),
25
+ (UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"ridx{x.arg}", 0, x.src[0].arg-1, ctx[0]))),
26
+ (UPat(Ops.LOAD, dtypes.ints, name="x"),
27
+ lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.vmin, x.vmax, ctx[0]))),
28
+ (UPat(Ops.CONST, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx))),
29
+ (UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.bool,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
30
+ (UPat(Ops.CAST, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.vmin, x.vmax, ctx[0]))),
31
+ (UPat(Ops.XOR, src=UPat(Ops.NOOP), name="x"),
32
+ lambda x: UOp(Ops.NOOP, arg=z3.BV2Int(z3_alu[x.op](*(z3.Int2BV(s.arg, x.dtype.itemsize*8) for s in x.src))))),
33
+ (UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=z3_alu[x.op](*(s.arg for s in x.src)))),
34
+ ])
35
+
36
+ z3_imported = True
37
+ except (ImportError, AttributeError): z3_imported = False
38
+
39
+ # if you have z3 installed, by default we check the bounds
40
+ IGNORE_OOB = ContextVar("IGNORE_OOB", int(not z3_imported))
41
+
42
+ buffer_spec = PatternMatcher([
43
+ (UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
44
+ (UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d:
45
+ isinstance(d.arg, str) or (isinstance(d.arg, tuple) and all(isinstance(s, str) for s in d.arg))),
46
+ (UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), allow_any_len=True, name="buf"),
47
+ lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
48
+ (UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
49
+ lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
50
+ (UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.MSTACK, src=UPat(Ops.BUFFER)),)), lambda: True),
51
+ # allow VIEW here. TODO: what views specifically are allowed? does this mess with gradient?
52
+ (UPat(Ops.VIEW), lambda: True),
53
+ ])
54
+
55
+ assign_spec = PatternMatcher([
56
+ # KERNEL can attach to an ASSIGN to describe the compute required to realize a BUFFER
57
+ (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
58
+
59
+ # ASSIGN has a target and a value. It can also optionally depend on other assigns
60
+ (UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
61
+
62
+ # MSELECT chooses one of the multi buffers
63
+ (UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
64
+
65
+ # MSTACK combines buffers into multi
66
+ (UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)),
67
+ ])
68
+
69
+ # *** this is the spec of a Tensor in UOp ***
70
+
71
+ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
72
+ (UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)),
73
+ # naturally correct
74
+ lambda mv,x: (isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
75
+ # "make things that can't be images not images" can change the buffer dtype
76
+ # this is fine as long as it's a realized buffer and base dtypes match.
77
+ ((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.base.op is Ops.BUFFER)),
78
+ (UPat(Ops.VIEW, src=(UPat.var("x"),)), lambda x: x.base.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.CONST, Ops.DEVICE}),
79
+
80
+ # Tensor variable bindings
81
+ (UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
82
+
83
+ # Tensor const has a device and an unmasked ShapeTracker of stride 0
84
+ # NOTE: variables in shape can cause multiple views in this ShapeTracker and other issues, see TestSymbolicJit.test_ones_sum
85
+ (UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)),
86
+ lambda st: len(st.st.views) == 1 and all(v.mask is None for v in st.st.views)),
87
+
88
+ # DETACH and CONTIGUOUS change how we interpret the source UOp
89
+ # CONTIGUOUS ensures the source UOp realizes
90
+ (UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="root", src=(UPat.var("x"),), arg=None),
91
+ lambda root,x: root.dtype == x.dtype),
92
+
93
+ # COPY/ALLREDUCE/MULTI
94
+ (UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE)), arg=None), lambda copy,x: copy.dtype == x.dtype),
95
+ (UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)),
96
+ (UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)),
97
+ ])
98
+
99
+ # ***** uop type spec *****
100
+
101
+ def validate_index(idx:UOp, gate:UOp=UOp.const(dtypes.bool, True)):
102
+ if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := cast(PtrDType, idx.src[0].dtype).size) == -1: return True
103
+ # We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
104
+ if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True
105
+ mask = idx.src[2]&gate if len(idx.src)==3 else gate
106
+
107
+ # WEBGPU has a BITCAST in the index. TODO: fix
108
+ if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
109
+
110
+ if not z3_imported: raise ImportError("z3 is required for bounds checking, try IGNORE_OOB=0 or \"pip install z3-solver\"")
111
+ solver = z3.Solver(ctx=z3.Context())
112
+ z3_sink = graph_rewrite(idx.src[1].sink(mask), z3_renderer, ctx=(solver, {}))
113
+ z3_idx = z3_sink.src[0].arg
114
+ solver.add(z3_sink.src[1].arg)
115
+ if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat:
116
+ print(f"idx={idx.src[1].render(simplify=False)}")
117
+ print(f"mask & gate={mask.render(simplify=False)}")
118
+ print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
119
+ return False
120
+ return True
121
+
122
+ def validate_store(idx:UOp, val:UOp, gate:UOp=UOp.const(dtypes.bool, True)):
123
+ if gate.op is Ops.IF: gate = gate.src[0]
124
+ # we need to find the implicit gates, inverse of delete_redundant_gates
125
+ for u in val.toposort():
126
+ if u.op is Ops.IF: gate &= u.src[0]
127
+ return validate_index(idx, gate)
128
+
129
+ index_pat = UPat(Ops.INDEX, name="idx").or_casted()
130
+
131
+ # this is the matcher for the final rendered UOps
132
+ # matcher functions returns True or False (or None to not match)
133
+ spec = PatternMatcher([
134
+ (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL),
135
+ (UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL),
136
+ (UPat(Ops.DEFINE_REG, src=()), lambda: True),
137
+ (UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
138
+
139
+ (UPat(Ops.RANGE, src=(UPat.var("x"),), name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg, int)),
140
+ (UPat(Ops.SPECIAL, src=()), lambda: True),
141
+
142
+ (UPat(Ops.VIEW, dtypes.void, src=(), name="x"), lambda x: isinstance(x.arg, ShapeTracker)),
143
+ (UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"),
144
+ lambda x,src: isinstance(x.arg, ShapeTracker) and src.op is not Ops.STORE and x.dtype.base == src.dtype.base),
145
+
146
+ (UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
147
+ (UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
148
+
149
+ # early LOAD has a <bufview, store?>
150
+ (UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)),)), lambda: True),
151
+ (UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)), UPat(Ops.STORE))), lambda: True),
152
+
153
+ # early STORE has a <bufview, val>
154
+ (UPat(Ops.STORE, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)), UPat())), lambda: True),
155
+
156
+ # **** new style load/store ****
157
+
158
+ # INDEX is used in new style load/store
159
+ # INDEX takes a <buf, alu, gate?>
160
+ (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat())), lambda: True),
161
+ (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
162
+
163
+ # LOAD on STORE
164
+ (UPat(Ops.LOAD, src=(UPat(Ops.STORE),), allow_any_len=True), lambda: True),
165
+
166
+ # LOAD takes a <bufidx, alt?, barrier?>
167
+ (UPat(Ops.LOAD, src=(index_pat, UPat(Ops.IF, name="cond")), allow_any_len=True), lambda idx,cond: validate_index(idx,cond.src[0])),
168
+ (UPat(Ops.LOAD, src=(index_pat,), allow_any_len=True), validate_index),
169
+
170
+ # STORE takes a <bufidx, val, gate?>
171
+ (UPat(Ops.STORE, dtype=dtypes.void, src=(index_pat, UPat(name="val"), UPat(Ops.IF, name="gate")), allow_any_len=True), validate_store),
172
+ (UPat(Ops.STORE, dtype=dtypes.void, src=(index_pat, UPat(name="val")), allow_any_len=True), validate_store),
173
+
174
+ # most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
175
+ (UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
176
+ (UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base),
177
+ # and SHL/SHR, the shift distance can be an int
178
+ (UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
179
+ (UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
180
+ (UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
181
+
182
+ (UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
183
+
184
+ # WMMA has a <a, b, acc>
185
+ (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
186
+ (UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
187
+ (UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
188
+
189
+ # if has a <gate, barrier?>
190
+ (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
191
+ (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
192
+ (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
193
+
194
+ (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
195
+ (UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
196
+ (UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
197
+ (UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
198
+ (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
199
+ (UPat(Ops.BARRIER, dtypes.void), lambda: True), # BARRIERs can also happen at the end of loops
200
+
201
+ # NOTE: for testing, we let sinks be anything
202
+ #(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
203
+ (UPat(Ops.SINK, dtypes.void), lambda: True),
204
+ (UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
205
+
206
+ # PTX LOAD/STORE
207
+ (UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
208
+ ])
209
+
210
+ # *** this is the UOp AST spec ***
211
+
212
+ ast_spec = PatternMatcher([
213
+ # VIEW can only exist in the edges
214
+ (UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL),))), lambda: True),
215
+ (UPat(Ops.VIEW, name="view"), lambda view: len(view.src) == 0),
216
+ # all parent UOps must have the same shape
217
+ (UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])),
218
+ ])
219
+
220
+ # ***** uop helpers *****
221
+
222
+ def type_verify(uops:list[UOp], extra_spec:PatternMatcher|None=None):
223
+ check_spec = (extra_spec+spec) if extra_spec is not None else spec
224
+ for i,u in enumerate(uops):
225
+ with Context(TRACK_MATCH_STATS=0): ret = check_spec.rewrite(u)
226
+ if cast(bool|None, ret) is not True:
227
+ if DEBUG >= 3: print_uops(uops)
228
+ raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[(x.op, x.dtype, x.arg) for x in u.src]} {u.arg}")