tensor-network-visualization 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,202 @@
1
+ """3D drawing for tensor networks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ from mpl_toolkits.mplot3d.axes3d import Axes3D
7
+
8
+ from ..config import PlotConfig
9
+ from ._draw_common import _draw_scale_params
10
+ from .curves import (
11
+ _ellipse_points_3d,
12
+ _group_contractions,
13
+ _quadratic_curve,
14
+ _require_self_endpoints,
15
+ )
16
+ from .graph import _GraphData
17
+ from .layout import AxisDirections, NodePositions, _orthogonal_unit
18
+
19
+
20
+ def _curved_edge_points_3d(
21
+ *,
22
+ start: np.ndarray,
23
+ end: np.ndarray,
24
+ offset_index: int,
25
+ edge_count: int,
26
+ scale: float = 1.0,
27
+ ) -> np.ndarray:
28
+ midpoint = (start + end) / 2.0
29
+ delta = end - start
30
+ distance = max(float(np.linalg.norm(delta)), 1e-6)
31
+ direction = delta / distance
32
+ reference = np.array([0.0, 0.0, 1.0], dtype=float)
33
+ perpendicular = np.cross(direction, reference)
34
+ if np.linalg.norm(perpendicular) < 1e-6:
35
+ perpendicular = np.cross(direction, np.array([0.0, 1.0, 0.0], dtype=float))
36
+ perpendicular = perpendicular / np.linalg.norm(perpendicular)
37
+ offset = (offset_index - (edge_count - 1) / 2.0) * 0.18 * scale * distance
38
+ control = midpoint + perpendicular * offset
39
+ return _quadratic_curve(start, control, end)
40
+
41
+
42
+ def _draw_3d(
43
+ *,
44
+ ax: Axes3D,
45
+ graph: _GraphData,
46
+ positions: NodePositions,
47
+ directions: AxisDirections,
48
+ show_tensor_labels: bool,
49
+ show_index_labels: bool,
50
+ config: PlotConfig,
51
+ scale: float = 1.0,
52
+ ) -> None:
53
+ ax.cla()
54
+ pair_groups = _group_contractions(graph)
55
+ p = _draw_scale_params(config, scale, is_3d=True)
56
+
57
+ for edge in graph.edges:
58
+ if edge.kind == "dangling":
59
+ endpoint = edge.endpoints[0]
60
+ direction = directions[(endpoint.node_id, endpoint.axis_index)]
61
+ start = positions[endpoint.node_id] + direction * p.r
62
+ end = start + direction * p.stub
63
+ ax.plot(
64
+ [start[0], end[0]],
65
+ [start[1], end[1]],
66
+ [start[2], end[2]],
67
+ color=config.dangling_edge_color,
68
+ linewidth=p.lw,
69
+ zorder=2,
70
+ )
71
+ if show_index_labels and edge.label:
72
+ label_pos = end + direction * p.label_offset
73
+ ax.text(
74
+ label_pos[0],
75
+ label_pos[1],
76
+ label_pos[2],
77
+ edge.label,
78
+ color=config.label_color,
79
+ fontsize=p.font_dangling,
80
+ zorder=5,
81
+ ha="center",
82
+ va="bottom",
83
+ )
84
+ elif edge.kind == "self":
85
+ endpoint_a, endpoint_b = _require_self_endpoints(edge)
86
+ direction_a = directions[(endpoint_a.node_id, endpoint_a.axis_index)]
87
+ direction_b = directions[(endpoint_b.node_id, endpoint_b.axis_index)]
88
+ orientation = direction_a + direction_b
89
+ if np.linalg.norm(orientation) < 1e-6:
90
+ orientation = np.array([1.0, 0.0, 0.0], dtype=float)
91
+ orientation = orientation / np.linalg.norm(orientation)
92
+ normal = _orthogonal_unit(orientation)
93
+ binormal = np.cross(orientation, normal)
94
+ binormal = binormal / np.linalg.norm(binormal)
95
+ center_pt = (
96
+ positions[endpoint_a.node_id]
97
+ + orientation * (p.r + p.loop_r)
98
+ )
99
+ curve = _ellipse_points_3d(
100
+ center_pt, normal, binormal, width=p.ellipse_w, height=p.ellipse_h
101
+ )
102
+ ax.plot(
103
+ curve[:, 0],
104
+ curve[:, 1],
105
+ curve[:, 2],
106
+ color=config.bond_edge_color,
107
+ linewidth=p.lw,
108
+ zorder=2,
109
+ )
110
+ if show_index_labels and edge.label:
111
+ label_pos = center_pt + binormal * p.ellipse_w
112
+ ax.text(
113
+ label_pos[0],
114
+ label_pos[1],
115
+ label_pos[2],
116
+ edge.label,
117
+ color=config.label_color,
118
+ fontsize=p.font_dangling,
119
+ zorder=5,
120
+ ha="center",
121
+ va="bottom",
122
+ )
123
+ else:
124
+ key = tuple(sorted(edge.node_ids))
125
+ group = pair_groups[key]
126
+ offset_index = group.index(edge)
127
+ curve = _curved_edge_points_3d(
128
+ start=positions[edge.node_ids[0]],
129
+ end=positions[edge.node_ids[1]],
130
+ offset_index=offset_index,
131
+ edge_count=len(group),
132
+ scale=scale,
133
+ )
134
+ ax.plot(
135
+ curve[:, 0],
136
+ curve[:, 1],
137
+ curve[:, 2],
138
+ color=config.bond_edge_color,
139
+ linewidth=p.lw,
140
+ zorder=1,
141
+ )
142
+ if show_index_labels and edge.label:
143
+ midpoint = curve[len(curve) // 2]
144
+ delta = positions[edge.node_ids[1]] - positions[edge.node_ids[0]]
145
+ dist = max(float(np.linalg.norm(delta)), 1e-6)
146
+ direction = delta / dist
147
+ perpendicular = np.cross(direction, np.array([0.0, 0.0, 1.0], dtype=float))
148
+ if np.linalg.norm(perpendicular) < 1e-6:
149
+ perpendicular = np.cross(direction, np.array([0.0, 1.0, 0.0], dtype=float))
150
+ perpendicular = perpendicular / np.linalg.norm(perpendicular)
151
+ if perpendicular[2] < 0:
152
+ perpendicular = -perpendicular
153
+ label_pos = midpoint + perpendicular * p.label_offset
154
+ ax.text(
155
+ label_pos[0],
156
+ label_pos[1],
157
+ label_pos[2],
158
+ edge.label,
159
+ color=config.label_color,
160
+ fontsize=p.font_bond,
161
+ zorder=5,
162
+ ha="center",
163
+ va="bottom",
164
+ )
165
+
166
+ coords = np.stack(list(positions.values()))
167
+ ax.scatter(
168
+ coords[:, 0],
169
+ coords[:, 1],
170
+ coords[:, 2],
171
+ s=p.scatter_s,
172
+ c=config.node_color,
173
+ edgecolors=config.node_edge_color,
174
+ linewidths=p.lw,
175
+ depthshade=False,
176
+ )
177
+
178
+ if show_tensor_labels:
179
+ for node_id, node in graph.nodes.items():
180
+ x, y, z = positions[node_id]
181
+ ax.text(
182
+ x, y, z,
183
+ node.name,
184
+ color=config.tensor_label_color,
185
+ fontsize=p.font_node,
186
+ ha="center",
187
+ va="center",
188
+ zorder=5,
189
+ )
190
+
191
+ _style_3d_axes(ax, coords)
192
+
193
+
194
+ def _style_3d_axes(ax: Axes3D, coords: np.ndarray) -> None:
195
+ span = np.ptp(coords, axis=0)
196
+ span = np.maximum(span, 1.0)
197
+ center = coords.mean(axis=0)
198
+ ax.set_xlim(center[0] - span[0] * 0.9, center[0] + span[0] * 0.9)
199
+ ax.set_ylim(center[1] - span[1] * 0.9, center[1] + span[1] * 0.9)
200
+ ax.set_zlim(center[2] - span[2] * 0.9, center[2] + span[2] * 0.9)
201
+ ax.set_box_aspect(span)
202
+ ax.set_axis_off()
@@ -0,0 +1,190 @@
1
+ """Graph data structures and construction from tensor networks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, Literal
7
+
8
+ EdgeKind = Literal["contraction", "dangling", "self"]
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class _EdgeEndpoint:
13
+ node_id: int
14
+ axis_index: int
15
+ axis_name: str | None
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class _NodeData:
20
+ name: str
21
+ axes_names: tuple[str, ...]
22
+ degree: int
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class _EdgeData:
27
+ name: str | None
28
+ kind: EdgeKind
29
+ node_ids: tuple[int, ...]
30
+ endpoints: tuple[_EdgeEndpoint, ...]
31
+ label: str | None
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class _GraphData:
36
+ nodes: dict[int, _NodeData]
37
+ edges: tuple[_EdgeData, ...]
38
+
39
+
40
+ def _get_network_nodes(network: Any) -> list[Any]:
41
+ if hasattr(network, "nodes"):
42
+ raw_nodes = network.nodes
43
+ elif hasattr(network, "leaf_nodes"):
44
+ raw_nodes = network.leaf_nodes
45
+ else:
46
+ raise TypeError(
47
+ "Tensor network must expose either a 'nodes' attribute or a 'leaf_nodes' attribute."
48
+ )
49
+
50
+ iterable = raw_nodes.values() if isinstance(raw_nodes, dict) else raw_nodes
51
+
52
+ try:
53
+ items = list(iterable)
54
+ except TypeError as exc:
55
+ raise TypeError("Tensor network nodes must be iterable.") from exc
56
+
57
+ unique_nodes: list[Any] = []
58
+ seen: set[int] = set()
59
+ for node in items:
60
+ if node is None:
61
+ continue
62
+ node_id = id(node)
63
+ if node_id in seen:
64
+ continue
65
+ seen.add(node_id)
66
+ unique_nodes.append(node)
67
+ return unique_nodes
68
+
69
+
70
+ def _iterable_attr(obj: Any, attr_name: str, object_name: str) -> list[Any]:
71
+ value = _require_attr(obj, attr_name, object_name)
72
+ if isinstance(value, dict):
73
+ return list(value.values())
74
+ try:
75
+ return list(value)
76
+ except TypeError as exc:
77
+ msg = f"{object_name.capitalize()} attribute '{attr_name}' must be iterable."
78
+ raise TypeError(msg) from exc
79
+
80
+
81
+ def _require_attr(obj: Any, attr_name: str, object_name: str) -> Any:
82
+ if not hasattr(obj, attr_name):
83
+ raise TypeError(f"{object_name.capitalize()} is missing required attribute '{attr_name}'.")
84
+ return getattr(obj, attr_name)
85
+
86
+
87
+ def _stringify(value: Any) -> str:
88
+ return "" if value is None else str(value)
89
+
90
+
91
+ def _optional_string(value: Any) -> str | None:
92
+ if value is None:
93
+ return None
94
+ text = str(value)
95
+ return text or None
96
+
97
+
98
+ def _build_edge_label(
99
+ kind: EdgeKind,
100
+ endpoints: tuple[_EdgeEndpoint, ...],
101
+ edge_name: str | None,
102
+ ) -> str | None:
103
+ axis_names = [endpoint.axis_name for endpoint in endpoints if endpoint.axis_name]
104
+ if kind == "dangling":
105
+ return axis_names[0] if axis_names else edge_name
106
+ if len(axis_names) >= 2:
107
+ return f"{axis_names[0]}<->{axis_names[1]}"
108
+ return edge_name
109
+
110
+
111
+ def _build_graph(network: Any) -> _GraphData:
112
+ node_refs = _get_network_nodes(network)
113
+ if not node_refs:
114
+ raise ValueError("The tensor network does not expose any nodes to visualize.")
115
+
116
+ nodes: dict[int, _NodeData] = {}
117
+ edge_refs: dict[int, Any] = {}
118
+ edge_endpoints: dict[int, list[_EdgeEndpoint]] = {}
119
+
120
+ for node in node_refs:
121
+ name = _stringify(_require_attr(node, "name", "node"))
122
+ edges = tuple(_iterable_attr(node, "edges", "node"))
123
+ axes_names = tuple(_stringify(item) for item in _iterable_attr(node, "axes_names", "node"))
124
+ if len(edges) != len(axes_names):
125
+ raise TypeError(
126
+ f"Node {name!r} has {len(edges)} edges but {len(axes_names)} axes_names."
127
+ )
128
+
129
+ node_id = id(node)
130
+ nodes[node_id] = _NodeData(
131
+ name=name,
132
+ axes_names=axes_names,
133
+ degree=len(edges),
134
+ )
135
+
136
+ for axis_index, edge in enumerate(edges):
137
+ if edge is None:
138
+ continue
139
+ edge_id = id(edge)
140
+ edge_refs[edge_id] = edge
141
+ edge_endpoints.setdefault(edge_id, []).append(
142
+ _EdgeEndpoint(
143
+ node_id=node_id,
144
+ axis_index=axis_index,
145
+ axis_name=axes_names[axis_index],
146
+ )
147
+ )
148
+
149
+ edges: list[_EdgeData] = []
150
+ for edge_id, edge in edge_refs.items():
151
+ name = _optional_string(_require_attr(edge, "name", "edge"))
152
+ node1 = _require_attr(edge, "node1", "edge")
153
+ node2 = _require_attr(edge, "node2", "edge")
154
+
155
+ node1_id = id(node1) if node1 is not None and id(node1) in nodes else None
156
+ node2_id = id(node2) if node2 is not None and id(node2) in nodes else None
157
+ endpoints = tuple(edge_endpoints.get(edge_id, ()))
158
+ if not endpoints:
159
+ continue
160
+ if len(endpoints) > 2:
161
+ raise TypeError("Edges with more than two endpoints are not supported.")
162
+
163
+ if node1_id is not None and node2_id is not None:
164
+ kind: EdgeKind
165
+ if node1_id == node2_id:
166
+ kind = "self"
167
+ node_ids = (node1_id,)
168
+ else:
169
+ kind = "contraction"
170
+ node_ids = (node1_id, node2_id)
171
+ elif node1_id is not None:
172
+ kind = "dangling"
173
+ node_ids = (node1_id,)
174
+ elif node2_id is not None:
175
+ kind = "dangling"
176
+ node_ids = (node2_id,)
177
+ else:
178
+ continue
179
+
180
+ edges.append(
181
+ _EdgeData(
182
+ name=name,
183
+ kind=kind,
184
+ node_ids=node_ids,
185
+ endpoints=endpoints,
186
+ label=_build_edge_label(kind=kind, endpoints=endpoints, edge_name=name),
187
+ )
188
+ )
189
+
190
+ return _GraphData(nodes=nodes, edges=tuple(edges))