tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/shape/view.py CHANGED
@@ -2,8 +2,9 @@ from __future__ import annotations
2
2
  import functools, operator, itertools, math
3
3
  from dataclasses import dataclass
4
4
  from typing import Tuple, List, Optional, Dict, Set, cast
5
- from tinygrad.helpers import prod, all_int, argsort
6
- from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer
5
+ from tinygrad.dtype import dtypes
6
+ from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin
7
+ from tinygrad.helpers import prod, all_int, argsort, flatten
7
8
 
8
9
  @functools.lru_cache(maxsize=None)
9
10
  def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
@@ -20,7 +21,7 @@ def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tu
20
21
  # merge contiguous sub-parts or zero strided dims. ret = Tuple[(merged_size, stride, merged size w/o zero stride), ...]
21
22
  if not shape: return ()
22
23
  assert len(shape) == len(strides) and (mask is None or len(shape) == len(mask))
23
- ret = [(shape[0], strides[0], shape[0] if strides[0] else 0)]
24
+ ret = [(shape[0], strides[0], shape[0] if strides[0] != 0 else 0)]
24
25
  # merge this dim to next dim if size is 1
25
26
  merging = (mask[0][1] - mask[0][0] == 1) if mask is not None else shape[0] == 1
26
27
  for i, (s, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
@@ -28,21 +29,24 @@ def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tu
28
29
  # always merge 1
29
30
  if s == 1: continue
30
31
  # merge last dim with this dim if merging or strides matched
31
- if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st else 0)
32
- else: ret.append((s, st, s if st else 0))
32
+ if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st != 0 else 0)
33
+ else: ret.append((s, st, s if st != 0 else 0))
33
34
  # merge this dim to next dim if size is 1
34
35
  merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1
35
36
  return tuple(ret)
36
37
 
37
38
  @functools.lru_cache(maxsize=None)
38
- def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tuple[Tuple[sint, sint], ...]], bool]:
39
- if view.mask is None: return view.mask, False
40
- if any(not isinstance(m[0], int) or not isinstance(m[1], int) for m in view.mask): return view.mask, True
41
- new_mask: List[Tuple[int, int]] = []
39
+ def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) \
40
+ -> Optional[Tuple[Tuple[sint, sint], ...]]:
41
+ """Returns the new mask if reshape is possible, and None if not possible."""
42
+ if _mask is None: return tuple((0, s) for s in new_shape)
43
+ if any(not isinstance(m[0], int) or not isinstance(m[1], int) for m in _mask): return None
44
+ if any(m[1] - m[0] < 1 for m in _mask): return ((0, 0),) * len(new_shape) # zero mask
42
45
 
43
- r_masks, r_shape, r_new_shape = reversed(view.mask), reversed(view.shape), reversed(new_shape)
46
+ new_mask: List[Tuple[int, int]] = []
47
+ # _mask is all int here
48
+ r_masks, r_shape, r_new_shape = reversed(cast(Tuple[Tuple[int, int], ...], _mask)), reversed(old_shape), reversed(new_shape)
44
49
  curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
45
- if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
46
50
 
47
51
  while len(new_mask) < len(new_shape):
48
52
  (l, r), next_stride = mask, new_dim * curr_stride
@@ -51,34 +55,35 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
51
55
  if old_dim == next_stride: # simply copy the mask and get next batch for merging
52
56
  new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1))
53
57
  curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
54
- if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
55
58
 
56
59
  else: # mask can only be splitted if reshape doesn't cut across the mask.
57
60
  if (((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
58
- or old_dim % next_stride != 0): return view.mask, True
61
+ or old_dim % next_stride != 0): return None
59
62
  new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
60
63
  curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
61
64
 
62
65
  else:
63
66
  next_mask = next(r_masks, (0, 1))
64
67
  # combine if the mask can unfold continuously
65
- if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return view.mask, True
68
+ if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return None
66
69
  mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
67
70
 
68
71
  for mask in r_masks: # if the old shape has leading 1s, need to make sure their mask is (0,1)
69
- if mask != (0, 1): return ((0, 0),) * len(new_shape), False # invalid mask
72
+ if mask != (0, 1): return ((0, 0),) * len(new_shape) # invalid mask
70
73
 
71
- return tuple(reversed(new_mask)), False
74
+ return tuple(reversed(new_mask))
72
75
 
73
76
  def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
74
77
  strides = strides_for_shape(shape)
75
78
  result = []
76
79
  for stride in strides:
77
- here = offs // stride if stride else 0
80
+ here = offs // stride if stride != 0 else 0
78
81
  result.append(here)
79
82
  offs -= here * stride
80
83
  return result
81
84
 
85
+ def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x
86
+
82
87
  @dataclass(frozen=True)
83
88
  class View:
84
89
  shape:Tuple[sint, ...]
@@ -87,16 +92,33 @@ class View:
87
92
  mask:Optional[Tuple[Tuple[sint, sint], ...]]
88
93
  contiguous:bool
89
94
 
95
+ @functools.cached_property
96
+ def t(self):
97
+ return tuple(x.tuplize if isinstance(x, UOp) else (x,) \
98
+ for x in self.shape+self.strides+(self.offset,)+(tuple(flatten(self.mask)) if self.mask is not None else tuple()))
99
+ def __lt__(self, o:View): return self.t < o.t
100
+
101
+ def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
102
+ idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
103
+ iexpr = variable_to_uop(self.offset)
104
+ for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
105
+ if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
106
+ if m is not None:
107
+ if resolve(m[0] != 0): vexpr = vexpr * idx.ge(m[0])
108
+ if resolve(m[1] != sh): vexpr = vexpr * idx.lt(m[1])
109
+ return iexpr, vexpr
110
+
90
111
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
91
112
  def size(self) -> int:
92
- # NOTE: Variable and the Node derived from it in symbolic shapes can only have int as max.
93
- ret = prod([x.max if isinstance(x, Node) else x for x in self.shape])
113
+ ret = prod([x.vmax if isinstance(x, UOp) else x for x in self.shape])
94
114
  assert isinstance(ret, int), f"{ret=} is not int"
95
115
  return ret
96
116
 
97
117
  @staticmethod
98
118
  @functools.lru_cache(maxsize=None)
99
119
  def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
120
+ # TODO: this resolve shouldn't be needed
121
+ if not all(resolve(s >= 0) for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
100
122
  strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
101
123
  # canonicalize 0 in shape
102
124
  if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
@@ -104,29 +126,36 @@ class View:
104
126
  if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
105
127
  # if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
106
128
  # then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
107
- # TODO: assert comparison with LtNode to avoid mis-using symbolic
108
- if mask and any(elim := [not (b+1 < e) for b,e in mask]):
109
- if any(not (b < e) for b,e in mask):
129
+ if mask and any(elim := [not resolve(b+1 < e) for b,e in mask]):
130
+ if any(not resolve(b < e) for b,e in mask):
110
131
  strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
111
132
  offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
112
133
  strides = tuple(0 if e else st for st,e in zip(strides, elim))
134
+ # simplify as we go
135
+ if isinstance(offset, UOp): offset = cast(sint, offset.ssimplify())
136
+ shape = tuple(cast(sint, x.ssimplify()) if isinstance(x, UOp) else x for x in shape)
137
+ # TODO: enabling stride simplification breaks it
138
+ """
139
+ strides = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in strides)
140
+ if mask: mask = tuple((s.ssimplify() if isinstance(s, UOp) else s, e.ssimplify() if isinstance(e, UOp) else e) for s,e in mask)
141
+ """
113
142
  contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
114
143
  return View(shape, strides, offset, mask, contiguous)
115
144
 
116
145
  @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
117
146
  def vars(self) -> Set[Variable]:
118
147
  flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
119
- return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], set())
148
+ return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, UOp)], set())
120
149
 
121
150
  @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
122
151
  def unbind(self) -> Tuple[View, Dict[Variable, int]]:
123
- var_unboundvar_val = [(v, v.unbind()) for v in self.vars() if v.val is not None]
152
+ var_unboundvar_val = [(v, v.unbind()) for v in self.vars()]
124
153
  unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
125
- new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])
126
- new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides])
127
- new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars)
128
- new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars),
129
- b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
154
+ def substitute(x): return x if isinstance(x, int) else x.substitute(unbound_vars)
155
+ new_shape = tuple(map(substitute, self.shape))
156
+ new_strides = tuple(map(substitute, self.strides))
157
+ new_offset = substitute(self.offset)
158
+ new_mask = tuple((substitute(x[0]), substitute(x[1])) for x in self.mask) if self.mask is not None else None
130
159
  return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
131
160
 
132
161
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@@ -137,7 +166,7 @@ class View:
137
166
  if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
138
167
  if vm1.mask:
139
168
  for b,e in vm1.mask:
140
- if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
169
+ if resolve(b >= e, False): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
141
170
  return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
142
171
 
143
172
  # Project vm1's offset and strides on to vm2.
@@ -153,28 +182,30 @@ class View:
153
182
 
154
183
  # Merge dimensions in vm2 if required.
155
184
  # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
156
- idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
157
- merged_size, merged_term = 1, NumNode(0)
158
- extents: List[Tuple[sint, Node]] = []
185
+ if not all_int(vm1.shape): return None
186
+ idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
187
+ merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
188
+ extents: List[Tuple[sint, UOp]] = []
159
189
  for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
160
- merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
190
+ merged_term += (sum([idxs[d1] * s1 for d1, s1 in term]) + o) * merged_size
161
191
  merged_size *= s
162
- if not (merged_term >= merged_size) and not (merged_term < 0):
192
+ if resolve(merged_term < merged_size, False) and resolve(0 <= merged_term, False):
163
193
  extents.append((merged_size, merged_term))
164
- merged_size, merged_term = 1, NumNode(0)
165
- if merged_term: return None
194
+ merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
195
+ if resolve(merged_term != 0): return None
166
196
  if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
167
- return (reshaped_vm2 := vm2.reshape(vm2_shape)) and reshaped_vm2 + vm1
197
+ reshaped_vm2 = vm2.reshape(vm2_shape)
198
+ if reshaped_vm2 is None: return None
199
+ if reshaped_vm2.shape != vm2.shape: return reshaped_vm2 + vm1
168
200
 
169
201
  if vm2.mask:
170
202
  # Try to project vm2's mask on to vm1.
171
203
  newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
172
- for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
173
- if not (t.min < b or t.max >= e): continue
174
- if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
204
+ for (b, e), o, term, (_, t) in zip(vm2.mask, origin, terms, reversed(extents)):
205
+ if resolve(b <= t.vmin and t.vmax < e, False): continue
206
+ if not all_int([o, b, e]):
175
207
  bad = True
176
208
  continue
177
- term = terms[d2]
178
209
  if len(term) != 1:
179
210
  if not term and newe: newe[0] = 0
180
211
  else: bad = True
@@ -188,7 +219,7 @@ class View:
188
219
 
189
220
  # If any of vm1 was masked off, try again with that mask in place.
190
221
  for b, e, s in zip(newb, newe, vm1.shape):
191
- if b != 0 or e != s:
222
+ if (b, e) != (0, s):
192
223
  return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
193
224
  # Otherwise if vm2's mask was violated, then cannot merge.
194
225
  if bad: return None
@@ -211,17 +242,19 @@ class View:
211
242
  offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
212
243
  if self.mask:
213
244
  # move the old mask
214
- nmask = tuple([(max(0, min(mx-ax,ay-ax)), max(0, min(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)])
245
+ nmask = tuple([(smax(0, smin(mx-ax,ay-ax)), smax(0, smin(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)])
215
246
  # merge the masks if we have two
216
- mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
247
+ mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
217
248
  shape = [y-x for x,y in arg]
218
249
  if mask is not None and all(m[0] == 0 and m[1] == s for m,s in zip(mask, shape)): mask = None
219
- return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
250
+ return View.create(tuple(s.ssimplify() if isinstance(s, UOp) else s for s in shape), self.strides, self.offset+offset, mask)
220
251
 
221
252
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
222
253
  def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
223
- assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape), f"{self.shape=}, {arg=}"
224
- if any(b or e for b, e in arg):
254
+ assert len(arg) == len(self.shape), f"invalid pad {arg} for {self.shape}"
255
+ # NOTE: not checking for symbolic arg
256
+ for b,e in arg: assert not all_int([b,e]) or b>=0 and e>=0, f"invalid pad {arg} for {self.shape}"
257
+ if any(resolve(b!=0) or resolve(e!=0) for b, e in arg):
225
258
  zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
226
259
  mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])
227
260
  return self.__unsafe_resize(zvarg, mask=mask)
@@ -229,7 +262,9 @@ class View:
229
262
 
230
263
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
231
264
  def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
232
- assert all((0<=b<=e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
265
+ assert len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
266
+ # NOTE: not checking for symbolic arg
267
+ for s,(b,e) in zip(self.shape,arg): assert not all_int([s,b,e]) or (0<=b<=e<=s), f"invalid shrink {arg} for {self.shape}"
233
268
  return self.__unsafe_resize(arg)
234
269
 
235
270
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@@ -238,9 +273,12 @@ class View:
238
273
  if 0 in self.shape:
239
274
  assert all((s == x == 0) or (s > 0 and (x % s) == 0) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
240
275
  return View.create(new_shape)
241
- assert all((s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}"
276
+ # TODO: this resolve might be wrong
277
+ assert all((not resolve(s != x, False) or s == 1) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
242
278
  # NOTE: can the mask ever be (0,0)?
243
- mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
279
+ # TODO: this resolve may not be needed, but it's hard because vars need to be sorted
280
+ mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns, False) else m) \
281
+ for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
244
282
  return View.create(new_shape, self.strides, self.offset, mask)
245
283
 
246
284
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@@ -264,14 +302,15 @@ class View:
264
302
  def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
265
303
  if self.shape == new_shape: return self
266
304
 
267
- assert all(x >= 0 for x in new_shape), f"shape can't contain negative numbers {new_shape}"
305
+ # TODO: this resolve shouldn't be needed
306
+ assert all(resolve(x >= 0) for x in new_shape), f"shape can't contain negative numbers {new_shape}"
268
307
  if 0 in self.shape:
269
308
  assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
270
309
  return View.create(new_shape)
271
310
  # check for the same size
272
311
  if (self_all_int := all_int(self.shape)):
273
- assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
274
- if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]):
312
+ assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
313
+ if resolve(prod(self.shape) != prod(new_shape), False):
275
314
  raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
276
315
 
277
316
  if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
@@ -286,7 +325,7 @@ class View:
286
325
  if isinstance(so, int):
287
326
  if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
288
327
  else:
289
- var_vals = {v: v.unbind()[1] for v in so.vars()}
328
+ var_vals = dict([v.unbind() for v in so.vars()])
290
329
  if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
291
330
  # all dimensions matched, return the new view directly
292
331
  return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
@@ -294,18 +333,18 @@ class View:
294
333
  strides, r_new_shape = [], reversed(new_shape)
295
334
  for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
296
335
  acc = 1
297
- # TODO: this <= and != is for symbolic!?
298
- while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, None)):
336
+ # TODO: third resolve shouldn't be needed
337
+ while resolve(acc <= merged_dim) and resolve(acc != merged_dim) and resolve((new_dim := next(r_new_shape, 0)) > 0):
299
338
  strides.append(new_stride)
300
- if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
301
- if acc != merged_dim: break
339
+ if resolve(new_dim != 1): new_stride *= (new_dim if resolve((acc := acc * new_dim) < real_dim) else 0)
340
+ if resolve(acc != merged_dim): break
302
341
  else:
303
342
  strides += [0,] * (len(new_shape) - len(strides))
304
- new_mask, extra = _reshape_mask(self, new_shape)
305
- if not extra:
306
- new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask) if new_mask else new_shape, tuple(reversed(strides)))
343
+ new_mask = _reshape_mask(self.mask, self.shape, new_shape)
344
+ if new_mask is not None:
345
+ new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask), tuple(reversed(strides)))
307
346
  extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
308
- (sum(m[0] * s for m,s in zip(new_mask, new_strides)) if new_mask else 0)
347
+ (sum(m[0] * s for m,s in zip(new_mask, new_strides)))
309
348
  return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
310
349
 
311
350
  return None