rainbow-tensor 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rainbow_tensor/__init__.py +18 -0
- rainbow_tensor/indexing.py +124 -0
- rainbow_tensor/layout.py +126 -0
- rainbow_tensor/notebook.py +152 -0
- rainbow_tensor/render_svg.py +156 -0
- rainbow_tensor/shape.py +92 -0
- rainbow_tensor-0.2.1.dist-info/METADATA +124 -0
- rainbow_tensor-0.2.1.dist-info/RECORD +10 -0
- rainbow_tensor-0.2.1.dist-info/WHEEL +5 -0
- rainbow_tensor-0.2.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""rainbow-tensor.
|
|
2
|
+
|
|
3
|
+
A small package for IPython and Jupyter notebooks that visualises tensor
|
|
4
|
+
shape, indexing, and slicing as SVG. It is meant for people learning how a
|
|
5
|
+
tensor is structured and how an indexing expression selects elements.
|
|
6
|
+
|
|
7
|
+
Public API:
|
|
8
|
+
|
|
9
|
+
from rainbow_tensor import show_shape, show_index
|
|
10
|
+
|
|
11
|
+
show_shape((2, 2, 2))
|
|
12
|
+
show_index((2, 2, 2), (0, slice(None), 1))
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from .notebook import show_index, show_shape
|
|
16
|
+
|
|
17
|
+
__version__ = "0.2.1"
|
|
18
|
+
__all__ = ["show_shape", "show_index", "__version__"]
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""Index validation, slice expansion, and selection computation.
|
|
2
|
+
|
|
3
|
+
This module validates indexing expressions, expands slices into integer
|
|
4
|
+
positions, computes the selected coordinates, and computes the result shape
|
|
5
|
+
after indexing.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import itertools
|
|
9
|
+
|
|
10
|
+
from .shape import format_shape
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def validate_index(index, shape):
|
|
14
|
+
"""Validate an index tuple against a tensor shape.
|
|
15
|
+
|
|
16
|
+
The index must be a tuple whose length matches the tensor rank. Each
|
|
17
|
+
entry must be a non-negative in-range integer or a slice. Negative
|
|
18
|
+
integers and other entry types are rejected in this version.
|
|
19
|
+
"""
|
|
20
|
+
if not isinstance(index, tuple):
|
|
21
|
+
raise TypeError(
|
|
22
|
+
f"index must be a tuple, got {type(index).__name__}. "
|
|
23
|
+
f"For a single axis use a one element tuple such as (0,)"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
if len(index) != len(shape):
|
|
27
|
+
raise ValueError(
|
|
28
|
+
f"index length {len(index)} does not match tensor rank {len(shape)}"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
for axis, (entry, size) in enumerate(zip(index, shape)):
|
|
32
|
+
if isinstance(entry, bool):
|
|
33
|
+
raise TypeError(f"axis {axis}: boolean indices are not supported")
|
|
34
|
+
if isinstance(entry, int):
|
|
35
|
+
if entry < 0:
|
|
36
|
+
raise IndexError(
|
|
37
|
+
f"axis {axis}: negative indices are not supported, got {entry}"
|
|
38
|
+
)
|
|
39
|
+
if entry >= size:
|
|
40
|
+
raise IndexError(
|
|
41
|
+
f"axis {axis}: integer index {entry} is out of range for size {size}"
|
|
42
|
+
)
|
|
43
|
+
elif isinstance(entry, slice):
|
|
44
|
+
continue
|
|
45
|
+
else:
|
|
46
|
+
raise TypeError(
|
|
47
|
+
f"axis {axis}: unsupported index entry {entry!r}. "
|
|
48
|
+
f"Only integers and slices are supported"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
return index
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def expand_slice(sl, size):
|
|
55
|
+
"""Expand a slice into the list of integer positions it selects."""
|
|
56
|
+
return list(range(*sl.indices(size)))
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def selected_coordinates(shape, index):
|
|
60
|
+
"""Compute every coordinate selected by ``index`` in row-major order."""
|
|
61
|
+
validate_index(index, shape)
|
|
62
|
+
axes = []
|
|
63
|
+
for entry, size in zip(index, shape):
|
|
64
|
+
if isinstance(entry, int):
|
|
65
|
+
axes.append([entry])
|
|
66
|
+
else:
|
|
67
|
+
axes.append(expand_slice(entry, size))
|
|
68
|
+
return list(itertools.product(*axes))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def result_shape(shape, index):
|
|
72
|
+
"""Compute the result shape after indexing.
|
|
73
|
+
|
|
74
|
+
Integer entries remove their axis. Slice entries keep their axis with a
|
|
75
|
+
size equal to the number of selected positions.
|
|
76
|
+
"""
|
|
77
|
+
validate_index(index, shape)
|
|
78
|
+
out = []
|
|
79
|
+
for entry, size in zip(index, shape):
|
|
80
|
+
if isinstance(entry, slice):
|
|
81
|
+
out.append(len(expand_slice(entry, size)))
|
|
82
|
+
return tuple(out)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def format_slice(sl):
|
|
86
|
+
"""Format a slice the way it would appear in indexing notation."""
|
|
87
|
+
if sl.start is None and sl.stop is None and sl.step is None:
|
|
88
|
+
return ":"
|
|
89
|
+
start = "" if sl.start is None else str(sl.start)
|
|
90
|
+
stop = "" if sl.stop is None else str(sl.stop)
|
|
91
|
+
if sl.step is None:
|
|
92
|
+
return f"{start}:{stop}"
|
|
93
|
+
return f"{start}:{stop}:{sl.step}"
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def format_index(index):
|
|
97
|
+
"""Format an index tuple as a readable string such as ``0, :, 1``."""
|
|
98
|
+
parts = []
|
|
99
|
+
for entry in index:
|
|
100
|
+
if isinstance(entry, slice):
|
|
101
|
+
parts.append(format_slice(entry))
|
|
102
|
+
else:
|
|
103
|
+
parts.append(str(entry))
|
|
104
|
+
return ", ".join(parts)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def explain_index(shape, index):
|
|
108
|
+
"""Build the lines explaining how an index transforms a shape."""
|
|
109
|
+
validate_index(index, shape)
|
|
110
|
+
lines = [
|
|
111
|
+
f"Original shape: {format_shape(shape)}",
|
|
112
|
+
f"Index: {format_index(index)}",
|
|
113
|
+
f"Result shape: {format_shape(result_shape(shape, index))}",
|
|
114
|
+
]
|
|
115
|
+
for axis, entry in enumerate(index):
|
|
116
|
+
if isinstance(entry, int):
|
|
117
|
+
lines.append(
|
|
118
|
+
f"Axis {axis} is removed because integer index {entry} is used."
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
lines.append(
|
|
122
|
+
f"Axis {axis} is kept because slice {format_slice(entry)} is used."
|
|
123
|
+
)
|
|
124
|
+
return lines
|
rainbow_tensor/layout.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""Layout calculation.
|
|
2
|
+
|
|
3
|
+
This module converts a tensor shape into 2D drawing coordinates. It knows
|
|
4
|
+
nothing about SVG, so the same layout could be rendered by any backend.
|
|
5
|
+
|
|
6
|
+
The tensor is drawn as nested per-axis frames. Axis 0 is the outer frame,
|
|
7
|
+
each following non-leaf axis is an inner frame, and the leaf axis elements
|
|
8
|
+
are placed as plain text cells inside the innermost frame.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
|
|
13
|
+
from .shape import flat_index
|
|
14
|
+
|
|
15
|
+
CELL_W = 48
|
|
16
|
+
CELL_H = 34
|
|
17
|
+
CELL_GAP = 8
|
|
18
|
+
ROW_PAD = 9
|
|
19
|
+
ROW_GAP = 12
|
|
20
|
+
BLOCK_PAD = 12
|
|
21
|
+
BLOCK_GAP = 30
|
|
22
|
+
PADDING = 20
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class Cell:
|
|
27
|
+
"""A single tensor element placed on the canvas."""
|
|
28
|
+
|
|
29
|
+
x: float
|
|
30
|
+
y: float
|
|
31
|
+
width: float
|
|
32
|
+
height: float
|
|
33
|
+
value: object
|
|
34
|
+
coord: tuple
|
|
35
|
+
selected: bool = False
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class Frame:
|
|
40
|
+
"""A grouping rectangle for one axis.
|
|
41
|
+
|
|
42
|
+
``axis`` is the axis this frame represents, or ``None`` for a neutral
|
|
43
|
+
container that is not tied to a specific axis (used for 1D tensors).
|
|
44
|
+
``selected`` is true when the region contains at least one selected cell.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
x: float
|
|
48
|
+
y: float
|
|
49
|
+
width: float
|
|
50
|
+
height: float
|
|
51
|
+
axis: object
|
|
52
|
+
selected: bool = False
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass
|
|
56
|
+
class Layout:
|
|
57
|
+
"""All drawing data for one tensor."""
|
|
58
|
+
|
|
59
|
+
cells: list = field(default_factory=list)
|
|
60
|
+
frames: list = field(default_factory=list)
|
|
61
|
+
width: float = 0.0
|
|
62
|
+
height: float = 0.0
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def build_layout(shape, selected=None, value_fn=None):
|
|
66
|
+
"""Compute the layout for a tensor.
|
|
67
|
+
|
|
68
|
+
``selected`` is an iterable of coordinates to mark as selected.
|
|
69
|
+
``value_fn`` maps a coordinate to its display value. When it is ``None``
|
|
70
|
+
sequential row-major values starting at 0 are used.
|
|
71
|
+
"""
|
|
72
|
+
sel = {tuple(coord) for coord in (selected or [])}
|
|
73
|
+
ndim = len(shape)
|
|
74
|
+
cols = shape[-1]
|
|
75
|
+
|
|
76
|
+
row_w = cols * CELL_W + (cols - 1) * CELL_GAP + 2 * ROW_PAD
|
|
77
|
+
row_h = CELL_H + 2 * ROW_PAD
|
|
78
|
+
|
|
79
|
+
layout = Layout()
|
|
80
|
+
|
|
81
|
+
def place_cells(rx, ry, prefix):
|
|
82
|
+
for c in range(cols):
|
|
83
|
+
x = rx + ROW_PAD + c * (CELL_W + CELL_GAP)
|
|
84
|
+
coord = prefix + (c,)
|
|
85
|
+
value = value_fn(coord) if value_fn is not None else flat_index(coord, shape)
|
|
86
|
+
layout.cells.append(
|
|
87
|
+
Cell(x, ry + ROW_PAD, CELL_W, CELL_H, value, coord, coord in sel)
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if ndim == 1:
|
|
91
|
+
rx, ry = PADDING, PADDING
|
|
92
|
+
layout.frames.append(Frame(rx, ry, row_w, row_h, axis=None, selected=bool(sel)))
|
|
93
|
+
place_cells(rx, ry, ())
|
|
94
|
+
layout.width = PADDING * 2 + row_w
|
|
95
|
+
layout.height = PADDING * 2 + row_h
|
|
96
|
+
return layout
|
|
97
|
+
|
|
98
|
+
if ndim == 2:
|
|
99
|
+
rows = shape[0]
|
|
100
|
+
for r in range(rows):
|
|
101
|
+
rx = PADDING
|
|
102
|
+
ry = PADDING + r * (row_h + ROW_GAP)
|
|
103
|
+
row_sel = any(co[0] == r for co in sel)
|
|
104
|
+
layout.frames.append(Frame(rx, ry, row_w, row_h, axis=0, selected=row_sel))
|
|
105
|
+
place_cells(rx, ry, (r,))
|
|
106
|
+
layout.width = PADDING * 2 + row_w
|
|
107
|
+
layout.height = PADDING * 2 + rows * row_h + (rows - 1) * ROW_GAP
|
|
108
|
+
return layout
|
|
109
|
+
|
|
110
|
+
blocks, rows, _ = shape
|
|
111
|
+
block_w = row_w + 2 * BLOCK_PAD
|
|
112
|
+
block_h = rows * row_h + (rows - 1) * ROW_GAP + 2 * BLOCK_PAD
|
|
113
|
+
for b in range(blocks):
|
|
114
|
+
bx = PADDING + b * (block_w + BLOCK_GAP)
|
|
115
|
+
by = PADDING
|
|
116
|
+
block_sel = any(co[0] == b for co in sel)
|
|
117
|
+
layout.frames.append(Frame(bx, by, block_w, block_h, axis=0, selected=block_sel))
|
|
118
|
+
for r in range(rows):
|
|
119
|
+
rx = bx + BLOCK_PAD
|
|
120
|
+
ry = by + BLOCK_PAD + r * (row_h + ROW_GAP)
|
|
121
|
+
row_sel = any(co[0] == b and co[1] == r for co in sel)
|
|
122
|
+
layout.frames.append(Frame(rx, ry, row_w, row_h, axis=1, selected=row_sel))
|
|
123
|
+
place_cells(rx, ry, (b, r))
|
|
124
|
+
layout.width = PADDING * 2 + blocks * block_w + (blocks - 1) * BLOCK_GAP
|
|
125
|
+
layout.height = PADDING * 2 + block_h
|
|
126
|
+
return layout
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
"""Notebook display layer.
|
|
2
|
+
|
|
3
|
+
This module exposes the public functions ``show_shape`` and ``show_index``.
|
|
4
|
+
Each returns a small object that renders as SVG in a notebook and stays
|
|
5
|
+
inspectable in plain Python, so the package is testable outside a notebook.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .indexing import (
|
|
9
|
+
explain_index,
|
|
10
|
+
format_slice,
|
|
11
|
+
result_shape,
|
|
12
|
+
selected_coordinates,
|
|
13
|
+
validate_index,
|
|
14
|
+
)
|
|
15
|
+
from .render_svg import (
|
|
16
|
+
AXIS_FRAME_COLORS,
|
|
17
|
+
NEUTRAL_COLOR,
|
|
18
|
+
SELECT_VALUE_COLOR,
|
|
19
|
+
render_svg,
|
|
20
|
+
)
|
|
21
|
+
from .shape import extract_shape
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TensorVisual:
|
|
25
|
+
"""A rendered tensor visualisation.
|
|
26
|
+
|
|
27
|
+
The ``svg`` attribute holds the SVG string. In a notebook the object is
|
|
28
|
+
shown as an image through ``_repr_svg_``. Outside a notebook the same
|
|
29
|
+
string is available for inspection and testing.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, svg, shape, selected=None, result=None, explanation=None):
|
|
33
|
+
self.svg = svg
|
|
34
|
+
self.shape = shape
|
|
35
|
+
self.selected = selected
|
|
36
|
+
self.result_shape = result
|
|
37
|
+
self.explanation = explanation or []
|
|
38
|
+
|
|
39
|
+
def _repr_svg_(self):
|
|
40
|
+
return self.svg
|
|
41
|
+
|
|
42
|
+
def __str__(self):
|
|
43
|
+
return self.svg
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _value_fn_for(obj):
|
|
47
|
+
"""Build a coordinate to value function for array-like input.
|
|
48
|
+
|
|
49
|
+
A tuple carries no values, so generated values are used. Any object with
|
|
50
|
+
a ``.shape`` attribute is read by coordinate, which covers NumPy arrays
|
|
51
|
+
and any compatible array-like without importing a tensor library.
|
|
52
|
+
"""
|
|
53
|
+
if isinstance(obj, tuple):
|
|
54
|
+
return None
|
|
55
|
+
if not hasattr(obj, "shape"):
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
def value_fn(coord):
|
|
59
|
+
value = obj[coord]
|
|
60
|
+
item = getattr(value, "item", None)
|
|
61
|
+
if callable(item):
|
|
62
|
+
return item()
|
|
63
|
+
return value
|
|
64
|
+
|
|
65
|
+
return value_fn
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _shape_label_parts(shape):
|
|
69
|
+
"""Build coloured label parts for a shape.
|
|
70
|
+
|
|
71
|
+
Each dimension is coloured to match its frame. Frame axes use their axis
|
|
72
|
+
colour and the leaf axis stays neutral, since it has no frame.
|
|
73
|
+
"""
|
|
74
|
+
ndim = len(shape)
|
|
75
|
+
parts = [("Shape (", NEUTRAL_COLOR)]
|
|
76
|
+
for axis, dim in enumerate(shape):
|
|
77
|
+
if axis < ndim - 1:
|
|
78
|
+
color = AXIS_FRAME_COLORS.get(axis, NEUTRAL_COLOR)
|
|
79
|
+
else:
|
|
80
|
+
color = NEUTRAL_COLOR
|
|
81
|
+
parts.append((str(dim), color))
|
|
82
|
+
if axis < ndim - 1:
|
|
83
|
+
parts.append((", ", NEUTRAL_COLOR))
|
|
84
|
+
parts.append((")", NEUTRAL_COLOR))
|
|
85
|
+
return parts
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _index_label_parts(index):
|
|
89
|
+
"""Build coloured label parts for an index.
|
|
90
|
+
|
|
91
|
+
Each token is coloured to match its axis. Frame axes use their axis
|
|
92
|
+
colour and the leaf axis uses the selected value colour, so the label
|
|
93
|
+
mirrors the colours drawn on the tensor.
|
|
94
|
+
"""
|
|
95
|
+
ndim = len(index)
|
|
96
|
+
parts = [("Index (", NEUTRAL_COLOR)]
|
|
97
|
+
for axis, entry in enumerate(index):
|
|
98
|
+
token = format_slice(entry) if isinstance(entry, slice) else str(entry)
|
|
99
|
+
if axis < ndim - 1:
|
|
100
|
+
color = AXIS_FRAME_COLORS.get(axis, NEUTRAL_COLOR)
|
|
101
|
+
else:
|
|
102
|
+
color = SELECT_VALUE_COLOR
|
|
103
|
+
parts.append((token, color))
|
|
104
|
+
if axis < ndim - 1:
|
|
105
|
+
parts.append((", ", NEUTRAL_COLOR))
|
|
106
|
+
parts.append((")", NEUTRAL_COLOR))
|
|
107
|
+
return parts
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def show_shape(shape):
|
|
111
|
+
"""Visualise the structure of a tensor shape.
|
|
112
|
+
|
|
113
|
+
``shape`` may be a tuple such as ``(2, 2, 2)`` or an array-like object
|
|
114
|
+
with a ``.shape`` attribute such as a NumPy array. The returned
|
|
115
|
+
:class:`TensorVisual` renders as SVG in a notebook through ``_repr_svg_``
|
|
116
|
+
and also exposes the SVG string for inspection and testing.
|
|
117
|
+
"""
|
|
118
|
+
normalized = extract_shape(shape)
|
|
119
|
+
value_fn = _value_fn_for(shape)
|
|
120
|
+
label_parts = _shape_label_parts(normalized)
|
|
121
|
+
svg = render_svg(normalized, value_fn=value_fn, label_parts=label_parts)
|
|
122
|
+
return TensorVisual(svg, normalized)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def show_index(tensor_or_shape, index):
|
|
126
|
+
"""Visualise how an index selects elements from a tensor.
|
|
127
|
+
|
|
128
|
+
``tensor_or_shape`` may be a shape tuple or an array-like object with a
|
|
129
|
+
``.shape`` attribute. ``index`` must be a tuple of integers and slices
|
|
130
|
+
whose length matches the tensor rank. Selected elements are highlighted,
|
|
131
|
+
unselected elements stay visible but de-emphasised, and an explanation of
|
|
132
|
+
the result shape is drawn below the tensor. The returned
|
|
133
|
+
:class:`TensorVisual` renders as SVG in a notebook through ``_repr_svg_``
|
|
134
|
+
and also exposes the SVG string for inspection and testing.
|
|
135
|
+
"""
|
|
136
|
+
normalized = extract_shape(tensor_or_shape)
|
|
137
|
+
validate_index(index, normalized)
|
|
138
|
+
selected = selected_coordinates(normalized, index)
|
|
139
|
+
result = result_shape(normalized, index)
|
|
140
|
+
explanation = explain_index(normalized, index)
|
|
141
|
+
value_fn = _value_fn_for(tensor_or_shape)
|
|
142
|
+
label_parts = _index_label_parts(index)
|
|
143
|
+
svg = render_svg(
|
|
144
|
+
normalized,
|
|
145
|
+
selected=selected,
|
|
146
|
+
value_fn=value_fn,
|
|
147
|
+
label_parts=label_parts,
|
|
148
|
+
explanation=explanation,
|
|
149
|
+
)
|
|
150
|
+
return TensorVisual(
|
|
151
|
+
svg, normalized, selected=selected, result=result, explanation=explanation
|
|
152
|
+
)
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""SVG rendering.
|
|
2
|
+
|
|
3
|
+
This module turns a :class:`~rainbow_tensor.layout.Layout` into an SVG
|
|
4
|
+
string. Each axis has its own colour. Axis 0 frames are red and axis 1
|
|
5
|
+
frames are orange. The leaf axis elements are plain text. Selected elements
|
|
6
|
+
are drawn green and, in an index view, only the selected frames keep their
|
|
7
|
+
axis colour while the rest are drawn in a neutral dark tone.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from .layout import build_layout
|
|
11
|
+
|
|
12
|
+
AXIS_FRAME_COLORS = {0: "#dc2626", 1: "#f97316"}
|
|
13
|
+
SELECT_VALUE_COLOR = "#16a34a"
|
|
14
|
+
NEUTRAL_COLOR = "#111827"
|
|
15
|
+
TEXT_COLOR = "#111827"
|
|
16
|
+
LABEL_COLOR = "#334155"
|
|
17
|
+
|
|
18
|
+
LABEL_HEIGHT = 30
|
|
19
|
+
LINE_HEIGHT = 20
|
|
20
|
+
FRAME_WIDTH = 3
|
|
21
|
+
TEXT_MARGIN = 20
|
|
22
|
+
LABEL_FONT_SIZE = 16
|
|
23
|
+
EXPLANATION_FONT_SIZE = 13
|
|
24
|
+
# Approximate width of one monospace character relative to the font size.
|
|
25
|
+
CHAR_WIDTH_RATIO = 0.62
|
|
26
|
+
FONT_FAMILY = "ui-monospace, Menlo, Consolas, monospace"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def escape(text):
|
|
30
|
+
"""Escape a value so it is safe inside SVG text and attributes."""
|
|
31
|
+
return (
|
|
32
|
+
str(text)
|
|
33
|
+
.replace("&", "&")
|
|
34
|
+
.replace("<", "<")
|
|
35
|
+
.replace(">", ">")
|
|
36
|
+
.replace('"', """)
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def svg_document(content, width, height):
|
|
41
|
+
"""Wrap rendered elements in a complete SVG document."""
|
|
42
|
+
return (
|
|
43
|
+
f'<svg xmlns="http://www.w3.org/2000/svg" '
|
|
44
|
+
f'width="{width:.0f}" height="{height:.0f}" '
|
|
45
|
+
f'viewBox="0 0 {width:.0f} {height:.0f}" '
|
|
46
|
+
f'font-family="{FONT_FAMILY}">'
|
|
47
|
+
f"{content}"
|
|
48
|
+
f"</svg>"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _frame_color(frame, has_selection):
|
|
53
|
+
"""Pick the stroke colour for a frame.
|
|
54
|
+
|
|
55
|
+
In a shape view every frame uses its axis colour. In an index view only
|
|
56
|
+
selected frames keep their axis colour, the rest are neutral.
|
|
57
|
+
"""
|
|
58
|
+
axis_color = AXIS_FRAME_COLORS.get(frame.axis, NEUTRAL_COLOR)
|
|
59
|
+
if not has_selection:
|
|
60
|
+
return axis_color
|
|
61
|
+
return axis_color if frame.selected else NEUTRAL_COLOR
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _cell_color(cell, has_selection):
|
|
65
|
+
"""Pick the text colour for a cell value."""
|
|
66
|
+
if has_selection and cell.selected:
|
|
67
|
+
return SELECT_VALUE_COLOR
|
|
68
|
+
return TEXT_COLOR
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _text_width(char_count, font_size):
|
|
72
|
+
"""Estimate the rendered width of a monospace string."""
|
|
73
|
+
return char_count * font_size * CHAR_WIDTH_RATIO
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _text_block_width(label_len, explanation):
|
|
77
|
+
"""Estimate the width needed so the label and explanation are not clipped."""
|
|
78
|
+
widest = _text_width(label_len, LABEL_FONT_SIZE)
|
|
79
|
+
for line in explanation:
|
|
80
|
+
widest = max(widest, _text_width(len(line), EXPLANATION_FONT_SIZE))
|
|
81
|
+
return 2 * TEXT_MARGIN + widest
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _label_element(x, y, parts):
|
|
85
|
+
"""Build a label text element from coloured parts."""
|
|
86
|
+
spans = "".join(
|
|
87
|
+
f'<tspan fill="{escape(color)}">{escape(text)}</tspan>' for text, color in parts
|
|
88
|
+
)
|
|
89
|
+
return (
|
|
90
|
+
f'<text x="{x:.0f}" y="{y:.0f}" font-size="16" font-weight="600">{spans}</text>'
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def render_svg(
|
|
95
|
+
shape, selected=None, value_fn=None, label="", label_parts=None, explanation=None
|
|
96
|
+
):
|
|
97
|
+
"""Render a tensor as an SVG string.
|
|
98
|
+
|
|
99
|
+
``selected`` is an iterable of selected coordinates. ``value_fn`` maps a
|
|
100
|
+
coordinate to its display value. ``label`` is a plain label string while
|
|
101
|
+
``label_parts`` is an optional list of ``(text, colour)`` pairs for a
|
|
102
|
+
coloured label. ``explanation`` is an optional list of lines drawn below.
|
|
103
|
+
"""
|
|
104
|
+
explanation = explanation or []
|
|
105
|
+
selected_list = list(selected or [])
|
|
106
|
+
has_selection = len(selected_list) > 0
|
|
107
|
+
layout = build_layout(shape, selected=selected_list, value_fn=value_fn)
|
|
108
|
+
|
|
109
|
+
parts = []
|
|
110
|
+
|
|
111
|
+
for frame in layout.frames:
|
|
112
|
+
color = _frame_color(frame, has_selection)
|
|
113
|
+
parts.append(
|
|
114
|
+
f'<rect x="{frame.x:.0f}" y="{frame.y:.0f}" '
|
|
115
|
+
f'width="{frame.width:.0f}" height="{frame.height:.0f}" '
|
|
116
|
+
f'rx="6" fill="none" stroke="{color}" stroke-width="{FRAME_WIDTH}"/>'
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
for cell in layout.cells:
|
|
120
|
+
color = _cell_color(cell, has_selection)
|
|
121
|
+
weight = "700" if (has_selection and cell.selected) else "400"
|
|
122
|
+
cx = cell.x + cell.width / 2
|
|
123
|
+
cy = cell.y + cell.height / 2
|
|
124
|
+
parts.append(
|
|
125
|
+
f'<text x="{cx:.0f}" y="{cy:.0f}" text-anchor="middle" '
|
|
126
|
+
f'dominant-baseline="central" font-size="15" font-weight="{weight}" '
|
|
127
|
+
f'fill="{color}">{escape(cell.value)}</text>'
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
if label_parts:
|
|
131
|
+
label_len = sum(len(text) for text, _ in label_parts)
|
|
132
|
+
else:
|
|
133
|
+
label_len = len(label)
|
|
134
|
+
text_width = _text_block_width(label_len, explanation)
|
|
135
|
+
width = max(layout.width, text_width)
|
|
136
|
+
y = layout.height + LABEL_HEIGHT
|
|
137
|
+
|
|
138
|
+
if label_parts:
|
|
139
|
+
parts.append(_label_element(20, y, label_parts))
|
|
140
|
+
y += LINE_HEIGHT
|
|
141
|
+
elif label:
|
|
142
|
+
parts.append(
|
|
143
|
+
f'<text x="20" y="{y:.0f}" font-size="16" font-weight="600" '
|
|
144
|
+
f'fill="{LABEL_COLOR}">{escape(label)}</text>'
|
|
145
|
+
)
|
|
146
|
+
y += LINE_HEIGHT
|
|
147
|
+
|
|
148
|
+
for line in explanation:
|
|
149
|
+
y += LINE_HEIGHT - 4
|
|
150
|
+
parts.append(
|
|
151
|
+
f'<text x="20" y="{y:.0f}" font-size="13" fill="{LABEL_COLOR}">'
|
|
152
|
+
f"{escape(line)}</text>"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
height = y + 16
|
|
156
|
+
return svg_document("".join(parts), width, height)
|
rainbow_tensor/shape.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""Shape extraction and validation.
|
|
2
|
+
|
|
3
|
+
This module turns user input into a validated ``tuple[int, ...]`` shape and
|
|
4
|
+
generates the display values used when only a shape is provided.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import itertools
|
|
8
|
+
|
|
9
|
+
SUPPORTED_NDIM = (1, 2, 3)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def extract_shape(obj):
|
|
13
|
+
"""Extract a validated shape from a tuple or an array-like object.
|
|
14
|
+
|
|
15
|
+
A tuple is treated as a literal shape. Any other object is expected to
|
|
16
|
+
expose a ``.shape`` attribute (for example a NumPy array). No tensor
|
|
17
|
+
library is imported, so this works for any object with ``.shape``.
|
|
18
|
+
"""
|
|
19
|
+
if isinstance(obj, tuple):
|
|
20
|
+
raw = obj
|
|
21
|
+
elif hasattr(obj, "shape"):
|
|
22
|
+
raw = obj.shape
|
|
23
|
+
else:
|
|
24
|
+
raw = obj
|
|
25
|
+
return validate_shape(raw)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def validate_shape(shape):
|
|
29
|
+
"""Validate a shape and return it as a ``tuple[int, ...]``.
|
|
30
|
+
|
|
31
|
+
The shape must be a non-empty sequence of positive integers with a
|
|
32
|
+
supported number of dimensions (1D, 2D, or 3D in this version).
|
|
33
|
+
"""
|
|
34
|
+
if not isinstance(shape, tuple):
|
|
35
|
+
try:
|
|
36
|
+
shape = tuple(shape)
|
|
37
|
+
except TypeError as exc:
|
|
38
|
+
raise TypeError(
|
|
39
|
+
"shape must be a tuple of integers or an object with a .shape attribute"
|
|
40
|
+
) from exc
|
|
41
|
+
|
|
42
|
+
if len(shape) == 0:
|
|
43
|
+
raise ValueError("shape must have at least one dimension, got an empty shape")
|
|
44
|
+
|
|
45
|
+
for dim in shape:
|
|
46
|
+
if isinstance(dim, bool) or not isinstance(dim, int):
|
|
47
|
+
raise TypeError(f"shape dimensions must be integers, got {dim!r}")
|
|
48
|
+
if dim <= 0:
|
|
49
|
+
raise ValueError(f"shape dimensions must be positive, got {dim!r}")
|
|
50
|
+
|
|
51
|
+
if len(shape) not in SUPPORTED_NDIM:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"only 1D, 2D, and 3D tensors are supported in this version, "
|
|
54
|
+
f"got {len(shape)} dimensions"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
return shape
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def coordinates(shape):
|
|
61
|
+
"""Yield every coordinate in row-major order for the given shape."""
|
|
62
|
+
return itertools.product(*[range(dim) for dim in shape])
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def flat_index(coord, shape):
|
|
66
|
+
"""Return the row-major flat index of ``coord`` inside ``shape``."""
|
|
67
|
+
index = 0
|
|
68
|
+
for value, dim in zip(coord, shape):
|
|
69
|
+
index = index * dim + value
|
|
70
|
+
return index
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def generate_values(shape):
|
|
74
|
+
"""Map every coordinate of ``shape`` to a sequential value from 0.
|
|
75
|
+
|
|
76
|
+
For shape ``(2, 2, 2)`` this produces the values 0 through 7 in
|
|
77
|
+
row-major order.
|
|
78
|
+
"""
|
|
79
|
+
return {coord: flat_index(coord, shape) for coord in coordinates(shape)}
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def format_shape(shape):
|
|
83
|
+
"""Format a shape the way Python prints a tuple, for example ``(2,)``."""
|
|
84
|
+
return str(tuple(shape))
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def format_shape_label(shape):
|
|
88
|
+
"""Format a shape for an on-screen label without a trailing comma.
|
|
89
|
+
|
|
90
|
+
For shape ``(3,)`` this returns ``(3)`` rather than ``(3,)``.
|
|
91
|
+
"""
|
|
92
|
+
return "(" + ", ".join(str(dim) for dim in shape) + ")"
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: rainbow-tensor
|
|
3
|
+
Version: 0.2.1
|
|
4
|
+
Summary: Visualise tensor shape, indexing, and slicing as SVG in Jupyter notebooks
|
|
5
|
+
Author-email: Zhixiang Feng <contact@zhixiangfeng.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/Niox1337/rainbow-tensor
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Classifier: Framework :: IPython
|
|
12
|
+
Classifier: Framework :: Jupyter
|
|
13
|
+
Requires-Python: >=3.9
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
Requires-Dist: numpy>=1.19.0
|
|
16
|
+
Requires-Dist: ipython>=7.0.0
|
|
17
|
+
Provides-Extra: dev
|
|
18
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
19
|
+
Requires-Dist: ruff>=0.1; extra == "dev"
|
|
20
|
+
Requires-Dist: build>=1.0; extra == "dev"
|
|
21
|
+
|
|
22
|
+
# rainbow-tensor
|
|
23
|
+
|
|
24
|
+
Visualise tensor shape, indexing, and slicing as SVG inside IPython and Jupyter notebooks.
|
|
25
|
+
|
|
26
|
+
rainbow-tensor is made for people who are learning how a tensor is structured and how an indexing expression selects elements. It draws the tensor as nested blocks, rows, and cells, then highlights exactly which elements an index picks out.
|
|
27
|
+
|
|
28
|
+
## Features
|
|
29
|
+
|
|
30
|
+
- Static SVG output that stays sharp at any zoom level in a notebook
|
|
31
|
+
- Shape visualisation for 1D, 2D, and 3D tensors
|
|
32
|
+
- Index visualisation with highlighted selections and a plain text explanation
|
|
33
|
+
- Works with shape tuples and with array-like objects that expose a `.shape` attribute, such as NumPy arrays
|
|
34
|
+
- No tensor library is imported by the core, so the package stays lightweight
|
|
35
|
+
|
|
36
|
+
## Colour scheme
|
|
37
|
+
|
|
38
|
+
Each axis has its own colour so the structure and a selection are easy to read.
|
|
39
|
+
|
|
40
|
+
- Axis 0 is the outer frame, drawn red
|
|
41
|
+
- Axis 1 is the inner row frame, drawn orange
|
|
42
|
+
- The leaf axis elements are plain text, and a selected element is drawn green
|
|
43
|
+
- The numbers in the shape label and the tokens in the index label are coloured to match
|
|
44
|
+
|
|
45
|
+
In an index view only the selected frames keep their axis colour. The rest of the tensor is drawn in a neutral dark tone so the selected path stands out.
|
|
46
|
+
|
|
47
|
+
## Installation
|
|
48
|
+
|
|
49
|
+
Install from source for development.
|
|
50
|
+
|
|
51
|
+
```bash
|
|
52
|
+
git clone https://github.com/Niox1337/rainbow-tensor.git
|
|
53
|
+
cd rainbow-tensor
|
|
54
|
+
pip install -e .
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
Install with the development tools (pytest, ruff, build).
|
|
58
|
+
|
|
59
|
+
```bash
|
|
60
|
+
pip install -e ".[dev]"
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
## Usage
|
|
64
|
+
|
|
65
|
+
Run the examples in a Jupyter notebook or an IPython shell so the SVG is displayed.
|
|
66
|
+
|
|
67
|
+
Visualise a shape.
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
from rainbow_tensor import show_shape
|
|
71
|
+
|
|
72
|
+
show_shape((2, 2, 2))
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
Visualise how an index selects elements.
|
|
76
|
+
|
|
77
|
+
```python
|
|
78
|
+
from rainbow_tensor import show_index
|
|
79
|
+
|
|
80
|
+
show_index((2, 2, 2), (0, slice(None), 1))
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
For the shape `(2, 2, 2)` the index `(0, slice(None), 1)` selects the values `1` and `3`, the selected coordinates are `(0, 0, 1)` and `(0, 1, 1)`, and the result shape is `(2,)`.
|
|
84
|
+
|
|
85
|
+
Use a real NumPy array to display its actual values.
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
import numpy as np
|
|
89
|
+
from rainbow_tensor import show_shape, show_index
|
|
90
|
+
|
|
91
|
+
x = np.arange(8).reshape(2, 2, 2)
|
|
92
|
+
show_shape(x)
|
|
93
|
+
show_index(x, (0, slice(None), 1))
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
Each function returns a small result object. Its `svg` attribute holds the SVG string, so the package can be inspected and tested outside a notebook.
|
|
97
|
+
|
|
98
|
+
## Supported
|
|
99
|
+
|
|
100
|
+
- 1D, 2D, and 3D tensors
|
|
101
|
+
- Shape tuples and array-like objects with a `.shape` attribute
|
|
102
|
+
- Integer indexing
|
|
103
|
+
- Basic slicing with `slice(None)`, `slice(start, stop)`, and `slice(start, stop, step)`
|
|
104
|
+
|
|
105
|
+
## Unsupported in the first version
|
|
106
|
+
|
|
107
|
+
- Advanced NumPy indexing
|
|
108
|
+
- Boolean masks
|
|
109
|
+
- `None` and `newaxis`
|
|
110
|
+
- Ellipsis
|
|
111
|
+
- 4D or higher tensors
|
|
112
|
+
- Interactive controls and animation
|
|
113
|
+
|
|
114
|
+
## Development
|
|
115
|
+
|
|
116
|
+
```bash
|
|
117
|
+
pytest
|
|
118
|
+
ruff check .
|
|
119
|
+
python -m build
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
## License
|
|
123
|
+
|
|
124
|
+
MIT
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
rainbow_tensor/__init__.py,sha256=NqgwPMsJP1om_TOSe4cAwikbY2fnW-I0WXW5TXjZJ1g,507
|
|
2
|
+
rainbow_tensor/indexing.py,sha256=Fe3gS2gEMzEGer6SxuTfMBILLT3JERZtulac2pmunrc,4061
|
|
3
|
+
rainbow_tensor/layout.py,sha256=YxqsV7ie2cnXBoxVjcN41B_Cegw0bHZwzmMShrl81yc,3854
|
|
4
|
+
rainbow_tensor/notebook.py,sha256=DdX_XWvl6SleWz985kIJgRzI4U6-qe3oEUro17YjJys,5031
|
|
5
|
+
rainbow_tensor/render_svg.py,sha256=bQxZq8LNKIYdWvlWLZIBWXDvbonSvzpuLN-Z8RJItHc,5164
|
|
6
|
+
rainbow_tensor/shape.py,sha256=0fsuvkxhf_6LwgvjBKzsmwAx1FFsFa-TJsJ2sEGM7y0,2828
|
|
7
|
+
rainbow_tensor-0.2.1.dist-info/METADATA,sha256=VUO6PZ6lv-SGLh-TF3Ifwsr3BxrGZMsrBHOI8mx5CxM,3745
|
|
8
|
+
rainbow_tensor-0.2.1.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
9
|
+
rainbow_tensor-0.2.1.dist-info/top_level.txt,sha256=86NVkchsTdikpDRTvZFN2aapL_CSfvwIlB4Xj3oeitM,15
|
|
10
|
+
rainbow_tensor-0.2.1.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
rainbow_tensor
|