tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
- tinygrad/codegen/kernel.py +572 -83
- tinygrad/codegen/linearizer.py +415 -395
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +183 -0
- tinygrad/dtype.py +113 -0
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +76 -55
- tinygrad/helpers.py +196 -89
- tinygrad/lazy.py +210 -371
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +202 -22
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +112 -32
- tinygrad/nn/state.py +136 -39
- tinygrad/ops.py +119 -202
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +353 -166
- tinygrad/renderer/llvmir.py +150 -138
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +81 -0
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +75 -0
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +24 -77
- tinygrad/runtime/ops_cuda.py +175 -89
- tinygrad/runtime/ops_disk.py +56 -33
- tinygrad/runtime/ops_gpu.py +92 -95
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +39 -60
- tinygrad/runtime/ops_metal.py +92 -74
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +86 -254
- tinygrad/shape/symbolic.py +166 -141
- tinygrad/shape/view.py +296 -0
- tinygrad/tensor.py +2619 -448
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- tinygrad-0.9.0.dist-info/METADATA +227 -0
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/assembly.py +0 -190
- tinygrad/codegen/optimizer.py +0 -379
- tinygrad/codegen/search.py +0 -72
- tinygrad/graph.py +0 -83
- tinygrad/jit.py +0 -57
- tinygrad/nn/image.py +0 -100
- tinygrad/renderer/assembly_arm64.py +0 -169
- tinygrad/renderer/assembly_ptx.py +0 -98
- tinygrad/renderer/wgsl.py +0 -53
- tinygrad/runtime/lib.py +0 -113
- tinygrad/runtime/ops_cpu.py +0 -51
- tinygrad/runtime/ops_hip.py +0 -82
- tinygrad/runtime/ops_shm.py +0 -29
- tinygrad/runtime/ops_torch.py +0 -30
- tinygrad/runtime/ops_webgpu.py +0 -45
- tinygrad-0.7.0.dist-info/METADATA +0 -212
- tinygrad-0.7.0.dist-info/RECORD +0 -40
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/codegen/optimizer.py
DELETED
@@ -1,379 +0,0 @@
|
|
1
|
-
from typing import Tuple, List, cast
|
2
|
-
import itertools, math
|
3
|
-
from tinygrad.helpers import DEBUG, prod, getenv, ImageDType, dtypes
|
4
|
-
from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp
|
5
|
-
from tinygrad.codegen.kernel import Kernel, LocalBuffer
|
6
|
-
from tinygrad.lazy import LazyBuffer
|
7
|
-
from tinygrad.shape.shapetracker import ShapeTracker, View
|
8
|
-
|
9
|
-
class OptimizedKernel(Kernel):
|
10
|
-
def process(self) -> None:
|
11
|
-
if hasattr(self, "sts"): return # already processed
|
12
|
-
super().process()
|
13
|
-
|
14
|
-
# move all reduce axes to the end
|
15
|
-
reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape)))
|
16
|
-
permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
|
17
|
-
self.reshape_and_permute(None, permute)
|
18
|
-
|
19
|
-
# group simplifies
|
20
|
-
self.simplify_ones()
|
21
|
-
self.simplify_merge_adjacent()
|
22
|
-
|
23
|
-
# ******************** base simplifiers ********************
|
24
|
-
|
25
|
-
# apply reshape and permute to all shapetrackers
|
26
|
-
def reshape_and_permute(self, new_shape_fxn, axis):
|
27
|
-
for st in self.sts:
|
28
|
-
if new_shape_fxn is not None: st.reshape(tuple(new_shape_fxn(st.shape)))
|
29
|
-
if axis is not None: st.permute(tuple(axis))
|
30
|
-
|
31
|
-
# drops the final dimension
|
32
|
-
def upcast(self):
|
33
|
-
assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1"
|
34
|
-
self.upcasted += 1
|
35
|
-
|
36
|
-
# axis : the axis to pull from
|
37
|
-
# amount : the amount to take
|
38
|
-
# top : if you want to pull that amount from the top
|
39
|
-
# insert_before : place to insert the new stuff
|
40
|
-
def shift_to(self, axis, amount, top=False, insert_before=None):
|
41
|
-
if insert_before is None: insert_before = self.shape_len
|
42
|
-
move_axis = axis if top else axis+1
|
43
|
-
if move_axis < insert_before: insert_before += 1
|
44
|
-
self.reshape_and_permute(
|
45
|
-
lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]),
|
46
|
-
[i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])
|
47
|
-
|
48
|
-
# ******************** complex simplifiers ********************
|
49
|
-
|
50
|
-
def simplify_ones(self):
|
51
|
-
# remove places where the shape is all ones
|
52
|
-
# TODO: this should be factored in to multi shape stride
|
53
|
-
if self.shape_len == 0: return
|
54
|
-
all_ones = [s==1 for s in self.full_shape]
|
55
|
-
self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
|
56
|
-
self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:])
|
57
|
-
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
|
58
|
-
|
59
|
-
def simplify_merge_adjacent(self):
|
60
|
-
if self.shape_len == 0: return
|
61
|
-
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
|
62
|
-
|
63
|
-
# merge dimensions if we can, multi get_shape_strides
|
64
|
-
# TODO: does this always preserve the reduce dimension, NO
|
65
|
-
# TODO: move this into shapetracker, with tests!
|
66
|
-
rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))]
|
67
|
-
for i in range(1, len(shapes[0])):
|
68
|
-
can_merge = []
|
69
|
-
for j in range(len(shapes)):
|
70
|
-
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
|
71
|
-
can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0)))
|
72
|
-
# more can merge than this
|
73
|
-
mergeable = all(can_merge) and i != self.first_reduce
|
74
|
-
for j in range(len(shapes)):
|
75
|
-
if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
|
76
|
-
else: rets[j].append((shapes[j][i], strides[j][i]))
|
77
|
-
|
78
|
-
# do the reshapes
|
79
|
-
for i,x in enumerate(rets): self.sts[i].reshape(tuple([y[0] for y in x]))
|
80
|
-
|
81
|
-
# ******************** GPU simplifiers ********************
|
82
|
-
def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
|
83
|
-
new_shape,dims = list(x), len(x)
|
84
|
-
for i in range(dims):
|
85
|
-
next_idx = (i + 1) % dims
|
86
|
-
while new_shape[i] > max_size[i]:
|
87
|
-
new_shape[i] = new_shape[i] // 2
|
88
|
-
if (new_shape[next_idx] <= max_size[next_idx]):
|
89
|
-
new_shape[next_idx] = new_shape[next_idx] * 2
|
90
|
-
else:
|
91
|
-
next_idx = (next_idx + 1) % dims
|
92
|
-
new_shape[next_idx] = new_shape[next_idx] * 2
|
93
|
-
return tuple(new_shape)
|
94
|
-
|
95
|
-
def limit_global_dims(self, limit: int, global_max: List[int], local_max: List[int]):
|
96
|
-
# sometimes, there's more dimensions than len(self.lang.gid).
|
97
|
-
# compact all the dimensions into the first
|
98
|
-
# NOTE: this might make multiview shapetrackers
|
99
|
-
if (self.first_reduce-self.local_dims) > limit:
|
100
|
-
num_to_merge = ((self.first_reduce-self.local_dims) - limit)+1
|
101
|
-
self.reshape_and_permute(lambda x: (prod(x[0:num_to_merge]),)+x[num_to_merge:], None)
|
102
|
-
if DEBUG >= 3: print("reshaped to", self.full_shape, "due to too many global dimensions")
|
103
|
-
# Check the global allocation limit, current the global_size will be flipped during codegen
|
104
|
-
# and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
|
105
|
-
global_dims = self.first_reduce-self.local_dims
|
106
|
-
if global_dims > 0:
|
107
|
-
if global_max:
|
108
|
-
tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else [])
|
109
|
-
if max(global_max) < max(self.full_shape[:global_dims]): self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
|
110
|
-
assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}"
|
111
|
-
for i in range(global_dims-1):
|
112
|
-
if self.full_shape[i] > global_max[i]:
|
113
|
-
order = list(range(len(self.full_shape)))
|
114
|
-
order[i], order[global_dims-1] = order[global_dims-1], order[i]
|
115
|
-
self.reshape_and_permute(None, order)
|
116
|
-
if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit")
|
117
|
-
|
118
|
-
def alias_buffer(self, i, pattern):
|
119
|
-
assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
|
120
|
-
|
121
|
-
bst = 1
|
122
|
-
real_strides = self.sts[i].real_strides()
|
123
|
-
shp, stride = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern)
|
124
|
-
for priority in range(1, max(pattern)+1): # priority. 0 is non local and ignored
|
125
|
-
for j,p in enumerate(pattern):
|
126
|
-
if priority == p and real_strides[j] != 0:
|
127
|
-
stride[j] = bst
|
128
|
-
bst *= shp[j]
|
129
|
-
|
130
|
-
self.sts.append(ShapeTracker(tuple(shp), [View(tuple(shp), tuple(stride))]))
|
131
|
-
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size()))
|
132
|
-
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
|
133
|
-
self.local_alias[i] = self.bufs[-1]
|
134
|
-
|
135
|
-
# ******************** high level optimizers ********************
|
136
|
-
|
137
|
-
def apply_auto_opt(self, x):
|
138
|
-
for axis, amt, typ in x:
|
139
|
-
if axis is None or amt == 1: continue
|
140
|
-
if typ == "R":
|
141
|
-
typ = "U"
|
142
|
-
axis += self.first_reduce
|
143
|
-
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
|
144
|
-
if typ == "U":
|
145
|
-
self.shift_to(axis, amt)
|
146
|
-
self.upcast()
|
147
|
-
elif typ == "L":
|
148
|
-
self.shift_to(axis, amt, insert_before=self.first_reduce)
|
149
|
-
self.local_dims += 1
|
150
|
-
self.simplify_ones()
|
151
|
-
|
152
|
-
def required_optimizations(self, early_only=False):
|
153
|
-
for buf_index,buf in enumerate(self.bufs):
|
154
|
-
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
|
155
|
-
if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType:
|
156
|
-
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
|
157
|
-
if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
|
158
|
-
self.shift_to(unit_stride_axes_mul_4[0], 4)
|
159
|
-
self.upcast()
|
160
|
-
|
161
|
-
def hand_coded_optimizations(self):
|
162
|
-
self.process()
|
163
|
-
|
164
|
-
# if there's images in the earlybufs, we have to make an axis the 4 loading one
|
165
|
-
self.required_optimizations(early_only=True)
|
166
|
-
|
167
|
-
# simplify
|
168
|
-
self.simplify_ones()
|
169
|
-
|
170
|
-
# should use HIP tensor cores?
|
171
|
-
if getenv("TC", 1) != 0 and self.bufs[0].device == "HIP" and self.reduceop and self.reduceop.op == ReduceOps.SUM and \
|
172
|
-
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and \
|
173
|
-
isinstance(self.reduceop.src[0].src[0], LazyOp) and self.reduceop.src[0].src[0].op == BinaryOps.MUL and \
|
174
|
-
isinstance(self.reduceop.src[0].src[0].src[0], LazyBuffer) and isinstance(self.reduceop.src[0].src[0].src[1], LazyBuffer) and self.opts.has_local and \
|
175
|
-
self.reduceop.src[0].src[0].src[0].dtype == dtypes.half and self.reduceop.src[0].src[0].src[1].dtype == dtypes.half:
|
176
|
-
# HIP tensor cores are 16x16x16
|
177
|
-
buf0 = self.bufs.index(self.reduceop.src[0].src[0].src[0])
|
178
|
-
buf1 = self.bufs.index(self.reduceop.src[0].src[0].src[1])
|
179
|
-
buf0_strides = self.sts[buf0].real_strides()
|
180
|
-
buf1_strides = self.sts[buf1].real_strides()
|
181
|
-
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides) if s == 0 and self.full_shape[i]%16 == 0 and i < self.first_reduce]
|
182
|
-
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides) if s == 0 and self.full_shape[i]%16 == 0 and i < self.first_reduce]
|
183
|
-
if axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%8 == 0 and (self.shape_len-self.first_reduce) == 1:
|
184
|
-
if DEBUG >= 3: print("HIP TENSOR CORES", axis_buf0, axis_buf1)
|
185
|
-
self.use_tensor_cores = getenv("TC", 1) == 1 # TC=2 will do the shape ops without the WMMA
|
186
|
-
self.reverse_upcast_dir = True
|
187
|
-
|
188
|
-
# TODO: select axis in smart way
|
189
|
-
s0, s1 = axis_buf0[-1][0], axis_buf1[-1][0]
|
190
|
-
global_count = self.first_reduce
|
191
|
-
|
192
|
-
# upcast first
|
193
|
-
if self.full_shape[self.first_reduce] > 16: self.shift_to(self.first_reduce, 16)
|
194
|
-
self.upcast()
|
195
|
-
|
196
|
-
# 2 locals
|
197
|
-
self.shift_to(s1, 16, insert_before=self.first_reduce) # axis 2
|
198
|
-
self.shift_to(s0, 16, insert_before=self.first_reduce) # axis 3
|
199
|
-
self.local_dims += 1
|
200
|
-
|
201
|
-
# output shape
|
202
|
-
self.shift_to(self.first_reduce-2, 8)
|
203
|
-
self.upcast()
|
204
|
-
|
205
|
-
# split local dim
|
206
|
-
self.shift_to(self.first_reduce-1, 8, insert_before=self.first_reduce) # axis 3
|
207
|
-
|
208
|
-
# final global upcast
|
209
|
-
for ax in [s1, s0]:
|
210
|
-
for upc in [4,3,2]:
|
211
|
-
if self.full_shape[ax]%upc == 0:
|
212
|
-
self.shift_to(ax, upc)
|
213
|
-
self.upcast()
|
214
|
-
break
|
215
|
-
|
216
|
-
# alias buffer
|
217
|
-
alias_pattern = [0]*global_count + [0,0,1] + [0] * (self.shape_len-self.upcasted-self.first_reduce) + [2,3] + [0]*(self.upcasted-2)
|
218
|
-
self.alias_buffer(buf0, alias_pattern)
|
219
|
-
self.alias_buffer(buf1, alias_pattern)
|
220
|
-
|
221
|
-
# two fake locals
|
222
|
-
if self.use_tensor_cores:
|
223
|
-
self.local_dims += 2
|
224
|
-
self.exclude_local_upcast += 2
|
225
|
-
|
226
|
-
# early exit
|
227
|
-
return
|
228
|
-
|
229
|
-
# should use METAL tensor cores?
|
230
|
-
# first, confirm it's a straightforward mulacc on a device with real locals
|
231
|
-
tensor_cores_allowed = getenv("TC", 1) != 0 and (getenv("TC", 1) == 2 or (self.bufs[0].device == "METAL" and getenv("CI", "") != "true"))
|
232
|
-
if tensor_cores_allowed and self.reduceop and self.reduceop.op == ReduceOps.SUM and \
|
233
|
-
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \
|
234
|
-
isinstance(self.reduceop.src[0].src[0], LazyBuffer) and isinstance(self.reduceop.src[0].src[1], LazyBuffer) and self.opts.has_local:
|
235
|
-
# METAL tensor cores are 8x8x8, with 2 elements per thread in the 32 thread warp
|
236
|
-
buf0 = self.bufs.index(self.reduceop.src[0].src[0])
|
237
|
-
buf1 = self.bufs.index(self.reduceop.src[0].src[1])
|
238
|
-
buf0_strides = self.sts[buf0].real_strides()
|
239
|
-
buf1_strides = self.sts[buf1].real_strides()
|
240
|
-
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides) if s == 0 and self.full_shape[i]%8 == 0 and i < self.first_reduce]
|
241
|
-
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides) if s == 0 and self.full_shape[i]%8 == 0 and i < self.first_reduce]
|
242
|
-
if axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%8 == 0 and (self.shape_len-self.first_reduce) == 1:
|
243
|
-
if DEBUG >= 3: print("METAL TENSOR CORES", axis_buf0, axis_buf1)
|
244
|
-
self.use_tensor_cores = getenv("TC", 1) == 1 # TC=2 will do the shape ops without the WMMA
|
245
|
-
|
246
|
-
# TODO: select axis in smart way
|
247
|
-
s0, s1 = axis_buf0[-1][0], axis_buf1[-1][0]
|
248
|
-
global_count = self.first_reduce
|
249
|
-
|
250
|
-
# upcast first
|
251
|
-
if self.full_shape[self.first_reduce] > 8: self.shift_to(self.first_reduce, 8)
|
252
|
-
self.upcast()
|
253
|
-
|
254
|
-
# 2 locals
|
255
|
-
self.shift_to(s1, 8, insert_before=self.first_reduce) # axis 2
|
256
|
-
self.shift_to(s0, 8, insert_before=self.first_reduce) # axis 3
|
257
|
-
|
258
|
-
# permuted+upcast for tensor cores
|
259
|
-
self.shift_to(global_count, 4, insert_before=self.first_reduce)
|
260
|
-
self.shift_to(global_count+1, 4, insert_before=self.first_reduce)
|
261
|
-
self.shift_to(self.first_reduce-1, 2)
|
262
|
-
self.upcast()
|
263
|
-
|
264
|
-
# final global upcast
|
265
|
-
for ax in [s1, s0]:
|
266
|
-
for upc in [4,3,2]:
|
267
|
-
if self.full_shape[ax]%upc == 0:
|
268
|
-
self.shift_to(ax, upc)
|
269
|
-
self.upcast()
|
270
|
-
break
|
271
|
-
|
272
|
-
# alias buffer
|
273
|
-
self.local_dims = self.first_reduce - global_count
|
274
|
-
alias_pattern = [0]*global_count + [2] * self.local_dims + [0] * (self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3] * (self.upcasted-2)
|
275
|
-
self.alias_buffer(buf0, alias_pattern)
|
276
|
-
self.alias_buffer(buf1, alias_pattern)
|
277
|
-
|
278
|
-
# very late upcast to run group at the same time. only if actually using real tensor cores, otherwise local isn't a simdgroup
|
279
|
-
if self.use_tensor_cores and self.full_shape[s0] % 2 == 0:
|
280
|
-
self.shift_to(s0, 2, insert_before=self.first_reduce-self.local_dims)
|
281
|
-
self.local_dims += 1
|
282
|
-
self.exclude_local_upcast += 1
|
283
|
-
|
284
|
-
# early exit
|
285
|
-
return
|
286
|
-
|
287
|
-
if self.opts.has_local and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]):
|
288
|
-
# are we grouping? (requires local shape support)
|
289
|
-
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
|
290
|
-
# TODO: use 1024 if it's allowed in a smarter way
|
291
|
-
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
292
|
-
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
|
293
|
-
self.shift_to(self.first_reduce, sz, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
|
294
|
-
self.group_for_reduce.append(sz)
|
295
|
-
break
|
296
|
-
|
297
|
-
# are we upcasting in mid reduce? (only for images)
|
298
|
-
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1:
|
299
|
-
axes = self.sts[0].unit_stride_axes()
|
300
|
-
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
301
|
-
if self.sts[0].shape[axes[0]]%4 == 0:
|
302
|
-
self.shift_to(axes[0], 4, insert_before=self.first_reduce + len(self.group_for_reduce)) # insert at the end of the grouped axis
|
303
|
-
self.group_for_reduce.append(4)
|
304
|
-
|
305
|
-
# now do everything required
|
306
|
-
self.required_optimizations()
|
307
|
-
|
308
|
-
# simplify (sets first_reduce)
|
309
|
-
self.simplify_ones()
|
310
|
-
|
311
|
-
# use more opencl indexing if the output buffer is an image and we have room
|
312
|
-
if self.bufs[0].dtype.name.startswith('image') and self.first_reduce+len(self.group_for_reduce) < 3:
|
313
|
-
base_shape = self.bufs[0].dtype.shape
|
314
|
-
if (base_shape[0]*base_shape[1]) % self.sts[0].shape[0] == 0 and self.sts[0].shape[0]//base_shape[0] != 0:
|
315
|
-
if DEBUG >= 4: print("split opencl", base_shape, self.sts[0].shape)
|
316
|
-
self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None)
|
317
|
-
self.simplify_ones()
|
318
|
-
|
319
|
-
# no more opt if we are grouping
|
320
|
-
if self.group_for_reduce: return
|
321
|
-
|
322
|
-
# no more opt if there's non ints in any shapes
|
323
|
-
# TODO: this is due to a bug. repro by commenting this one while running GPT-2 with the JIT
|
324
|
-
if self.has_variable_shape(): return
|
325
|
-
|
326
|
-
# **** below this line need to be optional and benchmarked ****
|
327
|
-
|
328
|
-
# potentially do more upcasts of non reduce axes based on a heuristic
|
329
|
-
upcasted_axis = set()
|
330
|
-
while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
|
331
|
-
xb_choices = []
|
332
|
-
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
333
|
-
# if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
334
|
-
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index in range(len(self.sts))):
|
335
|
-
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount))
|
336
|
-
if xb_choices:
|
337
|
-
xb_choices = sorted(xb_choices)
|
338
|
-
if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
|
339
|
-
self.shift_to(xb_choices[0][2], amount=xb_choices[0][3])
|
340
|
-
self.upcast()
|
341
|
-
self.simplify_ones()
|
342
|
-
upcasted_axis.add(xb_choices[0][2])
|
343
|
-
else:
|
344
|
-
break
|
345
|
-
|
346
|
-
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
|
347
|
-
if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))):
|
348
|
-
if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
|
349
|
-
self.upcast()
|
350
|
-
# if it's small, upcast a second reduce dimension too
|
351
|
-
if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and self.full_unupcasted_shape[-1] <= 3: self.upcast()
|
352
|
-
else:
|
353
|
-
for splits in [4]:
|
354
|
-
if self.full_unupcasted_shape[-1]%splits == 0:
|
355
|
-
self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape))
|
356
|
-
self.upcast()
|
357
|
-
break
|
358
|
-
|
359
|
-
# if nothing at all is upcasted and it's easy to, do an upcast
|
360
|
-
# TODO: this is breaking the tests
|
361
|
-
for splits in [4]:
|
362
|
-
if self.upcasted == 0 and self.full_unupcasted_shape and self.full_unupcasted_shape[-1] % splits == 0:
|
363
|
-
self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape))
|
364
|
-
self.upcast()
|
365
|
-
|
366
|
-
# **** local groups ****
|
367
|
-
|
368
|
-
if self.opts.has_local:
|
369
|
-
for axis in range(self.first_reduce - self.local_dims - 1, -1, -1):
|
370
|
-
local_size = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce])
|
371
|
-
if self.full_shape[axis] == 1: continue
|
372
|
-
last_try = self.local_dims == 0 and axis == 0
|
373
|
-
if any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))) or last_try:
|
374
|
-
for sz in [x for x in (([32] if last_try else []) + [16,8,4,3]) if self.full_shape[axis] % x == 0 and local_size*x <= 128]:
|
375
|
-
self.shift_to(axis, sz, insert_before=self.first_reduce-self.local_dims)
|
376
|
-
self.local_dims += 1
|
377
|
-
break
|
378
|
-
if self.local_dims >= 3: break
|
379
|
-
self.simplify_ones()
|
tinygrad/codegen/search.py
DELETED
@@ -1,72 +0,0 @@
|
|
1
|
-
from typing import Callable
|
2
|
-
import time
|
3
|
-
from tinygrad.codegen.linearizer import Linearizer
|
4
|
-
from tinygrad.helpers import DEBUG, prod, getenv
|
5
|
-
|
6
|
-
UPCASTS = [1,2,3,4,5,6,7,8]
|
7
|
-
LOCALS = [1,2,3,4,5,6,7,8,16,24,32]
|
8
|
-
def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, baseline):
|
9
|
-
import nevergrad as ng
|
10
|
-
def opt(x):
|
11
|
-
try:
|
12
|
-
k = create_k()
|
13
|
-
k.process()
|
14
|
-
k.apply_auto_opt(x)
|
15
|
-
prg = to_prg(k)
|
16
|
-
first_tm = prg.exec(k.bufs, force_wait=True, optimizing=True)
|
17
|
-
if baseline*5 < first_tm*1000: return first_tm*1000 # very slow
|
18
|
-
tm = min([first_tm]+[prg.exec(k.bufs, force_wait=True, optimizing=True) for _ in range(2)])*1000
|
19
|
-
return tm
|
20
|
-
except Exception:
|
21
|
-
if DEBUG >= 3:
|
22
|
-
import traceback
|
23
|
-
traceback.print_exc()
|
24
|
-
return 10000_000 # 10000 seconds is infinity
|
25
|
-
opts = []
|
26
|
-
for i in range(k.first_reduce):
|
27
|
-
# TODO: the upcast always happen first, you might want to reverse this?
|
28
|
-
# TODO: the order of the locals might improve things too
|
29
|
-
opts.append(ng.p.TransitionChoice([(i,s,"U") for s in UPCASTS if k.full_shape[i]%s == 0]))
|
30
|
-
opts.append(ng.p.TransitionChoice([(i,s,"L") for s in LOCALS if k.full_shape[i]%s == 0]))
|
31
|
-
for i in range(k.shape_len-k.first_reduce):
|
32
|
-
opts.append(ng.p.TransitionChoice([(i,s,"R") for s in UPCASTS if k.full_shape[k.first_reduce+i]%s == 0]))
|
33
|
-
if not opts: return "BASELINE"
|
34
|
-
search_space = prod([len(x.choices) for x in opts])
|
35
|
-
st = time.perf_counter()
|
36
|
-
optimizer = ng.optimizers.NGOpt(parametrization=ng.p.Tuple(*opts), budget=min(search_space, 200))
|
37
|
-
recommendation = optimizer.minimize(opt)
|
38
|
-
et = time.perf_counter() - st
|
39
|
-
if DEBUG >= 1: print(f"optimizer({et:6.2f} s to search) space {search_space:8d} with tm {recommendation.loss:5.2f} ms vs baseline {baseline:5.2f} ms, a {baseline/recommendation.loss:5.2f}x gain : {k.colored_shape()}")
|
40
|
-
return recommendation.value if recommendation.loss < baseline else "BASELINE"
|
41
|
-
|
42
|
-
# optimization
|
43
|
-
global_db = None
|
44
|
-
def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], to_prg):
|
45
|
-
global global_db
|
46
|
-
|
47
|
-
k.process()
|
48
|
-
skey = str(k.key)
|
49
|
-
|
50
|
-
if getenv("KOPT") == 2 and global_db is None:
|
51
|
-
import shelve
|
52
|
-
global_db = shelve.open("/tmp/kopt_cache")
|
53
|
-
|
54
|
-
if global_db is not None and skey in global_db:
|
55
|
-
choice = global_db[skey]
|
56
|
-
elif k.has_variable_shape():
|
57
|
-
# don't optimize variable shapes
|
58
|
-
choice = "BASELINE"
|
59
|
-
else:
|
60
|
-
# get baseline
|
61
|
-
def get_baseline():
|
62
|
-
k = create_k()
|
63
|
-
k.hand_coded_optimizations()
|
64
|
-
prg = to_prg(k)
|
65
|
-
return min([prg.exec(k.bufs, force_wait=True, optimizing=True) for _ in range(5)])*1000
|
66
|
-
choice = kernel_optimize_search(k, create_k, to_prg, get_baseline())
|
67
|
-
if global_db is not None:
|
68
|
-
global_db[skey] = choice
|
69
|
-
global_db.sync()
|
70
|
-
|
71
|
-
if choice == "BASELINE": k.hand_coded_optimizations()
|
72
|
-
else: k.apply_auto_opt(choice)
|
tinygrad/graph.py
DELETED
@@ -1,83 +0,0 @@
|
|
1
|
-
import os, atexit, itertools
|
2
|
-
try:
|
3
|
-
import networkx as nx # type: ignore
|
4
|
-
except ImportError:
|
5
|
-
nx = None # graph won't work
|
6
|
-
from collections import defaultdict
|
7
|
-
from typing import Dict, List, Optional, TYPE_CHECKING
|
8
|
-
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, Op, OpType, LazyOp
|
9
|
-
from tinygrad.helpers import GRAPH, GRAPHPATH, PRUNEGRAPH, DEBUG, GlobalCounters
|
10
|
-
from tinygrad.runtime.lib import RawConst
|
11
|
-
|
12
|
-
if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer
|
13
|
-
|
14
|
-
# **** debugging and graphing ****
|
15
|
-
|
16
|
-
G = nx.DiGraph() if nx is not None else None
|
17
|
-
cnts: Dict[OpType, int] = defaultdict(int)
|
18
|
-
if DEBUG >= 2:
|
19
|
-
def print_globalcounters():
|
20
|
-
if GlobalCounters.time_sum_s == 0: return
|
21
|
-
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s",
|
22
|
-
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms")
|
23
|
-
atexit.register(print_globalcounters)
|
24
|
-
if GRAPH:
|
25
|
-
def save_graph_exit():
|
26
|
-
for k,v in cnts.items(): print(k, v)
|
27
|
-
if PRUNEGRAPH: prune_graph()
|
28
|
-
print("saving", G)
|
29
|
-
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
|
30
|
-
# -Gnslimit=100 can make it finish, but you won't like results
|
31
|
-
os.system(f'dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')
|
32
|
-
atexit.register(save_graph_exit)
|
33
|
-
|
34
|
-
node_count = 0
|
35
|
-
def nm(x):
|
36
|
-
global node_count
|
37
|
-
if not hasattr(x, 'node_id'):
|
38
|
-
setattr(x, 'node_id', node_count)
|
39
|
-
node_count += 1
|
40
|
-
return x.node_id
|
41
|
-
|
42
|
-
def get_sop(op: List[Op]):
|
43
|
-
if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1])
|
44
|
-
if len(op) <= 4: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1])
|
45
|
-
return str(len(op))
|
46
|
-
|
47
|
-
def str_dtype(dtyp):
|
48
|
-
ret = str(dtyp)[7:]
|
49
|
-
return "" if ret == 'float' else f"\n{ret}"
|
50
|
-
|
51
|
-
def log_op(ret: 'LazyBuffer', ast: LazyOp, show_graph: Optional[bool] = None, phantom=False):
|
52
|
-
if show_graph is None: show_graph = bool(GRAPH)
|
53
|
-
if not DEBUG and not show_graph: return
|
54
|
-
op: List[Op] = [x.op for x in ast.get_lazyops()]
|
55
|
-
inp: List['LazyBuffer'] = [x for x in ast.buffers if not isinstance(x.realized, RawConst) or GRAPH > 1]
|
56
|
-
oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps]
|
57
|
-
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
|
58
|
-
cnts[optype] += 1
|
59
|
-
if DEBUG >= 6: print(f"{op} : {', '.join([f'{x.shape}-<{nm(x)}>' for x in inp])} -> {ret.shape}-<{nm(ret)}>")
|
60
|
-
if show_graph:
|
61
|
-
top_colors = {LoadOps: '#FFFF80', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#ff8080"}
|
62
|
-
dashed = (optype == LoadOps and hasattr(ret, "_backing")) or (hasattr(ret, "st") and not ret.st.contiguous) # type: ignore
|
63
|
-
|
64
|
-
for x in inp:
|
65
|
-
G.add_edge(nm(x), nm(ret), label=get_sop(op), color='#00000060' if phantom else 'black')
|
66
|
-
if 'label' not in G.nodes[nm(x)]:
|
67
|
-
G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(ret.dtype)
|
68
|
-
if nm(ret) not in G.nodes: G.add_node(nm(ret))
|
69
|
-
|
70
|
-
G.nodes[nm(ret)]['label'] = (str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape))+str_dtype(ret.dtype)
|
71
|
-
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('60' if phantom else ('80' if dashed else str()))) if optype in top_colors else "#ffffff"
|
72
|
-
G.nodes[nm(ret)]['color'] = 'white' if phantom else 'black'
|
73
|
-
G.nodes[nm(ret)]['style'] = ('filled, dashed' if dashed else 'filled')
|
74
|
-
G.nodes[nm(ret)]['prunable'] = optype in [LoadOps, MovementOps]
|
75
|
-
|
76
|
-
# prune movementops and loadops
|
77
|
-
def prune_graph():
|
78
|
-
dead_nodes = []
|
79
|
-
for n in G.nodes:
|
80
|
-
if 'prunable' in G.nodes[n] and G.nodes[n]['prunable']:
|
81
|
-
G.add_edges_from([(x, y) for (x,_),(_,y) in itertools.product(G.in_edges(n), G.out_edges(n))])
|
82
|
-
dead_nodes.append(n)
|
83
|
-
G.remove_nodes_from(dead_nodes)
|
tinygrad/jit.py
DELETED
@@ -1,57 +0,0 @@
|
|
1
|
-
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
|
2
|
-
import functools, itertools
|
3
|
-
from tinygrad.helpers import DEBUG, DType, merge_dicts
|
4
|
-
from tinygrad.ops import Device
|
5
|
-
from tinygrad.tensor import Tensor
|
6
|
-
from tinygrad.ops import GlobalCounters, RawBuffer
|
7
|
-
from tinygrad.shape.shapetracker import ShapeTracker
|
8
|
-
from tinygrad.shape.symbolic import Variable
|
9
|
-
|
10
|
-
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]
|
11
|
-
|
12
|
-
class TinyJit:
|
13
|
-
def __init__(self, fxn:Callable):
|
14
|
-
self.fxn: Callable = fxn
|
15
|
-
self.cnt: int = 0
|
16
|
-
self.jit_cache: List[Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]] = []
|
17
|
-
self.ret: Any = None
|
18
|
-
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], ShapeTracker, DType]]= {} # (kernel_number, buffer_number) -> (input_name, expected_shapetracker, expected_type)
|
19
|
-
|
20
|
-
# add support for instance methods
|
21
|
-
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
22
|
-
|
23
|
-
def __call__(self, *args, **kwargs) -> Any:
|
24
|
-
if Device.DEFAULT not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
|
25
|
-
# NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't
|
26
|
-
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.lazydata.st) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
27
|
-
assert len(input_rawbuffers) != 0, "no inputs to JIT"
|
28
|
-
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
|
29
|
-
if self.cnt >= 2:
|
30
|
-
var_vals = dict(sorted(merge_dicts([arg.lazydata.st.var_vals for arg in args if isinstance(arg, Tensor)]).items(), key=lambda kv: kv[0].key))
|
31
|
-
for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
|
32
|
-
assert input_rawbuffers[input_name][1].views == expected_st.views and input_rawbuffers[input_name][0].dtype == expected_type, f"ShapeTracker.views or type mismatch in JIT, <{input_rawbuffers[input_name][1].views}, {input_rawbuffers[input_name][0].dtype}> != <{expected_st.views}, {expected_type}>"
|
33
|
-
self.jit_cache[j][1][i] = input_rawbuffers[input_name][0]
|
34
|
-
for prg, pargs, variables in self.jit_cache: # type: Callable, List[Optional[RawBuffer]], Dict[Variable, int]
|
35
|
-
for v in (var_vals.keys() & variables.keys()): variables[v] = var_vals[v]
|
36
|
-
prg(pargs, variables, jit=True)
|
37
|
-
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
|
38
|
-
elif self.cnt == 1:
|
39
|
-
GlobalCounters.cache = []
|
40
|
-
self.ret = self.fxn(*args, **kwargs)
|
41
|
-
self.jit_cache = GlobalCounters.cache
|
42
|
-
GlobalCounters.cache = None
|
43
|
-
assert len(self.jit_cache) != 0, "didn't JIT anything!"
|
44
|
-
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
45
|
-
|
46
|
-
# get the inputs for replacement
|
47
|
-
for j_,cache in enumerate(self.jit_cache): # type: Tuple[int, Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]]
|
48
|
-
for i,a in enumerate(cache[1]):
|
49
|
-
if a in [v[0] for v in input_rawbuffers.values()]:
|
50
|
-
self.input_replace[(j_,i)] = [(k, v[1], v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
|
51
|
-
#if prg.local_size is None: prg.local_size = prg.optimize_local_size(args, preserve_output=True) # the JIT can optimize local
|
52
|
-
assert set([x[0] for x in self.input_replace.values()]) == set(input_rawbuffers.keys()), "some input tensors not found"
|
53
|
-
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
|
54
|
-
elif self.cnt == 0:
|
55
|
-
self.ret = self.fxn(*args, **kwargs)
|
56
|
-
self.cnt += 1
|
57
|
-
return self.ret
|