arrscope 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arrscope/__init__.py +4 -0
- arrscope/__main__.py +80 -0
- arrscope/_api.py +131 -0
- arrscope/_config.py +14 -0
- arrscope/_format.py +58 -0
- arrscope/_layout.py +151 -0
- arrscope/_types.py +39 -0
- arrscope/adapters/__init__.py +3 -0
- arrscope/adapters/_core.py +33 -0
- arrscope/renderers/__init__.py +4 -0
- arrscope/renderers/html.py +295 -0
- arrscope/renderers/terminal.py +277 -0
- arrscope-0.1.0.dist-info/METADATA +185 -0
- arrscope-0.1.0.dist-info/RECORD +16 -0
- arrscope-0.1.0.dist-info/WHEEL +4 -0
- arrscope-0.1.0.dist-info/entry_points.txt +2 -0
arrscope/__init__.py
ADDED
arrscope/__main__.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from arrscope import scope
|
|
8
|
+
|
|
9
|
+
SHAPES: dict[str, tuple] = {
|
|
10
|
+
"1d": (6,),
|
|
11
|
+
"2d": (6, 6),
|
|
12
|
+
"3d": (3, 4, 5),
|
|
13
|
+
"4d": (2, 3, 4, 5),
|
|
14
|
+
"5d": (2, 2, 3, 4, 5),
|
|
15
|
+
"eye": (5, 5),
|
|
16
|
+
"trunc": (20, 4, 5),
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def main() -> None:
|
|
21
|
+
parser = argparse.ArgumentParser(
|
|
22
|
+
prog="arrscope",
|
|
23
|
+
description="Visualize n-dimensional arrays in the terminal.",
|
|
24
|
+
)
|
|
25
|
+
parser.add_argument(
|
|
26
|
+
"shape",
|
|
27
|
+
nargs="?",
|
|
28
|
+
default="4d",
|
|
29
|
+
help=f"Shape preset or custom dims like 3x4x5. Presets: {', '.join(SHAPES)}",
|
|
30
|
+
)
|
|
31
|
+
parser.add_argument("--axes", "-a", nargs="*", help="Axis names")
|
|
32
|
+
parser.add_argument(
|
|
33
|
+
"--grid", "-g", nargs="*", default=None, help="Grid axis names"
|
|
34
|
+
)
|
|
35
|
+
parser.add_argument(
|
|
36
|
+
"--title", "-t", default=None, help="Title for the visualization"
|
|
37
|
+
)
|
|
38
|
+
parser.add_argument(
|
|
39
|
+
"--max-height", type=int, default=None, help="Max rows before truncation"
|
|
40
|
+
)
|
|
41
|
+
parser.add_argument("--no-color", action="store_true", help="Disable color")
|
|
42
|
+
parser.add_argument(
|
|
43
|
+
"--mode", "-m",
|
|
44
|
+
default="dtype",
|
|
45
|
+
choices=["dtype", "heatmap", "sparsity"],
|
|
46
|
+
help="Color mode: dtype (default), heatmap (diverging), sparsity (highlight non-zero)",
|
|
47
|
+
)
|
|
48
|
+
parser.add_argument(
|
|
49
|
+
"--no-stats", action="store_true", help="Hide stats overlay"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
args = parser.parse_args()
|
|
53
|
+
|
|
54
|
+
if args.shape in SHAPES:
|
|
55
|
+
shape = SHAPES[args.shape]
|
|
56
|
+
else:
|
|
57
|
+
shape = tuple(int(d) for d in args.shape.split("x"))
|
|
58
|
+
|
|
59
|
+
arr = np.random.rand(*shape)
|
|
60
|
+
|
|
61
|
+
kw = {}
|
|
62
|
+
if args.axes:
|
|
63
|
+
kw["axes"] = args.axes
|
|
64
|
+
if args.grid:
|
|
65
|
+
kw["grid"] = args.grid
|
|
66
|
+
if args.title:
|
|
67
|
+
kw["title"] = args.title
|
|
68
|
+
if args.max_height:
|
|
69
|
+
kw["max_height"] = args.max_height
|
|
70
|
+
if args.no_color:
|
|
71
|
+
kw["color"] = False
|
|
72
|
+
kw["mode"] = args.mode
|
|
73
|
+
kw["stats"] = not args.no_stats
|
|
74
|
+
|
|
75
|
+
output = scope(arr, **kw)
|
|
76
|
+
print(output)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
if __name__ == "__main__":
|
|
80
|
+
main()
|
arrscope/_api.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from arrscope._config import config
|
|
4
|
+
from arrscope._format import format_value
|
|
5
|
+
from arrscope._layout import build_layout, compute_stats
|
|
6
|
+
from arrscope._types import VizOutput
|
|
7
|
+
from arrscope.adapters import to_numpy
|
|
8
|
+
from arrscope.renderers import render_html, render_terminal
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def scope(
|
|
12
|
+
arr,
|
|
13
|
+
*,
|
|
14
|
+
axes: list[str] | None = None,
|
|
15
|
+
grid: list[str | int] | None = None,
|
|
16
|
+
title: str | None = None,
|
|
17
|
+
max_width: int | None = None,
|
|
18
|
+
max_height: int | None = None,
|
|
19
|
+
fmt: str | None = None,
|
|
20
|
+
color: bool | None = None,
|
|
21
|
+
style: str = "auto",
|
|
22
|
+
mode: str | None = None,
|
|
23
|
+
stats: bool | None = None,
|
|
24
|
+
) -> VizOutput:
|
|
25
|
+
arr_np = to_numpy(arr)
|
|
26
|
+
|
|
27
|
+
color = config.color if color is None else color
|
|
28
|
+
max_width = max_width or config.max_width
|
|
29
|
+
max_height = max_height or config.max_height
|
|
30
|
+
mode = mode or config.mode
|
|
31
|
+
show_stats = config.stats if stats is None else stats
|
|
32
|
+
|
|
33
|
+
if mode not in ("dtype", "heatmap", "sparsity"):
|
|
34
|
+
raise ValueError(f"Unknown mode '{mode}'. Expected one of: dtype, heatmap, sparsity")
|
|
35
|
+
|
|
36
|
+
grid_dims = _resolve_grid_dims(arr_np, axes, grid)
|
|
37
|
+
|
|
38
|
+
node = build_layout(
|
|
39
|
+
arr_np,
|
|
40
|
+
grid_dims=grid_dims,
|
|
41
|
+
max_height=max_height,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
global_stats = compute_stats(arr_np) if (show_stats or mode == "heatmap") else None
|
|
45
|
+
if global_stats and show_stats:
|
|
46
|
+
node.stats = global_stats
|
|
47
|
+
|
|
48
|
+
use_html = _use_html(style)
|
|
49
|
+
use_terminal = _use_terminal(style)
|
|
50
|
+
|
|
51
|
+
output = VizOutput()
|
|
52
|
+
|
|
53
|
+
if use_html:
|
|
54
|
+
output.html = _wrap_html(
|
|
55
|
+
render_html(node, color=color, mode=mode, global_stats=global_stats), title
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if use_terminal:
|
|
59
|
+
output.ansi = _wrap_ansi(
|
|
60
|
+
render_terminal(node, color=color, mode=mode, global_stats=global_stats), title
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
return output
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _resolve_grid_dims(
|
|
67
|
+
arr: np.ndarray,
|
|
68
|
+
axes: list[str] | None,
|
|
69
|
+
grid: list[str | int] | None,
|
|
70
|
+
) -> list[int] | None:
|
|
71
|
+
if grid is None:
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
if axes is None:
|
|
75
|
+
return [g if isinstance(g, int) else arr.ndim - 2 + (list(axes or []).index(g))
|
|
76
|
+
for g in grid]
|
|
77
|
+
|
|
78
|
+
name_to_idx = {name: i for i, name in enumerate(axes)}
|
|
79
|
+
resolved: list[int] = []
|
|
80
|
+
for g in grid:
|
|
81
|
+
if isinstance(g, str):
|
|
82
|
+
if g not in name_to_idx:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"Axis '{g}' not found in axes={axes}. "
|
|
85
|
+
f"Available axes: {list(name_to_idx.keys())}"
|
|
86
|
+
)
|
|
87
|
+
resolved.append(name_to_idx[g])
|
|
88
|
+
else:
|
|
89
|
+
resolved.append(g)
|
|
90
|
+
return resolved
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _use_html(style: str) -> bool:
|
|
94
|
+
if style == "html":
|
|
95
|
+
return True
|
|
96
|
+
if style == "terminal":
|
|
97
|
+
return False
|
|
98
|
+
return _in_jupyter()
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _use_terminal(style: str) -> bool:
|
|
102
|
+
if style == "terminal":
|
|
103
|
+
return True
|
|
104
|
+
if style == "html":
|
|
105
|
+
return False
|
|
106
|
+
return True
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _in_jupyter() -> bool:
|
|
110
|
+
try:
|
|
111
|
+
from IPython import get_ipython
|
|
112
|
+
|
|
113
|
+
shell = get_ipython()
|
|
114
|
+
return shell is not None and "IPKernelApp" in shell.config
|
|
115
|
+
except (ImportError, AttributeError):
|
|
116
|
+
return False
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _wrap_html(body: str, title: str | None) -> str:
|
|
120
|
+
if title:
|
|
121
|
+
return f'<h3 style="font-family: sans-serif; margin: 4px 0;">{title}</h3>\n{body}'
|
|
122
|
+
return body
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _wrap_ansi(body: str, title: str | None) -> str:
|
|
126
|
+
if title:
|
|
127
|
+
from rich.text import Text
|
|
128
|
+
|
|
129
|
+
t = Text(title, style="bold")
|
|
130
|
+
return f"{t}\n{body}"
|
|
131
|
+
return body
|
arrscope/_config.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@dataclass
|
|
5
|
+
class Config:
|
|
6
|
+
max_width: int | None = None
|
|
7
|
+
max_height: int | None = None
|
|
8
|
+
color: bool = True
|
|
9
|
+
precision: int | None = None
|
|
10
|
+
mode: str = "dtype"
|
|
11
|
+
stats: bool = True
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
config: Config = Config()
|
arrscope/_format.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def format_value(val, fmt: str | None = None) -> str:
|
|
9
|
+
if isinstance(val, np.floating):
|
|
10
|
+
return _format_float(float(val), fmt)
|
|
11
|
+
if isinstance(val, float):
|
|
12
|
+
return _format_float(val, fmt)
|
|
13
|
+
if isinstance(val, np.integer):
|
|
14
|
+
return str(int(val))
|
|
15
|
+
if isinstance(val, np.bool_):
|
|
16
|
+
return "T" if val else "F"
|
|
17
|
+
if isinstance(val, bool):
|
|
18
|
+
return "T" if val else "F"
|
|
19
|
+
if isinstance(val, np.complexfloating):
|
|
20
|
+
return _format_complex(val, fmt)
|
|
21
|
+
if val is None or (isinstance(val, float) and np.isnan(val)):
|
|
22
|
+
return "NaN"
|
|
23
|
+
return str(val)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _format_float(val: float, fmt: str | None) -> str:
|
|
27
|
+
if fmt is not None:
|
|
28
|
+
return f"{val:{fmt}}"
|
|
29
|
+
if val == 0.0:
|
|
30
|
+
return "0"
|
|
31
|
+
if math.isnan(val):
|
|
32
|
+
return "NaN"
|
|
33
|
+
if math.isinf(val):
|
|
34
|
+
return "inf" if val > 0 else "-inf"
|
|
35
|
+
abs_val = abs(val)
|
|
36
|
+
if abs_val >= 1e4 or abs_val < 1e-4:
|
|
37
|
+
return f"{val:.4e}"
|
|
38
|
+
if abs_val >= 1:
|
|
39
|
+
return f"{val:.4f}".rstrip("0").rstrip(".")
|
|
40
|
+
return f"{val:.4f}".rstrip("0").rstrip(".")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _format_complex(val: complex, fmt: str | None) -> str:
|
|
44
|
+
real = _format_float(val.real, fmt)
|
|
45
|
+
imag = _format_float(abs(val.imag), fmt)
|
|
46
|
+
sign = "+" if val.imag >= 0 else "-"
|
|
47
|
+
return f"{real}{sign}{imag}j"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def auto_precision(arr: np.ndarray) -> int:
|
|
51
|
+
flat = arr.ravel()
|
|
52
|
+
flat = flat[np.isfinite(flat)]
|
|
53
|
+
if len(flat) == 0:
|
|
54
|
+
return 2
|
|
55
|
+
log_range = np.log10(np.max(np.abs(flat))) - np.log10(np.min(np.abs(flat[flat != 0]))) if np.any(flat != 0) else 2
|
|
56
|
+
if np.isnan(log_range) or np.isinf(log_range):
|
|
57
|
+
return 4
|
|
58
|
+
return min(max(int(log_range) + 2, 2), 8)
|
arrscope/_layout.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from arrscope._types import VizNode
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def build_layout(
|
|
11
|
+
arr: np.ndarray,
|
|
12
|
+
*,
|
|
13
|
+
grid_dims: list[int] | None = None,
|
|
14
|
+
max_height: int | None = None,
|
|
15
|
+
) -> VizNode:
|
|
16
|
+
shape = arr.shape
|
|
17
|
+
ndim = len(shape)
|
|
18
|
+
|
|
19
|
+
if grid_dims is None:
|
|
20
|
+
grid_dims = list(range(max(0, ndim - 2), ndim))
|
|
21
|
+
|
|
22
|
+
non_grid = [d for d in range(ndim) if d not in grid_dims]
|
|
23
|
+
|
|
24
|
+
root = _build_tree(arr, non_grid, grid_dims, (), max_height)
|
|
25
|
+
root.shape = shape
|
|
26
|
+
root.label = _shape_label(shape, arr.dtype)
|
|
27
|
+
return root
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def compute_stats(arr: np.ndarray) -> dict:
|
|
31
|
+
flat = arr.ravel()
|
|
32
|
+
stats: dict = {}
|
|
33
|
+
|
|
34
|
+
if np.issubdtype(arr.dtype, np.floating) or np.issubdtype(arr.dtype, np.integer):
|
|
35
|
+
finite = flat[np.isfinite(flat)]
|
|
36
|
+
if len(finite) > 0:
|
|
37
|
+
stats["min"] = float(finite.min())
|
|
38
|
+
stats["max"] = float(finite.max())
|
|
39
|
+
stats["mean"] = float(finite.mean())
|
|
40
|
+
stats["std"] = float(finite.std())
|
|
41
|
+
else:
|
|
42
|
+
stats["min"] = stats["max"] = stats["mean"] = stats["std"] = float("nan")
|
|
43
|
+
elif np.issubdtype(arr.dtype, np.bool_):
|
|
44
|
+
stats["min"] = int(flat.min())
|
|
45
|
+
stats["max"] = int(flat.max())
|
|
46
|
+
stats["mean"] = float(flat.mean())
|
|
47
|
+
stats["std"] = float(flat.std())
|
|
48
|
+
|
|
49
|
+
nan_count = int(np.isnan(flat).sum()) if np.issubdtype(arr.dtype, np.floating) else 0
|
|
50
|
+
zero_count = int((flat == 0).sum())
|
|
51
|
+
total = flat.size
|
|
52
|
+
stats["nan_count"] = nan_count
|
|
53
|
+
stats["zero_count"] = zero_count
|
|
54
|
+
stats["total"] = total
|
|
55
|
+
stats["sparsity"] = zero_count / total if total > 0 else 0.0
|
|
56
|
+
stats["dtype"] = str(arr.dtype)
|
|
57
|
+
|
|
58
|
+
return stats
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _build_tree(
|
|
62
|
+
arr: np.ndarray,
|
|
63
|
+
non_grid: list[int],
|
|
64
|
+
grid_dims: list[int],
|
|
65
|
+
indices: tuple[int, ...],
|
|
66
|
+
max_height: int | None,
|
|
67
|
+
) -> VizNode:
|
|
68
|
+
if not non_grid:
|
|
69
|
+
return VizNode(
|
|
70
|
+
indices=indices,
|
|
71
|
+
label=_indices_label(indices),
|
|
72
|
+
shape=arr.shape,
|
|
73
|
+
grid_data=arr,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
current_dim = non_grid[0]
|
|
77
|
+
size = arr.shape[current_dim]
|
|
78
|
+
remaining = non_grid[1:]
|
|
79
|
+
|
|
80
|
+
children: list[VizNode] = []
|
|
81
|
+
head_count, tail_count = _truncation_limits(size, max_height)
|
|
82
|
+
|
|
83
|
+
for i in range(size):
|
|
84
|
+
is_truncated_region = (
|
|
85
|
+
max_height is not None
|
|
86
|
+
and size > head_count + tail_count
|
|
87
|
+
and head_count <= i < size - tail_count
|
|
88
|
+
)
|
|
89
|
+
if is_truncated_region:
|
|
90
|
+
if i == head_count:
|
|
91
|
+
children.append(
|
|
92
|
+
VizNode(
|
|
93
|
+
indices=(),
|
|
94
|
+
label=f"... ({size - head_count - tail_count} more)",
|
|
95
|
+
truncated=True,
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
continue
|
|
99
|
+
|
|
100
|
+
child_indices = indices + (i,)
|
|
101
|
+
child_arr = _slice_dim(arr, current_dim, i)
|
|
102
|
+
|
|
103
|
+
new_remaining = _adjust_indices(remaining, current_dim)
|
|
104
|
+
child = _build_tree(
|
|
105
|
+
child_arr, new_remaining, grid_dims, child_indices, max_height
|
|
106
|
+
)
|
|
107
|
+
children.append(child)
|
|
108
|
+
|
|
109
|
+
return VizNode(
|
|
110
|
+
indices=indices,
|
|
111
|
+
label=_indices_label(indices),
|
|
112
|
+
shape=(size,) + arr.shape[current_dim + 1 :],
|
|
113
|
+
children=children,
|
|
114
|
+
truncated=len(children) < size and size > 0,
|
|
115
|
+
truncated_head=head_count,
|
|
116
|
+
truncated_tail=tail_count if len(children) < size else 0,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _slice_dim(arr: np.ndarray, dim: int, index: int) -> np.ndarray:
|
|
121
|
+
slices = [slice(None)] * arr.ndim
|
|
122
|
+
slices[dim] = index
|
|
123
|
+
return arr[tuple(slices)]
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _adjust_indices(
|
|
127
|
+
remaining: list[int], removed_dim: int
|
|
128
|
+
) -> list[int]:
|
|
129
|
+
return [d - 1 if d > removed_dim else d for d in remaining]
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _indices_label(indices: tuple[int, ...]) -> str:
|
|
133
|
+
if not indices:
|
|
134
|
+
return ""
|
|
135
|
+
return "[" + ", ".join(str(i) for i in indices) + "]"
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _shape_label(shape: tuple[int, ...], dtype: np.dtype) -> str:
|
|
139
|
+
return f"({', '.join(str(s) for s in shape)}) {dtype}"
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _truncation_limits(
|
|
143
|
+
size: int, max_height: int | None
|
|
144
|
+
) -> tuple[int, int]:
|
|
145
|
+
if max_height is None:
|
|
146
|
+
return size, 0
|
|
147
|
+
if size <= max_height:
|
|
148
|
+
return size, 0
|
|
149
|
+
head = math.ceil(max_height / 2)
|
|
150
|
+
tail = max_height // 2
|
|
151
|
+
return head, tail
|
arrscope/_types.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class VizNode:
|
|
10
|
+
indices: tuple[int, ...] = ()
|
|
11
|
+
label: str = ""
|
|
12
|
+
shape: tuple[int, ...] = ()
|
|
13
|
+
children: list[VizNode] = field(default_factory=list)
|
|
14
|
+
grid_data: np.ndarray | None = None
|
|
15
|
+
truncated: bool = False
|
|
16
|
+
truncated_head: int = 0
|
|
17
|
+
truncated_tail: int = 0
|
|
18
|
+
stats: dict | None = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class VizOutput:
|
|
23
|
+
html: str = ""
|
|
24
|
+
ansi: str = ""
|
|
25
|
+
metadata: dict = field(default_factory=dict)
|
|
26
|
+
|
|
27
|
+
def _repr_html_(self) -> str:
|
|
28
|
+
return self.html
|
|
29
|
+
|
|
30
|
+
def __rich_console__(self, console, options):
|
|
31
|
+
from rich.text import Text
|
|
32
|
+
|
|
33
|
+
yield Text.from_ansi(self.ansi)
|
|
34
|
+
|
|
35
|
+
def __str__(self) -> str:
|
|
36
|
+
return self.ansi
|
|
37
|
+
|
|
38
|
+
def __repr__(self) -> str:
|
|
39
|
+
return self.ansi
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def to_numpy(arr) -> np.ndarray:
|
|
7
|
+
if isinstance(arr, np.ndarray):
|
|
8
|
+
return arr
|
|
9
|
+
|
|
10
|
+
if hasattr(arr, "__array__"):
|
|
11
|
+
return np.asarray(arr)
|
|
12
|
+
|
|
13
|
+
module = type(arr).__module__ if hasattr(type(arr), "__module__") else ""
|
|
14
|
+
|
|
15
|
+
if "torch" in module:
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
return arr.detach().cpu().numpy()
|
|
19
|
+
|
|
20
|
+
if "tensorflow" in module or "keras" in module:
|
|
21
|
+
return arr.numpy()
|
|
22
|
+
|
|
23
|
+
if "jax" in module:
|
|
24
|
+
return np.asarray(arr)
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
return np.asarray(arr)
|
|
28
|
+
except (TypeError, ValueError) as e:
|
|
29
|
+
raise TypeError(
|
|
30
|
+
f"Cannot convert {type(arr).__name__} to numpy array. "
|
|
31
|
+
f"Supported types: np.ndarray, torch.Tensor, tf.Tensor, jax.Array, "
|
|
32
|
+
f"and objects implementing __array__."
|
|
33
|
+
) from e
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import html as html_mod
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from arrscope._format import format_value
|
|
8
|
+
from arrscope._types import VizNode
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def render_html(
|
|
12
|
+
node: VizNode,
|
|
13
|
+
color: bool = True,
|
|
14
|
+
mode: str = "dtype",
|
|
15
|
+
global_stats: dict | None = None,
|
|
16
|
+
) -> str:
|
|
17
|
+
css = _css()
|
|
18
|
+
body = _render_node_html(node, 0, mode, global_stats)
|
|
19
|
+
|
|
20
|
+
stats_html = ""
|
|
21
|
+
if global_stats and node.stats:
|
|
22
|
+
stats_html = _stats_html(global_stats)
|
|
23
|
+
|
|
24
|
+
return f"""<div class="arrscope-wrapper">
|
|
25
|
+
<style>{css}</style>
|
|
26
|
+
{body}
|
|
27
|
+
{stats_html}
|
|
28
|
+
</div>"""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _css() -> str:
|
|
32
|
+
return """
|
|
33
|
+
.arrscope-wrapper {
|
|
34
|
+
font-family: 'SF Mono', 'Fira Code', 'Consolas', monospace;
|
|
35
|
+
font-size: 13px;
|
|
36
|
+
line-height: 1.5;
|
|
37
|
+
margin: 8px 0;
|
|
38
|
+
}
|
|
39
|
+
.arrscope-wrapper table {
|
|
40
|
+
border-collapse: collapse;
|
|
41
|
+
margin: 4px 0;
|
|
42
|
+
}
|
|
43
|
+
.arrscope-wrapper td, .arrscope-th {
|
|
44
|
+
padding: 2px 6px;
|
|
45
|
+
text-align: right;
|
|
46
|
+
white-space: pre;
|
|
47
|
+
border: 1px solid #d0d0d0;
|
|
48
|
+
}
|
|
49
|
+
.arrscope-th {
|
|
50
|
+
font-weight: 600;
|
|
51
|
+
background: #f5f5f5;
|
|
52
|
+
text-align: center;
|
|
53
|
+
}
|
|
54
|
+
.arrscope-label {
|
|
55
|
+
font-weight: 600;
|
|
56
|
+
margin: 4px 0 2px;
|
|
57
|
+
color: #333;
|
|
58
|
+
}
|
|
59
|
+
.arrscope-shape {
|
|
60
|
+
color: #888;
|
|
61
|
+
font-size: 12px;
|
|
62
|
+
margin-left: 8px;
|
|
63
|
+
}
|
|
64
|
+
.arrscope-row { display: flex; flex-wrap: wrap; gap: 12px; }
|
|
65
|
+
.arrscope-tree { margin-left: 16px; }
|
|
66
|
+
.arrscope-ellipsis { color: #999; font-style: italic; margin: 2px 0; }
|
|
67
|
+
.arrscope-header {
|
|
68
|
+
font-weight: 600;
|
|
69
|
+
padding: 4px 0;
|
|
70
|
+
border-bottom: 1px solid #e0e0e0;
|
|
71
|
+
margin-bottom: 4px;
|
|
72
|
+
}
|
|
73
|
+
.arrscope-stats {
|
|
74
|
+
font-size: 12px;
|
|
75
|
+
color: #666;
|
|
76
|
+
margin: 2px 0 4px;
|
|
77
|
+
}
|
|
78
|
+
details.arrscope-details > summary {
|
|
79
|
+
cursor: pointer;
|
|
80
|
+
list-style: '⊞ ';
|
|
81
|
+
}
|
|
82
|
+
details.arrscope-details[open] > summary {
|
|
83
|
+
list-style: '⊟ ';
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
@media (prefers-color-scheme: dark) {
|
|
87
|
+
.arrscope-wrapper td {
|
|
88
|
+
border-color: #444;
|
|
89
|
+
}
|
|
90
|
+
.arrscope-th {
|
|
91
|
+
background: #2a2a2a;
|
|
92
|
+
color: #ddd;
|
|
93
|
+
}
|
|
94
|
+
.arrscope-label { color: #ccc; }
|
|
95
|
+
.arrscope-shape { color: #777; }
|
|
96
|
+
.arrscope-stats { color: #999; }
|
|
97
|
+
.arrscope-header { border-bottom-color: #444; }
|
|
98
|
+
}
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _stats_html(stats: dict) -> str:
|
|
103
|
+
items = []
|
|
104
|
+
for key in ("min", "max", "mean", "std"):
|
|
105
|
+
if key in stats:
|
|
106
|
+
v = stats[key]
|
|
107
|
+
if isinstance(v, float):
|
|
108
|
+
items.append(f"{key}={format_value(v)}")
|
|
109
|
+
else:
|
|
110
|
+
items.append(f"{key}={v}")
|
|
111
|
+
if "zero_count" in stats and "total" in stats:
|
|
112
|
+
pct = stats["zero_count"] / stats["total"] * 100
|
|
113
|
+
items.append(f"zeros={pct:.1f}%")
|
|
114
|
+
if "nan_count" in stats and stats["nan_count"] > 0:
|
|
115
|
+
items.append(f"NaN={stats['nan_count']}")
|
|
116
|
+
return f'<div class="arrscope-stats">{" · ".join(items)}</div>'
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _render_node_html(
|
|
120
|
+
node: VizNode, depth: int, mode: str, global_stats: dict | None
|
|
121
|
+
) -> str:
|
|
122
|
+
if node.grid_data is not None:
|
|
123
|
+
return _render_grid_html(node.grid_data, node.label, mode, global_stats)
|
|
124
|
+
|
|
125
|
+
if not node.children and node.truncated:
|
|
126
|
+
return f'<div class="arrscope-ellipsis">{html_mod.escape(node.label)}</div>'
|
|
127
|
+
|
|
128
|
+
is_large_tree = depth > 0 and len(node.children) > 4
|
|
129
|
+
|
|
130
|
+
parts: list[str] = []
|
|
131
|
+
if depth == 0:
|
|
132
|
+
parts.append(
|
|
133
|
+
f'<div class="arrscope-header">{html_mod.escape(_format_header(node))}</div>'
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
if is_large_tree:
|
|
137
|
+
parts.append(
|
|
138
|
+
f'<details class="arrscope-details" {"open" if depth == 0 else ""}>'
|
|
139
|
+
)
|
|
140
|
+
parts.append(
|
|
141
|
+
f"<summary>{html_mod.escape(node.label or _format_header(node))}</summary>"
|
|
142
|
+
)
|
|
143
|
+
parts.append('<div class="arrscope-tree">')
|
|
144
|
+
|
|
145
|
+
for child in node.children:
|
|
146
|
+
parts.append(_render_node_html(child, depth + 1, mode, global_stats))
|
|
147
|
+
|
|
148
|
+
if is_large_tree:
|
|
149
|
+
parts.append("</div></details>")
|
|
150
|
+
|
|
151
|
+
return "".join(parts)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _format_header(node: VizNode) -> str:
|
|
155
|
+
label = node.label or ""
|
|
156
|
+
shape = " × ".join(str(s) for s in node.shape)
|
|
157
|
+
if shape:
|
|
158
|
+
return f"{label} ({shape})"
|
|
159
|
+
return label
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _render_grid_html(
|
|
163
|
+
data: np.ndarray, label: str, mode: str, global_stats: dict | None
|
|
164
|
+
) -> str:
|
|
165
|
+
if data.ndim == 0:
|
|
166
|
+
return f"<span>{html_mod.escape(format_value(data.item()))}</span>"
|
|
167
|
+
|
|
168
|
+
if data.ndim == 1:
|
|
169
|
+
return _render_1d_html(data, mode, global_stats)
|
|
170
|
+
|
|
171
|
+
return _render_2d_html(data, label, mode, global_stats)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _render_1d_html(
|
|
175
|
+
data: np.ndarray, mode: str, global_stats: dict | None
|
|
176
|
+
) -> str:
|
|
177
|
+
cells = [
|
|
178
|
+
f"<td style=\"{_style_html(v, mode, global_stats)}\">"
|
|
179
|
+
f"{html_mod.escape(_format_cell_html(v, mode))}</td>"
|
|
180
|
+
for v in data
|
|
181
|
+
]
|
|
182
|
+
return f"<table><tr>{''.join(cells)}</tr></table>"
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _render_2d_html(
|
|
186
|
+
data: np.ndarray, label: str, mode: str, global_stats: dict | None
|
|
187
|
+
) -> str:
|
|
188
|
+
nrows, ncols = data.shape
|
|
189
|
+
rows: list[str] = []
|
|
190
|
+
if label:
|
|
191
|
+
rows.append(
|
|
192
|
+
f'<div class="arrscope-label">{html_mod.escape(label)} '
|
|
193
|
+
f'<span class="arrscope-shape">{nrows}×{ncols}</span></div>'
|
|
194
|
+
)
|
|
195
|
+
rows.append("<table>")
|
|
196
|
+
for row_idx in range(nrows):
|
|
197
|
+
rows.append("<tr>")
|
|
198
|
+
for col_idx in range(ncols):
|
|
199
|
+
val = data[row_idx, col_idx]
|
|
200
|
+
text = html_mod.escape(_format_cell_html(val, mode))
|
|
201
|
+
style = _style_html(val, mode, global_stats)
|
|
202
|
+
rows.append(f"<td style=\"{style}\">{text}</td>")
|
|
203
|
+
rows.append("</tr>")
|
|
204
|
+
rows.append("</table>")
|
|
205
|
+
return "".join(rows)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _format_cell_html(val, mode: str) -> str:
|
|
209
|
+
if mode == "sparsity" and _is_zero(val):
|
|
210
|
+
return "·"
|
|
211
|
+
return format_value(val)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _is_zero(val) -> bool:
|
|
215
|
+
if isinstance(val, (np.floating, float)):
|
|
216
|
+
return val == 0.0 or abs(val) < 1e-15
|
|
217
|
+
if isinstance(val, (np.integer, int)):
|
|
218
|
+
return val == 0
|
|
219
|
+
if isinstance(val, (np.bool_, bool)):
|
|
220
|
+
return not val
|
|
221
|
+
return val == 0
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _style_html(val, mode: str, global_stats: dict | None) -> str:
|
|
225
|
+
if mode == "heatmap":
|
|
226
|
+
return _heatmap_style(val, global_stats)
|
|
227
|
+
if mode == "sparsity":
|
|
228
|
+
return _sparsity_style(val)
|
|
229
|
+
return _dtype_style(val)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def _heatmap_style(val, global_stats: dict | None) -> str:
|
|
233
|
+
if not global_stats:
|
|
234
|
+
return ""
|
|
235
|
+
if isinstance(val, (np.floating, float)) and (np.isnan(val) or np.isinf(val)):
|
|
236
|
+
return "color: #e74c3c; font-weight: bold;"
|
|
237
|
+
if not isinstance(val, (np.floating, float, np.integer, int)):
|
|
238
|
+
return _dtype_style(val)
|
|
239
|
+
|
|
240
|
+
vmin = global_stats.get("min", 0.0)
|
|
241
|
+
vmax = global_stats.get("max", 1.0)
|
|
242
|
+
if vmax == vmin:
|
|
243
|
+
return "color: #999;"
|
|
244
|
+
|
|
245
|
+
t = (float(val) - vmin) / (vmax - vmin)
|
|
246
|
+
t = max(0.0, min(1.0, t))
|
|
247
|
+
|
|
248
|
+
if t < 0.5:
|
|
249
|
+
t2 = t / 0.5
|
|
250
|
+
r = int(180 - 140 * t2)
|
|
251
|
+
g = int(30 + 210 * t2)
|
|
252
|
+
b = int(30 + 100 * t2)
|
|
253
|
+
else:
|
|
254
|
+
t2 = (t - 0.5) / 0.5
|
|
255
|
+
r = int(40 - 10 * t2)
|
|
256
|
+
g = int(240 - 200 * t2)
|
|
257
|
+
b = int(130 + 130 * t2)
|
|
258
|
+
|
|
259
|
+
return f"color: rgb({r},{g},{b});"
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def _sparsity_style(val) -> str:
|
|
263
|
+
if _is_zero(val):
|
|
264
|
+
return "color: #999; opacity: 0.4;"
|
|
265
|
+
if isinstance(val, (np.integer, int)):
|
|
266
|
+
return "color: #27ae60; font-weight: bold;"
|
|
267
|
+
if isinstance(val, (np.floating, float)):
|
|
268
|
+
return "color: #2980b9; font-weight: bold;"
|
|
269
|
+
if isinstance(val, (np.bool_, bool)):
|
|
270
|
+
return "color: #8e44ad; font-weight: bold;"
|
|
271
|
+
return _dtype_style(val)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _dtype_style(val) -> str:
|
|
275
|
+
if isinstance(val, (np.floating, float)):
|
|
276
|
+
if np.isnan(val) or np.isinf(val):
|
|
277
|
+
return "color: #e74c3c; font-weight: bold;"
|
|
278
|
+
if val > 0:
|
|
279
|
+
intensity = min(abs(val) / 10, 1.0)
|
|
280
|
+
b = int(180 + 75 * intensity)
|
|
281
|
+
return f"color: rgb(40, 40, {b});"
|
|
282
|
+
if val < 0:
|
|
283
|
+
intensity = min(abs(val) / 10, 1.0)
|
|
284
|
+
r = int(180 + 75 * intensity)
|
|
285
|
+
return f"color: rgb({r}, 40, 40);"
|
|
286
|
+
return "color: #2980b9;"
|
|
287
|
+
if isinstance(val, (np.integer, int)):
|
|
288
|
+
return "color: #27ae60;"
|
|
289
|
+
if isinstance(val, (np.bool_, bool)):
|
|
290
|
+
return "color: #8e44ad;" if val else "color: #8e44ad; opacity: 0.5;"
|
|
291
|
+
if isinstance(val, str):
|
|
292
|
+
return "color: #f39c12;"
|
|
293
|
+
if isinstance(val, (np.complexfloating, complex)):
|
|
294
|
+
return "color: #c0392b;"
|
|
295
|
+
return ""
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from rich.box import SQUARE
|
|
5
|
+
from rich.color import Color
|
|
6
|
+
from rich.style import Style
|
|
7
|
+
from rich.table import Table
|
|
8
|
+
from rich.text import Text
|
|
9
|
+
from rich.tree import Tree
|
|
10
|
+
|
|
11
|
+
from arrscope._format import format_value
|
|
12
|
+
from arrscope._types import VizNode
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def render_terminal(
|
|
16
|
+
node: VizNode,
|
|
17
|
+
color: bool = True,
|
|
18
|
+
mode: str = "dtype",
|
|
19
|
+
global_stats: dict | None = None,
|
|
20
|
+
) -> str:
|
|
21
|
+
from rich.console import Console
|
|
22
|
+
|
|
23
|
+
console = Console(width=120, color_system="truecolor" if color else None)
|
|
24
|
+
parts: list = []
|
|
25
|
+
|
|
26
|
+
body = _render_node(node, 0, color, mode, global_stats)
|
|
27
|
+
parts.append(body)
|
|
28
|
+
|
|
29
|
+
if global_stats and node.stats:
|
|
30
|
+
parts.append(_stats_text(global_stats, color))
|
|
31
|
+
|
|
32
|
+
with console.capture() as capture:
|
|
33
|
+
for p in parts:
|
|
34
|
+
console.print(p)
|
|
35
|
+
return capture.get()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _stats_text(stats: dict, color: bool) -> Text:
|
|
39
|
+
dim = Style(dim=True, color="grey50") if color else Style(dim=True)
|
|
40
|
+
t = Text()
|
|
41
|
+
fmt = ".4f" if stats.get("dtype", "").startswith(("float", "int")) else ".4f"
|
|
42
|
+
|
|
43
|
+
def fmt_val(v):
|
|
44
|
+
if isinstance(v, float):
|
|
45
|
+
return format_value(v, fmt)
|
|
46
|
+
return str(v)
|
|
47
|
+
|
|
48
|
+
items = []
|
|
49
|
+
for key in ("min", "max", "mean", "std"):
|
|
50
|
+
if key in stats:
|
|
51
|
+
items.append(f"{key}={fmt_val(stats[key])}")
|
|
52
|
+
|
|
53
|
+
if "zero_count" in stats and "total" in stats:
|
|
54
|
+
pct = stats["zero_count"] / stats["total"] * 100
|
|
55
|
+
items.append(f"zeros={pct:.1f}%")
|
|
56
|
+
|
|
57
|
+
if "nan_count" in stats and stats["nan_count"] > 0:
|
|
58
|
+
items.append(f"NaN={stats['nan_count']}")
|
|
59
|
+
|
|
60
|
+
t.append(" " + " ".join(items), style=dim)
|
|
61
|
+
return t
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _render_node(
|
|
65
|
+
node: VizNode,
|
|
66
|
+
depth: int,
|
|
67
|
+
color: bool,
|
|
68
|
+
mode: str,
|
|
69
|
+
global_stats: dict | None,
|
|
70
|
+
) -> Tree | Table | Text:
|
|
71
|
+
if node.grid_data is not None:
|
|
72
|
+
return _render_grid(node.grid_data, node.label, color, mode, global_stats)
|
|
73
|
+
|
|
74
|
+
if not node.children and node.truncated:
|
|
75
|
+
t = Text(f" {node.label}")
|
|
76
|
+
t.stylize(Style(dim=True))
|
|
77
|
+
return t
|
|
78
|
+
|
|
79
|
+
tree = Tree(
|
|
80
|
+
_make_header(node, color),
|
|
81
|
+
guide_style=Style(dim=True, color="grey50") if color else Style(dim=True),
|
|
82
|
+
)
|
|
83
|
+
for child in node.children:
|
|
84
|
+
child_renderable = _render_node(child, depth + 1, color, mode, global_stats)
|
|
85
|
+
tree.add(child_renderable)
|
|
86
|
+
return tree
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _make_header(node: VizNode, color: bool) -> Text:
|
|
90
|
+
parts = []
|
|
91
|
+
if node.label:
|
|
92
|
+
parts.append((node.label, "bold" if color else ""))
|
|
93
|
+
shape_str = " × ".join(str(s) for s in node.shape)
|
|
94
|
+
if shape_str:
|
|
95
|
+
dim_style = Style(dim=True, color="grey50") if color else Style(dim=True)
|
|
96
|
+
parts.append((f" ({shape_str})", dim_style))
|
|
97
|
+
t = Text()
|
|
98
|
+
for text, style in parts:
|
|
99
|
+
t.append(text, style=style)
|
|
100
|
+
return t
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _render_grid(
|
|
104
|
+
data: np.ndarray,
|
|
105
|
+
label: str,
|
|
106
|
+
color: bool,
|
|
107
|
+
mode: str,
|
|
108
|
+
global_stats: dict | None,
|
|
109
|
+
) -> Text | Table:
|
|
110
|
+
if data.ndim == 0:
|
|
111
|
+
return Text(format_value(data.item()))
|
|
112
|
+
|
|
113
|
+
if data.ndim == 1:
|
|
114
|
+
return _render_1d(data, color, mode, global_stats)
|
|
115
|
+
|
|
116
|
+
return _render_2d(data, label, color, mode, global_stats)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _render_1d(
|
|
120
|
+
data: np.ndarray,
|
|
121
|
+
color: bool,
|
|
122
|
+
mode: str,
|
|
123
|
+
global_stats: dict | None,
|
|
124
|
+
) -> Text:
|
|
125
|
+
items = []
|
|
126
|
+
for v in data:
|
|
127
|
+
text = _format_cell(v, mode)
|
|
128
|
+
items.append(_style_value(text, v, color, mode, global_stats))
|
|
129
|
+
|
|
130
|
+
t = Text("[")
|
|
131
|
+
for i, item in enumerate(items):
|
|
132
|
+
if i > 0:
|
|
133
|
+
t.append(", ", Style(dim=True))
|
|
134
|
+
t.append(item)
|
|
135
|
+
t.append("]")
|
|
136
|
+
return t
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _render_2d(
|
|
140
|
+
data: np.ndarray,
|
|
141
|
+
label: str,
|
|
142
|
+
color: bool,
|
|
143
|
+
mode: str,
|
|
144
|
+
global_stats: dict | None,
|
|
145
|
+
) -> Table:
|
|
146
|
+
nrows, ncols = data.shape
|
|
147
|
+
table = Table(
|
|
148
|
+
show_header=False,
|
|
149
|
+
show_edge=True,
|
|
150
|
+
pad_edge=False,
|
|
151
|
+
box=SQUARE,
|
|
152
|
+
title=label if label else None,
|
|
153
|
+
min_width=min(ncols * 8 + 4, 120),
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
for _ in range(ncols):
|
|
157
|
+
table.add_column()
|
|
158
|
+
|
|
159
|
+
for row_idx in range(nrows):
|
|
160
|
+
row_cells = [
|
|
161
|
+
_style_value(
|
|
162
|
+
_format_cell(data[row_idx, col_idx], mode),
|
|
163
|
+
data[row_idx, col_idx],
|
|
164
|
+
color,
|
|
165
|
+
mode,
|
|
166
|
+
global_stats,
|
|
167
|
+
)
|
|
168
|
+
for col_idx in range(ncols)
|
|
169
|
+
]
|
|
170
|
+
table.add_row(*row_cells)
|
|
171
|
+
|
|
172
|
+
return table
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _format_cell(val, mode: str) -> str:
|
|
176
|
+
if mode == "sparsity" and _is_zero(val):
|
|
177
|
+
return "·"
|
|
178
|
+
return format_value(val)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _is_zero(val) -> bool:
|
|
182
|
+
if isinstance(val, (np.floating, float)):
|
|
183
|
+
return val == 0.0 or np.abs(val) < 1e-15
|
|
184
|
+
if isinstance(val, (np.integer, int)):
|
|
185
|
+
return val == 0
|
|
186
|
+
if isinstance(val, (np.bool_, bool)):
|
|
187
|
+
return not val
|
|
188
|
+
return val == 0
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _style_value(
|
|
192
|
+
text: str,
|
|
193
|
+
val,
|
|
194
|
+
color: bool,
|
|
195
|
+
mode: str,
|
|
196
|
+
global_stats: dict | None,
|
|
197
|
+
) -> Text:
|
|
198
|
+
if not color:
|
|
199
|
+
return Text(text)
|
|
200
|
+
|
|
201
|
+
if mode == "heatmap":
|
|
202
|
+
cell_style = _heatmap_color(val, global_stats)
|
|
203
|
+
elif mode == "sparsity":
|
|
204
|
+
cell_style = _sparsity_color(val)
|
|
205
|
+
else:
|
|
206
|
+
cell_style = _dtype_color(val)
|
|
207
|
+
|
|
208
|
+
return Text(text, style=cell_style)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _heatmap_color(val, global_stats: dict | None) -> Style:
|
|
212
|
+
if not global_stats:
|
|
213
|
+
return Style()
|
|
214
|
+
|
|
215
|
+
if isinstance(val, (np.floating, float)):
|
|
216
|
+
if np.isnan(val) or np.isinf(val):
|
|
217
|
+
return Style(color="red", bold=True)
|
|
218
|
+
if not isinstance(val, (np.floating, float, np.integer, int)):
|
|
219
|
+
return _dtype_color(val)
|
|
220
|
+
|
|
221
|
+
vmin = global_stats.get("min", 0.0)
|
|
222
|
+
vmax = global_stats.get("max", 1.0)
|
|
223
|
+
if vmax == vmin:
|
|
224
|
+
return Style(color="grey50")
|
|
225
|
+
|
|
226
|
+
t = (float(val) - vmin) / (vmax - vmin)
|
|
227
|
+
t = max(0.0, min(1.0, t))
|
|
228
|
+
|
|
229
|
+
# diverging colormap: dark red -> light -> dark blue
|
|
230
|
+
if t < 0.5:
|
|
231
|
+
t2 = t / 0.5 # 0..1
|
|
232
|
+
r = int(180 - 140 * t2)
|
|
233
|
+
g = int(30 + 210 * t2)
|
|
234
|
+
b = int(30 + 100 * t2)
|
|
235
|
+
else:
|
|
236
|
+
t2 = (t - 0.5) / 0.5 # 0..1
|
|
237
|
+
r = int(40 - 10 * t2)
|
|
238
|
+
g = int(240 - 200 * t2)
|
|
239
|
+
b = int(130 + 130 * t2)
|
|
240
|
+
|
|
241
|
+
return Style(color=Color.from_rgb(r, g, b))
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _sparsity_color(val) -> Style:
|
|
245
|
+
if _is_zero(val):
|
|
246
|
+
return Style(dim=True, color="grey50")
|
|
247
|
+
if isinstance(val, (np.integer, int)):
|
|
248
|
+
return Style(color="green", bold=True)
|
|
249
|
+
if isinstance(val, (np.floating, float)):
|
|
250
|
+
return Style(color="cyan", bold=True)
|
|
251
|
+
if isinstance(val, (np.bool_, bool)):
|
|
252
|
+
return Style(color="magenta", bold=True)
|
|
253
|
+
return _dtype_color(val)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _dtype_color(val) -> Style:
|
|
257
|
+
if isinstance(val, (np.floating, float)):
|
|
258
|
+
if np.isnan(val) or np.isinf(val):
|
|
259
|
+
return Style(color="red", bold=True)
|
|
260
|
+
if val > 0:
|
|
261
|
+
intensity = min(abs(val) / 10, 1.0)
|
|
262
|
+
b = int(180 + 75 * intensity)
|
|
263
|
+
return Style(color=f"rgb(40,40,{b})")
|
|
264
|
+
if val < 0:
|
|
265
|
+
intensity = min(abs(val) / 10, 1.0)
|
|
266
|
+
r = int(180 + 75 * intensity)
|
|
267
|
+
return Style(color=f"rgb({r},40,40)")
|
|
268
|
+
return Style(color="cyan")
|
|
269
|
+
if isinstance(val, (np.integer, int)):
|
|
270
|
+
return Style(color="green")
|
|
271
|
+
if isinstance(val, (np.bool_, bool)):
|
|
272
|
+
return Style(color="magenta") if val else Style(color="magenta", dim=True)
|
|
273
|
+
if isinstance(val, str):
|
|
274
|
+
return Style(color="yellow")
|
|
275
|
+
if isinstance(val, (np.complexfloating, complex)):
|
|
276
|
+
return Style(color="red")
|
|
277
|
+
return Style()
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: arrscope
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Beautiful n-dimensional array visualization for Python
|
|
5
|
+
Project-URL: Homepage, https://github.com/vizarray/arrscope
|
|
6
|
+
Project-URL: Repository, https://github.com/vizarray/arrscope
|
|
7
|
+
Author: arrscope contributors
|
|
8
|
+
License: MIT
|
|
9
|
+
Keywords: arrays,data-science,machine-learning,numpy,visualization
|
|
10
|
+
Classifier: Development Status :: 3 - Alpha
|
|
11
|
+
Classifier: Intended Audience :: Science/Research
|
|
12
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Visualization
|
|
17
|
+
Requires-Python: >=3.11
|
|
18
|
+
Requires-Dist: numpy>=1.24
|
|
19
|
+
Requires-Dist: rich>=13.0
|
|
20
|
+
Provides-Extra: jax
|
|
21
|
+
Requires-Dist: jax>=0.4; extra == 'jax'
|
|
22
|
+
Provides-Extra: tf
|
|
23
|
+
Requires-Dist: tensorflow>=2.12; extra == 'tf'
|
|
24
|
+
Provides-Extra: torch
|
|
25
|
+
Requires-Dist: torch>=2.0; extra == 'torch'
|
|
26
|
+
Description-Content-Type: text/markdown
|
|
27
|
+
|
|
28
|
+
# arrscope
|
|
29
|
+
|
|
30
|
+
Beautiful n-dimensional array visualization for Python — in the terminal and Jupyter.
|
|
31
|
+
|
|
32
|
+
```python
|
|
33
|
+
from arrscope import scope
|
|
34
|
+
import numpy as np
|
|
35
|
+
|
|
36
|
+
scope(np.random.rand(2, 3, 8, 8), axes=['batch', 'heads', 'h', 'w'])
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
```
|
|
40
|
+
├── [0] (3, 8, 8)
|
|
41
|
+
│ ├── [0,0]: 8×8 grid
|
|
42
|
+
│ ├── [0,1]: 8×8 grid
|
|
43
|
+
│ └── [0,2]: 8×8 grid
|
|
44
|
+
└── [1] (3, 8, 8)
|
|
45
|
+
...
|
|
46
|
+
min=0.001 max=0.999 mean=0.5 std=0.29 zeros=0.0%
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
## Features
|
|
50
|
+
|
|
51
|
+
- **1D → 6D+**: Tiered visual grammar — lists, grids, trees, collapsed hierarchies
|
|
52
|
+
- **Named axes**: Attach semantics to dimensions (`batch`, `heads`, `h`, `w`)
|
|
53
|
+
- **Configurable grid**: Pick which axes form the leaf 2D matrix
|
|
54
|
+
- **Three color modes**:
|
|
55
|
+
- `dtype` — semantic colors by data type (float=blue, int=green, bool=magenta, …)
|
|
56
|
+
- `heatmap` — diverging colormap (red→light→blue) by value magnitude
|
|
57
|
+
- `sparsity` — zeros as `·`, non-zeros highlighted in bold
|
|
58
|
+
- **Stats overlay**: min, max, mean, std, zero%, NaN count
|
|
59
|
+
- **Head/tail truncation**: large dimensions show first/last N slices with `…`
|
|
60
|
+
- **Smart precision**: auto-detects significant figures for floats
|
|
61
|
+
- **Terminal + Jupyter**: Rich ANSI output + static HTML/CSS with dark mode
|
|
62
|
+
- **Multi-framework**: NumPy, PyTorch, TensorFlow, JAX (lazy imports, no hard deps)
|
|
63
|
+
|
|
64
|
+
## Install
|
|
65
|
+
|
|
66
|
+
```bash
|
|
67
|
+
pip install arrscope
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
Only requires `numpy` + `rich`. Torch/TF/JAX are optional — pass any array type, it just works.
|
|
71
|
+
|
|
72
|
+
## Quick start
|
|
73
|
+
|
|
74
|
+
```python
|
|
75
|
+
from arrscope import scope
|
|
76
|
+
import numpy as np
|
|
77
|
+
|
|
78
|
+
# Auto-detect — last 2 dims form the grid
|
|
79
|
+
scope(np.random.rand(3, 4, 5))
|
|
80
|
+
|
|
81
|
+
# Named axes for clarity
|
|
82
|
+
scope(
|
|
83
|
+
np.random.rand(2, 8, 32, 32),
|
|
84
|
+
axes=['batch', 'heads', 'h', 'w'],
|
|
85
|
+
grid=['h', 'w'],
|
|
86
|
+
title='Attention heads',
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Pick any dims as the leaf grid
|
|
90
|
+
scope(data, axes=['a', 'b', 'c', 'd'], grid=['a', 'b'])
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
## Color modes
|
|
94
|
+
|
|
95
|
+
```python
|
|
96
|
+
scope(arr, mode='dtype') # default — blue floats, green ints, etc.
|
|
97
|
+
scope(arr, mode='heatmap') # diverging colormap by value
|
|
98
|
+
scope(arr, mode='sparsity') # · for zeros, bold for non-zeros
|
|
99
|
+
```
|
|
100
|
+
|
|
101
|
+
## Stats overlay
|
|
102
|
+
|
|
103
|
+
Stats are shown by default. Hide with:
|
|
104
|
+
|
|
105
|
+
```python
|
|
106
|
+
scope(arr, stats=False)
|
|
107
|
+
```
|
|
108
|
+
|
|
109
|
+
Shows: `min=0.0 max=1.0 mean=0.5 std=0.29 zeros=12.5%`
|
|
110
|
+
|
|
111
|
+
## Truncation
|
|
112
|
+
|
|
113
|
+
```python
|
|
114
|
+
scope(np.random.rand(100, 32, 32), max_height=8)
|
|
115
|
+
```
|
|
116
|
+
|
|
117
|
+
Shows first 4 + last 4 slices with `… (92 more)` in between.
|
|
118
|
+
|
|
119
|
+
## Custom formatting
|
|
120
|
+
|
|
121
|
+
```python
|
|
122
|
+
scope(arr, fmt='.2f') # fixed precision
|
|
123
|
+
scope(arr, color=False) # monochrome
|
|
124
|
+
scope(arr, style='html') # force HTML output in terminal
|
|
125
|
+
```
|
|
126
|
+
|
|
127
|
+
## CLI
|
|
128
|
+
|
|
129
|
+
```bash
|
|
130
|
+
# Install globally (ships with the library)
|
|
131
|
+
pip install arrscope
|
|
132
|
+
|
|
133
|
+
# Use from anywhere
|
|
134
|
+
arrscope 3x4x5
|
|
135
|
+
arrscope 2x3x32x32 --axes batch heads h w --grid h w --mode heatmap
|
|
136
|
+
arrscope 20x4x5 --max-height 6 --no-stats
|
|
137
|
+
```
|
|
138
|
+
|
|
139
|
+
## Framework support
|
|
140
|
+
|
|
141
|
+
Pass any array-like — conversion is automatic:
|
|
142
|
+
|
|
143
|
+
```python
|
|
144
|
+
import torch
|
|
145
|
+
scope(torch.randn(2, 3, 4)) # PyTorch
|
|
146
|
+
|
|
147
|
+
import tensorflow as tf
|
|
148
|
+
scope(tf.random.uniform((2, 3, 4))) # TensorFlow
|
|
149
|
+
|
|
150
|
+
import jax.numpy as jnp
|
|
151
|
+
scope(jnp.array([[1, 2], [3, 4]])) # JAX
|
|
152
|
+
```
|
|
153
|
+
|
|
154
|
+
## API
|
|
155
|
+
|
|
156
|
+
```python
|
|
157
|
+
scope(
|
|
158
|
+
arr, # np.ndarray | torch.Tensor | tf.Tensor | jax.Array
|
|
159
|
+
axes=None, # list[str] — name each dimension
|
|
160
|
+
grid=None, # list[str | int] — which dims form the leaf grid
|
|
161
|
+
title=None, # str — heading above the visualization
|
|
162
|
+
max_width=None, # int — max characters wide
|
|
163
|
+
max_height=None, # int — max rows before truncation
|
|
164
|
+
fmt=None, # str — format spec like '.4f'
|
|
165
|
+
color=True, # bool — enable/disable color
|
|
166
|
+
mode='dtype', # 'dtype' | 'heatmap' | 'sparsity'
|
|
167
|
+
stats=True, # bool — show min/max/mean/std/zeros
|
|
168
|
+
style='auto', # 'auto' | 'terminal' | 'html'
|
|
169
|
+
)
|
|
170
|
+
```
|
|
171
|
+
|
|
172
|
+
## Development
|
|
173
|
+
|
|
174
|
+
```bash
|
|
175
|
+
git clone https://github.com/vizarray/arrscope
|
|
176
|
+
cd arrscope
|
|
177
|
+
uv sync
|
|
178
|
+
uv run pytest
|
|
179
|
+
uv run python main.py # demo script
|
|
180
|
+
uv run jupyter notebook test.ipynb # notebook demo
|
|
181
|
+
```
|
|
182
|
+
|
|
183
|
+
## License
|
|
184
|
+
|
|
185
|
+
MIT
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
arrscope/__init__.py,sha256=WxxdMd9gq_DC4AoZ5PMJLtVVBcrcUCsHJRQ5F1qZunc,99
|
|
2
|
+
arrscope/__main__.py,sha256=oLxjmmAXtsPmhHO6SBSAu0Z3J3wmGScHyACOqTUL1K4,2051
|
|
3
|
+
arrscope/_api.py,sha256=d75HggCQr8HhXMsJFn7-ARUw8Zv0NVL00DUiRzjoPAo,3521
|
|
4
|
+
arrscope/_config.py,sha256=ZuaKcIqXbxRhikceYV5zM41OVV83js4O3mlbHd5ECTk,266
|
|
5
|
+
arrscope/_format.py,sha256=fEnOuu_Q1jKJcivMiMX7Lt1I3IWYKE2KUSJwvmBnHGg,1741
|
|
6
|
+
arrscope/_layout.py,sha256=EGnArvkslIRBBw_4RUSafgYnxHLtyRHnlkLwXmnEjvg,4270
|
|
7
|
+
arrscope/_types.py,sha256=0ASV_Tlo6cKIdDX2L9deCYTxVS26VqTKDexesoYtyIE,840
|
|
8
|
+
arrscope/adapters/__init__.py,sha256=y1WQxJmvO4rrIjqCFYgQLvnYcVwoHMEWSU8eIP9v1fU,69
|
|
9
|
+
arrscope/adapters/_core.py,sha256=Q-zonVPFlWunNn_kIreFQPl_ePk2eLbeOyK5FxfsfdQ,841
|
|
10
|
+
arrscope/renderers/__init__.py,sha256=he1exAHx6e375HpLPJ7kAq6AXC7Vi3lVmPyl_Hko32M,150
|
|
11
|
+
arrscope/renderers/html.py,sha256=vEfai1eDOPDavAMn6Oeq2AMIKeI56J8Jt40SPaC945I,8251
|
|
12
|
+
arrscope/renderers/terminal.py,sha256=lfhgYJfD8QGehwy65SAOmiD1cvCayFF8sWJnp4W20ZA,7493
|
|
13
|
+
arrscope-0.1.0.dist-info/METADATA,sha256=lxSNw9gchfuuvJRO-cnkK5VZ9x4gYy8iqtAvz2wItDg,5177
|
|
14
|
+
arrscope-0.1.0.dist-info/WHEEL,sha256=mffPy8wBnZQn2VnJUU5jE99KsxaSfiyMHV9Yt0aLVxs,87
|
|
15
|
+
arrscope-0.1.0.dist-info/entry_points.txt,sha256=PD1xqmqsAS6HwUKGcU6lCI-kifhDhsSW_gajRG-2pjM,52
|
|
16
|
+
arrscope-0.1.0.dist-info/RECORD,,
|