tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +248 -115
- tinygrad/codegen/lowerer.py +215 -0
- tinygrad/codegen/transcendental.py +310 -0
- tinygrad/codegen/uopgraph.py +622 -0
- tinygrad/codegen/uops.py +235 -393
- tinygrad/device.py +428 -69
- tinygrad/dtype.py +18 -4
- tinygrad/engine/graph.py +19 -32
- tinygrad/engine/jit.py +148 -70
- tinygrad/engine/realize.py +127 -51
- tinygrad/engine/schedule.py +259 -216
- tinygrad/engine/search.py +29 -22
- tinygrad/function.py +9 -0
- tinygrad/helpers.py +87 -49
- tinygrad/lazy.py +34 -35
- tinygrad/multi.py +41 -36
- tinygrad/nn/__init__.py +39 -22
- tinygrad/nn/state.py +3 -3
- tinygrad/ops.py +63 -62
- tinygrad/renderer/__init__.py +43 -21
- tinygrad/renderer/assembly.py +104 -106
- tinygrad/renderer/cstyle.py +87 -60
- tinygrad/renderer/llvmir.py +21 -30
- tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/kfd.py +32 -0
- tinygrad/runtime/autogen/libc.py +4260 -0
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/graph/clang.py +2 -2
- tinygrad/runtime/graph/cuda.py +8 -11
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +18 -15
- tinygrad/runtime/ops_amd.py +197 -305
- tinygrad/runtime/ops_clang.py +2 -2
- tinygrad/runtime/ops_cuda.py +36 -94
- tinygrad/runtime/ops_disk.py +3 -7
- tinygrad/runtime/ops_gpu.py +4 -2
- tinygrad/runtime/ops_hip.py +70 -0
- tinygrad/runtime/ops_metal.py +38 -27
- tinygrad/runtime/ops_nv.py +283 -363
- tinygrad/runtime/ops_python.py +26 -30
- tinygrad/runtime/support/compiler_cuda.py +78 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/shape/shapetracker.py +5 -14
- tinygrad/shape/symbolic.py +4 -8
- tinygrad/shape/view.py +34 -22
- tinygrad/tensor.py +399 -97
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
- tinygrad-0.9.2.dist-info/RECORD +70 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/runtime/{driver → support}/__init__.py +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/engine/realize.py
CHANGED
@@ -1,42 +1,64 @@
|
|
1
|
-
from typing import List, Dict, Optional, cast, Generator, Tuple
|
2
|
-
import time
|
1
|
+
from typing import List, Dict, Optional, cast, Generator, Tuple, Union
|
2
|
+
import time, pprint
|
3
|
+
from collections import defaultdict
|
3
4
|
from dataclasses import dataclass, replace
|
4
|
-
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING
|
5
|
-
from tinygrad.ops import
|
5
|
+
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA, dedup
|
6
|
+
from tinygrad.ops import MetaOps, LazyOp
|
7
|
+
from tinygrad.dtype import dtypes
|
6
8
|
from tinygrad.device import Device, Buffer
|
7
9
|
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
8
10
|
from tinygrad.renderer import Renderer, Program
|
9
|
-
from tinygrad.codegen.
|
11
|
+
from tinygrad.codegen.kernel import Kernel
|
10
12
|
from tinygrad.engine.schedule import ScheduleItem
|
11
13
|
|
12
14
|
# **************** Program Creation ****************
|
13
15
|
|
14
16
|
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
|
15
|
-
def
|
16
|
-
if DEBUG >=
|
17
|
-
|
18
|
-
|
19
|
-
k = Linearizer(*ast, opts=renderer)
|
20
|
-
k.required_optimizations()
|
17
|
+
def get_kernel(renderer:Renderer, ast:LazyOp) -> Kernel:
|
18
|
+
if DEBUG >= 5:
|
19
|
+
print(ast)
|
20
|
+
k = Kernel(ast, opts=renderer).required_optimizations()
|
21
21
|
if not NOOPT:
|
22
22
|
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
|
23
23
|
if BEAM >= 1:
|
24
24
|
from tinygrad.engine.search import beam_search, time_linearizer, bufs_from_lin
|
25
|
-
kb, k_opt =
|
26
|
-
kb.required_optimizations()
|
25
|
+
kb, k_opt = Kernel(ast, opts=renderer).required_optimizations(), k
|
27
26
|
rawbufs = bufs_from_lin(kb, allocate=False)
|
28
|
-
|
29
|
-
|
27
|
+
if BEAM.value >= 100:
|
28
|
+
from extra.mcts_search import mcts_search
|
29
|
+
k = mcts_search(kb, rawbufs, BEAM.value)
|
30
|
+
else:
|
31
|
+
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
32
|
+
if beam_compare:=getenv("BEAM_COMPARE", 1):
|
30
33
|
# TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
|
31
|
-
lins: List[Tuple[str,
|
32
|
-
if used_tensor_cores:
|
33
|
-
lins.append(("hc", Linearizer(*ast, opts=renderer)))
|
34
|
-
lins[-1][1].hand_coded_optimizations()
|
34
|
+
lins: List[Tuple[str, Kernel]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
|
35
|
+
if used_tensor_cores: lins.append(("hc", Kernel(ast, opts=renderer).hand_coded_optimizations()))
|
35
36
|
timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
36
37
|
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
37
38
|
k = timed[0][1]
|
38
39
|
if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
|
39
|
-
|
40
|
+
if beam_compare == 2:
|
41
|
+
from tinygrad import Tensor
|
42
|
+
all_outs: List[List[Tensor]] = []
|
43
|
+
with Context(DEBUG=0, BEAM=0, CAPTURING=0):
|
44
|
+
rand_bufs = [Tensor.normal(buf.size, std=0.1, dtype=buf.dtype).data() if dtypes.is_float(buf.dtype) else \
|
45
|
+
(Tensor.randint(buf.size, low=0, high=2).cast(buf.dtype).data() if buf.dtype == dtypes.bool else \
|
46
|
+
Tensor.randint(buf.size, low=dtypes.min(buf.dtype), high=dtypes.max(buf.dtype), dtype=buf.dtype).data()) \
|
47
|
+
for buf in rawbufs]
|
48
|
+
for _, tk in lins[::-1]:
|
49
|
+
for buf,data in zip(rawbufs, rand_bufs): buf.ensure_allocated().copyin(data)
|
50
|
+
time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True, disable_cache=True)
|
51
|
+
all_outs.append([Tensor(bytes(buf.as_buffer()), dtype=buf.dtype) for buf in rawbufs[:len(ast.src)]])
|
52
|
+
with Context(DEBUG=0, BEAM=0, CAPTURING=0):
|
53
|
+
for bufs in zip(*all_outs):
|
54
|
+
for b in bufs[1:]:
|
55
|
+
if dtypes.is_float(bufs[0].dtype):
|
56
|
+
# we check both atol and rtol here
|
57
|
+
diff_count = (((b-bufs[0]).abs() > 1e-3) * (((b-bufs[0])/bufs[0]).abs() > 1e-3)).sum().item()
|
58
|
+
else:
|
59
|
+
diff_count = (b != bufs[0]).sum().item()
|
60
|
+
if diff_count != 0:
|
61
|
+
raise RuntimeError(f"mismatch of {diff_count}/{b.numel()} items with type {b.dtype}, max {(b-bufs[0]).abs().max().item()}")
|
40
62
|
if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
|
41
63
|
if DEBUG >= 5: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
|
42
64
|
return k
|
@@ -44,8 +66,9 @@ def get_linearizer(renderer:Renderer, ast:Tuple[LazyOp, ...]) -> Linearizer:
|
|
44
66
|
# **************** Runners ****************
|
45
67
|
|
46
68
|
class Runner:
|
47
|
-
def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0):
|
48
|
-
self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate
|
69
|
+
def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0, lds_estimate:Optional[sint]=None):
|
70
|
+
self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate, self.lds_estimate = \
|
71
|
+
True, display_name, dname, op_estimate, mem_estimate, mem_estimate if lds_estimate is None else lds_estimate
|
49
72
|
@property
|
50
73
|
def device(self): return Device[self.dname]
|
51
74
|
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
|
@@ -59,7 +82,7 @@ class CompiledRunner(Runner):
|
|
59
82
|
self.p:Program = p
|
60
83
|
self.lib:bytes = precompiled if precompiled is not None else Device[p.dname].compiler.compile_cached(p.src)
|
61
84
|
self.clprg = Device[p.dname].runtime(p.function_name, self.lib)
|
62
|
-
super().__init__(p.name, p.dname, p.op_estimate, p.mem_estimate)
|
85
|
+
super().__init__(p.name, p.dname, p.op_estimate, p.mem_estimate, p.lds_estimate)
|
63
86
|
|
64
87
|
def __reduce__(self): return self.__class__, (self.p, self.lib)
|
65
88
|
|
@@ -73,10 +96,10 @@ class CompiledRunner(Runner):
|
|
73
96
|
self.p = replace(self.p, global_size=global_size, local_size=local_size)
|
74
97
|
lra = {}
|
75
98
|
if global_size:
|
76
|
-
lra['global_size'] = global_size
|
99
|
+
lra['global_size'] = tuple(global_size)
|
77
100
|
assert len(global_size) == 3, "global size must have len 3"
|
78
101
|
if local_size:
|
79
|
-
lra['local_size'] = local_size
|
102
|
+
lra['local_size'] = tuple(local_size)
|
80
103
|
assert len(local_size) == 3, "local size must have len 3"
|
81
104
|
return self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait)
|
82
105
|
|
@@ -119,24 +142,20 @@ class BufferCopy(Runner):
|
|
119
142
|
return time.perf_counter() - st
|
120
143
|
|
121
144
|
class BufferXfer(BufferCopy):
|
122
|
-
def copy(self, dest, src):
|
123
|
-
if hasattr(dest.allocator.device, "track_cross_buffer") and hasattr(src.allocator, "track_cross_device"):
|
124
|
-
dest.allocator.device.track_cross_buffer.append(src)
|
125
|
-
src.allocator.track_cross_device.add(dest.allocator.device)
|
126
|
-
dest.allocator.transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.device, dest_dev=dest.allocator.device)
|
145
|
+
def copy(self, dest, src): dest.allocator.transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.device, dest_dev=dest.allocator.device)
|
127
146
|
|
128
147
|
# **************** method cache ****************
|
129
148
|
|
130
|
-
method_cache: Dict[Tuple[str,
|
131
|
-
def get_runner(dname:str, ast:
|
132
|
-
ckey = (dname, ast, BEAM.value, False)
|
149
|
+
method_cache: Dict[Tuple[str, LazyOp, int, int, bool], CompiledRunner] = {}
|
150
|
+
def get_runner(dname:str, ast:LazyOp) -> CompiledRunner:
|
151
|
+
ckey = (dname, ast, BEAM.value, NOOPT.value, False)
|
133
152
|
if cret:=method_cache.get(ckey): return cret
|
134
|
-
bkey = (dname.split(":")[0], ast, BEAM.value, True)
|
153
|
+
bkey = (dname.split(":")[0], ast, BEAM.value, NOOPT.value, True)
|
135
154
|
if bret:=method_cache.get(bkey):
|
136
155
|
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, dname=dname), bret.lib)
|
137
156
|
else:
|
138
|
-
prg: Program =
|
139
|
-
if
|
157
|
+
prg: Program = get_kernel(Device[dname].renderer, ast).to_program()
|
158
|
+
if getenv("FUZZ_UOPS"):
|
140
159
|
from test.external.fuzz_uops import UOpsFuzzerRunner
|
141
160
|
return UOpsFuzzerRunner(replace(prg, dname=dname))
|
142
161
|
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, dname=dname))
|
@@ -148,39 +167,51 @@ def get_runner(dname:str, ast:Tuple[LazyOp, ...]) -> CompiledRunner:
|
|
148
167
|
class ExecItem:
|
149
168
|
prg: Runner
|
150
169
|
bufs: List[Optional[Buffer]]
|
170
|
+
metadata: Optional[List[Metadata]] = None
|
151
171
|
def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
|
152
172
|
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
|
153
173
|
et = self.prg(bufs, var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
|
154
174
|
if do_update_stats:
|
155
175
|
GlobalCounters.kernel_count += 1
|
156
|
-
GlobalCounters.global_ops += (
|
157
|
-
GlobalCounters.global_mem += (
|
176
|
+
GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.op_estimate, var_vals))
|
177
|
+
GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.mem_estimate, var_vals))
|
158
178
|
if et is not None: GlobalCounters.time_sum_s += et
|
159
179
|
if DEBUG >= 2:
|
180
|
+
lds_est = sym_infer(self.prg.lds_estimate, var_vals)
|
181
|
+
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
|
160
182
|
ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
|
161
|
-
print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(
|
162
|
-
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({
|
183
|
+
print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(40-ansilen(self.prg.display_name))} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
|
184
|
+
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501
|
185
|
+
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}"))
|
163
186
|
self.prg.first_run = False
|
164
187
|
return et
|
165
188
|
|
166
189
|
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
|
167
|
-
assert len(set(x.device for x in si.bufs)) == 1 or si.ast[0]
|
168
|
-
if si.ast
|
190
|
+
assert len(set(x.device for x in si.bufs)) == 1 or (si.ast.op is MetaOps.EXT and si.ast.arg[0] is MetaOps.COPY) or getenv("USE_COPY_KERNEL")
|
191
|
+
if si.ast.op is MetaOps.KERNEL:
|
169
192
|
runner = get_runner(si.outputs[0].device, si.ast)
|
170
|
-
return ExecItem(runner, [si.bufs[x
|
171
|
-
out,
|
172
|
-
if
|
193
|
+
return ExecItem(runner, [si.bufs[x] for x in runner.p.globals], si.metadata)
|
194
|
+
out, (op, arg) = si.outputs[0], si.ast.arg
|
195
|
+
if op is MetaOps.COPY:
|
173
196
|
kernel_type = BufferCopy
|
174
197
|
if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
|
175
198
|
kernel_type = BufferXfer
|
176
|
-
return ExecItem(kernel_type(
|
177
|
-
if
|
178
|
-
if
|
179
|
-
if
|
180
|
-
raise RuntimeError(f"don't know how to lower {ast}")
|
199
|
+
return ExecItem(kernel_type(arg, out.device, si.inputs[0].device), list(si.bufs))
|
200
|
+
if op is MetaOps.CUSTOM: return ExecItem(CustomOp(arg), list(si.bufs))
|
201
|
+
if op is MetaOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
|
202
|
+
if op is MetaOps.VIEW: return ExecItem(ViewOp(out), list(si.bufs))
|
203
|
+
raise RuntimeError(f"don't know how to lower {si.ast}")
|
181
204
|
|
182
205
|
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
|
183
|
-
while len(schedule):
|
206
|
+
while len(schedule):
|
207
|
+
si = schedule.pop(0)
|
208
|
+
try: yield lower_schedule_item(si)
|
209
|
+
except Exception as e:
|
210
|
+
if DEBUG >= 2:
|
211
|
+
print(f"error lowering {si.ast.op}")
|
212
|
+
print("tensor operations:")
|
213
|
+
pprint.pprint(si.metadata, indent=2)
|
214
|
+
raise e
|
184
215
|
|
185
216
|
# **************** main run function ****************
|
186
217
|
|
@@ -190,3 +221,48 @@ def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, i
|
|
190
221
|
for ei in lower_schedule(schedule):
|
191
222
|
if len(capturing) and CAPTURING: capturing[0].add(ei)
|
192
223
|
ei.run(var_vals, do_update_stats=do_update_stats)
|
224
|
+
|
225
|
+
# **************** memory planning ****************
|
226
|
+
|
227
|
+
def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], noopt_buffers=None, debug_prefix="") -> Dict[Buffer, Buffer]:
|
228
|
+
if getenv("NO_MEMORY_PLANNER"): return {}
|
229
|
+
first_appearance, last_appearance = {}, {}
|
230
|
+
for i,u in enumerate(buffers):
|
231
|
+
for buf in u:
|
232
|
+
if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
|
233
|
+
if buf.base not in first_appearance: first_appearance[buf.base] = i
|
234
|
+
last_appearance[buf.base] = i
|
235
|
+
|
236
|
+
# Sort buffers by size in descending order, prioritizing largest buffers for allocation first.
|
237
|
+
# Track free segments, each containing (start, stop, and buffer that could be reused on this segment).
|
238
|
+
free_segs: Dict[Tuple, List[Tuple[int, int, Buffer]]] = defaultdict(list) # Dict[buffer key, Tuple[start, end, buffer to reuse on the seg]]
|
239
|
+
def find_replace_buffer(buf, st, en):
|
240
|
+
key = (buf.device, buf.dtype, buf.options) + ((buf.nbytes,) if not hasattr(Device[buf.device].allocator, "offset") else tuple())
|
241
|
+
|
242
|
+
default_buf = (0, len(buffers) - 1, buf) # will return the buffer itself if the replace one is not found.
|
243
|
+
seg_st, seg_en, seg_buf = next((free_segs[key].pop(i) for i,(sst,sen,_) in enumerate(free_segs[key]) if sst <= st and en <= sen), default_buf)
|
244
|
+
|
245
|
+
free_segs[key] += [(seg_st, st - 1, seg_buf)] if st - 1 >= seg_st else []
|
246
|
+
free_segs[key] += [(en + 1, seg_en, seg_buf)] if seg_en >= en + 1 else []
|
247
|
+
|
248
|
+
return seg_buf if seg_buf.nbytes == buf.nbytes else Buffer(buf.device, buf.size, buf.dtype, base=seg_buf)
|
249
|
+
|
250
|
+
buffer_requests = sorted([(first_appearance[buf], last_appearance[buf], buf) for buf in first_appearance.keys()], key=lambda x: -x[2].nbytes)
|
251
|
+
assigned = {buf:find_replace_buffer(buf, st, en) for st, en, buf in buffer_requests}
|
252
|
+
|
253
|
+
for i,u in enumerate(buffers):
|
254
|
+
for buf in u:
|
255
|
+
if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
|
256
|
+
if buf._base is not None: assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf.base, buf.base).base, offset=buf.offset)
|
257
|
+
else: assigned[buf] = assigned.get(buf, buf)
|
258
|
+
|
259
|
+
if DEBUG >= 1 and len(ak:=dedup(x for x in assigned.keys() if x._base is None)) != len(av:=dedup(x for x in assigned.values() if x._base is None)):
|
260
|
+
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
|
261
|
+
f"{len(ak)} -> {len(av)} bufs")
|
262
|
+
return assigned
|
263
|
+
|
264
|
+
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
|
265
|
+
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
266
|
+
assigned = _internal_memory_planner([si.bufs for si in schedule],
|
267
|
+
noopt_buffers={b for si in schedule if si.ast.op is not MetaOps.KERNEL for b in si.bufs})
|
268
|
+
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]
|