tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. tinygrad/__init__.py +6 -0
  2. tinygrad/codegen/kernel.py +572 -83
  3. tinygrad/codegen/linearizer.py +415 -395
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +183 -0
  6. tinygrad/dtype.py +113 -0
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +76 -55
  14. tinygrad/helpers.py +196 -89
  15. tinygrad/lazy.py +210 -371
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +202 -22
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +112 -32
  20. tinygrad/nn/state.py +136 -39
  21. tinygrad/ops.py +119 -202
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +353 -166
  25. tinygrad/renderer/llvmir.py +150 -138
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +81 -0
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +75 -0
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +24 -77
  43. tinygrad/runtime/ops_cuda.py +175 -89
  44. tinygrad/runtime/ops_disk.py +56 -33
  45. tinygrad/runtime/ops_gpu.py +92 -95
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +39 -60
  48. tinygrad/runtime/ops_metal.py +92 -74
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +86 -254
  53. tinygrad/shape/symbolic.py +166 -141
  54. tinygrad/shape/view.py +296 -0
  55. tinygrad/tensor.py +2619 -448
  56. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. tinygrad-0.9.0.dist-info/METADATA +227 -0
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/codegen/assembly.py +0 -190
  61. tinygrad/codegen/optimizer.py +0 -379
  62. tinygrad/codegen/search.py +0 -72
  63. tinygrad/graph.py +0 -83
  64. tinygrad/jit.py +0 -57
  65. tinygrad/nn/image.py +0 -100
  66. tinygrad/renderer/assembly_arm64.py +0 -169
  67. tinygrad/renderer/assembly_ptx.py +0 -98
  68. tinygrad/renderer/wgsl.py +0 -53
  69. tinygrad/runtime/lib.py +0 -113
  70. tinygrad/runtime/ops_cpu.py +0 -51
  71. tinygrad/runtime/ops_hip.py +0 -82
  72. tinygrad/runtime/ops_shm.py +0 -29
  73. tinygrad/runtime/ops_torch.py +0 -30
  74. tinygrad/runtime/ops_webgpu.py +0 -45
  75. tinygrad-0.7.0.dist-info/METADATA +0 -212
  76. tinygrad-0.7.0.dist-info/RECORD +0 -40
  77. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,379 +0,0 @@
1
- from typing import Tuple, List, cast
2
- import itertools, math
3
- from tinygrad.helpers import DEBUG, prod, getenv, ImageDType, dtypes
4
- from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp
5
- from tinygrad.codegen.kernel import Kernel, LocalBuffer
6
- from tinygrad.lazy import LazyBuffer
7
- from tinygrad.shape.shapetracker import ShapeTracker, View
8
-
9
- class OptimizedKernel(Kernel):
10
- def process(self) -> None:
11
- if hasattr(self, "sts"): return # already processed
12
- super().process()
13
-
14
- # move all reduce axes to the end
15
- reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape)))
16
- permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
17
- self.reshape_and_permute(None, permute)
18
-
19
- # group simplifies
20
- self.simplify_ones()
21
- self.simplify_merge_adjacent()
22
-
23
- # ******************** base simplifiers ********************
24
-
25
- # apply reshape and permute to all shapetrackers
26
- def reshape_and_permute(self, new_shape_fxn, axis):
27
- for st in self.sts:
28
- if new_shape_fxn is not None: st.reshape(tuple(new_shape_fxn(st.shape)))
29
- if axis is not None: st.permute(tuple(axis))
30
-
31
- # drops the final dimension
32
- def upcast(self):
33
- assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1"
34
- self.upcasted += 1
35
-
36
- # axis : the axis to pull from
37
- # amount : the amount to take
38
- # top : if you want to pull that amount from the top
39
- # insert_before : place to insert the new stuff
40
- def shift_to(self, axis, amount, top=False, insert_before=None):
41
- if insert_before is None: insert_before = self.shape_len
42
- move_axis = axis if top else axis+1
43
- if move_axis < insert_before: insert_before += 1
44
- self.reshape_and_permute(
45
- lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]),
46
- [i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])
47
-
48
- # ******************** complex simplifiers ********************
49
-
50
- def simplify_ones(self):
51
- # remove places where the shape is all ones
52
- # TODO: this should be factored in to multi shape stride
53
- if self.shape_len == 0: return
54
- all_ones = [s==1 for s in self.full_shape]
55
- self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
56
- self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:])
57
- self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
58
-
59
- def simplify_merge_adjacent(self):
60
- if self.shape_len == 0: return
61
- shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
62
-
63
- # merge dimensions if we can, multi get_shape_strides
64
- # TODO: does this always preserve the reduce dimension, NO
65
- # TODO: move this into shapetracker, with tests!
66
- rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))]
67
- for i in range(1, len(shapes[0])):
68
- can_merge = []
69
- for j in range(len(shapes)):
70
- # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
71
- can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0)))
72
- # more can merge than this
73
- mergeable = all(can_merge) and i != self.first_reduce
74
- for j in range(len(shapes)):
75
- if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
76
- else: rets[j].append((shapes[j][i], strides[j][i]))
77
-
78
- # do the reshapes
79
- for i,x in enumerate(rets): self.sts[i].reshape(tuple([y[0] for y in x]))
80
-
81
- # ******************** GPU simplifiers ********************
82
- def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
83
- new_shape,dims = list(x), len(x)
84
- for i in range(dims):
85
- next_idx = (i + 1) % dims
86
- while new_shape[i] > max_size[i]:
87
- new_shape[i] = new_shape[i] // 2
88
- if (new_shape[next_idx] <= max_size[next_idx]):
89
- new_shape[next_idx] = new_shape[next_idx] * 2
90
- else:
91
- next_idx = (next_idx + 1) % dims
92
- new_shape[next_idx] = new_shape[next_idx] * 2
93
- return tuple(new_shape)
94
-
95
- def limit_global_dims(self, limit: int, global_max: List[int], local_max: List[int]):
96
- # sometimes, there's more dimensions than len(self.lang.gid).
97
- # compact all the dimensions into the first
98
- # NOTE: this might make multiview shapetrackers
99
- if (self.first_reduce-self.local_dims) > limit:
100
- num_to_merge = ((self.first_reduce-self.local_dims) - limit)+1
101
- self.reshape_and_permute(lambda x: (prod(x[0:num_to_merge]),)+x[num_to_merge:], None)
102
- if DEBUG >= 3: print("reshaped to", self.full_shape, "due to too many global dimensions")
103
- # Check the global allocation limit, current the global_size will be flipped during codegen
104
- # and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
105
- global_dims = self.first_reduce-self.local_dims
106
- if global_dims > 0:
107
- if global_max:
108
- tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else [])
109
- if max(global_max) < max(self.full_shape[:global_dims]): self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
110
- assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}"
111
- for i in range(global_dims-1):
112
- if self.full_shape[i] > global_max[i]:
113
- order = list(range(len(self.full_shape)))
114
- order[i], order[global_dims-1] = order[global_dims-1], order[i]
115
- self.reshape_and_permute(None, order)
116
- if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit")
117
-
118
- def alias_buffer(self, i, pattern):
119
- assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
120
-
121
- bst = 1
122
- real_strides = self.sts[i].real_strides()
123
- shp, stride = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern)
124
- for priority in range(1, max(pattern)+1): # priority. 0 is non local and ignored
125
- for j,p in enumerate(pattern):
126
- if priority == p and real_strides[j] != 0:
127
- stride[j] = bst
128
- bst *= shp[j]
129
-
130
- self.sts.append(ShapeTracker(tuple(shp), [View(tuple(shp), tuple(stride))]))
131
- self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size()))
132
- if DEBUG >= 4: print("aliasing buffer", self.sts[i])
133
- self.local_alias[i] = self.bufs[-1]
134
-
135
- # ******************** high level optimizers ********************
136
-
137
- def apply_auto_opt(self, x):
138
- for axis, amt, typ in x:
139
- if axis is None or amt == 1: continue
140
- if typ == "R":
141
- typ = "U"
142
- axis += self.first_reduce
143
- assert self.full_shape[axis] % amt == 0, "no longer valid shift"
144
- if typ == "U":
145
- self.shift_to(axis, amt)
146
- self.upcast()
147
- elif typ == "L":
148
- self.shift_to(axis, amt, insert_before=self.first_reduce)
149
- self.local_dims += 1
150
- self.simplify_ones()
151
-
152
- def required_optimizations(self, early_only=False):
153
- for buf_index,buf in enumerate(self.bufs):
154
- unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
155
- if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType:
156
- assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
157
- if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
158
- self.shift_to(unit_stride_axes_mul_4[0], 4)
159
- self.upcast()
160
-
161
- def hand_coded_optimizations(self):
162
- self.process()
163
-
164
- # if there's images in the earlybufs, we have to make an axis the 4 loading one
165
- self.required_optimizations(early_only=True)
166
-
167
- # simplify
168
- self.simplify_ones()
169
-
170
- # should use HIP tensor cores?
171
- if getenv("TC", 1) != 0 and self.bufs[0].device == "HIP" and self.reduceop and self.reduceop.op == ReduceOps.SUM and \
172
- isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and \
173
- isinstance(self.reduceop.src[0].src[0], LazyOp) and self.reduceop.src[0].src[0].op == BinaryOps.MUL and \
174
- isinstance(self.reduceop.src[0].src[0].src[0], LazyBuffer) and isinstance(self.reduceop.src[0].src[0].src[1], LazyBuffer) and self.opts.has_local and \
175
- self.reduceop.src[0].src[0].src[0].dtype == dtypes.half and self.reduceop.src[0].src[0].src[1].dtype == dtypes.half:
176
- # HIP tensor cores are 16x16x16
177
- buf0 = self.bufs.index(self.reduceop.src[0].src[0].src[0])
178
- buf1 = self.bufs.index(self.reduceop.src[0].src[0].src[1])
179
- buf0_strides = self.sts[buf0].real_strides()
180
- buf1_strides = self.sts[buf1].real_strides()
181
- axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides) if s == 0 and self.full_shape[i]%16 == 0 and i < self.first_reduce]
182
- axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides) if s == 0 and self.full_shape[i]%16 == 0 and i < self.first_reduce]
183
- if axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%8 == 0 and (self.shape_len-self.first_reduce) == 1:
184
- if DEBUG >= 3: print("HIP TENSOR CORES", axis_buf0, axis_buf1)
185
- self.use_tensor_cores = getenv("TC", 1) == 1 # TC=2 will do the shape ops without the WMMA
186
- self.reverse_upcast_dir = True
187
-
188
- # TODO: select axis in smart way
189
- s0, s1 = axis_buf0[-1][0], axis_buf1[-1][0]
190
- global_count = self.first_reduce
191
-
192
- # upcast first
193
- if self.full_shape[self.first_reduce] > 16: self.shift_to(self.first_reduce, 16)
194
- self.upcast()
195
-
196
- # 2 locals
197
- self.shift_to(s1, 16, insert_before=self.first_reduce) # axis 2
198
- self.shift_to(s0, 16, insert_before=self.first_reduce) # axis 3
199
- self.local_dims += 1
200
-
201
- # output shape
202
- self.shift_to(self.first_reduce-2, 8)
203
- self.upcast()
204
-
205
- # split local dim
206
- self.shift_to(self.first_reduce-1, 8, insert_before=self.first_reduce) # axis 3
207
-
208
- # final global upcast
209
- for ax in [s1, s0]:
210
- for upc in [4,3,2]:
211
- if self.full_shape[ax]%upc == 0:
212
- self.shift_to(ax, upc)
213
- self.upcast()
214
- break
215
-
216
- # alias buffer
217
- alias_pattern = [0]*global_count + [0,0,1] + [0] * (self.shape_len-self.upcasted-self.first_reduce) + [2,3] + [0]*(self.upcasted-2)
218
- self.alias_buffer(buf0, alias_pattern)
219
- self.alias_buffer(buf1, alias_pattern)
220
-
221
- # two fake locals
222
- if self.use_tensor_cores:
223
- self.local_dims += 2
224
- self.exclude_local_upcast += 2
225
-
226
- # early exit
227
- return
228
-
229
- # should use METAL tensor cores?
230
- # first, confirm it's a straightforward mulacc on a device with real locals
231
- tensor_cores_allowed = getenv("TC", 1) != 0 and (getenv("TC", 1) == 2 or (self.bufs[0].device == "METAL" and getenv("CI", "") != "true"))
232
- if tensor_cores_allowed and self.reduceop and self.reduceop.op == ReduceOps.SUM and \
233
- isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \
234
- isinstance(self.reduceop.src[0].src[0], LazyBuffer) and isinstance(self.reduceop.src[0].src[1], LazyBuffer) and self.opts.has_local:
235
- # METAL tensor cores are 8x8x8, with 2 elements per thread in the 32 thread warp
236
- buf0 = self.bufs.index(self.reduceop.src[0].src[0])
237
- buf1 = self.bufs.index(self.reduceop.src[0].src[1])
238
- buf0_strides = self.sts[buf0].real_strides()
239
- buf1_strides = self.sts[buf1].real_strides()
240
- axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides) if s == 0 and self.full_shape[i]%8 == 0 and i < self.first_reduce]
241
- axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides) if s == 0 and self.full_shape[i]%8 == 0 and i < self.first_reduce]
242
- if axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%8 == 0 and (self.shape_len-self.first_reduce) == 1:
243
- if DEBUG >= 3: print("METAL TENSOR CORES", axis_buf0, axis_buf1)
244
- self.use_tensor_cores = getenv("TC", 1) == 1 # TC=2 will do the shape ops without the WMMA
245
-
246
- # TODO: select axis in smart way
247
- s0, s1 = axis_buf0[-1][0], axis_buf1[-1][0]
248
- global_count = self.first_reduce
249
-
250
- # upcast first
251
- if self.full_shape[self.first_reduce] > 8: self.shift_to(self.first_reduce, 8)
252
- self.upcast()
253
-
254
- # 2 locals
255
- self.shift_to(s1, 8, insert_before=self.first_reduce) # axis 2
256
- self.shift_to(s0, 8, insert_before=self.first_reduce) # axis 3
257
-
258
- # permuted+upcast for tensor cores
259
- self.shift_to(global_count, 4, insert_before=self.first_reduce)
260
- self.shift_to(global_count+1, 4, insert_before=self.first_reduce)
261
- self.shift_to(self.first_reduce-1, 2)
262
- self.upcast()
263
-
264
- # final global upcast
265
- for ax in [s1, s0]:
266
- for upc in [4,3,2]:
267
- if self.full_shape[ax]%upc == 0:
268
- self.shift_to(ax, upc)
269
- self.upcast()
270
- break
271
-
272
- # alias buffer
273
- self.local_dims = self.first_reduce - global_count
274
- alias_pattern = [0]*global_count + [2] * self.local_dims + [0] * (self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3] * (self.upcasted-2)
275
- self.alias_buffer(buf0, alias_pattern)
276
- self.alias_buffer(buf1, alias_pattern)
277
-
278
- # very late upcast to run group at the same time. only if actually using real tensor cores, otherwise local isn't a simdgroup
279
- if self.use_tensor_cores and self.full_shape[s0] % 2 == 0:
280
- self.shift_to(s0, 2, insert_before=self.first_reduce-self.local_dims)
281
- self.local_dims += 1
282
- self.exclude_local_upcast += 1
283
-
284
- # early exit
285
- return
286
-
287
- if self.opts.has_local and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]):
288
- # are we grouping? (requires local shape support)
289
- if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
290
- # TODO: use 1024 if it's allowed in a smarter way
291
- for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
292
- if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
293
- self.shift_to(self.first_reduce, sz, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
294
- self.group_for_reduce.append(sz)
295
- break
296
-
297
- # are we upcasting in mid reduce? (only for images)
298
- if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1:
299
- axes = self.sts[0].unit_stride_axes()
300
- assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
301
- if self.sts[0].shape[axes[0]]%4 == 0:
302
- self.shift_to(axes[0], 4, insert_before=self.first_reduce + len(self.group_for_reduce)) # insert at the end of the grouped axis
303
- self.group_for_reduce.append(4)
304
-
305
- # now do everything required
306
- self.required_optimizations()
307
-
308
- # simplify (sets first_reduce)
309
- self.simplify_ones()
310
-
311
- # use more opencl indexing if the output buffer is an image and we have room
312
- if self.bufs[0].dtype.name.startswith('image') and self.first_reduce+len(self.group_for_reduce) < 3:
313
- base_shape = self.bufs[0].dtype.shape
314
- if (base_shape[0]*base_shape[1]) % self.sts[0].shape[0] == 0 and self.sts[0].shape[0]//base_shape[0] != 0:
315
- if DEBUG >= 4: print("split opencl", base_shape, self.sts[0].shape)
316
- self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None)
317
- self.simplify_ones()
318
-
319
- # no more opt if we are grouping
320
- if self.group_for_reduce: return
321
-
322
- # no more opt if there's non ints in any shapes
323
- # TODO: this is due to a bug. repro by commenting this one while running GPT-2 with the JIT
324
- if self.has_variable_shape(): return
325
-
326
- # **** below this line need to be optional and benchmarked ****
327
-
328
- # potentially do more upcasts of non reduce axes based on a heuristic
329
- upcasted_axis = set()
330
- while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
331
- xb_choices = []
332
- for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
333
- # if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
334
- if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index in range(len(self.sts))):
335
- xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount))
336
- if xb_choices:
337
- xb_choices = sorted(xb_choices)
338
- if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
339
- self.shift_to(xb_choices[0][2], amount=xb_choices[0][3])
340
- self.upcast()
341
- self.simplify_ones()
342
- upcasted_axis.add(xb_choices[0][2])
343
- else:
344
- break
345
-
346
- # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
347
- if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))):
348
- if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
349
- self.upcast()
350
- # if it's small, upcast a second reduce dimension too
351
- if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and self.full_unupcasted_shape[-1] <= 3: self.upcast()
352
- else:
353
- for splits in [4]:
354
- if self.full_unupcasted_shape[-1]%splits == 0:
355
- self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape))
356
- self.upcast()
357
- break
358
-
359
- # if nothing at all is upcasted and it's easy to, do an upcast
360
- # TODO: this is breaking the tests
361
- for splits in [4]:
362
- if self.upcasted == 0 and self.full_unupcasted_shape and self.full_unupcasted_shape[-1] % splits == 0:
363
- self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape))
364
- self.upcast()
365
-
366
- # **** local groups ****
367
-
368
- if self.opts.has_local:
369
- for axis in range(self.first_reduce - self.local_dims - 1, -1, -1):
370
- local_size = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce])
371
- if self.full_shape[axis] == 1: continue
372
- last_try = self.local_dims == 0 and axis == 0
373
- if any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))) or last_try:
374
- for sz in [x for x in (([32] if last_try else []) + [16,8,4,3]) if self.full_shape[axis] % x == 0 and local_size*x <= 128]:
375
- self.shift_to(axis, sz, insert_before=self.first_reduce-self.local_dims)
376
- self.local_dims += 1
377
- break
378
- if self.local_dims >= 3: break
379
- self.simplify_ones()
@@ -1,72 +0,0 @@
1
- from typing import Callable
2
- import time
3
- from tinygrad.codegen.linearizer import Linearizer
4
- from tinygrad.helpers import DEBUG, prod, getenv
5
-
6
- UPCASTS = [1,2,3,4,5,6,7,8]
7
- LOCALS = [1,2,3,4,5,6,7,8,16,24,32]
8
- def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, baseline):
9
- import nevergrad as ng
10
- def opt(x):
11
- try:
12
- k = create_k()
13
- k.process()
14
- k.apply_auto_opt(x)
15
- prg = to_prg(k)
16
- first_tm = prg.exec(k.bufs, force_wait=True, optimizing=True)
17
- if baseline*5 < first_tm*1000: return first_tm*1000 # very slow
18
- tm = min([first_tm]+[prg.exec(k.bufs, force_wait=True, optimizing=True) for _ in range(2)])*1000
19
- return tm
20
- except Exception:
21
- if DEBUG >= 3:
22
- import traceback
23
- traceback.print_exc()
24
- return 10000_000 # 10000 seconds is infinity
25
- opts = []
26
- for i in range(k.first_reduce):
27
- # TODO: the upcast always happen first, you might want to reverse this?
28
- # TODO: the order of the locals might improve things too
29
- opts.append(ng.p.TransitionChoice([(i,s,"U") for s in UPCASTS if k.full_shape[i]%s == 0]))
30
- opts.append(ng.p.TransitionChoice([(i,s,"L") for s in LOCALS if k.full_shape[i]%s == 0]))
31
- for i in range(k.shape_len-k.first_reduce):
32
- opts.append(ng.p.TransitionChoice([(i,s,"R") for s in UPCASTS if k.full_shape[k.first_reduce+i]%s == 0]))
33
- if not opts: return "BASELINE"
34
- search_space = prod([len(x.choices) for x in opts])
35
- st = time.perf_counter()
36
- optimizer = ng.optimizers.NGOpt(parametrization=ng.p.Tuple(*opts), budget=min(search_space, 200))
37
- recommendation = optimizer.minimize(opt)
38
- et = time.perf_counter() - st
39
- if DEBUG >= 1: print(f"optimizer({et:6.2f} s to search) space {search_space:8d} with tm {recommendation.loss:5.2f} ms vs baseline {baseline:5.2f} ms, a {baseline/recommendation.loss:5.2f}x gain : {k.colored_shape()}")
40
- return recommendation.value if recommendation.loss < baseline else "BASELINE"
41
-
42
- # optimization
43
- global_db = None
44
- def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], to_prg):
45
- global global_db
46
-
47
- k.process()
48
- skey = str(k.key)
49
-
50
- if getenv("KOPT") == 2 and global_db is None:
51
- import shelve
52
- global_db = shelve.open("/tmp/kopt_cache")
53
-
54
- if global_db is not None and skey in global_db:
55
- choice = global_db[skey]
56
- elif k.has_variable_shape():
57
- # don't optimize variable shapes
58
- choice = "BASELINE"
59
- else:
60
- # get baseline
61
- def get_baseline():
62
- k = create_k()
63
- k.hand_coded_optimizations()
64
- prg = to_prg(k)
65
- return min([prg.exec(k.bufs, force_wait=True, optimizing=True) for _ in range(5)])*1000
66
- choice = kernel_optimize_search(k, create_k, to_prg, get_baseline())
67
- if global_db is not None:
68
- global_db[skey] = choice
69
- global_db.sync()
70
-
71
- if choice == "BASELINE": k.hand_coded_optimizations()
72
- else: k.apply_auto_opt(choice)
tinygrad/graph.py DELETED
@@ -1,83 +0,0 @@
1
- import os, atexit, itertools
2
- try:
3
- import networkx as nx # type: ignore
4
- except ImportError:
5
- nx = None # graph won't work
6
- from collections import defaultdict
7
- from typing import Dict, List, Optional, TYPE_CHECKING
8
- from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, Op, OpType, LazyOp
9
- from tinygrad.helpers import GRAPH, GRAPHPATH, PRUNEGRAPH, DEBUG, GlobalCounters
10
- from tinygrad.runtime.lib import RawConst
11
-
12
- if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer
13
-
14
- # **** debugging and graphing ****
15
-
16
- G = nx.DiGraph() if nx is not None else None
17
- cnts: Dict[OpType, int] = defaultdict(int)
18
- if DEBUG >= 2:
19
- def print_globalcounters():
20
- if GlobalCounters.time_sum_s == 0: return
21
- print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s",
22
- f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms")
23
- atexit.register(print_globalcounters)
24
- if GRAPH:
25
- def save_graph_exit():
26
- for k,v in cnts.items(): print(k, v)
27
- if PRUNEGRAPH: prune_graph()
28
- print("saving", G)
29
- nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
30
- # -Gnslimit=100 can make it finish, but you won't like results
31
- os.system(f'dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')
32
- atexit.register(save_graph_exit)
33
-
34
- node_count = 0
35
- def nm(x):
36
- global node_count
37
- if not hasattr(x, 'node_id'):
38
- setattr(x, 'node_id', node_count)
39
- node_count += 1
40
- return x.node_id
41
-
42
- def get_sop(op: List[Op]):
43
- if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1])
44
- if len(op) <= 4: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1])
45
- return str(len(op))
46
-
47
- def str_dtype(dtyp):
48
- ret = str(dtyp)[7:]
49
- return "" if ret == 'float' else f"\n{ret}"
50
-
51
- def log_op(ret: 'LazyBuffer', ast: LazyOp, show_graph: Optional[bool] = None, phantom=False):
52
- if show_graph is None: show_graph = bool(GRAPH)
53
- if not DEBUG and not show_graph: return
54
- op: List[Op] = [x.op for x in ast.get_lazyops()]
55
- inp: List['LazyBuffer'] = [x for x in ast.buffers if not isinstance(x.realized, RawConst) or GRAPH > 1]
56
- oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps]
57
- optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
58
- cnts[optype] += 1
59
- if DEBUG >= 6: print(f"{op} : {', '.join([f'{x.shape}-<{nm(x)}>' for x in inp])} -> {ret.shape}-<{nm(ret)}>")
60
- if show_graph:
61
- top_colors = {LoadOps: '#FFFF80', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#ff8080"}
62
- dashed = (optype == LoadOps and hasattr(ret, "_backing")) or (hasattr(ret, "st") and not ret.st.contiguous) # type: ignore
63
-
64
- for x in inp:
65
- G.add_edge(nm(x), nm(ret), label=get_sop(op), color='#00000060' if phantom else 'black')
66
- if 'label' not in G.nodes[nm(x)]:
67
- G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(ret.dtype)
68
- if nm(ret) not in G.nodes: G.add_node(nm(ret))
69
-
70
- G.nodes[nm(ret)]['label'] = (str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape))+str_dtype(ret.dtype)
71
- G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('60' if phantom else ('80' if dashed else str()))) if optype in top_colors else "#ffffff"
72
- G.nodes[nm(ret)]['color'] = 'white' if phantom else 'black'
73
- G.nodes[nm(ret)]['style'] = ('filled, dashed' if dashed else 'filled')
74
- G.nodes[nm(ret)]['prunable'] = optype in [LoadOps, MovementOps]
75
-
76
- # prune movementops and loadops
77
- def prune_graph():
78
- dead_nodes = []
79
- for n in G.nodes:
80
- if 'prunable' in G.nodes[n] and G.nodes[n]['prunable']:
81
- G.add_edges_from([(x, y) for (x,_),(_,y) in itertools.product(G.in_edges(n), G.out_edges(n))])
82
- dead_nodes.append(n)
83
- G.remove_nodes_from(dead_nodes)
tinygrad/jit.py DELETED
@@ -1,57 +0,0 @@
1
- from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
2
- import functools, itertools
3
- from tinygrad.helpers import DEBUG, DType, merge_dicts
4
- from tinygrad.ops import Device
5
- from tinygrad.tensor import Tensor
6
- from tinygrad.ops import GlobalCounters, RawBuffer
7
- from tinygrad.shape.shapetracker import ShapeTracker
8
- from tinygrad.shape.symbolic import Variable
9
-
10
- JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]
11
-
12
- class TinyJit:
13
- def __init__(self, fxn:Callable):
14
- self.fxn: Callable = fxn
15
- self.cnt: int = 0
16
- self.jit_cache: List[Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]] = []
17
- self.ret: Any = None
18
- self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], ShapeTracker, DType]]= {} # (kernel_number, buffer_number) -> (input_name, expected_shapetracker, expected_type)
19
-
20
- # add support for instance methods
21
- def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
22
-
23
- def __call__(self, *args, **kwargs) -> Any:
24
- if Device.DEFAULT not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
25
- # NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't
26
- input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.lazydata.st) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
27
- assert len(input_rawbuffers) != 0, "no inputs to JIT"
28
- assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
29
- if self.cnt >= 2:
30
- var_vals = dict(sorted(merge_dicts([arg.lazydata.st.var_vals for arg in args if isinstance(arg, Tensor)]).items(), key=lambda kv: kv[0].key))
31
- for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
32
- assert input_rawbuffers[input_name][1].views == expected_st.views and input_rawbuffers[input_name][0].dtype == expected_type, f"ShapeTracker.views or type mismatch in JIT, <{input_rawbuffers[input_name][1].views}, {input_rawbuffers[input_name][0].dtype}> != <{expected_st.views}, {expected_type}>"
33
- self.jit_cache[j][1][i] = input_rawbuffers[input_name][0]
34
- for prg, pargs, variables in self.jit_cache: # type: Callable, List[Optional[RawBuffer]], Dict[Variable, int]
35
- for v in (var_vals.keys() & variables.keys()): variables[v] = var_vals[v]
36
- prg(pargs, variables, jit=True)
37
- for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
38
- elif self.cnt == 1:
39
- GlobalCounters.cache = []
40
- self.ret = self.fxn(*args, **kwargs)
41
- self.jit_cache = GlobalCounters.cache
42
- GlobalCounters.cache = None
43
- assert len(self.jit_cache) != 0, "didn't JIT anything!"
44
- if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
45
-
46
- # get the inputs for replacement
47
- for j_,cache in enumerate(self.jit_cache): # type: Tuple[int, Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]]
48
- for i,a in enumerate(cache[1]):
49
- if a in [v[0] for v in input_rawbuffers.values()]:
50
- self.input_replace[(j_,i)] = [(k, v[1], v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
51
- #if prg.local_size is None: prg.local_size = prg.optimize_local_size(args, preserve_output=True) # the JIT can optimize local
52
- assert set([x[0] for x in self.input_replace.values()]) == set(input_rawbuffers.keys()), "some input tensors not found"
53
- for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
54
- elif self.cnt == 0:
55
- self.ret = self.fxn(*args, **kwargs)
56
- self.cnt += 1
57
- return self.ret