tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,29 @@
1
+ const NODE_PADDING = 10;
2
+ const LINE_HEIGHT = 14;
3
+ const canvas = new OffscreenCanvas(0, 0);
4
+ const ctx = canvas.getContext("2d");
5
+ ctx.font = `${LINE_HEIGHT}px sans-serif`;
6
+
7
+ onmessage = (e) => {
8
+ const { graph, additions, ctxs } = e.data;
9
+ const g = new dagre.graphlib.Graph({ compound: true });
10
+ g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
11
+ if (additions.length !== 0) g.setNode("addition", {label:"", style:"fill: rgba(26, 27, 38, 0.5);", padding:0});
12
+ for (let [k, {label, src, ref, ...rest }] of Object.entries(graph)) {
13
+ // adjust node dims by label size (excluding escape codes) + add padding
14
+ let [width, height] = [0, 0];
15
+ for (line of label.replace(/\u001B\[(?:K|.*?m)/g, "").split("\n")) {
16
+ width = Math.max(width, ctx.measureText(line).width);
17
+ height += LINE_HEIGHT;
18
+ }
19
+ g.setNode(k, {width:width+NODE_PADDING*2, height:height+NODE_PADDING*2, padding:NODE_PADDING, label, ref, ...rest});
20
+ // add edges
21
+ const edgeCounts = {}
22
+ for (const s of src) edgeCounts[s] = (edgeCounts[s] || 0)+1;
23
+ for (const s of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? edgeCounts[s] : null });
24
+ if (additions.includes(parseInt(k))) g.setParent(k, "addition");
25
+ }
26
+ dagre.layout(g);
27
+ postMessage(dagre.graphlib.json.write(g));
28
+ self.close();
29
+ }
tinygrad/viz/serve.py CHANGED
@@ -1,118 +1,240 @@
1
1
  #!/usr/bin/env python3
2
- import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal
3
- from http.server import HTTPServer, BaseHTTPRequestHandler
2
+ import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io
3
+ import subprocess, ctypes
4
+ from contextlib import redirect_stdout
5
+ from decimal import Decimal
6
+ from http.server import BaseHTTPRequestHandler
4
7
  from urllib.parse import parse_qs, urlparse
5
- from typing import Any, Callable, TypedDict, Generator
6
- from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap
7
- from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp
8
- from tinygrad.codegen.kernel import Kernel
9
- from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent
8
+ from typing import Any, TypedDict, Generator
9
+ from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent
10
+ from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, srender, sint
11
+ from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
12
+ from tinygrad.renderer import ProgramSpec
10
13
  from tinygrad.dtype import dtypes
11
14
 
12
- uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
13
- Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
14
- Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
15
+ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
16
+ Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
17
+ Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
15
18
  Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
16
- **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
17
- Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.NAME:"#808080"}
19
+ **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
20
+ Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
21
+ Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
22
+ Ops.CHILDREN: "#80ffc0", Ops.CHILD: "#80fff0", Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e"}
18
23
 
19
24
  # VIZ API
20
25
 
21
- # NOTE: if any extra rendering in VIZ fails, we don't crash
22
- def pcall(fxn:Callable[..., str], *args, **kwargs) -> str:
23
- try: return fxn(*args, **kwargs)
24
- except Exception as e: return f"ERROR: {e}"
25
-
26
26
  # ** Metadata for a track_rewrites scope
27
27
 
28
- class GraphRewriteMetadata(TypedDict):
29
- loc: tuple[str, int] # [path, lineno] calling graph_rewrite
30
- match_count: int # total match count in this context
31
- code_line: str # source code calling graph_rewrite
32
- kernel_code: str|None # optionally render the final kernel code
33
- name: str|None # optional name of the rewrite
34
-
35
- @functools.lru_cache(None)
36
- def _prg(k:Kernel): return k.to_program().src
37
- def to_metadata(k:Any, v:TrackedGraphRewrite) -> GraphRewriteMetadata:
38
- return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(),
39
- "kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None, "name":v.name}
40
- def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]:
41
- return [(k.name if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)]
28
+ ref_map:dict[Any, int] = {}
29
+ def get_metadata(keys:list[TracingKey], contexts:list[list[TrackedGraphRewrite]]) -> list[dict]:
30
+ ret = []
31
+ for i,(k,v) in enumerate(zip(keys, contexts)):
32
+ steps = [{"name":s.name, "loc":s.loc, "depth":s.depth, "match_count":len(s.matches), "code_line":printable(s.loc),
33
+ "query":f"/ctxs?ctx={i}&idx={j}"} for j,s in enumerate(v)]
34
+ ret.append(r:={"name":k.display_name, "steps":steps})
35
+ # use the first key to get runtime profiling data about this context
36
+ if getenv("PROFILE_VALUE") >= 2 and k.keys: r["runtime_stats"] = get_runtime_stats(k.keys[0])
37
+ # program spec metadata
38
+ if isinstance(k.ret, ProgramSpec):
39
+ steps.append({"name":"View Disassembly", "query":f"/disasm?ctx={i}"})
40
+ r["fmt"] = k.ret.src
41
+ for key in k.keys: ref_map[key] = i
42
+ return ret
42
43
 
43
44
  # ** Complete rewrite details for a graph_rewrite call
44
45
 
45
46
  class GraphRewriteDetails(TypedDict):
46
47
  graph: dict # JSON serialized UOp for this rewrite step
47
48
  uop: str # strigified UOp for this rewrite step
48
- diff: list[str]|None # string diff of the single UOp that changed
49
+ diff: list[str]|None # diff of the single UOp that changed
49
50
  changed_nodes: list[int]|None # the changed UOp id + all its parents ids
50
51
  upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
51
52
 
52
- def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]:
53
+ def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in s) + ")"
54
+ def mask_to_str(s:tuple[tuple[sint, sint], ...]): return "(" + ','.join(shape_to_str(x) for x in s) + ")"
55
+
56
+ def uop_to_json(x:UOp) -> dict[int, dict]:
53
57
  assert isinstance(x, UOp)
54
- # NOTE: this is [id, [label, src_ids, color]]
55
- graph: dict[int, tuple[str, list[int], str]] = {}
58
+ graph: dict[int, dict] = {}
56
59
  excluded: set[UOp] = set()
57
- for u in (toposort:=x.toposort):
60
+ for u in (toposort:=x.toposort()):
58
61
  # always exclude DEVICE/CONST/UNIQUE
59
- if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE}: excluded.add(u)
62
+ if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
60
63
  # only exclude CONST VIEW source if it has no other children in the graph
61
64
  if u.op is Ops.CONST and len(u.src) != 0 and all(cr.op is Ops.CONST for c in u.src[0].children if (cr:=c()) is not None and cr in toposort):
62
65
  excluded.update(u.src)
63
66
  for u in toposort:
64
67
  if u in excluded: continue
65
- argst = str(u.arg)
68
+ argst = codecs.decode(str(u.arg), "unicode_escape")
66
69
  if u.op is Ops.VIEW:
67
- argst = ("\n".join([f"{v.shape} / {v.strides}"+(f"\nMASK {v.mask}" if v.mask is not None else "")+
68
- ("" if v.offset == 0 else f" / {v.offset}") for v in unwrap(u.st).views]))
70
+ argst = ("\n".join([f"{shape_to_str(v.shape)} / {shape_to_str(v.strides)}"+("" if v.offset == 0 else f" / {srender(v.offset)}")+
71
+ (f"\nMASK {mask_to_str(v.mask)}" if v.mask is not None else "") for v in unwrap(u.st).views]))
69
72
  label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
70
73
  if u.dtype != dtypes.void: label += f"\n{u.dtype}"
71
74
  for idx,x in enumerate(u.src):
72
75
  if x in excluded:
73
- if x.op is Ops.CONST and dtypes.is_float(u.dtype): label += f"\nCONST{idx} {x.arg:g}"
74
- else: label += f"\n{x.op.name}{idx} {x.arg}"
75
- graph[id(u)] = (label, [id(x) for x in u.src if x not in excluded], uops_colors.get(u.op, "#ffffff"))
76
+ arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(u.dtype) else f"{x.arg}"
77
+ label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
78
+ try:
79
+ if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
80
+ label += f"\n{shape_to_str(u.shape)}"
81
+ elif len(rngs:=u.ranges):
82
+ label += f"\n{str(sorted([x.arg for x in rngs]))}"
83
+ except Exception:
84
+ label += "\n<ISSUE GETTING LABEL>"
85
+ if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
86
+ # NOTE: kernel already has metadata in arg
87
+ if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.KERNEL: label += "\n"+repr(u.metadata)
88
+ graph[id(u)] = {"label":label, "src":[id(x) for x in u.src if x not in excluded], "color":uops_colors.get(u.op, "#ffffff"),
89
+ "ref":ref, "tag":u.tag}
76
90
  return graph
77
91
 
78
- def get_details(k:Any, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
79
- yield {"graph":uop_to_json(next_sink:=ctx.sink), "uop":str(ctx.sink), "changed_nodes":None, "diff":None, "upat":None}
92
+ @functools.cache
93
+ def _reconstruct(a:int):
94
+ op, dtype, src, arg, tag = contexts[2][a]
95
+ arg = type(arg)(_reconstruct(arg.ast), arg.metadata) if op is Ops.KERNEL else arg
96
+ return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, tag)
97
+
98
+ def get_details(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
99
+ yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink)), "uop":str(next_sink), "changed_nodes":None, "diff":None, "upat":None}
80
100
  replaces: dict[UOp, UOp] = {}
81
- for u0,u1,upat in tqdm(ctx.matches):
82
- replaces[u0] = u1
83
- new_sink = next_sink.substitute(replaces)
84
- yield {"graph": (sink_json:=uop_to_json(new_sink)), "uop":str(new_sink), "changed_nodes":[id(x) for x in u1.toposort if id(x) in sink_json],
85
- "diff":list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())), "upat":(upat.location, upat.printable())}
101
+ for u0_num,u1_num,upat_loc in tqdm(ctx.matches):
102
+ replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
103
+ try: new_sink = next_sink.substitute(replaces)
104
+ except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
105
+ yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":str(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
106
+ "diff":list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())), "upat":(upat_loc, printable(upat_loc))}
86
107
  if not ctx.bottom_up: next_sink = new_sink
87
108
 
88
109
  # Profiler API
89
- devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {}
90
- def prep_ts(device:str, ts:decimal.Decimal, is_copy): return int(decimal.Decimal(ts) + devices[device][is_copy])
91
- def dev_to_pid(device:str, is_copy=False): return {"pid": devices[device][2], "tid": int(is_copy)}
92
- def dev_ev_to_perfetto_json(ev:ProfileDeviceEvent):
93
- devices[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff, len(devices))
94
- return [{"name": "process_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "args": {"name": ev.device}},
95
- {"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 0, "args": {"name": "COMPUTE"}},
96
- {"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 1, "args": {"name": "COPY"}}]
97
- def range_ev_to_perfetto_json(ev:ProfileRangeEvent):
98
- return [{"name": ev.name, "ph": "X", "ts": prep_ts(ev.device, ev.st, ev.is_copy), "dur": float(ev.en-ev.st), **dev_to_pid(ev.device, ev.is_copy)}]
99
- def graph_ev_to_perfetto_json(ev:ProfileGraphEvent, reccnt):
100
- ret = []
101
- for i,e in enumerate(ev.ents):
102
- st, en = ev.sigs[e.st_id], ev.sigs[e.en_id]
103
- ret += [{"name": e.name, "ph": "X", "ts": prep_ts(e.device, st, e.is_copy), "dur": float(en-st), **dev_to_pid(e.device, e.is_copy)}]
104
- for dep in ev.deps[i]:
105
- d = ev.ents[dep]
106
- ret += [{"ph": "s", **dev_to_pid(d.device, d.is_copy), "id": reccnt+len(ret), "ts": prep_ts(d.device, ev.sigs[d.en_id], d.is_copy), "bp": "e"}]
107
- ret += [{"ph": "f", **dev_to_pid(e.device, e.is_copy), "id": reccnt+len(ret)-1, "ts": prep_ts(e.device, st, e.is_copy), "bp": "e"}]
110
+
111
+ device_ts_diffs:dict[str, tuple[Decimal, Decimal]] = {}
112
+ def cpu_ts_diff(device:str, thread=0) -> Decimal: return device_ts_diffs.get(device, (Decimal(0),))[thread]
113
+
114
+ DevEvent = ProfileRangeEvent|ProfileGraphEntry|ProfilePointEvent
115
+ def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]:
116
+ for e in profile:
117
+ if isinstance(e, ProfileRangeEvent): yield (e.st+(diff:=cpu_ts_diff(e.device, e.is_copy)), (e.en if e.en is not None else e.st)+diff, e)
118
+ elif isinstance(e, ProfilePointEvent): yield (e.ts, e.ts, e)
119
+ elif isinstance(e, ProfileGraphEvent):
120
+ cpu_ts = []
121
+ for ent in e.ents: cpu_ts += [e.sigs[ent.st_id]+(diff:=cpu_ts_diff(ent.device, ent.is_copy)), e.sigs[ent.en_id]+diff]
122
+ yield (st:=min(cpu_ts)), (et:=max(cpu_ts)), ProfileRangeEvent(f"{e.ents[0].device.split(':')[0]} Graph", f"batched {len(e.ents)}", st, et)
123
+ for i,ent in enumerate(e.ents): yield (cpu_ts[i*2], cpu_ts[i*2+1], ent)
124
+
125
+ # timeline layout stacks events in a contiguous block. When a late starter finishes late, there is whitespace in the higher levels.
126
+ def timeline_layout(events:list[tuple[int, int, float, DevEvent]]) -> dict:
127
+ shapes:list[dict] = []
128
+ levels:list[int] = []
129
+ for st,et,dur,e in events:
130
+ if dur == 0: continue
131
+ # find a free level to put the event
132
+ depth = next((i for i,level_et in enumerate(levels) if st>=level_et), len(levels))
133
+ if depth < len(levels): levels[depth] = et
134
+ else: levels.append(et)
135
+ name, cat, info = e.name, None, None
136
+ if (ref:=ref_map.get(name)) is not None:
137
+ name = ctxs[ref]["name"]
138
+ # TODO: support symbolic by capturing var_vals in profile events
139
+ if isinstance(p:=contexts[0][ref].ret, ProgramSpec) and all(isinstance(es,int) for es in [p.estimates.ops, p.estimates.mem, p.estimates.lds]):
140
+ info = f"{p.estimates.ops/(t:=dur*1e3):.2f} GFLOPS {p.estimates.mem/t:4.1f}|{p.estimates.lds/t:.1f} GB/s"
141
+ elif isinstance(e.name, TracingKey):
142
+ name, cat = e.name.display_name, e.name.cat
143
+ ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None)
144
+ shapes.append({"name":name, "ref":ref, "st":st, "dur":dur, "depth":depth, "cat":cat, "info":info})
145
+ return {"shapes":shapes, "maxDepth":len(levels)}
146
+
147
+ def mem_layout(events:list[tuple[int, int, float, DevEvent]], max_ts:int) -> dict:
148
+ step, peak, mem = 0, 0, 0
149
+ shps:dict[int, dict] = {}
150
+ temp:dict[int, dict] = {}
151
+ timestamps:list[int] = []
152
+ for st,_,_,e in events:
153
+ if not isinstance(e, ProfilePointEvent): continue
154
+ if e.name == "alloc":
155
+ shps[e.key] = temp[e.key] = {"x":[step], "y":[mem], "arg":e.arg}
156
+ timestamps.append(int(e.ts))
157
+ step += 1
158
+ mem += e.arg["nbytes"]
159
+ if mem > peak: peak = mem
160
+ if e.name == "free":
161
+ timestamps.append(int(e.ts))
162
+ step += 1
163
+ mem -= (removed:=temp.pop(e.key))["arg"]["nbytes"]
164
+ removed["x"].append(step)
165
+ removed["y"].append(removed["y"][-1])
166
+ for k,v in temp.items():
167
+ if k > e.key:
168
+ v["x"] += [step, step]
169
+ v["y"] += [v["y"][-1], v["y"][-1]-removed["arg"]["nbytes"]]
170
+ for v in temp.values():
171
+ v["x"].append(step)
172
+ v["y"].append(v["y"][-1])
173
+ timestamps.append(max_ts)
174
+ return {"shapes":list(shps.values()), "peak":peak, "timestamps":timestamps}
175
+
176
+ def get_profile(profile:list[ProfileEvent]) -> bytes|None:
177
+ # start by getting the time diffs
178
+ for ev in profile:
179
+ if isinstance(ev,ProfileDeviceEvent): device_ts_diffs[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff)
180
+ # map events per device
181
+ dev_events:dict[str, list[tuple[int, int, float, DevEvent]]] = {}
182
+ min_ts:int|None = None
183
+ max_ts:int|None = None
184
+ for ts,en,e in flatten_events(profile):
185
+ dev_events.setdefault(e.device,[]).append((st:=int(ts), et:=int(en), float(en-ts), e))
186
+ if min_ts is None or st < min_ts: min_ts = st
187
+ if max_ts is None or et > max_ts: max_ts = et
188
+ if min_ts is None: return None
189
+ # return layout of per device events
190
+ layout:dict[str, dict] = {}
191
+ for k,v in dev_events.items():
192
+ v.sort(key=lambda e:e[0])
193
+ layout[k] = timeline_layout(v)
194
+ layout[f"{k} Memory"] = mem_layout(v, unwrap(max_ts))
195
+ return json.dumps({"layout":layout, "st":min_ts, "et":max_ts}).encode("utf-8")
196
+
197
+ def get_runtime_stats(key) -> list[dict]:
198
+ ret:list[dict] = []
199
+ for e in profile:
200
+ if isinstance(e, ProfileRangeEvent) and e.en is not None and e.name == key:
201
+ ret.append({"device":e.device, "data":[{"name":"Duration", "value":float(e.en-e.st), "unit":"us"}]})
108
202
  return ret
109
- def to_perfetto(profile:list[ProfileEvent]):
110
- # Start json with devices.
111
- prof_json = [x for ev in profile if isinstance(ev, ProfileDeviceEvent) for x in dev_ev_to_perfetto_json(ev)]
112
- for ev in tqdm(profile, desc="preparing profile"):
113
- if isinstance(ev, ProfileRangeEvent): prof_json += range_ev_to_perfetto_json(ev)
114
- elif isinstance(ev, ProfileGraphEvent): prof_json += graph_ev_to_perfetto_json(ev, reccnt=len(prof_json))
115
- return json.dumps({"traceEvents": prof_json}).encode() if len(prof_json) > 0 else None
203
+
204
+ # ** Assembly analyzers
205
+
206
+ def get_llvm_mca(asm:str, mtriple:str, mcpu:str) -> dict:
207
+ target_args = [f"-mtriple={mtriple}", f"-mcpu={mcpu}"]
208
+ # disassembly output can include headers / metadata, skip if llvm-mca can't parse those lines
209
+ data = json.loads(subprocess.check_output(["llvm-mca","-skip-unsupported-instructions=parse-failure","--json","-"]+target_args, input=asm.encode()))
210
+ cr = data["CodeRegions"][0]
211
+ resource_labels = data["TargetInfo"]["Resources"]
212
+ rows:list = [[instr] for instr in cr["Instructions"]]
213
+ # add scheduler estimates
214
+ for info in cr["InstructionInfoView"]["InstructionList"]: rows[info["Instruction"]].append(info["Latency"])
215
+ # map per instruction resource usage
216
+ instr_usage:dict[int, dict[int, int]] = {}
217
+ for d in cr["ResourcePressureView"]["ResourcePressureInfo"]:
218
+ instr_usage.setdefault(i:=d["InstructionIndex"], {}).setdefault(r:=d["ResourceIndex"], 0)
219
+ instr_usage[i][r] += d["ResourceUsage"]
220
+ # last row is the usage summary
221
+ summary = [{"idx":k, "label":resource_labels[k], "value":v} for k,v in instr_usage.pop(len(rows), {}).items()]
222
+ max_usage = max([sum(v.values()) for i,v in instr_usage.items() if i<len(rows)], default=0)
223
+ for i,usage in instr_usage.items(): rows[i].append([[k, v, (v/max_usage)*100] for k,v in usage.items()])
224
+ return {"rows":rows, "cols":["Opcode", "Latency", {"title":"HW Resources", "labels":resource_labels}], "summary":summary}
225
+
226
+ def get_disassembly(ctx:list[str]):
227
+ if not isinstance(prg:=contexts[0][int(ctx[0])].ret, ProgramSpec): return
228
+ lib = (compiler:=Device[prg.device].compiler).compile(prg.src)
229
+ with redirect_stdout(buf:=io.StringIO()): compiler.disassemble(lib)
230
+ disasm_str = buf.getvalue()
231
+ from tinygrad.runtime.ops_llvm import llvm, LLVMCompiler
232
+ if isinstance(compiler, LLVMCompiler):
233
+ mtriple = ctypes.string_at(llvm.LLVMGetTargetMachineTriple(tm:=compiler.target_machine)).decode()
234
+ mcpu = ctypes.string_at(llvm.LLVMGetTargetMachineCPU(tm)).decode()
235
+ ret = get_llvm_mca(disasm_str, mtriple, mcpu)
236
+ else: ret = {"src":disasm_str}
237
+ return json.dumps(ret).encode()
116
238
 
117
239
  # ** HTTP server
118
240
 
@@ -122,33 +244,17 @@ class Handler(BaseHTTPRequestHandler):
122
244
 
123
245
  if (url:=urlparse(self.path)).path == "/":
124
246
  with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read()
125
- elif (url:=urlparse(self.path)).path == "/profiler":
126
- with open(os.path.join(os.path.dirname(__file__), "perfetto.html"), "rb") as f: ret = f.read()
127
- elif self.path.startswith("/assets/") and '/..' not in self.path:
247
+ elif self.path.startswith(("/assets/", "/js/")) and '/..' not in self.path:
128
248
  try:
129
249
  with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read()
130
250
  if url.path.endswith(".js"): content_type = "application/javascript"
131
251
  if url.path.endswith(".css"): content_type = "text/css"
132
252
  except FileNotFoundError: status_code = 404
133
- elif url.path == "/kernels":
134
- if "kernel" in (query:=parse_qs(url.query)):
135
- def getarg(k:str,default=0): return int(query[k][0]) if k in query else default
136
- kidx, ridx = getarg("kernel"), getarg("idx")
137
- try:
138
- # stream details
139
- self.send_response(200)
140
- self.send_header("Content-Type", "text/event-stream")
141
- self.send_header("Cache-Control", "no-cache")
142
- self.end_headers()
143
- for r in get_details(contexts[0][kidx], contexts[1][kidx][ridx]):
144
- self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
145
- self.wfile.flush()
146
- self.wfile.write("data: END\n\n".encode("utf-8"))
147
- return self.wfile.flush()
148
- # pass if client closed connection
149
- except (BrokenPipeError, ConnectionResetError): return
150
- ret, content_type = json.dumps(kernels).encode(), "application/json"
151
- elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json"
253
+ elif (query:=parse_qs(url.query)):
254
+ if url.path == "/disasm": ret, content_type = get_disassembly(**query), "application/json"
255
+ else: return self.stream_json(get_details(contexts[1][int(query["ctx"][0])][int(query["idx"][0])]))
256
+ elif url.path == "/ctxs": ret, content_type = json.dumps(ctxs).encode(), "application/json"
257
+ elif url.path == "/get_profile" and profile_ret is not None: ret, content_type = profile_ret, "application/json"
152
258
  else: status_code = 404
153
259
 
154
260
  # send response
@@ -158,6 +264,19 @@ class Handler(BaseHTTPRequestHandler):
158
264
  self.end_headers()
159
265
  return self.wfile.write(ret)
160
266
 
267
+ def stream_json(self, source:Generator):
268
+ try:
269
+ self.send_response(200)
270
+ self.send_header("Content-Type", "text/event-stream")
271
+ self.send_header("Cache-Control", "no-cache")
272
+ self.end_headers()
273
+ for r in source:
274
+ self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
275
+ self.wfile.flush()
276
+ self.wfile.write("data: END\n\n".encode("utf-8"))
277
+ # pass if client closed connection
278
+ except (BrokenPipeError, ConnectionResetError): return
279
+
161
280
  # ** main loop
162
281
 
163
282
  def reloader():
@@ -172,6 +291,9 @@ def load_pickle(path:str):
172
291
  if path is None or not os.path.exists(path): return None
173
292
  with open(path, "rb") as f: return pickle.load(f)
174
293
 
294
+ # NOTE: using HTTPServer forces a potentially slow socket.getfqdn
295
+ class TCPServerWithReuse(socketserver.TCPServer): allow_reuse_address = True
296
+
175
297
  if __name__ == "__main__":
176
298
  parser = argparse.ArgumentParser()
177
299
  parser.add_argument('--kernels', type=str, help='Path to kernels', default=None)
@@ -189,15 +311,15 @@ if __name__ == "__main__":
189
311
  contexts, profile = load_pickle(args.kernels), load_pickle(args.profile)
190
312
 
191
313
  # NOTE: this context is a tuple of list[keys] and list[values]
192
- kernels = get_metadata(*contexts) if contexts is not None else []
314
+ ctxs = get_metadata(*contexts[:2]) if contexts is not None else []
193
315
 
194
- perfetto_profile = to_perfetto(profile) if profile is not None else None
316
+ profile_ret = get_profile(profile) if profile is not None else None
195
317
 
196
- server = HTTPServer(('', PORT), Handler)
318
+ server = TCPServerWithReuse(('', PORT), Handler)
197
319
  reloader_thread = threading.Thread(target=reloader)
198
320
  reloader_thread.start()
199
321
  print(f"*** started viz on {HOST}:{PORT}")
200
- print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"))
322
+ print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"), flush=True)
201
323
  if len(getenv("BROWSER", "")) > 0: webbrowser.open(f"{HOST}:{PORT}{'/profiler' if contexts is None else ''}")
202
324
  try: server.serve_forever()
203
325
  except KeyboardInterrupt:
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: tinygrad
3
- Version: 0.10.2
3
+ Version: 0.11.0
4
4
  Summary: You like pytorch? You like micrograd? You love tinygrad! <3
5
5
  Author: George Hotz
6
6
  License: MIT
@@ -19,31 +19,38 @@ Requires-Dist: mypy==1.13.0; extra == "linting"
19
19
  Requires-Dist: typing-extensions; extra == "linting"
20
20
  Requires-Dist: pre-commit; extra == "linting"
21
21
  Requires-Dist: ruff; extra == "linting"
22
- Requires-Dist: types-tqdm; extra == "linting"
22
+ Requires-Dist: numpy; extra == "linting"
23
23
  Provides-Extra: testing-minimal
24
24
  Requires-Dist: numpy; extra == "testing-minimal"
25
- Requires-Dist: torch; extra == "testing-minimal"
25
+ Requires-Dist: torch==2.7.1; extra == "testing-minimal"
26
26
  Requires-Dist: pytest; extra == "testing-minimal"
27
27
  Requires-Dist: pytest-xdist; extra == "testing-minimal"
28
28
  Requires-Dist: hypothesis; extra == "testing-minimal"
29
+ Requires-Dist: z3-solver; extra == "testing-minimal"
30
+ Requires-Dist: ml_dtypes; extra == "testing-minimal"
29
31
  Provides-Extra: testing-unit
30
32
  Requires-Dist: numpy; extra == "testing-unit"
31
- Requires-Dist: torch; extra == "testing-unit"
33
+ Requires-Dist: torch==2.7.1; extra == "testing-unit"
32
34
  Requires-Dist: pytest; extra == "testing-unit"
33
35
  Requires-Dist: pytest-xdist; extra == "testing-unit"
34
36
  Requires-Dist: hypothesis; extra == "testing-unit"
37
+ Requires-Dist: z3-solver; extra == "testing-unit"
38
+ Requires-Dist: ml_dtypes; extra == "testing-unit"
35
39
  Requires-Dist: tqdm; extra == "testing-unit"
36
40
  Requires-Dist: safetensors; extra == "testing-unit"
37
41
  Requires-Dist: tabulate; extra == "testing-unit"
38
42
  Provides-Extra: testing
39
43
  Requires-Dist: numpy; extra == "testing"
40
- Requires-Dist: torch; extra == "testing"
44
+ Requires-Dist: torch==2.7.1; extra == "testing"
41
45
  Requires-Dist: pytest; extra == "testing"
42
46
  Requires-Dist: pytest-xdist; extra == "testing"
43
47
  Requires-Dist: hypothesis; extra == "testing"
48
+ Requires-Dist: z3-solver; extra == "testing"
49
+ Requires-Dist: ml_dtypes; extra == "testing"
44
50
  Requires-Dist: pillow; extra == "testing"
45
- Requires-Dist: onnx==1.16.0; extra == "testing"
51
+ Requires-Dist: onnx==1.18.0; extra == "testing"
46
52
  Requires-Dist: onnx2torch; extra == "testing"
53
+ Requires-Dist: onnxruntime; extra == "testing"
47
54
  Requires-Dist: opencv-python; extra == "testing"
48
55
  Requires-Dist: tabulate; extra == "testing"
49
56
  Requires-Dist: tqdm; extra == "testing"
@@ -58,6 +65,10 @@ Requires-Dist: nibabel; extra == "testing"
58
65
  Requires-Dist: bottle; extra == "testing"
59
66
  Requires-Dist: ggml-python; extra == "testing"
60
67
  Requires-Dist: capstone; extra == "testing"
68
+ Requires-Dist: pycocotools; extra == "testing"
69
+ Requires-Dist: boto3; extra == "testing"
70
+ Requires-Dist: pandas; extra == "testing"
71
+ Requires-Dist: influxdb3-python; extra == "testing"
61
72
  Provides-Extra: docs
62
73
  Requires-Dist: mkdocs; extra == "docs"
63
74
  Requires-Dist: mkdocs-material; extra == "docs"
@@ -66,14 +77,12 @@ Requires-Dist: markdown-callouts; extra == "docs"
66
77
  Requires-Dist: markdown-exec[ansi]; extra == "docs"
67
78
  Requires-Dist: black; extra == "docs"
68
79
  Requires-Dist: numpy; extra == "docs"
69
- Provides-Extra: testing-tf
70
- Requires-Dist: tensorflow==2.15.1; extra == "testing-tf"
71
- Requires-Dist: tensorflow_addons; extra == "testing-tf"
72
80
  Dynamic: author
73
81
  Dynamic: classifier
74
82
  Dynamic: description
75
83
  Dynamic: description-content-type
76
84
  Dynamic: license
85
+ Dynamic: license-file
77
86
  Dynamic: provides-extra
78
87
  Dynamic: requires-python
79
88
  Dynamic: summary
@@ -101,11 +110,11 @@ tinygrad: For something between [PyTorch](https://github.com/pytorch/pytorch) an
101
110
 
102
111
  ---
103
112
 
104
- This may not be the best deep learning framework, but it is a deep learning framework.
113
+ Despite tinygrad's size, it is a fully featured deep learning framework.
105
114
 
106
- Due to its extreme simplicity, it aims to be the easiest framework to add new accelerators to, with support for both inference and training. If XLA is CISC, tinygrad is RISC.
115
+ Due to its extreme simplicity, it is the easiest framework to add new accelerators to, with support for both inference and training. If XLA is CISC, tinygrad is RISC.
107
116
 
108
- tinygrad is still alpha software, but we [raised some money](https://geohot.github.io/blog/jekyll/update/2023/05/24/the-tiny-corp-raised-5M.html) to make it good. Someday, we will tape out chips.
117
+ tinygrad is now beta software, we [raised some money](https://geohot.github.io/blog/jekyll/update/2023/05/24/the-tiny-corp-raised-5M.html) to make it good. Someday, we will tape out chips.
109
118
 
110
119
  ## Features
111
120
 
@@ -119,9 +128,8 @@ Try a matmul. See how, despite the style, it is fused into one kernel with the p
119
128
 
120
129
  ```sh
121
130
  DEBUG=3 python3 -c "from tinygrad import Tensor;
122
- N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N);
123
- c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2);
124
- print((c.numpy() - (a.numpy() @ b.numpy())).mean())"
131
+ N = 1024; a, b = Tensor.empty(N, N), Tensor.empty(N, N);
132
+ (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2).realize()"
125
133
  ```
126
134
 
127
135
  And we can change `DEBUG` to `4` to see the generated code.