tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. tinygrad/__init__.py +6 -0
  2. tinygrad/codegen/kernel.py +572 -83
  3. tinygrad/codegen/linearizer.py +415 -395
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +183 -0
  6. tinygrad/dtype.py +113 -0
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +76 -55
  14. tinygrad/helpers.py +196 -89
  15. tinygrad/lazy.py +210 -371
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +202 -22
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +112 -32
  20. tinygrad/nn/state.py +136 -39
  21. tinygrad/ops.py +119 -202
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +353 -166
  25. tinygrad/renderer/llvmir.py +150 -138
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +81 -0
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +75 -0
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +24 -77
  43. tinygrad/runtime/ops_cuda.py +175 -89
  44. tinygrad/runtime/ops_disk.py +56 -33
  45. tinygrad/runtime/ops_gpu.py +92 -95
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +39 -60
  48. tinygrad/runtime/ops_metal.py +92 -74
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +86 -254
  53. tinygrad/shape/symbolic.py +166 -141
  54. tinygrad/shape/view.py +296 -0
  55. tinygrad/tensor.py +2619 -448
  56. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. tinygrad-0.9.0.dist-info/METADATA +227 -0
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/codegen/assembly.py +0 -190
  61. tinygrad/codegen/optimizer.py +0 -379
  62. tinygrad/codegen/search.py +0 -72
  63. tinygrad/graph.py +0 -83
  64. tinygrad/jit.py +0 -57
  65. tinygrad/nn/image.py +0 -100
  66. tinygrad/renderer/assembly_arm64.py +0 -169
  67. tinygrad/renderer/assembly_ptx.py +0 -98
  68. tinygrad/renderer/wgsl.py +0 -53
  69. tinygrad/runtime/lib.py +0 -113
  70. tinygrad/runtime/ops_cpu.py +0 -51
  71. tinygrad/runtime/ops_hip.py +0 -82
  72. tinygrad/runtime/ops_shm.py +0 -29
  73. tinygrad/runtime/ops_torch.py +0 -30
  74. tinygrad/runtime/ops_webgpu.py +0 -45
  75. tinygrad-0.7.0.dist-info/METADATA +0 -212
  76. tinygrad-0.7.0.dist-info/RECORD +0 -40
  77. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,47 @@
1
- from typing import NamedTuple, Optional, List, Tuple, cast, Dict
2
- import itertools
3
- from tinygrad.ops import LazyOp, MovementOps, FlopCounter, get_lazyop_info, ReduceOps
4
- from tinygrad.lazy import LazyBuffer
5
- from tinygrad.helpers import dedup, dtypes, colored, prod, ImageDType, DType
6
- from tinygrad.runtime.lib import buf_is_kernel_arg
7
- from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
1
+ from __future__ import annotations
2
+ import math, itertools
3
+ from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
4
+ from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS
5
+ from tinygrad.device import Device
6
+ from tinygrad.renderer import Renderer, TensorCore
7
+ from tinygrad.dtype import dtypes, ImageDType, DType
8
+ from tinygrad.helpers import colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
9
+ from tinygrad.shape.shapetracker import ShapeTracker
10
+ from tinygrad.shape.symbolic import sint
11
+ from tinygrad.shape.view import View, strides_for_shape
12
+ from dataclasses import dataclass
13
+ from enum import Enum, auto
14
+
15
+ class OptOps(Enum):
16
+ TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
17
+ GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702
18
+ def __lt__(self, x:OptOps): return self.value < x.value
19
+
20
+ class KernelOptError(Exception): pass
21
+
22
+ def check(cond:bool, msg:str=""):
23
+ if not cond: raise KernelOptError(msg)
24
+
25
+ @dataclass(frozen=True, order=True)
26
+ class Opt:
27
+ op: OptOps
28
+ axis: Optional[int] = None
29
+ amt: Optional[int] = None
30
+ def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
31
+ def real_axis(self, k:Kernel):
32
+ if self.axis is None: return -1
33
+ if self.op is OptOps.UNROLL: return k.first_reduce+self.axis
34
+ if self.op in {OptOps.GROUP, OptOps.GROUPTOP}: return k.first_reduce+k.group_for_reduces+self.axis
35
+ return self.axis
36
+
37
+ class TensorCoreOptions(NamedTuple):
38
+ bufs: Tuple[int, int] # the local aliased buffers for A and B
39
+ axes: List[int] # the location of the original N and M axes if still in the shape
40
+ axes_exist: List[bool] # true if the original N and M axes are still in the shape
41
+ def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when an dimension is removed
42
+ for tc_dim in [i for i in range(2) if self.axes_exist[i]]:
43
+ if removed_axis < self.axes[tc_dim]: self.axes[tc_dim] -= 1
44
+ elif removed_axis == self.axes[tc_dim]: self.axes_exist[tc_dim] = False
8
45
 
9
46
  class LocalBuffer(NamedTuple):
10
47
  name: str
@@ -13,107 +50,129 @@ class LocalBuffer(NamedTuple):
13
50
  realized: None = None
14
51
  def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
15
52
 
16
- class LinearizerOptions(NamedTuple):
17
- # TODO: make this generic with a list of supported types
18
- supports_float4: bool = True
19
- supports_float4_alu: bool = True
20
- has_local: bool = True
21
- # NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
22
- global_max: Optional[List[int]] = None
23
- local_max: Optional[List[int]] = None
24
-
25
53
  class Kernel:
26
- def __init__(self, ast:LazyOp, output_buffer:LazyBuffer, opts:LinearizerOptions):
27
- # NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf
28
- self.ast = ast.src[0] if ast.op == MovementOps.RESHAPE else ast
29
- self.opts = opts
30
-
31
- # get the output buffers
32
- self.bufs = [output_buffer] + dedup(ast.buffers)
33
- self.arg_bufs = {x:f"data{i}" for i,x in enumerate(dedup([x.realized for x in self.bufs if buf_is_kernel_arg(x)]))}
34
-
35
- # key for lookup in cache (can change, str might not be right)
36
- # bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels.
37
- # mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?)
38
- self.key = (ast.map_buffers({x:(self.arg_bufs[x.realized] if x.realized in self.arg_bufs else x) for x in self.bufs}).key, tuple([x.key for x in self.bufs]))
39
-
40
- def process(self) -> None:
41
- if hasattr(self, "sts"): return # already processed
42
-
43
- # fetch lazyop info
44
- self.info: FlopCounter = get_lazyop_info(cast(LazyOp, self.ast))
45
- self.mem_estimate: int = sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None)
54
+ def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
55
+ self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
56
+ assert all(op.op is BufferOps.STORE for op in ast), f"kernels must have stores as the output, got {ast}"
57
+ assert len(set(op.arg.st.size for op in ast)) == 1, f"all outbufs should have the same size, got {[op.arg.st for op in ast]}"
58
+ self.ast = ast
59
+ self.lazyops = flatten([op.lazyops for op in self.ast])
46
60
 
47
61
  # there's only allowed to be one reduceop
48
- reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps]
49
- assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast"
50
- self.reduceop = reduceops[0] if reduceops else None
62
+ cached_ordered_lazyops: Dict[LazyOp, List[LazyOp]] = {}
63
+ def ordered_lazyops(op):
64
+ if op not in cached_ordered_lazyops: cached_ordered_lazyops[op] = dedup([item for x in op.src for item in ordered_lazyops(x)] + [op])
65
+ return cached_ordered_lazyops[op]
66
+ self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps])
67
+ assert len(self.reduceops) < 2, "Only one reduceop allowed"
68
+
69
+ self.outbufs, self.vars = [x.arg for x in self.ast], flatten([x.vars() for x in self.ast])
70
+ loadops = [BufferOps.LOAD, BufferOps.CONST]
71
+ self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])
51
72
 
52
73
  # get earlybufs, before the one reduce op
53
- self.earlybufs = dedup(self.reduceop.buffers) if self.reduceop else []
74
+ self.earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
75
+ self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
54
76
 
55
77
  # create new shapetrackers inside this kernel, we will permute them
56
- self.sts: List[ShapeTracker] = [x.st.copy() for x in self.bufs]
57
- for st in self.sts: st.simplify()
78
+ self.sts: List[ShapeTracker] = [x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)]
58
79
 
59
- # make the output buffer shape correct in here
60
- self.sts[0].reshape(self.info.shape)
61
- self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
80
+ # move all reduce axes to the end
81
+ reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
82
+ permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
83
+ self.reshape_and_permute(None, permute)
62
84
 
63
- # parameters
64
- self.group_for_reduce: List[int] = []
85
+ # parameters for optimization
86
+ self.applied_opts: List[Opt] = []
87
+ self.group_for_reduces: int = 0
65
88
  self.upcasted: int = 0
66
89
  self.local_dims: int = 0
67
90
  self.local_alias: Dict[int, LocalBuffer] = {}
68
- self.use_tensor_cores: bool = False
69
- self.exclude_local_upcast: int = 0
70
- self.reverse_upcast_dir: bool = False
91
+ self.tensor_core: Optional[TensorCore] = None
92
+ self.tensor_core_opts: Optional[TensorCoreOptions] = None
93
+ self.dont_use_locals: bool = False
71
94
 
72
- def has_variable_shape(self) -> bool:
73
- for b in self.bufs:
74
- if any(not isinstance(x, int) for x in b.st.shape): return True
75
- return False
95
+ # group simplifies
96
+ self.simplify_ones()
97
+ self.simplify_merge_adjacent()
98
+
99
+ # cache
100
+ self.applied_opts_cache: Optional[List[Opt]] = None
101
+
102
+ def copy(self):
103
+ ret = type(self).__new__(type(self))
104
+
105
+ # base linearizer params
106
+ ret.opts, ret.ast, ret.lazyops = self.opts, self.ast, self.lazyops
107
+
108
+ # things downstream of the AST
109
+ ret.reduceops, ret.outbufs, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \
110
+ self.reduceops, self.outbufs, self.vars, [x for x in self.bufs if not isinstance(x, LocalBuffer)], self.earlybufs, self.full_buf_index
111
+ ret.sts = self.sts[:len(ret.bufs)] # NOTE: must redo the local buffers with TC in beam
112
+
113
+ # parameters for optimizations
114
+ ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
115
+ self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
116
+ ret.tensor_core, ret.tensor_core_opts, ret.local_alias = self.tensor_core, self.tensor_core_opts, {}
117
+
118
+ # uncached since linearize didn't run
119
+ ret.applied_opts_cache = None
76
120
 
77
- def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
78
- def float4_axis(self, i): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0]
121
+ return ret
79
122
 
80
- def upcasted_axis(self, i):
81
- return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:],
82
- self.sts[i].real_strides()[self.shape_len-self.upcasted:],
123
+ @property
124
+ def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)]
125
+
126
+ # TODO: these need more tests or it might silently be no-op
127
+ def shape_offsets(self, i:int): return itertools.product(*[list(range(cast(int, s))) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] # noqa: E501
128
+ def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] # noqa: E501
129
+
130
+ def upcasted_axis(self, i:int) -> List[Tuple[int, Optional[sint], bool]]:
131
+ upcasted_shape, upcasted_stride = self.sts[i].shape[self.shape_len-self.upcasted:], self.sts[i].real_strides()[self.shape_len-self.upcasted:]
132
+ assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
133
+ return list(zip(upcasted_shape, upcasted_stride,
83
134
  [x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
84
135
 
85
136
  # TODO: is there a better way to write this?
86
- def acc_offsets(self, i):
137
+ def acc_offsets(self, i:int) -> List[int]:
87
138
  if self.upcasted == 0: return [0]
88
139
  upcasted_i = self.upcasted_axis(i)
89
140
  acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
90
141
  return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]
91
142
 
92
- def get_upcast_dim(self, i) -> List[int]:
93
- should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType))
94
- return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1]
143
+ def get_float4_upcast_dim(self, i:int) -> List[int]:
144
+ should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in (dtypes.float, dtypes.half) or isinstance(self.bufs[i].dtype, ImageDType))
145
+ return [x for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1] if should_upcast else []
146
+
147
+ @property
148
+ def first_reduce(self) -> int:
149
+ return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True) # noqa: E501
95
150
 
96
151
  @property
97
- def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True)
152
+ def reduceop(self) -> Optional[LazyOp]: return self.reduceops[0] if len(self.reduceops) > 0 else None
98
153
 
99
154
  @property
100
- def output_shape(self) -> Tuple[int, ...]: return self.sts[0].shape
155
+ def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape
101
156
 
102
157
  @property
103
- def full_shape(self) -> Tuple[int, ...]: return self.sts[self.full_buf_index].shape
158
+ def full_shape(self) -> Tuple[sint, ...]: return self.sts[self.full_buf_index].shape
104
159
 
105
160
  @property
106
- def full_unupcasted_shape(self) -> Tuple[int, ...]: return self.full_shape[:self.shape_len-self.upcasted]
161
+ def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.shape_len-self.upcasted]
107
162
 
108
163
  @property
109
164
  def shape_len(self) -> int: return len(self.sts[0].shape)
110
165
 
111
166
  @property
112
- def upcast_in_mid_reduce_axes(self) -> List[int]: return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]]
167
+ def upcast_in_mid_reduce_axes(self) -> List[int]:
168
+ return [j for j in range(self.first_reduce, self.first_reduce+self.group_for_reduces) if self.full_shape[j] == self.sts[0].shape[j]]
113
169
 
114
- # there's seven chunks of the shape
170
+ @property
171
+ def global_dims(self) -> int: return self.first_reduce-self.local_dims
172
+
173
+ # there's eight chunks of the shape
115
174
  # blue -- global dims
116
- # cyan -- local dims
175
+ # cyan -- local dims (warp ones first)
117
176
  # *** self.first_reduce
118
177
  # green -- reduce-local dims
119
178
  # white -- reduce-late upcasted dim (self.upcast_in_mid_reduce_axes)
@@ -122,22 +181,452 @@ class Kernel:
122
181
  # purple -- reduce upcasted
123
182
  # yellow -- normal upcasted dimensions
124
183
  def colors(self) -> List[str]:
125
- # up to first_reduce, they are all global (blue)
126
- colors = ["blue"] * (self.first_reduce-self.local_dims)
127
- # except the local_dims, these are non-reduce locals (cyan)
128
- colors += ["cyan"] * (self.local_dims)
129
- # between first_reduce and first_reduce + group_for_reduce, they are either local (cyan), or late upcasted (green)
130
- colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))]
131
- # between first_reduce + group_for_reduce and upcasted, they are reduce (red)
132
- colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + len(self.group_for_reduce)))
184
+ # first non local non reduce dims are global (blue)
185
+ colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
186
+ # after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
187
+ colors += ["cyan"] * self.local_dims
188
+ # between first_reduce and first_reduce + group_for_reduces, they are either upcast mid reduce (white), or late upcasted (green)
189
+ colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + self.group_for_reduces)] # noqa: E501
190
+ # between first_reduce + group_for_reduces and upcasted, they are reduce (red)
191
+ colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + self.group_for_reduces))
133
192
  # upcasted dimensions are reduce (magenta) or normal (yellow)
134
193
  colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.shape_len-self.upcasted, self.shape_len)]
135
194
  assert len(colors) == self.shape_len, "colors size mismatch"
136
195
  return colors
137
196
 
138
- def colored_shape(self) -> str: return ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) else s for s in self.full_shape], self.colors()))
139
- def printbufs(self, prefix=""):
140
- for i in range(len(self.sts)):
141
- print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", self.sts[i].views)
142
- print(self.colored_shape())
197
+ def colored_shape(self, pad:Optional[int]=None, dense=False) -> str:
198
+ ret = ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) and not dense else s for s in self.full_shape], self.colors())) # noqa: E501
199
+ if pad: ret += ' '*(pad-ansilen(ret))
200
+ return ret
201
+
202
+ # ******************** base simplifiers ********************
203
+
204
+ # apply reshape and permute to all shapetrackers
205
+ def reshape_and_permute(self, new_shape_fxn, axis):
206
+ new_sts = []
207
+ for st in self.sts:
208
+ if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape)))
209
+ if axis is not None: st = st.permute(tuple(axis))
210
+ new_sts.append(st)
211
+ self.sts = new_sts
212
+
213
+ # drops the final dimension
214
+ def upcast(self):
215
+ check(self.full_shape[-1] != 1, "can't upcast a dimension with size 1")
216
+ self.upcasted += 1
217
+
218
+ # axis : the axis to pull from
219
+ # amount : the amount to take
220
+ # top : if you want to pull that amount from the top
221
+ # insert_before : place to insert the new stuff
222
+ def shift_to(self, axis, amount, top=False, insert_before=None):
223
+ if insert_before is None: insert_before = self.shape_len
224
+ move_axis = axis if top else axis+1
225
+ if move_axis < insert_before: insert_before += 1
226
+ self.reshape_and_permute(
227
+ lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]),
228
+ [i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])
229
+
230
+ # ******************** complex simplifiers ********************
231
+
232
+ def simplify_ones(self) -> bool:
233
+ # remove places where the shape is all ones
234
+ # TODO: this should be factored in to multi shape stride
235
+ if self.shape_len == 0: return False
236
+ all_ones = [s==1 for s in self.full_shape]
237
+ self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
238
+ self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:]) # TODO: no necessary since upcasted axis can't be un-upcasted
239
+ self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
240
+ return any(all_ones)
241
+
242
+ def simplify_merge_adjacent(self):
243
+ if self.shape_len == 0: return
244
+ shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
245
+
246
+ # if it's an image, insert fake strides such that this fusion doesn't happen across image axes
247
+ if isinstance(self.bufs[0].dtype, ImageDType):
248
+ base_shape = self.bufs[0].dtype.shape
249
+ if shape_idx_groups := get_contraction(self.output_shape, base_shape):
250
+ special_strides: Tuple[sint, ...] = tuple()
251
+ for i,g in enumerate(shape_idx_groups):
252
+ shape_piece = tuple(self.output_shape[x] for x in g)
253
+ assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
254
+ special_strides += strides_for_shape(shape_piece)
255
+ # adding the fake image shape
256
+ shapes.append(self.output_shape)
257
+ strides.append(special_strides)
258
+
259
+ # merge dimensions if we can, multi get_shape_strides
260
+ # NOTE: this does not always preserve the reduce dimension
261
+ # TODO: move this into shapetracker, with tests!
262
+ rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))]
263
+ for i in range(1, len(shapes[0])):
264
+ can_merge = []
265
+ for j in range(len(shapes)):
266
+ # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
267
+ can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0))) # noqa: E501
268
+ # more can merge than this
269
+ mergeable = all(can_merge) and i != self.first_reduce
270
+ for j in range(len(shapes)):
271
+ if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
272
+ else: rets[j].append((shapes[j][i], strides[j][i]))
273
+
274
+ # do the reshapes
275
+ for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
276
+
277
+ # ******************** helpers ********************
278
+
279
+ def _limit_size(self, x: Tuple[int], max_size: List[Union[int,float]]) -> Tuple[int, ...]:
280
+ new_shape = list(x)
281
+ for i in range(len(new_shape)):
282
+ next_idx = (i + 1) % len(new_shape)
283
+ while new_shape[i] > max_size[i]:
284
+ # TODO: what if new_shape[i] is not a multiple of 2??
285
+ new_shape[i] = new_shape[i] // 2
286
+ next_idx = next_idx if new_shape[next_idx] <= max_size[next_idx] else (next_idx + 1) % len(new_shape)
287
+ new_shape[next_idx] = new_shape[next_idx] * 2
288
+ return tuple(new_shape)
289
+
290
+ def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
291
+ # Check the global allocation limit, current the global_size will be flipped during codegen
292
+ # and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
293
+ if self.global_dims > 0:
294
+ if global_max:
295
+ tmp = global_max[:self.global_dims] + (local_max[:self.local_dims] if local_max else [])
296
+ if max(global_max) < max(self.full_shape[:self.global_dims]):
297
+ self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
298
+ assert max(global_max) >= max(self.full_shape[:self.global_dims]), f"device max allocation {max(self.full_shape[:self.global_dims])} exceeds global dim maximum {max(global_max)}" # noqa: E501
299
+ for i in range(self.global_dims-1):
300
+ if i < len(global_max) and self.full_shape[i] > global_max[i]:
301
+ order = list(range(len(self.full_shape)))
302
+ order[i], order[self.global_dims-1] = order[self.global_dims-1], order[i]
303
+ self.reshape_and_permute(None, order)
304
+ if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit")
305
+
306
+ def alias_buffer(self, i, pattern):
307
+ assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
308
+
309
+ bst = 1
310
+ real_strides = self.sts[i].real_strides()
311
+ shp, stride = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern)
312
+ for priority in range(1, max(pattern)+1): # priority. 0 is non local and ignored
313
+ for j,p in enumerate(pattern):
314
+ if priority == p and real_strides[j] != 0:
315
+ stride[j] = bst
316
+ bst *= shp[j]
317
+
318
+ self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
319
+ self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size))
320
+ if DEBUG >= 4: print("aliasing buffer", self.sts[i])
321
+ self.local_alias[i] = cast(LocalBuffer, self.bufs[-1])
322
+
323
+ # ******************** high level optimizers ********************
324
+
325
+ def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
326
+ if use_tensor_cores and self.opts.has_local and self.reduceop is not None and self.reduceop.op is ReduceOps.SUM:
327
+ for tc in self.opts.tensor_cores:
328
+ has_cast = tc.dtype_in != tc.dtype_out
329
+ if has_cast and not(self.reduceop.src[0].op is UnaryOps.CAST and self.reduceop.src[0].arg == tc.dtype_out): continue
330
+
331
+ mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
332
+ if mul_op.op is not BinaryOps.MUL: continue
333
+
334
+ def buf_index(src: LazyOp) -> Optional[int]:
335
+ # TODO: apply tc even if the sources are not from LOAD
336
+ if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
337
+ try:
338
+ if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
339
+ except ValueError: return None
340
+ return None
341
+ if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue
342
+
343
+ buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
344
+ axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
345
+ axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
346
+ if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): continue
347
+
348
+ axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
349
+ if not(axis < len(axis_choices)): continue
350
+
351
+ s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
352
+ axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0]
353
+ if axis_pads and (opt_level < 2): continue
354
+
355
+ # tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
356
+ self.tensor_core_opts = (tc_opts:=TensorCoreOptions(bufs=(buf0, buf1), axes=[s0, s1], axes_exist=[True, True]))
357
+
358
+ # attempt to pad the tensor axes that require it
359
+ try:
360
+ for axis, dim in axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
361
+ except KernelOptError: continue
362
+ self.apply_opt(Opt(OptOps.UNROLL, s2-self.first_reduce, tc.dims[2]), append_opt=False)
363
+ for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
364
+ if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
365
+ for (tc_dim, tc_amt) in tc.threads:
366
+ self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
367
+
368
+ # assert tensor core
369
+ if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
370
+ if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
371
+ return True
372
+ return False
373
+
374
+ def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, axis:int=0, tc_opt:int=getenv("TC_OPT")) -> bool:
375
+ """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
376
+ Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
377
+
378
+ Keyword arguments:
379
+ use_tensor_cores -- controls how tensor cores are applied (default 1)
380
+ 0: will disable any tensor core matching
381
+ 1: enable tensor cores
382
+ 2: apply tensor core shape but don't use UOp.WMMA
383
+ extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
384
+ tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
385
+ 0: applies to only kernels with a single reduce axis and direct BufferOps.LOAD into BinaryOps.MUL
386
+ 1: allows kernels with multiple reduce axes and also multiplication of UnaryOps.CAST'd buffers
387
+ 2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
388
+ """
389
+ if not self.opts.tensor_cores and use_tensor_cores != 2: return False
390
+ try: # check TC first and apply hand-coded opts if successful
391
+ self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
392
+
393
+ if (tc_opts:=self.tensor_core_opts) is not None:
394
+ if extra_opts is not None:
395
+ for opt in extra_opts: self.apply_opt(opt)
396
+ else:
397
+ # hand-coded TC opts
398
+ def late_upcast_tc(tc_dim: int):
399
+ if tc_opts.axes_exist[tc_dim]:
400
+ ax_div = [upc for upc in [5,4,3,2,1] if self.full_shape[tc_opts.axes[tc_dim]]%upc == 0][0]
401
+ if ax_div != 1: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], ax_div))
402
+ late_upcast_tc(1) # attempt to upcast M
403
+ late_upcast_tc(0) # attempt to upcast N
404
+
405
+ if self.tensor_core and tc_opts.axes_exist[0]: # attempt to local N
406
+ for upc in [4,2]:
407
+ if self.full_shape[tc_opts.axes[0]] % upc == 0:
408
+ self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], upc))
409
+ break
410
+
411
+ return True
412
+ except KernelOptError:
413
+ return False
414
+
415
+ def apply_opt(self, opt:Opt, append_opt:bool=True):
416
+ check(not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals")
417
+
418
+ if opt.op is OptOps.TC:
419
+ check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
420
+ check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
421
+ check((use_tensor_cores:=getenv("TC", 1)) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
422
+ check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available")
423
+ self.applied_opts.append(opt)
424
+ return
425
+
426
+ axis = opt.real_axis(self)
427
+ check(axis < len(self.full_shape), "invalid axis")
428
+
429
+ if opt.amt is not None:
430
+ amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
431
+ check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless")
432
+ if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
433
+ else: amt = -1
434
+
435
+ if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
436
+ acc_sz, upcast_idx = dt.base.itemsize if isinstance((dt:=self.reduceop.dtype), ImageDType) else dt.itemsize, self.shape_len-self.upcasted
437
+ upcast_sz = prod([a for a,b in zip(self.full_shape[upcast_idx:], self.sts[0].shape[upcast_idx:]) if a == b])
438
+ local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
439
+ smem_sz = amt*acc_sz*upcast_sz*local_sz
440
+ check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
441
+
442
+ if opt.op is OptOps.LOCAL: # cyan
443
+ check(self.opts.has_local, "target does not support local")
444
+ check(axis < self.global_dims, "local is for globals")
445
+ self.shift_to(axis, amt, insert_before=self.first_reduce)
446
+ self.local_dims += 1
447
+ elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
448
+ check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
449
+ check(axis >= self.first_reduce + self.group_for_reduces and axis < self.shape_len-self.upcasted, "must be reduce axis to group")
450
+ check(not self.tensor_core, "can't group with tensor cores")
451
+ self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
452
+ self.group_for_reduces += 1
453
+ elif opt.op is OptOps.UNROLL: # purple
454
+ check(axis < self.shape_len-self.upcasted, "can't upcasted already upcasted")
455
+ check(amt <= 32, "don't unroll more than 32")
456
+ # TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
457
+ #upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
458
+ #self.shift_to(axis, amt, insert_before=None if upcast_count == 0 else self.shape_len-upcast_count)
459
+ if self.full_shape[axis] == amt and axis == self.first_reduce: self.local_dims += 1 # first_reduce will ++, so offset loss in simplify_ones
460
+ if self.full_shape[axis] == amt and axis < self.first_reduce+self.group_for_reduces: self.group_for_reduces -= 1 # fully unrolling a GROUP
461
+ self.shift_to(axis, amt, insert_before=None)
462
+ self.upcast()
463
+ elif opt.op is OptOps.UPCAST: # yellow
464
+ check(axis < self.first_reduce, "upcast is for non-reduce")
465
+ check(not(self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
466
+ check(amt <= 8, "don't upcast more than 8")
467
+ self.shift_to(axis, amt, insert_before=None)
468
+ self.upcast()
469
+ elif opt.op is OptOps.UPCASTMID: # white
470
+ check(self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
471
+ axes = self.sts[0].unit_stride_axes()
472
+ check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
473
+ check(axes[0] == axis, "wrong axis")
474
+ check(amt == 4, "don't upcast mid anything but 4")
475
+ self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces)
476
+ self.group_for_reduces += 1
477
+ elif opt.op is OptOps.NOLOCALS:
478
+ check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
479
+ check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
480
+ self.dont_use_locals = True
481
+ elif opt.op is OptOps.PADTO:
482
+ check(not self.vars, "does not work with symbolic shape")
483
+ check(axis < self.shape_len - self.upcasted, "cannot pad upcasted")
484
+ # ok to pad SUM if all parent ops have f(0) = 0
485
+ if self.first_reduce <= axis:
486
+ check((r:=cast(LazyOp, self.reduceop)).op is ReduceOps.SUM and \
487
+ all(op.op not in UNSAFE_PAD_OPS for ops in r.src for op in ops.lazyops), "cannot pad")
488
+ padded = False
489
+ for i,st in enumerate(self.sts):
490
+ if self.sts[i].shape[axis] == 1: continue # reduced
491
+ check(self.sts[i].shape[axis] > amt//4, f"pad adds more than quadruple the work {self.sts[i].shape[axis]=} > {amt//4=}")
492
+ if (ru := round_up(cast(int, self.sts[i].shape[axis]), cast(int, amt)) - self.sts[i].shape[axis]):
493
+ # pad right seems to be faster
494
+ self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
495
+ padded = True
496
+ check(padded, "nothing was padded")
497
+
498
+ if append_opt: self.applied_opts.append(opt)
499
+ if self.simplify_ones() and self.tensor_core_opts:
500
+ self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
501
+
502
+ def required_optimizations(self):
503
+ if self.bufs[0].dtype.__class__ is ImageDType:
504
+ unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
505
+ assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[0]}"
506
+ if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
507
+ self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
508
+
509
+ def hand_coded_optimizations(self):
510
+ self.required_optimizations()
511
+
512
+ # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
513
+ MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
514
+ if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
515
+ self.reduceop is not None and self.reduceop.op is ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
516
+ (mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is BufferOps.LOAD and mulop.src[1].op is BufferOps.LOAD:
517
+ st0, st1 = self.sts[self.bufs.index(mulop.src[0].arg)], self.sts[self.bufs.index(mulop.src[1].arg)]
518
+ strides0, strides1 = st0.real_strides(), st1.real_strides()
519
+ def has_expanded_axis(shape, strides): return any(s > 1 and st == 0 for s,st in zip(shape,strides))
520
+ if strides0[self.first_reduce] == 1 and not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
521
+ for global_idx in range(self.global_dims):
522
+ if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
523
+ if DEBUG >= 3:
524
+ print(f"MATVEC: {self.full_shape=} {self.first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
525
+ if MV_THREADS_PER_ROW > 1: self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
526
+ if MV_BLOCKSIZE > 1: self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
527
+ if MV_ROWS_PER_THREAD > 1: self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
528
+ return
529
+
530
+ if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]):
531
+ # are we grouping? (requires local shape support)
532
+ if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # noqa: E501
533
+ # TODO: use 1024 if it's allowed in a smarter way
534
+ for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
535
+ if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
536
+ try: # may fail due to excessive smem usage
537
+ self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
538
+ break
539
+ except KernelOptError: pass
540
+
541
+ # are we upcasting in mid reduce? (only for images)
542
+ if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
543
+ axes = self.sts[0].unit_stride_axes()
544
+ assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
545
+ if self.sts[0].shape[axes[0]]%4 == 0:
546
+ self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4))
547
+
548
+ # upcast float4 images
549
+ for buf_index,buf in enumerate(self.bufs):
550
+ unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
551
+ if buf.dtype.__class__ is ImageDType:
552
+ #assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
553
+ if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
554
+ if unit_stride_axes_mul_4[0] < self.first_reduce:
555
+ self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
556
+ else:
557
+ self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
558
+
559
+ # no more opt if we are grouping
560
+ if self.group_for_reduces: return
561
+
562
+ # **** below this line need to be optional and benchmarked ****
563
+
564
+ # TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
565
+ # to trigger the above bug, remove prod(self.full_shape[self.shape_len - self.upcasted:]) from the below
566
+ # expression and run test/test_ops.py with IMAGE=2
567
+ # if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
568
+ # this can be made much smarter
569
+ to_upcast: List[int] = []
570
+ # upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
571
+ for axis in range(self.first_reduce):
572
+ # we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
573
+ # for now skip upcasting here if there is a symbolic axis
574
+ if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \
575
+ prod(self.full_shape[self.shape_len - self.upcasted:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
576
+ if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
577
+ to_upcast.append(axis)
578
+ for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
579
+
580
+ # potentially do more upcasts of non reduce axes based on a heuristic
581
+ upcasted_axis = set()
582
+ while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
583
+ xb_choices = []
584
+ for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
585
+ # if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
586
+ if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): # noqa: E501
587
+ xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) # noqa: E501
588
+ if xb_choices:
589
+ xb_choices = sorted(xb_choices)
590
+ if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
591
+ self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
592
+ upcasted_axis.add(xb_choices[0][2])
593
+ else: break
594
+
595
+ # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
596
+ if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): # noqa: E501
597
+ if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
598
+ self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
599
+ # if it's small, upcast a second reduce dimension too
600
+ if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int):
601
+ self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
602
+ else:
603
+ for splits in [4]:
604
+ if self.full_unupcasted_shape[-1]%splits == 0:
605
+ self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits))
606
+ break
607
+
608
+ # if nothing at all is upcasted and it's easy to, do an upcast
609
+ # TODO: this is breaking the tests
610
+ for splits in [4]:
611
+ if self.upcasted == 0 and self.full_unupcasted_shape and self.full_unupcasted_shape[-1] % splits == 0:
612
+ self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits))
613
+
614
+ # **** local groups ****
143
615
 
616
+ if self.opts.has_local:
617
+ if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduces:
618
+ self.apply_opt(Opt(OptOps.NOLOCALS))
619
+ else:
620
+ # prioritize making expand axes local
621
+ local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] # noqa: E501
622
+ to_local: List[Tuple[int, int]] = []
623
+ for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
624
+ local_size = prod(sz for _, sz in to_local)
625
+ local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None) # noqa: E501
626
+ if local_sz is not None: to_local.append((axis, local_sz))
627
+ deleted_shape = 0
628
+ for axis, local_sz in sorted(to_local[:3]):
629
+ axis = axis - deleted_shape
630
+ will_delete_shape = local_sz == self.full_shape[axis]
631
+ self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
632
+ if will_delete_shape: deleted_shape += 1