tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/__init__.py +0 -0
  3. tinygrad/codegen/kernel.py +253 -225
  4. tinygrad/codegen/linearizer.py +398 -436
  5. tinygrad/codegen/uops.py +451 -0
  6. tinygrad/device.py +268 -274
  7. tinygrad/dtype.py +56 -40
  8. tinygrad/engine/__init__.py +0 -0
  9. tinygrad/engine/graph.py +100 -0
  10. tinygrad/engine/jit.py +198 -0
  11. tinygrad/engine/realize.py +192 -0
  12. tinygrad/engine/schedule.py +370 -0
  13. tinygrad/engine/search.py +199 -0
  14. tinygrad/{mlops.py → function.py} +40 -32
  15. tinygrad/helpers.py +144 -46
  16. tinygrad/lazy.py +143 -242
  17. tinygrad/multi.py +173 -0
  18. tinygrad/nn/__init__.py +180 -9
  19. tinygrad/nn/datasets.py +8 -0
  20. tinygrad/nn/optim.py +106 -28
  21. tinygrad/nn/state.py +87 -19
  22. tinygrad/ops.py +104 -45
  23. tinygrad/renderer/__init__.py +65 -0
  24. tinygrad/renderer/assembly.py +269 -0
  25. tinygrad/renderer/cstyle.py +308 -210
  26. tinygrad/renderer/llvmir.py +119 -124
  27. tinygrad/runtime/__init__.py +0 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +13403 -0
  29. tinygrad/runtime/autogen/comgr.py +891 -0
  30. tinygrad/runtime/autogen/cuda.py +5923 -0
  31. tinygrad/runtime/autogen/hip.py +5909 -0
  32. tinygrad/runtime/autogen/hsa.py +5893 -0
  33. tinygrad/runtime/autogen/io_uring.py +1486 -0
  34. tinygrad/runtime/autogen/kfd.py +812 -0
  35. tinygrad/runtime/autogen/nv_gpu.py +33597 -0
  36. tinygrad/runtime/autogen/opencl.py +1795 -0
  37. tinygrad/runtime/driver/__init__.py +0 -0
  38. tinygrad/runtime/driver/hip_comgr.py +56 -0
  39. tinygrad/runtime/graph/__init__.py +0 -0
  40. tinygrad/runtime/graph/clang.py +39 -0
  41. tinygrad/runtime/graph/cuda.py +59 -54
  42. tinygrad/runtime/graph/hcq.py +187 -0
  43. tinygrad/runtime/graph/metal.py +37 -41
  44. tinygrad/runtime/ops_amd.py +550 -0
  45. tinygrad/runtime/ops_clang.py +16 -14
  46. tinygrad/runtime/ops_cuda.py +129 -37
  47. tinygrad/runtime/ops_disk.py +111 -43
  48. tinygrad/runtime/ops_gpu.py +52 -50
  49. tinygrad/runtime/ops_llvm.py +36 -56
  50. tinygrad/runtime/ops_metal.py +41 -24
  51. tinygrad/runtime/ops_npy.py +9 -0
  52. tinygrad/runtime/ops_nv.py +625 -0
  53. tinygrad/runtime/ops_python.py +208 -0
  54. tinygrad/shape/__init__.py +0 -0
  55. tinygrad/shape/shapetracker.py +46 -107
  56. tinygrad/shape/symbolic.py +99 -98
  57. tinygrad/shape/view.py +162 -45
  58. tinygrad/tensor.py +2492 -483
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
  60. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
  61. tinygrad-0.9.1.dist-info/RECORD +63 -0
  62. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  63. tinygrad/features/image.py +0 -93
  64. tinygrad/features/multi.py +0 -103
  65. tinygrad/features/search.py +0 -160
  66. tinygrad/graph.py +0 -106
  67. tinygrad/jit.py +0 -152
  68. tinygrad/realize.py +0 -50
  69. tinygrad/runtime/graph/hip.py +0 -24
  70. tinygrad/runtime/ops_cpu.py +0 -45
  71. tinygrad/runtime/ops_hip.py +0 -97
  72. tinygrad/runtime/ops_torch.py +0 -49
  73. tinygrad-0.8.0.dist-info/RECORD +0 -41
  74. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,208 @@
1
+ # pylint: disable=cell-var-from-loop
2
+ # a python uops emulator
3
+ # works to test the tensor cores, and all the uops in general
4
+ # this is the (living) definition of uops
5
+ from typing import Tuple, List, Optional, Any, Dict
6
+ import pickle, base64, itertools, time, struct
7
+ from tinygrad.dtype import DType, dtypes, ImageDType
8
+ from tinygrad.helpers import all_same, getenv, flatten
9
+ from tinygrad.device import Compiled, Compiler, Allocator
10
+ from tinygrad.codegen.uops import UOpGraph, UOps
11
+ from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate
12
+ from tinygrad.renderer import Renderer
13
+ from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer
14
+
15
+ def _load(m, i):
16
+ if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
17
+ return m[i]
18
+
19
+ def load(inp, j=0):
20
+ if len(inp) == 4: return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)]
21
+ return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
22
+
23
+ def _store(m, i, v):
24
+ if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
25
+ m[i] = v
26
+
27
+ class PythonProgram:
28
+ def __init__(self, name:str, lib:bytes):
29
+ self.uops: List[Tuple[UOps, Optional[DType], List[int], Any]] = pickle.loads(lib)
30
+ def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
31
+ st = time.perf_counter()
32
+ warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
33
+ warp_size = len(warp)
34
+ for idxs in itertools.product(*[range(x) for x in global_size[::-1]]):
35
+ ul: Dict[int, Any] = {}
36
+ dl: Dict[int, DType] = {}
37
+ pbufs: List[memoryview] = list(bufs)
38
+ pvals: List[int] = list(vals)
39
+ i = 0
40
+ loop_ends: Dict[int, int] = {}
41
+ while i < len(self.uops):
42
+ uop, dtype, idp, arg = self.uops[i]
43
+ void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF}
44
+ if uop is UOps.DEFINE_ACC: idp = [idp[0]]
45
+ inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
46
+ dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
47
+ if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
48
+ if uop is UOps.STORE:
49
+ if len(inp) == 3: inp.append([True] * len(inp[0])) # set the gate to True
50
+ if isinstance(dtp[0], ImageDType):
51
+ # image store
52
+ assert dtp[2].count == 4
53
+ for j,val in enumerate(inp[2]):
54
+ for m,ox,oy,v,g in zip(inp[0], inp[1][0], inp[1][1], val, inp[3]):
55
+ assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
56
+ if g: _store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
57
+ elif dtp[2].count > 1:
58
+ for j,val in enumerate(inp[2]):
59
+ for m,o,v,g in zip(inp[0], inp[1], val, inp[3]):
60
+ if g: _store(m, o+j, v)
61
+ else:
62
+ for m,o,v,g in zip(*inp):
63
+ if g: _store(m, o, v)
64
+ i += 1
65
+ continue
66
+ if uop is UOps.ENDRANGE:
67
+ loop_ends[idp[0]] = i
68
+ i = idp[0]
69
+ continue
70
+ if uop in (UOps.BARRIER, UOps.IF, UOps.ENDIF):
71
+ # in the python emulator, the warp is always in sync
72
+ i += 1
73
+ continue
74
+ assert dtype is not None, f"{uop} is missing a dtype"
75
+ dl[i] = dtype
76
+ if uop is UOps.DEFINE_GLOBAL:
77
+ assert dtype.fmt is not None
78
+ ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size
79
+ elif uop is UOps.DEFINE_LOCAL:
80
+ assert dtype.fmt is not None
81
+ lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
82
+ ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
83
+ elif uop is UOps.DEFINE_VAR:
84
+ ul[i] = [pvals.pop(0)] * warp_size
85
+ elif uop is UOps.SPECIAL:
86
+ if arg[1][0] == 'g':
87
+ ul[i] = [idxs[2-arg[0]]] * warp_size
88
+ elif arg[1][0] == 'l':
89
+ ul[i] = [x[2-arg[0]] for x in warp]
90
+ elif uop is UOps.CONST:
91
+ ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
92
+ elif uop is UOps.DEFINE_ACC:
93
+ ul[i] = [[inp[0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
94
+ elif uop is UOps.RANGE:
95
+ if i not in ul: ul[i] = [inp[0][0]] * warp_size
96
+ else:
97
+ for j in range(len(ul[i])):
98
+ ul[i][j] += 1
99
+ if ul[i][0] == inp[1][0]:
100
+ del ul[i]
101
+ i = loop_ends[i] + 1
102
+ continue
103
+ elif uop in (UOps.CAST, UOps.BITCAST):
104
+ if dtype.count > 1: ul[i] = inp
105
+ else:
106
+ assert dtp[0].fmt and dtype.fmt
107
+ pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
108
+ if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
109
+ else:
110
+ casted = [dtypes.as_const(x, dtype) for x in inp[0]]
111
+ if dtypes.is_int(dtype):
112
+ overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
113
+ casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
114
+ elif dtypes.is_float(dtype):
115
+ casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
116
+ ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
117
+ elif uop is UOps.LOAD:
118
+ if isinstance(dtp[0], ImageDType):
119
+ assert dtype.count == 4
120
+ ul[i] = []
121
+ for j in range(dtype.count):
122
+ ret = []
123
+ for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
124
+ if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append(0)
125
+ else: ret.append(_load(m, ox*4 + oy*dtp[0].shape[1]*4 + j))
126
+ ul[i].append(ret)
127
+ elif dtype.count > 1:
128
+ ul[i] = [load([inp[i][j] if dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
129
+ else:
130
+ ul[i] = load(inp)
131
+ elif uop is UOps.PHI:
132
+ for j in range(len(inp[0])): inp[0][j] = inp[1][j]
133
+ ul[i] = inp[0]
134
+ elif uop is UOps.GEP:
135
+ ul[i] = inp[0][arg]
136
+ elif uop is UOps.WMMA:
137
+ # here are the models for the WMMA instruction on the different hardware
138
+ def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
139
+ assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread"
140
+ assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread"
141
+ assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread"
142
+ assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA"
143
+ assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA"
144
+ assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA"
145
+ assert warp_size > 0 and warp_size % WARP_THREADS == 0, f"must have multiples of {WARP_THREADS} warp threads"
146
+ out = [inp[2][elem_idx][:] for elem_idx in range(NUM_C)]
147
+ for goff in range(0, warp_size, WARP_THREADS):
148
+ for lane_id in range(WARP_THREADS):
149
+ for elem_idx in range(NUM_C): # calculate new muls and add to acc
150
+ (c_i, c_j) = c_map(lane_id, elem_idx)
151
+ out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
152
+ return out
153
+
154
+ # TODO: refactor these to a shared TensorCoreLayout in kernel.py
155
+ if arg[5] == "METAL":
156
+ # A (2 elements on 32 threads): row major
157
+ def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
158
+ # (i, j), C, D (2 elements on 32 threads): row major same as A/B
159
+ def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
160
+ ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
161
+ elif arg[5] == "AMD":
162
+ # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
163
+ def a_elem(x, i, j, goff):
164
+ assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
165
+ return x[i][goff+j]
166
+ # B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
167
+ def b_elem(x, i, j, goff): return a_elem(x, j, i, goff) # pylint: disable=arguments-out-of-order
168
+ def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
169
+ ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
170
+ elif arg[5] == "CUDA":
171
+ # A (8 elements on 32 threads)
172
+ def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
173
+ # B (4 elements on 32 threads)
174
+ def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4]
175
+ # (i, j), C, D (4 elements on 32 threads)
176
+ def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8)
177
+ ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
178
+ else: raise NotImplementedError(f"unimplemented tensor core {arg}")
179
+ elif uop is UOps.ALU:
180
+ assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
181
+ assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
182
+ ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)]
183
+ assert i in ul, (uop, dtype, idp, arg)
184
+ i += 1
185
+ return time.perf_counter() - st
186
+
187
+ class PythonRenderer(Renderer):
188
+ device = "PYTHON"
189
+ def __init__(self):
190
+ if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
191
+ if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores
192
+ if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
193
+
194
+ def render(self, name:str, uops:UOpGraph) -> str:
195
+ lops = [(u.op, u.dtype, [uops.uops.index(v) for v in u.src], u.arg) for u in uops]
196
+ return base64.b64encode(pickle.dumps(lops)).decode()
197
+
198
+ class PythonCompiler(Compiler):
199
+ def compile(self, src:str) -> bytes: return base64.b64decode(src)
200
+
201
+ class PythonAllocator(Allocator):
202
+ def _alloc(self, size, options): return memoryview(bytearray(size))
203
+ def copyin(self, dest, src:memoryview): dest[:] = src
204
+ def copyout(self, dest:memoryview, src): dest[:] = src
205
+
206
+ class PythonDevice(Compiled):
207
+ def __init__(self, device:str):
208
+ super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram)
File without changes
@@ -1,70 +1,35 @@
1
1
  # ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
2
2
  from __future__ import annotations
3
- import functools, itertools, operator
4
3
  from dataclasses import dataclass
5
- from typing import Tuple, List, Optional, Dict, Set, cast, Union, Iterable
6
- from tinygrad.ops import MovementOps
7
- from tinygrad.helpers import prod, DEBUG, merge_dicts, getenv
8
- from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
9
- from tinygrad.shape.view import View, _merge_dims
10
-
11
- def expr_node_mask(view:View, idx:Node, valid:Optional[Node]=None) -> Node:
12
- expr = [valid] if valid is not None else []
13
- if view.mask is not None:
14
- acc = 1
15
- for d,(x,y) in zip(reversed(view.shape), reversed(view.mask)):
16
- if (x,y) != (0,d):
17
- base = ((idx//acc)%d)
18
- expr += [base >= x, base < y]
19
- acc *= d
20
- return Node.ands(expr)
21
-
22
- # generate an expression if you have a single idx variable
23
- def expr_node(view:View, idx:Optional[Node]=None) -> Node:
24
- if idx is None: idx = Variable('idx', 0, prod(view.shape)-1)
25
- ret: List[Node] = [NumNode(view.offset) if isinstance(view.offset, int) else view.offset] if view.offset else []
26
- acc = 1
27
- for d,s,_ in reversed(_merge_dims(view.shape, view.strides)):
28
- ret.append(((idx//acc)%d)*s)
29
- acc *= d
30
- return Node.sum(ret)
31
-
32
- # generate an expression if you have a variable or expression for each index
33
- def expr_idxs(view:View, idxs:Tuple[Node, ...]) -> Node:
4
+ from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
5
+ from tinygrad.helpers import merge_dicts, getenv
6
+ from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, create_lt_node, create_ge_node, sint
7
+ from tinygrad.shape.view import View, strides_for_shape
8
+
9
+ def _expr_view(view:View, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
34
10
  assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
35
- return Node.sum([NumNode(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0]) # noqa: E501
36
-
37
- @functools.lru_cache(maxsize=None)
38
- def merge_views(vm2:View, vm1:View) -> Optional[View]:
39
- if vm1.contiguous and vm1.shape == vm2.shape: return vm2
40
- if vm2.contiguous: return vm1
41
- if vm2.mask or vm1.offset != 0: return None # this isn't supported yet
42
- if None in (strides := ShapeTracker((vm2, vm1)).real_strides()): return None
43
- return View.create(vm1.shape, cast(Tuple[sint, ...], strides), vm2.offset, vm1.mask)
44
-
45
- @functools.lru_cache(maxsize=None)
46
- def idxs_to_idx(shape:Tuple[int, ...], idxs:Tuple[Node, ...]) -> Node:
47
- assert len(idxs) == len(shape), "need an idx for all dimensions"
48
- acc, ret = 1, []
49
- for tidx,d in zip(reversed(idxs), reversed(shape)):
50
- ret.append(tidx * acc)
51
- acc *= d
52
- return Node.sum(ret)
11
+ iexpr: List[Node] = [NumNode(view.offset) if isinstance(view.offset, int) else view.offset]
12
+ vexpr: List[Node] = [valid] if valid is not None else []
13
+ for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
14
+ if sh != 1 and st != 0: iexpr.append(idx*st)
15
+ if m is not None: vexpr += [create_ge_node(idx, m[0]), create_lt_node(idx, m[1])] # idx >= m[0], idx < m[1]
16
+ return Node.sum(iexpr), Node.ands(vexpr)
53
17
 
54
18
  @dataclass(frozen=True)
55
19
  class ShapeTracker:
56
20
  views: Tuple[View, ...]
57
- def __post_init__(self):
58
- assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views"
59
21
 
60
22
  def __add__(self, st:ShapeTracker) -> ShapeTracker:
61
- base = ShapeTracker(self.views)
62
- for v in st.views: base = ShapeTracker(base.views + (v,)).simplify() # one view at a time = better simplification
63
- return base
23
+ ret = self
24
+ for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
25
+ return ret
64
26
 
65
27
  def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]:
66
- ret = tuple(v.invert(s) for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]))
67
- return ShapeTracker(cast(Tuple[View, ...], ret)).reshape(out_shape) if all(x is not None for x in ret) else None
28
+ inverted_views:List[View] = []
29
+ for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]):
30
+ if (inverted:= v.invert(s)) is None: return None
31
+ inverted_views.append(inverted)
32
+ return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
68
33
 
69
34
  @staticmethod
70
35
  def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
@@ -72,16 +37,22 @@ class ShapeTracker:
72
37
  @property
73
38
  def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
74
39
 
40
+ @property
41
+ def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
42
+
75
43
  @property
76
44
  def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
77
45
 
78
46
  @property
79
- def size(self) -> int: return prod([x.max if isinstance(x, Node) else x for x in self.views[-1].shape])
47
+ def size(self) -> int: return self.views[-1].size()
80
48
 
81
49
  def real_size(self) -> int:
82
50
  if 0 in self.shape: return 0
83
- ret = self.expr_idxs()[0].max
84
- while not isinstance(ret, int): ret = ret.max # TODO: this is a while loop?!? it should be more clear what max does
51
+ idx, valid = self.expr_idxs()
52
+ if not valid: return 0
53
+ # TODO: it's possible that the real_size is smaller condition on valid being true
54
+ ret = idx.max
55
+ if not isinstance(ret, int): ret = ret.max # might be represent by symbolic shape, take one more max for int max
85
56
  assert isinstance(ret, int), f"ret must be integer, {ret=} isn't"
86
57
  return ret+1
87
58
 
@@ -90,30 +61,9 @@ class ShapeTracker:
90
61
  @property
91
62
  def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
92
63
 
93
- def unbind(self) -> ShapeTracker: return ShapeTracker(tuple(v.unbind() for v in self.views))
94
-
95
- def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]:
96
- to_apply:List[Tuple[MovementOps, Tuple]] = []
97
- for v in self.views:
98
- real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
99
- real_offset = 0 if 0 in real_shape else (v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0))
100
- # first, we apply the offset
101
- # then, we make it the correct shape
102
- # then, we apply permutations
103
- to_apply.append((MovementOps.AS_STRIDED, (tuple([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)]), v.strides, real_offset)))
104
- # then, we apply pre expand pads
105
- if v.mask is not None:
106
- pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
107
- post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
108
- if any(x != (0,0) for x in pre_expand_pads):
109
- to_apply.append((MovementOps.PAD, pre_expand_pads))
110
- real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads))
111
- # then, we do any expands
112
- # NOTE: this is a good idea even without masks, since torch doesn't support negative strides and has to make a copy
113
- if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape))
114
- # lastly, we apply post expand pads
115
- if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads))
116
- return to_apply
64
+ def unbind(self) -> Tuple[ShapeTracker, Dict[Variable, int]]:
65
+ unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
66
+ return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
117
67
 
118
68
  # NOTE: if a stride is not always valid, it will be None
119
69
  def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
@@ -124,7 +74,7 @@ class ShapeTracker:
124
74
  bad_idx_vars: Set[Variable] = set()
125
75
  for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
126
76
  idx_maybe, stride_maybe = (this_dim.a, this_dim.b) if isinstance(this_dim, MulNode) else (this_dim, 1)
127
- try: ret[idxs.index(idx_maybe)] = stride_maybe
77
+ try: ret[idxs.index(idx_maybe)] = cast(sint, stride_maybe)
128
78
  except ValueError: bad_idx_vars = bad_idx_vars.union(idx_maybe.vars())
129
79
  idx_vars, valid_vars = idx.vars(), valid.vars()
130
80
  for i,tidx in enumerate(idxs):
@@ -134,30 +84,27 @@ class ShapeTracker:
134
84
 
135
85
  def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
136
86
 
137
- def _expr_idx(self, idx:Node, valid:Node) -> Tuple[Node, Node]:
138
- for v in reversed(self.views[0:-1]):
87
+ def expr_idxs(self, idxs:Optional[Iterable[Node]]=None) -> Tuple[Node, Node]:
88
+ idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] if idxs is None else list(idxs)
89
+ idx, valid = _expr_view(self.views[-1], idxs)
90
+ for view in reversed(self.views[0:-1]):
139
91
  if valid.max == 0: return NumNode(-1), valid
140
- valid = expr_node_mask(v, idx, valid)
141
- idx = expr_node(v, idx)
92
+ view = view.minify()
93
+ acc, idxs = 1, []
94
+ for d in reversed(view.shape):
95
+ idxs.append((idx//acc)%d)
96
+ acc *= d
97
+ idx, valid = _expr_view(view, idxs[::-1], valid)
98
+ assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
99
+ assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
142
100
  return idx, valid
143
101
 
144
- def expr_idxs(self, idxs:Optional[Iterable[Node]]=None):
145
- if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
146
- idx = expr_idxs(self.views[-1], tuple(idxs))
147
- valid = expr_node_mask(self.views[-1], idxs_to_idx(self.views[-1].shape, tuple(idxs)))
148
- return self._expr_idx(idx, valid)
149
-
150
- def expr_node(self, idx:Union[Node,str]='idx'):
151
- if isinstance(idx, str): idx = Variable(idx, 0, prod(self.shape)-1)
152
- return self._expr_idx(expr_node(self.views[-1], idx), expr_node_mask(self.views[-1], idx))
153
-
154
102
  def axis_is_masked(self, axis:int) -> bool:
155
103
  _, valid = self.expr_idxs()
156
104
  return f'idx{axis}' in [v.expr for v in valid.vars()]
157
105
 
158
106
  def simplify(self) -> ShapeTracker:
159
- if len(self.views) >= 2 and (new_view := merge_views(self.views[-2], self.views[-1])) is not None:
160
- if DEBUG >= 5: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}")
107
+ if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
161
108
  return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
162
109
  return self
163
110
 
@@ -172,11 +119,3 @@ class ShapeTracker:
172
119
  def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
173
120
  if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
174
121
  return ShapeTracker(self.views + (View.create(new_shape), ))
175
-
176
- # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
177
- # TODO: if we remove movementops from lazy.py we can delete this
178
- def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
179
- acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
180
- try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
181
- except ValueError: return None
182
- return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]