tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/kernel.py +230 -190
  3. tinygrad/codegen/linearizer.py +278 -384
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +132 -275
  6. tinygrad/dtype.py +53 -37
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +28 -14
  14. tinygrad/helpers.py +72 -43
  15. tinygrad/lazy.py +141 -240
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +179 -8
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +106 -28
  20. tinygrad/nn/state.py +86 -17
  21. tinygrad/ops.py +70 -44
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +299 -206
  25. tinygrad/renderer/llvmir.py +118 -123
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +59 -54
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +37 -41
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +16 -14
  43. tinygrad/runtime/ops_cuda.py +130 -38
  44. tinygrad/runtime/ops_disk.py +45 -42
  45. tinygrad/runtime/ops_gpu.py +52 -50
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +36 -56
  48. tinygrad/runtime/ops_metal.py +42 -24
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +41 -105
  53. tinygrad/shape/symbolic.py +98 -95
  54. tinygrad/shape/view.py +137 -35
  55. tinygrad/tensor.py +2367 -442
  56. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/features/image.py +0 -93
  61. tinygrad/features/multi.py +0 -103
  62. tinygrad/features/search.py +0 -160
  63. tinygrad/graph.py +0 -106
  64. tinygrad/jit.py +0 -152
  65. tinygrad/realize.py +0 -50
  66. tinygrad/runtime/graph/hip.py +0 -24
  67. tinygrad/runtime/ops_cpu.py +0 -45
  68. tinygrad/runtime/ops_hip.py +0 -97
  69. tinygrad/runtime/ops_torch.py +0 -49
  70. tinygrad-0.8.0.dist-info/RECORD +0 -41
  71. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,204 @@
1
+ # a python uops emulator
2
+ # works to test the tensor cores, and all the uops in general
3
+ # this is the (living) definition of uops
4
+ from typing import Tuple, List, Optional, Any, Dict
5
+ import pickle, base64, itertools, time, struct
6
+ from tinygrad.dtype import DType, dtypes, ImageDType
7
+ from tinygrad.helpers import all_same, getenv, flatten
8
+ from tinygrad.device import Compiled, Compiler, Allocator
9
+ from tinygrad.codegen.uops import UOpGraph, UOps
10
+ from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
11
+ from tinygrad.renderer import Renderer
12
+ from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, HIPRenderer
13
+
14
+ def _load(m, i):
15
+ if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
16
+ return m[i]
17
+
18
+ def load(inp, j=0):
19
+ if len(inp) == 4: return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)]
20
+ else: return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
21
+
22
+ def _store(m, i, v):
23
+ if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
24
+ m[i] = v
25
+
26
+ class PythonProgram:
27
+ def __init__(self, name:str, lib:bytes):
28
+ self.uops: List[Tuple[UOps, Optional[DType], List[int], Any]] = pickle.loads(lib)
29
+ def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
30
+ st = time.perf_counter()
31
+ warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
32
+ warp_size = len(warp)
33
+ for idxs in itertools.product(*[range(x) for x in global_size[::-1]]):
34
+ ul: Dict[int, Any] = {}
35
+ dl: Dict[int, DType] = {}
36
+ pbufs: List[memoryview] = list(bufs)
37
+ pvals: List[int] = list(vals)
38
+ i = 0
39
+ loop_ends: Dict[int, int] = {}
40
+ while i < len(self.uops):
41
+ uop, dtype, idp, arg = self.uops[i]
42
+ void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF}
43
+ if uop is UOps.DEFINE_ACC: idp.clear()
44
+ inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
45
+ dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
46
+ if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
47
+ if uop is UOps.STORE:
48
+ if len(inp) == 3: inp.append([True] * len(inp[0])) # set the gate to True
49
+ if isinstance(dtp[0], ImageDType):
50
+ # image store
51
+ assert dtp[2].count == 4
52
+ for j,val in enumerate(inp[2]):
53
+ for m,ox,oy,v,g in zip(inp[0], inp[1][0], inp[1][1], val, inp[3]):
54
+ assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
55
+ if g: _store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
56
+ elif dtp[2].count > 1:
57
+ for j,val in enumerate(inp[2]):
58
+ for m,o,v,g in zip(inp[0], inp[1], val, inp[3]):
59
+ if g: _store(m, o+j, v)
60
+ else:
61
+ for m,o,v,g in zip(*inp):
62
+ if g: _store(m, o, v)
63
+ i += 1
64
+ continue
65
+ elif uop is UOps.ENDRANGE:
66
+ loop_ends[idp[0]] = i
67
+ i = idp[0]
68
+ continue
69
+ elif uop in (UOps.BARRIER, UOps.IF, UOps.ENDIF):
70
+ # in the python emulator, the warp is always in sync
71
+ i += 1
72
+ continue
73
+ assert dtype is not None, f"{uop} is missing a dtype"
74
+ dl[i] = dtype
75
+ if uop is UOps.DEFINE_GLOBAL:
76
+ assert dtype.fmt is not None
77
+ ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size
78
+ elif uop is UOps.DEFINE_LOCAL:
79
+ assert dtype.fmt is not None
80
+ lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
81
+ ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
82
+ elif uop is UOps.DEFINE_VAR:
83
+ ul[i] = [pvals.pop(0)] * warp_size
84
+ elif uop is UOps.SPECIAL:
85
+ if arg[1][0] == 'g':
86
+ ul[i] = [idxs[2-arg[0]]] * warp_size
87
+ elif arg[1][0] == 'l':
88
+ ul[i] = [x[2-arg[0]] for x in warp]
89
+ elif uop is UOps.CONST:
90
+ ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
91
+ elif uop is UOps.DEFINE_ACC:
92
+ ul[i] = [[arg[0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg[0]] * warp_size
93
+ elif uop is UOps.RANGE:
94
+ if i not in ul: ul[i] = [inp[0][0]] * warp_size
95
+ else:
96
+ for j in range(len(ul[i])):
97
+ ul[i][j] += 1
98
+ if ul[i][0] == inp[1][0]:
99
+ del ul[i]
100
+ i = loop_ends[i] + 1
101
+ continue
102
+ elif uop in {UOps.CAST, UOps.BITCAST}:
103
+ if dtype.count > 1: ul[i] = inp
104
+ else:
105
+ assert dtp[0].fmt and dtype.fmt
106
+ pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
107
+ if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
108
+ else:
109
+ casted = [dtypes.as_const(x, dtype) for x in inp[0]]
110
+ overflow_adjust = 2**(dtype.itemsize*8 - 1) if (dtypes.is_int(dtype) and not dtypes.is_unsigned(dtype)) else 0
111
+ overflow_fixed = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) if dtypes.is_int(dtype) else x for x in casted]
112
+ ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *overflow_fixed)))
113
+ elif uop is UOps.LOAD:
114
+ if isinstance(dtp[0], ImageDType):
115
+ assert dtype.count == 4
116
+ ul[i] = []
117
+ for j in range(dtype.count):
118
+ ret = []
119
+ for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
120
+ if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append(0)
121
+ else: ret.append(_load(m, ox*4 + oy*dtp[0].shape[1]*4 + j))
122
+ ul[i].append(ret)
123
+ elif dtype.count > 1:
124
+ ul[i] = [load([inp[i][j] if dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
125
+ else:
126
+ ul[i] = load(inp)
127
+ elif uop is UOps.PHI:
128
+ for j in range(len(inp[0])): inp[0][j] = inp[1][j]
129
+ ul[i] = inp[0]
130
+ elif uop is UOps.GEP:
131
+ ul[i] = inp[0][arg]
132
+ elif uop is UOps.WMMA:
133
+ # here are the models for the WMMA instruction on the different hardware
134
+ def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
135
+ assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread"
136
+ assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread"
137
+ assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread"
138
+ assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA"
139
+ assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA"
140
+ assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA"
141
+ assert warp_size > 0 and warp_size % WARP_THREADS == 0, f"must have multiples of {WARP_THREADS} warp threads"
142
+ out = [inp[2][elem_idx][:] for elem_idx in range(NUM_C)]
143
+ for goff in range(0, warp_size, WARP_THREADS):
144
+ for lane_id in range(WARP_THREADS):
145
+ for elem_idx in range(NUM_C): # calculate new muls and add to acc
146
+ (c_i, c_j) = c_map(lane_id, elem_idx)
147
+ out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
148
+ return out
149
+
150
+ # TODO: refactor these to a shared TensorCoreLayout in kernel.py
151
+ if arg[5] == "METAL":
152
+ # A (2 elements on 32 threads): row major
153
+ def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
154
+ # (i, j), C, D (2 elements on 32 threads): row major same as A/B
155
+ def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
156
+ ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
157
+ elif arg[5] == "HSA":
158
+ # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
159
+ def a_elem(x, i, j, goff):
160
+ assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
161
+ return x[i][goff+j]
162
+ # B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
163
+ def b_elem(x, i, j, goff): return a_elem(x, j, i, goff)
164
+ def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
165
+ ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
166
+ elif arg[5] == "CUDA":
167
+ # A (8 elements on 32 threads)
168
+ def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
169
+ # B (4 elements on 32 threads)
170
+ def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4]
171
+ # (i, j), C, D (4 elements on 32 threads)
172
+ def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8)
173
+ ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
174
+ else: raise NotImplementedError(f"unimplemented tensor core {arg}")
175
+ elif uop is UOps.ALU:
176
+ assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
177
+ assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
178
+ ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)]
179
+ assert i in ul, (uop, dtype, idp, arg)
180
+ i += 1
181
+ return time.perf_counter() - st
182
+
183
+ class PythonRenderer(Renderer):
184
+ device = "PYTHON"
185
+ def __init__(self):
186
+ if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
187
+ if getenv("EMULATE_HSA"): self.device, self.tensor_cores = "HSA", HIPRenderer.tensor_cores
188
+ if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
189
+
190
+ def render(self, name:str, uops:UOpGraph) -> str:
191
+ lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops]
192
+ return base64.b64encode(pickle.dumps(lops)).decode()
193
+
194
+ class PythonCompiler(Compiler):
195
+ def compile(self, src:str) -> bytes: return base64.b64decode(src)
196
+
197
+ class PythonAllocator(Allocator):
198
+ def _alloc(self, size, options): return memoryview(bytearray(size))
199
+ def copyin(self, dest, src:memoryview): dest[:] = src
200
+ def copyout(self, dest:memoryview, src): dest[:] = src
201
+
202
+ class PythonDevice(Compiled):
203
+ def __init__(self, device:str):
204
+ super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram)
@@ -1,66 +1,28 @@
1
1
  # ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
2
2
  from __future__ import annotations
3
- import functools, itertools, operator
4
3
  from dataclasses import dataclass
5
- from typing import Tuple, List, Optional, Dict, Set, cast, Union, Iterable
6
- from tinygrad.ops import MovementOps
7
- from tinygrad.helpers import prod, DEBUG, merge_dicts, getenv
8
- from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
9
- from tinygrad.shape.view import View, _merge_dims
10
-
11
- def expr_node_mask(view:View, idx:Node, valid:Optional[Node]=None) -> Node:
12
- expr = [valid] if valid is not None else []
13
- if view.mask is not None:
14
- acc = 1
15
- for d,(x,y) in zip(reversed(view.shape), reversed(view.mask)):
16
- if (x,y) != (0,d):
17
- base = ((idx//acc)%d)
18
- expr += [base >= x, base < y]
19
- acc *= d
20
- return Node.ands(expr)
21
-
22
- # generate an expression if you have a single idx variable
23
- def expr_node(view:View, idx:Optional[Node]=None) -> Node:
24
- if idx is None: idx = Variable('idx', 0, prod(view.shape)-1)
25
- ret: List[Node] = [NumNode(view.offset) if isinstance(view.offset, int) else view.offset] if view.offset else []
26
- acc = 1
27
- for d,s,_ in reversed(_merge_dims(view.shape, view.strides)):
28
- ret.append(((idx//acc)%d)*s)
29
- acc *= d
30
- return Node.sum(ret)
31
-
32
- # generate an expression if you have a variable or expression for each index
33
- def expr_idxs(view:View, idxs:Tuple[Node, ...]) -> Node:
4
+ from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
5
+ from tinygrad.helpers import merge_dicts, getenv
6
+ from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, create_lt_node, create_ge_node, sint
7
+ from tinygrad.shape.view import View, strides_for_shape
8
+
9
+ def _expr_view(view:View, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
34
10
  assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
35
- return Node.sum([NumNode(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0]) # noqa: E501
36
-
37
- @functools.lru_cache(maxsize=None)
38
- def merge_views(vm2:View, vm1:View) -> Optional[View]:
39
- if vm1.contiguous and vm1.shape == vm2.shape: return vm2
40
- if vm2.contiguous: return vm1
41
- if vm2.mask or vm1.offset != 0: return None # this isn't supported yet
42
- if None in (strides := ShapeTracker((vm2, vm1)).real_strides()): return None
43
- return View.create(vm1.shape, cast(Tuple[sint, ...], strides), vm2.offset, vm1.mask)
44
-
45
- @functools.lru_cache(maxsize=None)
46
- def idxs_to_idx(shape:Tuple[int, ...], idxs:Tuple[Node, ...]) -> Node:
47
- assert len(idxs) == len(shape), "need an idx for all dimensions"
48
- acc, ret = 1, []
49
- for tidx,d in zip(reversed(idxs), reversed(shape)):
50
- ret.append(tidx * acc)
51
- acc *= d
52
- return Node.sum(ret)
11
+ iexpr: List[Node] = [NumNode(view.offset) if isinstance(view.offset, int) else view.offset]
12
+ vexpr: List[Node] = [valid] if valid is not None else []
13
+ for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
14
+ if sh != 1 and st != 0: iexpr.append(idx*st)
15
+ if m is not None: vexpr += [create_ge_node(idx, m[0]), create_lt_node(idx, m[1])] # idx >= m[0], idx < m[1]
16
+ return Node.sum(iexpr), Node.ands(vexpr)
53
17
 
54
18
  @dataclass(frozen=True)
55
19
  class ShapeTracker:
56
20
  views: Tuple[View, ...]
57
- def __post_init__(self):
58
- assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views"
59
21
 
60
22
  def __add__(self, st:ShapeTracker) -> ShapeTracker:
61
- base = ShapeTracker(self.views)
62
- for v in st.views: base = ShapeTracker(base.views + (v,)).simplify() # one view at a time = better simplification
63
- return base
23
+ ret = self
24
+ for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
25
+ return ret
64
26
 
65
27
  def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]:
66
28
  ret = tuple(v.invert(s) for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]))
@@ -72,16 +34,22 @@ class ShapeTracker:
72
34
  @property
73
35
  def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
74
36
 
37
+ @property
38
+ def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
39
+
75
40
  @property
76
41
  def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
77
42
 
78
43
  @property
79
- def size(self) -> int: return prod([x.max if isinstance(x, Node) else x for x in self.views[-1].shape])
44
+ def size(self) -> int: return self.views[-1].size()
80
45
 
81
46
  def real_size(self) -> int:
82
47
  if 0 in self.shape: return 0
83
- ret = self.expr_idxs()[0].max
84
- while not isinstance(ret, int): ret = ret.max # TODO: this is a while loop?!? it should be more clear what max does
48
+ idx, valid = self.expr_idxs()
49
+ if not valid: return 0
50
+ # TODO: it's possible that the real_size is smaller condition on valid being true
51
+ ret = idx.max
52
+ if not isinstance(ret, int): ret = ret.max # might be represent by symbolic shape, take one more max for int max
85
53
  assert isinstance(ret, int), f"ret must be integer, {ret=} isn't"
86
54
  return ret+1
87
55
 
@@ -90,30 +58,9 @@ class ShapeTracker:
90
58
  @property
91
59
  def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
92
60
 
93
- def unbind(self) -> ShapeTracker: return ShapeTracker(tuple(v.unbind() for v in self.views))
94
-
95
- def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]:
96
- to_apply:List[Tuple[MovementOps, Tuple]] = []
97
- for v in self.views:
98
- real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
99
- real_offset = 0 if 0 in real_shape else (v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0))
100
- # first, we apply the offset
101
- # then, we make it the correct shape
102
- # then, we apply permutations
103
- to_apply.append((MovementOps.AS_STRIDED, (tuple([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)]), v.strides, real_offset)))
104
- # then, we apply pre expand pads
105
- if v.mask is not None:
106
- pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
107
- post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
108
- if any(x != (0,0) for x in pre_expand_pads):
109
- to_apply.append((MovementOps.PAD, pre_expand_pads))
110
- real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads))
111
- # then, we do any expands
112
- # NOTE: this is a good idea even without masks, since torch doesn't support negative strides and has to make a copy
113
- if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape))
114
- # lastly, we apply post expand pads
115
- if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads))
116
- return to_apply
61
+ def unbind(self) -> Tuple[ShapeTracker, Dict[Variable, int]]:
62
+ unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
63
+ return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
117
64
 
118
65
  # NOTE: if a stride is not always valid, it will be None
119
66
  def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
@@ -124,7 +71,7 @@ class ShapeTracker:
124
71
  bad_idx_vars: Set[Variable] = set()
125
72
  for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
126
73
  idx_maybe, stride_maybe = (this_dim.a, this_dim.b) if isinstance(this_dim, MulNode) else (this_dim, 1)
127
- try: ret[idxs.index(idx_maybe)] = stride_maybe
74
+ try: ret[idxs.index(idx_maybe)] = cast(sint, stride_maybe)
128
75
  except ValueError: bad_idx_vars = bad_idx_vars.union(idx_maybe.vars())
129
76
  idx_vars, valid_vars = idx.vars(), valid.vars()
130
77
  for i,tidx in enumerate(idxs):
@@ -134,30 +81,27 @@ class ShapeTracker:
134
81
 
135
82
  def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
136
83
 
137
- def _expr_idx(self, idx:Node, valid:Node) -> Tuple[Node, Node]:
138
- for v in reversed(self.views[0:-1]):
84
+ def expr_idxs(self, idxs:Optional[Iterable[Node]]=None) -> Tuple[Node, Node]:
85
+ idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] if idxs is None else list(idxs)
86
+ idx, valid = _expr_view(self.views[-1], idxs)
87
+ for view in reversed(self.views[0:-1]):
139
88
  if valid.max == 0: return NumNode(-1), valid
140
- valid = expr_node_mask(v, idx, valid)
141
- idx = expr_node(v, idx)
89
+ view = view.minify()
90
+ acc, idxs = 1, []
91
+ for d in reversed(view.shape):
92
+ idxs.append((idx//acc)%d)
93
+ acc *= d
94
+ idx, valid = _expr_view(view, idxs[::-1], valid)
95
+ assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
96
+ assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
142
97
  return idx, valid
143
98
 
144
- def expr_idxs(self, idxs:Optional[Iterable[Node]]=None):
145
- if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
146
- idx = expr_idxs(self.views[-1], tuple(idxs))
147
- valid = expr_node_mask(self.views[-1], idxs_to_idx(self.views[-1].shape, tuple(idxs)))
148
- return self._expr_idx(idx, valid)
149
-
150
- def expr_node(self, idx:Union[Node,str]='idx'):
151
- if isinstance(idx, str): idx = Variable(idx, 0, prod(self.shape)-1)
152
- return self._expr_idx(expr_node(self.views[-1], idx), expr_node_mask(self.views[-1], idx))
153
-
154
99
  def axis_is_masked(self, axis:int) -> bool:
155
100
  _, valid = self.expr_idxs()
156
101
  return f'idx{axis}' in [v.expr for v in valid.vars()]
157
102
 
158
103
  def simplify(self) -> ShapeTracker:
159
- if len(self.views) >= 2 and (new_view := merge_views(self.views[-2], self.views[-1])) is not None:
160
- if DEBUG >= 5: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}")
104
+ if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
161
105
  return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
162
106
  return self
163
107
 
@@ -172,11 +116,3 @@ class ShapeTracker:
172
116
  def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
173
117
  if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
174
118
  return ShapeTracker(self.views + (View.create(new_shape), ))
175
-
176
- # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
177
- # TODO: if we remove movementops from lazy.py we can delete this
178
- def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
179
- acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
180
- try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
181
- except ValueError: return None
182
- return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]