tensor-network-visualization 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tensor_network_visualization-1.0.0.dist-info/METADATA +104 -0
- tensor_network_visualization-1.0.0.dist-info/RECORD +17 -0
- tensor_network_visualization-1.0.0.dist-info/WHEEL +5 -0
- tensor_network_visualization-1.0.0.dist-info/licenses/LICENSE +21 -0
- tensor_network_visualization-1.0.0.dist-info/top_level.txt +1 -0
- tensor_network_viz/__init__.py +10 -0
- tensor_network_viz/config.py +60 -0
- tensor_network_viz/py.typed +0 -0
- tensor_network_viz/tensorkrowch/__init__.py +9 -0
- tensor_network_viz/tensorkrowch/_draw_common.py +62 -0
- tensor_network_viz/tensorkrowch/curves.py +98 -0
- tensor_network_viz/tensorkrowch/draw_2d.py +178 -0
- tensor_network_viz/tensorkrowch/draw_3d.py +202 -0
- tensor_network_viz/tensorkrowch/graph.py +190 -0
- tensor_network_viz/tensorkrowch/layout.py +352 -0
- tensor_network_viz/tensorkrowch/renderer.py +181 -0
- tensor_network_viz/viewer.py +57 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""3D drawing for tensor networks."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from mpl_toolkits.mplot3d.axes3d import Axes3D
|
|
7
|
+
|
|
8
|
+
from ..config import PlotConfig
|
|
9
|
+
from ._draw_common import _draw_scale_params
|
|
10
|
+
from .curves import (
|
|
11
|
+
_ellipse_points_3d,
|
|
12
|
+
_group_contractions,
|
|
13
|
+
_quadratic_curve,
|
|
14
|
+
_require_self_endpoints,
|
|
15
|
+
)
|
|
16
|
+
from .graph import _GraphData
|
|
17
|
+
from .layout import AxisDirections, NodePositions, _orthogonal_unit
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _curved_edge_points_3d(
|
|
21
|
+
*,
|
|
22
|
+
start: np.ndarray,
|
|
23
|
+
end: np.ndarray,
|
|
24
|
+
offset_index: int,
|
|
25
|
+
edge_count: int,
|
|
26
|
+
scale: float = 1.0,
|
|
27
|
+
) -> np.ndarray:
|
|
28
|
+
midpoint = (start + end) / 2.0
|
|
29
|
+
delta = end - start
|
|
30
|
+
distance = max(float(np.linalg.norm(delta)), 1e-6)
|
|
31
|
+
direction = delta / distance
|
|
32
|
+
reference = np.array([0.0, 0.0, 1.0], dtype=float)
|
|
33
|
+
perpendicular = np.cross(direction, reference)
|
|
34
|
+
if np.linalg.norm(perpendicular) < 1e-6:
|
|
35
|
+
perpendicular = np.cross(direction, np.array([0.0, 1.0, 0.0], dtype=float))
|
|
36
|
+
perpendicular = perpendicular / np.linalg.norm(perpendicular)
|
|
37
|
+
offset = (offset_index - (edge_count - 1) / 2.0) * 0.18 * scale * distance
|
|
38
|
+
control = midpoint + perpendicular * offset
|
|
39
|
+
return _quadratic_curve(start, control, end)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _draw_3d(
|
|
43
|
+
*,
|
|
44
|
+
ax: Axes3D,
|
|
45
|
+
graph: _GraphData,
|
|
46
|
+
positions: NodePositions,
|
|
47
|
+
directions: AxisDirections,
|
|
48
|
+
show_tensor_labels: bool,
|
|
49
|
+
show_index_labels: bool,
|
|
50
|
+
config: PlotConfig,
|
|
51
|
+
scale: float = 1.0,
|
|
52
|
+
) -> None:
|
|
53
|
+
ax.cla()
|
|
54
|
+
pair_groups = _group_contractions(graph)
|
|
55
|
+
p = _draw_scale_params(config, scale, is_3d=True)
|
|
56
|
+
|
|
57
|
+
for edge in graph.edges:
|
|
58
|
+
if edge.kind == "dangling":
|
|
59
|
+
endpoint = edge.endpoints[0]
|
|
60
|
+
direction = directions[(endpoint.node_id, endpoint.axis_index)]
|
|
61
|
+
start = positions[endpoint.node_id] + direction * p.r
|
|
62
|
+
end = start + direction * p.stub
|
|
63
|
+
ax.plot(
|
|
64
|
+
[start[0], end[0]],
|
|
65
|
+
[start[1], end[1]],
|
|
66
|
+
[start[2], end[2]],
|
|
67
|
+
color=config.dangling_edge_color,
|
|
68
|
+
linewidth=p.lw,
|
|
69
|
+
zorder=2,
|
|
70
|
+
)
|
|
71
|
+
if show_index_labels and edge.label:
|
|
72
|
+
label_pos = end + direction * p.label_offset
|
|
73
|
+
ax.text(
|
|
74
|
+
label_pos[0],
|
|
75
|
+
label_pos[1],
|
|
76
|
+
label_pos[2],
|
|
77
|
+
edge.label,
|
|
78
|
+
color=config.label_color,
|
|
79
|
+
fontsize=p.font_dangling,
|
|
80
|
+
zorder=5,
|
|
81
|
+
ha="center",
|
|
82
|
+
va="bottom",
|
|
83
|
+
)
|
|
84
|
+
elif edge.kind == "self":
|
|
85
|
+
endpoint_a, endpoint_b = _require_self_endpoints(edge)
|
|
86
|
+
direction_a = directions[(endpoint_a.node_id, endpoint_a.axis_index)]
|
|
87
|
+
direction_b = directions[(endpoint_b.node_id, endpoint_b.axis_index)]
|
|
88
|
+
orientation = direction_a + direction_b
|
|
89
|
+
if np.linalg.norm(orientation) < 1e-6:
|
|
90
|
+
orientation = np.array([1.0, 0.0, 0.0], dtype=float)
|
|
91
|
+
orientation = orientation / np.linalg.norm(orientation)
|
|
92
|
+
normal = _orthogonal_unit(orientation)
|
|
93
|
+
binormal = np.cross(orientation, normal)
|
|
94
|
+
binormal = binormal / np.linalg.norm(binormal)
|
|
95
|
+
center_pt = (
|
|
96
|
+
positions[endpoint_a.node_id]
|
|
97
|
+
+ orientation * (p.r + p.loop_r)
|
|
98
|
+
)
|
|
99
|
+
curve = _ellipse_points_3d(
|
|
100
|
+
center_pt, normal, binormal, width=p.ellipse_w, height=p.ellipse_h
|
|
101
|
+
)
|
|
102
|
+
ax.plot(
|
|
103
|
+
curve[:, 0],
|
|
104
|
+
curve[:, 1],
|
|
105
|
+
curve[:, 2],
|
|
106
|
+
color=config.bond_edge_color,
|
|
107
|
+
linewidth=p.lw,
|
|
108
|
+
zorder=2,
|
|
109
|
+
)
|
|
110
|
+
if show_index_labels and edge.label:
|
|
111
|
+
label_pos = center_pt + binormal * p.ellipse_w
|
|
112
|
+
ax.text(
|
|
113
|
+
label_pos[0],
|
|
114
|
+
label_pos[1],
|
|
115
|
+
label_pos[2],
|
|
116
|
+
edge.label,
|
|
117
|
+
color=config.label_color,
|
|
118
|
+
fontsize=p.font_dangling,
|
|
119
|
+
zorder=5,
|
|
120
|
+
ha="center",
|
|
121
|
+
va="bottom",
|
|
122
|
+
)
|
|
123
|
+
else:
|
|
124
|
+
key = tuple(sorted(edge.node_ids))
|
|
125
|
+
group = pair_groups[key]
|
|
126
|
+
offset_index = group.index(edge)
|
|
127
|
+
curve = _curved_edge_points_3d(
|
|
128
|
+
start=positions[edge.node_ids[0]],
|
|
129
|
+
end=positions[edge.node_ids[1]],
|
|
130
|
+
offset_index=offset_index,
|
|
131
|
+
edge_count=len(group),
|
|
132
|
+
scale=scale,
|
|
133
|
+
)
|
|
134
|
+
ax.plot(
|
|
135
|
+
curve[:, 0],
|
|
136
|
+
curve[:, 1],
|
|
137
|
+
curve[:, 2],
|
|
138
|
+
color=config.bond_edge_color,
|
|
139
|
+
linewidth=p.lw,
|
|
140
|
+
zorder=1,
|
|
141
|
+
)
|
|
142
|
+
if show_index_labels and edge.label:
|
|
143
|
+
midpoint = curve[len(curve) // 2]
|
|
144
|
+
delta = positions[edge.node_ids[1]] - positions[edge.node_ids[0]]
|
|
145
|
+
dist = max(float(np.linalg.norm(delta)), 1e-6)
|
|
146
|
+
direction = delta / dist
|
|
147
|
+
perpendicular = np.cross(direction, np.array([0.0, 0.0, 1.0], dtype=float))
|
|
148
|
+
if np.linalg.norm(perpendicular) < 1e-6:
|
|
149
|
+
perpendicular = np.cross(direction, np.array([0.0, 1.0, 0.0], dtype=float))
|
|
150
|
+
perpendicular = perpendicular / np.linalg.norm(perpendicular)
|
|
151
|
+
if perpendicular[2] < 0:
|
|
152
|
+
perpendicular = -perpendicular
|
|
153
|
+
label_pos = midpoint + perpendicular * p.label_offset
|
|
154
|
+
ax.text(
|
|
155
|
+
label_pos[0],
|
|
156
|
+
label_pos[1],
|
|
157
|
+
label_pos[2],
|
|
158
|
+
edge.label,
|
|
159
|
+
color=config.label_color,
|
|
160
|
+
fontsize=p.font_bond,
|
|
161
|
+
zorder=5,
|
|
162
|
+
ha="center",
|
|
163
|
+
va="bottom",
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
coords = np.stack(list(positions.values()))
|
|
167
|
+
ax.scatter(
|
|
168
|
+
coords[:, 0],
|
|
169
|
+
coords[:, 1],
|
|
170
|
+
coords[:, 2],
|
|
171
|
+
s=p.scatter_s,
|
|
172
|
+
c=config.node_color,
|
|
173
|
+
edgecolors=config.node_edge_color,
|
|
174
|
+
linewidths=p.lw,
|
|
175
|
+
depthshade=False,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if show_tensor_labels:
|
|
179
|
+
for node_id, node in graph.nodes.items():
|
|
180
|
+
x, y, z = positions[node_id]
|
|
181
|
+
ax.text(
|
|
182
|
+
x, y, z,
|
|
183
|
+
node.name,
|
|
184
|
+
color=config.tensor_label_color,
|
|
185
|
+
fontsize=p.font_node,
|
|
186
|
+
ha="center",
|
|
187
|
+
va="center",
|
|
188
|
+
zorder=5,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
_style_3d_axes(ax, coords)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _style_3d_axes(ax: Axes3D, coords: np.ndarray) -> None:
|
|
195
|
+
span = np.ptp(coords, axis=0)
|
|
196
|
+
span = np.maximum(span, 1.0)
|
|
197
|
+
center = coords.mean(axis=0)
|
|
198
|
+
ax.set_xlim(center[0] - span[0] * 0.9, center[0] + span[0] * 0.9)
|
|
199
|
+
ax.set_ylim(center[1] - span[1] * 0.9, center[1] + span[1] * 0.9)
|
|
200
|
+
ax.set_zlim(center[2] - span[2] * 0.9, center[2] + span[2] * 0.9)
|
|
201
|
+
ax.set_box_aspect(span)
|
|
202
|
+
ax.set_axis_off()
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""Graph data structures and construction from tensor networks."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Literal
|
|
7
|
+
|
|
8
|
+
EdgeKind = Literal["contraction", "dangling", "self"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class _EdgeEndpoint:
|
|
13
|
+
node_id: int
|
|
14
|
+
axis_index: int
|
|
15
|
+
axis_name: str | None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class _NodeData:
|
|
20
|
+
name: str
|
|
21
|
+
axes_names: tuple[str, ...]
|
|
22
|
+
degree: int
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class _EdgeData:
|
|
27
|
+
name: str | None
|
|
28
|
+
kind: EdgeKind
|
|
29
|
+
node_ids: tuple[int, ...]
|
|
30
|
+
endpoints: tuple[_EdgeEndpoint, ...]
|
|
31
|
+
label: str | None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(frozen=True)
|
|
35
|
+
class _GraphData:
|
|
36
|
+
nodes: dict[int, _NodeData]
|
|
37
|
+
edges: tuple[_EdgeData, ...]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _get_network_nodes(network: Any) -> list[Any]:
|
|
41
|
+
if hasattr(network, "nodes"):
|
|
42
|
+
raw_nodes = network.nodes
|
|
43
|
+
elif hasattr(network, "leaf_nodes"):
|
|
44
|
+
raw_nodes = network.leaf_nodes
|
|
45
|
+
else:
|
|
46
|
+
raise TypeError(
|
|
47
|
+
"Tensor network must expose either a 'nodes' attribute or a 'leaf_nodes' attribute."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
iterable = raw_nodes.values() if isinstance(raw_nodes, dict) else raw_nodes
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
items = list(iterable)
|
|
54
|
+
except TypeError as exc:
|
|
55
|
+
raise TypeError("Tensor network nodes must be iterable.") from exc
|
|
56
|
+
|
|
57
|
+
unique_nodes: list[Any] = []
|
|
58
|
+
seen: set[int] = set()
|
|
59
|
+
for node in items:
|
|
60
|
+
if node is None:
|
|
61
|
+
continue
|
|
62
|
+
node_id = id(node)
|
|
63
|
+
if node_id in seen:
|
|
64
|
+
continue
|
|
65
|
+
seen.add(node_id)
|
|
66
|
+
unique_nodes.append(node)
|
|
67
|
+
return unique_nodes
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _iterable_attr(obj: Any, attr_name: str, object_name: str) -> list[Any]:
|
|
71
|
+
value = _require_attr(obj, attr_name, object_name)
|
|
72
|
+
if isinstance(value, dict):
|
|
73
|
+
return list(value.values())
|
|
74
|
+
try:
|
|
75
|
+
return list(value)
|
|
76
|
+
except TypeError as exc:
|
|
77
|
+
msg = f"{object_name.capitalize()} attribute '{attr_name}' must be iterable."
|
|
78
|
+
raise TypeError(msg) from exc
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _require_attr(obj: Any, attr_name: str, object_name: str) -> Any:
|
|
82
|
+
if not hasattr(obj, attr_name):
|
|
83
|
+
raise TypeError(f"{object_name.capitalize()} is missing required attribute '{attr_name}'.")
|
|
84
|
+
return getattr(obj, attr_name)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _stringify(value: Any) -> str:
|
|
88
|
+
return "" if value is None else str(value)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _optional_string(value: Any) -> str | None:
|
|
92
|
+
if value is None:
|
|
93
|
+
return None
|
|
94
|
+
text = str(value)
|
|
95
|
+
return text or None
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _build_edge_label(
|
|
99
|
+
kind: EdgeKind,
|
|
100
|
+
endpoints: tuple[_EdgeEndpoint, ...],
|
|
101
|
+
edge_name: str | None,
|
|
102
|
+
) -> str | None:
|
|
103
|
+
axis_names = [endpoint.axis_name for endpoint in endpoints if endpoint.axis_name]
|
|
104
|
+
if kind == "dangling":
|
|
105
|
+
return axis_names[0] if axis_names else edge_name
|
|
106
|
+
if len(axis_names) >= 2:
|
|
107
|
+
return f"{axis_names[0]}<->{axis_names[1]}"
|
|
108
|
+
return edge_name
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _build_graph(network: Any) -> _GraphData:
|
|
112
|
+
node_refs = _get_network_nodes(network)
|
|
113
|
+
if not node_refs:
|
|
114
|
+
raise ValueError("The tensor network does not expose any nodes to visualize.")
|
|
115
|
+
|
|
116
|
+
nodes: dict[int, _NodeData] = {}
|
|
117
|
+
edge_refs: dict[int, Any] = {}
|
|
118
|
+
edge_endpoints: dict[int, list[_EdgeEndpoint]] = {}
|
|
119
|
+
|
|
120
|
+
for node in node_refs:
|
|
121
|
+
name = _stringify(_require_attr(node, "name", "node"))
|
|
122
|
+
edges = tuple(_iterable_attr(node, "edges", "node"))
|
|
123
|
+
axes_names = tuple(_stringify(item) for item in _iterable_attr(node, "axes_names", "node"))
|
|
124
|
+
if len(edges) != len(axes_names):
|
|
125
|
+
raise TypeError(
|
|
126
|
+
f"Node {name!r} has {len(edges)} edges but {len(axes_names)} axes_names."
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
node_id = id(node)
|
|
130
|
+
nodes[node_id] = _NodeData(
|
|
131
|
+
name=name,
|
|
132
|
+
axes_names=axes_names,
|
|
133
|
+
degree=len(edges),
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
for axis_index, edge in enumerate(edges):
|
|
137
|
+
if edge is None:
|
|
138
|
+
continue
|
|
139
|
+
edge_id = id(edge)
|
|
140
|
+
edge_refs[edge_id] = edge
|
|
141
|
+
edge_endpoints.setdefault(edge_id, []).append(
|
|
142
|
+
_EdgeEndpoint(
|
|
143
|
+
node_id=node_id,
|
|
144
|
+
axis_index=axis_index,
|
|
145
|
+
axis_name=axes_names[axis_index],
|
|
146
|
+
)
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
edges: list[_EdgeData] = []
|
|
150
|
+
for edge_id, edge in edge_refs.items():
|
|
151
|
+
name = _optional_string(_require_attr(edge, "name", "edge"))
|
|
152
|
+
node1 = _require_attr(edge, "node1", "edge")
|
|
153
|
+
node2 = _require_attr(edge, "node2", "edge")
|
|
154
|
+
|
|
155
|
+
node1_id = id(node1) if node1 is not None and id(node1) in nodes else None
|
|
156
|
+
node2_id = id(node2) if node2 is not None and id(node2) in nodes else None
|
|
157
|
+
endpoints = tuple(edge_endpoints.get(edge_id, ()))
|
|
158
|
+
if not endpoints:
|
|
159
|
+
continue
|
|
160
|
+
if len(endpoints) > 2:
|
|
161
|
+
raise TypeError("Edges with more than two endpoints are not supported.")
|
|
162
|
+
|
|
163
|
+
if node1_id is not None and node2_id is not None:
|
|
164
|
+
kind: EdgeKind
|
|
165
|
+
if node1_id == node2_id:
|
|
166
|
+
kind = "self"
|
|
167
|
+
node_ids = (node1_id,)
|
|
168
|
+
else:
|
|
169
|
+
kind = "contraction"
|
|
170
|
+
node_ids = (node1_id, node2_id)
|
|
171
|
+
elif node1_id is not None:
|
|
172
|
+
kind = "dangling"
|
|
173
|
+
node_ids = (node1_id,)
|
|
174
|
+
elif node2_id is not None:
|
|
175
|
+
kind = "dangling"
|
|
176
|
+
node_ids = (node2_id,)
|
|
177
|
+
else:
|
|
178
|
+
continue
|
|
179
|
+
|
|
180
|
+
edges.append(
|
|
181
|
+
_EdgeData(
|
|
182
|
+
name=name,
|
|
183
|
+
kind=kind,
|
|
184
|
+
node_ids=node_ids,
|
|
185
|
+
endpoints=endpoints,
|
|
186
|
+
label=_build_edge_label(kind=kind, endpoints=endpoints, edge_name=name),
|
|
187
|
+
)
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
return _GraphData(nodes=nodes, edges=tuple(edges))
|