tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. tinygrad/__init__.py +6 -0
  2. tinygrad/codegen/kernel.py +572 -83
  3. tinygrad/codegen/linearizer.py +415 -395
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +183 -0
  6. tinygrad/dtype.py +113 -0
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +76 -55
  14. tinygrad/helpers.py +196 -89
  15. tinygrad/lazy.py +210 -371
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +202 -22
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +112 -32
  20. tinygrad/nn/state.py +136 -39
  21. tinygrad/ops.py +119 -202
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +353 -166
  25. tinygrad/renderer/llvmir.py +150 -138
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +81 -0
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +75 -0
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +24 -77
  43. tinygrad/runtime/ops_cuda.py +175 -89
  44. tinygrad/runtime/ops_disk.py +56 -33
  45. tinygrad/runtime/ops_gpu.py +92 -95
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +39 -60
  48. tinygrad/runtime/ops_metal.py +92 -74
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +86 -254
  53. tinygrad/shape/symbolic.py +166 -141
  54. tinygrad/shape/view.py +296 -0
  55. tinygrad/tensor.py +2619 -448
  56. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. tinygrad-0.9.0.dist-info/METADATA +227 -0
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/codegen/assembly.py +0 -190
  61. tinygrad/codegen/optimizer.py +0 -379
  62. tinygrad/codegen/search.py +0 -72
  63. tinygrad/graph.py +0 -83
  64. tinygrad/jit.py +0 -57
  65. tinygrad/nn/image.py +0 -100
  66. tinygrad/renderer/assembly_arm64.py +0 -169
  67. tinygrad/renderer/assembly_ptx.py +0 -98
  68. tinygrad/renderer/wgsl.py +0 -53
  69. tinygrad/runtime/lib.py +0 -113
  70. tinygrad/runtime/ops_cpu.py +0 -51
  71. tinygrad/runtime/ops_hip.py +0 -82
  72. tinygrad/runtime/ops_shm.py +0 -29
  73. tinygrad/runtime/ops_torch.py +0 -30
  74. tinygrad/runtime/ops_webgpu.py +0 -45
  75. tinygrad-0.7.0.dist-info/METADATA +0 -212
  76. tinygrad-0.7.0.dist-info/RECORD +0 -40
  77. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,204 @@
1
+ # a python uops emulator
2
+ # works to test the tensor cores, and all the uops in general
3
+ # this is the (living) definition of uops
4
+ from typing import Tuple, List, Optional, Any, Dict
5
+ import pickle, base64, itertools, time, struct
6
+ from tinygrad.dtype import DType, dtypes, ImageDType
7
+ from tinygrad.helpers import all_same, getenv, flatten
8
+ from tinygrad.device import Compiled, Compiler, Allocator
9
+ from tinygrad.codegen.uops import UOpGraph, UOps
10
+ from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
11
+ from tinygrad.renderer import Renderer
12
+ from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, HIPRenderer
13
+
14
+ def _load(m, i):
15
+ if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
16
+ return m[i]
17
+
18
+ def load(inp, j=0):
19
+ if len(inp) == 4: return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)]
20
+ else: return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
21
+
22
+ def _store(m, i, v):
23
+ if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
24
+ m[i] = v
25
+
26
+ class PythonProgram:
27
+ def __init__(self, name:str, lib:bytes):
28
+ self.uops: List[Tuple[UOps, Optional[DType], List[int], Any]] = pickle.loads(lib)
29
+ def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
30
+ st = time.perf_counter()
31
+ warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
32
+ warp_size = len(warp)
33
+ for idxs in itertools.product(*[range(x) for x in global_size[::-1]]):
34
+ ul: Dict[int, Any] = {}
35
+ dl: Dict[int, DType] = {}
36
+ pbufs: List[memoryview] = list(bufs)
37
+ pvals: List[int] = list(vals)
38
+ i = 0
39
+ loop_ends: Dict[int, int] = {}
40
+ while i < len(self.uops):
41
+ uop, dtype, idp, arg = self.uops[i]
42
+ void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF}
43
+ if uop is UOps.DEFINE_ACC: idp.clear()
44
+ inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
45
+ dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
46
+ if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
47
+ if uop is UOps.STORE:
48
+ if len(inp) == 3: inp.append([True] * len(inp[0])) # set the gate to True
49
+ if isinstance(dtp[0], ImageDType):
50
+ # image store
51
+ assert dtp[2].count == 4
52
+ for j,val in enumerate(inp[2]):
53
+ for m,ox,oy,v,g in zip(inp[0], inp[1][0], inp[1][1], val, inp[3]):
54
+ assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
55
+ if g: _store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
56
+ elif dtp[2].count > 1:
57
+ for j,val in enumerate(inp[2]):
58
+ for m,o,v,g in zip(inp[0], inp[1], val, inp[3]):
59
+ if g: _store(m, o+j, v)
60
+ else:
61
+ for m,o,v,g in zip(*inp):
62
+ if g: _store(m, o, v)
63
+ i += 1
64
+ continue
65
+ elif uop is UOps.ENDRANGE:
66
+ loop_ends[idp[0]] = i
67
+ i = idp[0]
68
+ continue
69
+ elif uop in (UOps.BARRIER, UOps.IF, UOps.ENDIF):
70
+ # in the python emulator, the warp is always in sync
71
+ i += 1
72
+ continue
73
+ assert dtype is not None, f"{uop} is missing a dtype"
74
+ dl[i] = dtype
75
+ if uop is UOps.DEFINE_GLOBAL:
76
+ assert dtype.fmt is not None
77
+ ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size
78
+ elif uop is UOps.DEFINE_LOCAL:
79
+ assert dtype.fmt is not None
80
+ lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
81
+ ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
82
+ elif uop is UOps.DEFINE_VAR:
83
+ ul[i] = [pvals.pop(0)] * warp_size
84
+ elif uop is UOps.SPECIAL:
85
+ if arg[1][0] == 'g':
86
+ ul[i] = [idxs[2-arg[0]]] * warp_size
87
+ elif arg[1][0] == 'l':
88
+ ul[i] = [x[2-arg[0]] for x in warp]
89
+ elif uop is UOps.CONST:
90
+ ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
91
+ elif uop is UOps.DEFINE_ACC:
92
+ ul[i] = [[arg[0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg[0]] * warp_size
93
+ elif uop is UOps.RANGE:
94
+ if i not in ul: ul[i] = [inp[0][0]] * warp_size
95
+ else:
96
+ for j in range(len(ul[i])):
97
+ ul[i][j] += 1
98
+ if ul[i][0] == inp[1][0]:
99
+ del ul[i]
100
+ i = loop_ends[i] + 1
101
+ continue
102
+ elif uop in {UOps.CAST, UOps.BITCAST}:
103
+ if dtype.count > 1: ul[i] = inp
104
+ else:
105
+ assert dtp[0].fmt and dtype.fmt
106
+ pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
107
+ if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
108
+ else:
109
+ casted = [dtypes.as_const(x, dtype) for x in inp[0]]
110
+ overflow_adjust = 2**(dtype.itemsize*8 - 1) if (dtypes.is_int(dtype) and not dtypes.is_unsigned(dtype)) else 0
111
+ overflow_fixed = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) if dtypes.is_int(dtype) else x for x in casted]
112
+ ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *overflow_fixed)))
113
+ elif uop is UOps.LOAD:
114
+ if isinstance(dtp[0], ImageDType):
115
+ assert dtype.count == 4
116
+ ul[i] = []
117
+ for j in range(dtype.count):
118
+ ret = []
119
+ for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
120
+ if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append(0)
121
+ else: ret.append(_load(m, ox*4 + oy*dtp[0].shape[1]*4 + j))
122
+ ul[i].append(ret)
123
+ elif dtype.count > 1:
124
+ ul[i] = [load([inp[i][j] if dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
125
+ else:
126
+ ul[i] = load(inp)
127
+ elif uop is UOps.PHI:
128
+ for j in range(len(inp[0])): inp[0][j] = inp[1][j]
129
+ ul[i] = inp[0]
130
+ elif uop is UOps.GEP:
131
+ ul[i] = inp[0][arg]
132
+ elif uop is UOps.WMMA:
133
+ # here are the models for the WMMA instruction on the different hardware
134
+ def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
135
+ assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread"
136
+ assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread"
137
+ assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread"
138
+ assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA"
139
+ assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA"
140
+ assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA"
141
+ assert warp_size > 0 and warp_size % WARP_THREADS == 0, f"must have multiples of {WARP_THREADS} warp threads"
142
+ out = [inp[2][elem_idx][:] for elem_idx in range(NUM_C)]
143
+ for goff in range(0, warp_size, WARP_THREADS):
144
+ for lane_id in range(WARP_THREADS):
145
+ for elem_idx in range(NUM_C): # calculate new muls and add to acc
146
+ (c_i, c_j) = c_map(lane_id, elem_idx)
147
+ out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
148
+ return out
149
+
150
+ # TODO: refactor these to a shared TensorCoreLayout in kernel.py
151
+ if arg[5] == "METAL":
152
+ # A (2 elements on 32 threads): row major
153
+ def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
154
+ # (i, j), C, D (2 elements on 32 threads): row major same as A/B
155
+ def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
156
+ ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
157
+ elif arg[5] == "HSA":
158
+ # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
159
+ def a_elem(x, i, j, goff):
160
+ assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
161
+ return x[i][goff+j]
162
+ # B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
163
+ def b_elem(x, i, j, goff): return a_elem(x, j, i, goff)
164
+ def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
165
+ ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
166
+ elif arg[5] == "CUDA":
167
+ # A (8 elements on 32 threads)
168
+ def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
169
+ # B (4 elements on 32 threads)
170
+ def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4]
171
+ # (i, j), C, D (4 elements on 32 threads)
172
+ def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8)
173
+ ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
174
+ else: raise NotImplementedError(f"unimplemented tensor core {arg}")
175
+ elif uop is UOps.ALU:
176
+ assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
177
+ assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
178
+ ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)]
179
+ assert i in ul, (uop, dtype, idp, arg)
180
+ i += 1
181
+ return time.perf_counter() - st
182
+
183
+ class PythonRenderer(Renderer):
184
+ device = "PYTHON"
185
+ def __init__(self):
186
+ if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
187
+ if getenv("EMULATE_HSA"): self.device, self.tensor_cores = "HSA", HIPRenderer.tensor_cores
188
+ if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
189
+
190
+ def render(self, name:str, uops:UOpGraph) -> str:
191
+ lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops]
192
+ return base64.b64encode(pickle.dumps(lops)).decode()
193
+
194
+ class PythonCompiler(Compiler):
195
+ def compile(self, src:str) -> bytes: return base64.b64decode(src)
196
+
197
+ class PythonAllocator(Allocator):
198
+ def _alloc(self, size, options): return memoryview(bytearray(size))
199
+ def copyin(self, dest, src:memoryview): dest[:] = src
200
+ def copyout(self, dest:memoryview, src): dest[:] = src
201
+
202
+ class PythonDevice(Compiled):
203
+ def __init__(self, device:str):
204
+ super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram)
@@ -1,286 +1,118 @@
1
1
  # ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
2
2
  from __future__ import annotations
3
- import functools
4
- from typing import Dict, Tuple, Union, List, Optional, cast, NamedTuple
5
- from tinygrad.helpers import prod, DEBUG, partition
6
- from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, is_sym_int
7
-
8
- @functools.lru_cache(maxsize=None)
9
- def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]:
10
- assert len(shape) == len(strides)
11
- ret = [(shape[0], strides[0])] if shape else []
12
- for i in range(1, len(shape)):
13
- if ret[-1][1] == shape[i]*strides[i] or ret[-1][0] == 1:
14
- ret[-1] = (ret[-1][0] * shape[i], strides[i])
15
- elif shape[i] == 1:
16
- continue
17
- else:
18
- ret.append((shape[i], strides[i]))
19
- return tuple(ret)
20
-
21
- @functools.lru_cache(maxsize=None)
22
- def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all(s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape)))
23
-
24
- @functools.lru_cache(maxsize=None)
25
- def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
26
- return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape))
27
-
28
- class ViewInternal(NamedTuple):
29
- shape:Tuple[int, ...]
30
- strides:Tuple[int, ...]
31
- offset:int
32
- mask:Optional[Tuple[Tuple[int, int]]]
33
- contiguous:bool
34
- shape_strides:Tuple[Tuple[int, int], ...]
35
-
36
- @functools.lru_cache(maxsize=None)
37
- class View(ViewInternal):
38
- def __new__(cls, shape, strides=None, offset=0, mask=None):
39
- strides_from_shape = strides_for_shape(shape)
40
- strides = strides_from_shape if not strides else filter_strides(shape, strides)
41
- contiguous = offset == 0 and is_contiguous(shape, strides) and mask is None
42
- return super().__new__(cls, shape, strides, offset, mask, contiguous, to_shape_strides(shape, strides))
43
- def __init__(self, shape, strides=None, offset=0, mask=None, contiguous=False, shape_strides=()): super().__init__()
44
-
45
- def expr_node_mask(self, idx, valid=None) -> Node:
46
- expr = [valid] if valid is not None else []
47
- if self.mask is not None:
48
- acc = 1
49
- for ns,(x,y) in reversed(list(zip(self.shape, self.mask))):
50
- base = ((idx//acc) % ns)
51
- expr += [base >= x, base < y]
52
- acc *= ns
53
- return Variable.ands(expr)
54
-
55
- # generate an expression if you have a single idx variable
56
- def expr_node(self, idx=None) -> Node:
57
- if idx is None: idx = Variable('idx', 0, prod(self.shape)-1)
58
- ret: List[Node] = [Variable.num(self.offset) if isinstance(self.offset, int) else self.offset] if self.offset else []
59
- acc = 1
60
- for d,s in reversed(self.shape_strides):
61
- ret.append(((idx//acc)%d)*s)
62
- acc *= d
63
- return Variable.sum(ret)
64
-
65
- # generate an expression if you have a variable or expression for each index
66
- def expr_idxs(self, idxs) -> Node:
67
- assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}"
68
- return Variable.sum([Variable.num(self.offset) if isinstance(self.offset, int) else self.offset] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0])
69
-
70
- @functools.lru_cache(maxsize=None)
71
- def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
72
- assert len(idxs) == len(shape), "need an idx for all dimensions"
73
- acc = 1
74
- ret = []
75
- for tidx,d in reversed(list(zip(idxs, shape))):
76
- ret.append(tidx * acc)
77
- acc *= d
78
- return Variable.sum(ret)
79
-
80
- @functools.lru_cache(maxsize=None)
81
- def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
82
- strides = [1] if shape else []
83
- for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
84
- return filter_strides(shape, tuple(strides))
85
-
86
- @functools.lru_cache(maxsize=None)
87
- def merge_views(vm2:View, vm1:View) -> Optional[View]:
88
- if vm2.mask: return None # this isn't supported yet
89
- mst = ShapeTracker(vm1.shape, [vm2, vm1])
90
- strides = mst.real_strides()
91
- if None in strides: return None
92
- return View(vm1.shape, strides, mst.real_offset(), vm1.mask)
93
-
94
- @functools.lru_cache(maxsize=None)
95
- def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]:
96
- shape, mask, strides, offset = view.shape, view.mask, view.strides, view.offset
97
- # check if this is adding or removing 1s (only)
98
- # NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional)
99
- if [x for x in shape if x != 1] == [x for x in new_shape if x != 1]:
100
- new_strides: List[int] = [y for x,y in zip(shape, strides) if x != 1]
101
- new_strides_tuple: Tuple[int, ...] = tuple([0 if x == 1 else new_strides.pop(0) for x in new_shape])
102
- new_mask_tuple = None
103
- if mask:
104
- for x,y in zip(shape, mask):
105
- if x == 1 and y != (0, 1):
106
- new_mask_tuple = ((0,0),) * len(new_shape)
107
- break
108
- else:
109
- new_mask: List[Tuple[int, int]] = [y for x,y in zip(shape, mask) if x != 1]
110
- new_mask_tuple = tuple([(0,1) if x == 1 else new_mask.pop(0) for x in new_shape])
111
- return View(new_shape, new_strides_tuple, offset, new_mask_tuple), False
3
+ from dataclasses import dataclass
4
+ from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
5
+ from tinygrad.helpers import merge_dicts, getenv
6
+ from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, create_lt_node, create_ge_node, sint
7
+ from tinygrad.shape.view import View, strides_for_shape
8
+
9
+ def _expr_view(view:View, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
10
+ assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
11
+ iexpr: List[Node] = [NumNode(view.offset) if isinstance(view.offset, int) else view.offset]
12
+ vexpr: List[Node] = [valid] if valid is not None else []
13
+ for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
14
+ if sh != 1 and st != 0: iexpr.append(idx*st)
15
+ if m is not None: vexpr += [create_ge_node(idx, m[0]), create_lt_node(idx, m[1])] # idx >= m[0], idx < m[1]
16
+ return Node.sum(iexpr), Node.ands(vexpr)
17
+
18
+ @dataclass(frozen=True)
19
+ class ShapeTracker:
20
+ views: Tuple[View, ...]
112
21
 
113
- new_view = View(new_shape)
114
- if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset
115
- if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False
116
- if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}")
117
- return new_view, True
22
+ def __add__(self, st:ShapeTracker) -> ShapeTracker:
23
+ ret = self
24
+ for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
25
+ return ret
118
26
 
119
- @functools.lru_cache(maxsize=None)
120
- def get_pad_args(shape:Tuple[int,...], arg:Tuple[Tuple[int, int], ...]):
121
- return tuple([(-b,s+e) for s,(b,e) in zip(shape, arg)]), tuple([(b,s+b) for s,(b,_) in zip(shape, arg)])
27
+ def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]:
28
+ ret = tuple(v.invert(s) for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]))
29
+ return ShapeTracker(cast(Tuple[View, ...], ret)).reshape(out_shape) if all(x is not None for x in ret) else None
122
30
 
123
- @functools.lru_cache(maxsize=None)
124
- def get_unsafe_resize_offset(strides, arg):
125
- return sum([s * x[0] for s, x in zip(strides,arg)])
126
-
127
- class ShapeTracker:
128
- __slots__ = "views", "var_vals"
129
- def __init__(self, shape:Union[ShapeTracker, Tuple[Union[Node,int], ...]], views:Optional[List[View]]=None):
130
- self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ is ShapeTracker else [View(shape)])
131
- self.var_vals: Dict[Variable, int] = shape.var_vals if isinstance(shape, ShapeTracker) else {}
132
- def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views}, var_vals={self.var_vals})"
133
- def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views])
31
+ @staticmethod
32
+ def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
134
33
 
135
34
  @property
136
35
  def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
137
36
 
138
37
  @property
139
- def shape(self) -> Tuple[int, ...]: return self.views[-1].shape # NOTE: real type is Tuple[Union[Node, int], ...] but mypy complains about prod(shape)
38
+ def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
39
+
40
+ @property
41
+ def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
140
42
 
141
43
  @property
142
- def key(self) -> Tuple[View, ...]: return tuple(self.views)
44
+ def size(self) -> int: return self.views[-1].size()
45
+
46
+ def real_size(self) -> int:
47
+ if 0 in self.shape: return 0
48
+ idx, valid = self.expr_idxs()
49
+ if not valid: return 0
50
+ # TODO: it's possible that the real_size is smaller condition on valid being true
51
+ ret = idx.max
52
+ if not isinstance(ret, int): ret = ret.max # might be represent by symbolic shape, take one more max for int max
53
+ assert isinstance(ret, int), f"ret must be integer, {ret=} isn't"
54
+ return ret+1
143
55
 
144
- # this is the real size (ish)
145
- def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0])
56
+ def vars(self) -> Set[Variable]: return set.union(*[v.vars() for v in self.views], set())
146
57
 
147
- # these are multiview strides, value is None if it's not a simple strided dimension
148
- # TODO: this can be shared code between simplify and merge_views
149
- def real_offset(self) -> int:
150
- real_offset, mask = self.expr_node(Variable('zero', 0, 0))
151
- assert real_offset.__class__ is NumNode, f"how is the offset not a number? {real_offset} {mask}"
152
- return real_offset.b
58
+ @property
59
+ def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
60
+
61
+ def unbind(self) -> Tuple[ShapeTracker, Dict[Variable, int]]:
62
+ unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
63
+ return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
153
64
 
154
65
  # NOTE: if a stride is not always valid, it will be None
155
- def real_strides(self, ignore_valid=False) -> Tuple[Optional[Union[Node, int]], ...]:
66
+ def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
156
67
  if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
157
- idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
68
+ idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
158
69
  idx, valid = self.expr_idxs(idxs)
159
- ret: List[Optional[Union[Node, int]]] = [None] * len(self.views[-1].shape)
70
+ ret: List[Optional[sint]] = [None] * len(self.views[-1].shape)
71
+ bad_idx_vars: Set[Variable] = set()
160
72
  for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
161
- if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and this_dim.a in idxs:
162
- ret[idxs.index(this_dim.a)] = this_dim.b
163
- elif isinstance(this_dim, Variable):
164
- ret[idxs.index(this_dim)] = 1
73
+ idx_maybe, stride_maybe = (this_dim.a, this_dim.b) if isinstance(this_dim, MulNode) else (this_dim, 1)
74
+ try: ret[idxs.index(idx_maybe)] = cast(sint, stride_maybe)
75
+ except ValueError: bad_idx_vars = bad_idx_vars.union(idx_maybe.vars())
165
76
  idx_vars, valid_vars = idx.vars(), valid.vars()
166
77
  for i,tidx in enumerate(idxs):
167
- if tidx in valid_vars and not ignore_valid: ret[i] = None
78
+ if tidx in bad_idx_vars or (tidx in valid_vars and not ignore_valid): ret[i] = None
168
79
  elif tidx not in idx_vars: ret[i] = 0
169
80
  return tuple(ret)
81
+
170
82
  def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
171
83
 
172
- def _expr_idx(self, idx, valid):
173
- for v in reversed(self.views[0:-1]):
174
- valid = v.expr_node_mask(idx, valid)
175
- idx = v.expr_node(idx)
84
+ def expr_idxs(self, idxs:Optional[Iterable[Node]]=None) -> Tuple[Node, Node]:
85
+ idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] if idxs is None else list(idxs)
86
+ idx, valid = _expr_view(self.views[-1], idxs)
87
+ for view in reversed(self.views[0:-1]):
88
+ if valid.max == 0: return NumNode(-1), valid
89
+ view = view.minify()
90
+ acc, idxs = 1, []
91
+ for d in reversed(view.shape):
92
+ idxs.append((idx//acc)%d)
93
+ acc *= d
94
+ idx, valid = _expr_view(view, idxs[::-1], valid)
95
+ assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
96
+ assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
176
97
  return idx, valid
177
98
 
178
- def simplify(self):
179
- if len(self.views) >= 2:
180
- new_view = merge_views(self.views[-2], self.views[-1])
181
- if new_view:
182
- if DEBUG >= 4: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}")
183
- self.views = self.views[:-2] + [new_view]
184
- self.simplify()
185
-
186
- def expr_idxs(self, idxs=None):
187
- if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
188
- idx = self.views[-1].expr_idxs(tuple(idxs))
189
- valid = self.views[-1].expr_node_mask(idxs_to_idx(self.views[-1].shape, tuple(idxs)))
190
- return self._expr_idx(idx, valid)
191
-
192
- def expr_node(self, idx='idx'):
193
- if idx.__class__ is str: idx = Variable(idx, 0, prod(self.shape)-1)
194
- return self._expr_idx(self.views[-1].expr_node(idx), self.views[-1].expr_node_mask(idx))
195
-
196
- def needs_valid(self) -> bool:
197
- return any(v.mask is not None for v in self.views)
198
-
199
- # *** under this line are the movement ops ***
200
-
201
- def __unsafe_resize(self, arg: Tuple[Tuple[int, int], ...], mask=None):
202
- offset = get_unsafe_resize_offset(self.views[-1].strides, arg)
203
- if self.views[-1].mask:
204
- # move the old mask
205
- nmask = tuple([(max(mx-ax, 0), min(my-ax, ay-ax)) for (mx,my),(ax,ay) in zip(self.views[-1].mask, arg)])
206
- # merge the masks if we have two
207
- mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
208
- self.views[-1] = View(tuple([y-x for x,y in arg]), self.views[-1].strides, self.views[-1].offset+offset, mask)
99
+ def axis_is_masked(self, axis:int) -> bool:
100
+ _, valid = self.expr_idxs()
101
+ return f'idx{axis}' in [v.expr for v in valid.vars()]
209
102
 
210
- def pad(self, arg: Tuple[Tuple[int, int], ...]):
211
- assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
212
- if any(b or e for b, e in arg):
213
- zvarg, mask = get_pad_args(self.shape, arg)
214
- self.__unsafe_resize(zvarg, mask=mask)
103
+ def simplify(self) -> ShapeTracker:
104
+ if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
105
+ return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
215
106
  return self
216
107
 
217
- def shrink(self, arg: Tuple[Tuple[int, int], ...]):
218
- assert all((b>=0 and e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape)
219
- self.__unsafe_resize(arg)
220
- return self
221
-
222
- def expand(self, new_shape: Tuple[Union[Node,int], ...]) -> ShapeTracker:
223
- assert len(new_shape) == len(self.views[-1].shape)
224
- assert all(is_sym_int(x) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.views[-1].strides)), f"can't expand {self.shape} into {new_shape}"
225
- # NOTE: can the mask ever be (0,0)?
226
- mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.views[-1].mask, self.shape, new_shape)]) if self.views[-1].mask else None
227
- self.views[-1] = View(new_shape, self.views[-1].strides, self.views[-1].offset, mask)
228
- return self
229
-
230
- def reshape(self, new_shape: Tuple[Union[Node,int], ...]):
231
- if self.views[-1].shape == new_shape: return self
232
- new_ints, new_nodes = partition(new_shape, lambda s: isinstance(s, int))
233
- if new_nodes and all(isinstance(s, int) for s in self.shape):
234
- # reshape from all int shape into shape with a variable, update the variable value
235
- assert len(new_nodes) == 1 and isinstance(new_nodes[0], Variable), "only support adding one Variable to the int shape"
236
- new_var, new_val = new_nodes[0], prod(self.shape) // prod(new_ints)
237
- if new_var not in self.var_vals:
238
- assert new_var.min <= new_val <= new_var.max, f"variable value {new_val} out of range [{new_var.min}, {new_var.max}]"
239
- self.var_vals[new_var] = new_val
240
- else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}"
241
- elif not new_nodes: self.var_vals = {}
242
- assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
243
- # only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done
244
- if all(isinstance(s, int) for s in self.shape) and all(isinstance(s, int) for s in new_shape):
245
- assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}" # type: ignore # mypy cannot resolve, all ints here
246
- new_view, extra = _reshape(self.views[-1], new_shape)
247
- if extra: self.views.append(new_view)
248
- else: self.views[-1] = new_view
249
- return self
250
-
251
- def permute(self, axis: Tuple[int, ...]):
252
- assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
253
- assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
254
- self.views[-1] = View(tuple([self.views[-1].shape[a] for a in axis]), tuple([self.views[-1].strides[a] for a in axis]), self.views[-1].offset, tuple([self.views[-1].mask[a] for a in axis]) if self.views[-1].mask is not None else None)
255
- return self
108
+ # *** under this line are the movement ops ***
256
109
 
257
- # except for the negative case, you can build this from the others. invertible in the negative case
258
- def stride(self, mul: Tuple[int, ...]):
259
- assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}"
260
- strides = tuple([z*m for z,m in zip(self.views[-1].strides, mul)])
261
- new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.views[-1].shape, mul)])
262
- offset = sum([(s-1)*z for s,z,m in zip(self.views[-1].shape, self.views[-1].strides, mul) if m < 0])
263
- mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.views[-1].mask, self.views[-1].shape, mul)]) if self.views[-1].mask is not None else None
264
- self.views[-1] = View(new_shape, strides, self.views[-1].offset + offset, mask)
265
- return self
110
+ def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
111
+ def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
112
+ def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
113
+ def permute(self, axis: Tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
114
+ def stride(self, mul: Tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), ))
266
115
 
267
- # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
268
- def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Optional[List[List[int]]]:
269
- # Pre-allocate all groups.
270
- axis_groups: List[List[int]] = [[] for _ in range(len(new_shape))]
271
- # Index for new_shape and axis_groups.
272
- i: int = 0
273
- old_shape_i: int = 0
274
- while old_shape_i < len(old_shape):
275
- # 1s exist in new_shape only will lead to empty axes group creations.
276
- if new_shape[i] == 1 and old_shape[old_shape_i] != 1:
277
- if i < len(new_shape) - 1: i += 1
278
- else:
279
- axis_groups[i].append(old_shape_i)
280
- axis_group_size = prod([old_shape[x] for x in axis_groups[i]])
281
- # Move to next axes group if total size of all dimensions match.
282
- if axis_group_size == new_shape[i]:
283
- if i < len(new_shape) - 1: i += 1
284
- elif axis_group_size > new_shape[i]: return None
285
- old_shape_i += 1
286
- return axis_groups
116
+ def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
117
+ if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
118
+ return ShapeTracker(self.views + (View.create(new_shape), ))