tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +114 -172
- tinygrad/codegen/linearize.py +211 -81
- tinygrad/codegen/lowerer.py +30 -35
- tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
- tinygrad/codegen/transcendental.py +12 -13
- tinygrad/device.py +170 -47
- tinygrad/dtype.py +28 -26
- tinygrad/engine/jit.py +80 -63
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +162 -0
- tinygrad/engine/realize.py +58 -107
- tinygrad/engine/schedule.py +381 -314
- tinygrad/engine/search.py +40 -44
- tinygrad/gradient.py +70 -0
- tinygrad/helpers.py +77 -58
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +89 -64
- tinygrad/ops.py +562 -446
- tinygrad/renderer/__init__.py +79 -36
- tinygrad/renderer/cstyle.py +70 -84
- tinygrad/renderer/llvmir.py +32 -20
- tinygrad/renderer/ptx.py +79 -99
- tinygrad/renderer/wgsl.py +87 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libpciaccess.py +2023 -0
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +19 -21
- tinygrad/runtime/ops_amd.py +488 -327
- tinygrad/runtime/ops_clang.py +15 -28
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +129 -38
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +45 -40
- tinygrad/runtime/ops_metal.py +93 -73
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +232 -270
- tinygrad/runtime/ops_python.py +51 -46
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +63 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +384 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +26 -4
- tinygrad/runtime/support/hcq.py +254 -324
- tinygrad/runtime/support/llvm.py +32 -0
- tinygrad/shape/shapetracker.py +84 -53
- tinygrad/shape/view.py +103 -138
- tinygrad/spec.py +154 -0
- tinygrad/tensor.py +744 -496
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
- tinygrad-0.10.1.dist-info/RECORD +86 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/engine/realize.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional, cast, Generator
|
2
2
|
import time, pprint
|
3
3
|
from dataclasses import dataclass, replace
|
4
|
-
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata,
|
5
|
-
from tinygrad.ops import Ops, UOp, Variable, sym_infer
|
6
|
-
from tinygrad.dtype import dtypes
|
4
|
+
from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
|
5
|
+
from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
|
7
6
|
from tinygrad.device import Device, Buffer
|
8
|
-
from tinygrad.renderer import Renderer,
|
7
|
+
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
9
8
|
from tinygrad.codegen.kernel import Kernel
|
10
9
|
from tinygrad.engine.schedule import ScheduleItem
|
11
10
|
|
@@ -13,50 +12,15 @@ from tinygrad.engine.schedule import ScheduleItem
|
|
13
12
|
|
14
13
|
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
|
15
14
|
def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
|
16
|
-
if DEBUG >= 5:
|
17
|
-
print(ast)
|
15
|
+
if DEBUG >= 5: print(ast)
|
18
16
|
k = Kernel(ast, opts=renderer).required_optimizations()
|
19
17
|
if not NOOPT:
|
20
|
-
if not
|
18
|
+
if not k.apply_tensor_cores(getenv("TC", 1)): k.hand_coded_optimizations()
|
21
19
|
if BEAM >= 1:
|
22
|
-
from tinygrad.engine.search import beam_search,
|
23
|
-
kb
|
20
|
+
from tinygrad.engine.search import beam_search, bufs_from_lin
|
21
|
+
kb = Kernel(ast, opts=renderer).required_optimizations()
|
24
22
|
rawbufs = bufs_from_lin(kb, allocate=False)
|
25
|
-
|
26
|
-
from extra.mcts_search import mcts_search
|
27
|
-
k = mcts_search(kb, rawbufs, BEAM.value)
|
28
|
-
else:
|
29
|
-
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
30
|
-
if beam_compare:=getenv("BEAM_COMPARE", 1):
|
31
|
-
# TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
|
32
|
-
lins: List[Tuple[str, Kernel]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
|
33
|
-
if used_tensor_cores: lins.append(("hc", Kernel(ast, opts=renderer).hand_coded_optimizations()))
|
34
|
-
timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
35
|
-
if DEBUG >= 3: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
36
|
-
k = timed[0][1]
|
37
|
-
if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
|
38
|
-
if beam_compare == 2:
|
39
|
-
from tinygrad import Tensor
|
40
|
-
all_outs: List[List[Tensor]] = []
|
41
|
-
with Context(DEBUG=0, BEAM=0, CAPTURING=0):
|
42
|
-
rand_bufs = [Tensor.normal(buf.size, std=0.1, dtype=buf.dtype).data() if dtypes.is_float(buf.dtype) else \
|
43
|
-
(Tensor.randint(buf.size, low=0, high=2).cast(buf.dtype).data() if buf.dtype == dtypes.bool else \
|
44
|
-
Tensor.randint(buf.size, low=dtypes.min(buf.dtype), high=dtypes.max(buf.dtype), dtype=buf.dtype).data()) \
|
45
|
-
for buf in rawbufs]
|
46
|
-
for _, tk in lins[::-1]:
|
47
|
-
for buf,data in zip(rawbufs, rand_bufs): buf.ensure_allocated().copyin(data)
|
48
|
-
time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True, disable_cache=True)
|
49
|
-
all_outs.append([Tensor(bytes(buf.as_buffer()), dtype=buf.dtype) for buf in rawbufs[:len(ast.src)]])
|
50
|
-
with Context(DEBUG=0, BEAM=0, CAPTURING=0):
|
51
|
-
for bufs in zip(*all_outs):
|
52
|
-
for b in bufs[1:]:
|
53
|
-
if dtypes.is_float(bufs[0].dtype):
|
54
|
-
# we check both atol and rtol here
|
55
|
-
diff_count = (((b-bufs[0]).abs() > 1e-3) * (((b-bufs[0])/bufs[0]).abs() > 1e-3)).sum().item()
|
56
|
-
else:
|
57
|
-
diff_count = (b != bufs[0]).sum().item()
|
58
|
-
if diff_count != 0:
|
59
|
-
raise RuntimeError(f"mismatch of {diff_count}/{b.numel()} items with type {b.dtype}, max {(b-bufs[0]).abs().max().item()}")
|
23
|
+
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
60
24
|
if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
|
61
25
|
if DEBUG >= 5: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
|
62
26
|
return k
|
@@ -64,33 +28,32 @@ def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
|
|
64
28
|
# **************** Runners ****************
|
65
29
|
|
66
30
|
class Runner:
|
67
|
-
def __init__(self, display_name:str,
|
68
|
-
self.first_run, self.display_name, self.
|
69
|
-
True, display_name, dname, op_estimate, mem_estimate, mem_estimate if lds_estimate is None else lds_estimate
|
31
|
+
def __init__(self, display_name:str, device:str, estimates=Estimates()):
|
32
|
+
self.first_run, self.display_name, self.device, self.estimates = True, display_name, device, estimates
|
70
33
|
@property
|
71
|
-
def
|
72
|
-
def exec(self, rawbufs:
|
34
|
+
def dev(self): return Device[self.device]
|
35
|
+
def exec(self, rawbufs:list[Buffer], var_vals:Optional[dict[Variable, int]]=None) -> Optional[float]:
|
73
36
|
return self(rawbufs, {} if var_vals is None else var_vals)
|
74
|
-
def __call__(self, rawbufs:
|
37
|
+
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> Optional[float]:
|
75
38
|
raise NotImplementedError("override this")
|
76
39
|
|
77
40
|
class CompiledRunner(Runner):
|
78
|
-
def __init__(self, p:
|
41
|
+
def __init__(self, p:ProgramSpec, precompiled:Optional[bytes]=None):
|
79
42
|
if DEBUG >= 4: print(p.src)
|
80
|
-
self.p:
|
81
|
-
self.lib:bytes = precompiled if precompiled is not None else Device[p.
|
82
|
-
if DEBUG >= 6: Device[p.
|
83
|
-
self.
|
84
|
-
super().__init__(p.name, p.
|
43
|
+
self.p:ProgramSpec = p
|
44
|
+
self.lib:bytes = precompiled if precompiled is not None else Device[p.device].compiler.compile_cached(p.src)
|
45
|
+
if DEBUG >= 6: Device[p.device].compiler.disassemble(self.lib)
|
46
|
+
self._prg = Device[p.device].runtime(p.function_name, self.lib)
|
47
|
+
super().__init__(p.name, p.device, p.estimates)
|
85
48
|
|
86
49
|
def __reduce__(self): return self.__class__, (self.p, self.lib)
|
87
50
|
|
88
|
-
def __call__(self, rawbufs:
|
51
|
+
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> Optional[float]:
|
89
52
|
global_size, local_size = self.p.launch_dims(var_vals)
|
90
53
|
if global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
|
91
54
|
# TODO: this is copied from get_program
|
92
55
|
from tinygrad.engine.search import optimize_local_size
|
93
|
-
local_size = optimize_local_size(self.
|
56
|
+
local_size = optimize_local_size(self._prg, global_size, rawbufs)
|
94
57
|
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
95
58
|
self.p = replace(self.p, global_size=global_size, local_size=local_size)
|
96
59
|
lra = {}
|
@@ -100,33 +63,29 @@ class CompiledRunner(Runner):
|
|
100
63
|
if local_size:
|
101
64
|
lra['local_size'] = tuple(local_size)
|
102
65
|
assert len(local_size) == 3, "local size must have len 3"
|
103
|
-
return self.
|
104
|
-
|
105
|
-
class EmptyOp(Runner):
|
106
|
-
def __init__(self, buf:Buffer): super().__init__(colored(f"empty {buf.size:10d} {buf.dtype}", "yellow"), buf.device)
|
107
|
-
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): pass
|
66
|
+
return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait)
|
108
67
|
|
109
68
|
class ViewOp(Runner):
|
110
69
|
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
|
111
|
-
def __call__(self, rawbufs:
|
70
|
+
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False):
|
112
71
|
assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}"
|
113
72
|
|
114
73
|
class BufferCopy(Runner):
|
115
74
|
def __init__(self, total_sz, dest_device, src_device):
|
116
75
|
if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
|
117
76
|
else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
|
118
|
-
super().__init__(colored(name, "yellow"), dest_device,
|
77
|
+
super().__init__(colored(name, "yellow"), dest_device, Estimates(lds=total_sz, mem=total_sz))
|
119
78
|
def copy(self, dest, src):
|
120
|
-
disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.
|
121
|
-
getattr(src.allocator.
|
79
|
+
disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.dev, 'io_uring') and \
|
80
|
+
getattr(src.allocator.dev, 'fd', None) is not None
|
122
81
|
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
|
123
82
|
dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
|
124
|
-
elif src.device.startswith("DISK") and hasattr(dest.allocator, '
|
83
|
+
elif src.device.startswith("DISK") and hasattr(dest.allocator, '_as_buffer'):
|
125
84
|
# fast(ish) path, uses readinto in diskbuffers
|
126
|
-
src.allocator.
|
85
|
+
src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
|
127
86
|
else:
|
128
87
|
dest.copyin(src.as_buffer(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy
|
129
|
-
def __call__(self, rawbufs:
|
88
|
+
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False):
|
130
89
|
dest, src = rawbufs[0:2]
|
131
90
|
assert dest.size == src.size and dest.dtype == src.dtype, f"buffer copy mismatch, {dest.size} != {src.size}, {dest.dtype} != {src.dtype}"
|
132
91
|
st = time.perf_counter()
|
@@ -136,23 +95,20 @@ class BufferCopy(Runner):
|
|
136
95
|
return time.perf_counter() - st
|
137
96
|
|
138
97
|
class BufferXfer(BufferCopy):
|
139
|
-
def copy(self, dest, src): dest.allocator.
|
98
|
+
def copy(self, dest, src): dest.allocator._transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.dev, dest_dev=dest.allocator.dev)
|
140
99
|
|
141
100
|
# **************** method cache ****************
|
142
101
|
|
143
|
-
method_cache:
|
144
|
-
def get_runner(
|
145
|
-
ckey = (
|
102
|
+
method_cache: dict[tuple[str, bytes, int, int, bool], CompiledRunner] = {}
|
103
|
+
def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
104
|
+
ckey = (device, ast.key, BEAM.value, NOOPT.value, False)
|
146
105
|
if cret:=method_cache.get(ckey): return cret
|
147
|
-
bkey = (
|
106
|
+
bkey = (device.split(":")[0], ast.key, BEAM.value, NOOPT.value, True)
|
148
107
|
if bret:=method_cache.get(bkey):
|
149
|
-
method_cache[ckey] = ret = CompiledRunner(replace(bret.p,
|
108
|
+
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
|
150
109
|
else:
|
151
|
-
prg:
|
152
|
-
|
153
|
-
from test.external.fuzz_uops import UOpsFuzzerRunner
|
154
|
-
return UOpsFuzzerRunner(replace(prg, dname=dname))
|
155
|
-
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, dname=dname))
|
110
|
+
prg: ProgramSpec = get_kernel(Device[device].renderer, ast).to_program()
|
111
|
+
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))
|
156
112
|
return ret
|
157
113
|
|
158
114
|
# **************** lowering functions ****************
|
@@ -160,43 +116,38 @@ def get_runner(dname:str, ast:UOp) -> CompiledRunner:
|
|
160
116
|
@dataclass(frozen=True)
|
161
117
|
class ExecItem:
|
162
118
|
prg: Runner
|
163
|
-
bufs:
|
164
|
-
metadata: Optional[
|
165
|
-
def run(self, _var_vals:Optional[
|
119
|
+
bufs: list[Optional[Buffer]]
|
120
|
+
metadata: Optional[tuple[Metadata, ...]] = None
|
121
|
+
def run(self, _var_vals:Optional[dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
|
166
122
|
var_vals = {} if _var_vals is None else _var_vals
|
167
123
|
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
|
168
124
|
et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
|
169
125
|
if do_update_stats:
|
170
126
|
GlobalCounters.kernel_count += 1
|
171
|
-
GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.
|
172
|
-
GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.
|
127
|
+
GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.estimates.ops, var_vals))
|
128
|
+
GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.estimates.mem, var_vals))
|
173
129
|
if et is not None: GlobalCounters.time_sum_s += et
|
174
130
|
if DEBUG >= 2:
|
175
|
-
lds_est = sym_infer(self.prg.
|
131
|
+
lds_est = sym_infer(self.prg.estimates.lds, var_vals)
|
176
132
|
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
|
177
133
|
ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
|
178
|
-
print(f"{colored(f'*** {self.prg.
|
134
|
+
print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
|
179
135
|
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501
|
180
136
|
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}"))
|
181
137
|
self.prg.first_run = False
|
182
138
|
return et
|
183
139
|
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
if si.ast.op is Ops.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
|
196
|
-
if si.ast.op is Ops.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs))
|
197
|
-
raise RuntimeError(f"don't know how to lower {si.ast}")
|
198
|
-
|
199
|
-
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
|
140
|
+
# NOTE: ctx is the buffers
|
141
|
+
si_lowerer = PatternMatcher([
|
142
|
+
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: (runner:=get_runner(ctx[0].device, sink), [ctx[x] for x in runner.p.globals])),
|
143
|
+
(UPat(Ops.BUFFER_VIEW), lambda ctx: (ViewOp(ctx[0]), list(ctx))),
|
144
|
+
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
|
145
|
+
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
|
146
|
+
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))),
|
147
|
+
])
|
148
|
+
def lower_schedule_item(si:ScheduleItem) -> ExecItem: return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata)
|
149
|
+
|
150
|
+
def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, None]:
|
200
151
|
while len(schedule):
|
201
152
|
si = schedule.pop(0)
|
202
153
|
try: yield lower_schedule_item(si)
|
@@ -209,9 +160,9 @@ def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, Non
|
|
209
160
|
|
210
161
|
# **************** main run function ****************
|
211
162
|
|
212
|
-
capturing:
|
163
|
+
capturing: list = [] # put classes with an add method in here
|
213
164
|
|
214
|
-
def run_schedule(schedule:
|
165
|
+
def run_schedule(schedule:list[ScheduleItem], var_vals:Optional[dict[Variable, int]]=None, do_update_stats=True):
|
215
166
|
for ei in lower_schedule(schedule):
|
216
167
|
if len(capturing) and CAPTURING: capturing[0].add(ei)
|
217
168
|
ei.run(var_vals, do_update_stats=do_update_stats)
|