tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/shape/view.py CHANGED
@@ -1,23 +1,48 @@
1
1
  from __future__ import annotations
2
2
  import functools, operator, itertools
3
3
  from dataclasses import dataclass
4
- from typing import Optional, cast, Sequence
4
+ from typing import cast, Sequence
5
5
  from tinygrad.dtype import dtypes
6
- from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop
6
+ from tinygrad.uop.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop, Ops, ssimplify
7
7
  from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
8
8
 
9
- @functools.lru_cache(maxsize=None)
9
+ # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
10
+ def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
11
+ acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
12
+ try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
13
+ except ValueError: return None
14
+ return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
15
+
16
+ def get_contraction_with_reduce(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...], reduce_axis:tuple[int, ...]) -> list[list[int]]|None:
17
+ if (contraction:=get_contraction(old_shape, new_shape)) is None: return None
18
+ # contraction returns the 1s as right justified as possible
19
+ # normally this contraction is good, but sometimes the reduce dim is empty. borrow from the next one, leaving one
20
+ # this ensures there's always ones available in the reduce dimension. this is also a valid contraction
21
+ for i in range(len(contraction)):
22
+ if i in reduce_axis and len(contraction[i]) == 0:
23
+ take_from = i+1
24
+ while take_from < len(contraction) and len(contraction[take_from]) == 0:
25
+ assert new_shape[take_from] == 1
26
+ take_from += 1
27
+ if take_from == len(contraction) or new_shape[take_from] != 1: return None # nothing to take
28
+ for j in range(take_from, i, -1):
29
+ assert len(contraction[j]) > 0
30
+ contraction[j-1] = contraction[j][:-1]
31
+ contraction[j] = contraction[j][-1:]
32
+ return contraction
33
+
34
+ @functools.cache
10
35
  def canonicalize_strides(shape:tuple[sint, ...], strides:tuple[sint, ...]) -> tuple[sint, ...]:
11
36
  return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
12
37
 
13
- @functools.lru_cache(maxsize=None)
38
+ @functools.cache
14
39
  def strides_for_shape(shape:tuple[sint, ...]) -> tuple[sint, ...]:
15
40
  if not shape: return ()
16
41
  strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))[::-1]
17
42
  return canonicalize_strides(shape, strides)
18
43
 
19
- @functools.lru_cache(maxsize=None)
20
- def merge_dims(shape:tuple[int, ...], strides:tuple[int, ...], mask:Optional[tuple[tuple[int, int], ...]]=None) -> tuple[tuple[int, int, int], ...]:
44
+ @functools.cache
45
+ def merge_dims(shape:tuple[int, ...], strides:tuple[int, ...], mask:tuple[tuple[int, int], ...]|None=None) -> tuple[tuple[int, int, int], ...]:
21
46
  # merge contiguous sub-parts or zero strided dims
22
47
  # any stride 0, masked from dim=1, or contiguous part is merged into next dim.
23
48
  # stride != 0 to stride == 0 starts a new merging block
@@ -38,9 +63,9 @@ def merge_dims(shape:tuple[int, ...], strides:tuple[int, ...], mask:Optional[tup
38
63
  merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1
39
64
  return tuple(ret)
40
65
 
41
- @functools.lru_cache(maxsize=None)
42
- def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) \
43
- -> Optional[tuple[tuple[sint, sint], ...]]:
66
+ @functools.cache
67
+ def _reshape_mask(_mask:tuple[tuple[sint, sint], ...]|None, old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) \
68
+ -> tuple[tuple[sint, sint], ...]|None:
44
69
  """Returns the new mask if reshape is possible, and None if not possible."""
45
70
  if _mask is None: return tuple((0, s) for s in new_shape)
46
71
  if not all_int(flatten(_mask)): return None
@@ -51,7 +76,7 @@ def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple
51
76
  curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
52
77
 
53
78
  while len(new_mask) < len(new_shape):
54
- (l, r), next_stride = mask, new_dim * curr_stride
79
+ (l, r), next_stride = mask, ssimplify(new_dim * curr_stride)
55
80
 
56
81
  # need to split mask
57
82
  if old_dim == next_stride: # simply copy the mask and get next batch for merging
@@ -66,7 +91,7 @@ def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple
66
91
  next_mask = next(r_masks, (0, 1))
67
92
  # combine if the mask can unfold continuously
68
93
  if mask != (0, old_dim) and l != r and next_mask[1] - next_mask[0] != 1: return None
69
- mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
94
+ mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), ssimplify(old_dim * next(r_shape, 1))
70
95
 
71
96
  return tuple(reversed(new_mask))
72
97
 
@@ -84,12 +109,12 @@ class View:
84
109
  shape:tuple[sint, ...]
85
110
  strides:tuple[sint, ...]
86
111
  offset:sint
87
- mask:Optional[tuple[tuple[sint, sint], ...]]
112
+ mask:tuple[tuple[sint, sint], ...]|None
88
113
  contiguous:bool
89
114
 
90
- def to_indexed_uops(self:View, idxs:Optional[Sequence[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> tuple[UOp, UOp]:
115
+ def to_indexed_uops(self:View, idxs:Sequence[UOp]|None=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> tuple[UOp, UOp]:
91
116
  """(idx, valid)"""
92
- if idxs is None: idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)]
117
+ if idxs is None: idxs = [UOp.range(dtypes.int, s, i) for i,s in enumerate(self.shape)]
93
118
  iexpr = sint_to_uop(self.offset)
94
119
  for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)):
95
120
  if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
@@ -98,16 +123,17 @@ class View:
98
123
  if resolve(m[1] != sh): vexpr = vexpr * (idx < m[1])
99
124
  return iexpr, vexpr
100
125
 
101
- @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
126
+ @functools.cache # pylint: disable=method-cache-max-size-none
102
127
  def size(self) -> int:
103
128
  ret = prod([x.vmax if isinstance(x, UOp) else x for x in self.shape])
104
129
  assert isinstance(ret, int), f"{ret=} is not int"
105
130
  return ret
106
131
 
107
132
  @staticmethod
108
- @functools.lru_cache(maxsize=None)
109
- def create(shape:tuple[sint, ...], strides:Optional[tuple[sint, ...]]=None, offset:sint=0, mask:Optional[tuple[tuple[sint, sint], ...]]=None):
110
- if not all(s >= 0 for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
133
+ @functools.cache
134
+ def create(shape:tuple[sint, ...], strides:tuple[sint, ...]|None=None, offset:sint=0, mask:tuple[tuple[sint, sint], ...]|None=None):
135
+ # TODO: resolve shouldn't be needed here
136
+ if not all(resolve(s >= 0) for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
111
137
  strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
112
138
  # canonicalize 0 in shape
113
139
  if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
@@ -131,43 +157,50 @@ class View:
131
157
  contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
132
158
  return View(shape, strides, offset, mask, contiguous)
133
159
 
134
- @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
160
+ @functools.cache # pylint: disable=method-cache-max-size-none
135
161
  def vars(self) -> set[Variable]:
136
162
  flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
137
163
  return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, UOp)], set())
138
164
 
139
- @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
165
+ @functools.cache # pylint: disable=method-cache-max-size-none
140
166
  def unbind(self) -> tuple[View, dict[Variable, int]]:
141
- var_unboundvar_val = [(v, v.unbind()) for v in self.vars()]
167
+ var_unboundvar_val = [(v, v.unbind()) for v in self.vars() if v.op is Ops.BIND]
142
168
  unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
143
- def substitute(x:sint): return x if isinstance(x, int) else x.substitute(unbound_vars)
144
- new_shape = tuple(map(substitute, self.shape))
145
- new_strides = tuple(map(substitute, self.strides))
146
- new_offset = substitute(self.offset)
147
- new_mask = tuple((substitute(x[0]), substitute(x[1])) for x in self.mask) if self.mask is not None else None
148
- return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
149
-
150
- @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
151
- def __add__(self, vm1:View) -> Optional[View]:
169
+ return self.substitute(unbound_vars), dict(x[1] for x in var_unboundvar_val)
170
+
171
+ def substitute(self, dvars:dict[UOp, UOp]):
172
+ def _substitute(x:sint): return x if isinstance(x, int) else x.substitute(dvars)
173
+ new_shape = tuple(map(_substitute, self.shape))
174
+ new_strides = tuple(map(_substitute, self.strides))
175
+ new_offset = _substitute(self.offset)
176
+ new_mask = tuple((_substitute(x[0]), _substitute(x[1])) for x in self.mask) if self.mask is not None else None
177
+ return View.create(new_shape, new_strides, new_offset, new_mask)
178
+
179
+ @functools.cache # pylint: disable=method-cache-max-size-none
180
+ def __add__(self, vm1:View) -> View|None:
152
181
  vm2 = self
153
- if vm2.contiguous: return vm1
182
+ if vm2.contiguous or vm1.size() == 0: return vm1
154
183
  if vm1.contiguous and vm1.shape == vm2.shape: return vm2
155
184
  if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
156
185
  if vm1.mask:
157
186
  if (new_vm1 := vm1.shrink(vm1.mask)) == vm1 or (merged := vm2 + new_vm1) is None: return None
158
187
  return merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
159
- if not all_int(vm1.shape): return None
188
+ if not all_int(vm1.shape):
189
+ # if all strides are 0 and vm2 is unmasked, return vm1
190
+ if all(x == 0 for x in vm2.strides+vm1.strides) and vm2.mask is None: return vm1
191
+ # TODO: handle more cases
192
+ return None
160
193
 
161
194
  # Project vm1's offset and strides on to vm2.
162
- origin = unravel(vm2.shape, vm1.offset)
195
+ origin = [ssimplify(o) for o in unravel(vm2.shape, vm1.offset)]
163
196
  terms: list[list[tuple[int, sint]]] = [[] for _ in vm2.shape]
164
197
  strides: list[sint] = [0] * len(vm1.shape)
165
198
  for d1, st in enumerate(vm1.strides):
166
199
  if st == 0: continue
167
200
  for d2, (o, s1) in enumerate(zip(origin, unravel(vm2.shape, vm1.offset + st))):
168
- if (s1 := s1 - o) == 0: continue
201
+ if not resolve((s1 := s1 - o)!=0): continue # if s1 can possibly be 0
169
202
  terms[d2].append((d1, s1))
170
- strides[d1] += s1 * vm2.strides[d2]
203
+ strides[d1] += ssimplify(s1 * vm2.strides[d2])
171
204
 
172
205
  # Merge dimensions in vm2 if required.
173
206
  # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
@@ -190,14 +223,17 @@ class View:
190
223
  # Try to project vm2's mask on to vm1.
191
224
  newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
192
225
  for (b, e), o, term, (_, t) in zip(vm2.mask, origin, terms, reversed(extents)):
193
- if resolve(b <= t.vmin and t.vmax < e, False): continue
226
+ if resolve(b <= (t := t.simplify()).vmin and t.vmax < e, False): continue
194
227
  if len(term) != 1:
195
- if not term and newe: newe[0] = 0
228
+ if not term and newe:
229
+ # t should be a constant if no terms contribute to this dimension, but it might not be simplified
230
+ if t.vmin != t.vmax: return None
231
+ newe[0] = 0
196
232
  else: bad = True
197
233
  continue
198
234
  d1, s1 = term[0]
199
- newb[d1] = max(newb[d1], ceildiv(b - o if s1 > 0 else e - o - 1, s1))
200
- newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
235
+ newb[d1] = smax(newb[d1], ceildiv(b - o if s1 > 0 else e - o - 1, s1))
236
+ newe[d1] = smin(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
201
237
 
202
238
  # If any of vm1 was masked off, try again with that mask in place.
203
239
  if any((b, e) != (0, s) for b, e, s in zip(newb, newe, vm1.shape)):
@@ -205,16 +241,16 @@ class View:
205
241
  # Otherwise if vm2's mask was violated, then cannot merge.
206
242
  if bad: return None
207
243
 
208
- return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
244
+ return View.create(vm1.shape, tuple(strides), ssimplify(sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset))
209
245
 
210
- @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
211
- def invert(self, out_shape:tuple[sint, ...]) -> Optional[View]:
246
+ @functools.cache # pylint: disable=method-cache-max-size-none
247
+ def invert(self, out_shape:tuple[sint, ...]) -> View|None:
212
248
  ret = View.create(self.shape)
213
249
  if self.mask: ret = ret.shrink(self.mask)
214
250
  ret = ret.flip(tuple(x < 0 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
215
251
  return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
216
252
 
217
- @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
253
+ @functools.cache # pylint: disable=method-cache-max-size-none
218
254
  def minify(self):
219
255
  min_shape = tuple(x[0] for x in merge_dims(self.shape, self.strides, self.mask))
220
256
  return nv if (nv := self.reshape(min_shape)) else self
@@ -228,7 +264,7 @@ class View:
228
264
  mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
229
265
  return View.create(tuple([y-x for x,y in arg]), self.strides, self.offset+offset, mask)
230
266
 
231
- @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
267
+ @functools.cache # pylint: disable=method-cache-max-size-none
232
268
  def pad(self, arg: tuple[tuple[sint, sint], ...]) -> View:
233
269
  assert len(arg) == len(self.shape), f"invalid pad {arg} for {self.shape}"
234
270
  # NOTE: not checking for symbolic arg
@@ -239,38 +275,38 @@ class View:
239
275
  return self.__unsafe_resize(zvarg, mask=mask)
240
276
  return self
241
277
 
242
- @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
278
+ @functools.cache # pylint: disable=method-cache-max-size-none
243
279
  def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> View:
244
280
  assert len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
245
281
  # NOTE: not checking for symbolic arg
246
282
  for s,(b,e) in zip(self.shape,arg): assert not all_int([s,b,e]) or (0<=b<=e<=s), f"invalid shrink {arg} for {self.shape}"
247
283
  return self.__unsafe_resize(arg)
248
284
 
249
- @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
285
+ @functools.cache # pylint: disable=method-cache-max-size-none
250
286
  def expand(self, new_shape: tuple[sint, ...]) -> View:
251
287
  if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}")
252
288
  # NOTE: does not check multiple of symbolic shape
253
289
  assert all(resolve(s == ns) or s == 1 for s,ns in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
254
290
  if 0 in self.shape: return View.create(new_shape)
255
- # TODO: this resolve may not be needed, but it's hard because vars need to be sorted
256
- mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns, False) else m) \
291
+ # TODO: resolve may not be needed, but it's hard because vars need to be canonicalized
292
+ mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns) and resolve(s == 1, False) else m) \
257
293
  for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
258
294
  return View.create(new_shape, self.strides, self.offset, mask)
259
295
 
260
- @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
296
+ @functools.cache # pylint: disable=method-cache-max-size-none
261
297
  def permute(self, axis: tuple[int, ...]) -> View:
262
298
  assert sorted(axis) == list(range(len(self.shape))), f"invalid permutation {axis} of len {len(self.shape)}"
263
299
  return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
264
300
  tuple(self.mask[a] for a in axis) if self.mask is not None else None)
265
301
 
266
- @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
302
+ @functools.cache # pylint: disable=method-cache-max-size-none
267
303
  def flip(self, arg: tuple[bool, ...]) -> View:
268
304
  offset = sum((s-1)*z for s,z,f in zip(self.shape, self.strides, arg) if f)
269
305
  mask = tuple((s-my,s-mx) if f else (mx,my) for (mx,my),s,f in zip(self.mask, self.shape, arg)) if self.mask is not None else None
270
306
  return View.create(self.shape, tuple(-z if f else z for z,f in zip(self.strides, arg)), self.offset+offset, mask)
271
307
 
272
- @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
273
- def reshape(self, new_shape: tuple[sint, ...]) -> Optional[View]:
308
+ @functools.cache # pylint: disable=method-cache-max-size-none
309
+ def reshape(self, new_shape: tuple[sint, ...]) -> View|None:
274
310
  if self.shape == new_shape: return self
275
311
 
276
312
  if not all(x >= 0 for x in new_shape): raise ValueError(f"shape can't contain negative numbers {new_shape}")