tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +114 -172
- tinygrad/codegen/linearize.py +211 -81
- tinygrad/codegen/lowerer.py +30 -35
- tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
- tinygrad/codegen/transcendental.py +12 -13
- tinygrad/device.py +170 -47
- tinygrad/dtype.py +28 -26
- tinygrad/engine/jit.py +80 -63
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +162 -0
- tinygrad/engine/realize.py +58 -107
- tinygrad/engine/schedule.py +381 -314
- tinygrad/engine/search.py +40 -44
- tinygrad/gradient.py +70 -0
- tinygrad/helpers.py +77 -58
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +89 -64
- tinygrad/ops.py +562 -446
- tinygrad/renderer/__init__.py +79 -36
- tinygrad/renderer/cstyle.py +70 -84
- tinygrad/renderer/llvmir.py +32 -20
- tinygrad/renderer/ptx.py +79 -99
- tinygrad/renderer/wgsl.py +87 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libpciaccess.py +2023 -0
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +19 -21
- tinygrad/runtime/ops_amd.py +488 -327
- tinygrad/runtime/ops_clang.py +15 -28
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +129 -38
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +45 -40
- tinygrad/runtime/ops_metal.py +93 -73
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +232 -270
- tinygrad/runtime/ops_python.py +51 -46
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +63 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +384 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +26 -4
- tinygrad/runtime/support/hcq.py +254 -324
- tinygrad/runtime/support/llvm.py +32 -0
- tinygrad/shape/shapetracker.py +84 -53
- tinygrad/shape/view.py +103 -138
- tinygrad/spec.py +154 -0
- tinygrad/tensor.py +744 -496
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
- tinygrad-0.10.1.dist-info/RECORD +86 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/shape/view.py
CHANGED
@@ -1,111 +1,101 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import functools, operator, itertools
|
2
|
+
import functools, operator, itertools
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from typing import
|
4
|
+
from typing import Optional, cast, Sequence
|
5
5
|
from tinygrad.dtype import dtypes
|
6
|
-
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin
|
7
|
-
from tinygrad.helpers import prod, all_int, argsort, flatten
|
6
|
+
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop
|
7
|
+
from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
|
8
8
|
|
9
9
|
@functools.lru_cache(maxsize=None)
|
10
|
-
def canonicalize_strides(shape:
|
10
|
+
def canonicalize_strides(shape:tuple[sint, ...], strides:tuple[sint, ...]) -> tuple[sint, ...]:
|
11
11
|
return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
|
12
12
|
|
13
13
|
@functools.lru_cache(maxsize=None)
|
14
|
-
def strides_for_shape(shape:
|
14
|
+
def strides_for_shape(shape:tuple[sint, ...]) -> tuple[sint, ...]:
|
15
15
|
if not shape: return ()
|
16
16
|
strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))[::-1]
|
17
17
|
return canonicalize_strides(shape, strides)
|
18
18
|
|
19
19
|
@functools.lru_cache(maxsize=None)
|
20
|
-
def
|
21
|
-
# merge contiguous sub-parts or zero strided dims
|
20
|
+
def merge_dims(shape:tuple[int, ...], strides:tuple[int, ...], mask:Optional[tuple[tuple[int, int], ...]]=None) -> tuple[tuple[int, int, int], ...]:
|
21
|
+
# merge contiguous sub-parts or zero strided dims
|
22
|
+
# any stride 0, masked from dim=1, or contiguous part is merged into next dim.
|
23
|
+
# stride != 0 to stride == 0 starts a new merging block
|
24
|
+
# ret = tuple[(merged_size, stride, merged size w/o zero stride), ...]
|
22
25
|
if not shape: return ()
|
23
26
|
assert len(shape) == len(strides) and (mask is None or len(shape) == len(mask))
|
24
27
|
ret = [(shape[0], strides[0], shape[0] if strides[0] != 0 else 0)]
|
25
28
|
# merge this dim to next dim if size is 1
|
26
29
|
merging = (mask[0][1] - mask[0][0] == 1) if mask is not None else shape[0] == 1
|
27
30
|
for i, (s, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
|
28
|
-
last_s, last_st, last_pre_expand_s = ret[-1]
|
29
31
|
# always merge 1
|
30
32
|
if s == 1: continue
|
33
|
+
last_s, last_st, last_pre_expand_s = ret[-1]
|
31
34
|
# merge last dim with this dim if merging or strides matched
|
32
|
-
if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s)
|
33
|
-
else: ret.append((s, st, s
|
35
|
+
if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s))
|
36
|
+
else: ret.append((s, st, s))
|
34
37
|
# merge this dim to next dim if size is 1
|
35
38
|
merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1
|
36
39
|
return tuple(ret)
|
37
40
|
|
38
41
|
@functools.lru_cache(maxsize=None)
|
39
|
-
def _reshape_mask(_mask:Optional[
|
40
|
-
-> Optional[
|
42
|
+
def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) \
|
43
|
+
-> Optional[tuple[tuple[sint, sint], ...]]:
|
41
44
|
"""Returns the new mask if reshape is possible, and None if not possible."""
|
42
45
|
if _mask is None: return tuple((0, s) for s in new_shape)
|
43
|
-
if
|
44
|
-
if any(m[1] - m[0] < 1 for m in _mask): return ((0, 0),) * len(new_shape) # zero mask
|
46
|
+
if not all_int(flatten(_mask)): return None
|
45
47
|
|
46
|
-
new_mask:
|
48
|
+
new_mask: list[tuple[int, int]] = []
|
47
49
|
# _mask is all int here
|
48
|
-
r_masks, r_shape, r_new_shape = reversed(cast(
|
50
|
+
r_masks, r_shape, r_new_shape = reversed(cast(tuple[tuple[int, int], ...], _mask)), reversed(old_shape), reversed(new_shape)
|
49
51
|
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
50
52
|
|
51
53
|
while len(new_mask) < len(new_shape):
|
52
54
|
(l, r), next_stride = mask, new_dim * curr_stride
|
53
55
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
|
64
|
-
|
56
|
+
# need to split mask
|
57
|
+
if old_dim == next_stride: # simply copy the mask and get next batch for merging
|
58
|
+
new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1))
|
59
|
+
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
60
|
+
elif old_dim > next_stride: # mask can only be splitted if reshape doesn't cut across the mask.
|
61
|
+
if old_dim % next_stride != 0: return None
|
62
|
+
if (l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride: return None
|
63
|
+
new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
|
64
|
+
curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
|
65
65
|
else:
|
66
66
|
next_mask = next(r_masks, (0, 1))
|
67
67
|
# combine if the mask can unfold continuously
|
68
|
-
if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return None
|
68
|
+
if mask != (0, old_dim) and l != r and next_mask[1] - next_mask[0] != 1: return None
|
69
69
|
mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
|
70
70
|
|
71
|
-
for mask in r_masks: # if the old shape has leading 1s, need to make sure their mask is (0,1)
|
72
|
-
if mask != (0, 1): return ((0, 0),) * len(new_shape) # invalid mask
|
73
|
-
|
74
71
|
return tuple(reversed(new_mask))
|
75
72
|
|
76
|
-
def
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
return
|
84
|
-
|
85
|
-
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x
|
73
|
+
def unravel(shape:tuple[sint, ...], offset:sint) -> list[sint]:
|
74
|
+
# find the position of offset on each dimension based on shape
|
75
|
+
# similar to unravel_index in numpy/torch
|
76
|
+
acc, idxs = 1, []
|
77
|
+
for d in reversed(shape):
|
78
|
+
idxs.append((offset//acc)%d)
|
79
|
+
acc *= d
|
80
|
+
return idxs[::-1]
|
86
81
|
|
87
82
|
@dataclass(frozen=True)
|
88
83
|
class View:
|
89
|
-
shape:
|
90
|
-
strides:
|
84
|
+
shape:tuple[sint, ...]
|
85
|
+
strides:tuple[sint, ...]
|
91
86
|
offset:sint
|
92
|
-
mask:Optional[
|
87
|
+
mask:Optional[tuple[tuple[sint, sint], ...]]
|
93
88
|
contiguous:bool
|
94
89
|
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
|
102
|
-
idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
|
103
|
-
iexpr = variable_to_uop(self.offset)
|
104
|
-
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
|
90
|
+
def to_indexed_uops(self:View, idxs:Optional[Sequence[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> tuple[UOp, UOp]:
|
91
|
+
"""(idx, valid)"""
|
92
|
+
if idxs is None: idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)]
|
93
|
+
iexpr = sint_to_uop(self.offset)
|
94
|
+
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)):
|
105
95
|
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
|
106
96
|
if m is not None:
|
107
|
-
if resolve(m[0] != 0): vexpr = vexpr * idx
|
108
|
-
if resolve(m[1] != sh): vexpr = vexpr * idx
|
97
|
+
if resolve(m[0] != 0): vexpr = vexpr * (idx >= m[0])
|
98
|
+
if resolve(m[1] != sh): vexpr = vexpr * (idx < m[1])
|
109
99
|
return iexpr, vexpr
|
110
100
|
|
111
101
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
@@ -116,13 +106,13 @@ class View:
|
|
116
106
|
|
117
107
|
@staticmethod
|
118
108
|
@functools.lru_cache(maxsize=None)
|
119
|
-
def create(shape:
|
109
|
+
def create(shape:tuple[sint, ...], strides:Optional[tuple[sint, ...]]=None, offset:sint=0, mask:Optional[tuple[tuple[sint, sint], ...]]=None):
|
120
110
|
# TODO: this resolve shouldn't be needed
|
121
111
|
if not all(resolve(s >= 0) for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
|
122
112
|
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
|
123
113
|
# canonicalize 0 in shape
|
124
114
|
if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
|
125
|
-
# canonicalize
|
115
|
+
# canonicalize no-op mask
|
126
116
|
if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
|
127
117
|
# if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
|
128
118
|
# then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
|
@@ -134,7 +124,7 @@ class View:
|
|
134
124
|
# simplify as we go
|
135
125
|
if isinstance(offset, UOp): offset = cast(sint, offset.ssimplify())
|
136
126
|
shape = tuple(cast(sint, x.ssimplify()) if isinstance(x, UOp) else x for x in shape)
|
137
|
-
# TODO: enabling stride simplification breaks
|
127
|
+
# TODO: enabling stride simplification breaks symbolic jit
|
138
128
|
"""
|
139
129
|
strides = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in strides)
|
140
130
|
if mask: mask = tuple((s.ssimplify() if isinstance(s, UOp) else s, e.ssimplify() if isinstance(e, UOp) else e) for s,e in mask)
|
@@ -143,15 +133,15 @@ class View:
|
|
143
133
|
return View(shape, strides, offset, mask, contiguous)
|
144
134
|
|
145
135
|
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
146
|
-
def vars(self) ->
|
136
|
+
def vars(self) -> set[Variable]:
|
147
137
|
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
|
148
138
|
return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, UOp)], set())
|
149
139
|
|
150
140
|
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
151
|
-
def unbind(self) ->
|
141
|
+
def unbind(self) -> tuple[View, dict[Variable, int]]:
|
152
142
|
var_unboundvar_val = [(v, v.unbind()) for v in self.vars()]
|
153
143
|
unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
|
154
|
-
def substitute(x): return x if isinstance(x, int) else x.substitute(unbound_vars)
|
144
|
+
def substitute(x:sint): return x if isinstance(x, int) else x.substitute(unbound_vars)
|
155
145
|
new_shape = tuple(map(substitute, self.shape))
|
156
146
|
new_strides = tuple(map(substitute, self.strides))
|
157
147
|
new_offset = substitute(self.offset)
|
@@ -165,27 +155,26 @@ class View:
|
|
165
155
|
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
|
166
156
|
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
|
167
157
|
if vm1.mask:
|
168
|
-
|
169
|
-
|
170
|
-
|
158
|
+
if (new_vm1 := vm1.shrink(vm1.mask)) == vm1 or (merged := vm2 + new_vm1) is None: return None
|
159
|
+
return merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
|
160
|
+
if not all_int(vm1.shape): return None
|
171
161
|
|
172
162
|
# Project vm1's offset and strides on to vm2.
|
173
|
-
origin =
|
174
|
-
terms:
|
175
|
-
strides:
|
163
|
+
origin = unravel(vm2.shape, vm1.offset)
|
164
|
+
terms: list[list[tuple[int, sint]]] = [[] for _ in vm2.shape]
|
165
|
+
strides: list[sint] = [0] * len(vm1.shape)
|
176
166
|
for d1, st in enumerate(vm1.strides):
|
177
167
|
if st == 0: continue
|
178
|
-
for d2, (o, s1) in enumerate(zip(origin,
|
168
|
+
for d2, (o, s1) in enumerate(zip(origin, unravel(vm2.shape, vm1.offset + st))):
|
179
169
|
if (s1 := s1 - o) == 0: continue
|
180
170
|
terms[d2].append((d1, s1))
|
181
171
|
strides[d1] += s1 * vm2.strides[d2]
|
182
172
|
|
183
173
|
# Merge dimensions in vm2 if required.
|
184
174
|
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
185
|
-
|
186
|
-
idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
175
|
+
idxs: list[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
187
176
|
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
|
188
|
-
extents:
|
177
|
+
extents: list[tuple[sint, UOp]] = []
|
189
178
|
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
|
190
179
|
merged_term += (sum([idxs[d1] * s1 for d1, s1 in term]) + o) * merged_size
|
191
180
|
merged_size *= s
|
@@ -194,8 +183,8 @@ class View:
|
|
194
183
|
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
|
195
184
|
if resolve(merged_term != 0): return None
|
196
185
|
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
|
197
|
-
reshaped_vm2
|
198
|
-
|
186
|
+
if (reshaped_vm2 := vm2.reshape(vm2_shape)) is None: return None
|
187
|
+
# NOTE: this != to prevent infinite loop
|
199
188
|
if reshaped_vm2.shape != vm2.shape: return reshaped_vm2 + vm1
|
200
189
|
|
201
190
|
if vm2.mask:
|
@@ -203,54 +192,45 @@ class View:
|
|
203
192
|
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
|
204
193
|
for (b, e), o, term, (_, t) in zip(vm2.mask, origin, terms, reversed(extents)):
|
205
194
|
if resolve(b <= t.vmin and t.vmax < e, False): continue
|
206
|
-
if not all_int([o, b, e]):
|
207
|
-
bad = True
|
208
|
-
continue
|
209
195
|
if len(term) != 1:
|
210
196
|
if not term and newe: newe[0] = 0
|
211
197
|
else: bad = True
|
212
198
|
continue
|
213
199
|
d1, s1 = term[0]
|
214
|
-
|
215
|
-
bad = True
|
216
|
-
continue
|
217
|
-
newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
|
200
|
+
newb[d1] = max(newb[d1], ceildiv(b - o if s1 > 0 else e - o - 1, s1))
|
218
201
|
newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
|
219
202
|
|
220
203
|
# If any of vm1 was masked off, try again with that mask in place.
|
221
|
-
for b, e, s in zip(newb, newe, vm1.shape):
|
222
|
-
|
223
|
-
return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
|
204
|
+
if any((b, e) != (0, s) for b, e, s in zip(newb, newe, vm1.shape)):
|
205
|
+
return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
|
224
206
|
# Otherwise if vm2's mask was violated, then cannot merge.
|
225
207
|
if bad: return None
|
226
208
|
|
227
209
|
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
|
228
210
|
|
229
211
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
230
|
-
def invert(self, out_shape:
|
212
|
+
def invert(self, out_shape:tuple[sint, ...]) -> Optional[View]:
|
231
213
|
ret = View.create(self.shape)
|
232
214
|
if self.mask: ret = ret.shrink(self.mask)
|
233
|
-
ret = ret.
|
215
|
+
ret = ret.flip(tuple(x < 0 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
|
234
216
|
return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
|
235
217
|
|
236
218
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
237
219
|
def minify(self):
|
238
|
-
min_shape = tuple(x[0] for x in
|
220
|
+
min_shape = tuple(x[0] for x in merge_dims(self.shape, self.strides, self.mask))
|
239
221
|
return nv if (nv := self.reshape(min_shape)) else self
|
240
222
|
|
241
|
-
def __unsafe_resize(self, arg:
|
223
|
+
def __unsafe_resize(self, arg: tuple[tuple[sint, sint], ...], mask=None) -> View:
|
242
224
|
offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
|
243
225
|
if self.mask:
|
244
226
|
# move the old mask
|
245
227
|
nmask = tuple([(smax(0, smin(mx-ax,ay-ax)), smax(0, smin(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)])
|
246
228
|
# merge the masks if we have two
|
247
229
|
mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
248
|
-
|
249
|
-
if mask is not None and all(m[0] == 0 and m[1] == s for m,s in zip(mask, shape)): mask = None
|
250
|
-
return View.create(tuple(s.ssimplify() if isinstance(s, UOp) else s for s in shape), self.strides, self.offset+offset, mask)
|
230
|
+
return View.create(tuple([y-x for x,y in arg]), self.strides, self.offset+offset, mask)
|
251
231
|
|
252
232
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
253
|
-
def pad(self, arg:
|
233
|
+
def pad(self, arg: tuple[tuple[sint, sint], ...]) -> View:
|
254
234
|
assert len(arg) == len(self.shape), f"invalid pad {arg} for {self.shape}"
|
255
235
|
# NOTE: not checking for symbolic arg
|
256
236
|
for b,e in arg: assert not all_int([b,e]) or b>=0 and e>=0, f"invalid pad {arg} for {self.shape}"
|
@@ -261,58 +241,46 @@ class View:
|
|
261
241
|
return self
|
262
242
|
|
263
243
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
264
|
-
def shrink(self, arg:
|
244
|
+
def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> View:
|
265
245
|
assert len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
|
266
246
|
# NOTE: not checking for symbolic arg
|
267
247
|
for s,(b,e) in zip(self.shape,arg): assert not all_int([s,b,e]) or (0<=b<=e<=s), f"invalid shrink {arg} for {self.shape}"
|
268
248
|
return self.__unsafe_resize(arg)
|
269
249
|
|
270
250
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
271
|
-
def expand(self, new_shape:
|
251
|
+
def expand(self, new_shape: tuple[sint, ...]) -> View:
|
272
252
|
if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}")
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
# TODO: this resolve might be wrong
|
277
|
-
assert all((not resolve(s != x, False) or s == 1) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
278
|
-
# NOTE: can the mask ever be (0,0)?
|
253
|
+
# NOTE: does not check multiple of symbolic shape
|
254
|
+
assert all(resolve(s == ns) or s == 1 for s,ns in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
255
|
+
if 0 in self.shape: return View.create(new_shape)
|
279
256
|
# TODO: this resolve may not be needed, but it's hard because vars need to be sorted
|
280
257
|
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns, False) else m) \
|
281
258
|
for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
|
282
259
|
return View.create(new_shape, self.strides, self.offset, mask)
|
283
260
|
|
284
261
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
285
|
-
def permute(self, axis:
|
262
|
+
def permute(self, axis: tuple[int, ...]) -> View:
|
286
263
|
assert sorted(axis) == list(range(len(self.shape))), f"invalid permutation {axis} of len {len(self.shape)}"
|
287
264
|
return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
|
288
265
|
tuple(self.mask[a] for a in axis) if self.mask is not None else None)
|
289
266
|
|
290
267
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
291
|
-
def
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)])
|
296
|
-
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
|
297
|
-
mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) \
|
298
|
-
for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
|
299
|
-
return View.create(new_shape, strides, self.offset + offset, mask)
|
268
|
+
def flip(self, arg: tuple[bool, ...]) -> View:
|
269
|
+
offset = sum((s-1)*z for s,z,f in zip(self.shape, self.strides, arg) if f)
|
270
|
+
mask = tuple((s-my,s-mx) if f else (mx,my) for (mx,my),s,f in zip(self.mask, self.shape, arg)) if self.mask is not None else None
|
271
|
+
return View.create(self.shape, tuple(-z if f else z for z,f in zip(self.strides, arg)), self.offset+offset, mask)
|
300
272
|
|
301
273
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
302
|
-
def reshape(self, new_shape:
|
274
|
+
def reshape(self, new_shape: tuple[sint, ...]) -> Optional[View]:
|
303
275
|
if self.shape == new_shape: return self
|
304
276
|
|
305
|
-
|
306
|
-
assert all(resolve(x >= 0) for x in new_shape), f"shape can't contain negative numbers {new_shape}"
|
307
|
-
if 0 in self.shape:
|
308
|
-
assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
|
309
|
-
return View.create(new_shape)
|
277
|
+
assert all(x >= 0 for x in new_shape), f"shape can't contain negative numbers {new_shape}"
|
310
278
|
# check for the same size
|
311
279
|
if (self_all_int := all_int(self.shape)):
|
312
280
|
assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
313
|
-
if resolve(prod(self.shape) != prod(new_shape), False):
|
314
|
-
raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
281
|
+
if resolve(prod(self.shape) != prod(new_shape), False): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
315
282
|
|
283
|
+
if 0 in self.shape: return View.create(new_shape)
|
316
284
|
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
|
317
285
|
|
318
286
|
# after the asserts, it's okay to check contiguous
|
@@ -322,29 +290,26 @@ class View:
|
|
322
290
|
if self_all_int and not all_int(new_shape):
|
323
291
|
if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
324
292
|
for si, so in zip(self.shape, new_shape):
|
325
|
-
if isinstance(so, int):
|
326
|
-
|
327
|
-
else:
|
328
|
-
var_vals = dict([v.unbind() for v in so.vars()])
|
329
|
-
if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
293
|
+
if not isinstance(so, int): so = sym_infer(so, dict([v.unbind() for v in so.vars()]))
|
294
|
+
if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
330
295
|
# all dimensions matched, return the new view directly
|
331
296
|
return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
|
332
297
|
|
333
|
-
|
334
|
-
for
|
298
|
+
r_strides, r_new_shape = [], reversed(new_shape)
|
299
|
+
for merged_size, new_stride, real_size in reversed(merge_dims(self.shape, self.strides, self.mask)):
|
300
|
+
# TODO: write with get_contraction
|
335
301
|
acc = 1
|
336
302
|
# TODO: third resolve shouldn't be needed
|
337
|
-
while resolve(acc <=
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
|
303
|
+
while resolve(acc <= merged_size) and resolve(acc != merged_size) and resolve((new_dim := next(r_new_shape, 0)) > 0):
|
304
|
+
r_strides.append(new_stride * acc)
|
305
|
+
acc = acc * new_dim
|
306
|
+
if not resolve(acc < real_size): new_stride = 0
|
307
|
+
if resolve(acc != merged_size): return None
|
308
|
+
new_strides = (0,) * (len(new_shape) - len(r_strides)) + tuple(r_strides[::-1])
|
309
|
+
|
310
|
+
if (new_mask:=_reshape_mask(self.mask, self.shape, new_shape)) is not None:
|
311
|
+
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
|
312
|
+
(sum(m[0] * s for m,s in zip(new_mask, new_strides)))
|
313
|
+
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
|
349
314
|
|
350
315
|
return None
|
tinygrad/spec.py
ADDED
@@ -0,0 +1,154 @@
|
|
1
|
+
from typing import cast
|
2
|
+
from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops
|
3
|
+
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
|
4
|
+
from tinygrad.helpers import all_int, all_same, dedup, prod
|
5
|
+
|
6
|
+
# *** this is the spec of a Tensor in UOp ***
|
7
|
+
|
8
|
+
tensor_uop_spec = PatternMatcher([
|
9
|
+
(UPat(Ops.DEVICE, dtypes.void, (), name="device"), lambda device: isinstance(device.arg, str)),
|
10
|
+
(UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE),), name="buf"),
|
11
|
+
lambda buf: isinstance(buf.arg, tuple) and len(buf.arg) == 2 and all_int(buf.arg) and isinstance(buf.dtype, (DType, ImageDType))),
|
12
|
+
|
13
|
+
(UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)),
|
14
|
+
# naturally correct
|
15
|
+
lambda mv,x: (isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
|
16
|
+
# "make things that can't be images not images" can change the buffer dtype
|
17
|
+
# this is fine as long as it's a realized buffer and base dtypes match.
|
18
|
+
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.is_realized)),
|
19
|
+
|
20
|
+
# Tensor variable bindings
|
21
|
+
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
|
22
|
+
|
23
|
+
# Tensor const has a device and an unmasked ShapeTracker of stride 0
|
24
|
+
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)),
|
25
|
+
lambda st: st.st.views[0].mask is None and len(st.st.views) == 1 and all(s == 0 for s in st.st.views[0].strides)),
|
26
|
+
|
27
|
+
# DETACH and CONTIGUOUS change how we interpret the source UOp
|
28
|
+
# CONTIGUOUS ensures the source UOp realizes
|
29
|
+
(UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype),
|
30
|
+
|
31
|
+
# COPY
|
32
|
+
# NOTE: the arg here specifies clone=True, which prevents folding same device copy
|
33
|
+
(UPat(Ops.COPY, name="copy", src=(UPat(Ops.DEVICE), UPat.var("x"))), lambda copy,x: isinstance(copy.arg, bool) and copy.dtype == x.dtype),
|
34
|
+
|
35
|
+
# VIEW(BUFFER) applies a ShapeTracker on top of the underlying device buffer
|
36
|
+
# NOTE: VIEW size exactly matches the underlying BUFFER, tensor doesn't apply movement ops to the VIEW
|
37
|
+
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),
|
38
|
+
lambda view,buf: view.dtype == buf.dtype and view.size == buf.size and view.st.contiguous),
|
39
|
+
|
40
|
+
# ASSIGN changes the value of a realized buffer
|
41
|
+
(UPat(Ops.ASSIGN, name="assign", src=(UPat.var("target"), UPat.var("new_val"))),
|
42
|
+
lambda assign,target,new_val: target.is_realized and (assign.dtype == target.dtype == new_val.dtype)),
|
43
|
+
])
|
44
|
+
|
45
|
+
# ***** uop type spec *****
|
46
|
+
|
47
|
+
# this is the matcher for the final rendered UOps
|
48
|
+
# matcher functions returns True or False (or None to not match)
|
49
|
+
spec = PatternMatcher([
|
50
|
+
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
|
51
|
+
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
|
52
|
+
(UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True),
|
53
|
+
lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
|
54
|
+
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
55
|
+
|
56
|
+
(UPat(Ops.RANGE, src=(UPat.var("x"), UPat.var("y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype and isinstance(rng.arg, int)),
|
57
|
+
(UPat(Ops.SPECIAL, src=()), lambda: True),
|
58
|
+
|
59
|
+
# TODO: confirm the args of both of these are shapetrackers
|
60
|
+
(UPat(Ops.VIEW, dtypes.void, src=()), lambda: True),
|
61
|
+
(UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype),
|
62
|
+
|
63
|
+
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
|
64
|
+
(UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
65
|
+
|
66
|
+
# early LOAD has a <buf, shapetracker, store?>
|
67
|
+
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
|
68
|
+
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
|
69
|
+
|
70
|
+
# early STORE has a <buf, shapetracker, val>
|
71
|
+
(UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
|
72
|
+
|
73
|
+
# **** new style load/store ****
|
74
|
+
|
75
|
+
# INDEX is used in new style load/store
|
76
|
+
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
|
77
|
+
|
78
|
+
# LOAD takes a <bufidx, alt?, gate?, barrier?>
|
79
|
+
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
|
80
|
+
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
|
81
|
+
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat.var("alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
|
82
|
+
|
83
|
+
# STORE takes a <bufidx, val, gate?>
|
84
|
+
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True),
|
85
|
+
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
86
|
+
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True),
|
87
|
+
|
88
|
+
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
89
|
+
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
90
|
+
(UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base),
|
91
|
+
# and SHL/SHR, the shift distance can be an int
|
92
|
+
(UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
|
93
|
+
(UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
94
|
+
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
|
95
|
+
|
96
|
+
(UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
|
97
|
+
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
|
98
|
+
|
99
|
+
# WMMA has a <a, b, acc>
|
100
|
+
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
|
101
|
+
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
102
|
+
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
103
|
+
|
104
|
+
# if has a <gate, barrier?>
|
105
|
+
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
|
106
|
+
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
|
107
|
+
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
108
|
+
|
109
|
+
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
110
|
+
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
111
|
+
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
112
|
+
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
|
113
|
+
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
|
114
|
+
|
115
|
+
# NOTE: for testing, we let sinks be anything
|
116
|
+
#(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
|
117
|
+
(UPat(Ops.SINK, dtypes.void), lambda: True),
|
118
|
+
(UPat(Ops.NOOP), lambda: True),
|
119
|
+
|
120
|
+
# PTX LOAD/STORE
|
121
|
+
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
122
|
+
])
|
123
|
+
|
124
|
+
# *** this is the spec of a Kernel in UOp ***
|
125
|
+
|
126
|
+
kernel_spec = PatternMatcher([
|
127
|
+
(UPat(Ops.DEVICE, src=()), lambda: True),
|
128
|
+
(UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE),)), lambda: True),
|
129
|
+
# TODO: currently kernel only has buffer parents, this is incomplete. it should be BUFFER and ASSIGN
|
130
|
+
(UPat(Ops.KERNEL, src=UPat(Ops.BUFFER)), lambda: True),
|
131
|
+
])
|
132
|
+
|
133
|
+
# *** this is the UOp shape spec ***
|
134
|
+
|
135
|
+
def verify_sink_dims(sink:UOp):
|
136
|
+
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sink.toposort if x.op is not Ops.SINK and x.st is not None])]
|
137
|
+
return all_same([x.st_arg.size for x in sink.src]) and all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims)
|
138
|
+
|
139
|
+
shape_spec = PatternMatcher([
|
140
|
+
# shapes must have either 1 or n in each dimension
|
141
|
+
(UPat(Ops.SINK, src=UPat(Ops.STORE), allow_any_len=True, name="sink"), verify_sink_dims),
|
142
|
+
# all parent UOps must have the same shape
|
143
|
+
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])),
|
144
|
+
])
|
145
|
+
|
146
|
+
# ***** uop helpers *****
|
147
|
+
|
148
|
+
def type_verify(uops:list[UOp], *extra_specs:PatternMatcher):
|
149
|
+
specs = [spec, *extra_specs]
|
150
|
+
for i,u in enumerate(uops):
|
151
|
+
spec_ret = [cast(bool|None, s.rewrite(u)) for s in specs]
|
152
|
+
if any(ret is False for ret in spec_ret) or all(ret is None for ret in spec_ret):
|
153
|
+
print_uops(uops)
|
154
|
+
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}")
|