tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/shape/view.py
CHANGED
@@ -1,111 +1,101 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import functools, operator, itertools
|
2
|
+
import functools, operator, itertools
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from typing import
|
4
|
+
from typing import Optional, cast, Sequence
|
5
5
|
from tinygrad.dtype import dtypes
|
6
|
-
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin
|
7
|
-
from tinygrad.helpers import prod, all_int, argsort, flatten
|
6
|
+
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop
|
7
|
+
from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
|
8
8
|
|
9
9
|
@functools.lru_cache(maxsize=None)
|
10
|
-
def canonicalize_strides(shape:
|
10
|
+
def canonicalize_strides(shape:tuple[sint, ...], strides:tuple[sint, ...]) -> tuple[sint, ...]:
|
11
11
|
return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
|
12
12
|
|
13
13
|
@functools.lru_cache(maxsize=None)
|
14
|
-
def strides_for_shape(shape:
|
14
|
+
def strides_for_shape(shape:tuple[sint, ...]) -> tuple[sint, ...]:
|
15
15
|
if not shape: return ()
|
16
16
|
strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))[::-1]
|
17
17
|
return canonicalize_strides(shape, strides)
|
18
18
|
|
19
19
|
@functools.lru_cache(maxsize=None)
|
20
|
-
def
|
21
|
-
# merge contiguous sub-parts or zero strided dims
|
20
|
+
def merge_dims(shape:tuple[int, ...], strides:tuple[int, ...], mask:Optional[tuple[tuple[int, int], ...]]=None) -> tuple[tuple[int, int, int], ...]:
|
21
|
+
# merge contiguous sub-parts or zero strided dims
|
22
|
+
# any stride 0, masked from dim=1, or contiguous part is merged into next dim.
|
23
|
+
# stride != 0 to stride == 0 starts a new merging block
|
24
|
+
# ret = tuple[(merged_size, stride, merged size w/o zero stride), ...]
|
22
25
|
if not shape: return ()
|
23
26
|
assert len(shape) == len(strides) and (mask is None or len(shape) == len(mask))
|
24
27
|
ret = [(shape[0], strides[0], shape[0] if strides[0] != 0 else 0)]
|
25
28
|
# merge this dim to next dim if size is 1
|
26
29
|
merging = (mask[0][1] - mask[0][0] == 1) if mask is not None else shape[0] == 1
|
27
30
|
for i, (s, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
|
28
|
-
last_s, last_st, last_pre_expand_s = ret[-1]
|
29
31
|
# always merge 1
|
30
32
|
if s == 1: continue
|
33
|
+
last_s, last_st, last_pre_expand_s = ret[-1]
|
31
34
|
# merge last dim with this dim if merging or strides matched
|
32
|
-
if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s)
|
33
|
-
else: ret.append((s, st, s
|
35
|
+
if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s))
|
36
|
+
else: ret.append((s, st, s))
|
34
37
|
# merge this dim to next dim if size is 1
|
35
38
|
merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1
|
36
39
|
return tuple(ret)
|
37
40
|
|
38
41
|
@functools.lru_cache(maxsize=None)
|
39
|
-
def _reshape_mask(_mask:Optional[
|
40
|
-
-> Optional[
|
42
|
+
def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) \
|
43
|
+
-> Optional[tuple[tuple[sint, sint], ...]]:
|
41
44
|
"""Returns the new mask if reshape is possible, and None if not possible."""
|
42
45
|
if _mask is None: return tuple((0, s) for s in new_shape)
|
43
|
-
if
|
44
|
-
if any(m[1] - m[0] < 1 for m in _mask): return ((0, 0),) * len(new_shape) # zero mask
|
46
|
+
if not all_int(flatten(_mask)): return None
|
45
47
|
|
46
|
-
new_mask:
|
48
|
+
new_mask: list[tuple[int, int]] = []
|
47
49
|
# _mask is all int here
|
48
|
-
r_masks, r_shape, r_new_shape = reversed(cast(
|
50
|
+
r_masks, r_shape, r_new_shape = reversed(cast(tuple[tuple[int, int], ...], _mask)), reversed(old_shape), reversed(new_shape)
|
49
51
|
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
50
52
|
|
51
53
|
while len(new_mask) < len(new_shape):
|
52
54
|
(l, r), next_stride = mask, new_dim * curr_stride
|
53
55
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
|
64
|
-
|
56
|
+
# need to split mask
|
57
|
+
if old_dim == next_stride: # simply copy the mask and get next batch for merging
|
58
|
+
new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1))
|
59
|
+
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
60
|
+
elif old_dim > next_stride: # mask can only be splitted if reshape doesn't cut across the mask.
|
61
|
+
if old_dim % next_stride != 0: return None
|
62
|
+
if (l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride: return None
|
63
|
+
new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
|
64
|
+
curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
|
65
65
|
else:
|
66
66
|
next_mask = next(r_masks, (0, 1))
|
67
67
|
# combine if the mask can unfold continuously
|
68
|
-
if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return None
|
68
|
+
if mask != (0, old_dim) and l != r and next_mask[1] - next_mask[0] != 1: return None
|
69
69
|
mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
|
70
70
|
|
71
|
-
for mask in r_masks: # if the old shape has leading 1s, need to make sure their mask is (0,1)
|
72
|
-
if mask != (0, 1): return ((0, 0),) * len(new_shape) # invalid mask
|
73
|
-
|
74
71
|
return tuple(reversed(new_mask))
|
75
72
|
|
76
|
-
def
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
return
|
84
|
-
|
85
|
-
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x
|
73
|
+
def unravel(shape:tuple[sint, ...], offset:sint) -> list[sint]:
|
74
|
+
# find the position of offset on each dimension based on shape
|
75
|
+
# similar to unravel_index in numpy/torch
|
76
|
+
acc, idxs = 1, []
|
77
|
+
for d in reversed(shape):
|
78
|
+
idxs.append((offset//acc)%d)
|
79
|
+
acc *= d
|
80
|
+
return idxs[::-1]
|
86
81
|
|
87
82
|
@dataclass(frozen=True)
|
88
83
|
class View:
|
89
|
-
shape:
|
90
|
-
strides:
|
84
|
+
shape:tuple[sint, ...]
|
85
|
+
strides:tuple[sint, ...]
|
91
86
|
offset:sint
|
92
|
-
mask:Optional[
|
87
|
+
mask:Optional[tuple[tuple[sint, sint], ...]]
|
93
88
|
contiguous:bool
|
94
89
|
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
|
102
|
-
idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
|
103
|
-
iexpr = variable_to_uop(self.offset)
|
104
|
-
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
|
90
|
+
def to_indexed_uops(self:View, idxs:Optional[Sequence[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> tuple[UOp, UOp]:
|
91
|
+
"""(idx, valid)"""
|
92
|
+
if idxs is None: idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)]
|
93
|
+
iexpr = sint_to_uop(self.offset)
|
94
|
+
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)):
|
105
95
|
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
|
106
96
|
if m is not None:
|
107
|
-
if resolve(m[0] != 0): vexpr = vexpr * idx
|
108
|
-
if resolve(m[1] != sh): vexpr = vexpr * idx
|
97
|
+
if resolve(m[0] != 0): vexpr = vexpr * (idx >= m[0])
|
98
|
+
if resolve(m[1] != sh): vexpr = vexpr * (idx < m[1])
|
109
99
|
return iexpr, vexpr
|
110
100
|
|
111
101
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
@@ -116,13 +106,12 @@ class View:
|
|
116
106
|
|
117
107
|
@staticmethod
|
118
108
|
@functools.lru_cache(maxsize=None)
|
119
|
-
def create(shape:
|
120
|
-
|
121
|
-
if not all(resolve(s >= 0) for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
|
109
|
+
def create(shape:tuple[sint, ...], strides:Optional[tuple[sint, ...]]=None, offset:sint=0, mask:Optional[tuple[tuple[sint, sint], ...]]=None):
|
110
|
+
if not all(s >= 0 for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
|
122
111
|
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
|
123
112
|
# canonicalize 0 in shape
|
124
113
|
if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
|
125
|
-
# canonicalize
|
114
|
+
# canonicalize no-op mask
|
126
115
|
if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
|
127
116
|
# if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
|
128
117
|
# then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
|
@@ -134,7 +123,7 @@ class View:
|
|
134
123
|
# simplify as we go
|
135
124
|
if isinstance(offset, UOp): offset = cast(sint, offset.ssimplify())
|
136
125
|
shape = tuple(cast(sint, x.ssimplify()) if isinstance(x, UOp) else x for x in shape)
|
137
|
-
# TODO: enabling stride simplification breaks
|
126
|
+
# TODO: enabling stride simplification breaks symbolic jit
|
138
127
|
"""
|
139
128
|
strides = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in strides)
|
140
129
|
if mask: mask = tuple((s.ssimplify() if isinstance(s, UOp) else s, e.ssimplify() if isinstance(e, UOp) else e) for s,e in mask)
|
@@ -143,15 +132,15 @@ class View:
|
|
143
132
|
return View(shape, strides, offset, mask, contiguous)
|
144
133
|
|
145
134
|
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
146
|
-
def vars(self) ->
|
135
|
+
def vars(self) -> set[Variable]:
|
147
136
|
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
|
148
137
|
return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, UOp)], set())
|
149
138
|
|
150
139
|
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
151
|
-
def unbind(self) ->
|
140
|
+
def unbind(self) -> tuple[View, dict[Variable, int]]:
|
152
141
|
var_unboundvar_val = [(v, v.unbind()) for v in self.vars()]
|
153
142
|
unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
|
154
|
-
def substitute(x): return x if isinstance(x, int) else x.substitute(unbound_vars)
|
143
|
+
def substitute(x:sint): return x if isinstance(x, int) else x.substitute(unbound_vars)
|
155
144
|
new_shape = tuple(map(substitute, self.shape))
|
156
145
|
new_strides = tuple(map(substitute, self.strides))
|
157
146
|
new_offset = substitute(self.offset)
|
@@ -165,27 +154,26 @@ class View:
|
|
165
154
|
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
|
166
155
|
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
|
167
156
|
if vm1.mask:
|
168
|
-
|
169
|
-
|
170
|
-
|
157
|
+
if (new_vm1 := vm1.shrink(vm1.mask)) == vm1 or (merged := vm2 + new_vm1) is None: return None
|
158
|
+
return merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
|
159
|
+
if not all_int(vm1.shape): return None
|
171
160
|
|
172
161
|
# Project vm1's offset and strides on to vm2.
|
173
|
-
origin =
|
174
|
-
terms:
|
175
|
-
strides:
|
162
|
+
origin = unravel(vm2.shape, vm1.offset)
|
163
|
+
terms: list[list[tuple[int, sint]]] = [[] for _ in vm2.shape]
|
164
|
+
strides: list[sint] = [0] * len(vm1.shape)
|
176
165
|
for d1, st in enumerate(vm1.strides):
|
177
166
|
if st == 0: continue
|
178
|
-
for d2, (o, s1) in enumerate(zip(origin,
|
167
|
+
for d2, (o, s1) in enumerate(zip(origin, unravel(vm2.shape, vm1.offset + st))):
|
179
168
|
if (s1 := s1 - o) == 0: continue
|
180
169
|
terms[d2].append((d1, s1))
|
181
170
|
strides[d1] += s1 * vm2.strides[d2]
|
182
171
|
|
183
172
|
# Merge dimensions in vm2 if required.
|
184
173
|
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
185
|
-
|
186
|
-
idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
174
|
+
idxs: list[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
187
175
|
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
|
188
|
-
extents:
|
176
|
+
extents: list[tuple[sint, UOp]] = []
|
189
177
|
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
|
190
178
|
merged_term += (sum([idxs[d1] * s1 for d1, s1 in term]) + o) * merged_size
|
191
179
|
merged_size *= s
|
@@ -194,8 +182,8 @@ class View:
|
|
194
182
|
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
|
195
183
|
if resolve(merged_term != 0): return None
|
196
184
|
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
|
197
|
-
reshaped_vm2
|
198
|
-
|
185
|
+
if (reshaped_vm2 := vm2.reshape(vm2_shape)) is None: return None
|
186
|
+
# NOTE: this != to prevent infinite loop
|
199
187
|
if reshaped_vm2.shape != vm2.shape: return reshaped_vm2 + vm1
|
200
188
|
|
201
189
|
if vm2.mask:
|
@@ -203,54 +191,45 @@ class View:
|
|
203
191
|
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
|
204
192
|
for (b, e), o, term, (_, t) in zip(vm2.mask, origin, terms, reversed(extents)):
|
205
193
|
if resolve(b <= t.vmin and t.vmax < e, False): continue
|
206
|
-
if not all_int([o, b, e]):
|
207
|
-
bad = True
|
208
|
-
continue
|
209
194
|
if len(term) != 1:
|
210
195
|
if not term and newe: newe[0] = 0
|
211
196
|
else: bad = True
|
212
197
|
continue
|
213
198
|
d1, s1 = term[0]
|
214
|
-
|
215
|
-
bad = True
|
216
|
-
continue
|
217
|
-
newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
|
199
|
+
newb[d1] = max(newb[d1], ceildiv(b - o if s1 > 0 else e - o - 1, s1))
|
218
200
|
newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
|
219
201
|
|
220
202
|
# If any of vm1 was masked off, try again with that mask in place.
|
221
|
-
for b, e, s in zip(newb, newe, vm1.shape):
|
222
|
-
|
223
|
-
return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
|
203
|
+
if any((b, e) != (0, s) for b, e, s in zip(newb, newe, vm1.shape)):
|
204
|
+
return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
|
224
205
|
# Otherwise if vm2's mask was violated, then cannot merge.
|
225
206
|
if bad: return None
|
226
207
|
|
227
208
|
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
|
228
209
|
|
229
210
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
230
|
-
def invert(self, out_shape:
|
211
|
+
def invert(self, out_shape:tuple[sint, ...]) -> Optional[View]:
|
231
212
|
ret = View.create(self.shape)
|
232
213
|
if self.mask: ret = ret.shrink(self.mask)
|
233
|
-
ret = ret.
|
214
|
+
ret = ret.flip(tuple(x < 0 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
|
234
215
|
return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
|
235
216
|
|
236
217
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
237
218
|
def minify(self):
|
238
|
-
min_shape = tuple(x[0] for x in
|
219
|
+
min_shape = tuple(x[0] for x in merge_dims(self.shape, self.strides, self.mask))
|
239
220
|
return nv if (nv := self.reshape(min_shape)) else self
|
240
221
|
|
241
|
-
def __unsafe_resize(self, arg:
|
222
|
+
def __unsafe_resize(self, arg: tuple[tuple[sint, sint], ...], mask=None) -> View:
|
242
223
|
offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
|
243
224
|
if self.mask:
|
244
225
|
# move the old mask
|
245
226
|
nmask = tuple([(smax(0, smin(mx-ax,ay-ax)), smax(0, smin(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)])
|
246
227
|
# merge the masks if we have two
|
247
228
|
mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
248
|
-
|
249
|
-
if mask is not None and all(m[0] == 0 and m[1] == s for m,s in zip(mask, shape)): mask = None
|
250
|
-
return View.create(tuple(s.ssimplify() if isinstance(s, UOp) else s for s in shape), self.strides, self.offset+offset, mask)
|
229
|
+
return View.create(tuple([y-x for x,y in arg]), self.strides, self.offset+offset, mask)
|
251
230
|
|
252
231
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
253
|
-
def pad(self, arg:
|
232
|
+
def pad(self, arg: tuple[tuple[sint, sint], ...]) -> View:
|
254
233
|
assert len(arg) == len(self.shape), f"invalid pad {arg} for {self.shape}"
|
255
234
|
# NOTE: not checking for symbolic arg
|
256
235
|
for b,e in arg: assert not all_int([b,e]) or b>=0 and e>=0, f"invalid pad {arg} for {self.shape}"
|
@@ -261,58 +240,46 @@ class View:
|
|
261
240
|
return self
|
262
241
|
|
263
242
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
264
|
-
def shrink(self, arg:
|
243
|
+
def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> View:
|
265
244
|
assert len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
|
266
245
|
# NOTE: not checking for symbolic arg
|
267
246
|
for s,(b,e) in zip(self.shape,arg): assert not all_int([s,b,e]) or (0<=b<=e<=s), f"invalid shrink {arg} for {self.shape}"
|
268
247
|
return self.__unsafe_resize(arg)
|
269
248
|
|
270
249
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
271
|
-
def expand(self, new_shape:
|
250
|
+
def expand(self, new_shape: tuple[sint, ...]) -> View:
|
272
251
|
if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}")
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
# TODO: this resolve might be wrong
|
277
|
-
assert all((not resolve(s != x, False) or s == 1) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
278
|
-
# NOTE: can the mask ever be (0,0)?
|
252
|
+
# NOTE: does not check multiple of symbolic shape
|
253
|
+
assert all(resolve(s == ns) or s == 1 for s,ns in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
254
|
+
if 0 in self.shape: return View.create(new_shape)
|
279
255
|
# TODO: this resolve may not be needed, but it's hard because vars need to be sorted
|
280
256
|
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns, False) else m) \
|
281
257
|
for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
|
282
258
|
return View.create(new_shape, self.strides, self.offset, mask)
|
283
259
|
|
284
260
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
285
|
-
def permute(self, axis:
|
261
|
+
def permute(self, axis: tuple[int, ...]) -> View:
|
286
262
|
assert sorted(axis) == list(range(len(self.shape))), f"invalid permutation {axis} of len {len(self.shape)}"
|
287
263
|
return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
|
288
264
|
tuple(self.mask[a] for a in axis) if self.mask is not None else None)
|
289
265
|
|
290
266
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
291
|
-
def
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)])
|
296
|
-
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
|
297
|
-
mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) \
|
298
|
-
for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
|
299
|
-
return View.create(new_shape, strides, self.offset + offset, mask)
|
267
|
+
def flip(self, arg: tuple[bool, ...]) -> View:
|
268
|
+
offset = sum((s-1)*z for s,z,f in zip(self.shape, self.strides, arg) if f)
|
269
|
+
mask = tuple((s-my,s-mx) if f else (mx,my) for (mx,my),s,f in zip(self.mask, self.shape, arg)) if self.mask is not None else None
|
270
|
+
return View.create(self.shape, tuple(-z if f else z for z,f in zip(self.strides, arg)), self.offset+offset, mask)
|
300
271
|
|
301
272
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
302
|
-
def reshape(self, new_shape:
|
273
|
+
def reshape(self, new_shape: tuple[sint, ...]) -> Optional[View]:
|
303
274
|
if self.shape == new_shape: return self
|
304
275
|
|
305
|
-
|
306
|
-
assert all(resolve(x >= 0) for x in new_shape), f"shape can't contain negative numbers {new_shape}"
|
307
|
-
if 0 in self.shape:
|
308
|
-
assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
|
309
|
-
return View.create(new_shape)
|
276
|
+
if not all(x >= 0 for x in new_shape): raise ValueError(f"shape can't contain negative numbers {new_shape}")
|
310
277
|
# check for the same size
|
311
278
|
if (self_all_int := all_int(self.shape)):
|
312
279
|
assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
313
|
-
if resolve(prod(self.shape) != prod(new_shape), False):
|
314
|
-
raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
280
|
+
if resolve(prod(self.shape) != prod(new_shape), False): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
315
281
|
|
282
|
+
if 0 in self.shape: return View.create(new_shape)
|
316
283
|
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
|
317
284
|
|
318
285
|
# after the asserts, it's okay to check contiguous
|
@@ -322,29 +289,26 @@ class View:
|
|
322
289
|
if self_all_int and not all_int(new_shape):
|
323
290
|
if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
324
291
|
for si, so in zip(self.shape, new_shape):
|
325
|
-
if isinstance(so, int):
|
326
|
-
|
327
|
-
else:
|
328
|
-
var_vals = dict([v.unbind() for v in so.vars()])
|
329
|
-
if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
292
|
+
if not isinstance(so, int): so = sym_infer(so, dict([v.unbind() for v in so.vars()]))
|
293
|
+
if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
330
294
|
# all dimensions matched, return the new view directly
|
331
295
|
return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
|
332
296
|
|
333
|
-
|
334
|
-
for
|
297
|
+
r_strides, r_new_shape = [], reversed(new_shape)
|
298
|
+
for merged_size, new_stride, real_size in reversed(merge_dims(self.shape, self.strides, self.mask)):
|
299
|
+
# TODO: write with get_contraction
|
335
300
|
acc = 1
|
336
301
|
# TODO: third resolve shouldn't be needed
|
337
|
-
while resolve(acc <=
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
|
302
|
+
while resolve(acc <= merged_size) and resolve(acc != merged_size) and resolve((new_dim := next(r_new_shape, 0)) > 0):
|
303
|
+
r_strides.append(new_stride * acc)
|
304
|
+
acc = acc * new_dim
|
305
|
+
if not resolve(acc < real_size): new_stride = 0
|
306
|
+
if resolve(acc != merged_size): return None
|
307
|
+
new_strides = (0,) * (len(new_shape) - len(r_strides)) + tuple(r_strides[::-1])
|
308
|
+
|
309
|
+
if (new_mask:=_reshape_mask(self.mask, self.shape, new_shape)) is not None:
|
310
|
+
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
|
311
|
+
(sum(m[0] * s for m,s in zip(new_mask, new_strides)))
|
312
|
+
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
|
349
313
|
|
350
314
|
return None
|
tinygrad/spec.py
ADDED
@@ -0,0 +1,155 @@
|
|
1
|
+
from typing import cast
|
2
|
+
from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops
|
3
|
+
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
|
4
|
+
from tinygrad.helpers import all_same, dedup, prod
|
5
|
+
|
6
|
+
buffer_spec = PatternMatcher([
|
7
|
+
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
|
8
|
+
(UPat(Ops.DEVICE, dtypes.void, (), name="device"), lambda device: isinstance(device.arg, str)),
|
9
|
+
(UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="buf"),
|
10
|
+
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
|
11
|
+
])
|
12
|
+
|
13
|
+
# *** this is the spec of a Tensor in UOp ***
|
14
|
+
|
15
|
+
tensor_uop_spec = buffer_spec+PatternMatcher([
|
16
|
+
(UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)),
|
17
|
+
# naturally correct
|
18
|
+
lambda mv,x: (isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
|
19
|
+
# "make things that can't be images not images" can change the buffer dtype
|
20
|
+
# this is fine as long as it's a realized buffer and base dtypes match.
|
21
|
+
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.is_realized)),
|
22
|
+
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.CONST, Ops.DEVICE}),)), lambda: False),
|
23
|
+
|
24
|
+
# Tensor variable bindings
|
25
|
+
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
|
26
|
+
|
27
|
+
# Tensor const has a device and an unmasked ShapeTracker of stride 0
|
28
|
+
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)),
|
29
|
+
lambda st: st.st.views[0].mask is None and len(st.st.views) == 1 and all(s == 0 for s in st.st.views[0].strides)),
|
30
|
+
|
31
|
+
# DETACH and CONTIGUOUS change how we interpret the source UOp
|
32
|
+
# CONTIGUOUS ensures the source UOp realizes
|
33
|
+
(UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype),
|
34
|
+
|
35
|
+
# COPY
|
36
|
+
# NOTE: the arg here specifies clone=True, which prevents folding same device copy
|
37
|
+
(UPat(Ops.COPY, name="copy", src=(UPat(Ops.DEVICE), UPat.var("x"))), lambda copy,x: isinstance(copy.arg, bool) and copy.dtype == x.dtype),
|
38
|
+
|
39
|
+
# ASSIGN changes the value of a realized buffer
|
40
|
+
(UPat(Ops.ASSIGN, name="assign", src=(UPat.var("target"), UPat.var("new_val"))),
|
41
|
+
lambda assign,target,new_val: target.is_realized and (assign.dtype == target.dtype == new_val.dtype)),
|
42
|
+
])
|
43
|
+
|
44
|
+
# ***** uop type spec *****
|
45
|
+
|
46
|
+
# this is the matcher for the final rendered UOps
|
47
|
+
# matcher functions returns True or False (or None to not match)
|
48
|
+
spec = PatternMatcher([
|
49
|
+
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
|
50
|
+
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
|
51
|
+
(UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True),
|
52
|
+
lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
|
53
|
+
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
54
|
+
|
55
|
+
(UPat(Ops.RANGE, src=(UPat.var("x"), UPat.var("y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype and isinstance(rng.arg, int)),
|
56
|
+
(UPat(Ops.SPECIAL, src=()), lambda: True),
|
57
|
+
|
58
|
+
# TODO: confirm the args of both of these are shapetrackers
|
59
|
+
(UPat(Ops.VIEW, dtypes.void, src=()), lambda: True),
|
60
|
+
(UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype.base == src.dtype.base),
|
61
|
+
|
62
|
+
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
|
63
|
+
(UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
64
|
+
|
65
|
+
# early LOAD has a <buf, shapetracker, store?>
|
66
|
+
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
|
67
|
+
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
|
68
|
+
|
69
|
+
# early STORE has a <buf, shapetracker, val>
|
70
|
+
(UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
|
71
|
+
|
72
|
+
# **** new style load/store ****
|
73
|
+
|
74
|
+
# INDEX is used in new style load/store
|
75
|
+
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
|
76
|
+
|
77
|
+
# LOAD takes a <bufidx, alt?, gate?, barrier?>
|
78
|
+
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
|
79
|
+
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
|
80
|
+
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat.var("alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
|
81
|
+
|
82
|
+
# STORE takes a <bufidx, val, gate?>
|
83
|
+
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True),
|
84
|
+
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
85
|
+
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True),
|
86
|
+
|
87
|
+
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
88
|
+
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
89
|
+
(UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base),
|
90
|
+
# and SHL/SHR, the shift distance can be an int
|
91
|
+
(UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
|
92
|
+
(UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
93
|
+
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
|
94
|
+
|
95
|
+
(UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
|
96
|
+
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
|
97
|
+
|
98
|
+
# WMMA has a <a, b, acc>
|
99
|
+
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
|
100
|
+
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
101
|
+
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
102
|
+
|
103
|
+
# if has a <gate, barrier?>
|
104
|
+
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
|
105
|
+
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
|
106
|
+
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
107
|
+
|
108
|
+
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
109
|
+
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
110
|
+
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
111
|
+
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
|
112
|
+
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
|
113
|
+
|
114
|
+
# NOTE: for testing, we let sinks be anything
|
115
|
+
#(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
|
116
|
+
(UPat((Ops.NAME, Ops.SINK), dtypes.void), lambda: True),
|
117
|
+
(UPat((Ops.NOOP, Ops.CUSTOM)), lambda: True),
|
118
|
+
|
119
|
+
# PTX LOAD/STORE
|
120
|
+
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
121
|
+
])
|
122
|
+
|
123
|
+
# *** this is the spec of a Kernel in UOp ***
|
124
|
+
|
125
|
+
kernel_spec = buffer_spec+PatternMatcher([
|
126
|
+
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True),
|
127
|
+
# assign has a buffer view and kernel source, it can optionally depend on other assigns
|
128
|
+
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
|
129
|
+
# view/sink/const can also exist in the kernel graph
|
130
|
+
(UPat((Ops.VIEW, Ops.SINK, Ops.CONST)), lambda: True),
|
131
|
+
(UPat(GroupOp.All), lambda: False),
|
132
|
+
])
|
133
|
+
|
134
|
+
# *** this is the UOp shape spec ***
|
135
|
+
|
136
|
+
def verify_sink_dims(sink:UOp):
|
137
|
+
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sink.toposort if x.op is not Ops.SINK and x.st is not None])]
|
138
|
+
return all_same([x.st_arg.size for x in sink.src]) and all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims)
|
139
|
+
|
140
|
+
shape_spec = PatternMatcher([
|
141
|
+
# shapes must have either 1 or n in each dimension
|
142
|
+
(UPat(Ops.SINK, src=UPat(Ops.STORE), allow_any_len=True, name="sink"), verify_sink_dims),
|
143
|
+
# all parent UOps must have the same shape
|
144
|
+
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])),
|
145
|
+
])
|
146
|
+
|
147
|
+
# ***** uop helpers *****
|
148
|
+
|
149
|
+
def type_verify(uops:list[UOp], *extra_specs:PatternMatcher):
|
150
|
+
specs = [spec, *extra_specs]
|
151
|
+
for i,u in enumerate(uops):
|
152
|
+
spec_ret = [cast(bool|None, s.rewrite(u)) for s in specs]
|
153
|
+
if any(ret is False for ret in spec_ret) or all(ret is None for ret in spec_ret):
|
154
|
+
print_uops(uops)
|
155
|
+
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}")
|