tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/engine/realize.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional, cast, Generator
|
2
2
|
import time, pprint
|
3
3
|
from dataclasses import dataclass, replace
|
4
|
-
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata,
|
5
|
-
from tinygrad.
|
6
|
-
from tinygrad.
|
4
|
+
from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
|
5
|
+
from tinygrad.helpers import DEVECTORIZE, time_to_str
|
6
|
+
from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
|
7
7
|
from tinygrad.device import Device, Buffer
|
8
|
-
from tinygrad.renderer import Renderer,
|
8
|
+
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
9
9
|
from tinygrad.codegen.kernel import Kernel
|
10
10
|
from tinygrad.engine.schedule import ScheduleItem
|
11
11
|
|
@@ -13,50 +13,15 @@ from tinygrad.engine.schedule import ScheduleItem
|
|
13
13
|
|
14
14
|
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
|
15
15
|
def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
|
16
|
-
if DEBUG >= 5:
|
17
|
-
print(ast)
|
16
|
+
if DEBUG >= 5: print(ast)
|
18
17
|
k = Kernel(ast, opts=renderer).required_optimizations()
|
19
18
|
if not NOOPT:
|
20
|
-
if not
|
19
|
+
if not k.apply_tensor_cores(getenv("TC", 1)): k.hand_coded_optimizations()
|
21
20
|
if BEAM >= 1:
|
22
|
-
from tinygrad.engine.search import beam_search,
|
23
|
-
kb
|
21
|
+
from tinygrad.engine.search import beam_search, bufs_from_lin
|
22
|
+
kb = Kernel(ast, opts=renderer).required_optimizations()
|
24
23
|
rawbufs = bufs_from_lin(kb, allocate=False)
|
25
|
-
|
26
|
-
from extra.mcts_search import mcts_search
|
27
|
-
k = mcts_search(kb, rawbufs, BEAM.value)
|
28
|
-
else:
|
29
|
-
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
30
|
-
if beam_compare:=getenv("BEAM_COMPARE", 1):
|
31
|
-
# TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
|
32
|
-
lins: List[Tuple[str, Kernel]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
|
33
|
-
if used_tensor_cores: lins.append(("hc", Kernel(ast, opts=renderer).hand_coded_optimizations()))
|
34
|
-
timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
35
|
-
if DEBUG >= 3: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
36
|
-
k = timed[0][1]
|
37
|
-
if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
|
38
|
-
if beam_compare == 2:
|
39
|
-
from tinygrad import Tensor
|
40
|
-
all_outs: List[List[Tensor]] = []
|
41
|
-
with Context(DEBUG=0, BEAM=0, CAPTURING=0):
|
42
|
-
rand_bufs = [Tensor.normal(buf.size, std=0.1, dtype=buf.dtype).data() if dtypes.is_float(buf.dtype) else \
|
43
|
-
(Tensor.randint(buf.size, low=0, high=2).cast(buf.dtype).data() if buf.dtype == dtypes.bool else \
|
44
|
-
Tensor.randint(buf.size, low=dtypes.min(buf.dtype), high=dtypes.max(buf.dtype), dtype=buf.dtype).data()) \
|
45
|
-
for buf in rawbufs]
|
46
|
-
for _, tk in lins[::-1]:
|
47
|
-
for buf,data in zip(rawbufs, rand_bufs): buf.ensure_allocated().copyin(data)
|
48
|
-
time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True, disable_cache=True)
|
49
|
-
all_outs.append([Tensor(bytes(buf.as_buffer()), dtype=buf.dtype) for buf in rawbufs[:len(ast.src)]])
|
50
|
-
with Context(DEBUG=0, BEAM=0, CAPTURING=0):
|
51
|
-
for bufs in zip(*all_outs):
|
52
|
-
for b in bufs[1:]:
|
53
|
-
if dtypes.is_float(bufs[0].dtype):
|
54
|
-
# we check both atol and rtol here
|
55
|
-
diff_count = (((b-bufs[0]).abs() > 1e-3) * (((b-bufs[0])/bufs[0]).abs() > 1e-3)).sum().item()
|
56
|
-
else:
|
57
|
-
diff_count = (b != bufs[0]).sum().item()
|
58
|
-
if diff_count != 0:
|
59
|
-
raise RuntimeError(f"mismatch of {diff_count}/{b.numel()} items with type {b.dtype}, max {(b-bufs[0]).abs().max().item()}")
|
24
|
+
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
60
25
|
if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
|
61
26
|
if DEBUG >= 5: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
|
62
27
|
return k
|
@@ -64,33 +29,32 @@ def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
|
|
64
29
|
# **************** Runners ****************
|
65
30
|
|
66
31
|
class Runner:
|
67
|
-
def __init__(self, display_name:str,
|
68
|
-
self.first_run, self.display_name, self.
|
69
|
-
True, display_name, dname, op_estimate, mem_estimate, mem_estimate if lds_estimate is None else lds_estimate
|
32
|
+
def __init__(self, display_name:str, device:str, estimates=Estimates()):
|
33
|
+
self.first_run, self.display_name, self.device, self.estimates = True, display_name, device, estimates
|
70
34
|
@property
|
71
|
-
def
|
72
|
-
def exec(self, rawbufs:
|
35
|
+
def dev(self): return Device[self.device]
|
36
|
+
def exec(self, rawbufs:list[Buffer], var_vals:Optional[dict[Variable, int]]=None) -> Optional[float]:
|
73
37
|
return self(rawbufs, {} if var_vals is None else var_vals)
|
74
|
-
def __call__(self, rawbufs:
|
38
|
+
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> Optional[float]:
|
75
39
|
raise NotImplementedError("override this")
|
76
40
|
|
77
41
|
class CompiledRunner(Runner):
|
78
|
-
def __init__(self, p:
|
42
|
+
def __init__(self, p:ProgramSpec, precompiled:Optional[bytes]=None):
|
79
43
|
if DEBUG >= 4: print(p.src)
|
80
|
-
self.p:
|
81
|
-
self.lib:bytes = precompiled if precompiled is not None else Device[p.
|
82
|
-
if DEBUG >= 6: Device[p.
|
83
|
-
self.
|
84
|
-
super().__init__(p.name, p.
|
44
|
+
self.p:ProgramSpec = p
|
45
|
+
self.lib:bytes = precompiled if precompiled is not None else Device[p.device].compiler.compile_cached(p.src)
|
46
|
+
if DEBUG >= 6: Device[p.device].compiler.disassemble(self.lib)
|
47
|
+
self._prg = Device[p.device].runtime(p.function_name, self.lib)
|
48
|
+
super().__init__(p.name, p.device, p.estimates)
|
85
49
|
|
86
50
|
def __reduce__(self): return self.__class__, (self.p, self.lib)
|
87
51
|
|
88
|
-
def __call__(self, rawbufs:
|
52
|
+
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> Optional[float]:
|
89
53
|
global_size, local_size = self.p.launch_dims(var_vals)
|
90
54
|
if global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
|
91
55
|
# TODO: this is copied from get_program
|
92
56
|
from tinygrad.engine.search import optimize_local_size
|
93
|
-
local_size = optimize_local_size(self.
|
57
|
+
local_size = optimize_local_size(self._prg, global_size, rawbufs)
|
94
58
|
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
95
59
|
self.p = replace(self.p, global_size=global_size, local_size=local_size)
|
96
60
|
lra = {}
|
@@ -100,33 +64,29 @@ class CompiledRunner(Runner):
|
|
100
64
|
if local_size:
|
101
65
|
lra['local_size'] = tuple(local_size)
|
102
66
|
assert len(local_size) == 3, "local size must have len 3"
|
103
|
-
return self.
|
104
|
-
|
105
|
-
class EmptyOp(Runner):
|
106
|
-
def __init__(self, buf:Buffer): super().__init__(colored(f"empty {buf.size:10d} {buf.dtype}", "yellow"), buf.device)
|
107
|
-
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): pass
|
67
|
+
return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait)
|
108
68
|
|
109
69
|
class ViewOp(Runner):
|
110
70
|
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
|
111
|
-
def __call__(self, rawbufs:
|
71
|
+
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False):
|
112
72
|
assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}"
|
113
73
|
|
114
74
|
class BufferCopy(Runner):
|
115
75
|
def __init__(self, total_sz, dest_device, src_device):
|
116
76
|
if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
|
117
77
|
else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
|
118
|
-
super().__init__(colored(name, "yellow"), dest_device,
|
78
|
+
super().__init__(colored(name, "yellow"), dest_device, Estimates(lds=total_sz, mem=total_sz))
|
119
79
|
def copy(self, dest, src):
|
120
|
-
disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.
|
121
|
-
getattr(src.allocator.
|
80
|
+
disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.dev, 'io_uring') and \
|
81
|
+
getattr(src.allocator.dev, 'fd', None) is not None
|
122
82
|
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
|
123
83
|
dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
|
124
|
-
elif src.device.startswith("DISK") and hasattr(dest.allocator, '
|
84
|
+
elif src.device.startswith("DISK") and hasattr(dest.allocator, '_as_buffer'):
|
125
85
|
# fast(ish) path, uses readinto in diskbuffers
|
126
|
-
src.allocator.
|
86
|
+
src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
|
127
87
|
else:
|
128
88
|
dest.copyin(src.as_buffer(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy
|
129
|
-
def __call__(self, rawbufs:
|
89
|
+
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False):
|
130
90
|
dest, src = rawbufs[0:2]
|
131
91
|
assert dest.size == src.size and dest.dtype == src.dtype, f"buffer copy mismatch, {dest.size} != {src.size}, {dest.dtype} != {src.dtype}"
|
132
92
|
st = time.perf_counter()
|
@@ -136,23 +96,22 @@ class BufferCopy(Runner):
|
|
136
96
|
return time.perf_counter() - st
|
137
97
|
|
138
98
|
class BufferXfer(BufferCopy):
|
139
|
-
def copy(self, dest, src): dest.allocator.
|
99
|
+
def copy(self, dest, src): dest.allocator._transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.dev, dest_dev=dest.allocator.dev)
|
140
100
|
|
141
101
|
# **************** method cache ****************
|
142
102
|
|
143
|
-
method_cache:
|
144
|
-
def get_runner(
|
145
|
-
|
103
|
+
method_cache: dict[tuple[str, bytes, tuple[int, ...], bool], CompiledRunner] = {}
|
104
|
+
def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
105
|
+
# TODO: this should be all context relevant to rendering
|
106
|
+
context = (BEAM.value, NOOPT.value, DEVECTORIZE.value)
|
107
|
+
ckey = (device, ast.key, context, False)
|
146
108
|
if cret:=method_cache.get(ckey): return cret
|
147
|
-
bkey = (
|
109
|
+
bkey = (device.split(":")[0], ast.key, context, True)
|
148
110
|
if bret:=method_cache.get(bkey):
|
149
|
-
method_cache[ckey] = ret = CompiledRunner(replace(bret.p,
|
111
|
+
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
|
150
112
|
else:
|
151
|
-
prg:
|
152
|
-
|
153
|
-
from test.external.fuzz_uops import UOpsFuzzerRunner
|
154
|
-
return UOpsFuzzerRunner(replace(prg, dname=dname))
|
155
|
-
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, dname=dname))
|
113
|
+
prg: ProgramSpec = get_kernel(Device[device].renderer, ast).to_program()
|
114
|
+
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))
|
156
115
|
return ret
|
157
116
|
|
158
117
|
# **************** lowering functions ****************
|
@@ -160,43 +119,38 @@ def get_runner(dname:str, ast:UOp) -> CompiledRunner:
|
|
160
119
|
@dataclass(frozen=True)
|
161
120
|
class ExecItem:
|
162
121
|
prg: Runner
|
163
|
-
bufs:
|
164
|
-
metadata: Optional[
|
165
|
-
def run(self, _var_vals:Optional[
|
122
|
+
bufs: list[Optional[Buffer]]
|
123
|
+
metadata: Optional[tuple[Metadata, ...]] = None
|
124
|
+
def run(self, _var_vals:Optional[dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
|
166
125
|
var_vals = {} if _var_vals is None else _var_vals
|
167
126
|
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
|
168
127
|
et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
|
169
128
|
if do_update_stats:
|
170
129
|
GlobalCounters.kernel_count += 1
|
171
|
-
GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.
|
172
|
-
GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.
|
130
|
+
GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.estimates.ops, var_vals))
|
131
|
+
GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.estimates.mem, var_vals))
|
173
132
|
if et is not None: GlobalCounters.time_sum_s += et
|
174
133
|
if DEBUG >= 2:
|
175
|
-
lds_est = sym_infer(self.prg.
|
134
|
+
lds_est = sym_infer(self.prg.estimates.lds, var_vals)
|
176
135
|
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
|
177
|
-
ptm =
|
178
|
-
print(f"{colored(f'*** {self.prg.
|
136
|
+
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
|
137
|
+
print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
|
179
138
|
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501
|
180
139
|
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}"))
|
181
140
|
self.prg.first_run = False
|
182
141
|
return et
|
183
142
|
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
if si.ast.op is Ops.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
|
196
|
-
if si.ast.op is Ops.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs))
|
197
|
-
raise RuntimeError(f"don't know how to lower {si.ast}")
|
198
|
-
|
199
|
-
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
|
143
|
+
# NOTE: ctx is the buffers
|
144
|
+
si_lowerer = PatternMatcher([
|
145
|
+
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: (runner:=get_runner(ctx[0].device, sink), [ctx[x] for x in runner.p.globals])),
|
146
|
+
(UPat(Ops.BUFFER_VIEW), lambda ctx: (ViewOp(ctx[0]), list(ctx))),
|
147
|
+
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
|
148
|
+
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
|
149
|
+
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))),
|
150
|
+
])
|
151
|
+
def lower_schedule_item(si:ScheduleItem) -> ExecItem: return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata)
|
152
|
+
|
153
|
+
def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, None]:
|
200
154
|
while len(schedule):
|
201
155
|
si = schedule.pop(0)
|
202
156
|
try: yield lower_schedule_item(si)
|
@@ -209,9 +163,9 @@ def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, Non
|
|
209
163
|
|
210
164
|
# **************** main run function ****************
|
211
165
|
|
212
|
-
capturing:
|
166
|
+
capturing: list = [] # put classes with an add method in here
|
213
167
|
|
214
|
-
def run_schedule(schedule:
|
168
|
+
def run_schedule(schedule:list[ScheduleItem], var_vals:Optional[dict[Variable, int]]=None, do_update_stats=True):
|
215
169
|
for ei in lower_schedule(schedule):
|
216
170
|
if len(capturing) and CAPTURING: capturing[0].add(ei)
|
217
171
|
ei.run(var_vals, do_update_stats=do_update_stats)
|