vibetrack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vibetrack/__init__.py +184 -0
- vibetrack/_graph.py +294 -0
- vibetrack/cli.py +618 -0
- vibetrack/compare.py +109 -0
- vibetrack/config.py +123 -0
- vibetrack/db.py +1457 -0
- vibetrack/default_config.py +45 -0
- vibetrack/media.py +369 -0
- vibetrack/reader.py +459 -0
- vibetrack/smoother.py +120 -0
- vibetrack/sysmetrics.py +606 -0
- vibetrack/types.py +81 -0
- vibetrack/viewers/__init__.py +65 -0
- vibetrack/viewers/_summary.py +130 -0
- vibetrack/viewers/base.py +78 -0
- vibetrack/viewers/console.py +252 -0
- vibetrack/viewers/event.py +86 -0
- vibetrack/viewers/gradio.py +1133 -0
- vibetrack/viewers/jupyter.py +158 -0
- vibetrack/viewers/mcp.py +936 -0
- vibetrack/viewers/remote.py +207 -0
- vibetrack/viewers/slack.py +583 -0
- vibetrack/viewers/telegram.py +350 -0
- vibetrack/viewers/web/css/style.css +1905 -0
- vibetrack/viewers/web/index.html +219 -0
- vibetrack/viewers/web/js/charts.js +693 -0
- vibetrack/viewers/web/js/core.js +300 -0
- vibetrack/viewers/web/js/embeddings.js +756 -0
- vibetrack/viewers/web/js/hparams.js +432 -0
- vibetrack/viewers/web/js/main.js +79 -0
- vibetrack/viewers/web/js/media.js +921 -0
- vibetrack/viewers/web/js/meshes.js +448 -0
- vibetrack/viewers/web/js/pills.js +114 -0
- vibetrack/viewers/web/js/settings.js +148 -0
- vibetrack/viewers/web/vendor/OrbitControls.js +1045 -0
- vibetrack/viewers/web/vendor/README.md +13 -0
- vibetrack/viewers/web/vendor/three.min.js +6 -0
- vibetrack/viewers/web.py +933 -0
- vibetrack/writer.py +2241 -0
- vibetrack-0.1.0.dist-info/METADATA +286 -0
- vibetrack-0.1.0.dist-info/RECORD +45 -0
- vibetrack-0.1.0.dist-info/WHEEL +5 -0
- vibetrack-0.1.0.dist-info/entry_points.txt +2 -0
- vibetrack-0.1.0.dist-info/licenses/LICENSE +183 -0
- vibetrack-0.1.0.dist-info/top_level.txt +1 -0
vibetrack/__init__.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""vibetrack — lightweight experiment tracking.
|
|
2
|
+
|
|
3
|
+
Example usage:
|
|
4
|
+
# TensorBoard style
|
|
5
|
+
from vibetrack import SummaryWriter
|
|
6
|
+
writer = SummaryWriter("my_project/run_1")
|
|
7
|
+
writer.add_scalar("loss", 0.5, step)
|
|
8
|
+
|
|
9
|
+
# Module-level API
|
|
10
|
+
import vibetrack
|
|
11
|
+
vibetrack.init(project="my_project", name="run_1", config={"lr": 0.01})
|
|
12
|
+
vibetrack.log({"loss": 0.5, "acc": 0.9})
|
|
13
|
+
vibetrack.finish()
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import re
|
|
20
|
+
import sys
|
|
21
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Any, Dict, Optional, Union
|
|
24
|
+
|
|
25
|
+
from .writer import SummaryWriter
|
|
26
|
+
from .reader import ExperimentReader, RunReader
|
|
27
|
+
from .smoother import smooth, ema, moving_average, gaussian
|
|
28
|
+
from .compare import compare_scalars, compare_hparams, summary_table
|
|
29
|
+
from .types import Image, Audio, Video, Artifact
|
|
30
|
+
from .default_config import SYSTEM_METRICS_INTERVAL
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _read_version_from_pyproject() -> Optional[str]:
|
|
34
|
+
pyproject = Path(__file__).resolve().parent.parent / "pyproject.toml"
|
|
35
|
+
try:
|
|
36
|
+
text = pyproject.read_text(encoding="utf-8")
|
|
37
|
+
except OSError:
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
in_project = False
|
|
41
|
+
for line in text.splitlines():
|
|
42
|
+
stripped = line.strip()
|
|
43
|
+
if stripped.startswith("[") and stripped.endswith("]"):
|
|
44
|
+
in_project = stripped == "[project]"
|
|
45
|
+
continue
|
|
46
|
+
if in_project:
|
|
47
|
+
match = re.match(r'version\s*=\s*"([^"]+)"', stripped)
|
|
48
|
+
if match:
|
|
49
|
+
return match.group(1)
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
__version__ = _read_version_from_pyproject()
|
|
54
|
+
if __version__ is None:
|
|
55
|
+
try:
|
|
56
|
+
__version__ = version("vibetrack")
|
|
57
|
+
except PackageNotFoundError:
|
|
58
|
+
__version__ = "0.0.0"
|
|
59
|
+
|
|
60
|
+
__all__ = [
|
|
61
|
+
# Core
|
|
62
|
+
"SummaryWriter",
|
|
63
|
+
"ExperimentReader",
|
|
64
|
+
"RunReader",
|
|
65
|
+
# Module-level logging API
|
|
66
|
+
"init",
|
|
67
|
+
"log",
|
|
68
|
+
"finish",
|
|
69
|
+
"config",
|
|
70
|
+
# Smoothing
|
|
71
|
+
"smooth",
|
|
72
|
+
"ema",
|
|
73
|
+
"moving_average",
|
|
74
|
+
"gaussian",
|
|
75
|
+
# Compare
|
|
76
|
+
"compare_scalars",
|
|
77
|
+
"compare_hparams",
|
|
78
|
+
"summary_table",
|
|
79
|
+
# Media types
|
|
80
|
+
"Image",
|
|
81
|
+
"Audio",
|
|
82
|
+
"Video",
|
|
83
|
+
"Artifact",
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
# ── Module-level logging API ────────────────────────────────────
|
|
87
|
+
|
|
88
|
+
_active_writer: Optional[SummaryWriter] = None
|
|
89
|
+
_step: int = 0
|
|
90
|
+
|
|
91
|
+
config: Dict[str, Any] = {}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _warn(msg: str) -> None:
|
|
95
|
+
print(f"vibetrack warning: {msg}", file=sys.stderr)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def init(
|
|
99
|
+
project: Optional[str] = None,
|
|
100
|
+
name: Optional[str] = None,
|
|
101
|
+
config: Optional[Dict[str, Any]] = None,
|
|
102
|
+
log_dir: Optional[str] = None,
|
|
103
|
+
project_folder: Optional[str] = None,
|
|
104
|
+
precache_secs: float = 0,
|
|
105
|
+
system_metrics_interval: float = SYSTEM_METRICS_INTERVAL,
|
|
106
|
+
rank: Optional[Union[int, str]] = None,
|
|
107
|
+
to: Optional[Union[str, list, tuple]] = None,
|
|
108
|
+
**kwargs: Any,
|
|
109
|
+
) -> SummaryWriter:
|
|
110
|
+
"""Initialize a new run.
|
|
111
|
+
|
|
112
|
+
Only rank 0 logs by default. Other ranks get a no-op writer.
|
|
113
|
+
Set ``rank="all"`` to force every rank to log.
|
|
114
|
+
|
|
115
|
+
::
|
|
116
|
+
|
|
117
|
+
import vibetrack
|
|
118
|
+
vibetrack.init(project="cifar10", name="resnet18", config={"lr": 1e-3})
|
|
119
|
+
vibetrack.init(..., system_metrics_interval=10) # collect OS/GPU stats
|
|
120
|
+
"""
|
|
121
|
+
global _active_writer, _step
|
|
122
|
+
import vibetrack as _mod
|
|
123
|
+
|
|
124
|
+
if _active_writer is not None:
|
|
125
|
+
try:
|
|
126
|
+
_active_writer.close()
|
|
127
|
+
except Exception as exc:
|
|
128
|
+
_warn(f"failed to close active writer during init: {exc}")
|
|
129
|
+
|
|
130
|
+
_active_writer = SummaryWriter(
|
|
131
|
+
log_dir=log_dir,
|
|
132
|
+
project=project,
|
|
133
|
+
name=name,
|
|
134
|
+
config=config,
|
|
135
|
+
project_folder=project_folder,
|
|
136
|
+
precache_secs=precache_secs,
|
|
137
|
+
system_metrics_interval=system_metrics_interval,
|
|
138
|
+
rank=rank,
|
|
139
|
+
**kwargs,
|
|
140
|
+
)
|
|
141
|
+
_step = 0
|
|
142
|
+
if config:
|
|
143
|
+
_mod.config = dict(config)
|
|
144
|
+
if to is not None:
|
|
145
|
+
names = [to] if isinstance(to, str) else list(to)
|
|
146
|
+
for entry in names:
|
|
147
|
+
if isinstance(entry, str):
|
|
148
|
+
_active_writer.to(entry)
|
|
149
|
+
elif isinstance(entry, dict):
|
|
150
|
+
_active_writer.to(**entry)
|
|
151
|
+
else:
|
|
152
|
+
_warn(f"ignoring unknown to= entry: {entry!r}")
|
|
153
|
+
return _active_writer
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def log(data: Dict[str, Any], step: Optional[int] = None, **kwargs: Any) -> None:
|
|
157
|
+
"""Log metrics for the current step.
|
|
158
|
+
|
|
159
|
+
::
|
|
160
|
+
|
|
161
|
+
vibetrack.log({"loss": 0.5, "acc": 0.9})
|
|
162
|
+
"""
|
|
163
|
+
global _step
|
|
164
|
+
if _active_writer is None:
|
|
165
|
+
_warn("log() called before init(); dropping data")
|
|
166
|
+
return
|
|
167
|
+
if step is not None:
|
|
168
|
+
_step = step
|
|
169
|
+
try:
|
|
170
|
+
_active_writer.log(data, step=_step, **kwargs)
|
|
171
|
+
except Exception as exc:
|
|
172
|
+
_warn(f"failed to log data: {exc}")
|
|
173
|
+
_step += 1
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def finish() -> None:
|
|
177
|
+
"""Flush and close the active writer."""
|
|
178
|
+
global _active_writer
|
|
179
|
+
if _active_writer is not None:
|
|
180
|
+
try:
|
|
181
|
+
_active_writer.close()
|
|
182
|
+
except Exception as exc:
|
|
183
|
+
_warn(f"failed to close writer during finish: {exc}")
|
|
184
|
+
_active_writer = None
|
vibetrack/_graph.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
"""Self-contained PyTorch model-graph capture & rendering.
|
|
2
|
+
|
|
3
|
+
No external dependencies beyond the ones vibetrack already needs (torch at
|
|
4
|
+
call site, matplotlib + numpy for rendering). No torchviz / graphviz / TB
|
|
5
|
+
protobuf — we draw the diagram ourselves.
|
|
6
|
+
|
|
7
|
+
Public surface:
|
|
8
|
+
capture_graph(model, input_to_model) -> List[dict]
|
|
9
|
+
render_graph_png(layers, header_text) -> np.ndarray (HWC uint8)
|
|
10
|
+
human_params(n) -> str
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from typing import Any, List, Optional, Sequence
|
|
16
|
+
|
|
17
|
+
# ── Capture ──────────────────────────────────────────────────────────────
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _shape_of_first(obj: Any) -> Optional[List[int]]:
|
|
21
|
+
"""Pull a shape out of a tensor, tuple-of-tensors, or dict-of-tensors.
|
|
22
|
+
|
|
23
|
+
Returns the shape of the first tensor we find, or None when nothing is
|
|
24
|
+
tensor-shaped. Used inside hooks where we don't know the call signature.
|
|
25
|
+
"""
|
|
26
|
+
if obj is None:
|
|
27
|
+
return None
|
|
28
|
+
# Direct tensor
|
|
29
|
+
shape = getattr(obj, "shape", None)
|
|
30
|
+
if shape is not None and hasattr(shape, "__iter__"):
|
|
31
|
+
try:
|
|
32
|
+
return [int(v) for v in shape]
|
|
33
|
+
except Exception:
|
|
34
|
+
pass
|
|
35
|
+
# Tuple / list — first tensor wins
|
|
36
|
+
if isinstance(obj, (tuple, list)):
|
|
37
|
+
for el in obj:
|
|
38
|
+
s = _shape_of_first(el)
|
|
39
|
+
if s is not None:
|
|
40
|
+
return s
|
|
41
|
+
# Dict — first value
|
|
42
|
+
if isinstance(obj, dict):
|
|
43
|
+
for v in obj.values():
|
|
44
|
+
s = _shape_of_first(v)
|
|
45
|
+
if s is not None:
|
|
46
|
+
return s
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def capture_graph(model: Any, input_to_model: Any) -> List[dict]:
|
|
51
|
+
"""Run one forward pass under hooks; return per-leaf-module records.
|
|
52
|
+
|
|
53
|
+
Each record: ``{path, class_name, in_shape, out_shape, n_params}``.
|
|
54
|
+
|
|
55
|
+
The model is switched to ``eval()`` and gradients are disabled. The
|
|
56
|
+
previous training mode is restored on the way out.
|
|
57
|
+
|
|
58
|
+
Raises ``RuntimeError`` (propagated from forward) if the model crashes
|
|
59
|
+
on the given input — caller decides whether to fall back to a static
|
|
60
|
+
walk or skip rendering altogether.
|
|
61
|
+
"""
|
|
62
|
+
import torch # local: torch is optional for the rest of vibetrack
|
|
63
|
+
|
|
64
|
+
layers: List[dict] = []
|
|
65
|
+
leaves = [(name, m) for name, m in model.named_modules() if not list(m.children())]
|
|
66
|
+
handles = []
|
|
67
|
+
|
|
68
|
+
def make_hook(path: str, mod: Any):
|
|
69
|
+
def hook(module: Any, inputs: Any, output: Any) -> None:
|
|
70
|
+
try:
|
|
71
|
+
in_shape = _shape_of_first(inputs)
|
|
72
|
+
out_shape = _shape_of_first(output)
|
|
73
|
+
n_params = sum(p.numel() for p in module.parameters(recurse=False))
|
|
74
|
+
except Exception:
|
|
75
|
+
in_shape = out_shape = None
|
|
76
|
+
n_params = 0
|
|
77
|
+
layers.append(
|
|
78
|
+
{
|
|
79
|
+
"path": path or "<root>",
|
|
80
|
+
"class_name": type(mod).__name__,
|
|
81
|
+
"in_shape": in_shape,
|
|
82
|
+
"out_shape": out_shape,
|
|
83
|
+
"n_params": int(n_params),
|
|
84
|
+
}
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return hook
|
|
88
|
+
|
|
89
|
+
for name, mod in leaves:
|
|
90
|
+
handles.append(mod.register_forward_hook(make_hook(name, mod)))
|
|
91
|
+
|
|
92
|
+
was_training = bool(getattr(model, "training", False))
|
|
93
|
+
try:
|
|
94
|
+
if hasattr(model, "eval"):
|
|
95
|
+
model.eval()
|
|
96
|
+
with torch.no_grad():
|
|
97
|
+
if isinstance(input_to_model, (tuple, list)):
|
|
98
|
+
model(*input_to_model)
|
|
99
|
+
elif isinstance(input_to_model, dict):
|
|
100
|
+
model(**input_to_model)
|
|
101
|
+
else:
|
|
102
|
+
model(input_to_model)
|
|
103
|
+
finally:
|
|
104
|
+
for h in handles:
|
|
105
|
+
try:
|
|
106
|
+
h.remove()
|
|
107
|
+
except Exception:
|
|
108
|
+
pass
|
|
109
|
+
if was_training and hasattr(model, "train"):
|
|
110
|
+
model.train()
|
|
111
|
+
|
|
112
|
+
return layers
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def static_graph(model: Any) -> List[dict]:
|
|
116
|
+
"""Fallback: walk ``named_modules`` without running forward.
|
|
117
|
+
|
|
118
|
+
No shapes are recorded; ``in_shape`` and ``out_shape`` are ``None``.
|
|
119
|
+
Used when forward fails or no input is provided.
|
|
120
|
+
"""
|
|
121
|
+
out: List[dict] = []
|
|
122
|
+
for name, mod in model.named_modules():
|
|
123
|
+
if list(mod.children()):
|
|
124
|
+
continue # skip non-leaves
|
|
125
|
+
n_params = sum(p.numel() for p in mod.parameters(recurse=False))
|
|
126
|
+
out.append(
|
|
127
|
+
{
|
|
128
|
+
"path": name or "<root>",
|
|
129
|
+
"class_name": type(mod).__name__,
|
|
130
|
+
"in_shape": None,
|
|
131
|
+
"out_shape": None,
|
|
132
|
+
"n_params": int(n_params),
|
|
133
|
+
}
|
|
134
|
+
)
|
|
135
|
+
return out
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
# ── Render ───────────────────────────────────────────────────────────────
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
_TYPE_FAMILIES: Sequence[tuple] = (
|
|
142
|
+
# (substring, family-key, color)
|
|
143
|
+
("Conv", "conv", "#bbdefb"), # blue — Conv1d/2d/3d, ConvTranspose
|
|
144
|
+
("Linear", "linear", "#c8e6c9"), # green
|
|
145
|
+
("Norm", "norm", "#ffe0b2"), # orange — BatchNorm*, LayerNorm
|
|
146
|
+
("Pool", "pool", "#e1bee7"), # purple
|
|
147
|
+
("Drop", "drop", "#ffcdd2"), # red
|
|
148
|
+
("Embedding", "embed", "#b2dfdb"), # teal
|
|
149
|
+
# Activation family — explicit list (no shared substring)
|
|
150
|
+
("ReLU", "act", "#eeeeee"),
|
|
151
|
+
("GELU", "act", "#eeeeee"),
|
|
152
|
+
("SiLU", "act", "#eeeeee"),
|
|
153
|
+
("Sigmoid", "act", "#eeeeee"),
|
|
154
|
+
("Tanh", "act", "#eeeeee"),
|
|
155
|
+
("Softmax", "act", "#eeeeee"),
|
|
156
|
+
("LeakyReLU", "act", "#eeeeee"),
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _color_for(class_name: str) -> str:
|
|
161
|
+
for needle, _key, color in _TYPE_FAMILIES:
|
|
162
|
+
if needle in class_name:
|
|
163
|
+
return color
|
|
164
|
+
return "#ffffff" # unknown → white
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def human_params(n: int) -> str:
|
|
168
|
+
"""Format a parameter count as a short human-readable string."""
|
|
169
|
+
if n < 1_000:
|
|
170
|
+
return str(int(n))
|
|
171
|
+
if n < 1_000_000:
|
|
172
|
+
return f"{n / 1_000:.1f}K"
|
|
173
|
+
if n < 1_000_000_000:
|
|
174
|
+
return f"{n / 1_000_000:.1f}M"
|
|
175
|
+
return f"{n / 1_000_000_000:.1f}B"
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _shape_str(shape: Optional[List[int]]) -> str:
|
|
179
|
+
if shape is None:
|
|
180
|
+
return "?"
|
|
181
|
+
return "[" + ",".join(str(v) for v in shape) + "]"
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def render_graph_png(layers: List[dict], header_text: str) -> Any:
|
|
185
|
+
"""Render layers as a top-down stack of boxes. Returns HWC uint8 ndarray.
|
|
186
|
+
|
|
187
|
+
Raises ImportError if matplotlib/numpy aren't installed.
|
|
188
|
+
"""
|
|
189
|
+
import matplotlib
|
|
190
|
+
|
|
191
|
+
matplotlib.use("Agg")
|
|
192
|
+
import matplotlib.pyplot as plt
|
|
193
|
+
import numpy as np
|
|
194
|
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
|
195
|
+
from matplotlib.patches import FancyBboxPatch
|
|
196
|
+
|
|
197
|
+
n = max(1, len(layers))
|
|
198
|
+
box_h = 0.65 # inches per layer row
|
|
199
|
+
fig_w = 7.0
|
|
200
|
+
fig_h = 0.7 + n * box_h
|
|
201
|
+
|
|
202
|
+
fig, ax = plt.subplots(figsize=(fig_w, fig_h))
|
|
203
|
+
# Use unit y-axis: 1 row = 1 unit. Header occupies y in [0, 0.7].
|
|
204
|
+
ax.set_xlim(0, 1)
|
|
205
|
+
ax.set_ylim(0, n + 0.7)
|
|
206
|
+
ax.invert_yaxis()
|
|
207
|
+
ax.axis("off")
|
|
208
|
+
|
|
209
|
+
# Header
|
|
210
|
+
ax.text(
|
|
211
|
+
0.5,
|
|
212
|
+
0.35,
|
|
213
|
+
header_text,
|
|
214
|
+
ha="center",
|
|
215
|
+
va="center",
|
|
216
|
+
fontsize=11,
|
|
217
|
+
fontweight="bold",
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
for i, layer in enumerate(layers):
|
|
221
|
+
y0 = 0.7 + i + 0.05 # top edge of the box
|
|
222
|
+
height = 0.85 # box height in y-units
|
|
223
|
+
color = _color_for(layer["class_name"])
|
|
224
|
+
ax.add_patch(
|
|
225
|
+
FancyBboxPatch(
|
|
226
|
+
(0.04, y0),
|
|
227
|
+
0.92,
|
|
228
|
+
height,
|
|
229
|
+
boxstyle="round,pad=0.015",
|
|
230
|
+
facecolor=color,
|
|
231
|
+
edgecolor="#444",
|
|
232
|
+
linewidth=0.8,
|
|
233
|
+
)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
path = layer["path"] or "<root>"
|
|
237
|
+
if len(path) > 48:
|
|
238
|
+
path = path[:45] + "…"
|
|
239
|
+
ax.text(
|
|
240
|
+
0.5,
|
|
241
|
+
y0 + 0.20,
|
|
242
|
+
path,
|
|
243
|
+
ha="center",
|
|
244
|
+
va="center",
|
|
245
|
+
fontsize=8,
|
|
246
|
+
fontweight="bold",
|
|
247
|
+
)
|
|
248
|
+
ax.text(
|
|
249
|
+
0.5,
|
|
250
|
+
y0 + 0.42,
|
|
251
|
+
layer["class_name"],
|
|
252
|
+
ha="center",
|
|
253
|
+
va="center",
|
|
254
|
+
fontsize=7,
|
|
255
|
+
style="italic",
|
|
256
|
+
color="#222",
|
|
257
|
+
)
|
|
258
|
+
shape_txt = (
|
|
259
|
+
f"{_shape_str(layer['in_shape'])} → " f"{_shape_str(layer['out_shape'])}"
|
|
260
|
+
)
|
|
261
|
+
ax.text(
|
|
262
|
+
0.5,
|
|
263
|
+
y0 + 0.65,
|
|
264
|
+
shape_txt,
|
|
265
|
+
ha="center",
|
|
266
|
+
va="center",
|
|
267
|
+
fontsize=6.5,
|
|
268
|
+
family="monospace",
|
|
269
|
+
color="#333",
|
|
270
|
+
)
|
|
271
|
+
if layer["n_params"]:
|
|
272
|
+
ax.text(
|
|
273
|
+
0.96,
|
|
274
|
+
y0 + 0.42,
|
|
275
|
+
human_params(layer["n_params"]),
|
|
276
|
+
ha="right",
|
|
277
|
+
va="center",
|
|
278
|
+
fontsize=6.5,
|
|
279
|
+
color="#555",
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if i > 0:
|
|
283
|
+
ax.annotate(
|
|
284
|
+
"",
|
|
285
|
+
xy=(0.5, y0),
|
|
286
|
+
xytext=(0.5, y0 - 0.05),
|
|
287
|
+
arrowprops=dict(arrowstyle="->", color="#888", lw=0.8),
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
canvas = FigureCanvasAgg(fig)
|
|
291
|
+
canvas.draw()
|
|
292
|
+
arr = np.asarray(canvas.buffer_rgba())[:, :, :3].copy()
|
|
293
|
+
plt.close(fig)
|
|
294
|
+
return arr
|