tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tinygrad/codegen/kernel.py +248 -115
  2. tinygrad/codegen/lowerer.py +215 -0
  3. tinygrad/codegen/transcendental.py +310 -0
  4. tinygrad/codegen/uopgraph.py +622 -0
  5. tinygrad/codegen/uops.py +235 -393
  6. tinygrad/device.py +428 -69
  7. tinygrad/dtype.py +18 -4
  8. tinygrad/engine/graph.py +19 -32
  9. tinygrad/engine/jit.py +148 -70
  10. tinygrad/engine/realize.py +127 -51
  11. tinygrad/engine/schedule.py +259 -216
  12. tinygrad/engine/search.py +29 -22
  13. tinygrad/function.py +9 -0
  14. tinygrad/helpers.py +87 -49
  15. tinygrad/lazy.py +34 -35
  16. tinygrad/multi.py +41 -36
  17. tinygrad/nn/__init__.py +39 -22
  18. tinygrad/nn/state.py +3 -3
  19. tinygrad/ops.py +63 -62
  20. tinygrad/renderer/__init__.py +43 -21
  21. tinygrad/renderer/assembly.py +104 -106
  22. tinygrad/renderer/cstyle.py +87 -60
  23. tinygrad/renderer/llvmir.py +21 -30
  24. tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
  25. tinygrad/runtime/autogen/cuda.py +6 -162
  26. tinygrad/runtime/autogen/kfd.py +32 -0
  27. tinygrad/runtime/autogen/libc.py +4260 -0
  28. tinygrad/runtime/autogen/nvrtc.py +579 -0
  29. tinygrad/runtime/graph/clang.py +2 -2
  30. tinygrad/runtime/graph/cuda.py +8 -11
  31. tinygrad/runtime/graph/hcq.py +120 -107
  32. tinygrad/runtime/graph/metal.py +18 -15
  33. tinygrad/runtime/ops_amd.py +197 -305
  34. tinygrad/runtime/ops_clang.py +2 -2
  35. tinygrad/runtime/ops_cuda.py +36 -94
  36. tinygrad/runtime/ops_disk.py +3 -7
  37. tinygrad/runtime/ops_gpu.py +4 -2
  38. tinygrad/runtime/ops_hip.py +70 -0
  39. tinygrad/runtime/ops_metal.py +38 -27
  40. tinygrad/runtime/ops_nv.py +283 -363
  41. tinygrad/runtime/ops_python.py +26 -30
  42. tinygrad/runtime/support/compiler_cuda.py +78 -0
  43. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
  44. tinygrad/runtime/support/elf.py +38 -0
  45. tinygrad/shape/shapetracker.py +5 -14
  46. tinygrad/shape/symbolic.py +4 -8
  47. tinygrad/shape/view.py +34 -22
  48. tinygrad/tensor.py +399 -97
  49. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
  50. tinygrad-0.9.2.dist-info/RECORD +70 -0
  51. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
  52. tinygrad/codegen/linearizer.py +0 -528
  53. tinygrad-0.9.1.dist-info/RECORD +0 -63
  54. /tinygrad/runtime/{driver → support}/__init__.py +0 -0
  55. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
  56. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,25 @@
1
1
  from __future__ import annotations
2
+ import itertools, functools
3
+ from dataclasses import dataclass, replace
2
4
  from collections import defaultdict
3
- import itertools
4
- from typing import DefaultDict, Optional, List, Tuple, cast, Dict, Union
5
- from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS, verify_lazyop
5
+ from typing import Optional, List, Tuple, cast, Dict, Union, Final, DefaultDict, Any
6
+
7
+ from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, MetaOps, UNSAFE_PAD_OPS, verify_lazyop, KernelInfo
6
8
  from tinygrad.device import Device
7
- from tinygrad.renderer import Renderer, TensorCore
8
- from tinygrad.dtype import dtypes, ImageDType, DType
9
- from tinygrad.helpers import all_same, colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
9
+ from tinygrad.renderer import Renderer, TensorCore, Program
10
+ from tinygrad.dtype import ImageDType
11
+ from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, DEBUG, TC_OPT, USE_TC, round_up, all_int, \
12
+ get_contraction, to_function_name, diskcache_put, ContextVar
10
13
  from tinygrad.shape.shapetracker import ShapeTracker
11
14
  from tinygrad.shape.symbolic import sint
12
- from tinygrad.shape.view import View, strides_for_shape
13
- from dataclasses import dataclass
15
+ from tinygrad.shape.view import strides_for_shape
16
+ from tinygrad.codegen.uopgraph import UOpGraph
17
+ from tinygrad.codegen.lowerer import lazyop_to_uop
14
18
  from enum import Enum, auto
15
19
 
16
20
  class OptOps(Enum):
17
21
  TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
18
- GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702
22
+ GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
19
23
  def __lt__(self, x:OptOps): return self.value < x.value
20
24
 
21
25
  class KernelOptError(Exception): pass
@@ -47,37 +51,42 @@ class TensorCoreOptions:
47
51
  elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False
48
52
  self.axes, self.axes_exist = tuple(axes), tuple(axes_exist)
49
53
 
50
- @dataclass(frozen=True)
51
- class LocalBuffer:
52
- name: str
53
- size: int
54
- dtype: DType = dtypes.float32
55
- realized: None = None
56
- def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
57
-
58
54
  class Kernel:
59
55
  def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
56
+ if len(ast) > 1 or ast[0].op is BufferOps.STORE:
57
+ assert all(x.op is BufferOps.STORE for x in ast)
58
+ self.ast = LazyOp(MetaOps.KERNEL, ast)
59
+ else:
60
+ assert len(ast) == 1 and ast[0].op is MetaOps.KERNEL
61
+ self.ast = ast[0]
62
+
60
63
  self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
61
- verify_lazyop(*ast)
62
- self.ast = ast
63
- self.lazyops = flatten([op.lazyops for op in self.ast])
64
+ try: lazyop_sts_map = verify_lazyop(self.ast)
65
+ except AssertionError as e:
66
+ print("INVALID AST")
67
+ for op in ast: print(op)
68
+ raise e
64
69
 
65
- cached_ordered_lazyops: Dict[LazyOp, List[LazyOp]] = {}
66
- def ordered_lazyops(op):
67
- if op not in cached_ordered_lazyops: cached_ordered_lazyops[op] = dedup([item for x in op.src for item in ordered_lazyops(x)] + [op])
68
- return cached_ordered_lazyops[op]
69
- self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps])
70
+ @functools.lru_cache(None)
71
+ def ordered_lazyops(op): return dedup([item for x in op.src for item in ordered_lazyops(x)] + [op])
72
+ self.reduceops = dedup([x for x in ordered_lazyops(self.ast) if x.op in ReduceOps])
70
73
 
71
- self.outbufs, self.vars = [x.arg for x in self.ast], flatten([x.vars() for x in self.ast])
72
- loadops = [BufferOps.LOAD, BufferOps.CONST]
73
- self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])
74
+ self.vars = self.ast.vars()
75
+ self.bufs: List[Union[MemBuffer, ConstBuffer]] = dedup([x.arg for x in self.ast.lazyops if x.op in BufferOps])
74
76
 
75
77
  # get earlybufs, before any reduceops
76
- self.earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
77
- self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
78
+ earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
79
+ self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
80
+ # NOTE: full_shape can be wrong if there's a tree of reduces
78
81
 
79
82
  # create new shapetrackers inside this kernel, we will permute them
80
- self.sts: List[ShapeTracker] = [x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)]
83
+ self.sts: List[ShapeTracker] = [x.st for x in self.bufs]
84
+
85
+ # add the shapetrackers for each reduce
86
+ # we use this to track which axes are reduced in each reduce
87
+ for x in self.reduceops:
88
+ self.sts.append(lazyop_sts_map[x])
89
+ self.sts.append(lazyop_sts_map[x.src[0]])
81
90
 
82
91
  # move all reduce axes to the end
83
92
  reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
@@ -89,9 +98,9 @@ class Kernel:
89
98
  self.group_for_reduces: int = 0
90
99
  self.upcasted: int = 0
91
100
  self.local_dims: int = 0
92
- self.local_alias: DefaultDict[LazyOp, Dict[int, LocalBuffer]] = defaultdict(dict)
93
101
  self.tensor_core: Optional[TensorCore] = None
94
102
  self.tensor_core_opts: Optional[TensorCoreOptions] = None
103
+ self.use_tensor_cores: int = 0
95
104
  # the local aliased buffers for A and B
96
105
  self.bufs_for_tensor_core: Dict[LazyOp, Tuple[int, int]] = {}
97
106
  self.dont_use_locals: bool = False
@@ -100,28 +109,22 @@ class Kernel:
100
109
  self.simplify_ones()
101
110
  self.simplify_merge_adjacent()
102
111
 
103
- # cache
104
- self.applied_opts_cache: Optional[List[Opt]] = None
105
-
106
112
  def copy(self):
107
113
  ret = type(self).__new__(type(self))
108
114
 
109
115
  # base linearizer params
110
- ret.opts, ret.ast, ret.lazyops = self.opts, self.ast, self.lazyops
116
+ ret.opts, ret.ast = self.opts, self.ast
111
117
 
112
118
  # things downstream of the AST
113
- ret.reduceops, ret.outbufs, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \
114
- self.reduceops, self.outbufs, self.vars, [x for x in self.bufs if not isinstance(x, LocalBuffer)], self.earlybufs, self.full_buf_index
115
- ret.sts = self.sts[:len(ret.bufs)] # NOTE: must redo the local buffers with TC in beam
119
+ ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = \
120
+ self.reduceops, self.vars, self.bufs, self.full_buf_index
121
+ ret.sts = self.sts[:len(ret.bufs)+len(ret.reduceops)*2] # NOTE: must redo the local buffers with TC in beam
116
122
 
117
123
  # parameters for optimizations
118
124
  ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
119
125
  self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
120
- ret.tensor_core, ret.tensor_core_opts, ret.local_alias, ret.bufs_for_tensor_core = self.tensor_core, self.tensor_core_opts, defaultdict(dict), \
121
- self.bufs_for_tensor_core
122
-
123
- # uncached since linearize didn't run
124
- ret.applied_opts_cache = None
126
+ ret.tensor_core, ret.tensor_core_opts, ret.bufs_for_tensor_core, ret.use_tensor_cores = \
127
+ self.tensor_core, self.tensor_core_opts, self.bufs_for_tensor_core, self.use_tensor_cores
125
128
 
126
129
  return ret
127
130
 
@@ -129,29 +132,20 @@ class Kernel:
129
132
  def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)]
130
133
 
131
134
  # TODO: these need more tests or it might silently be no-op
132
- def shape_offsets(self, i:int): return itertools.product(*[list(range(cast(int, s))) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] # noqa: E501
133
- def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] # noqa: E501
135
+ def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501
134
136
 
135
137
  def upcasted_axis(self, i:int) -> List[Tuple[int, Optional[sint], bool]]:
136
- upcasted_shape, upcasted_stride = self.sts[i].shape[self.shape_len-self.upcasted:], self.sts[i].real_strides()[self.shape_len-self.upcasted:]
138
+ upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:]
137
139
  assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
138
140
  return list(zip(upcasted_shape, upcasted_stride,
139
- [x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
140
-
141
- # TODO: is there a better way to write this?
142
- def acc_offsets(self, i:int) -> List[int]:
143
- if self.upcasted == 0: return [0]
144
- upcasted_i = self.upcasted_axis(i)
145
- acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
146
- return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]
147
-
148
- def get_float4_upcast_dim(self, i:int) -> List[int]:
149
- should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in (dtypes.float, dtypes.half) or isinstance(self.bufs[i].dtype, ImageDType))
150
- return [x for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1] if should_upcast else []
141
+ [x!=y for x,y in zip(self.sts[0].shape[self.first_upcast:], self.full_shape[self.first_upcast:])]))
151
142
 
152
143
  @property
153
144
  def first_reduce(self) -> int:
154
- return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True) # noqa: E501
145
+ return [x!=y for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True)
146
+
147
+ @property
148
+ def first_upcast(self) -> int: return self.shape_len-self.upcasted
155
149
 
156
150
  @property
157
151
  def reduceop(self) -> Optional[LazyOp]: return self.reduceops[0] if len(self.reduceops) > 0 else None
@@ -163,7 +157,7 @@ class Kernel:
163
157
  def full_shape(self) -> Tuple[sint, ...]: return self.sts[self.full_buf_index].shape
164
158
 
165
159
  @property
166
- def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.shape_len-self.upcasted]
160
+ def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.first_upcast]
167
161
 
168
162
  @property
169
163
  def shape_len(self) -> int: return len(self.sts[0].shape)
@@ -193,9 +187,9 @@ class Kernel:
193
187
  # between first_reduce and first_reduce + group_for_reduces, they are either upcast mid reduce (white), or late upcasted (green)
194
188
  colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + self.group_for_reduces)] # noqa: E501
195
189
  # between first_reduce + group_for_reduces and upcasted, they are reduce (red)
196
- colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + self.group_for_reduces))
190
+ colors += ["red"] * (self.first_upcast - (self.first_reduce + self.group_for_reduces))
197
191
  # upcasted dimensions are reduce (magenta) or normal (yellow)
198
- colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.shape_len-self.upcasted, self.shape_len)]
192
+ colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.first_upcast, self.shape_len)]
199
193
  assert len(colors) == self.shape_len, "colors size mismatch"
200
194
  return colors
201
195
 
@@ -229,7 +223,7 @@ class Kernel:
229
223
  move_axis = axis if top else axis+1
230
224
  if move_axis < insert_before: insert_before += 1
231
225
  self.reshape_and_permute(
232
- lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]),
226
+ lambda x: x[0:axis] + (((amount, x[axis]//amount) if top else (x[axis]//amount, amount)) if x[axis] > 1 else (1,1)) + x[axis+1:],
233
227
  [i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])
234
228
 
235
229
  # ******************** complex simplifiers ********************
@@ -240,7 +234,7 @@ class Kernel:
240
234
  if self.shape_len == 0: return False
241
235
  all_ones = [s==1 for s in self.full_shape]
242
236
  self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
243
- self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:]) # TODO: no necessary since upcasted axis can't be un-upcasted
237
+ self.upcasted -= sum(all_ones[self.first_upcast:]) # TODO: no necessary since upcasted axis can't be un-upcasted
244
238
  self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
245
239
  return any(all_ones)
246
240
 
@@ -281,25 +275,6 @@ class Kernel:
281
275
  # do the reshapes
282
276
  for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
283
277
 
284
- # ******************** helpers ********************
285
-
286
- def alias_buffer(self, op:LazyOp, i:int, pattern:List[int]) -> None:
287
- assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
288
-
289
- bst = 1
290
- real_strides = self.sts[i].real_strides()
291
- shp, stride = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern)
292
- for priority in range(1, max(pattern)+1): # priority. 0 is non local and ignored
293
- for j,p in enumerate(pattern):
294
- if priority == p and real_strides[j] != 0:
295
- stride[j] = bst
296
- bst *= shp[j]
297
-
298
- self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
299
- self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size))
300
- if DEBUG >= 4: print("aliasing buffer", self.sts[i])
301
- self.local_alias[op][i] = cast(LocalBuffer, self.bufs[-1])
302
-
303
278
  # ******************** high level optimizers ********************
304
279
 
305
280
  def _create_tc_opts(self, reduceop:LazyOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
@@ -347,14 +322,26 @@ class Kernel:
347
322
  try:
348
323
  for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
349
324
  except KernelOptError: continue
350
- self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, tc.dims[2]), append_opt=False)
351
- for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
352
- if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
353
- for (tc_dim, tc_amt) in tc.threads:
354
- self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
355
-
356
- # assert tensor core
357
- if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
325
+ if self.opts.device in {"AMD", "HIP"}:
326
+ # NOTE: AMD requires locals first
327
+ self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, tc.dims[2]), append_opt=False)
328
+ for (tc_dim, tc_amt) in tc.threads: self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
329
+ for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
330
+ if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
331
+ elif self.opts.device == "METAL":
332
+ self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, tc.dims[2]), append_opt=False)
333
+ for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
334
+ if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
335
+ for (tc_dim, tc_amt) in tc.threads: self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
336
+ elif self.opts.device in {"CUDA", "NV"}:
337
+ self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, 8), append_opt=False)
338
+ self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, 2), append_opt=False)
339
+ # NOTE: LOCALS and UPCAST can be swapped here. it doesn't seem faster
340
+ self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[1], 2), append_opt=False)
341
+ self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[0], 2), append_opt=False)
342
+ for (tc_dim, tc_amt) in tc.threads: self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
343
+ self.tensor_core = tc
344
+ self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
358
345
  return True
359
346
  return False
360
347
 
@@ -373,7 +360,7 @@ class Kernel:
373
360
  1: allows kernels with multiple reduce axes and also multiplication of UnaryOps.CAST'd buffers
374
361
  2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
375
362
  """
376
- if tc_opt is None: tc_opt = self.opts.tc_opt
363
+ if tc_opt is None: tc_opt = TC_OPT.value
377
364
  if not self.opts.tensor_cores and use_tensor_cores != 2: return False
378
365
  try: # check TC first and apply hand-coded opts if successful
379
366
  self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
@@ -395,7 +382,6 @@ class Kernel:
395
382
  if self.full_shape[tc_opts.axes[0]] % upc == 0:
396
383
  self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], upc))
397
384
  break
398
-
399
385
  return True
400
386
  except KernelOptError:
401
387
  return False
@@ -406,7 +392,7 @@ class Kernel:
406
392
  if opt.op is OptOps.TC:
407
393
  check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
408
394
  check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
409
- check((use_tensor_cores:=self.opts.tc) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
395
+ check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
410
396
  check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available")
411
397
  self.applied_opts.append(opt)
412
398
  return
@@ -414,15 +400,16 @@ class Kernel:
414
400
  axis = opt.real_axis(self)
415
401
  check(axis < len(self.full_shape), "invalid axis")
416
402
 
417
- if opt.amt is not None:
403
+ if opt.op is OptOps.SWAP: amt = cast(int, opt.amt) # amt is an axis in the SWAPs
404
+ elif opt.amt is not None:
418
405
  amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
419
406
  check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless")
420
407
  if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
421
408
  else: amt = -1
422
409
 
423
410
  if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
424
- acc_sz, upcast_idx = dt.base.itemsize if isinstance((dt:=self.reduceop.dtype), ImageDType) else dt.itemsize, self.shape_len-self.upcasted
425
- upcast_sz = prod([a for a,b in zip(self.full_shape[upcast_idx:], self.sts[0].shape[upcast_idx:]) if a == b])
411
+ acc_sz = dt.base.itemsize if isinstance((dt:=self.reduceop.dtype), ImageDType) else dt.itemsize
412
+ upcast_sz = prod([a for a,b in zip(self.full_shape[self.first_upcast:], self.sts[0].shape[self.first_upcast:]) if a == b])
426
413
  local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
427
414
  smem_sz = amt*acc_sz*upcast_sz*local_sz
428
415
  check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
@@ -434,12 +421,13 @@ class Kernel:
434
421
  self.local_dims += 1
435
422
  elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
436
423
  check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
437
- check(axis >= self.first_reduce + self.group_for_reduces and axis < self.shape_len-self.upcasted, "must be reduce axis to group")
424
+ check(self.first_reduce + self.group_for_reduces <= axis < self.first_upcast, "must be reduce axis to group")
438
425
  check(not self.tensor_core, "can't group with tensor cores")
426
+ check(len(self.reduceops) == 1, "can't group with multiple reduces")
439
427
  self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
440
428
  self.group_for_reduces += 1
441
429
  elif opt.op is OptOps.UNROLL: # purple
442
- check(axis < self.shape_len-self.upcasted, "can't upcasted already upcasted")
430
+ check(axis < self.first_upcast, "can't upcasted already upcasted")
443
431
  check(amt <= 32, "don't unroll more than 32")
444
432
  # TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
445
433
  #upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
@@ -451,7 +439,7 @@ class Kernel:
451
439
  elif opt.op is OptOps.UPCAST: # yellow
452
440
  check(axis < self.first_reduce, "upcast is for non-reduce")
453
441
  check(not(self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
454
- check(amt <= 8, "don't upcast more than 8")
442
+ check(amt <= 16, "don't upcast more than 16")
455
443
  self.shift_to(axis, amt, insert_before=None)
456
444
  self.upcast()
457
445
  elif opt.op is OptOps.UPCASTMID: # white
@@ -466,18 +454,23 @@ class Kernel:
466
454
  check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
467
455
  check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
468
456
  self.dont_use_locals = True
457
+ elif opt.op is OptOps.SWAP:
458
+ check(axis < amt and amt < self.global_dims, f"swap is only for globals with axis < amt, getting {amt=}, {axis=}, {self.global_dims=}")
459
+ permute = list(range(self.shape_len))
460
+ permute[axis], permute[amt] = permute[amt], permute[axis]
461
+ self.reshape_and_permute(None, tuple(permute))
469
462
  elif opt.op is OptOps.PADTO:
470
463
  check(not self.vars, "does not work with symbolic shape")
471
- check(axis < self.shape_len - self.upcasted, "cannot pad upcasted")
464
+ check(axis < self.first_upcast, "cannot pad upcasted")
472
465
  # ok to pad SUM if all parent ops have f(0) = 0
473
466
  if self.first_reduce <= axis:
474
467
  check((r:=cast(LazyOp, self.reduceop)).op is ReduceOps.SUM and \
475
- all(op.op not in UNSAFE_PAD_OPS for ops in r.src for op in ops.lazyops), "cannot pad")
468
+ all(op.op not in UNSAFE_PAD_OPS for sop in r.src for op in sop.lazyops), "cannot pad")
476
469
  padded = False
477
470
  for i,st in enumerate(self.sts):
478
471
  if self.sts[i].shape[axis] == 1: continue # reduced
479
472
  check(self.sts[i].shape[axis] > amt//4, f"pad adds more than quadruple the work {self.sts[i].shape[axis]=} > {amt//4=}")
480
- if (ru := round_up(cast(int, self.sts[i].shape[axis]), cast(int, amt)) - self.sts[i].shape[axis]):
473
+ if (ru := round_up(cast(int, self.sts[i].shape[axis]), amt) - self.sts[i].shape[axis]):
481
474
  # pad right seems to be faster
482
475
  self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
483
476
  padded = True
@@ -487,14 +480,15 @@ class Kernel:
487
480
  if self.simplify_ones() and self.tensor_core_opts:
488
481
  self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
489
482
 
490
- def required_optimizations(self):
483
+ def required_optimizations(self) -> Kernel:
491
484
  if self.bufs[0].dtype.__class__ is ImageDType:
492
485
  unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
493
486
  assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[0]}"
494
- if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
487
+ if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
495
488
  self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
489
+ return self
496
490
 
497
- def hand_coded_optimizations(self):
491
+ def hand_coded_optimizations(self) -> Kernel:
498
492
  self.required_optimizations()
499
493
 
500
494
  # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
@@ -513,13 +507,13 @@ class Kernel:
513
507
  if MV_THREADS_PER_ROW > 1: self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
514
508
  if MV_BLOCKSIZE > 1: self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
515
509
  if MV_ROWS_PER_THREAD > 1: self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
516
- return
510
+ return self
517
511
 
518
512
  if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]):
519
513
  # are we grouping? (requires local shape support)
520
514
  if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # noqa: E501
521
515
  # TODO: use 1024 if it's allowed in a smarter way
522
- for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
516
+ for sz in ([256, 16] if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
523
517
  if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
524
518
  try: # may fail due to excessive smem usage
525
519
  self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
@@ -538,19 +532,19 @@ class Kernel:
538
532
  unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
539
533
  if buf.dtype.__class__ is ImageDType:
540
534
  #assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
541
- if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
535
+ if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
542
536
  if unit_stride_axes_mul_4[0] < self.first_reduce:
543
537
  self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
544
538
  else:
545
539
  self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
546
540
 
547
541
  # no more opt if we are grouping
548
- if self.group_for_reduces: return
542
+ if self.group_for_reduces: return self
549
543
 
550
544
  # **** below this line need to be optional and benchmarked ****
551
545
 
552
546
  # TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
553
- # to trigger the above bug, remove prod(self.full_shape[self.shape_len - self.upcasted:]) from the below
547
+ # to trigger the above bug, remove prod(self.full_shape[self.first_upcast:]) from the below
554
548
  # expression and run test/test_ops.py with IMAGE=2
555
549
  # if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
556
550
  # this can be made much smarter
@@ -560,7 +554,7 @@ class Kernel:
560
554
  # we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
561
555
  # for now skip upcasting here if there is a symbolic axis
562
556
  if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \
563
- prod(self.full_shape[self.shape_len - self.upcasted:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
557
+ prod(self.full_shape[self.first_upcast:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
564
558
  if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
565
559
  to_upcast.append(axis)
566
560
  for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
@@ -581,11 +575,11 @@ class Kernel:
581
575
  else: break
582
576
 
583
577
  # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
584
- if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): # noqa: E501
578
+ if self.first_reduce < self.first_upcast and (prod(self.full_shape[self.first_upcast:]) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): # noqa: E501
585
579
  if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
586
580
  self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
587
581
  # if it's small, upcast a second reduce dimension too
588
- if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int):
582
+ if self.first_reduce < self.first_upcast and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int):
589
583
  self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
590
584
  else:
591
585
  for splits in [4]:
@@ -618,3 +612,142 @@ class Kernel:
618
612
  will_delete_shape = local_sz == self.full_shape[axis]
619
613
  self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
620
614
  if will_delete_shape: deleted_shape += 1
615
+
616
+ return self
617
+
618
+ # **** kernel outputs ****
619
+
620
+ kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
621
+ @functools.cached_property
622
+ def name(self) -> str:
623
+ # kernel name (before late upcast)
624
+ name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.ast.lazyops) else "E")) + \
625
+ (f"{len(self.ast.src)}_" if len(self.ast.src) > 1 else "_") + \
626
+ colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
627
+
628
+ # name the function something unique
629
+ Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1
630
+ suffix = f"{'n'+str(Kernel.kernel_cnt[function_name]-1)}" if Kernel.kernel_cnt[function_name] > 1 else ""
631
+ return name+colored(suffix, 'BLACK')
632
+
633
+ def get_optimized_ast(self) -> LazyOp:
634
+ # set the shapetrackers to the optimized ones, fixup reduceop
635
+ # transformed to the final LazyOp
636
+ @functools.lru_cache(None)
637
+ def fixup_ast(op:LazyOp, apply_to_st=None) -> LazyOp:
638
+ if op.op in BufferOps:
639
+ if isinstance(op.arg, MemBuffer) and op.arg.idx < 0:
640
+ # for locals, we use the ShapeTracker that's in the MemBuffer
641
+ arg:Any = replace(op.arg, st=apply_to_st(op.arg.st)) if apply_to_st is not None else op.arg
642
+ else:
643
+ idx = self.bufs.index(op.arg)
644
+ arg = replace(op.arg, st=self.sts[idx] if apply_to_st is None else apply_to_st(self.sts[idx]))
645
+ elif op.op in ReduceOps:
646
+ reduce_idx = len(self.bufs) + self.reduceops.index(op)*2
647
+ arg = tuple(i for i in range(self.first_reduce+self.group_for_reduces, self.shape_len)
648
+ if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i])
649
+ if op in self.bufs_for_tensor_core and (tc := self.tensor_core):
650
+ rsrc = op.src[0]
651
+ if rsrc.op is UnaryOps.CAST: rsrc = rsrc.src[0]
652
+ assert rsrc.op is BinaryOps.MUL
653
+
654
+ def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1):
655
+ wd, tcd = self.global_dims, self.first_upcast
656
+ assert st1.shape[wd:wd+len(warp_dims)] == warp_dims, f"warp dims wrong: {st1.shape[wd:wd+len(warp_dims)]=} != {warp_dims=}"
657
+ assert st1.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, f"tcd dims wrong: {st1.shape[tcd:tcd+len(tcd_dims)]=} != {tcd_dims=}"
658
+ new_shape = st1.shape[:tcd] + tcd_expand + st1.shape[tcd+len(tcd_dims):] # expand the tcd
659
+ permaxis = list(range(wd)) + [y + (wd if x == 0 else tcd) for x,y in pattern_1] + list(range(wd+len(warp_dims), tcd)) + \
660
+ [y + (wd if x == 0 else tcd) for x,y in pattern_2] + list(range(tcd+len(tcd_expand), len(new_shape)))
661
+ return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape).simplify()
662
+
663
+ if self.opts.device in {"AMD", "HIP"}:
664
+ reduce_axes, upcast_axes = [0], [[(0, 16)], [(0, 16)], [(1, 8)]]
665
+ # https://gpuopen.com/learn/wmma_on_rdna3/
666
+ fix_st1 = functools.partial(fix_st, (8,2,2), (16,8), (16,2,4), ((1,2), (0,2), (1,1), (0,1)), ((1,0), (0,0)))
667
+ fix_st2 = None
668
+ elif self.opts.device == "METAL":
669
+ reduce_axes, upcast_axes = [0], [[(1, 2)], [(1, 2)], [(1, 2)]]
670
+ fix_st1 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((1,1), (0,1), (1,0), (0,3)), ((0,0), (0,2), (1,3), (1,2)))
671
+ fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3)))
672
+ elif self.opts.device in {"CUDA", "NV"}:
673
+ reduce_axes, upcast_axes = [0, 1], [[(0, 8)], [(2, 2), (3, 2)], [(2, 2), (3, 2)]]
674
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
675
+ fix_st1 = functools.partial(fix_st, (2,2,2,2,2), (8,2,2,2), (2,2,2,2,2,2),
676
+ ((1,1), (1,0), (0,2), (0,3), (0,4)), ((1,3), (1,4), (1,2), (0,0), (0,1), (1,5)))
677
+ fix_st2 = functools.partial(fix_st, (2,2,2,2,2), (8,2,2,2), (2,2,2,2,2,2),
678
+ ((1,1), (1,0), (1,5), (0,0), (0,1)), ((0,4), (0,2), (1,4), (0,3), (1,3), (1,2)))
679
+ else:
680
+ raise RuntimeError("unsupported device for tensor cores")
681
+
682
+ assert apply_to_st is None, "double tensor core? not supported"
683
+ wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device,
684
+ tuple(tuple((self.first_upcast+ax, sz) for ax, sz in up) for up in upcast_axes),
685
+ tuple(self.first_upcast+ax for ax in reduce_axes))
686
+ if self.use_tensor_cores >= 2:
687
+ if self.use_tensor_cores == 3:
688
+ # TC=3, emulate the warp addressing with locals
689
+ ex_shape = tuple(1 if i < self.global_dims or (i >= self.first_reduce and i < self.first_upcast) else s \
690
+ for i,s in enumerate(self.full_shape))
691
+ srcs = []
692
+ for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])):
693
+ st_load = [self.sts[self.bufs.index(op.arg)].real_strides() for op in src.lazyops if op.op is BufferOps.LOAD]
694
+ local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
695
+ membuf = MemBuffer(-1-i, tc.dtype_in, ShapeTracker.from_shape(local_shape).expand(ex_shape))
696
+ srcs.append(LazyOp(BufferOps.LOAD, (fixup_ast(LazyOp(BufferOps.STORE, (src,), membuf), fix_st_fxn),), membuf))
697
+ else:
698
+ # for TC=2, we can't do the shapetracker fixup
699
+ srcs = [fixup_ast(rsrc.src[0]), fixup_ast(rsrc.src[1])]
700
+ # MUL/SUM instead of WMMA
701
+ ret = LazyOp(ReduceOps.SUM, (LazyOp(UnaryOps.CAST, (LazyOp(BinaryOps.MUL, tuple(srcs)),), tc.dtype_out),), wmma_arg[-1])
702
+ else:
703
+ ret = LazyOp(ReduceOps.WMMA, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
704
+ return LazyOp(op.op, (ret,), new_reduce_axes) if (new_reduce_axes:=tuple(i for i in arg if i-self.first_upcast not in reduce_axes)) else ret
705
+ if self.group_for_reduces:
706
+ start = LazyOp(op.op, tuple(fixup_ast(x, apply_to_st) for x in op.src), arg)
707
+ local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces] + \
708
+ (1,) * (self.first_upcast - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
709
+ local_buffer = MemBuffer(-1, start.dtype, ShapeTracker.from_shape(local_shape))
710
+ local_store = LazyOp(BufferOps.STORE, (start,), local_buffer)
711
+ local_load = LazyOp(BufferOps.LOAD, (local_store,), local_buffer)
712
+ return LazyOp(op.op, (local_load,), tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces)))
713
+ elif op.op is MetaOps.KERNEL:
714
+ arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals)
715
+ else:
716
+ arg = op.arg
717
+ return LazyOp(op.op, tuple(fixup_ast(x, apply_to_st) for x in op.src), arg)
718
+ return fixup_ast(self.ast)
719
+
720
+ # **** this is the lowerer ****
721
+
722
+ def linearize(self) -> Kernel:
723
+ modified_ast = self.get_optimized_ast()
724
+
725
+ if DEBUG >= 3:
726
+ print(self.name)
727
+ if getenv("RAWAST"): print(self.ast)
728
+ print(modified_ast)
729
+ print(self.applied_opts)
730
+ verify_lazyop(modified_ast)
731
+
732
+ # generate the UOpGraph
733
+ self.uops:UOpGraph = UOpGraph(lazyop_to_uop(modified_ast, self.opts), self.opts)
734
+ if DEBUG >= 5: self.uops.print()
735
+ if getenv("GRAPHUOPS"): self.uops.graph()
736
+ return self
737
+
738
+ def to_program(self, name_override:Optional[str]=None) -> Program:
739
+ self.linearize()
740
+ self.uops.linearize(self.opts.extra_matcher)
741
+ src = self.opts.render(name:=to_function_name(ansiname:=(name_override if name_override is not None else self.name)), self.uops.uops)
742
+
743
+ if getenv("RUN_PROCESS_REPLAY"):
744
+ table_name = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}"
745
+ diskcache_put(table_name, id(self), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()}))
746
+
747
+ # group non-local MemBuffers by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
748
+ # TODO: these max and min don't work on symbolic, and results are very wrong.
749
+ mem_bytes = sum(max(x.arg.dtype.itemsize * x.arg.st.real_size() for x in group) for _, group in
750
+ itertools.groupby([x for x in self.ast.lazyops if x.op in BufferOps and isinstance(x.arg, MemBuffer) and x.arg.idx >= 0],
751
+ key=lambda x: (x.op, x.arg.idx)))
752
+ return Program(ansiname, src, self.opts.device, self.uops.uops, mem_estimate=mem_bytes,
753
+ global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)