arrscope 0.3.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {arrscope-0.3.0 → arrscope-0.4.0}/PKG-INFO +35 -25
  2. {arrscope-0.3.0 → arrscope-0.4.0}/README.md +34 -24
  3. {arrscope-0.3.0 → arrscope-0.4.0}/main.py +5 -0
  4. {arrscope-0.3.0 → arrscope-0.4.0}/pyproject.toml +1 -1
  5. {arrscope-0.3.0 → arrscope-0.4.0}/src/arrscope/_api.py +33 -17
  6. {arrscope-0.3.0 → arrscope-0.4.0}/src/arrscope/_format.py +12 -0
  7. {arrscope-0.3.0 → arrscope-0.4.0}/src/arrscope/_layout.py +29 -3
  8. {arrscope-0.3.0 → arrscope-0.4.0}/src/arrscope/_types.py +10 -7
  9. {arrscope-0.3.0 → arrscope-0.4.0}/src/arrscope/renderers/html.py +83 -31
  10. {arrscope-0.3.0 → arrscope-0.4.0}/src/arrscope/renderers/terminal.py +83 -32
  11. {arrscope-0.3.0 → arrscope-0.4.0}/uv.lock +1 -1
  12. {arrscope-0.3.0 → arrscope-0.4.0}/.gitignore +0 -0
  13. {arrscope-0.3.0 → arrscope-0.4.0}/docs/ADR-001-visual-grammar.md +0 -0
  14. {arrscope-0.3.0 → arrscope-0.4.0}/docs/ADR-002-architecture.md +0 -0
  15. {arrscope-0.3.0 → arrscope-0.4.0}/docs/ADR-003-api-design.md +0 -0
  16. {arrscope-0.3.0 → arrscope-0.4.0}/docs/GLOSSARY.md +0 -0
  17. {arrscope-0.3.0 → arrscope-0.4.0}/examples/renders.py +0 -0
  18. {arrscope-0.3.0 → arrscope-0.4.0}/src/arrscope/__init__.py +0 -0
  19. {arrscope-0.3.0 → arrscope-0.4.0}/src/arrscope/__main__.py +0 -0
  20. {arrscope-0.3.0 → arrscope-0.4.0}/src/arrscope/_config.py +0 -0
  21. {arrscope-0.3.0 → arrscope-0.4.0}/src/arrscope/adapters/__init__.py +0 -0
  22. {arrscope-0.3.0 → arrscope-0.4.0}/src/arrscope/adapters/_core.py +0 -0
  23. {arrscope-0.3.0 → arrscope-0.4.0}/src/arrscope/renderers/__init__.py +0 -0
  24. {arrscope-0.3.0 → arrscope-0.4.0}/test.ipynb +0 -0
  25. {arrscope-0.3.0 → arrscope-0.4.0}/tests/test_adapter.py +0 -0
  26. {arrscope-0.3.0 → arrscope-0.4.0}/tests/test_format.py +0 -0
  27. {arrscope-0.3.0 → arrscope-0.4.0}/tests/test_layout.py +0 -0
  28. {arrscope-0.3.0 → arrscope-0.4.0}/tests/test_stats.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arrscope
3
- Version: 0.3.0
3
+ Version: 0.4.0
4
4
  Summary: Beautiful n-dimensional array visualization for Python
5
5
  Project-URL: Homepage, https://github.com/vizarray/arrscope
6
6
  Project-URL: Repository, https://github.com/vizarray/arrscope
@@ -27,7 +27,7 @@ Description-Content-Type: text/markdown
27
27
 
28
28
  # arrscope
29
29
 
30
- Visualize n-dimensional arrays in the terminal and Jupyter — with structural trees, tiled mosaics, and color-coded value maps.
30
+ Visualize n-dimensional arrays in the terminal and Jupyter — with structural trees, tiled mosaics, array diffing, and distribution sparklines.
31
31
 
32
32
  ```python
33
33
  from arrscope import scope
@@ -40,15 +40,16 @@ scope(np.random.rand(3, 4, 5))
40
40
 
41
41
  - **1D → 6D+**: Tiered visual grammar — lists, grids, trees, nested layers
42
42
  - **Named axes**: Attach semantics (`batch`, `heads`, `h`, `w`)
43
- - **Three color modes**:
43
+ - **Four color modes**:
44
44
  - `dtype` — semantic colors by data type (blue=float, green=int, …)
45
45
  - `heatmap` — diverging colormap (red → light → blue) by value
46
46
  - `sparsity` — zeros as `·`, non-zeros in bold
47
+ - `diff` — compare with a reference array (red=increased, blue=decreased, grey=unchanged)
47
48
  - **Two render styles**:
48
49
  - `tree` (default) — hierarchical branch view with colored guide lines
49
50
  - `mosaic` — all 2D sub-slices tiled side by side as numeric tables
50
- - **Method chaining**: `scope(arr, mode="heatmap").tree().mosaic()`switch styles without re-specifying
51
- - **Stats overlay**: min, max, mean, std, zero%, NaN count (always shown)
51
+ - **Distribution sparkline**: every output shows a unicode histogram (`▁▂▃▄▅▆▇█`) of value distribution replaces text stats
52
+ - **Array diffing**: pass `reference=` to compare any two arrays. Color-coded per-cell changes + aggregate metrics (MSE, % changed)
52
53
  - **Head/tail truncation**: large dims show first/last N slices with `…` (default 20)
53
54
  - **Smart precision**: auto-detects significant figures for floats
54
55
  - **Terminal + Jupyter**: Rich ANSI + static HTML/CSS with dark mode auto-detect
@@ -82,26 +83,41 @@ scope(
82
83
 
83
84
  # Custom grid dims
84
85
  scope(data, axes=['a', 'b', 'c', 'd'], grid=['a', 'b'])
86
+
87
+ # Method chaining
88
+ r = scope(arr, mode='heatmap')
89
+ print(r.tree())
90
+ print(r.mosaic())
85
91
  ```
86
92
 
87
93
  ## Color modes
88
94
 
89
95
  ```python
90
- scope(arr, mode='dtype') # default
91
- scope(arr, mode='heatmap') # diverging colormap
92
- scope(arr, mode='sparsity') # · for zeros, bold for non-zeros
96
+ scope(arr, mode='dtype') # default — blue floats, green ints, ...
97
+ scope(arr, mode='heatmap') # diverging colormap by value
98
+ scope(arr, mode='sparsity') # · for zeros, bold for non-zeros
99
+ scope(arr, mode='diff', # red/blue by change direction
100
+ reference=original_array) # Δ -0.5 ▁▂▃▄▅▆▇█ +0.5 mse=0.02 12% changed
101
+ ```
102
+
103
+ ## Array diffing
104
+
105
+ Compare any two arrays of the same shape. Values show the current array; color encodes the change:
106
+
107
+ ```python
108
+ before = np.random.rand(3, 4, 5)
109
+ after = before + np.random.normal(0, 0.1, before.shape)
110
+
111
+ scope(after, reference=before, mode="diff")
93
112
  ```
94
113
 
114
+ Diff stats replace the sparkline: range of deltas, MSE, and percentage of elements that changed.
115
+
95
116
  ## Render styles
96
117
 
97
118
  ```python
98
119
  scope(arr) # tree (default) — click-to-expand layers
99
120
  scope(arr, render_style='mosaic') # all sub-slices tiled side by side
100
-
101
- # Method chaining for switching styles
102
- r = scope(arr, mode='heatmap')
103
- print(r.tree())
104
- print(r.mosaic())
105
121
  ```
106
122
 
107
123
  ## CLI
@@ -115,17 +131,10 @@ arrscope 20x4x5 --max-height 6
115
131
  ## Framework support
116
132
 
117
133
  ```python
118
- import torch
119
- scope(torch.randn(2, 3, 4))
120
-
121
- import tensorflow as tf
122
- scope(tf.random.uniform((2, 3, 4)))
123
-
124
- import jax.numpy as jnp
125
- scope(jnp.array([[1, 2], [3, 4]]))
126
-
127
- from tinygrad import Tensor
128
- scope(Tensor.randn(3, 4))
134
+ import torch; scope(torch.randn(2, 3, 4))
135
+ import tensorflow as tf; scope(tf.random.uniform((2, 3, 4)))
136
+ import jax.numpy as jnp; scope(jnp.array([[1, 2], [3, 4]]))
137
+ from tinygrad import Tensor; scope(Tensor.randn(3, 4))
129
138
  ```
130
139
 
131
140
  ## API
@@ -138,8 +147,9 @@ scope(
138
147
  title=None, # str — heading above the visualization
139
148
  max_height=20, # int | None — rows before truncation, None disables
140
149
  fmt=None, # str — format spec like '.4f'
141
- mode='dtype', # 'dtype' | 'heatmap' | 'sparsity'
150
+ mode='dtype', # 'dtype' | 'heatmap' | 'sparsity' | 'diff'
142
151
  render_style='tree', # 'tree' | 'mosaic'
152
+ reference=None, # array-like — reference for diff mode
143
153
  )
144
154
  ```
145
155
 
@@ -1,6 +1,6 @@
1
1
  # arrscope
2
2
 
3
- Visualize n-dimensional arrays in the terminal and Jupyter — with structural trees, tiled mosaics, and color-coded value maps.
3
+ Visualize n-dimensional arrays in the terminal and Jupyter — with structural trees, tiled mosaics, array diffing, and distribution sparklines.
4
4
 
5
5
  ```python
6
6
  from arrscope import scope
@@ -13,15 +13,16 @@ scope(np.random.rand(3, 4, 5))
13
13
 
14
14
  - **1D → 6D+**: Tiered visual grammar — lists, grids, trees, nested layers
15
15
  - **Named axes**: Attach semantics (`batch`, `heads`, `h`, `w`)
16
- - **Three color modes**:
16
+ - **Four color modes**:
17
17
  - `dtype` — semantic colors by data type (blue=float, green=int, …)
18
18
  - `heatmap` — diverging colormap (red → light → blue) by value
19
19
  - `sparsity` — zeros as `·`, non-zeros in bold
20
+ - `diff` — compare with a reference array (red=increased, blue=decreased, grey=unchanged)
20
21
  - **Two render styles**:
21
22
  - `tree` (default) — hierarchical branch view with colored guide lines
22
23
  - `mosaic` — all 2D sub-slices tiled side by side as numeric tables
23
- - **Method chaining**: `scope(arr, mode="heatmap").tree().mosaic()`switch styles without re-specifying
24
- - **Stats overlay**: min, max, mean, std, zero%, NaN count (always shown)
24
+ - **Distribution sparkline**: every output shows a unicode histogram (`▁▂▃▄▅▆▇█`) of value distribution replaces text stats
25
+ - **Array diffing**: pass `reference=` to compare any two arrays. Color-coded per-cell changes + aggregate metrics (MSE, % changed)
25
26
  - **Head/tail truncation**: large dims show first/last N slices with `…` (default 20)
26
27
  - **Smart precision**: auto-detects significant figures for floats
27
28
  - **Terminal + Jupyter**: Rich ANSI + static HTML/CSS with dark mode auto-detect
@@ -55,26 +56,41 @@ scope(
55
56
 
56
57
  # Custom grid dims
57
58
  scope(data, axes=['a', 'b', 'c', 'd'], grid=['a', 'b'])
59
+
60
+ # Method chaining
61
+ r = scope(arr, mode='heatmap')
62
+ print(r.tree())
63
+ print(r.mosaic())
58
64
  ```
59
65
 
60
66
  ## Color modes
61
67
 
62
68
  ```python
63
- scope(arr, mode='dtype') # default
64
- scope(arr, mode='heatmap') # diverging colormap
65
- scope(arr, mode='sparsity') # · for zeros, bold for non-zeros
69
+ scope(arr, mode='dtype') # default — blue floats, green ints, ...
70
+ scope(arr, mode='heatmap') # diverging colormap by value
71
+ scope(arr, mode='sparsity') # · for zeros, bold for non-zeros
72
+ scope(arr, mode='diff', # red/blue by change direction
73
+ reference=original_array) # Δ -0.5 ▁▂▃▄▅▆▇█ +0.5 mse=0.02 12% changed
74
+ ```
75
+
76
+ ## Array diffing
77
+
78
+ Compare any two arrays of the same shape. Values show the current array; color encodes the change:
79
+
80
+ ```python
81
+ before = np.random.rand(3, 4, 5)
82
+ after = before + np.random.normal(0, 0.1, before.shape)
83
+
84
+ scope(after, reference=before, mode="diff")
66
85
  ```
67
86
 
87
+ Diff stats replace the sparkline: range of deltas, MSE, and percentage of elements that changed.
88
+
68
89
  ## Render styles
69
90
 
70
91
  ```python
71
92
  scope(arr) # tree (default) — click-to-expand layers
72
93
  scope(arr, render_style='mosaic') # all sub-slices tiled side by side
73
-
74
- # Method chaining for switching styles
75
- r = scope(arr, mode='heatmap')
76
- print(r.tree())
77
- print(r.mosaic())
78
94
  ```
79
95
 
80
96
  ## CLI
@@ -88,17 +104,10 @@ arrscope 20x4x5 --max-height 6
88
104
  ## Framework support
89
105
 
90
106
  ```python
91
- import torch
92
- scope(torch.randn(2, 3, 4))
93
-
94
- import tensorflow as tf
95
- scope(tf.random.uniform((2, 3, 4)))
96
-
97
- import jax.numpy as jnp
98
- scope(jnp.array([[1, 2], [3, 4]]))
99
-
100
- from tinygrad import Tensor
101
- scope(Tensor.randn(3, 4))
107
+ import torch; scope(torch.randn(2, 3, 4))
108
+ import tensorflow as tf; scope(tf.random.uniform((2, 3, 4)))
109
+ import jax.numpy as jnp; scope(jnp.array([[1, 2], [3, 4]]))
110
+ from tinygrad import Tensor; scope(Tensor.randn(3, 4))
102
111
  ```
103
112
 
104
113
  ## API
@@ -111,8 +120,9 @@ scope(
111
120
  title=None, # str — heading above the visualization
112
121
  max_height=20, # int | None — rows before truncation, None disables
113
122
  fmt=None, # str — format spec like '.4f'
114
- mode='dtype', # 'dtype' | 'heatmap' | 'sparsity'
123
+ mode='dtype', # 'dtype' | 'heatmap' | 'sparsity' | 'diff'
115
124
  render_style='tree', # 'tree' | 'mosaic'
125
+ reference=None, # array-like — reference for diff mode
116
126
  )
117
127
  ```
118
128
 
@@ -89,4 +89,9 @@ print(
89
89
  )
90
90
  )
91
91
 
92
+ print("\n1️⃣6️⃣ DIFF — compare two arrays (value = current, color = change)")
93
+ a = np.random.rand(4, 5)
94
+ b = a + np.random.normal(0, 0.1, a.shape)
95
+ print(scope(a, reference=b, mode="diff"))
96
+
92
97
  print("\n Done!")
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "arrscope"
7
- version = "0.3.0"
7
+ version = "0.4.0"
8
8
  description = "Beautiful n-dimensional array visualization for Python"
9
9
  readme = "README.md"
10
10
  license = { text = "MIT" }
@@ -1,12 +1,11 @@
1
1
  from typing import Literal
2
2
 
3
- from arrscope._format import format_value
4
3
  from arrscope._layout import build_layout, compute_stats
5
4
  from arrscope._types import VizOutput
6
5
  from arrscope.adapters import to_numpy
7
6
  from arrscope.renderers import render_html, render_terminal
8
7
 
9
- Mode = Literal["dtype", "heatmap", "sparsity"]
8
+ Mode = Literal["dtype", "heatmap", "sparsity", "diff"]
10
9
  RenderStyle = Literal["tree", "mosaic"]
11
10
 
12
11
 
@@ -20,6 +19,7 @@ def scope(
20
19
  fmt: str | None = None,
21
20
  mode: Mode | None = None,
22
21
  render_style: RenderStyle | None = None,
22
+ reference=None,
23
23
  ) -> VizOutput:
24
24
  """Visualize an n-dimensional array in the terminal and/or Jupyter.
25
25
 
@@ -50,6 +50,9 @@ def scope(
50
50
  scaled to the array's global min/max range.
51
51
  - ``"sparsity"`` — Zeros rendered as dim ``·``, non-zero
52
52
  values in bold.
53
+ - ``"diff"`` — Compare with a reference array. Values
54
+ colored red/blue by how much they increased/decreased.
55
+ Requires ``reference``.
53
56
 
54
57
  render_style: Visual layout:
55
58
 
@@ -58,10 +61,13 @@ def scope(
58
61
  - ``"mosaic"`` — All 2D sub-slices tiled side by side as
59
62
  numeric tables.
60
63
 
64
+ reference: A reference array to compare against. Must have
65
+ the same shape as ``arr``. When provided, ``mode`` defaults
66
+ to ``"diff"`` and stats show difference metrics.
67
+
61
68
  Returns:
62
- A ``VizOutput`` with ``.ansi`` and ``.html`` attributes.
63
- Use ``.tree()`` or ``.mosaic()`` to switch render styles
64
- without re-specifying parameters.
69
+ A ``VizOutput`` with ``.tree()`` and ``.mosaic()`` methods.
70
+ Renders as ANSI in the terminal and as HTML in Jupyter.
65
71
 
66
72
  Examples:
67
73
  >>> import numpy as np
@@ -70,19 +76,29 @@ def scope(
70
76
 
71
77
  >>> scope(np.eye(5), title="Identity", mode="sparsity")
72
78
 
73
- >>> scope(
74
- ... np.random.rand(2, 8, 32, 32),
75
- ... axes=["batch", "heads", "h", "w"],
76
- ... mode="heatmap",
77
- ... )
79
+ >>> a = np.random.rand(3, 4, 5)
80
+ >>> b = a + np.random.normal(0, 0.1, a.shape)
81
+ >>> scope(a, reference=b, mode="diff")
78
82
  """
79
83
  arr_np = to_numpy(arr)
80
84
 
81
- mode = mode or "dtype"
85
+ if reference is not None:
86
+ ref_np = to_numpy(reference)
87
+ if ref_np.shape != arr_np.shape:
88
+ raise ValueError(
89
+ f"reference shape {ref_np.shape} does not match arr shape {arr_np.shape}"
90
+ )
91
+ else:
92
+ ref_np = None
93
+
94
+ mode = mode or ("diff" if ref_np is not None else "dtype")
82
95
  rs = render_style or "tree"
83
96
 
84
- if mode not in ("dtype", "heatmap", "sparsity"):
85
- raise ValueError(f"Unknown mode '{mode}'. Expected one of: dtype, heatmap, sparsity")
97
+ if mode not in ("dtype", "heatmap", "sparsity", "diff"):
98
+ raise ValueError(f"Unknown mode '{mode}'. Expected: dtype, heatmap, sparsity, diff")
99
+
100
+ if mode == "diff" and ref_np is None:
101
+ raise ValueError("mode='diff' requires a reference array via the `reference` parameter")
86
102
 
87
103
  if rs not in ("tree", "mosaic"):
88
104
  raise ValueError(f"Unknown render_style '{rs}'. Expected one of: tree, mosaic")
@@ -91,15 +107,15 @@ def scope(
91
107
  use_html = _in_jupyter()
92
108
 
93
109
  def _render(style: str) -> VizOutput:
94
- node = build_layout(arr_np, grid_dims=grid_dims, max_height=max_height)
95
- gs = compute_stats(arr_np)
110
+ node = build_layout(arr_np, grid_dims=grid_dims, max_height=max_height, reference=ref_np)
111
+ gs = compute_stats(arr_np, reference=ref_np)
96
112
  if gs:
97
113
  node.stats = gs
98
114
 
99
115
  out = VizOutput(_rerender=_render)
100
116
  if use_html:
101
- out.html = _wrap_html(render_html(node, mode=mode, global_stats=gs, render_style=style), title)
102
- out.ansi = _wrap_ansi(render_terminal(node, mode=mode, global_stats=gs, render_style=style), title)
117
+ out._html = _wrap_html(render_html(node, mode=mode, global_stats=gs, render_style=style), title)
118
+ out._ansi = _wrap_ansi(render_terminal(node, mode=mode, global_stats=gs, render_style=style), title)
103
119
  return out
104
120
 
105
121
  return _render(rs)
@@ -23,6 +23,18 @@ def format_value(val, fmt: str | None = None) -> str:
23
23
  return str(val)
24
24
 
25
25
 
26
+ def sparkline(stats: dict) -> str:
27
+ """Return a unicode histogram sparkline from pre-computed bins."""
28
+ bins = stats.get("histogram")
29
+ if not bins:
30
+ return ""
31
+ mx = max(bins)
32
+ if mx == 0:
33
+ return "▁" * len(bins)
34
+ chars = ["▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"]
35
+ return "".join(chars[min(int(b / mx * 7), 7)] for b in bins)
36
+
37
+
26
38
  def _format_float(val: float, fmt: str | None) -> str:
27
39
  if fmt is not None:
28
40
  return f"{val:{fmt}}"
@@ -12,6 +12,7 @@ def build_layout(
12
12
  *,
13
13
  grid_dims: list[int] | None = None,
14
14
  max_height: int | None = None,
15
+ reference: np.ndarray | None = None,
15
16
  ) -> VizNode:
16
17
  shape = arr.shape
17
18
  ndim = len(shape)
@@ -21,16 +22,31 @@ def build_layout(
21
22
 
22
23
  non_grid = [d for d in range(ndim) if d not in grid_dims]
23
24
 
24
- root = _build_tree(arr, non_grid, grid_dims, (), max_height)
25
+ root = _build_tree(arr, non_grid, grid_dims, (), max_height, reference)
25
26
  root.shape = shape
26
27
  root.label = _shape_label(shape, arr.dtype)
27
28
  return root
28
29
 
29
30
 
30
- def compute_stats(arr: np.ndarray) -> dict:
31
+ def compute_stats(arr: np.ndarray, reference: np.ndarray | None = None) -> dict:
31
32
  flat = arr.ravel()
32
33
  stats: dict = {}
33
34
 
35
+ if reference is not None:
36
+ ref_flat = reference.ravel()
37
+ diff = flat.astype(float) - ref_flat.astype(float)
38
+ finite_diff = diff[np.isfinite(diff)]
39
+ if len(finite_diff) > 0:
40
+ stats["max_abs_diff"] = float(np.max(np.abs(finite_diff)))
41
+ stats["mse"] = float(np.mean(finite_diff ** 2))
42
+ stats["pct_changed"] = float(np.mean(np.abs(diff) > 1e-15) * 100)
43
+ hist, edges = np.histogram(finite_diff, bins=10)
44
+ stats["histogram"] = hist.tolist()
45
+ stats["hist_min"] = float(edges[0])
46
+ stats["hist_max"] = float(edges[-1])
47
+ stats["dtype"] = str(arr.dtype)
48
+ return stats
49
+
34
50
  if np.issubdtype(arr.dtype, np.floating) or np.issubdtype(arr.dtype, np.integer):
35
51
  finite = flat[np.isfinite(flat)]
36
52
  if len(finite) > 0:
@@ -38,6 +54,10 @@ def compute_stats(arr: np.ndarray) -> dict:
38
54
  stats["max"] = float(finite.max())
39
55
  stats["mean"] = float(finite.mean())
40
56
  stats["std"] = float(finite.std())
57
+ hist, edges = np.histogram(finite, bins=10)
58
+ stats["histogram"] = hist.tolist()
59
+ stats["hist_min"] = float(edges[0])
60
+ stats["hist_max"] = float(edges[-1])
41
61
  else:
42
62
  stats["min"] = stats["max"] = stats["mean"] = stats["std"] = float("nan")
43
63
  elif np.issubdtype(arr.dtype, np.bool_):
@@ -64,13 +84,18 @@ def _build_tree(
64
84
  grid_dims: list[int],
65
85
  indices: tuple[int, ...],
66
86
  max_height: int | None,
87
+ reference: np.ndarray | None = None,
67
88
  ) -> VizNode:
68
89
  if not non_grid:
90
+ ref = None
91
+ if reference is not None:
92
+ ref = reference
69
93
  return VizNode(
70
94
  indices=indices,
71
95
  label=_indices_label(indices),
72
96
  shape=arr.shape,
73
97
  grid_data=arr,
98
+ ref_data=ref,
74
99
  )
75
100
 
76
101
  current_dim = non_grid[0]
@@ -99,10 +124,11 @@ def _build_tree(
99
124
 
100
125
  child_indices = indices + (i,)
101
126
  child_arr = _slice_dim(arr, current_dim, i)
127
+ child_ref = _slice_dim(reference, current_dim, i) if reference is not None else None
102
128
 
103
129
  new_remaining = _adjust_indices(remaining, current_dim)
104
130
  child = _build_tree(
105
- child_arr, new_remaining, grid_dims, child_indices, max_height
131
+ child_arr, new_remaining, grid_dims, child_indices, max_height, child_ref
106
132
  )
107
133
  children.append(child)
108
134
 
@@ -13,6 +13,7 @@ class VizNode:
13
13
  shape: tuple[int, ...] = ()
14
14
  children: list[VizNode] = field(default_factory=list)
15
15
  grid_data: np.ndarray | None = None
16
+ ref_data: np.ndarray | None = None
16
17
  truncated: bool = False
17
18
  truncated_head: int = 0
18
19
  truncated_tail: int = 0
@@ -24,24 +25,26 @@ _RerenderFn = Callable[[str], "VizOutput"]
24
25
 
25
26
  @dataclass
26
27
  class VizOutput:
27
- html: str = ""
28
- ansi: str = ""
29
- metadata: dict = field(default_factory=dict)
28
+ _html: str = ""
29
+ _ansi: str = ""
30
30
  _rerender: _RerenderFn | None = None
31
31
 
32
32
  def _repr_html_(self) -> str:
33
- return self.html
33
+ return self._html
34
34
 
35
35
  def __rich_console__(self, console, options):
36
36
  from rich.text import Text
37
37
 
38
- yield Text.from_ansi(self.ansi)
38
+ yield Text.from_ansi(self._ansi)
39
39
 
40
40
  def __str__(self) -> str:
41
- return self.ansi
41
+ return self._ansi
42
42
 
43
43
  def __repr__(self) -> str:
44
- return self.ansi
44
+ return self._ansi
45
+
46
+ def __dir__(self) -> list[str]:
47
+ return ["tree", "mosaic"]
45
48
 
46
49
  def tree(self) -> VizOutput:
47
50
  return self._rerender("tree") if self._rerender else self
@@ -4,7 +4,7 @@ import html as html_mod
4
4
 
5
5
  import numpy as np
6
6
 
7
- from arrscope._format import format_value
7
+ from arrscope._format import format_value, sparkline
8
8
  from arrscope._types import VizNode
9
9
 
10
10
 
@@ -23,7 +23,7 @@ def render_html(
23
23
 
24
24
  stats_html = ""
25
25
  if global_stats and node.stats:
26
- stats_html = _stats_html(global_stats)
26
+ stats_html = _stats_sparkline_html(global_stats, mode)
27
27
 
28
28
  return f"""<div class="arrscope-wrapper">
29
29
  <style>{css}</style>
@@ -53,7 +53,7 @@ def _render_mosaic_html(
53
53
  tiles: list[str] = []
54
54
  for leaf in leaves:
55
55
  if leaf.grid_data is not None:
56
- tiles.append(_render_grid_html(leaf.grid_data, leaf.label, mode, global_stats))
56
+ tiles.append(_render_grid_html(leaf.grid_data, leaf.label, mode, global_stats, leaf.ref_data))
57
57
 
58
58
  return '<div style="display: flex; flex-wrap: wrap; gap: 16px; align-items: flex-start;">' + "".join(tiles) + "</div>"
59
59
 
@@ -104,6 +104,7 @@ def _css() -> str:
104
104
  font-size: 12px;
105
105
  color: #666;
106
106
  margin: 2px 0 4px;
107
+ white-space: pre;
107
108
  }
108
109
  details.arrscope-details > summary {
109
110
  cursor: pointer;
@@ -129,28 +130,42 @@ details.arrscope-details[open] > summary {
129
130
  """
130
131
 
131
132
 
132
- def _stats_html(stats: dict) -> str:
133
- items = []
134
- for key in ("min", "max", "mean", "std"):
135
- if key in stats:
136
- v = stats[key]
137
- if isinstance(v, float):
138
- items.append(f"{key}={format_value(v)}")
139
- else:
140
- items.append(f"{key}={v}")
141
- if "zero_count" in stats and "total" in stats:
142
- pct = stats["zero_count"] / stats["total"] * 100
143
- items.append(f"zeros={pct:.1f}%")
144
- if "nan_count" in stats and stats["nan_count"] > 0:
145
- items.append(f"NaN={stats['nan_count']}")
146
- return f'<div class="arrscope-stats">{" &middot; ".join(items)}</div>'
133
+ def _stats_sparkline_html(stats: dict, mode: str) -> str:
134
+ sp = sparkline(stats)
135
+ parts = []
136
+
137
+ if mode == "diff":
138
+ lo = stats.get("hist_min")
139
+ hi = stats.get("hist_max")
140
+ if lo is not None and hi is not None and sp:
141
+ parts.append(f"Δ {format_value(lo)} {sp} {format_value(hi)}")
142
+ mse = stats.get("mse")
143
+ pct = stats.get("pct_changed")
144
+ if mse is not None:
145
+ parts.append(f"mse={format_value(mse)}")
146
+ if pct is not None:
147
+ parts.append(f"{pct:.1f}% changed")
148
+ else:
149
+ lo = stats.get("hist_min")
150
+ hi = stats.get("hist_max")
151
+ if lo is not None and hi is not None and sp:
152
+ parts.append(f"{format_value(lo)} {sp} {format_value(hi)}")
153
+ if "zero_count" in stats and "total" in stats:
154
+ pct = stats["zero_count"] / stats["total"] * 100
155
+ parts.append(f"zeros={pct:.1f}%")
156
+ if "nan_count" in stats and stats["nan_count"] > 0:
157
+ parts.append(f"NaN={stats['nan_count']}")
158
+
159
+ if not parts:
160
+ return ""
161
+ return f'<div class="arrscope-stats">{" &middot; ".join(html_mod.escape(p) for p in parts)}</div>'
147
162
 
148
163
 
149
164
  def _render_node_html(
150
165
  node: VizNode, depth: int, mode: str, global_stats: dict | None
151
166
  ) -> str:
152
167
  if node.grid_data is not None:
153
- return _render_grid_html(node.grid_data, node.label, mode, global_stats)
168
+ return _render_grid_html(node.grid_data, node.label, mode, global_stats, node.ref_data)
154
169
 
155
170
  if not node.children and node.truncated:
156
171
  return f'<div class="arrscope-ellipsis">{html_mod.escape(node.label)}</div>'
@@ -190,30 +205,35 @@ def _format_header(node: VizNode) -> str:
190
205
 
191
206
 
192
207
  def _render_grid_html(
193
- data: np.ndarray, label: str, mode: str, global_stats: dict | None
208
+ data: np.ndarray, label: str, mode: str, global_stats: dict | None,
209
+ ref_data: np.ndarray | None = None,
194
210
  ) -> str:
195
211
  if data.ndim == 0:
196
212
  return f"<span>{html_mod.escape(format_value(data.item()))}</span>"
197
213
 
198
214
  if data.ndim == 1:
199
- return _render_1d_html(data, mode, global_stats)
215
+ return _render_1d_html(data, mode, global_stats, ref_data)
200
216
 
201
- return _render_2d_html(data, label, mode, global_stats)
217
+ return _render_2d_html(data, label, mode, global_stats, ref_data)
202
218
 
203
219
 
204
220
  def _render_1d_html(
205
- data: np.ndarray, mode: str, global_stats: dict | None
221
+ data: np.ndarray, mode: str, global_stats: dict | None,
222
+ ref_data: np.ndarray | None = None,
206
223
  ) -> str:
207
- cells = [
208
- f"<td style=\"{_style_html(v, mode, global_stats)}\">"
209
- f"{html_mod.escape(_format_cell_html(v, mode))}</td>"
210
- for v in data
211
- ]
224
+ cells = []
225
+ for i, v in enumerate(data):
226
+ ref = ref_data[i] if ref_data is not None else None
227
+ cells.append(
228
+ f"<td style=\"{_style_html(v, mode, global_stats, ref)}\">"
229
+ f"{html_mod.escape(_format_cell_html(v, mode))}</td>"
230
+ )
212
231
  return f"<table><tr>{''.join(cells)}</tr></table>"
213
232
 
214
233
 
215
234
  def _render_2d_html(
216
- data: np.ndarray, label: str, mode: str, global_stats: dict | None
235
+ data: np.ndarray, label: str, mode: str, global_stats: dict | None,
236
+ ref_data: np.ndarray | None = None,
217
237
  ) -> str:
218
238
  nrows, ncols = data.shape
219
239
  rows: list[str] = []
@@ -227,8 +247,9 @@ def _render_2d_html(
227
247
  rows.append("<tr>")
228
248
  for col_idx in range(ncols):
229
249
  val = data[row_idx, col_idx]
250
+ ref = ref_data[row_idx, col_idx] if ref_data is not None else None
230
251
  text = html_mod.escape(_format_cell_html(val, mode))
231
- style = _style_html(val, mode, global_stats)
252
+ style = _style_html(val, mode, global_stats, ref)
232
253
  rows.append(f"<td style=\"{style}\">{text}</td>")
233
254
  rows.append("</tr>")
234
255
  rows.append("</table>")
@@ -251,7 +272,9 @@ def _is_zero(val) -> bool:
251
272
  return val == 0
252
273
 
253
274
 
254
- def _style_html(val, mode: str, global_stats: dict | None) -> str:
275
+ def _style_html(val, mode: str, global_stats: dict | None, ref_val=None) -> str:
276
+ if mode == "diff":
277
+ return _diff_style(val, ref_val, global_stats)
255
278
  if mode == "heatmap":
256
279
  return _heatmap_style(val, global_stats)
257
280
  if mode == "sparsity":
@@ -259,6 +282,35 @@ def _style_html(val, mode: str, global_stats: dict | None) -> str:
259
282
  return _dtype_style(val)
260
283
 
261
284
 
285
+ def _diff_style(val, ref_val, global_stats: dict | None) -> str:
286
+ if ref_val is None:
287
+ return ""
288
+ if not isinstance(val, (np.floating, float, np.integer, int)):
289
+ return _dtype_style(val)
290
+ if not isinstance(ref_val, (np.floating, float, np.integer, int)):
291
+ return ""
292
+
293
+ diff = float(val) - float(ref_val)
294
+ if abs(diff) < 1e-15:
295
+ return "color: #999; opacity: 0.4;"
296
+
297
+ max_abs = global_stats.get("max_abs_diff", 1.0) if global_stats else 1.0
298
+ if max_abs == 0:
299
+ max_abs = 1.0
300
+ t = min(abs(diff) / max_abs, 1.0)
301
+
302
+ if diff > 0:
303
+ r = int(180 + 75 * t)
304
+ g = int(40 * (1 - t))
305
+ b = int(40 * (1 - t))
306
+ else:
307
+ r = int(40 * (1 - t))
308
+ g = int(40 * (1 - t))
309
+ b = int(180 + 75 * t)
310
+
311
+ return f"color: rgb({r},{g},{b}); font-weight: bold;"
312
+
313
+
262
314
  def _heatmap_style(val, global_stats: dict | None) -> str:
263
315
  if not global_stats:
264
316
  return ""
@@ -9,7 +9,7 @@ from rich.table import Table
9
9
  from rich.text import Text
10
10
  from rich.tree import Tree
11
11
 
12
- from arrscope._format import format_value
12
+ from arrscope._format import format_value, sparkline
13
13
  from arrscope._types import VizNode
14
14
 
15
15
  BRANCH_COLORS = ["grey50", "cyan", "green", "yellow", "magenta", "blue"]
@@ -33,7 +33,7 @@ def render_terminal(
33
33
  parts.append(body)
34
34
 
35
35
  if global_stats and node.stats:
36
- parts.append(_stats_text(global_stats))
36
+ parts.append(_stats_sparkline(global_stats, mode))
37
37
 
38
38
  with console.capture() as capture:
39
39
  for p in parts:
@@ -62,35 +62,47 @@ def _render_mosaic(
62
62
  tables: list[Table] = []
63
63
  for leaf in leaves:
64
64
  if leaf.grid_data is not None:
65
- t = _render_grid(leaf.grid_data, leaf.label, mode, global_stats)
65
+ t = _render_grid(leaf.grid_data, leaf.label, mode, global_stats, leaf.ref_data)
66
66
  tables.append(t)
67
67
 
68
68
  return Columns(tables, equal=True, expand=False)
69
69
 
70
70
 
71
- def _stats_text(stats: dict) -> Text:
71
+ def _stats_sparkline(stats: dict, mode: str) -> Text:
72
72
  dim = Style(dim=True, color="grey50")
73
73
  t = Text()
74
- fmt = ".4f" if stats.get("dtype", "").startswith(("float", "int")) else ".4f"
74
+ sp = sparkline(stats)
75
+
76
+ if mode == "diff":
77
+ lo = stats.get("hist_min")
78
+ hi = stats.get("hist_max")
79
+ if lo is not None and hi is not None and sp:
80
+ t.append(f"Δ {format_value(lo)} {sp} {format_value(hi)}", style=dim)
81
+ mse = stats.get("mse")
82
+ pct = stats.get("pct_changed")
83
+ parts = []
84
+ if mse is not None:
85
+ parts.append(f"mse={format_value(mse)}")
86
+ if pct is not None:
87
+ parts.append(f"{pct:.1f}% changed")
88
+ if parts:
89
+ t.append(" " + " ".join(parts), style=dim)
90
+ return t
75
91
 
76
- def fmt_val(v):
77
- if isinstance(v, float):
78
- return format_value(v, fmt)
79
- return str(v)
92
+ lo = stats.get("hist_min")
93
+ hi = stats.get("hist_max")
94
+ if lo is not None and hi is not None and sp:
95
+ t.append(f"{format_value(lo)} {sp} {format_value(hi)}", style=dim)
80
96
 
81
97
  items = []
82
- for key in ("min", "max", "mean", "std"):
83
- if key in stats:
84
- items.append(f"{key}={fmt_val(stats[key])}")
85
-
86
98
  if "zero_count" in stats and "total" in stats:
87
99
  pct = stats["zero_count"] / stats["total"] * 100
88
100
  items.append(f"zeros={pct:.1f}%")
89
-
90
101
  if "nan_count" in stats and stats["nan_count"] > 0:
91
102
  items.append(f"NaN={stats['nan_count']}")
103
+ if items:
104
+ t.append(" " + " ".join(items), style=dim)
92
105
 
93
- t.append(" " + " ".join(items), style=dim)
94
106
  return t
95
107
 
96
108
 
@@ -101,7 +113,7 @@ def _render_node(
101
113
  global_stats: dict | None,
102
114
  ) -> Tree | Table | Text:
103
115
  if node.grid_data is not None:
104
- return _render_grid(node.grid_data, node.label, mode, global_stats)
116
+ return _render_grid(node.grid_data, node.label, mode, global_stats, node.ref_data)
105
117
 
106
118
  if not node.children and node.truncated:
107
119
  t = Text(f" {node.label}")
@@ -137,25 +149,28 @@ def _render_grid(
137
149
  label: str,
138
150
  mode: str,
139
151
  global_stats: dict | None,
152
+ ref_data: np.ndarray | None = None,
140
153
  ) -> Text | Table:
141
154
  if data.ndim == 0:
142
155
  return Text(format_value(data.item()))
143
156
 
144
157
  if data.ndim == 1:
145
- return _render_1d(data, mode, global_stats)
158
+ return _render_1d(data, mode, global_stats, ref_data)
146
159
 
147
- return _render_2d(data, label, mode, global_stats)
160
+ return _render_2d(data, label, mode, global_stats, ref_data)
148
161
 
149
162
 
150
163
  def _render_1d(
151
164
  data: np.ndarray,
152
165
  mode: str,
153
166
  global_stats: dict | None,
167
+ ref_data: np.ndarray | None = None,
154
168
  ) -> Text:
155
169
  items = []
156
- for v in data:
170
+ for i, v in enumerate(data):
157
171
  text = _format_cell(v, mode)
158
- items.append(_style_value(text, v, mode, global_stats))
172
+ ref = ref_data[i] if ref_data is not None else None
173
+ items.append(_style_value(text, v, mode, global_stats, ref))
159
174
 
160
175
  t = Text("[")
161
176
  for i, item in enumerate(items):
@@ -171,6 +186,7 @@ def _render_2d(
171
186
  label: str,
172
187
  mode: str,
173
188
  global_stats: dict | None,
189
+ ref_data: np.ndarray | None = None,
174
190
  ) -> Table:
175
191
  nrows, ncols = data.shape
176
192
  table = Table(
@@ -186,15 +202,19 @@ def _render_2d(
186
202
  table.add_column()
187
203
 
188
204
  for row_idx in range(nrows):
189
- row_cells = [
190
- _style_value(
191
- _format_cell(data[row_idx, col_idx], mode),
192
- data[row_idx, col_idx],
193
- mode,
194
- global_stats,
205
+ row_cells = []
206
+ for col_idx in range(ncols):
207
+ val = data[row_idx, col_idx]
208
+ ref = ref_data[row_idx, col_idx] if ref_data is not None else None
209
+ row_cells.append(
210
+ _style_value(
211
+ _format_cell(val, mode),
212
+ val,
213
+ mode,
214
+ global_stats,
215
+ ref,
216
+ )
195
217
  )
196
- for col_idx in range(ncols)
197
- ]
198
218
  table.add_row(*row_cells)
199
219
 
200
220
  return table
@@ -221,8 +241,11 @@ def _style_value(
221
241
  val,
222
242
  mode: str,
223
243
  global_stats: dict | None,
244
+ ref_val=None,
224
245
  ) -> Text:
225
- if mode == "heatmap":
246
+ if mode == "diff":
247
+ cell_style = _diff_color(val, ref_val, global_stats)
248
+ elif mode == "heatmap":
226
249
  cell_style = _heatmap_color(val, global_stats)
227
250
  elif mode == "sparsity":
228
251
  cell_style = _sparsity_color(val)
@@ -232,6 +255,35 @@ def _style_value(
232
255
  return Text(text, style=cell_style)
233
256
 
234
257
 
258
+ def _diff_color(val, ref_val, global_stats: dict | None) -> Style:
259
+ if ref_val is None:
260
+ return Style()
261
+ if not isinstance(val, (np.floating, float, np.integer, int)):
262
+ return _dtype_color(val)
263
+ if not isinstance(ref_val, (np.floating, float, np.integer, int)):
264
+ return Style()
265
+
266
+ diff = float(val) - float(ref_val)
267
+ if abs(diff) < 1e-15:
268
+ return Style(dim=True, color="grey50")
269
+
270
+ max_abs = global_stats.get("max_abs_diff", 1.0) if global_stats else 1.0
271
+ if max_abs == 0:
272
+ max_abs = 1.0
273
+ t = min(abs(diff) / max_abs, 1.0)
274
+
275
+ if diff > 0:
276
+ r = int(180 + 75 * t)
277
+ g = int(40 * (1 - t))
278
+ b = int(40 * (1 - t))
279
+ else:
280
+ r = int(40 * (1 - t))
281
+ g = int(40 * (1 - t))
282
+ b = int(180 + 75 * t)
283
+
284
+ return Style(color=Color.from_rgb(r, g, b))
285
+
286
+
235
287
  def _heatmap_color(val, global_stats: dict | None) -> Style:
236
288
  if not global_stats:
237
289
  return Style()
@@ -250,14 +302,13 @@ def _heatmap_color(val, global_stats: dict | None) -> Style:
250
302
  t = (float(val) - vmin) / (vmax - vmin)
251
303
  t = max(0.0, min(1.0, t))
252
304
 
253
- # diverging colormap: dark red -> light -> dark blue
254
305
  if t < 0.5:
255
- t2 = t / 0.5 # 0..1
306
+ t2 = t / 0.5
256
307
  r = int(180 - 140 * t2)
257
308
  g = int(30 + 210 * t2)
258
309
  b = int(30 + 100 * t2)
259
310
  else:
260
- t2 = (t - 0.5) / 0.5 # 0..1
311
+ t2 = (t - 0.5) / 0.5
261
312
  r = int(40 - 10 * t2)
262
313
  g = int(240 - 200 * t2)
263
314
  b = int(130 + 130 * t2)
@@ -18,7 +18,7 @@ wheels = [
18
18
 
19
19
  [[package]]
20
20
  name = "arrscope"
21
- version = "0.2.0"
21
+ version = "0.3.0"
22
22
  source = { editable = "." }
23
23
  dependencies = [
24
24
  { name = "numpy", version = "2.4.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" },
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes