tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,29 @@
|
|
1
|
+
const NODE_PADDING = 10;
|
2
|
+
const LINE_HEIGHT = 14;
|
3
|
+
const canvas = new OffscreenCanvas(0, 0);
|
4
|
+
const ctx = canvas.getContext("2d");
|
5
|
+
ctx.font = `${LINE_HEIGHT}px sans-serif`;
|
6
|
+
|
7
|
+
onmessage = (e) => {
|
8
|
+
const { graph, additions, ctxs } = e.data;
|
9
|
+
const g = new dagre.graphlib.Graph({ compound: true });
|
10
|
+
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
|
11
|
+
if (additions.length !== 0) g.setNode("addition", {label:"", style:"fill: rgba(26, 27, 38, 0.5);", padding:0});
|
12
|
+
for (let [k, {label, src, ref, ...rest }] of Object.entries(graph)) {
|
13
|
+
// adjust node dims by label size (excluding escape codes) + add padding
|
14
|
+
let [width, height] = [0, 0];
|
15
|
+
for (line of label.replace(/\u001B\[(?:K|.*?m)/g, "").split("\n")) {
|
16
|
+
width = Math.max(width, ctx.measureText(line).width);
|
17
|
+
height += LINE_HEIGHT;
|
18
|
+
}
|
19
|
+
g.setNode(k, {width:width+NODE_PADDING*2, height:height+NODE_PADDING*2, padding:NODE_PADDING, label, ref, ...rest});
|
20
|
+
// add edges
|
21
|
+
const edgeCounts = {}
|
22
|
+
for (const s of src) edgeCounts[s] = (edgeCounts[s] || 0)+1;
|
23
|
+
for (const s of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? edgeCounts[s] : null });
|
24
|
+
if (additions.includes(parseInt(k))) g.setParent(k, "addition");
|
25
|
+
}
|
26
|
+
dagre.layout(g);
|
27
|
+
postMessage(dagre.graphlib.json.write(g));
|
28
|
+
self.close();
|
29
|
+
}
|
tinygrad/viz/serve.py
CHANGED
@@ -1,118 +1,240 @@
|
|
1
1
|
#!/usr/bin/env python3
|
2
|
-
import multiprocessing, pickle,
|
3
|
-
|
2
|
+
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io
|
3
|
+
import subprocess, ctypes
|
4
|
+
from contextlib import redirect_stdout
|
5
|
+
from decimal import Decimal
|
6
|
+
from http.server import BaseHTTPRequestHandler
|
4
7
|
from urllib.parse import parse_qs, urlparse
|
5
|
-
from typing import Any,
|
6
|
-
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap
|
7
|
-
from tinygrad.ops import TrackedGraphRewrite, UOp, Ops,
|
8
|
-
from tinygrad.
|
9
|
-
from tinygrad.
|
8
|
+
from typing import Any, TypedDict, Generator
|
9
|
+
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent
|
10
|
+
from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, srender, sint
|
11
|
+
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
|
12
|
+
from tinygrad.renderer import ProgramSpec
|
10
13
|
from tinygrad.dtype import dtypes
|
11
14
|
|
12
|
-
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
|
13
|
-
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.
|
14
|
-
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#
|
15
|
+
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
16
|
+
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
|
17
|
+
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
15
18
|
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
|
16
|
-
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
17
|
-
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.
|
19
|
+
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
|
20
|
+
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
|
21
|
+
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
22
|
+
Ops.CHILDREN: "#80ffc0", Ops.CHILD: "#80fff0", Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e"}
|
18
23
|
|
19
24
|
# VIZ API
|
20
25
|
|
21
|
-
# NOTE: if any extra rendering in VIZ fails, we don't crash
|
22
|
-
def pcall(fxn:Callable[..., str], *args, **kwargs) -> str:
|
23
|
-
try: return fxn(*args, **kwargs)
|
24
|
-
except Exception as e: return f"ERROR: {e}"
|
25
|
-
|
26
26
|
# ** Metadata for a track_rewrites scope
|
27
27
|
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
28
|
+
ref_map:dict[Any, int] = {}
|
29
|
+
def get_metadata(keys:list[TracingKey], contexts:list[list[TrackedGraphRewrite]]) -> list[dict]:
|
30
|
+
ret = []
|
31
|
+
for i,(k,v) in enumerate(zip(keys, contexts)):
|
32
|
+
steps = [{"name":s.name, "loc":s.loc, "depth":s.depth, "match_count":len(s.matches), "code_line":printable(s.loc),
|
33
|
+
"query":f"/ctxs?ctx={i}&idx={j}"} for j,s in enumerate(v)]
|
34
|
+
ret.append(r:={"name":k.display_name, "steps":steps})
|
35
|
+
# use the first key to get runtime profiling data about this context
|
36
|
+
if getenv("PROFILE_VALUE") >= 2 and k.keys: r["runtime_stats"] = get_runtime_stats(k.keys[0])
|
37
|
+
# program spec metadata
|
38
|
+
if isinstance(k.ret, ProgramSpec):
|
39
|
+
steps.append({"name":"View Disassembly", "query":f"/disasm?ctx={i}"})
|
40
|
+
r["fmt"] = k.ret.src
|
41
|
+
for key in k.keys: ref_map[key] = i
|
42
|
+
return ret
|
42
43
|
|
43
44
|
# ** Complete rewrite details for a graph_rewrite call
|
44
45
|
|
45
46
|
class GraphRewriteDetails(TypedDict):
|
46
47
|
graph: dict # JSON serialized UOp for this rewrite step
|
47
48
|
uop: str # strigified UOp for this rewrite step
|
48
|
-
diff: list[str]|None #
|
49
|
+
diff: list[str]|None # diff of the single UOp that changed
|
49
50
|
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
|
50
51
|
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
|
51
52
|
|
52
|
-
def
|
53
|
+
def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in s) + ")"
|
54
|
+
def mask_to_str(s:tuple[tuple[sint, sint], ...]): return "(" + ','.join(shape_to_str(x) for x in s) + ")"
|
55
|
+
|
56
|
+
def uop_to_json(x:UOp) -> dict[int, dict]:
|
53
57
|
assert isinstance(x, UOp)
|
54
|
-
|
55
|
-
graph: dict[int, tuple[str, list[int], str]] = {}
|
58
|
+
graph: dict[int, dict] = {}
|
56
59
|
excluded: set[UOp] = set()
|
57
|
-
for u in (toposort:=x.toposort):
|
60
|
+
for u in (toposort:=x.toposort()):
|
58
61
|
# always exclude DEVICE/CONST/UNIQUE
|
59
|
-
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE}: excluded.add(u)
|
62
|
+
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
|
60
63
|
# only exclude CONST VIEW source if it has no other children in the graph
|
61
64
|
if u.op is Ops.CONST and len(u.src) != 0 and all(cr.op is Ops.CONST for c in u.src[0].children if (cr:=c()) is not None and cr in toposort):
|
62
65
|
excluded.update(u.src)
|
63
66
|
for u in toposort:
|
64
67
|
if u in excluded: continue
|
65
|
-
argst = str(u.arg)
|
68
|
+
argst = codecs.decode(str(u.arg), "unicode_escape")
|
66
69
|
if u.op is Ops.VIEW:
|
67
|
-
argst = ("\n".join([f"{v.shape} / {v.strides}"+(
|
68
|
-
("" if v.
|
70
|
+
argst = ("\n".join([f"{shape_to_str(v.shape)} / {shape_to_str(v.strides)}"+("" if v.offset == 0 else f" / {srender(v.offset)}")+
|
71
|
+
(f"\nMASK {mask_to_str(v.mask)}" if v.mask is not None else "") for v in unwrap(u.st).views]))
|
69
72
|
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
|
70
73
|
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
|
71
74
|
for idx,x in enumerate(u.src):
|
72
75
|
if x in excluded:
|
73
|
-
if x.op is Ops.CONST and dtypes.is_float(u.dtype)
|
74
|
-
|
75
|
-
|
76
|
+
arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(u.dtype) else f"{x.arg}"
|
77
|
+
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
|
78
|
+
try:
|
79
|
+
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
|
80
|
+
label += f"\n{shape_to_str(u.shape)}"
|
81
|
+
elif len(rngs:=u.ranges):
|
82
|
+
label += f"\n{str(sorted([x.arg for x in rngs]))}"
|
83
|
+
except Exception:
|
84
|
+
label += "\n<ISSUE GETTING LABEL>"
|
85
|
+
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
|
86
|
+
# NOTE: kernel already has metadata in arg
|
87
|
+
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.KERNEL: label += "\n"+repr(u.metadata)
|
88
|
+
graph[id(u)] = {"label":label, "src":[id(x) for x in u.src if x not in excluded], "color":uops_colors.get(u.op, "#ffffff"),
|
89
|
+
"ref":ref, "tag":u.tag}
|
76
90
|
return graph
|
77
91
|
|
78
|
-
|
79
|
-
|
92
|
+
@functools.cache
|
93
|
+
def _reconstruct(a:int):
|
94
|
+
op, dtype, src, arg, tag = contexts[2][a]
|
95
|
+
arg = type(arg)(_reconstruct(arg.ast), arg.metadata) if op is Ops.KERNEL else arg
|
96
|
+
return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, tag)
|
97
|
+
|
98
|
+
def get_details(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
|
99
|
+
yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink)), "uop":str(next_sink), "changed_nodes":None, "diff":None, "upat":None}
|
80
100
|
replaces: dict[UOp, UOp] = {}
|
81
|
-
for
|
82
|
-
replaces[u0] = u1
|
83
|
-
new_sink = next_sink.substitute(replaces)
|
84
|
-
|
85
|
-
|
101
|
+
for u0_num,u1_num,upat_loc in tqdm(ctx.matches):
|
102
|
+
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
|
103
|
+
try: new_sink = next_sink.substitute(replaces)
|
104
|
+
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
|
105
|
+
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":str(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
|
106
|
+
"diff":list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())), "upat":(upat_loc, printable(upat_loc))}
|
86
107
|
if not ctx.bottom_up: next_sink = new_sink
|
87
108
|
|
88
109
|
# Profiler API
|
89
|
-
|
90
|
-
|
91
|
-
def
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
110
|
+
|
111
|
+
device_ts_diffs:dict[str, tuple[Decimal, Decimal]] = {}
|
112
|
+
def cpu_ts_diff(device:str, thread=0) -> Decimal: return device_ts_diffs.get(device, (Decimal(0),))[thread]
|
113
|
+
|
114
|
+
DevEvent = ProfileRangeEvent|ProfileGraphEntry|ProfilePointEvent
|
115
|
+
def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]:
|
116
|
+
for e in profile:
|
117
|
+
if isinstance(e, ProfileRangeEvent): yield (e.st+(diff:=cpu_ts_diff(e.device, e.is_copy)), (e.en if e.en is not None else e.st)+diff, e)
|
118
|
+
elif isinstance(e, ProfilePointEvent): yield (e.ts, e.ts, e)
|
119
|
+
elif isinstance(e, ProfileGraphEvent):
|
120
|
+
cpu_ts = []
|
121
|
+
for ent in e.ents: cpu_ts += [e.sigs[ent.st_id]+(diff:=cpu_ts_diff(ent.device, ent.is_copy)), e.sigs[ent.en_id]+diff]
|
122
|
+
yield (st:=min(cpu_ts)), (et:=max(cpu_ts)), ProfileRangeEvent(f"{e.ents[0].device.split(':')[0]} Graph", f"batched {len(e.ents)}", st, et)
|
123
|
+
for i,ent in enumerate(e.ents): yield (cpu_ts[i*2], cpu_ts[i*2+1], ent)
|
124
|
+
|
125
|
+
# timeline layout stacks events in a contiguous block. When a late starter finishes late, there is whitespace in the higher levels.
|
126
|
+
def timeline_layout(events:list[tuple[int, int, float, DevEvent]]) -> dict:
|
127
|
+
shapes:list[dict] = []
|
128
|
+
levels:list[int] = []
|
129
|
+
for st,et,dur,e in events:
|
130
|
+
if dur == 0: continue
|
131
|
+
# find a free level to put the event
|
132
|
+
depth = next((i for i,level_et in enumerate(levels) if st>=level_et), len(levels))
|
133
|
+
if depth < len(levels): levels[depth] = et
|
134
|
+
else: levels.append(et)
|
135
|
+
name, cat, info = e.name, None, None
|
136
|
+
if (ref:=ref_map.get(name)) is not None:
|
137
|
+
name = ctxs[ref]["name"]
|
138
|
+
# TODO: support symbolic by capturing var_vals in profile events
|
139
|
+
if isinstance(p:=contexts[0][ref].ret, ProgramSpec) and all(isinstance(es,int) for es in [p.estimates.ops, p.estimates.mem, p.estimates.lds]):
|
140
|
+
info = f"{p.estimates.ops/(t:=dur*1e3):.2f} GFLOPS {p.estimates.mem/t:4.1f}|{p.estimates.lds/t:.1f} GB/s"
|
141
|
+
elif isinstance(e.name, TracingKey):
|
142
|
+
name, cat = e.name.display_name, e.name.cat
|
143
|
+
ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None)
|
144
|
+
shapes.append({"name":name, "ref":ref, "st":st, "dur":dur, "depth":depth, "cat":cat, "info":info})
|
145
|
+
return {"shapes":shapes, "maxDepth":len(levels)}
|
146
|
+
|
147
|
+
def mem_layout(events:list[tuple[int, int, float, DevEvent]], max_ts:int) -> dict:
|
148
|
+
step, peak, mem = 0, 0, 0
|
149
|
+
shps:dict[int, dict] = {}
|
150
|
+
temp:dict[int, dict] = {}
|
151
|
+
timestamps:list[int] = []
|
152
|
+
for st,_,_,e in events:
|
153
|
+
if not isinstance(e, ProfilePointEvent): continue
|
154
|
+
if e.name == "alloc":
|
155
|
+
shps[e.key] = temp[e.key] = {"x":[step], "y":[mem], "arg":e.arg}
|
156
|
+
timestamps.append(int(e.ts))
|
157
|
+
step += 1
|
158
|
+
mem += e.arg["nbytes"]
|
159
|
+
if mem > peak: peak = mem
|
160
|
+
if e.name == "free":
|
161
|
+
timestamps.append(int(e.ts))
|
162
|
+
step += 1
|
163
|
+
mem -= (removed:=temp.pop(e.key))["arg"]["nbytes"]
|
164
|
+
removed["x"].append(step)
|
165
|
+
removed["y"].append(removed["y"][-1])
|
166
|
+
for k,v in temp.items():
|
167
|
+
if k > e.key:
|
168
|
+
v["x"] += [step, step]
|
169
|
+
v["y"] += [v["y"][-1], v["y"][-1]-removed["arg"]["nbytes"]]
|
170
|
+
for v in temp.values():
|
171
|
+
v["x"].append(step)
|
172
|
+
v["y"].append(v["y"][-1])
|
173
|
+
timestamps.append(max_ts)
|
174
|
+
return {"shapes":list(shps.values()), "peak":peak, "timestamps":timestamps}
|
175
|
+
|
176
|
+
def get_profile(profile:list[ProfileEvent]) -> bytes|None:
|
177
|
+
# start by getting the time diffs
|
178
|
+
for ev in profile:
|
179
|
+
if isinstance(ev,ProfileDeviceEvent): device_ts_diffs[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff)
|
180
|
+
# map events per device
|
181
|
+
dev_events:dict[str, list[tuple[int, int, float, DevEvent]]] = {}
|
182
|
+
min_ts:int|None = None
|
183
|
+
max_ts:int|None = None
|
184
|
+
for ts,en,e in flatten_events(profile):
|
185
|
+
dev_events.setdefault(e.device,[]).append((st:=int(ts), et:=int(en), float(en-ts), e))
|
186
|
+
if min_ts is None or st < min_ts: min_ts = st
|
187
|
+
if max_ts is None or et > max_ts: max_ts = et
|
188
|
+
if min_ts is None: return None
|
189
|
+
# return layout of per device events
|
190
|
+
layout:dict[str, dict] = {}
|
191
|
+
for k,v in dev_events.items():
|
192
|
+
v.sort(key=lambda e:e[0])
|
193
|
+
layout[k] = timeline_layout(v)
|
194
|
+
layout[f"{k} Memory"] = mem_layout(v, unwrap(max_ts))
|
195
|
+
return json.dumps({"layout":layout, "st":min_ts, "et":max_ts}).encode("utf-8")
|
196
|
+
|
197
|
+
def get_runtime_stats(key) -> list[dict]:
|
198
|
+
ret:list[dict] = []
|
199
|
+
for e in profile:
|
200
|
+
if isinstance(e, ProfileRangeEvent) and e.en is not None and e.name == key:
|
201
|
+
ret.append({"device":e.device, "data":[{"name":"Duration", "value":float(e.en-e.st), "unit":"us"}]})
|
108
202
|
return ret
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
203
|
+
|
204
|
+
# ** Assembly analyzers
|
205
|
+
|
206
|
+
def get_llvm_mca(asm:str, mtriple:str, mcpu:str) -> dict:
|
207
|
+
target_args = [f"-mtriple={mtriple}", f"-mcpu={mcpu}"]
|
208
|
+
# disassembly output can include headers / metadata, skip if llvm-mca can't parse those lines
|
209
|
+
data = json.loads(subprocess.check_output(["llvm-mca","-skip-unsupported-instructions=parse-failure","--json","-"]+target_args, input=asm.encode()))
|
210
|
+
cr = data["CodeRegions"][0]
|
211
|
+
resource_labels = data["TargetInfo"]["Resources"]
|
212
|
+
rows:list = [[instr] for instr in cr["Instructions"]]
|
213
|
+
# add scheduler estimates
|
214
|
+
for info in cr["InstructionInfoView"]["InstructionList"]: rows[info["Instruction"]].append(info["Latency"])
|
215
|
+
# map per instruction resource usage
|
216
|
+
instr_usage:dict[int, dict[int, int]] = {}
|
217
|
+
for d in cr["ResourcePressureView"]["ResourcePressureInfo"]:
|
218
|
+
instr_usage.setdefault(i:=d["InstructionIndex"], {}).setdefault(r:=d["ResourceIndex"], 0)
|
219
|
+
instr_usage[i][r] += d["ResourceUsage"]
|
220
|
+
# last row is the usage summary
|
221
|
+
summary = [{"idx":k, "label":resource_labels[k], "value":v} for k,v in instr_usage.pop(len(rows), {}).items()]
|
222
|
+
max_usage = max([sum(v.values()) for i,v in instr_usage.items() if i<len(rows)], default=0)
|
223
|
+
for i,usage in instr_usage.items(): rows[i].append([[k, v, (v/max_usage)*100] for k,v in usage.items()])
|
224
|
+
return {"rows":rows, "cols":["Opcode", "Latency", {"title":"HW Resources", "labels":resource_labels}], "summary":summary}
|
225
|
+
|
226
|
+
def get_disassembly(ctx:list[str]):
|
227
|
+
if not isinstance(prg:=contexts[0][int(ctx[0])].ret, ProgramSpec): return
|
228
|
+
lib = (compiler:=Device[prg.device].compiler).compile(prg.src)
|
229
|
+
with redirect_stdout(buf:=io.StringIO()): compiler.disassemble(lib)
|
230
|
+
disasm_str = buf.getvalue()
|
231
|
+
from tinygrad.runtime.ops_llvm import llvm, LLVMCompiler
|
232
|
+
if isinstance(compiler, LLVMCompiler):
|
233
|
+
mtriple = ctypes.string_at(llvm.LLVMGetTargetMachineTriple(tm:=compiler.target_machine)).decode()
|
234
|
+
mcpu = ctypes.string_at(llvm.LLVMGetTargetMachineCPU(tm)).decode()
|
235
|
+
ret = get_llvm_mca(disasm_str, mtriple, mcpu)
|
236
|
+
else: ret = {"src":disasm_str}
|
237
|
+
return json.dumps(ret).encode()
|
116
238
|
|
117
239
|
# ** HTTP server
|
118
240
|
|
@@ -122,33 +244,17 @@ class Handler(BaseHTTPRequestHandler):
|
|
122
244
|
|
123
245
|
if (url:=urlparse(self.path)).path == "/":
|
124
246
|
with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read()
|
125
|
-
elif
|
126
|
-
with open(os.path.join(os.path.dirname(__file__), "perfetto.html"), "rb") as f: ret = f.read()
|
127
|
-
elif self.path.startswith("/assets/") and '/..' not in self.path:
|
247
|
+
elif self.path.startswith(("/assets/", "/js/")) and '/..' not in self.path:
|
128
248
|
try:
|
129
249
|
with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read()
|
130
250
|
if url.path.endswith(".js"): content_type = "application/javascript"
|
131
251
|
if url.path.endswith(".css"): content_type = "text/css"
|
132
252
|
except FileNotFoundError: status_code = 404
|
133
|
-
elif url.
|
134
|
-
if "
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
# stream details
|
139
|
-
self.send_response(200)
|
140
|
-
self.send_header("Content-Type", "text/event-stream")
|
141
|
-
self.send_header("Cache-Control", "no-cache")
|
142
|
-
self.end_headers()
|
143
|
-
for r in get_details(contexts[0][kidx], contexts[1][kidx][ridx]):
|
144
|
-
self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
|
145
|
-
self.wfile.flush()
|
146
|
-
self.wfile.write("data: END\n\n".encode("utf-8"))
|
147
|
-
return self.wfile.flush()
|
148
|
-
# pass if client closed connection
|
149
|
-
except (BrokenPipeError, ConnectionResetError): return
|
150
|
-
ret, content_type = json.dumps(kernels).encode(), "application/json"
|
151
|
-
elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json"
|
253
|
+
elif (query:=parse_qs(url.query)):
|
254
|
+
if url.path == "/disasm": ret, content_type = get_disassembly(**query), "application/json"
|
255
|
+
else: return self.stream_json(get_details(contexts[1][int(query["ctx"][0])][int(query["idx"][0])]))
|
256
|
+
elif url.path == "/ctxs": ret, content_type = json.dumps(ctxs).encode(), "application/json"
|
257
|
+
elif url.path == "/get_profile" and profile_ret is not None: ret, content_type = profile_ret, "application/json"
|
152
258
|
else: status_code = 404
|
153
259
|
|
154
260
|
# send response
|
@@ -158,6 +264,19 @@ class Handler(BaseHTTPRequestHandler):
|
|
158
264
|
self.end_headers()
|
159
265
|
return self.wfile.write(ret)
|
160
266
|
|
267
|
+
def stream_json(self, source:Generator):
|
268
|
+
try:
|
269
|
+
self.send_response(200)
|
270
|
+
self.send_header("Content-Type", "text/event-stream")
|
271
|
+
self.send_header("Cache-Control", "no-cache")
|
272
|
+
self.end_headers()
|
273
|
+
for r in source:
|
274
|
+
self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
|
275
|
+
self.wfile.flush()
|
276
|
+
self.wfile.write("data: END\n\n".encode("utf-8"))
|
277
|
+
# pass if client closed connection
|
278
|
+
except (BrokenPipeError, ConnectionResetError): return
|
279
|
+
|
161
280
|
# ** main loop
|
162
281
|
|
163
282
|
def reloader():
|
@@ -172,6 +291,9 @@ def load_pickle(path:str):
|
|
172
291
|
if path is None or not os.path.exists(path): return None
|
173
292
|
with open(path, "rb") as f: return pickle.load(f)
|
174
293
|
|
294
|
+
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
|
295
|
+
class TCPServerWithReuse(socketserver.TCPServer): allow_reuse_address = True
|
296
|
+
|
175
297
|
if __name__ == "__main__":
|
176
298
|
parser = argparse.ArgumentParser()
|
177
299
|
parser.add_argument('--kernels', type=str, help='Path to kernels', default=None)
|
@@ -189,15 +311,15 @@ if __name__ == "__main__":
|
|
189
311
|
contexts, profile = load_pickle(args.kernels), load_pickle(args.profile)
|
190
312
|
|
191
313
|
# NOTE: this context is a tuple of list[keys] and list[values]
|
192
|
-
|
314
|
+
ctxs = get_metadata(*contexts[:2]) if contexts is not None else []
|
193
315
|
|
194
|
-
|
316
|
+
profile_ret = get_profile(profile) if profile is not None else None
|
195
317
|
|
196
|
-
server =
|
318
|
+
server = TCPServerWithReuse(('', PORT), Handler)
|
197
319
|
reloader_thread = threading.Thread(target=reloader)
|
198
320
|
reloader_thread.start()
|
199
321
|
print(f"*** started viz on {HOST}:{PORT}")
|
200
|
-
print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"))
|
322
|
+
print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"), flush=True)
|
201
323
|
if len(getenv("BROWSER", "")) > 0: webbrowser.open(f"{HOST}:{PORT}{'/profiler' if contexts is None else ''}")
|
202
324
|
try: server.serve_forever()
|
203
325
|
except KeyboardInterrupt:
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: tinygrad
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.11.0
|
4
4
|
Summary: You like pytorch? You like micrograd? You love tinygrad! <3
|
5
5
|
Author: George Hotz
|
6
6
|
License: MIT
|
@@ -19,31 +19,38 @@ Requires-Dist: mypy==1.13.0; extra == "linting"
|
|
19
19
|
Requires-Dist: typing-extensions; extra == "linting"
|
20
20
|
Requires-Dist: pre-commit; extra == "linting"
|
21
21
|
Requires-Dist: ruff; extra == "linting"
|
22
|
-
Requires-Dist:
|
22
|
+
Requires-Dist: numpy; extra == "linting"
|
23
23
|
Provides-Extra: testing-minimal
|
24
24
|
Requires-Dist: numpy; extra == "testing-minimal"
|
25
|
-
Requires-Dist: torch; extra == "testing-minimal"
|
25
|
+
Requires-Dist: torch==2.7.1; extra == "testing-minimal"
|
26
26
|
Requires-Dist: pytest; extra == "testing-minimal"
|
27
27
|
Requires-Dist: pytest-xdist; extra == "testing-minimal"
|
28
28
|
Requires-Dist: hypothesis; extra == "testing-minimal"
|
29
|
+
Requires-Dist: z3-solver; extra == "testing-minimal"
|
30
|
+
Requires-Dist: ml_dtypes; extra == "testing-minimal"
|
29
31
|
Provides-Extra: testing-unit
|
30
32
|
Requires-Dist: numpy; extra == "testing-unit"
|
31
|
-
Requires-Dist: torch; extra == "testing-unit"
|
33
|
+
Requires-Dist: torch==2.7.1; extra == "testing-unit"
|
32
34
|
Requires-Dist: pytest; extra == "testing-unit"
|
33
35
|
Requires-Dist: pytest-xdist; extra == "testing-unit"
|
34
36
|
Requires-Dist: hypothesis; extra == "testing-unit"
|
37
|
+
Requires-Dist: z3-solver; extra == "testing-unit"
|
38
|
+
Requires-Dist: ml_dtypes; extra == "testing-unit"
|
35
39
|
Requires-Dist: tqdm; extra == "testing-unit"
|
36
40
|
Requires-Dist: safetensors; extra == "testing-unit"
|
37
41
|
Requires-Dist: tabulate; extra == "testing-unit"
|
38
42
|
Provides-Extra: testing
|
39
43
|
Requires-Dist: numpy; extra == "testing"
|
40
|
-
Requires-Dist: torch; extra == "testing"
|
44
|
+
Requires-Dist: torch==2.7.1; extra == "testing"
|
41
45
|
Requires-Dist: pytest; extra == "testing"
|
42
46
|
Requires-Dist: pytest-xdist; extra == "testing"
|
43
47
|
Requires-Dist: hypothesis; extra == "testing"
|
48
|
+
Requires-Dist: z3-solver; extra == "testing"
|
49
|
+
Requires-Dist: ml_dtypes; extra == "testing"
|
44
50
|
Requires-Dist: pillow; extra == "testing"
|
45
|
-
Requires-Dist: onnx==1.
|
51
|
+
Requires-Dist: onnx==1.18.0; extra == "testing"
|
46
52
|
Requires-Dist: onnx2torch; extra == "testing"
|
53
|
+
Requires-Dist: onnxruntime; extra == "testing"
|
47
54
|
Requires-Dist: opencv-python; extra == "testing"
|
48
55
|
Requires-Dist: tabulate; extra == "testing"
|
49
56
|
Requires-Dist: tqdm; extra == "testing"
|
@@ -58,6 +65,10 @@ Requires-Dist: nibabel; extra == "testing"
|
|
58
65
|
Requires-Dist: bottle; extra == "testing"
|
59
66
|
Requires-Dist: ggml-python; extra == "testing"
|
60
67
|
Requires-Dist: capstone; extra == "testing"
|
68
|
+
Requires-Dist: pycocotools; extra == "testing"
|
69
|
+
Requires-Dist: boto3; extra == "testing"
|
70
|
+
Requires-Dist: pandas; extra == "testing"
|
71
|
+
Requires-Dist: influxdb3-python; extra == "testing"
|
61
72
|
Provides-Extra: docs
|
62
73
|
Requires-Dist: mkdocs; extra == "docs"
|
63
74
|
Requires-Dist: mkdocs-material; extra == "docs"
|
@@ -66,14 +77,12 @@ Requires-Dist: markdown-callouts; extra == "docs"
|
|
66
77
|
Requires-Dist: markdown-exec[ansi]; extra == "docs"
|
67
78
|
Requires-Dist: black; extra == "docs"
|
68
79
|
Requires-Dist: numpy; extra == "docs"
|
69
|
-
Provides-Extra: testing-tf
|
70
|
-
Requires-Dist: tensorflow==2.15.1; extra == "testing-tf"
|
71
|
-
Requires-Dist: tensorflow_addons; extra == "testing-tf"
|
72
80
|
Dynamic: author
|
73
81
|
Dynamic: classifier
|
74
82
|
Dynamic: description
|
75
83
|
Dynamic: description-content-type
|
76
84
|
Dynamic: license
|
85
|
+
Dynamic: license-file
|
77
86
|
Dynamic: provides-extra
|
78
87
|
Dynamic: requires-python
|
79
88
|
Dynamic: summary
|
@@ -101,11 +110,11 @@ tinygrad: For something between [PyTorch](https://github.com/pytorch/pytorch) an
|
|
101
110
|
|
102
111
|
---
|
103
112
|
|
104
|
-
|
113
|
+
Despite tinygrad's size, it is a fully featured deep learning framework.
|
105
114
|
|
106
|
-
Due to its extreme simplicity, it
|
115
|
+
Due to its extreme simplicity, it is the easiest framework to add new accelerators to, with support for both inference and training. If XLA is CISC, tinygrad is RISC.
|
107
116
|
|
108
|
-
tinygrad is
|
117
|
+
tinygrad is now beta software, we [raised some money](https://geohot.github.io/blog/jekyll/update/2023/05/24/the-tiny-corp-raised-5M.html) to make it good. Someday, we will tape out chips.
|
109
118
|
|
110
119
|
## Features
|
111
120
|
|
@@ -119,9 +128,8 @@ Try a matmul. See how, despite the style, it is fused into one kernel with the p
|
|
119
128
|
|
120
129
|
```sh
|
121
130
|
DEBUG=3 python3 -c "from tinygrad import Tensor;
|
122
|
-
N = 1024; a, b = Tensor.
|
123
|
-
|
124
|
-
print((c.numpy() - (a.numpy() @ b.numpy())).mean())"
|
131
|
+
N = 1024; a, b = Tensor.empty(N, N), Tensor.empty(N, N);
|
132
|
+
(a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2).realize()"
|
125
133
|
```
|
126
134
|
|
127
135
|
And we can change `DEBUG` to `4` to see the generated code.
|