tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/engine/memory.py
CHANGED
@@ -1,50 +1,69 @@
|
|
1
|
+
from typing import cast
|
1
2
|
from collections import defaultdict
|
2
3
|
from tinygrad.engine.schedule import ScheduleItem
|
3
4
|
from tinygrad.device import Device, Buffer
|
4
|
-
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG
|
5
|
-
from tinygrad.ops import Ops
|
5
|
+
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up
|
6
|
+
from tinygrad.uop.ops import Ops
|
7
|
+
from tinygrad.dtype import dtypes, ImageDType
|
8
|
+
from tinygrad.runtime.support.memory import TLSFAllocator
|
6
9
|
|
7
10
|
# **************** memory planning ****************
|
8
11
|
|
9
|
-
def _internal_memory_planner(buffers:list[list[Buffer]
|
12
|
+
def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ignore_checks=False, debug_prefix="") -> dict[Buffer, Buffer]:
|
10
13
|
if NO_MEMORY_PLANNER: return {}
|
11
|
-
first_appearance, last_appearance = {}, {}
|
14
|
+
first_appearance, last_appearance, buf_to_opt = {}, {}, set()
|
12
15
|
for i,u in enumerate(buffers):
|
13
16
|
for buf in u:
|
14
|
-
|
17
|
+
should_skip = buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers)
|
18
|
+
if not ignore_checks and should_skip: continue
|
15
19
|
if buf.base not in first_appearance: first_appearance[buf.base] = i
|
16
20
|
last_appearance[buf.base] = i
|
21
|
+
buf_to_opt.add(buf)
|
17
22
|
|
18
|
-
# Sort
|
19
|
-
|
20
|
-
|
21
|
-
def find_replace_buffer(buf, st, en):
|
22
|
-
key = (buf.device, buf.dtype, buf.options) + ((buf.nbytes,) if not hasattr(Device[buf.device].allocator, "offset") else tuple())
|
23
|
+
# Sort buffer operations in timeline order. Two events: buffer is allocated or buffer is freed.
|
24
|
+
buffer_requests = sorted([((first_appearance[buf], True), buf) for buf in first_appearance.keys()] + \
|
25
|
+
[((last_appearance[buf] + 1, False), buf) for buf in first_appearance.keys()], key=lambda x: x[0])
|
23
26
|
|
24
|
-
|
25
|
-
|
27
|
+
# Try to suballocate from a shared buffer managed by global_planner using TLSFAllocator.
|
28
|
+
# Also track buffer replacements for buffers that do not support suballocation.
|
29
|
+
buffer_replace:dict[Buffer, tuple[Buffer|None, int|None]] = {}
|
30
|
+
reuse_buffers:dict[tuple, list[Buffer]] = defaultdict(list)
|
31
|
+
global_planner:dict[str, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(1 << 44, block_size=0x1000, lv2_cnt=32)))
|
32
|
+
for (_, is_open_ev), buf in buffer_requests:
|
33
|
+
# Check if suballocation is possible for the given buffer and device.
|
34
|
+
if hasattr(Device[buf.device].allocator, "_offset") and not isinstance(buf.dtype, ImageDType):
|
35
|
+
if is_open_ev: buffer_replace[buf] = (None, global_planner[buf.device][1].alloc(round_up(buf.nbytes, 0x1000)))
|
36
|
+
else: global_planner[buf.device][1].free(cast(int, buffer_replace[buf][1]))
|
37
|
+
global_planner[buf.device] = (max(global_planner[buf.device][0], buffer_replace[buf][1] + buf.nbytes), global_planner[buf.device][1])
|
38
|
+
else:
|
39
|
+
key = (buf.device, buf.dtype, buf.options, buf.nbytes)
|
40
|
+
if is_open_ev: buffer_replace[buf] = (reuse_buffers[key].pop(), None) if key in reuse_buffers and len(reuse_buffers[key]) > 0 else (buf, None)
|
41
|
+
else: reuse_buffers[key].append(cast(Buffer, buffer_replace[buf][0]))
|
26
42
|
|
27
|
-
|
28
|
-
|
43
|
+
# Allocate global buffers based on the memory planner.
|
44
|
+
global_buffers = {dev: Buffer(dev, round_up(sz, 0x1000), dtypes.int8) for dev, (sz, _) in global_planner.items()}
|
45
|
+
buffer_resolve:dict[Buffer, tuple[Buffer, int|None]] = {buf: (base or global_buffers[buf.device], off) for buf,(base,off) in buffer_replace.items()}
|
29
46
|
|
30
|
-
|
47
|
+
# Assign buffers. First, assign full buffers (not sub-buffers).
|
48
|
+
assigned:dict[Buffer, Buffer] = {}
|
49
|
+
for buf, (base, off) in buffer_resolve.items():
|
50
|
+
if buf != base:
|
51
|
+
assigned[buf] = base if off is None else Buffer(buf.device, buf.size, buf.dtype, base=base, offset=off)
|
31
52
|
|
32
|
-
|
33
|
-
|
53
|
+
# Now assign sub-buffers.
|
54
|
+
for buf in buf_to_opt:
|
55
|
+
if buf._base is not None:
|
56
|
+
assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=(pbuf:=assigned.get(buf.base, buf.base)).base, offset=pbuf.offset+buf.offset)
|
34
57
|
|
35
|
-
|
36
|
-
for
|
37
|
-
|
38
|
-
|
39
|
-
else: assigned[buf] = assigned.get(buf, buf)
|
58
|
+
if DEBUG >= 1:
|
59
|
+
ak, av = dedup(x for x in assigned.keys() if x._base is None),dedup(x for x in assigned.values() if x._base is None)+list(global_buffers.values())
|
60
|
+
omem, nmem = sum([x.nbytes for x in ak])/1e6, sum([x.nbytes for x in av])/1e6
|
61
|
+
if omem != nmem: print(f"{debug_prefix}memory reduced from {omem:.2f} MB -> {nmem:.2f} MB,", f"{len(ak)} -> {len(av)} bufs")
|
40
62
|
|
41
|
-
if DEBUG >= 1 and len(ak:=dedup(x for x in assigned.keys() if x._base is None)) != len(av:=dedup(x for x in assigned.values() if x._base is None)):
|
42
|
-
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
|
43
|
-
f"{len(ak)} -> {len(av)} bufs")
|
44
63
|
return assigned
|
45
64
|
|
46
65
|
def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]:
|
47
66
|
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
48
|
-
assigned = _internal_memory_planner([si.bufs for si in schedule],
|
67
|
+
assigned = _internal_memory_planner([list(si.bufs) for si in schedule],
|
49
68
|
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
|
50
|
-
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]
|
69
|
+
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.fixedvars) for si in schedule]
|
tinygrad/engine/realize.py
CHANGED
@@ -1,30 +1,51 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import cast, Generator
|
2
2
|
import time, pprint
|
3
|
-
from dataclasses import dataclass, replace
|
4
|
-
from tinygrad.helpers import all_same, colored,
|
5
|
-
from tinygrad.helpers import DEVECTORIZE, time_to_str
|
6
|
-
from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
|
3
|
+
from dataclasses import dataclass, replace, field
|
4
|
+
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
|
5
|
+
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile
|
6
|
+
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo
|
7
7
|
from tinygrad.device import Device, Buffer
|
8
8
|
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
9
|
-
from tinygrad.codegen.kernel import Kernel
|
10
9
|
from tinygrad.engine.schedule import ScheduleItem
|
10
|
+
from tinygrad.codegen import full_rewrite
|
11
|
+
from tinygrad.codegen.opt.kernel import Opt
|
11
12
|
|
12
13
|
# **************** Program Creation ****************
|
13
14
|
|
14
|
-
|
15
|
-
def
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
15
|
+
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
|
16
|
+
def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> ProgramSpec:
|
17
|
+
"""
|
18
|
+
Transform an AST into a ProgramSpec. May trigger BEAM search.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
ast: The Ops.SINK rooted AST
|
22
|
+
renderer: The renderer used to generate the code
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
The ProgramSpec of the program.
|
26
|
+
"""
|
27
|
+
|
28
|
+
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
29
|
+
|
30
|
+
# linearize
|
31
|
+
if renderer is None: renderer = Device.default.renderer
|
32
|
+
if opts is not None:
|
33
|
+
assert ast.arg is None, "can't apply opts if sink has an arg"
|
34
|
+
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
|
35
|
+
try:
|
36
|
+
uops = full_rewrite(ast, renderer)
|
37
|
+
except RuntimeError:
|
38
|
+
print("***** LINEARIZE FAILURE *****")
|
39
|
+
print(f"ast = {ast}")
|
40
|
+
raise
|
41
|
+
assert uops[-1].op is Ops.SINK, "last uop must be sink"
|
42
|
+
|
43
|
+
# print and render
|
44
|
+
if DEBUG >= 6: print_uops(uops)
|
45
|
+
src = renderer.render(uops)
|
46
|
+
|
47
|
+
return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops,
|
48
|
+
global_size=[1,1,1] if renderer.has_local else None, local_size=[1,1,1] if renderer.has_local else None)
|
28
49
|
|
29
50
|
# **************** Runners ****************
|
30
51
|
|
@@ -33,27 +54,30 @@ class Runner:
|
|
33
54
|
self.first_run, self.display_name, self.device, self.estimates = True, display_name, device, estimates
|
34
55
|
@property
|
35
56
|
def dev(self): return Device[self.device]
|
36
|
-
def exec(self, rawbufs:list[Buffer], var_vals:
|
57
|
+
def exec(self, rawbufs:list[Buffer], var_vals:dict[Variable, int]|None=None) -> float|None:
|
37
58
|
return self(rawbufs, {} if var_vals is None else var_vals)
|
38
|
-
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) ->
|
59
|
+
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> float|None:
|
39
60
|
raise NotImplementedError("override this")
|
40
61
|
|
41
62
|
class CompiledRunner(Runner):
|
42
|
-
def __init__(self, p:ProgramSpec, precompiled:
|
63
|
+
def __init__(self, p:ProgramSpec, precompiled:bytes|None=None, prg=None):
|
43
64
|
if DEBUG >= 4: print(p.src)
|
44
65
|
self.p:ProgramSpec = p
|
45
|
-
|
46
|
-
|
47
|
-
|
66
|
+
if precompiled is not None: self.lib = precompiled
|
67
|
+
else:
|
68
|
+
with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,), cat="compiler"), "TINY"):
|
69
|
+
self.lib = Device[p.device].compiler.compile_cached(p.src)
|
70
|
+
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib)
|
71
|
+
self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg
|
48
72
|
super().__init__(p.name, p.device, p.estimates)
|
49
73
|
|
50
74
|
def __reduce__(self): return self.__class__, (self.p, self.lib)
|
51
75
|
|
52
|
-
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) ->
|
76
|
+
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> float|None:
|
53
77
|
global_size, local_size = self.p.launch_dims(var_vals)
|
54
78
|
if global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
|
55
79
|
# TODO: this is copied from get_program
|
56
|
-
from tinygrad.
|
80
|
+
from tinygrad.codegen.opt.search import optimize_local_size
|
57
81
|
local_size = optimize_local_size(self._prg, global_size, rawbufs)
|
58
82
|
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
59
83
|
self.p = replace(self.p, global_size=global_size, local_size=local_size)
|
@@ -78,7 +102,7 @@ class BufferCopy(Runner):
|
|
78
102
|
super().__init__(colored(name, "yellow"), dest_device, Estimates(lds=total_sz, mem=total_sz))
|
79
103
|
def copy(self, dest, src):
|
80
104
|
disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.dev, 'io_uring') and \
|
81
|
-
getattr(src.allocator.dev, 'fd', None) is not None
|
105
|
+
getattr(src.allocator.dev, 'fd', None) is not None and dest.allocator.supports_copy_from_disk
|
82
106
|
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
|
83
107
|
dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
|
84
108
|
elif src.device.startswith("DISK") and hasattr(dest.allocator, '_as_buffer'):
|
@@ -110,7 +134,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
|
110
134
|
if bret:=method_cache.get(bkey):
|
111
135
|
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
|
112
136
|
else:
|
113
|
-
prg: ProgramSpec =
|
137
|
+
prg: ProgramSpec = get_program(ast, Device[device].renderer)
|
114
138
|
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))
|
115
139
|
return ret
|
116
140
|
|
@@ -119,10 +143,11 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
|
119
143
|
@dataclass(frozen=True)
|
120
144
|
class ExecItem:
|
121
145
|
prg: Runner
|
122
|
-
bufs: list[
|
123
|
-
metadata:
|
124
|
-
|
125
|
-
|
146
|
+
bufs: list[Buffer|None]
|
147
|
+
metadata: tuple[Metadata, ...]|None = None
|
148
|
+
fixedvars: dict[Variable, int] = field(default_factory=dict)
|
149
|
+
def run(self, _var_vals:dict[Variable, int]|None=None, wait=False, jit=False, do_update_stats=True) -> float|None:
|
150
|
+
var_vals = self.fixedvars if _var_vals is None else (_var_vals|self.fixedvars)
|
126
151
|
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
|
127
152
|
et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
|
128
153
|
if do_update_stats:
|
@@ -134,7 +159,7 @@ class ExecItem:
|
|
134
159
|
lds_est = sym_infer(self.prg.estimates.lds, var_vals)
|
135
160
|
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
|
136
161
|
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
|
137
|
-
print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(
|
162
|
+
print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(44-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
|
138
163
|
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501
|
139
164
|
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}"))
|
140
165
|
self.prg.first_run = False
|
@@ -148,12 +173,13 @@ si_lowerer = PatternMatcher([
|
|
148
173
|
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
|
149
174
|
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))),
|
150
175
|
])
|
151
|
-
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
|
176
|
+
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
|
177
|
+
return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata, si.fixedvars)
|
152
178
|
|
153
|
-
def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, None]:
|
179
|
+
def lower_schedule(schedule:list[ScheduleItem]) -> Generator[tuple[ScheduleItem, ExecItem], None, None]:
|
154
180
|
while len(schedule):
|
155
181
|
si = schedule.pop(0)
|
156
|
-
try: yield lower_schedule_item(si)
|
182
|
+
try: yield (si, lower_schedule_item(si))
|
157
183
|
except Exception as e:
|
158
184
|
if DEBUG >= 2:
|
159
185
|
print(f"error lowering {si.ast.op}")
|
@@ -165,7 +191,22 @@ def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, Non
|
|
165
191
|
|
166
192
|
capturing: list = [] # put classes with an add method in here
|
167
193
|
|
168
|
-
def run_schedule(schedule:list[ScheduleItem], var_vals:
|
169
|
-
for ei in lower_schedule(schedule):
|
194
|
+
def run_schedule(schedule:list[ScheduleItem], var_vals:dict[Variable, int]|None=None, do_update_stats=True):
|
195
|
+
for si, ei in lower_schedule(schedule):
|
170
196
|
if len(capturing) and CAPTURING: capturing[0].add(ei)
|
171
|
-
|
197
|
+
if VALIDATE_WITH_CPU and si.ast.op is Ops.SINK:
|
198
|
+
# copy in allocated buffers from the GPU
|
199
|
+
nb: tuple[Buffer, ...] = tuple(Buffer("CPU", b.size, b.dtype) for b in si.bufs)
|
200
|
+
for cpu_b, gpu_b in zip(nb, si.bufs):
|
201
|
+
if gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
|
202
|
+
|
203
|
+
# run on GPU
|
204
|
+
ei.run(var_vals, do_update_stats=do_update_stats)
|
205
|
+
|
206
|
+
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
|
207
|
+
lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata, si.fixedvars)).run(var_vals, do_update_stats=do_update_stats)
|
208
|
+
import numpy as np
|
209
|
+
np.testing.assert_allclose(si.bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
|
210
|
+
else:
|
211
|
+
ei.run(var_vals, do_update_stats=do_update_stats)
|
212
|
+
|