tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/shape/view.py CHANGED
@@ -1,111 +1,101 @@
1
1
  from __future__ import annotations
2
- import functools, operator, itertools, math
2
+ import functools, operator, itertools
3
3
  from dataclasses import dataclass
4
- from typing import Tuple, List, Optional, Dict, Set, cast
4
+ from typing import Optional, cast, Sequence
5
5
  from tinygrad.dtype import dtypes
6
- from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin
7
- from tinygrad.helpers import prod, all_int, argsort, flatten
6
+ from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop
7
+ from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
8
8
 
9
9
  @functools.lru_cache(maxsize=None)
10
- def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
10
+ def canonicalize_strides(shape:tuple[sint, ...], strides:tuple[sint, ...]) -> tuple[sint, ...]:
11
11
  return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
12
12
 
13
13
  @functools.lru_cache(maxsize=None)
14
- def strides_for_shape(shape:Tuple[sint, ...]) -> Tuple[sint, ...]:
14
+ def strides_for_shape(shape:tuple[sint, ...]) -> tuple[sint, ...]:
15
15
  if not shape: return ()
16
16
  strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))[::-1]
17
17
  return canonicalize_strides(shape, strides)
18
18
 
19
19
  @functools.lru_cache(maxsize=None)
20
- def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]:
21
- # merge contiguous sub-parts or zero strided dims. ret = Tuple[(merged_size, stride, merged size w/o zero stride), ...]
20
+ def merge_dims(shape:tuple[int, ...], strides:tuple[int, ...], mask:Optional[tuple[tuple[int, int], ...]]=None) -> tuple[tuple[int, int, int], ...]:
21
+ # merge contiguous sub-parts or zero strided dims
22
+ # any stride 0, masked from dim=1, or contiguous part is merged into next dim.
23
+ # stride != 0 to stride == 0 starts a new merging block
24
+ # ret = tuple[(merged_size, stride, merged size w/o zero stride), ...]
22
25
  if not shape: return ()
23
26
  assert len(shape) == len(strides) and (mask is None or len(shape) == len(mask))
24
27
  ret = [(shape[0], strides[0], shape[0] if strides[0] != 0 else 0)]
25
28
  # merge this dim to next dim if size is 1
26
29
  merging = (mask[0][1] - mask[0][0] == 1) if mask is not None else shape[0] == 1
27
30
  for i, (s, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
28
- last_s, last_st, last_pre_expand_s = ret[-1]
29
31
  # always merge 1
30
32
  if s == 1: continue
33
+ last_s, last_st, last_pre_expand_s = ret[-1]
31
34
  # merge last dim with this dim if merging or strides matched
32
- if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st != 0 else 0)
33
- else: ret.append((s, st, s if st != 0 else 0))
35
+ if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s))
36
+ else: ret.append((s, st, s))
34
37
  # merge this dim to next dim if size is 1
35
38
  merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1
36
39
  return tuple(ret)
37
40
 
38
41
  @functools.lru_cache(maxsize=None)
39
- def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) \
40
- -> Optional[Tuple[Tuple[sint, sint], ...]]:
42
+ def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) \
43
+ -> Optional[tuple[tuple[sint, sint], ...]]:
41
44
  """Returns the new mask if reshape is possible, and None if not possible."""
42
45
  if _mask is None: return tuple((0, s) for s in new_shape)
43
- if any(not isinstance(m[0], int) or not isinstance(m[1], int) for m in _mask): return None
44
- if any(m[1] - m[0] < 1 for m in _mask): return ((0, 0),) * len(new_shape) # zero mask
46
+ if not all_int(flatten(_mask)): return None
45
47
 
46
- new_mask: List[Tuple[int, int]] = []
48
+ new_mask: list[tuple[int, int]] = []
47
49
  # _mask is all int here
48
- r_masks, r_shape, r_new_shape = reversed(cast(Tuple[Tuple[int, int], ...], _mask)), reversed(old_shape), reversed(new_shape)
50
+ r_masks, r_shape, r_new_shape = reversed(cast(tuple[tuple[int, int], ...], _mask)), reversed(old_shape), reversed(new_shape)
49
51
  curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
50
52
 
51
53
  while len(new_mask) < len(new_shape):
52
54
  (l, r), next_stride = mask, new_dim * curr_stride
53
55
 
54
- if old_dim >= next_stride: # need to split mask.
55
- if old_dim == next_stride: # simply copy the mask and get next batch for merging
56
- new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1))
57
- curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
58
-
59
- else: # mask can only be splitted if reshape doesn't cut across the mask.
60
- if (((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
61
- or old_dim % next_stride != 0): return None
62
- new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
63
- curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
64
-
56
+ # need to split mask
57
+ if old_dim == next_stride: # simply copy the mask and get next batch for merging
58
+ new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1))
59
+ curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
60
+ elif old_dim > next_stride: # mask can only be splitted if reshape doesn't cut across the mask.
61
+ if old_dim % next_stride != 0: return None
62
+ if (l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride: return None
63
+ new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
64
+ curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
65
65
  else:
66
66
  next_mask = next(r_masks, (0, 1))
67
67
  # combine if the mask can unfold continuously
68
- if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return None
68
+ if mask != (0, old_dim) and l != r and next_mask[1] - next_mask[0] != 1: return None
69
69
  mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
70
70
 
71
- for mask in r_masks: # if the old shape has leading 1s, need to make sure their mask is (0,1)
72
- if mask != (0, 1): return ((0, 0),) * len(new_shape) # invalid mask
73
-
74
71
  return tuple(reversed(new_mask))
75
72
 
76
- def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
77
- strides = strides_for_shape(shape)
78
- result = []
79
- for stride in strides:
80
- here = offs // stride if stride != 0 else 0
81
- result.append(here)
82
- offs -= here * stride
83
- return result
84
-
85
- def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x
73
+ def unravel(shape:tuple[sint, ...], offset:sint) -> list[sint]:
74
+ # find the position of offset on each dimension based on shape
75
+ # similar to unravel_index in numpy/torch
76
+ acc, idxs = 1, []
77
+ for d in reversed(shape):
78
+ idxs.append((offset//acc)%d)
79
+ acc *= d
80
+ return idxs[::-1]
86
81
 
87
82
  @dataclass(frozen=True)
88
83
  class View:
89
- shape:Tuple[sint, ...]
90
- strides:Tuple[sint, ...]
84
+ shape:tuple[sint, ...]
85
+ strides:tuple[sint, ...]
91
86
  offset:sint
92
- mask:Optional[Tuple[Tuple[sint, sint], ...]]
87
+ mask:Optional[tuple[tuple[sint, sint], ...]]
93
88
  contiguous:bool
94
89
 
95
- @functools.cached_property
96
- def t(self):
97
- return tuple(x.tuplize if isinstance(x, UOp) else (x,) \
98
- for x in self.shape+self.strides+(self.offset,)+(tuple(flatten(self.mask)) if self.mask is not None else tuple()))
99
- def __lt__(self, o:View): return self.t < o.t
100
-
101
- def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
102
- idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
103
- iexpr = variable_to_uop(self.offset)
104
- for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
90
+ def to_indexed_uops(self:View, idxs:Optional[Sequence[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> tuple[UOp, UOp]:
91
+ """(idx, valid)"""
92
+ if idxs is None: idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)]
93
+ iexpr = sint_to_uop(self.offset)
94
+ for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)):
105
95
  if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
106
96
  if m is not None:
107
- if resolve(m[0] != 0): vexpr = vexpr * idx.ge(m[0])
108
- if resolve(m[1] != sh): vexpr = vexpr * idx.lt(m[1])
97
+ if resolve(m[0] != 0): vexpr = vexpr * (idx >= m[0])
98
+ if resolve(m[1] != sh): vexpr = vexpr * (idx < m[1])
109
99
  return iexpr, vexpr
110
100
 
111
101
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@@ -116,13 +106,12 @@ class View:
116
106
 
117
107
  @staticmethod
118
108
  @functools.lru_cache(maxsize=None)
119
- def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
120
- # TODO: this resolve shouldn't be needed
121
- if not all(resolve(s >= 0) for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
109
+ def create(shape:tuple[sint, ...], strides:Optional[tuple[sint, ...]]=None, offset:sint=0, mask:Optional[tuple[tuple[sint, sint], ...]]=None):
110
+ if not all(s >= 0 for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
122
111
  strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
123
112
  # canonicalize 0 in shape
124
113
  if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
125
- # canonicalize empty mask
114
+ # canonicalize no-op mask
126
115
  if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
127
116
  # if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
128
117
  # then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
@@ -134,7 +123,7 @@ class View:
134
123
  # simplify as we go
135
124
  if isinstance(offset, UOp): offset = cast(sint, offset.ssimplify())
136
125
  shape = tuple(cast(sint, x.ssimplify()) if isinstance(x, UOp) else x for x in shape)
137
- # TODO: enabling stride simplification breaks it
126
+ # TODO: enabling stride simplification breaks symbolic jit
138
127
  """
139
128
  strides = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in strides)
140
129
  if mask: mask = tuple((s.ssimplify() if isinstance(s, UOp) else s, e.ssimplify() if isinstance(e, UOp) else e) for s,e in mask)
@@ -143,15 +132,15 @@ class View:
143
132
  return View(shape, strides, offset, mask, contiguous)
144
133
 
145
134
  @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
146
- def vars(self) -> Set[Variable]:
135
+ def vars(self) -> set[Variable]:
147
136
  flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
148
137
  return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, UOp)], set())
149
138
 
150
139
  @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
151
- def unbind(self) -> Tuple[View, Dict[Variable, int]]:
140
+ def unbind(self) -> tuple[View, dict[Variable, int]]:
152
141
  var_unboundvar_val = [(v, v.unbind()) for v in self.vars()]
153
142
  unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
154
- def substitute(x): return x if isinstance(x, int) else x.substitute(unbound_vars)
143
+ def substitute(x:sint): return x if isinstance(x, int) else x.substitute(unbound_vars)
155
144
  new_shape = tuple(map(substitute, self.shape))
156
145
  new_strides = tuple(map(substitute, self.strides))
157
146
  new_offset = substitute(self.offset)
@@ -165,27 +154,26 @@ class View:
165
154
  if vm1.contiguous and vm1.shape == vm2.shape: return vm2
166
155
  if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
167
156
  if vm1.mask:
168
- for b,e in vm1.mask:
169
- if resolve(b >= e, False): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
170
- return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
157
+ if (new_vm1 := vm1.shrink(vm1.mask)) == vm1 or (merged := vm2 + new_vm1) is None: return None
158
+ return merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
159
+ if not all_int(vm1.shape): return None
171
160
 
172
161
  # Project vm1's offset and strides on to vm2.
173
- origin = un1d(vm2.shape, vm1.offset)
174
- terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
175
- strides: List[sint] = [0] * len(vm1.shape)
162
+ origin = unravel(vm2.shape, vm1.offset)
163
+ terms: list[list[tuple[int, sint]]] = [[] for _ in vm2.shape]
164
+ strides: list[sint] = [0] * len(vm1.shape)
176
165
  for d1, st in enumerate(vm1.strides):
177
166
  if st == 0: continue
178
- for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
167
+ for d2, (o, s1) in enumerate(zip(origin, unravel(vm2.shape, vm1.offset + st))):
179
168
  if (s1 := s1 - o) == 0: continue
180
169
  terms[d2].append((d1, s1))
181
170
  strides[d1] += s1 * vm2.strides[d2]
182
171
 
183
172
  # Merge dimensions in vm2 if required.
184
173
  # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
185
- if not all_int(vm1.shape): return None
186
- idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
174
+ idxs: list[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
187
175
  merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
188
- extents: List[Tuple[sint, UOp]] = []
176
+ extents: list[tuple[sint, UOp]] = []
189
177
  for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
190
178
  merged_term += (sum([idxs[d1] * s1 for d1, s1 in term]) + o) * merged_size
191
179
  merged_size *= s
@@ -194,8 +182,8 @@ class View:
194
182
  merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
195
183
  if resolve(merged_term != 0): return None
196
184
  if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
197
- reshaped_vm2 = vm2.reshape(vm2_shape)
198
- if reshaped_vm2 is None: return None
185
+ if (reshaped_vm2 := vm2.reshape(vm2_shape)) is None: return None
186
+ # NOTE: this != to prevent infinite loop
199
187
  if reshaped_vm2.shape != vm2.shape: return reshaped_vm2 + vm1
200
188
 
201
189
  if vm2.mask:
@@ -203,54 +191,45 @@ class View:
203
191
  newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
204
192
  for (b, e), o, term, (_, t) in zip(vm2.mask, origin, terms, reversed(extents)):
205
193
  if resolve(b <= t.vmin and t.vmax < e, False): continue
206
- if not all_int([o, b, e]):
207
- bad = True
208
- continue
209
194
  if len(term) != 1:
210
195
  if not term and newe: newe[0] = 0
211
196
  else: bad = True
212
197
  continue
213
198
  d1, s1 = term[0]
214
- if not isinstance(s1, int) or not isinstance(newe[d1], int):
215
- bad = True
216
- continue
217
- newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
199
+ newb[d1] = max(newb[d1], ceildiv(b - o if s1 > 0 else e - o - 1, s1))
218
200
  newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
219
201
 
220
202
  # If any of vm1 was masked off, try again with that mask in place.
221
- for b, e, s in zip(newb, newe, vm1.shape):
222
- if (b, e) != (0, s):
223
- return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
203
+ if any((b, e) != (0, s) for b, e, s in zip(newb, newe, vm1.shape)):
204
+ return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
224
205
  # Otherwise if vm2's mask was violated, then cannot merge.
225
206
  if bad: return None
226
207
 
227
208
  return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
228
209
 
229
210
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
230
- def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]:
211
+ def invert(self, out_shape:tuple[sint, ...]) -> Optional[View]:
231
212
  ret = View.create(self.shape)
232
213
  if self.mask: ret = ret.shrink(self.mask)
233
- ret = ret.stride(tuple(-1 if x < 0 else 1 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
214
+ ret = ret.flip(tuple(x < 0 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
234
215
  return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
235
216
 
236
217
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
237
218
  def minify(self):
238
- min_shape = tuple(x[0] for x in _merge_dims(self.shape, self.strides, self.mask))
219
+ min_shape = tuple(x[0] for x in merge_dims(self.shape, self.strides, self.mask))
239
220
  return nv if (nv := self.reshape(min_shape)) else self
240
221
 
241
- def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
222
+ def __unsafe_resize(self, arg: tuple[tuple[sint, sint], ...], mask=None) -> View:
242
223
  offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
243
224
  if self.mask:
244
225
  # move the old mask
245
226
  nmask = tuple([(smax(0, smin(mx-ax,ay-ax)), smax(0, smin(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)])
246
227
  # merge the masks if we have two
247
228
  mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
248
- shape = [y-x for x,y in arg]
249
- if mask is not None and all(m[0] == 0 and m[1] == s for m,s in zip(mask, shape)): mask = None
250
- return View.create(tuple(s.ssimplify() if isinstance(s, UOp) else s for s in shape), self.strides, self.offset+offset, mask)
229
+ return View.create(tuple([y-x for x,y in arg]), self.strides, self.offset+offset, mask)
251
230
 
252
231
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
253
- def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
232
+ def pad(self, arg: tuple[tuple[sint, sint], ...]) -> View:
254
233
  assert len(arg) == len(self.shape), f"invalid pad {arg} for {self.shape}"
255
234
  # NOTE: not checking for symbolic arg
256
235
  for b,e in arg: assert not all_int([b,e]) or b>=0 and e>=0, f"invalid pad {arg} for {self.shape}"
@@ -261,58 +240,46 @@ class View:
261
240
  return self
262
241
 
263
242
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
264
- def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
243
+ def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> View:
265
244
  assert len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
266
245
  # NOTE: not checking for symbolic arg
267
246
  for s,(b,e) in zip(self.shape,arg): assert not all_int([s,b,e]) or (0<=b<=e<=s), f"invalid shrink {arg} for {self.shape}"
268
247
  return self.__unsafe_resize(arg)
269
248
 
270
249
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
271
- def expand(self, new_shape: Tuple[sint, ...]) -> View:
250
+ def expand(self, new_shape: tuple[sint, ...]) -> View:
272
251
  if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}")
273
- if 0 in self.shape:
274
- assert all((s == x == 0) or (s > 0 and (x % s) == 0) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
275
- return View.create(new_shape)
276
- # TODO: this resolve might be wrong
277
- assert all((not resolve(s != x, False) or s == 1) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
278
- # NOTE: can the mask ever be (0,0)?
252
+ # NOTE: does not check multiple of symbolic shape
253
+ assert all(resolve(s == ns) or s == 1 for s,ns in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
254
+ if 0 in self.shape: return View.create(new_shape)
279
255
  # TODO: this resolve may not be needed, but it's hard because vars need to be sorted
280
256
  mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns, False) else m) \
281
257
  for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
282
258
  return View.create(new_shape, self.strides, self.offset, mask)
283
259
 
284
260
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
285
- def permute(self, axis: Tuple[int, ...]) -> View:
261
+ def permute(self, axis: tuple[int, ...]) -> View:
286
262
  assert sorted(axis) == list(range(len(self.shape))), f"invalid permutation {axis} of len {len(self.shape)}"
287
263
  return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
288
264
  tuple(self.mask[a] for a in axis) if self.mask is not None else None)
289
265
 
290
266
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
291
- def stride(self, mul: Tuple[int, ...]) -> View:
292
- # except for the negative case, you can build this from the others. invertible in the negative case
293
- assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}"
294
- strides = tuple([z*m for z,m in zip(self.strides, mul)])
295
- new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)])
296
- offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
297
- mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) \
298
- for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
299
- return View.create(new_shape, strides, self.offset + offset, mask)
267
+ def flip(self, arg: tuple[bool, ...]) -> View:
268
+ offset = sum((s-1)*z for s,z,f in zip(self.shape, self.strides, arg) if f)
269
+ mask = tuple((s-my,s-mx) if f else (mx,my) for (mx,my),s,f in zip(self.mask, self.shape, arg)) if self.mask is not None else None
270
+ return View.create(self.shape, tuple(-z if f else z for z,f in zip(self.strides, arg)), self.offset+offset, mask)
300
271
 
301
272
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
302
- def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
273
+ def reshape(self, new_shape: tuple[sint, ...]) -> Optional[View]:
303
274
  if self.shape == new_shape: return self
304
275
 
305
- # TODO: this resolve shouldn't be needed
306
- assert all(resolve(x >= 0) for x in new_shape), f"shape can't contain negative numbers {new_shape}"
307
- if 0 in self.shape:
308
- assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
309
- return View.create(new_shape)
276
+ if not all(x >= 0 for x in new_shape): raise ValueError(f"shape can't contain negative numbers {new_shape}")
310
277
  # check for the same size
311
278
  if (self_all_int := all_int(self.shape)):
312
279
  assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
313
- if resolve(prod(self.shape) != prod(new_shape), False):
314
- raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
280
+ if resolve(prod(self.shape) != prod(new_shape), False): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
315
281
 
282
+ if 0 in self.shape: return View.create(new_shape)
316
283
  if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
317
284
 
318
285
  # after the asserts, it's okay to check contiguous
@@ -322,29 +289,26 @@ class View:
322
289
  if self_all_int and not all_int(new_shape):
323
290
  if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
324
291
  for si, so in zip(self.shape, new_shape):
325
- if isinstance(so, int):
326
- if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
327
- else:
328
- var_vals = dict([v.unbind() for v in so.vars()])
329
- if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
292
+ if not isinstance(so, int): so = sym_infer(so, dict([v.unbind() for v in so.vars()]))
293
+ if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
330
294
  # all dimensions matched, return the new view directly
331
295
  return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
332
296
 
333
- strides, r_new_shape = [], reversed(new_shape)
334
- for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
297
+ r_strides, r_new_shape = [], reversed(new_shape)
298
+ for merged_size, new_stride, real_size in reversed(merge_dims(self.shape, self.strides, self.mask)):
299
+ # TODO: write with get_contraction
335
300
  acc = 1
336
301
  # TODO: third resolve shouldn't be needed
337
- while resolve(acc <= merged_dim) and resolve(acc != merged_dim) and resolve((new_dim := next(r_new_shape, 0)) > 0):
338
- strides.append(new_stride)
339
- if resolve(new_dim != 1): new_stride *= (new_dim if resolve((acc := acc * new_dim) < real_dim) else 0)
340
- if resolve(acc != merged_dim): break
341
- else:
342
- strides += [0,] * (len(new_shape) - len(strides))
343
- new_mask = _reshape_mask(self.mask, self.shape, new_shape)
344
- if new_mask is not None:
345
- new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask), tuple(reversed(strides)))
346
- extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
347
- (sum(m[0] * s for m,s in zip(new_mask, new_strides)))
348
- return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
302
+ while resolve(acc <= merged_size) and resolve(acc != merged_size) and resolve((new_dim := next(r_new_shape, 0)) > 0):
303
+ r_strides.append(new_stride * acc)
304
+ acc = acc * new_dim
305
+ if not resolve(acc < real_size): new_stride = 0
306
+ if resolve(acc != merged_size): return None
307
+ new_strides = (0,) * (len(new_shape) - len(r_strides)) + tuple(r_strides[::-1])
308
+
309
+ if (new_mask:=_reshape_mask(self.mask, self.shape, new_shape)) is not None:
310
+ extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
311
+ (sum(m[0] * s for m,s in zip(new_mask, new_strides)))
312
+ return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
349
313
 
350
314
  return None
tinygrad/spec.py ADDED
@@ -0,0 +1,155 @@
1
+ from typing import cast
2
+ from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops
3
+ from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
4
+ from tinygrad.helpers import all_same, dedup, prod
5
+
6
+ buffer_spec = PatternMatcher([
7
+ (UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
8
+ (UPat(Ops.DEVICE, dtypes.void, (), name="device"), lambda device: isinstance(device.arg, str)),
9
+ (UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="buf"),
10
+ lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
11
+ ])
12
+
13
+ # *** this is the spec of a Tensor in UOp ***
14
+
15
+ tensor_uop_spec = buffer_spec+PatternMatcher([
16
+ (UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)),
17
+ # naturally correct
18
+ lambda mv,x: (isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
19
+ # "make things that can't be images not images" can change the buffer dtype
20
+ # this is fine as long as it's a realized buffer and base dtypes match.
21
+ ((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.is_realized)),
22
+ (UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.CONST, Ops.DEVICE}),)), lambda: False),
23
+
24
+ # Tensor variable bindings
25
+ (UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
26
+
27
+ # Tensor const has a device and an unmasked ShapeTracker of stride 0
28
+ (UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)),
29
+ lambda st: st.st.views[0].mask is None and len(st.st.views) == 1 and all(s == 0 for s in st.st.views[0].strides)),
30
+
31
+ # DETACH and CONTIGUOUS change how we interpret the source UOp
32
+ # CONTIGUOUS ensures the source UOp realizes
33
+ (UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype),
34
+
35
+ # COPY
36
+ # NOTE: the arg here specifies clone=True, which prevents folding same device copy
37
+ (UPat(Ops.COPY, name="copy", src=(UPat(Ops.DEVICE), UPat.var("x"))), lambda copy,x: isinstance(copy.arg, bool) and copy.dtype == x.dtype),
38
+
39
+ # ASSIGN changes the value of a realized buffer
40
+ (UPat(Ops.ASSIGN, name="assign", src=(UPat.var("target"), UPat.var("new_val"))),
41
+ lambda assign,target,new_val: target.is_realized and (assign.dtype == target.dtype == new_val.dtype)),
42
+ ])
43
+
44
+ # ***** uop type spec *****
45
+
46
+ # this is the matcher for the final rendered UOps
47
+ # matcher functions returns True or False (or None to not match)
48
+ spec = PatternMatcher([
49
+ (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
50
+ (UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
51
+ (UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True),
52
+ lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
53
+ (UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
54
+
55
+ (UPat(Ops.RANGE, src=(UPat.var("x"), UPat.var("y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype and isinstance(rng.arg, int)),
56
+ (UPat(Ops.SPECIAL, src=()), lambda: True),
57
+
58
+ # TODO: confirm the args of both of these are shapetrackers
59
+ (UPat(Ops.VIEW, dtypes.void, src=()), lambda: True),
60
+ (UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype.base == src.dtype.base),
61
+
62
+ (UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
63
+ (UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
64
+
65
+ # early LOAD has a <buf, shapetracker, store?>
66
+ (UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
67
+ (UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
68
+
69
+ # early STORE has a <buf, shapetracker, val>
70
+ (UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
71
+
72
+ # **** new style load/store ****
73
+
74
+ # INDEX is used in new style load/store
75
+ (UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
76
+
77
+ # LOAD takes a <bufidx, alt?, gate?, barrier?>
78
+ (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
79
+ (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
80
+ (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat.var("alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
81
+
82
+ # STORE takes a <bufidx, val, gate?>
83
+ (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True),
84
+ (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
85
+ (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True),
86
+
87
+ # most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
88
+ (UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
89
+ (UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base),
90
+ # and SHL/SHR, the shift distance can be an int
91
+ (UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
92
+ (UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
93
+ (UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
94
+
95
+ (UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
96
+ (UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
97
+
98
+ # WMMA has a <a, b, acc>
99
+ (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
100
+ (UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
101
+ (UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
102
+
103
+ # if has a <gate, barrier?>
104
+ (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
105
+ (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
106
+ (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
107
+
108
+ (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
109
+ (UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
110
+ (UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
111
+ (UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
112
+ (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
113
+
114
+ # NOTE: for testing, we let sinks be anything
115
+ #(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
116
+ (UPat((Ops.NAME, Ops.SINK), dtypes.void), lambda: True),
117
+ (UPat((Ops.NOOP, Ops.CUSTOM)), lambda: True),
118
+
119
+ # PTX LOAD/STORE
120
+ (UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
121
+ ])
122
+
123
+ # *** this is the spec of a Kernel in UOp ***
124
+
125
+ kernel_spec = buffer_spec+PatternMatcher([
126
+ (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True),
127
+ # assign has a buffer view and kernel source, it can optionally depend on other assigns
128
+ (UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
129
+ # view/sink/const can also exist in the kernel graph
130
+ (UPat((Ops.VIEW, Ops.SINK, Ops.CONST)), lambda: True),
131
+ (UPat(GroupOp.All), lambda: False),
132
+ ])
133
+
134
+ # *** this is the UOp shape spec ***
135
+
136
+ def verify_sink_dims(sink:UOp):
137
+ shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sink.toposort if x.op is not Ops.SINK and x.st is not None])]
138
+ return all_same([x.st_arg.size for x in sink.src]) and all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims)
139
+
140
+ shape_spec = PatternMatcher([
141
+ # shapes must have either 1 or n in each dimension
142
+ (UPat(Ops.SINK, src=UPat(Ops.STORE), allow_any_len=True, name="sink"), verify_sink_dims),
143
+ # all parent UOps must have the same shape
144
+ (UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])),
145
+ ])
146
+
147
+ # ***** uop helpers *****
148
+
149
+ def type_verify(uops:list[UOp], *extra_specs:PatternMatcher):
150
+ specs = [spec, *extra_specs]
151
+ for i,u in enumerate(uops):
152
+ spec_ret = [cast(bool|None, s.rewrite(u)) for s in specs]
153
+ if any(ret is False for ret in spec_ret) or all(ret is None for ret in spec_ret):
154
+ print_uops(uops)
155
+ raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}")