tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tinygrad/codegen/kernel.py +248 -115
  2. tinygrad/codegen/lowerer.py +215 -0
  3. tinygrad/codegen/transcendental.py +310 -0
  4. tinygrad/codegen/uopgraph.py +622 -0
  5. tinygrad/codegen/uops.py +235 -393
  6. tinygrad/device.py +428 -69
  7. tinygrad/dtype.py +18 -4
  8. tinygrad/engine/graph.py +19 -32
  9. tinygrad/engine/jit.py +148 -70
  10. tinygrad/engine/realize.py +127 -51
  11. tinygrad/engine/schedule.py +259 -216
  12. tinygrad/engine/search.py +29 -22
  13. tinygrad/function.py +9 -0
  14. tinygrad/helpers.py +87 -49
  15. tinygrad/lazy.py +34 -35
  16. tinygrad/multi.py +41 -36
  17. tinygrad/nn/__init__.py +39 -22
  18. tinygrad/nn/state.py +3 -3
  19. tinygrad/ops.py +63 -62
  20. tinygrad/renderer/__init__.py +43 -21
  21. tinygrad/renderer/assembly.py +104 -106
  22. tinygrad/renderer/cstyle.py +87 -60
  23. tinygrad/renderer/llvmir.py +21 -30
  24. tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
  25. tinygrad/runtime/autogen/cuda.py +6 -162
  26. tinygrad/runtime/autogen/kfd.py +32 -0
  27. tinygrad/runtime/autogen/libc.py +4260 -0
  28. tinygrad/runtime/autogen/nvrtc.py +579 -0
  29. tinygrad/runtime/graph/clang.py +2 -2
  30. tinygrad/runtime/graph/cuda.py +8 -11
  31. tinygrad/runtime/graph/hcq.py +120 -107
  32. tinygrad/runtime/graph/metal.py +18 -15
  33. tinygrad/runtime/ops_amd.py +197 -305
  34. tinygrad/runtime/ops_clang.py +2 -2
  35. tinygrad/runtime/ops_cuda.py +36 -94
  36. tinygrad/runtime/ops_disk.py +3 -7
  37. tinygrad/runtime/ops_gpu.py +4 -2
  38. tinygrad/runtime/ops_hip.py +70 -0
  39. tinygrad/runtime/ops_metal.py +38 -27
  40. tinygrad/runtime/ops_nv.py +283 -363
  41. tinygrad/runtime/ops_python.py +26 -30
  42. tinygrad/runtime/support/compiler_cuda.py +78 -0
  43. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
  44. tinygrad/runtime/support/elf.py +38 -0
  45. tinygrad/shape/shapetracker.py +5 -14
  46. tinygrad/shape/symbolic.py +4 -8
  47. tinygrad/shape/view.py +34 -22
  48. tinygrad/tensor.py +399 -97
  49. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
  50. tinygrad-0.9.2.dist-info/RECORD +70 -0
  51. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
  52. tinygrad/codegen/linearizer.py +0 -528
  53. tinygrad-0.9.1.dist-info/RECORD +0 -63
  54. /tinygrad/runtime/{driver → support}/__init__.py +0 -0
  55. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
  56. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
@@ -1,42 +1,64 @@
1
- from typing import List, Dict, Optional, cast, Generator, Tuple
2
- import time
1
+ from typing import List, Dict, Optional, cast, Generator, Tuple, Union
2
+ import time, pprint
3
+ from collections import defaultdict
3
4
  from dataclasses import dataclass, replace
4
- from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING
5
- from tinygrad.ops import BufferOps, LoadOps, LazyOp
5
+ from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA, dedup
6
+ from tinygrad.ops import MetaOps, LazyOp
7
+ from tinygrad.dtype import dtypes
6
8
  from tinygrad.device import Device, Buffer
7
9
  from tinygrad.shape.symbolic import Variable, sym_infer, sint
8
10
  from tinygrad.renderer import Renderer, Program
9
- from tinygrad.codegen.linearizer import Linearizer
11
+ from tinygrad.codegen.kernel import Kernel
10
12
  from tinygrad.engine.schedule import ScheduleItem
11
13
 
12
14
  # **************** Program Creation ****************
13
15
 
14
16
  logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
15
- def get_linearizer(renderer:Renderer, ast:Tuple[LazyOp, ...]) -> Linearizer:
16
- if DEBUG >= 3:
17
- from tinygrad.engine.graph import print_tree
18
- for op in ast: print_tree(op)
19
- k = Linearizer(*ast, opts=renderer)
20
- k.required_optimizations()
17
+ def get_kernel(renderer:Renderer, ast:LazyOp) -> Kernel:
18
+ if DEBUG >= 5:
19
+ print(ast)
20
+ k = Kernel(ast, opts=renderer).required_optimizations()
21
21
  if not NOOPT:
22
22
  if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
23
23
  if BEAM >= 1:
24
24
  from tinygrad.engine.search import beam_search, time_linearizer, bufs_from_lin
25
- kb, k_opt = Linearizer(*ast, opts=renderer), k
26
- kb.required_optimizations()
25
+ kb, k_opt = Kernel(ast, opts=renderer).required_optimizations(), k
27
26
  rawbufs = bufs_from_lin(kb, allocate=False)
28
- k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
29
- if getenv("BEAM_COMPARE", 1):
27
+ if BEAM.value >= 100:
28
+ from extra.mcts_search import mcts_search
29
+ k = mcts_search(kb, rawbufs, BEAM.value)
30
+ else:
31
+ k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
32
+ if beam_compare:=getenv("BEAM_COMPARE", 1):
30
33
  # TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
31
- lins: List[Tuple[str, Linearizer]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
32
- if used_tensor_cores:
33
- lins.append(("hc", Linearizer(*ast, opts=renderer)))
34
- lins[-1][1].hand_coded_optimizations()
34
+ lins: List[Tuple[str, Kernel]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
35
+ if used_tensor_cores: lins.append(("hc", Kernel(ast, opts=renderer).hand_coded_optimizations()))
35
36
  timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
36
37
  if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
37
38
  k = timed[0][1]
38
39
  if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
39
- # TODO: check the correctness inline once compare_linearizer is in core
40
+ if beam_compare == 2:
41
+ from tinygrad import Tensor
42
+ all_outs: List[List[Tensor]] = []
43
+ with Context(DEBUG=0, BEAM=0, CAPTURING=0):
44
+ rand_bufs = [Tensor.normal(buf.size, std=0.1, dtype=buf.dtype).data() if dtypes.is_float(buf.dtype) else \
45
+ (Tensor.randint(buf.size, low=0, high=2).cast(buf.dtype).data() if buf.dtype == dtypes.bool else \
46
+ Tensor.randint(buf.size, low=dtypes.min(buf.dtype), high=dtypes.max(buf.dtype), dtype=buf.dtype).data()) \
47
+ for buf in rawbufs]
48
+ for _, tk in lins[::-1]:
49
+ for buf,data in zip(rawbufs, rand_bufs): buf.ensure_allocated().copyin(data)
50
+ time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True, disable_cache=True)
51
+ all_outs.append([Tensor(bytes(buf.as_buffer()), dtype=buf.dtype) for buf in rawbufs[:len(ast.src)]])
52
+ with Context(DEBUG=0, BEAM=0, CAPTURING=0):
53
+ for bufs in zip(*all_outs):
54
+ for b in bufs[1:]:
55
+ if dtypes.is_float(bufs[0].dtype):
56
+ # we check both atol and rtol here
57
+ diff_count = (((b-bufs[0]).abs() > 1e-3) * (((b-bufs[0])/bufs[0]).abs() > 1e-3)).sum().item()
58
+ else:
59
+ diff_count = (b != bufs[0]).sum().item()
60
+ if diff_count != 0:
61
+ raise RuntimeError(f"mismatch of {diff_count}/{b.numel()} items with type {b.dtype}, max {(b-bufs[0]).abs().max().item()}")
40
62
  if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
41
63
  if DEBUG >= 5: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
42
64
  return k
@@ -44,8 +66,9 @@ def get_linearizer(renderer:Renderer, ast:Tuple[LazyOp, ...]) -> Linearizer:
44
66
  # **************** Runners ****************
45
67
 
46
68
  class Runner:
47
- def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0):
48
- self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate = True, display_name, dname, op_estimate, mem_estimate
69
+ def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0, lds_estimate:Optional[sint]=None):
70
+ self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate, self.lds_estimate = \
71
+ True, display_name, dname, op_estimate, mem_estimate, mem_estimate if lds_estimate is None else lds_estimate
49
72
  @property
50
73
  def device(self): return Device[self.dname]
51
74
  def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
@@ -59,7 +82,7 @@ class CompiledRunner(Runner):
59
82
  self.p:Program = p
60
83
  self.lib:bytes = precompiled if precompiled is not None else Device[p.dname].compiler.compile_cached(p.src)
61
84
  self.clprg = Device[p.dname].runtime(p.function_name, self.lib)
62
- super().__init__(p.name, p.dname, p.op_estimate, p.mem_estimate)
85
+ super().__init__(p.name, p.dname, p.op_estimate, p.mem_estimate, p.lds_estimate)
63
86
 
64
87
  def __reduce__(self): return self.__class__, (self.p, self.lib)
65
88
 
@@ -73,10 +96,10 @@ class CompiledRunner(Runner):
73
96
  self.p = replace(self.p, global_size=global_size, local_size=local_size)
74
97
  lra = {}
75
98
  if global_size:
76
- lra['global_size'] = global_size
99
+ lra['global_size'] = tuple(global_size)
77
100
  assert len(global_size) == 3, "global size must have len 3"
78
101
  if local_size:
79
- lra['local_size'] = local_size
102
+ lra['local_size'] = tuple(local_size)
80
103
  assert len(local_size) == 3, "local size must have len 3"
81
104
  return self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait)
82
105
 
@@ -119,24 +142,20 @@ class BufferCopy(Runner):
119
142
  return time.perf_counter() - st
120
143
 
121
144
  class BufferXfer(BufferCopy):
122
- def copy(self, dest, src):
123
- if hasattr(dest.allocator.device, "track_cross_buffer") and hasattr(src.allocator, "track_cross_device"):
124
- dest.allocator.device.track_cross_buffer.append(src)
125
- src.allocator.track_cross_device.add(dest.allocator.device)
126
- dest.allocator.transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.device, dest_dev=dest.allocator.device)
145
+ def copy(self, dest, src): dest.allocator.transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.device, dest_dev=dest.allocator.device)
127
146
 
128
147
  # **************** method cache ****************
129
148
 
130
- method_cache: Dict[Tuple[str, Tuple[LazyOp, ...], int, bool], CompiledRunner] = {}
131
- def get_runner(dname:str, ast:Tuple[LazyOp, ...]) -> CompiledRunner:
132
- ckey = (dname, ast, BEAM.value, False)
149
+ method_cache: Dict[Tuple[str, LazyOp, int, int, bool], CompiledRunner] = {}
150
+ def get_runner(dname:str, ast:LazyOp) -> CompiledRunner:
151
+ ckey = (dname, ast, BEAM.value, NOOPT.value, False)
133
152
  if cret:=method_cache.get(ckey): return cret
134
- bkey = (dname.split(":")[0], ast, BEAM.value, True)
153
+ bkey = (dname.split(":")[0], ast, BEAM.value, NOOPT.value, True)
135
154
  if bret:=method_cache.get(bkey):
136
155
  method_cache[ckey] = ret = CompiledRunner(replace(bret.p, dname=dname), bret.lib)
137
156
  else:
138
- prg: Program = get_linearizer(Device[dname].renderer, ast).to_program()
139
- if hasattr(prg.uops, "fuzz_paths"):
157
+ prg: Program = get_kernel(Device[dname].renderer, ast).to_program()
158
+ if getenv("FUZZ_UOPS"):
140
159
  from test.external.fuzz_uops import UOpsFuzzerRunner
141
160
  return UOpsFuzzerRunner(replace(prg, dname=dname))
142
161
  method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, dname=dname))
@@ -148,39 +167,51 @@ def get_runner(dname:str, ast:Tuple[LazyOp, ...]) -> CompiledRunner:
148
167
  class ExecItem:
149
168
  prg: Runner
150
169
  bufs: List[Optional[Buffer]]
170
+ metadata: Optional[List[Metadata]] = None
151
171
  def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
152
172
  bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
153
173
  et = self.prg(bufs, var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
154
174
  if do_update_stats:
155
175
  GlobalCounters.kernel_count += 1
156
- GlobalCounters.global_ops += (op_estimate:=sym_infer(self.prg.op_estimate, var_vals))
157
- GlobalCounters.global_mem += (mem_estimate:=sym_infer(self.prg.mem_estimate, var_vals))
176
+ GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.op_estimate, var_vals))
177
+ GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.mem_estimate, var_vals))
158
178
  if et is not None: GlobalCounters.time_sum_s += et
159
179
  if DEBUG >= 2:
180
+ lds_est = sym_infer(self.prg.lds_estimate, var_vals)
181
+ mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
160
182
  ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
161
- print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(38-ansilen(self.prg.display_name))} arg {len(self.bufs):3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
162
- (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # noqa: E501
183
+ print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(40-ansilen(self.prg.display_name))} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
184
+ (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501
185
+ f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}"))
163
186
  self.prg.first_run = False
164
187
  return et
165
188
 
166
189
  def lower_schedule_item(si:ScheduleItem) -> ExecItem:
167
- assert len(set(x.device for x in si.bufs)) == 1 or si.ast[0].op is LoadOps.COPY or getenv("USE_COPY_KERNEL")
168
- if si.ast[0].op is BufferOps.STORE:
190
+ assert len(set(x.device for x in si.bufs)) == 1 or (si.ast.op is MetaOps.EXT and si.ast.arg[0] is MetaOps.COPY) or getenv("USE_COPY_KERNEL")
191
+ if si.ast.op is MetaOps.KERNEL:
169
192
  runner = get_runner(si.outputs[0].device, si.ast)
170
- return ExecItem(runner, [si.bufs[x[0]] for x in runner.p.globals])
171
- out, ast = si.outputs[0], si.ast[0]
172
- if ast.op is LoadOps.COPY:
193
+ return ExecItem(runner, [si.bufs[x] for x in runner.p.globals], si.metadata)
194
+ out, (op, arg) = si.outputs[0], si.ast.arg
195
+ if op is MetaOps.COPY:
173
196
  kernel_type = BufferCopy
174
197
  if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
175
198
  kernel_type = BufferXfer
176
- return ExecItem(kernel_type(ast.arg, out.device, si.inputs[0].device), list(si.bufs))
177
- if ast.op is LoadOps.CUSTOM: return ExecItem(CustomOp(ast.arg), list(si.bufs))
178
- if ast.op is LoadOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
179
- if ast.op is LoadOps.VIEW: return ExecItem(ViewOp(out), list(si.bufs))
180
- raise RuntimeError(f"don't know how to lower {ast}")
199
+ return ExecItem(kernel_type(arg, out.device, si.inputs[0].device), list(si.bufs))
200
+ if op is MetaOps.CUSTOM: return ExecItem(CustomOp(arg), list(si.bufs))
201
+ if op is MetaOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
202
+ if op is MetaOps.VIEW: return ExecItem(ViewOp(out), list(si.bufs))
203
+ raise RuntimeError(f"don't know how to lower {si.ast}")
181
204
 
182
205
  def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
183
- while len(schedule): yield lower_schedule_item(schedule.pop(0))
206
+ while len(schedule):
207
+ si = schedule.pop(0)
208
+ try: yield lower_schedule_item(si)
209
+ except Exception as e:
210
+ if DEBUG >= 2:
211
+ print(f"error lowering {si.ast.op}")
212
+ print("tensor operations:")
213
+ pprint.pprint(si.metadata, indent=2)
214
+ raise e
184
215
 
185
216
  # **************** main run function ****************
186
217
 
@@ -190,3 +221,48 @@ def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, i
190
221
  for ei in lower_schedule(schedule):
191
222
  if len(capturing) and CAPTURING: capturing[0].add(ei)
192
223
  ei.run(var_vals, do_update_stats=do_update_stats)
224
+
225
+ # **************** memory planning ****************
226
+
227
+ def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], noopt_buffers=None, debug_prefix="") -> Dict[Buffer, Buffer]:
228
+ if getenv("NO_MEMORY_PLANNER"): return {}
229
+ first_appearance, last_appearance = {}, {}
230
+ for i,u in enumerate(buffers):
231
+ for buf in u:
232
+ if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
233
+ if buf.base not in first_appearance: first_appearance[buf.base] = i
234
+ last_appearance[buf.base] = i
235
+
236
+ # Sort buffers by size in descending order, prioritizing largest buffers for allocation first.
237
+ # Track free segments, each containing (start, stop, and buffer that could be reused on this segment).
238
+ free_segs: Dict[Tuple, List[Tuple[int, int, Buffer]]] = defaultdict(list) # Dict[buffer key, Tuple[start, end, buffer to reuse on the seg]]
239
+ def find_replace_buffer(buf, st, en):
240
+ key = (buf.device, buf.dtype, buf.options) + ((buf.nbytes,) if not hasattr(Device[buf.device].allocator, "offset") else tuple())
241
+
242
+ default_buf = (0, len(buffers) - 1, buf) # will return the buffer itself if the replace one is not found.
243
+ seg_st, seg_en, seg_buf = next((free_segs[key].pop(i) for i,(sst,sen,_) in enumerate(free_segs[key]) if sst <= st and en <= sen), default_buf)
244
+
245
+ free_segs[key] += [(seg_st, st - 1, seg_buf)] if st - 1 >= seg_st else []
246
+ free_segs[key] += [(en + 1, seg_en, seg_buf)] if seg_en >= en + 1 else []
247
+
248
+ return seg_buf if seg_buf.nbytes == buf.nbytes else Buffer(buf.device, buf.size, buf.dtype, base=seg_buf)
249
+
250
+ buffer_requests = sorted([(first_appearance[buf], last_appearance[buf], buf) for buf in first_appearance.keys()], key=lambda x: -x[2].nbytes)
251
+ assigned = {buf:find_replace_buffer(buf, st, en) for st, en, buf in buffer_requests}
252
+
253
+ for i,u in enumerate(buffers):
254
+ for buf in u:
255
+ if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
256
+ if buf._base is not None: assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf.base, buf.base).base, offset=buf.offset)
257
+ else: assigned[buf] = assigned.get(buf, buf)
258
+
259
+ if DEBUG >= 1 and len(ak:=dedup(x for x in assigned.keys() if x._base is None)) != len(av:=dedup(x for x in assigned.values() if x._base is None)):
260
+ print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
261
+ f"{len(ak)} -> {len(av)} bufs")
262
+ return assigned
263
+
264
+ def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
265
+ # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
266
+ assigned = _internal_memory_planner([si.bufs for si in schedule],
267
+ noopt_buffers={b for si in schedule if si.ast.op is not MetaOps.KERNEL for b in si.bufs})
268
+ return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]