tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
@@ -1,25 +1,25 @@
1
1
  from __future__ import annotations
2
- import itertools, functools
2
+ import itertools, functools, math
3
3
  from dataclasses import dataclass
4
4
  from collections import defaultdict
5
- from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence
5
+ from typing import Optional, cast, Final, Callable, Sequence
6
6
  from enum import Enum, auto
7
7
 
8
- from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \
9
- graph_rewrite, track_rewrites, UPat
8
+ from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, view_left, print_uops
9
+ from tinygrad.spec import type_verify, shape_spec
10
10
  from tinygrad.device import Device
11
- from tinygrad.renderer import Renderer, TensorCore, Program
11
+ from tinygrad.renderer import Renderer, TensorCore, ProgramSpec
12
12
  from tinygrad.dtype import ImageDType
13
- from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap
14
- from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
13
+ from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap, ContextVar
14
+ from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, USE_TC, AMX, CAPTURE_PROCESS_REPLAY
15
15
  from tinygrad.shape.shapetracker import ShapeTracker
16
16
  from tinygrad.shape.view import strides_for_shape
17
17
  from tinygrad.codegen.linearize import linearize_uop
18
- from tinygrad.codegen.uopgraph import full_graph_rewrite
18
+ from tinygrad.codegen.rewriter import full_graph_rewrite
19
19
  from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction
20
20
 
21
21
  class OptOps(Enum):
22
- TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
22
+ TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
23
23
  GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
24
24
  def __lt__(self, x:OptOps): return self.value < x.value
25
25
 
@@ -32,8 +32,8 @@ def check(cond:bool, msg:str=""):
32
32
  class Opt:
33
33
  op: OptOps
34
34
  axis: Optional[int] = None
35
- amt: Optional[int] = None
36
- def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
35
+ arg: Optional[int | tuple] = None
36
+ def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
37
37
  def real_axis(self, k:Kernel):
38
38
  if self.axis is None: return -1
39
39
  if self.op is OptOps.UNROLL: return k.first_reduce+self.axis
@@ -42,10 +42,10 @@ class Opt:
42
42
 
43
43
  @dataclass
44
44
  class TensorCoreOptions:
45
- axes: Tuple[int, ...] # the location of the original N and M axes if still in the shape
46
- axes_exist: Tuple[bool, ...] # true if the original N and M axes are still in the shape
47
- axis_pads: Tuple[Tuple[int, int], ...]
48
- def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when a dimension is removed
45
+ axes: tuple[int, ...] # the location of the original N and M axes if still in the shape
46
+ axes_exist: tuple[bool, ...] # true if the original N and M axes are still in the shape
47
+ axis_pads: tuple[tuple[int, int], ...]
48
+ def fix_axes(self, removed_axis:int): # adjust the TC axes if necessary when a dimension is removed
49
49
  axes, axes_exist = list(self.axes), list(self.axes_exist)
50
50
  for tc_dim in [i for i in range(2) if axes_exist[i]]:
51
51
  if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1
@@ -57,32 +57,28 @@ class Kernel:
57
57
  if ast.op is Ops.SINK: self.ast = ast
58
58
 
59
59
  self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
60
- try: uop_sts_map = verify_ast(self.ast)
61
- except AssertionError as e:
62
- print("INVALID AST")
63
- print(self.ast)
64
- raise e
60
+ # verify AST matches the spec
61
+ if __debug__: type_verify(list(self.ast.toposort), shape_spec)
65
62
 
66
- @functools.lru_cache(None)
67
- def ordered_parents(op:UOp) -> List[UOp]: return dedup([item for x in op.src for item in ordered_parents(x)] + [op])
68
- self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS])
63
+ self.reduceops = [x for x in self.ast.toposort if x.op is Ops.REDUCE_AXIS]
69
64
 
70
- self.vars: List[Variable] = self.ast.variables()
71
- self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in GroupOp.Buffer]
65
+ self.vars: list[Variable] = self.ast.variables()
66
+ # NOTE: this requires a specific order with the [::-1], this is likely a bug
67
+ self.bufs: list[UOp] = [x for x in self.ast.toposort if x.op in GroupOp.Buffer][::-1]
72
68
 
73
69
  # get earlybufs, before any reduceops
74
- earlybufs: List[UOp] = [x for reduceop in self.reduceops for x in reduceop.parents if x.op in GroupOp.Buffer]
70
+ earlybufs: list[UOp] = [x for reduceop in self.reduceops for x in reduceop.src[0].toposort if x.op in GroupOp.Buffer]
75
71
  self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
76
72
  # NOTE: full_shape can be wrong if there's a tree of reduces
77
73
 
78
74
  # create new shapetrackers inside this kernel, we will permute them
79
- self.sts: List[ShapeTracker] = [x.st_arg for x in self.bufs]
75
+ self.sts: list[ShapeTracker] = [x.st_arg for x in self.bufs]
80
76
 
81
77
  # add the shapetrackers for each reduce
82
78
  # we use this to track which axes are reduced in each reduce
83
79
  for x in self.reduceops:
84
- self.sts.append(uop_sts_map[x])
85
- self.sts.append(uop_sts_map[x.src[0]])
80
+ self.sts.append(unwrap(x.st))
81
+ self.sts.append(unwrap(x.src[0].st))
86
82
 
87
83
  # move all reduce axes to the end
88
84
  reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
@@ -90,15 +86,13 @@ class Kernel:
90
86
  self.reshape_and_permute(None, permute)
91
87
 
92
88
  # parameters for optimization
93
- self.applied_opts: List[Opt] = []
89
+ self.applied_opts: list[Opt] = []
94
90
  self.group_for_reduces: int = 0
95
91
  self.upcasted: int = 0
96
92
  self.local_dims: int = 0
97
93
  self.tensor_core: Optional[TensorCore] = None
98
94
  self.tensor_core_opts: Optional[TensorCoreOptions] = None
99
95
  self.use_tensor_cores: int = 0
100
- # the local aliased buffers for A and B
101
- self.bufs_for_tensor_core: Dict[UOp, Tuple[int, int]] = {}
102
96
  self.dont_use_locals: bool = False
103
97
 
104
98
  # group simplifies
@@ -112,25 +106,23 @@ class Kernel:
112
106
  ret.opts, ret.ast = self.opts, self.ast
113
107
 
114
108
  # things downstream of the AST
115
- ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = \
116
- self.reduceops, self.vars, self.bufs, self.full_buf_index
109
+ ret.reduceops, ret.vars, ret.bufs, ret.full_buf_index = self.reduceops, self.vars, self.bufs, self.full_buf_index
117
110
  ret.sts = self.sts[:len(ret.bufs)+len(ret.reduceops)*2] # NOTE: must redo the local buffers with TC in beam
118
111
 
119
112
  # parameters for optimizations
120
113
  ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
121
114
  self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
122
- ret.tensor_core, ret.tensor_core_opts, ret.bufs_for_tensor_core, ret.use_tensor_cores = \
123
- self.tensor_core, self.tensor_core_opts, self.bufs_for_tensor_core, self.use_tensor_cores
115
+ ret.tensor_core, ret.tensor_core_opts, ret.use_tensor_cores = self.tensor_core, self.tensor_core_opts, self.use_tensor_cores
124
116
 
125
117
  return ret
126
118
 
127
119
  @property
128
- def membufs(self) -> List[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
120
+ def membufs(self) -> list[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
129
121
 
130
122
  # TODO: these need more tests or it might silently be no-op
131
123
  def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501
132
124
 
133
- def upcasted_axis(self, i:int) -> List[Tuple[int, Optional[sint], bool]]:
125
+ def upcasted_axis(self, i:int) -> list[tuple[int, Optional[sint], bool]]:
134
126
  upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:]
135
127
  assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
136
128
  return list(zip(upcasted_shape, upcasted_stride,
@@ -144,24 +136,20 @@ class Kernel:
144
136
  def first_upcast(self) -> int: return self.shape_len-self.upcasted
145
137
 
146
138
  @property
147
- def reduceop(self) -> Optional[UOp]: return self.reduceops[0] if len(self.reduceops) > 0 else None
139
+ def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None
148
140
 
149
141
  @property
150
- def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape
142
+ def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape
151
143
 
152
144
  @property
153
- def full_shape(self) -> Tuple[sint, ...]: return self.sts[self.full_buf_index].shape
145
+ def full_shape(self) -> tuple[sint, ...]: return self.sts[self.full_buf_index].shape
154
146
 
155
147
  @property
156
- def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.first_upcast]
148
+ def full_unupcasted_shape(self) -> tuple[sint, ...]: return self.full_shape[:self.first_upcast]
157
149
 
158
150
  @property
159
151
  def shape_len(self) -> int: return len(self.sts[0].shape)
160
152
 
161
- @property
162
- def upcast_in_mid_reduce_axes(self) -> List[int]:
163
- return [j for j in range(self.first_reduce, self.first_reduce+self.group_for_reduces) if self.full_shape[j] == self.sts[0].shape[j]]
164
-
165
153
  @property
166
154
  def global_dims(self) -> int: return self.first_reduce-self.local_dims
167
155
 
@@ -170,18 +158,17 @@ class Kernel:
170
158
  # cyan -- local dims (warp ones first)
171
159
  # *** self.first_reduce
172
160
  # green -- reduce-local dims
173
- # white -- reduce-late upcasted dim (self.upcast_in_mid_reduce_axes)
174
161
  # red -- reduce loops
175
162
  # *** self.upcasted
176
163
  # purple -- reduce upcasted
177
164
  # yellow -- normal upcasted dimensions
178
- def colors(self) -> List[str]:
165
+ def colors(self) -> list[str]:
179
166
  # first non local non reduce dims are global (blue)
180
167
  colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
181
168
  # after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
182
169
  colors += ["cyan"] * self.local_dims
183
- # between first_reduce and first_reduce + group_for_reduces, they are either upcast mid reduce (white), or late upcasted (green)
184
- colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + self.group_for_reduces)] # noqa: E501
170
+ # between first_reduce and first_reduce + group_for_reduces, they are late upcasted (green)
171
+ colors += ["green"] * self.group_for_reduces
185
172
  # between first_reduce + group_for_reduces and upcasted, they are reduce (red)
186
173
  colors += ["red"] * (self.first_upcast - (self.first_reduce + self.group_for_reduces))
187
174
  # upcasted dimensions are reduce (magenta) or normal (yellow)
@@ -198,7 +185,7 @@ class Kernel:
198
185
  # ******************** base simplifiers ********************
199
186
 
200
187
  # apply reshape and permute to all shapetrackers
201
- def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[Tuple[sint, ...]], Sequence[sint]]], axis:Optional[Sequence[int]]):
188
+ def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[tuple[sint, ...]], Sequence[sint]]], axis:Optional[Sequence[int]]):
202
189
  def reshape(st:ShapeTracker): return st.reshape(tuple(new_shape_fxn(st.shape))) if new_shape_fxn is not None else st
203
190
  def permute(st:ShapeTracker): return st.permute(tuple(axis)) if axis is not None else st
204
191
  self.sts = [permute(reshape(st)) for st in self.sts]
@@ -240,7 +227,7 @@ class Kernel:
240
227
  if isinstance(self.membufs[0].dtype, ImageDType):
241
228
  base_shape = self.membufs[0].dtype.shape
242
229
  if shape_idx_groups := get_contraction(self.output_shape, base_shape):
243
- special_strides: Tuple[sint, ...] = tuple()
230
+ special_strides: tuple[sint, ...] = tuple()
244
231
  for i,g in enumerate(shape_idx_groups):
245
232
  shape_piece = tuple(self.output_shape[x] for x in g)
246
233
  assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
@@ -298,37 +285,34 @@ class Kernel:
298
285
  s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
299
286
  axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if resolve(self.full_shape[x]%tc.dims[i] != 0))
300
287
  if axis_pads and (opt_level < 2): return None
301
- self.bufs_for_tensor_core[reduceop] = (buf0, buf1)
302
288
  if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
303
289
  return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
304
290
 
305
- def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
291
+ def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> bool:
306
292
  if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD:
307
- for tc in self.opts.tensor_cores:
293
+ tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
294
+ for tc in tensor_cores:
308
295
  tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
309
296
  # can only fuse reduces with the same tc options
310
297
  assert all_same(tensor_core_opts)
311
298
  if tensor_core_opts[0] is None: continue
312
- # tensor core -- unroll the reduce dim, upcast input and local the correct thread pattern
313
299
  self.tensor_core_opts = tc_opts = tensor_core_opts[0]
314
300
 
315
301
  # attempt to pad the tensor axes that require it
316
302
  try:
317
303
  for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
318
304
  except KernelOptError: continue
319
- for tc_dim, amt in tc.reduce_axes: self.apply_opt(Opt(OptOps.UNROLL,tc_opts.axes[2]-self.first_reduce,amt), append_opt=False)
320
- for opt in tc.opts_seq:
321
- if opt == "UP":
322
- for tc_dim, amt in tc.early_upcast_axes: self.apply_opt(Opt(OptOps.UPCAST,tc_opts.axes[tc_dim],amt), append_opt=False)
323
- elif opt == "LC":
324
- for tc_dim, amt in tc.threads: self.apply_opt(Opt(OptOps.LOCAL,tc_opts.axes[tc_dim],amt), append_opt=False)
305
+ # tensor core -- unroll the reduce dim (K), upcast and local the inner and outer dims (N, M)
306
+ for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, amt), append_opt=False)
307
+ for opt in tc.opts: self.apply_opt(Opt({"u":OptOps.UPCAST, "l":OptOps.LOCAL}[opt[0]], tc_opts.axes[int(opt[1])], 2), append_opt=False)
325
308
  self.tensor_core = tc
326
309
  self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
327
310
  return True
328
311
  return False
329
312
 
330
- def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, axis:int=0, tc_opt:Optional[int]=None) -> bool:
331
- """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
313
+ def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[list[Opt]]=None, axis:int=0, tc_select:Optional[int]=None,
314
+ tc_opt:Optional[int]=None) -> bool:
315
+ """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
332
316
  Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
333
317
 
334
318
  Keyword arguments:
@@ -337,15 +321,19 @@ class Kernel:
337
321
  1: enable tensor cores
338
322
  2: apply tensor core shape but don't use UOp.WMMA
339
323
  extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
324
+ tc_select -- specifies which tensor core(s) to use for optimization (default -1)
325
+ -1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
326
+ [0-N]: uses only the n'th tensor core available; useful for search
340
327
  tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
341
328
  0: applies to only kernels with a single reduce axis and direct UOps.LOAD into Ops.MUL
342
329
  1: allows kernels with multiple reduce axes and also multiplication of UOps.CAST'd buffers
343
330
  2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
344
331
  """
332
+ if tc_select is None: tc_select = TC_SELECT.value
345
333
  if tc_opt is None: tc_opt = TC_OPT.value
346
334
  if not self.opts.tensor_cores and use_tensor_cores != 2: return False
347
335
  try: # check TC first and apply hand-coded opts if successful
348
- self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
336
+ self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt)))
349
337
 
350
338
  if (tc_opts:=self.tensor_core_opts) is not None:
351
339
  if extra_opts is not None:
@@ -364,24 +352,28 @@ class Kernel:
364
352
  return False
365
353
 
366
354
  def apply_opt(self, opt:Opt, append_opt:bool=True):
367
- if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals")
355
+ if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals")
368
356
 
369
357
  if opt.op is OptOps.TC:
370
358
  check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
371
- check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
372
359
  check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
373
- check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available")
360
+ check(opt.axis is not None, "tensor core opts must have an axis")
361
+ check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 2, "tensor core opts must have tc_select and tc_opt")
362
+ check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
363
+ check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
364
+ check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available")
374
365
  self.applied_opts.append(opt)
375
366
  return
376
367
 
377
368
  axis = opt.real_axis(self)
378
369
  check(axis < len(self.full_shape), "invalid axis")
379
370
 
380
- if opt.op is OptOps.SWAP: amt = cast(int, opt.amt) # amt is an axis in the SWAPs
381
- elif opt.amt is not None:
382
- amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
383
- check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless")
384
- if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
371
+ if opt.op is OptOps.SWAP: amt = cast(int, opt.arg) # arg is an axis in the SWAPs
372
+ elif opt.arg is not None:
373
+ check(isinstance(opt.arg, int), "arg should be int")
374
+ amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis]
375
+ check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless")
376
+ if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}")
385
377
  else: amt = -1
386
378
 
387
379
  if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
@@ -416,18 +408,10 @@ class Kernel:
416
408
  self.upcast()
417
409
  elif opt.op is OptOps.UPCAST: # yellow
418
410
  check(axis < self.first_reduce, "upcast is for non-reduce")
419
- check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
411
+ check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.get_local_axes())), "can't upcast TC locals")
420
412
  check(amt <= 16, "don't upcast more than 16")
421
413
  self.shift_to(axis, amt, insert_before=None)
422
414
  self.upcast()
423
- elif opt.op is OptOps.UPCASTMID: # white
424
- check(self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
425
- axes = self.sts[0].unit_stride_axes()
426
- check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
427
- check(axes[0] == axis, "wrong axis")
428
- check(amt == 4, "don't upcast mid anything but 4")
429
- self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces)
430
- self.group_for_reduces += 1
431
415
  elif opt.op is OptOps.NOLOCALS:
432
416
  check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
433
417
  check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
@@ -441,7 +425,7 @@ class Kernel:
441
425
  check(not self.vars, "does not work with symbolic shape")
442
426
  check(axis < self.first_upcast, "cannot pad upcasted")
443
427
  # ok to pad SUM if all parent ALU ops have f(0) = 0
444
- if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r), f"cannot pad {r}")
428
+ if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}, {}), f"cannot pad {r}")
445
429
  padded = False
446
430
  for i,st in enumerate(self.sts):
447
431
  if (s:=st.shape[axis]) == 1: continue # reduced
@@ -460,8 +444,7 @@ class Kernel:
460
444
  if isinstance(self.membufs[0].dtype, ImageDType):
461
445
  unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
462
446
  assert unit_stride_axes_mul_4, f"needs a unit stride axis in {self.bufs[0]}"
463
- if all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
464
- self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
447
+ if all(x < self.first_upcast for x in unit_stride_axes_mul_4): self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
465
448
  return self
466
449
 
467
450
  def hand_coded_optimizations(self) -> Kernel:
@@ -496,19 +479,12 @@ class Kernel:
496
479
  break
497
480
  except KernelOptError: pass
498
481
 
499
- # are we upcasting in mid reduce? (only for images)
500
- if self.bufs[0].src[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
501
- axes = self.sts[0].unit_stride_axes()
502
- assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
503
- if self.sts[0].shape[axes[0]]%4 == 0:
504
- self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4))
505
-
506
482
  # upcast float4 images
507
483
  for buf_index,buf in enumerate(self.bufs):
508
484
  unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
509
485
  if buf.src[0].dtype.__class__ is ImageDType:
510
486
  #assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
511
- if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
487
+ if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4):
512
488
  if unit_stride_axes_mul_4[0] < self.first_reduce:
513
489
  self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
514
490
  else:
@@ -524,7 +500,7 @@ class Kernel:
524
500
  # expression and run test/test_ops.py with IMAGE=2
525
501
  # if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
526
502
  # this can be made much smarter
527
- to_upcast: List[int] = []
503
+ to_upcast: list[int] = []
528
504
  # upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
529
505
  for axis in range(self.first_reduce):
530
506
  # we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
@@ -577,7 +553,7 @@ class Kernel:
577
553
  else:
578
554
  # prioritize making expand axes local
579
555
  local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] # noqa: E501
580
- to_local: List[Tuple[int, int]] = []
556
+ to_local: list[tuple[int, int]] = []
581
557
  for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
582
558
  local_size = prod(sz for _, sz in to_local)
583
559
  local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None) # noqa: E501
@@ -593,11 +569,11 @@ class Kernel:
593
569
 
594
570
  # **** kernel outputs ****
595
571
 
596
- kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
572
+ kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
597
573
  @functools.cached_property
598
574
  def name(self) -> str:
599
575
  # kernel name (before late upcast)
600
- kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op in GroupOp.Buffer for x in self.ast.parents) else "E")
576
+ kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op is Ops.SINK or x.op in GroupOp.Buffer for x in self.ast.toposort) else "E")
601
577
  suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())])
602
578
  name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix
603
579
 
@@ -612,7 +588,10 @@ class Kernel:
612
588
  ret = op.replace(src=tuple(fixup_ast(x) for x in op.src))
613
589
  if op.op in GroupOp.Buffer and op in self.bufs:
614
590
  st_uop = self.sts[self.bufs.index(op)].to_uop()
615
- return ret.replace(src=(st_uop,)) if op.op is Ops.VALID else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
591
+ # NOTE: if CONST got masked after applying opts, we create a new VALID
592
+ if op.op is Ops.CONST and any(v.mask is not None for v in unwrap(st_uop.st).views): return op.valid(unwrap(st_uop.st))
593
+ # otherwise we just replace the VIEW source
594
+ return ret.replace(src=(st_uop,)) if len(op.src) == 1 else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
616
595
  if op.op is Ops.SINK: return ret.replace(arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals))
617
596
  if op.op is Ops.REDUCE_AXIS:
618
597
  reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
@@ -623,47 +602,43 @@ class Kernel:
623
602
  grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces)
624
603
 
625
604
  if (tc := self.tensor_core) and (self.use_tensor_cores == 1 or self.use_tensor_cores == 3):
626
- def fix_st(st: ShapeTracker, wd_pattern, tcd_pattern):
627
- wd, warp_dims = self.global_dims, tuple(sz for _, sz in tc.threads)
628
- tcd, tcd_dims = self.first_upcast, tuple(sz for _, sz in tc.reduce_axes + tc.early_upcast_axes)
629
-
630
- assert st.shape[wd:wd+len(warp_dims)] == warp_dims, f"warp dims wrong: {st.shape[wd:wd+len(warp_dims)]=} != {warp_dims=}"
631
- assert st.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, f"tcd dims wrong: {st.shape[tcd:tcd+len(tcd_dims)]=} != {tcd_dims=}"
632
- assert tc.expanded_shape is not None
633
-
634
- new_shape = st.shape[:tcd] + tc.expanded_shape + st.shape[tcd+len(tcd_dims):] # expand the tcd
635
- permaxis = list(range(wd)) + [y + (wd if x == 0 else tcd) for x,y in wd_pattern] + list(range(wd+len(warp_dims),tcd)) + \
636
- [y + (wd if x == 0 else tcd) for x,y in tcd_pattern] + list(range(tcd+len(tc.expanded_shape),len(new_shape)))
637
- return st.reshape(new_shape).permute(tuple(permaxis)).reshape(st.shape).simplify()
605
+ wd, tcd = self.global_dims, self.first_upcast
606
+ def get_upcast_axes(buf): # upcast along non-zero dimensions of (tc_reduce + tc_upcast)
607
+ upcast_axes = int(math.log2(tc.elements_per_thread[buf]))
608
+ return tuple((tcd + len(tc.get_reduce_axes()) + len(tc.get_upcast_axes()) - (i+1), 2) for i in range(upcast_axes))
609
+ def get_tc_swizzle_st(shape, local_perm, upcast_perm):
610
+ offset = (tcd - (wd + len(local_perm)))
611
+ permaxis = list(range(wd)) \
612
+ + [wd + x + (offset if x >= len(local_perm) else 0) for x in local_perm] + list(range(wd + len(local_perm), tcd)) \
613
+ + [wd + x + (offset if x >= len(local_perm) else 0) for x in upcast_perm] + list(range(tcd + len(upcast_perm), len(shape)))
614
+ return ShapeTracker.from_shape(shape).permute(tuple(permaxis))
638
615
 
639
616
  srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
640
- for i, tc_pattern in enumerate([tc.st1_pattern, tc.st2_pattern]):
641
- if tc_pattern: srcs[i] = srcs[i].view(fix_st(unwrap(srcs[i].st), *tc_pattern))
617
+ for i, (src, swizzle) in enumerate(zip(srcs, tc.swizzle)):
618
+ if swizzle: srcs[i] = src.view(get_tc_swizzle_st((src if src.op is Ops.LOAD else src.src[0]).st_arg.shape, *swizzle))
642
619
 
643
620
  if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
644
621
  local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape))
645
622
  st = store_st = ShapeTracker.from_shape(local_shape)
646
- local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{i + 1}", st.real_size()))
647
- if tc_pattern: store_st = fix_st(store_st, *tc_pattern)
623
+ local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i}")
624
+ if swizzle: store_st = get_tc_swizzle_st(store_st.shape, *swizzle)
648
625
  local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i])
649
626
  srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store))
650
627
 
651
- tc_reduce_axes = tuple(self.first_upcast + ax for ax, _ in tc.reduce_axes)
652
- if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/EXPAND to get the vectorization right
653
- upcast_axes = tuple(tuple((self.first_upcast + ax, sz) for ax, sz in up) for up in tc.upcast_axes)
654
- wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, prod(sz for _, sz in tc.threads), upcast_axes, tc_reduce_axes)
655
- wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
656
- wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=(
657
- UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(wmma_sz[0]), src=(srcs[0],), arg=upcast_axes[0]),
658
- UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(wmma_sz[1]), src=(srcs[1],), arg=upcast_axes[1]),
659
- UOp.const(tc.dtype_out.vec(wmma_sz[2]), 0.0)), arg=wmma_arg)
660
- tc_uop = UOp(Ops.EXPAND, tc.dtype_out, (wmma,), arg=upcast_axes[2])
628
+ tc_reduce_axes = tuple(tcd + ax for ax, _ in tc.get_reduce_axes())
629
+ if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/UNROLL to get the vectorization right
630
+ tc_upcast_axes = (get_upcast_axes(0), get_upcast_axes(1), get_upcast_axes(2))
631
+ wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes)
632
+ wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
633
+ UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
634
+ UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]),
635
+ UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
636
+ tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])
661
637
 
662
638
  else: # for TC=3 MUL/SUM instead of WMMA
663
639
  tc_uop = UOp(Ops.REDUCE_AXIS, tc.dtype_out, ((srcs[0] * srcs[1]).cast(tc.dtype_out),), (Ops.ADD, tc_reduce_axes))
664
640
 
665
- new_reduce_axes = tuple(i for i in axes if i not in tc_reduce_axes)
666
- return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_reduce_axes)) if new_reduce_axes else tc_uop
641
+ return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop
667
642
 
668
643
  ret = ret.replace(arg = (op.arg[0], axes))
669
644
  if self.group_for_reduces and grouped_axes:
@@ -672,7 +647,8 @@ class Kernel:
672
647
  for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
673
648
  (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
674
649
  st_uop = ShapeTracker.from_shape(local_shape).to_uop()
675
- local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
650
+ local_size = st_uop.arg.real_size()
651
+ local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
676
652
  local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret)))
677
653
  grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
678
654
  if op is self.reduceops[-1]: return grouped_reduce
@@ -681,9 +657,7 @@ class Kernel:
681
657
 
682
658
  return ret
683
659
 
684
- return graph_rewrite(fixup_ast(self.ast), PatternMatcher([
685
- (UPat({*GroupOp.ALU,Ops.CAST,Ops.BITCAST,Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
686
- (UPat(Ops.LOAD, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src)))]))
660
+ return graph_rewrite(fixup_ast(self.ast), view_left)
687
661
 
688
662
  # **** this is the lowerer ****
689
663
 
@@ -696,58 +670,26 @@ class Kernel:
696
670
  if getenv("RAWAST"): print(self.ast)
697
671
  print(modified_ast)
698
672
  print(self.applied_opts)
699
- verify_ast(modified_ast)
673
+ # verify AST matches the spec after applying opts
674
+ if __debug__: type_verify(list(modified_ast.toposort))
675
+ # TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
676
+ #if __debug__: type_verify(list(modified_ast.toposort), shape_spec)
700
677
 
701
- self.uops:List[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
678
+ self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
702
679
  if DEBUG >= 5: print_uops(self.uops)
703
680
  return self
704
681
 
705
- def to_program(self, name_override:Optional[str]=None) -> Program:
682
+ def to_program(self, name_override:Optional[str]=None) -> ProgramSpec:
706
683
  self.linearize()
707
684
  src = self.opts.render(name:=to_function_name(ansiname:=(name_override if name_override is not None else self.name)), self.uops)
708
685
 
709
- if getenv("RUN_PROCESS_REPLAY"):
710
- from test.external.process_replay.helpers import get_process_replay_ctx
711
- diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, *get_process_replay_ctx(), src))
686
+ if CAPTURE_PROCESS_REPLAY:
687
+ diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, ContextVar._cache, src))
712
688
 
713
689
  # group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
714
690
  # TODO: these max and min don't work on symbolic, and results are very wrong.
715
691
  mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
716
- for _, group in itertools.groupby([x for x in self.ast.parents if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
692
+ for _, group in itertools.groupby([x for x in self.ast.toposort if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
717
693
  key=lambda x: (x.op, x.src[0].arg)))
718
- return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
694
+ return ProgramSpec(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
719
695
  global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
720
-
721
- # the living definition of intermediate UOps
722
-
723
- def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
724
- if not uop.has_st or uop in sts: return
725
- # restore globals from the two stage reduce
726
- if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL:
727
- _assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
728
- sts[uop] = sts[local_reduce]
729
- return
730
- for x in uop.src: _assert_valid_uop(x, st, sts)
731
- # only reduceuop is allowed to change shape, limited to turning n to 1
732
- if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
733
- # movementops are pushed to VIEW
734
- elif uop.op is Ops.VIEW:
735
- assert len(uop.src) == 0, f"can't swizzle in kernel yet {uop}"
736
- st = uop.arg
737
- # everything else inherits shape
738
- else:
739
- st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0]
740
- if not all_same(shapes:=[x.shape for x in src_sts]):
741
- if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
742
- raise AssertionError(f"found implicit expand {sizes} {shapes}")
743
- sts[uop] = st
744
-
745
- def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
746
- assert ast.op is Ops.SINK and all(x.op is Ops.STORE for x in ast.src), "must be SINK"
747
- assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size"
748
- sts: Dict[UOp, ShapeTracker] = {}
749
- for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
750
- shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
751
- assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
752
- type_verify(list(sts))
753
- return sts