rainbow-tensor 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,18 @@
1
+ """rainbow-tensor.
2
+
3
+ A small package for IPython and Jupyter notebooks that visualises tensor
4
+ shape, indexing, and slicing as SVG. It is meant for people learning how a
5
+ tensor is structured and how an indexing expression selects elements.
6
+
7
+ Public API:
8
+
9
+ from rainbow_tensor import show_shape, show_index
10
+
11
+ show_shape((2, 2, 2))
12
+ show_index((2, 2, 2), (0, slice(None), 1))
13
+ """
14
+
15
+ from .notebook import show_index, show_shape
16
+
17
+ __version__ = "0.2.1"
18
+ __all__ = ["show_shape", "show_index", "__version__"]
@@ -0,0 +1,124 @@
1
+ """Index validation, slice expansion, and selection computation.
2
+
3
+ This module validates indexing expressions, expands slices into integer
4
+ positions, computes the selected coordinates, and computes the result shape
5
+ after indexing.
6
+ """
7
+
8
+ import itertools
9
+
10
+ from .shape import format_shape
11
+
12
+
13
+ def validate_index(index, shape):
14
+ """Validate an index tuple against a tensor shape.
15
+
16
+ The index must be a tuple whose length matches the tensor rank. Each
17
+ entry must be a non-negative in-range integer or a slice. Negative
18
+ integers and other entry types are rejected in this version.
19
+ """
20
+ if not isinstance(index, tuple):
21
+ raise TypeError(
22
+ f"index must be a tuple, got {type(index).__name__}. "
23
+ f"For a single axis use a one element tuple such as (0,)"
24
+ )
25
+
26
+ if len(index) != len(shape):
27
+ raise ValueError(
28
+ f"index length {len(index)} does not match tensor rank {len(shape)}"
29
+ )
30
+
31
+ for axis, (entry, size) in enumerate(zip(index, shape)):
32
+ if isinstance(entry, bool):
33
+ raise TypeError(f"axis {axis}: boolean indices are not supported")
34
+ if isinstance(entry, int):
35
+ if entry < 0:
36
+ raise IndexError(
37
+ f"axis {axis}: negative indices are not supported, got {entry}"
38
+ )
39
+ if entry >= size:
40
+ raise IndexError(
41
+ f"axis {axis}: integer index {entry} is out of range for size {size}"
42
+ )
43
+ elif isinstance(entry, slice):
44
+ continue
45
+ else:
46
+ raise TypeError(
47
+ f"axis {axis}: unsupported index entry {entry!r}. "
48
+ f"Only integers and slices are supported"
49
+ )
50
+
51
+ return index
52
+
53
+
54
+ def expand_slice(sl, size):
55
+ """Expand a slice into the list of integer positions it selects."""
56
+ return list(range(*sl.indices(size)))
57
+
58
+
59
+ def selected_coordinates(shape, index):
60
+ """Compute every coordinate selected by ``index`` in row-major order."""
61
+ validate_index(index, shape)
62
+ axes = []
63
+ for entry, size in zip(index, shape):
64
+ if isinstance(entry, int):
65
+ axes.append([entry])
66
+ else:
67
+ axes.append(expand_slice(entry, size))
68
+ return list(itertools.product(*axes))
69
+
70
+
71
+ def result_shape(shape, index):
72
+ """Compute the result shape after indexing.
73
+
74
+ Integer entries remove their axis. Slice entries keep their axis with a
75
+ size equal to the number of selected positions.
76
+ """
77
+ validate_index(index, shape)
78
+ out = []
79
+ for entry, size in zip(index, shape):
80
+ if isinstance(entry, slice):
81
+ out.append(len(expand_slice(entry, size)))
82
+ return tuple(out)
83
+
84
+
85
+ def format_slice(sl):
86
+ """Format a slice the way it would appear in indexing notation."""
87
+ if sl.start is None and sl.stop is None and sl.step is None:
88
+ return ":"
89
+ start = "" if sl.start is None else str(sl.start)
90
+ stop = "" if sl.stop is None else str(sl.stop)
91
+ if sl.step is None:
92
+ return f"{start}:{stop}"
93
+ return f"{start}:{stop}:{sl.step}"
94
+
95
+
96
+ def format_index(index):
97
+ """Format an index tuple as a readable string such as ``0, :, 1``."""
98
+ parts = []
99
+ for entry in index:
100
+ if isinstance(entry, slice):
101
+ parts.append(format_slice(entry))
102
+ else:
103
+ parts.append(str(entry))
104
+ return ", ".join(parts)
105
+
106
+
107
+ def explain_index(shape, index):
108
+ """Build the lines explaining how an index transforms a shape."""
109
+ validate_index(index, shape)
110
+ lines = [
111
+ f"Original shape: {format_shape(shape)}",
112
+ f"Index: {format_index(index)}",
113
+ f"Result shape: {format_shape(result_shape(shape, index))}",
114
+ ]
115
+ for axis, entry in enumerate(index):
116
+ if isinstance(entry, int):
117
+ lines.append(
118
+ f"Axis {axis} is removed because integer index {entry} is used."
119
+ )
120
+ else:
121
+ lines.append(
122
+ f"Axis {axis} is kept because slice {format_slice(entry)} is used."
123
+ )
124
+ return lines
@@ -0,0 +1,126 @@
1
+ """Layout calculation.
2
+
3
+ This module converts a tensor shape into 2D drawing coordinates. It knows
4
+ nothing about SVG, so the same layout could be rendered by any backend.
5
+
6
+ The tensor is drawn as nested per-axis frames. Axis 0 is the outer frame,
7
+ each following non-leaf axis is an inner frame, and the leaf axis elements
8
+ are placed as plain text cells inside the innermost frame.
9
+ """
10
+
11
+ from dataclasses import dataclass, field
12
+
13
+ from .shape import flat_index
14
+
15
+ CELL_W = 48
16
+ CELL_H = 34
17
+ CELL_GAP = 8
18
+ ROW_PAD = 9
19
+ ROW_GAP = 12
20
+ BLOCK_PAD = 12
21
+ BLOCK_GAP = 30
22
+ PADDING = 20
23
+
24
+
25
+ @dataclass
26
+ class Cell:
27
+ """A single tensor element placed on the canvas."""
28
+
29
+ x: float
30
+ y: float
31
+ width: float
32
+ height: float
33
+ value: object
34
+ coord: tuple
35
+ selected: bool = False
36
+
37
+
38
+ @dataclass
39
+ class Frame:
40
+ """A grouping rectangle for one axis.
41
+
42
+ ``axis`` is the axis this frame represents, or ``None`` for a neutral
43
+ container that is not tied to a specific axis (used for 1D tensors).
44
+ ``selected`` is true when the region contains at least one selected cell.
45
+ """
46
+
47
+ x: float
48
+ y: float
49
+ width: float
50
+ height: float
51
+ axis: object
52
+ selected: bool = False
53
+
54
+
55
+ @dataclass
56
+ class Layout:
57
+ """All drawing data for one tensor."""
58
+
59
+ cells: list = field(default_factory=list)
60
+ frames: list = field(default_factory=list)
61
+ width: float = 0.0
62
+ height: float = 0.0
63
+
64
+
65
+ def build_layout(shape, selected=None, value_fn=None):
66
+ """Compute the layout for a tensor.
67
+
68
+ ``selected`` is an iterable of coordinates to mark as selected.
69
+ ``value_fn`` maps a coordinate to its display value. When it is ``None``
70
+ sequential row-major values starting at 0 are used.
71
+ """
72
+ sel = {tuple(coord) for coord in (selected or [])}
73
+ ndim = len(shape)
74
+ cols = shape[-1]
75
+
76
+ row_w = cols * CELL_W + (cols - 1) * CELL_GAP + 2 * ROW_PAD
77
+ row_h = CELL_H + 2 * ROW_PAD
78
+
79
+ layout = Layout()
80
+
81
+ def place_cells(rx, ry, prefix):
82
+ for c in range(cols):
83
+ x = rx + ROW_PAD + c * (CELL_W + CELL_GAP)
84
+ coord = prefix + (c,)
85
+ value = value_fn(coord) if value_fn is not None else flat_index(coord, shape)
86
+ layout.cells.append(
87
+ Cell(x, ry + ROW_PAD, CELL_W, CELL_H, value, coord, coord in sel)
88
+ )
89
+
90
+ if ndim == 1:
91
+ rx, ry = PADDING, PADDING
92
+ layout.frames.append(Frame(rx, ry, row_w, row_h, axis=None, selected=bool(sel)))
93
+ place_cells(rx, ry, ())
94
+ layout.width = PADDING * 2 + row_w
95
+ layout.height = PADDING * 2 + row_h
96
+ return layout
97
+
98
+ if ndim == 2:
99
+ rows = shape[0]
100
+ for r in range(rows):
101
+ rx = PADDING
102
+ ry = PADDING + r * (row_h + ROW_GAP)
103
+ row_sel = any(co[0] == r for co in sel)
104
+ layout.frames.append(Frame(rx, ry, row_w, row_h, axis=0, selected=row_sel))
105
+ place_cells(rx, ry, (r,))
106
+ layout.width = PADDING * 2 + row_w
107
+ layout.height = PADDING * 2 + rows * row_h + (rows - 1) * ROW_GAP
108
+ return layout
109
+
110
+ blocks, rows, _ = shape
111
+ block_w = row_w + 2 * BLOCK_PAD
112
+ block_h = rows * row_h + (rows - 1) * ROW_GAP + 2 * BLOCK_PAD
113
+ for b in range(blocks):
114
+ bx = PADDING + b * (block_w + BLOCK_GAP)
115
+ by = PADDING
116
+ block_sel = any(co[0] == b for co in sel)
117
+ layout.frames.append(Frame(bx, by, block_w, block_h, axis=0, selected=block_sel))
118
+ for r in range(rows):
119
+ rx = bx + BLOCK_PAD
120
+ ry = by + BLOCK_PAD + r * (row_h + ROW_GAP)
121
+ row_sel = any(co[0] == b and co[1] == r for co in sel)
122
+ layout.frames.append(Frame(rx, ry, row_w, row_h, axis=1, selected=row_sel))
123
+ place_cells(rx, ry, (b, r))
124
+ layout.width = PADDING * 2 + blocks * block_w + (blocks - 1) * BLOCK_GAP
125
+ layout.height = PADDING * 2 + block_h
126
+ return layout
@@ -0,0 +1,152 @@
1
+ """Notebook display layer.
2
+
3
+ This module exposes the public functions ``show_shape`` and ``show_index``.
4
+ Each returns a small object that renders as SVG in a notebook and stays
5
+ inspectable in plain Python, so the package is testable outside a notebook.
6
+ """
7
+
8
+ from .indexing import (
9
+ explain_index,
10
+ format_slice,
11
+ result_shape,
12
+ selected_coordinates,
13
+ validate_index,
14
+ )
15
+ from .render_svg import (
16
+ AXIS_FRAME_COLORS,
17
+ NEUTRAL_COLOR,
18
+ SELECT_VALUE_COLOR,
19
+ render_svg,
20
+ )
21
+ from .shape import extract_shape
22
+
23
+
24
+ class TensorVisual:
25
+ """A rendered tensor visualisation.
26
+
27
+ The ``svg`` attribute holds the SVG string. In a notebook the object is
28
+ shown as an image through ``_repr_svg_``. Outside a notebook the same
29
+ string is available for inspection and testing.
30
+ """
31
+
32
+ def __init__(self, svg, shape, selected=None, result=None, explanation=None):
33
+ self.svg = svg
34
+ self.shape = shape
35
+ self.selected = selected
36
+ self.result_shape = result
37
+ self.explanation = explanation or []
38
+
39
+ def _repr_svg_(self):
40
+ return self.svg
41
+
42
+ def __str__(self):
43
+ return self.svg
44
+
45
+
46
+ def _value_fn_for(obj):
47
+ """Build a coordinate to value function for array-like input.
48
+
49
+ A tuple carries no values, so generated values are used. Any object with
50
+ a ``.shape`` attribute is read by coordinate, which covers NumPy arrays
51
+ and any compatible array-like without importing a tensor library.
52
+ """
53
+ if isinstance(obj, tuple):
54
+ return None
55
+ if not hasattr(obj, "shape"):
56
+ return None
57
+
58
+ def value_fn(coord):
59
+ value = obj[coord]
60
+ item = getattr(value, "item", None)
61
+ if callable(item):
62
+ return item()
63
+ return value
64
+
65
+ return value_fn
66
+
67
+
68
+ def _shape_label_parts(shape):
69
+ """Build coloured label parts for a shape.
70
+
71
+ Each dimension is coloured to match its frame. Frame axes use their axis
72
+ colour and the leaf axis stays neutral, since it has no frame.
73
+ """
74
+ ndim = len(shape)
75
+ parts = [("Shape (", NEUTRAL_COLOR)]
76
+ for axis, dim in enumerate(shape):
77
+ if axis < ndim - 1:
78
+ color = AXIS_FRAME_COLORS.get(axis, NEUTRAL_COLOR)
79
+ else:
80
+ color = NEUTRAL_COLOR
81
+ parts.append((str(dim), color))
82
+ if axis < ndim - 1:
83
+ parts.append((", ", NEUTRAL_COLOR))
84
+ parts.append((")", NEUTRAL_COLOR))
85
+ return parts
86
+
87
+
88
+ def _index_label_parts(index):
89
+ """Build coloured label parts for an index.
90
+
91
+ Each token is coloured to match its axis. Frame axes use their axis
92
+ colour and the leaf axis uses the selected value colour, so the label
93
+ mirrors the colours drawn on the tensor.
94
+ """
95
+ ndim = len(index)
96
+ parts = [("Index (", NEUTRAL_COLOR)]
97
+ for axis, entry in enumerate(index):
98
+ token = format_slice(entry) if isinstance(entry, slice) else str(entry)
99
+ if axis < ndim - 1:
100
+ color = AXIS_FRAME_COLORS.get(axis, NEUTRAL_COLOR)
101
+ else:
102
+ color = SELECT_VALUE_COLOR
103
+ parts.append((token, color))
104
+ if axis < ndim - 1:
105
+ parts.append((", ", NEUTRAL_COLOR))
106
+ parts.append((")", NEUTRAL_COLOR))
107
+ return parts
108
+
109
+
110
+ def show_shape(shape):
111
+ """Visualise the structure of a tensor shape.
112
+
113
+ ``shape`` may be a tuple such as ``(2, 2, 2)`` or an array-like object
114
+ with a ``.shape`` attribute such as a NumPy array. The returned
115
+ :class:`TensorVisual` renders as SVG in a notebook through ``_repr_svg_``
116
+ and also exposes the SVG string for inspection and testing.
117
+ """
118
+ normalized = extract_shape(shape)
119
+ value_fn = _value_fn_for(shape)
120
+ label_parts = _shape_label_parts(normalized)
121
+ svg = render_svg(normalized, value_fn=value_fn, label_parts=label_parts)
122
+ return TensorVisual(svg, normalized)
123
+
124
+
125
+ def show_index(tensor_or_shape, index):
126
+ """Visualise how an index selects elements from a tensor.
127
+
128
+ ``tensor_or_shape`` may be a shape tuple or an array-like object with a
129
+ ``.shape`` attribute. ``index`` must be a tuple of integers and slices
130
+ whose length matches the tensor rank. Selected elements are highlighted,
131
+ unselected elements stay visible but de-emphasised, and an explanation of
132
+ the result shape is drawn below the tensor. The returned
133
+ :class:`TensorVisual` renders as SVG in a notebook through ``_repr_svg_``
134
+ and also exposes the SVG string for inspection and testing.
135
+ """
136
+ normalized = extract_shape(tensor_or_shape)
137
+ validate_index(index, normalized)
138
+ selected = selected_coordinates(normalized, index)
139
+ result = result_shape(normalized, index)
140
+ explanation = explain_index(normalized, index)
141
+ value_fn = _value_fn_for(tensor_or_shape)
142
+ label_parts = _index_label_parts(index)
143
+ svg = render_svg(
144
+ normalized,
145
+ selected=selected,
146
+ value_fn=value_fn,
147
+ label_parts=label_parts,
148
+ explanation=explanation,
149
+ )
150
+ return TensorVisual(
151
+ svg, normalized, selected=selected, result=result, explanation=explanation
152
+ )
@@ -0,0 +1,156 @@
1
+ """SVG rendering.
2
+
3
+ This module turns a :class:`~rainbow_tensor.layout.Layout` into an SVG
4
+ string. Each axis has its own colour. Axis 0 frames are red and axis 1
5
+ frames are orange. The leaf axis elements are plain text. Selected elements
6
+ are drawn green and, in an index view, only the selected frames keep their
7
+ axis colour while the rest are drawn in a neutral dark tone.
8
+ """
9
+
10
+ from .layout import build_layout
11
+
12
+ AXIS_FRAME_COLORS = {0: "#dc2626", 1: "#f97316"}
13
+ SELECT_VALUE_COLOR = "#16a34a"
14
+ NEUTRAL_COLOR = "#111827"
15
+ TEXT_COLOR = "#111827"
16
+ LABEL_COLOR = "#334155"
17
+
18
+ LABEL_HEIGHT = 30
19
+ LINE_HEIGHT = 20
20
+ FRAME_WIDTH = 3
21
+ TEXT_MARGIN = 20
22
+ LABEL_FONT_SIZE = 16
23
+ EXPLANATION_FONT_SIZE = 13
24
+ # Approximate width of one monospace character relative to the font size.
25
+ CHAR_WIDTH_RATIO = 0.62
26
+ FONT_FAMILY = "ui-monospace, Menlo, Consolas, monospace"
27
+
28
+
29
+ def escape(text):
30
+ """Escape a value so it is safe inside SVG text and attributes."""
31
+ return (
32
+ str(text)
33
+ .replace("&", "&amp;")
34
+ .replace("<", "&lt;")
35
+ .replace(">", "&gt;")
36
+ .replace('"', "&quot;")
37
+ )
38
+
39
+
40
+ def svg_document(content, width, height):
41
+ """Wrap rendered elements in a complete SVG document."""
42
+ return (
43
+ f'<svg xmlns="http://www.w3.org/2000/svg" '
44
+ f'width="{width:.0f}" height="{height:.0f}" '
45
+ f'viewBox="0 0 {width:.0f} {height:.0f}" '
46
+ f'font-family="{FONT_FAMILY}">'
47
+ f"{content}"
48
+ f"</svg>"
49
+ )
50
+
51
+
52
+ def _frame_color(frame, has_selection):
53
+ """Pick the stroke colour for a frame.
54
+
55
+ In a shape view every frame uses its axis colour. In an index view only
56
+ selected frames keep their axis colour, the rest are neutral.
57
+ """
58
+ axis_color = AXIS_FRAME_COLORS.get(frame.axis, NEUTRAL_COLOR)
59
+ if not has_selection:
60
+ return axis_color
61
+ return axis_color if frame.selected else NEUTRAL_COLOR
62
+
63
+
64
+ def _cell_color(cell, has_selection):
65
+ """Pick the text colour for a cell value."""
66
+ if has_selection and cell.selected:
67
+ return SELECT_VALUE_COLOR
68
+ return TEXT_COLOR
69
+
70
+
71
+ def _text_width(char_count, font_size):
72
+ """Estimate the rendered width of a monospace string."""
73
+ return char_count * font_size * CHAR_WIDTH_RATIO
74
+
75
+
76
+ def _text_block_width(label_len, explanation):
77
+ """Estimate the width needed so the label and explanation are not clipped."""
78
+ widest = _text_width(label_len, LABEL_FONT_SIZE)
79
+ for line in explanation:
80
+ widest = max(widest, _text_width(len(line), EXPLANATION_FONT_SIZE))
81
+ return 2 * TEXT_MARGIN + widest
82
+
83
+
84
+ def _label_element(x, y, parts):
85
+ """Build a label text element from coloured parts."""
86
+ spans = "".join(
87
+ f'<tspan fill="{escape(color)}">{escape(text)}</tspan>' for text, color in parts
88
+ )
89
+ return (
90
+ f'<text x="{x:.0f}" y="{y:.0f}" font-size="16" font-weight="600">{spans}</text>'
91
+ )
92
+
93
+
94
+ def render_svg(
95
+ shape, selected=None, value_fn=None, label="", label_parts=None, explanation=None
96
+ ):
97
+ """Render a tensor as an SVG string.
98
+
99
+ ``selected`` is an iterable of selected coordinates. ``value_fn`` maps a
100
+ coordinate to its display value. ``label`` is a plain label string while
101
+ ``label_parts`` is an optional list of ``(text, colour)`` pairs for a
102
+ coloured label. ``explanation`` is an optional list of lines drawn below.
103
+ """
104
+ explanation = explanation or []
105
+ selected_list = list(selected or [])
106
+ has_selection = len(selected_list) > 0
107
+ layout = build_layout(shape, selected=selected_list, value_fn=value_fn)
108
+
109
+ parts = []
110
+
111
+ for frame in layout.frames:
112
+ color = _frame_color(frame, has_selection)
113
+ parts.append(
114
+ f'<rect x="{frame.x:.0f}" y="{frame.y:.0f}" '
115
+ f'width="{frame.width:.0f}" height="{frame.height:.0f}" '
116
+ f'rx="6" fill="none" stroke="{color}" stroke-width="{FRAME_WIDTH}"/>'
117
+ )
118
+
119
+ for cell in layout.cells:
120
+ color = _cell_color(cell, has_selection)
121
+ weight = "700" if (has_selection and cell.selected) else "400"
122
+ cx = cell.x + cell.width / 2
123
+ cy = cell.y + cell.height / 2
124
+ parts.append(
125
+ f'<text x="{cx:.0f}" y="{cy:.0f}" text-anchor="middle" '
126
+ f'dominant-baseline="central" font-size="15" font-weight="{weight}" '
127
+ f'fill="{color}">{escape(cell.value)}</text>'
128
+ )
129
+
130
+ if label_parts:
131
+ label_len = sum(len(text) for text, _ in label_parts)
132
+ else:
133
+ label_len = len(label)
134
+ text_width = _text_block_width(label_len, explanation)
135
+ width = max(layout.width, text_width)
136
+ y = layout.height + LABEL_HEIGHT
137
+
138
+ if label_parts:
139
+ parts.append(_label_element(20, y, label_parts))
140
+ y += LINE_HEIGHT
141
+ elif label:
142
+ parts.append(
143
+ f'<text x="20" y="{y:.0f}" font-size="16" font-weight="600" '
144
+ f'fill="{LABEL_COLOR}">{escape(label)}</text>'
145
+ )
146
+ y += LINE_HEIGHT
147
+
148
+ for line in explanation:
149
+ y += LINE_HEIGHT - 4
150
+ parts.append(
151
+ f'<text x="20" y="{y:.0f}" font-size="13" fill="{LABEL_COLOR}">'
152
+ f"{escape(line)}</text>"
153
+ )
154
+
155
+ height = y + 16
156
+ return svg_document("".join(parts), width, height)
@@ -0,0 +1,92 @@
1
+ """Shape extraction and validation.
2
+
3
+ This module turns user input into a validated ``tuple[int, ...]`` shape and
4
+ generates the display values used when only a shape is provided.
5
+ """
6
+
7
+ import itertools
8
+
9
+ SUPPORTED_NDIM = (1, 2, 3)
10
+
11
+
12
+ def extract_shape(obj):
13
+ """Extract a validated shape from a tuple or an array-like object.
14
+
15
+ A tuple is treated as a literal shape. Any other object is expected to
16
+ expose a ``.shape`` attribute (for example a NumPy array). No tensor
17
+ library is imported, so this works for any object with ``.shape``.
18
+ """
19
+ if isinstance(obj, tuple):
20
+ raw = obj
21
+ elif hasattr(obj, "shape"):
22
+ raw = obj.shape
23
+ else:
24
+ raw = obj
25
+ return validate_shape(raw)
26
+
27
+
28
+ def validate_shape(shape):
29
+ """Validate a shape and return it as a ``tuple[int, ...]``.
30
+
31
+ The shape must be a non-empty sequence of positive integers with a
32
+ supported number of dimensions (1D, 2D, or 3D in this version).
33
+ """
34
+ if not isinstance(shape, tuple):
35
+ try:
36
+ shape = tuple(shape)
37
+ except TypeError as exc:
38
+ raise TypeError(
39
+ "shape must be a tuple of integers or an object with a .shape attribute"
40
+ ) from exc
41
+
42
+ if len(shape) == 0:
43
+ raise ValueError("shape must have at least one dimension, got an empty shape")
44
+
45
+ for dim in shape:
46
+ if isinstance(dim, bool) or not isinstance(dim, int):
47
+ raise TypeError(f"shape dimensions must be integers, got {dim!r}")
48
+ if dim <= 0:
49
+ raise ValueError(f"shape dimensions must be positive, got {dim!r}")
50
+
51
+ if len(shape) not in SUPPORTED_NDIM:
52
+ raise ValueError(
53
+ f"only 1D, 2D, and 3D tensors are supported in this version, "
54
+ f"got {len(shape)} dimensions"
55
+ )
56
+
57
+ return shape
58
+
59
+
60
+ def coordinates(shape):
61
+ """Yield every coordinate in row-major order for the given shape."""
62
+ return itertools.product(*[range(dim) for dim in shape])
63
+
64
+
65
+ def flat_index(coord, shape):
66
+ """Return the row-major flat index of ``coord`` inside ``shape``."""
67
+ index = 0
68
+ for value, dim in zip(coord, shape):
69
+ index = index * dim + value
70
+ return index
71
+
72
+
73
+ def generate_values(shape):
74
+ """Map every coordinate of ``shape`` to a sequential value from 0.
75
+
76
+ For shape ``(2, 2, 2)`` this produces the values 0 through 7 in
77
+ row-major order.
78
+ """
79
+ return {coord: flat_index(coord, shape) for coord in coordinates(shape)}
80
+
81
+
82
+ def format_shape(shape):
83
+ """Format a shape the way Python prints a tuple, for example ``(2,)``."""
84
+ return str(tuple(shape))
85
+
86
+
87
+ def format_shape_label(shape):
88
+ """Format a shape for an on-screen label without a trailing comma.
89
+
90
+ For shape ``(3,)`` this returns ``(3)`` rather than ``(3,)``.
91
+ """
92
+ return "(" + ", ".join(str(dim) for dim in shape) + ")"
@@ -0,0 +1,124 @@
1
+ Metadata-Version: 2.4
2
+ Name: rainbow-tensor
3
+ Version: 0.2.1
4
+ Summary: Visualise tensor shape, indexing, and slicing as SVG in Jupyter notebooks
5
+ Author-email: Zhixiang Feng <contact@zhixiangfeng.com>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/Niox1337/rainbow-tensor
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Framework :: IPython
12
+ Classifier: Framework :: Jupyter
13
+ Requires-Python: >=3.9
14
+ Description-Content-Type: text/markdown
15
+ Requires-Dist: numpy>=1.19.0
16
+ Requires-Dist: ipython>=7.0.0
17
+ Provides-Extra: dev
18
+ Requires-Dist: pytest>=7.0; extra == "dev"
19
+ Requires-Dist: ruff>=0.1; extra == "dev"
20
+ Requires-Dist: build>=1.0; extra == "dev"
21
+
22
+ # rainbow-tensor
23
+
24
+ Visualise tensor shape, indexing, and slicing as SVG inside IPython and Jupyter notebooks.
25
+
26
+ rainbow-tensor is made for people who are learning how a tensor is structured and how an indexing expression selects elements. It draws the tensor as nested blocks, rows, and cells, then highlights exactly which elements an index picks out.
27
+
28
+ ## Features
29
+
30
+ - Static SVG output that stays sharp at any zoom level in a notebook
31
+ - Shape visualisation for 1D, 2D, and 3D tensors
32
+ - Index visualisation with highlighted selections and a plain text explanation
33
+ - Works with shape tuples and with array-like objects that expose a `.shape` attribute, such as NumPy arrays
34
+ - No tensor library is imported by the core, so the package stays lightweight
35
+
36
+ ## Colour scheme
37
+
38
+ Each axis has its own colour so the structure and a selection are easy to read.
39
+
40
+ - Axis 0 is the outer frame, drawn red
41
+ - Axis 1 is the inner row frame, drawn orange
42
+ - The leaf axis elements are plain text, and a selected element is drawn green
43
+ - The numbers in the shape label and the tokens in the index label are coloured to match
44
+
45
+ In an index view only the selected frames keep their axis colour. The rest of the tensor is drawn in a neutral dark tone so the selected path stands out.
46
+
47
+ ## Installation
48
+
49
+ Install from source for development.
50
+
51
+ ```bash
52
+ git clone https://github.com/Niox1337/rainbow-tensor.git
53
+ cd rainbow-tensor
54
+ pip install -e .
55
+ ```
56
+
57
+ Install with the development tools (pytest, ruff, build).
58
+
59
+ ```bash
60
+ pip install -e ".[dev]"
61
+ ```
62
+
63
+ ## Usage
64
+
65
+ Run the examples in a Jupyter notebook or an IPython shell so the SVG is displayed.
66
+
67
+ Visualise a shape.
68
+
69
+ ```python
70
+ from rainbow_tensor import show_shape
71
+
72
+ show_shape((2, 2, 2))
73
+ ```
74
+
75
+ Visualise how an index selects elements.
76
+
77
+ ```python
78
+ from rainbow_tensor import show_index
79
+
80
+ show_index((2, 2, 2), (0, slice(None), 1))
81
+ ```
82
+
83
+ For the shape `(2, 2, 2)` the index `(0, slice(None), 1)` selects the values `1` and `3`, the selected coordinates are `(0, 0, 1)` and `(0, 1, 1)`, and the result shape is `(2,)`.
84
+
85
+ Use a real NumPy array to display its actual values.
86
+
87
+ ```python
88
+ import numpy as np
89
+ from rainbow_tensor import show_shape, show_index
90
+
91
+ x = np.arange(8).reshape(2, 2, 2)
92
+ show_shape(x)
93
+ show_index(x, (0, slice(None), 1))
94
+ ```
95
+
96
+ Each function returns a small result object. Its `svg` attribute holds the SVG string, so the package can be inspected and tested outside a notebook.
97
+
98
+ ## Supported
99
+
100
+ - 1D, 2D, and 3D tensors
101
+ - Shape tuples and array-like objects with a `.shape` attribute
102
+ - Integer indexing
103
+ - Basic slicing with `slice(None)`, `slice(start, stop)`, and `slice(start, stop, step)`
104
+
105
+ ## Unsupported in the first version
106
+
107
+ - Advanced NumPy indexing
108
+ - Boolean masks
109
+ - `None` and `newaxis`
110
+ - Ellipsis
111
+ - 4D or higher tensors
112
+ - Interactive controls and animation
113
+
114
+ ## Development
115
+
116
+ ```bash
117
+ pytest
118
+ ruff check .
119
+ python -m build
120
+ ```
121
+
122
+ ## License
123
+
124
+ MIT
@@ -0,0 +1,10 @@
1
+ rainbow_tensor/__init__.py,sha256=NqgwPMsJP1om_TOSe4cAwikbY2fnW-I0WXW5TXjZJ1g,507
2
+ rainbow_tensor/indexing.py,sha256=Fe3gS2gEMzEGer6SxuTfMBILLT3JERZtulac2pmunrc,4061
3
+ rainbow_tensor/layout.py,sha256=YxqsV7ie2cnXBoxVjcN41B_Cegw0bHZwzmMShrl81yc,3854
4
+ rainbow_tensor/notebook.py,sha256=DdX_XWvl6SleWz985kIJgRzI4U6-qe3oEUro17YjJys,5031
5
+ rainbow_tensor/render_svg.py,sha256=bQxZq8LNKIYdWvlWLZIBWXDvbonSvzpuLN-Z8RJItHc,5164
6
+ rainbow_tensor/shape.py,sha256=0fsuvkxhf_6LwgvjBKzsmwAx1FFsFa-TJsJ2sEGM7y0,2828
7
+ rainbow_tensor-0.2.1.dist-info/METADATA,sha256=VUO6PZ6lv-SGLh-TF3Ifwsr3BxrGZMsrBHOI8mx5CxM,3745
8
+ rainbow_tensor-0.2.1.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
9
+ rainbow_tensor-0.2.1.dist-info/top_level.txt,sha256=86NVkchsTdikpDRTvZFN2aapL_CSfvwIlB4Xj3oeitM,15
10
+ rainbow_tensor-0.2.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ rainbow_tensor