tensor-network-visualization 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tensor_network_visualization-1.0.0.dist-info/METADATA +104 -0
- tensor_network_visualization-1.0.0.dist-info/RECORD +17 -0
- tensor_network_visualization-1.0.0.dist-info/WHEEL +5 -0
- tensor_network_visualization-1.0.0.dist-info/licenses/LICENSE +21 -0
- tensor_network_visualization-1.0.0.dist-info/top_level.txt +1 -0
- tensor_network_viz/__init__.py +10 -0
- tensor_network_viz/config.py +60 -0
- tensor_network_viz/py.typed +0 -0
- tensor_network_viz/tensorkrowch/__init__.py +9 -0
- tensor_network_viz/tensorkrowch/_draw_common.py +62 -0
- tensor_network_viz/tensorkrowch/curves.py +98 -0
- tensor_network_viz/tensorkrowch/draw_2d.py +178 -0
- tensor_network_viz/tensorkrowch/draw_3d.py +202 -0
- tensor_network_viz/tensorkrowch/graph.py +190 -0
- tensor_network_viz/tensorkrowch/layout.py +352 -0
- tensor_network_viz/tensorkrowch/renderer.py +181 -0
- tensor_network_viz/viewer.py +57 -0
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
"""Layout computation for tensor network graphs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from typing import TypeAlias
|
|
7
|
+
|
|
8
|
+
import networkx as nx
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from .graph import _GraphData
|
|
12
|
+
|
|
13
|
+
Vector: TypeAlias = np.ndarray
|
|
14
|
+
NodePositions: TypeAlias = dict[int, Vector]
|
|
15
|
+
AxisDirections: TypeAlias = dict[tuple[int, int], Vector]
|
|
16
|
+
|
|
17
|
+
# Axis names that map to fixed 2D directions (x, y). Case-insensitive.
|
|
18
|
+
_AXIS_DIR_2D: dict[str, tuple[float, float]] = {
|
|
19
|
+
"up": (0.0, 1.0),
|
|
20
|
+
"down": (0.0, -1.0),
|
|
21
|
+
"left": (-1.0, 0.0),
|
|
22
|
+
"right": (1.0, 0.0),
|
|
23
|
+
"north": (0.0, 1.0),
|
|
24
|
+
"south": (0.0, -1.0),
|
|
25
|
+
"east": (1.0, 0.0),
|
|
26
|
+
"west": (-1.0, 0.0),
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
_AXIS_DIR_3D: dict[str, tuple[float, float, float]] = {
|
|
30
|
+
"up": (0.0, 0.0, 1.0),
|
|
31
|
+
"down": (0.0, 0.0, -1.0),
|
|
32
|
+
"left": (-1.0, 0.0, 0.0),
|
|
33
|
+
"right": (1.0, 0.0, 0.0),
|
|
34
|
+
"north": (0.0, 0.0, 1.0),
|
|
35
|
+
"south": (0.0, 0.0, -1.0),
|
|
36
|
+
"east": (1.0, 0.0, 0.0),
|
|
37
|
+
"west": (-1.0, 0.0, 0.0),
|
|
38
|
+
"front": (0.0, 1.0, 0.0),
|
|
39
|
+
"back": (0.0, -1.0, 0.0),
|
|
40
|
+
"in": (0.0, 1.0, 0.0),
|
|
41
|
+
"out": (0.0, -1.0, 0.0),
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _direction_from_axis_name_2d(axis_name: str | None) -> np.ndarray | None:
|
|
46
|
+
if not axis_name:
|
|
47
|
+
return None
|
|
48
|
+
key = axis_name.lower().strip()
|
|
49
|
+
if key in _AXIS_DIR_2D:
|
|
50
|
+
return np.array(_AXIS_DIR_2D[key], dtype=float)
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _direction_from_axis_name_3d(axis_name: str | None) -> np.ndarray | None:
|
|
55
|
+
if not axis_name:
|
|
56
|
+
return None
|
|
57
|
+
key = axis_name.lower().strip()
|
|
58
|
+
if key in _AXIS_DIR_3D:
|
|
59
|
+
return np.array(_AXIS_DIR_3D[key], dtype=float)
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _compute_layout(graph: _GraphData, dimensions: int, seed: int) -> NodePositions:
|
|
64
|
+
node_ids = list(graph.nodes)
|
|
65
|
+
if len(node_ids) == 1:
|
|
66
|
+
origin = np.zeros(dimensions, dtype=float)
|
|
67
|
+
return {node_ids[0]: origin}
|
|
68
|
+
|
|
69
|
+
if dimensions == 2:
|
|
70
|
+
grid_pos = _try_grid_layout_2d(graph)
|
|
71
|
+
if grid_pos is not None:
|
|
72
|
+
return grid_pos
|
|
73
|
+
planar_pos = _try_planar_layout_2d(graph)
|
|
74
|
+
if planar_pos is not None:
|
|
75
|
+
return planar_pos
|
|
76
|
+
|
|
77
|
+
positions = _initial_positions(node_ids, dimensions=dimensions, seed=seed)
|
|
78
|
+
index_by_node = {node_id: index for index, node_id in enumerate(node_ids)}
|
|
79
|
+
|
|
80
|
+
pair_weights: dict[tuple[int, int], int] = {}
|
|
81
|
+
for edge in graph.edges:
|
|
82
|
+
if edge.kind != "contraction":
|
|
83
|
+
continue
|
|
84
|
+
left, right = edge.node_ids
|
|
85
|
+
key = tuple(sorted((left, right)))
|
|
86
|
+
pair_weights[key] = pair_weights.get(key, 0) + 1
|
|
87
|
+
|
|
88
|
+
k = 1.6
|
|
89
|
+
temperature = 0.12
|
|
90
|
+
for _ in range(220):
|
|
91
|
+
deltas = positions[:, None, :] - positions[None, :, :]
|
|
92
|
+
distances = np.linalg.norm(deltas, axis=2)
|
|
93
|
+
np.fill_diagonal(distances, 1.0)
|
|
94
|
+
directions = deltas / np.maximum(distances[..., None], 1e-6)
|
|
95
|
+
repulsion = (k * k / np.maximum(distances, 1e-6) ** 2)[..., None] * directions
|
|
96
|
+
displacement = repulsion.sum(axis=1)
|
|
97
|
+
|
|
98
|
+
for (left_id, right_id), weight in pair_weights.items():
|
|
99
|
+
left_index = index_by_node[left_id]
|
|
100
|
+
right_index = index_by_node[right_id]
|
|
101
|
+
delta = positions[right_index] - positions[left_index]
|
|
102
|
+
distance = max(float(np.linalg.norm(delta)), 1e-6)
|
|
103
|
+
direction = delta / distance
|
|
104
|
+
attraction = weight * distance * distance / k * direction
|
|
105
|
+
displacement[left_index] += attraction
|
|
106
|
+
displacement[right_index] -= attraction
|
|
107
|
+
|
|
108
|
+
norms = np.linalg.norm(displacement, axis=1, keepdims=True)
|
|
109
|
+
positions += displacement / np.maximum(norms, 1e-6) * temperature
|
|
110
|
+
positions -= positions.mean(axis=0, keepdims=True)
|
|
111
|
+
max_norm = np.linalg.norm(positions, axis=1).max()
|
|
112
|
+
if max_norm > 1.6:
|
|
113
|
+
positions /= max_norm / 1.6
|
|
114
|
+
temperature *= 0.985
|
|
115
|
+
|
|
116
|
+
return {node_id: positions[index] for index, node_id in enumerate(node_ids)}
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _try_grid_layout_2d(graph: _GraphData) -> NodePositions | None:
|
|
120
|
+
"""Attempt regular grid layout when the graph is a 2D grid. Returns None otherwise."""
|
|
121
|
+
node_ids = list(graph.nodes)
|
|
122
|
+
if len(node_ids) <= 1:
|
|
123
|
+
return None
|
|
124
|
+
|
|
125
|
+
g = nx.Graph()
|
|
126
|
+
g.add_nodes_from(node_ids)
|
|
127
|
+
for edge in graph.edges:
|
|
128
|
+
if edge.kind != "contraction":
|
|
129
|
+
continue
|
|
130
|
+
left, right = edge.node_ids
|
|
131
|
+
if left != right:
|
|
132
|
+
g.add_edge(left, right)
|
|
133
|
+
|
|
134
|
+
n_nodes = g.number_of_nodes()
|
|
135
|
+
n_edges = g.number_of_edges()
|
|
136
|
+
|
|
137
|
+
for rows in range(1, n_nodes + 1):
|
|
138
|
+
if n_nodes % rows != 0:
|
|
139
|
+
continue
|
|
140
|
+
cols = n_nodes // rows
|
|
141
|
+
expected_edges = 2 * rows * cols - rows - cols
|
|
142
|
+
if n_edges != expected_edges:
|
|
143
|
+
continue
|
|
144
|
+
grid_g = nx.grid_2d_graph(rows, cols)
|
|
145
|
+
mapping = nx.vf2pp_isomorphism(g, grid_g)
|
|
146
|
+
if mapping is not None:
|
|
147
|
+
arr = np.array(
|
|
148
|
+
[[mapping[nid][1], mapping[nid][0]] for nid in node_ids],
|
|
149
|
+
dtype=float,
|
|
150
|
+
)
|
|
151
|
+
arr -= arr.mean(axis=0, keepdims=True)
|
|
152
|
+
max_norm = np.linalg.norm(arr, axis=1).max()
|
|
153
|
+
if max_norm > 1e-6:
|
|
154
|
+
arr /= max_norm / 1.6
|
|
155
|
+
return {nid: arr[i].copy() for i, nid in enumerate(node_ids)}
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _try_planar_layout_2d(graph: _GraphData) -> NodePositions | None:
|
|
160
|
+
"""Attempt planar layout for 2D. Returns None if graph is not planar."""
|
|
161
|
+
node_ids = list(graph.nodes)
|
|
162
|
+
if len(node_ids) <= 1:
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
g = nx.Graph()
|
|
166
|
+
g.add_nodes_from(node_ids)
|
|
167
|
+
for edge in graph.edges:
|
|
168
|
+
if edge.kind != "contraction":
|
|
169
|
+
continue
|
|
170
|
+
left, right = edge.node_ids
|
|
171
|
+
if left != right:
|
|
172
|
+
g.add_edge(left, right)
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
pos = nx.planar_layout(g)
|
|
176
|
+
except nx.NetworkXException:
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
arr = np.array([pos[nid] for nid in node_ids], dtype=float)
|
|
180
|
+
arr -= arr.mean(axis=0, keepdims=True)
|
|
181
|
+
max_norm = np.linalg.norm(arr, axis=1).max()
|
|
182
|
+
if max_norm > 1e-6:
|
|
183
|
+
arr /= max_norm / 1.6
|
|
184
|
+
return {nid: arr[i].copy() for i, nid in enumerate(node_ids)}
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _initial_positions(node_ids: list[int], dimensions: int, seed: int) -> Vector:
|
|
188
|
+
count = len(node_ids)
|
|
189
|
+
rng = np.random.default_rng(seed)
|
|
190
|
+
|
|
191
|
+
if dimensions == 2:
|
|
192
|
+
angles = np.linspace(0.0, 2.0 * math.pi, count, endpoint=False)
|
|
193
|
+
positions = np.column_stack((np.cos(angles), np.sin(angles)))
|
|
194
|
+
else:
|
|
195
|
+
positions = np.zeros((count, 3), dtype=float)
|
|
196
|
+
golden_angle = math.pi * (3.0 - math.sqrt(5.0))
|
|
197
|
+
for index in range(count):
|
|
198
|
+
y = 1.0 - (2.0 * index) / max(count - 1, 1)
|
|
199
|
+
radius = math.sqrt(max(0.0, 1.0 - y * y))
|
|
200
|
+
theta = golden_angle * index
|
|
201
|
+
positions[index] = np.array(
|
|
202
|
+
[math.cos(theta) * radius, y, math.sin(theta) * radius],
|
|
203
|
+
dtype=float,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
positions += rng.normal(loc=0.0, scale=0.03, size=positions.shape)
|
|
207
|
+
return positions
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _compute_axis_directions(
|
|
211
|
+
graph: _GraphData,
|
|
212
|
+
positions: NodePositions,
|
|
213
|
+
dimensions: int,
|
|
214
|
+
) -> AxisDirections:
|
|
215
|
+
directions: AxisDirections = {}
|
|
216
|
+
center = np.mean(np.stack(list(positions.values())), axis=0)
|
|
217
|
+
|
|
218
|
+
for edge in graph.edges:
|
|
219
|
+
if edge.kind != "contraction":
|
|
220
|
+
continue
|
|
221
|
+
a_id, b_id = edge.node_ids
|
|
222
|
+
a_ep = next(ep for ep in edge.endpoints if ep.node_id == a_id)
|
|
223
|
+
b_ep = next(ep for ep in edge.endpoints if ep.node_id == b_id)
|
|
224
|
+
pa = positions[a_id]
|
|
225
|
+
pb = positions[b_id]
|
|
226
|
+
delta = pb - pa
|
|
227
|
+
dist = max(float(np.linalg.norm(delta)), 1e-6)
|
|
228
|
+
toward_b = delta / dist
|
|
229
|
+
toward_a = -toward_b
|
|
230
|
+
directions[(a_id, a_ep.axis_index)] = toward_b
|
|
231
|
+
directions[(b_id, b_ep.axis_index)] = toward_a
|
|
232
|
+
|
|
233
|
+
if dimensions == 2:
|
|
234
|
+
_compute_free_directions_2d(graph, positions, directions)
|
|
235
|
+
else:
|
|
236
|
+
_compute_free_directions_3d(graph, positions, center, directions)
|
|
237
|
+
|
|
238
|
+
return directions
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _compute_free_directions_2d(
|
|
242
|
+
graph: _GraphData,
|
|
243
|
+
positions: NodePositions,
|
|
244
|
+
directions: AxisDirections,
|
|
245
|
+
) -> None:
|
|
246
|
+
pos_arr = np.stack(list(positions.values()))
|
|
247
|
+
node_ids = list(positions.keys())
|
|
248
|
+
index_by_node = {nid: i for i, nid in enumerate(node_ids)}
|
|
249
|
+
_SAMPLES = 72
|
|
250
|
+
angles = np.linspace(0.0, 2.0 * math.pi, _SAMPLES, endpoint=False)
|
|
251
|
+
unit_circle = np.column_stack((np.cos(angles), np.sin(angles)))
|
|
252
|
+
|
|
253
|
+
for node_id, node in graph.nodes.items():
|
|
254
|
+
origin = positions[node_id]
|
|
255
|
+
i_origin = index_by_node[node_id]
|
|
256
|
+
obstacles = list(np.delete(pos_arr, i_origin, axis=0))
|
|
257
|
+
for edge in graph.edges:
|
|
258
|
+
if edge.kind == "contraction" and node_id in edge.node_ids:
|
|
259
|
+
other_id = edge.node_ids[1] if edge.node_ids[0] == node_id else edge.node_ids[0]
|
|
260
|
+
mid = (positions[node_id] + positions[other_id]) / 2.0
|
|
261
|
+
obstacles.append(mid)
|
|
262
|
+
obstacles = np.array(obstacles, dtype=float) if obstacles else np.array(
|
|
263
|
+
[[origin[0] + 1.0, origin[1]]], dtype=float
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
vecs_to_obstacles = obstacles - origin
|
|
267
|
+
dists = np.linalg.norm(vecs_to_obstacles, axis=1, keepdims=True)
|
|
268
|
+
dists = np.maximum(dists, 1e-6)
|
|
269
|
+
dirs_to_obstacles = vecs_to_obstacles / dists
|
|
270
|
+
|
|
271
|
+
for axis_index in range(max(node.degree, 1)):
|
|
272
|
+
if (node_id, axis_index) in directions:
|
|
273
|
+
continue
|
|
274
|
+
axis_name = node.axes_names[axis_index] if axis_index < len(node.axes_names) else None
|
|
275
|
+
named_d = _direction_from_axis_name_2d(axis_name)
|
|
276
|
+
if named_d is not None:
|
|
277
|
+
used_dirs = [
|
|
278
|
+
directions[(node_id, j)]
|
|
279
|
+
for j in range(max(node.degree, 1))
|
|
280
|
+
if (node_id, j) in directions
|
|
281
|
+
]
|
|
282
|
+
overlap = sum(max(0.0, float(np.dot(named_d, u[:2]))) for u in used_dirs)
|
|
283
|
+
if overlap < 0.7:
|
|
284
|
+
directions[(node_id, axis_index)] = named_d
|
|
285
|
+
continue
|
|
286
|
+
used_dirs = [
|
|
287
|
+
directions[(node_id, j)]
|
|
288
|
+
for j in range(max(node.degree, 1))
|
|
289
|
+
if (node_id, j) in directions
|
|
290
|
+
]
|
|
291
|
+
best_score = -np.inf
|
|
292
|
+
best_d = np.array([1.0, 0.0], dtype=float)
|
|
293
|
+
for d in unit_circle:
|
|
294
|
+
d = d.astype(float)
|
|
295
|
+
toward_obstacles = np.dot(dirs_to_obstacles, d)
|
|
296
|
+
away_score = -float(np.min(toward_obstacles))
|
|
297
|
+
sep_score = 0.0
|
|
298
|
+
for u in used_dirs:
|
|
299
|
+
sim = float(np.dot(d, u[:2]))
|
|
300
|
+
sep_score += max(0.0, sim) * 2.0
|
|
301
|
+
score = away_score - sep_score
|
|
302
|
+
if score > best_score:
|
|
303
|
+
best_score = score
|
|
304
|
+
best_d = d.copy()
|
|
305
|
+
directions[(node_id, axis_index)] = best_d
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _compute_free_directions_3d(
|
|
309
|
+
graph: _GraphData,
|
|
310
|
+
positions: NodePositions,
|
|
311
|
+
center: np.ndarray,
|
|
312
|
+
directions: AxisDirections,
|
|
313
|
+
) -> None:
|
|
314
|
+
for node_id, node in graph.nodes.items():
|
|
315
|
+
origin = positions[node_id]
|
|
316
|
+
radial = origin - center
|
|
317
|
+
if np.linalg.norm(radial) < 1e-6:
|
|
318
|
+
radial = np.array([1.0, 0.0, 0.0], dtype=float)
|
|
319
|
+
radial = radial / np.linalg.norm(radial)
|
|
320
|
+
basis_a = _orthogonal_unit(radial)
|
|
321
|
+
basis_b = np.cross(radial, basis_a)
|
|
322
|
+
basis_b = basis_b / np.linalg.norm(basis_b)
|
|
323
|
+
free_indices = [
|
|
324
|
+
j for j in range(max(node.degree, 1))
|
|
325
|
+
if (node_id, j) not in directions
|
|
326
|
+
]
|
|
327
|
+
for idx, axis_index in enumerate(free_indices):
|
|
328
|
+
axis_name = node.axes_names[axis_index] if axis_index < len(node.axes_names) else None
|
|
329
|
+
named_d = _direction_from_axis_name_3d(axis_name)
|
|
330
|
+
if named_d is not None:
|
|
331
|
+
used_dirs = [
|
|
332
|
+
directions[(node_id, j)]
|
|
333
|
+
for j in range(max(node.degree, 1))
|
|
334
|
+
if (node_id, j) in directions
|
|
335
|
+
]
|
|
336
|
+
overlap = sum(max(0.0, float(np.dot(named_d, u))) for u in used_dirs)
|
|
337
|
+
if overlap < 0.7:
|
|
338
|
+
directions[(node_id, axis_index)] = named_d
|
|
339
|
+
continue
|
|
340
|
+
angle = 2.0 * math.pi * idx / max(len(free_indices), 1)
|
|
341
|
+
direction = radial + 0.55 * (
|
|
342
|
+
math.cos(angle) * basis_a + math.sin(angle) * basis_b
|
|
343
|
+
)
|
|
344
|
+
directions[(node_id, axis_index)] = direction / np.linalg.norm(direction)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def _orthogonal_unit(vector: Vector) -> Vector:
|
|
348
|
+
reference = np.array([0.0, 0.0, 1.0], dtype=float)
|
|
349
|
+
if abs(float(np.dot(vector, reference))) > 0.9:
|
|
350
|
+
reference = np.array([0.0, 1.0, 0.0], dtype=float)
|
|
351
|
+
orthogonal = np.cross(vector, reference)
|
|
352
|
+
return orthogonal / np.linalg.norm(orthogonal)
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
"""Main entry points for TensorKrowch tensor network plotting."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, cast
|
|
6
|
+
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import numpy as np
|
|
9
|
+
from matplotlib.axes import Axes
|
|
10
|
+
from matplotlib.figure import Figure
|
|
11
|
+
from mpl_toolkits.mplot3d.axes3d import Axes3D
|
|
12
|
+
|
|
13
|
+
from ..config import PlotConfig
|
|
14
|
+
from .draw_2d import _draw_2d
|
|
15
|
+
from .draw_3d import _draw_3d
|
|
16
|
+
from .graph import _build_graph, _GraphData
|
|
17
|
+
from .layout import _compute_axis_directions, _compute_layout
|
|
18
|
+
|
|
19
|
+
NodePositions = dict[int, np.ndarray]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _apply_custom_positions(
|
|
23
|
+
graph: _GraphData,
|
|
24
|
+
custom_positions: dict[int, tuple[float, ...]],
|
|
25
|
+
dimensions: int,
|
|
26
|
+
) -> NodePositions:
|
|
27
|
+
"""Apply custom positions, using layout for missing nodes, then center and scale."""
|
|
28
|
+
from .layout import _compute_layout
|
|
29
|
+
|
|
30
|
+
node_ids = list(graph.nodes)
|
|
31
|
+
positions_arr = np.zeros((len(node_ids), dimensions), dtype=float)
|
|
32
|
+
missing: list[int] = []
|
|
33
|
+
for i, nid in enumerate(node_ids):
|
|
34
|
+
if nid in custom_positions:
|
|
35
|
+
pos = np.array(custom_positions[nid], dtype=float)
|
|
36
|
+
n = min(len(pos), dimensions)
|
|
37
|
+
positions_arr[i, :n] = pos[:n]
|
|
38
|
+
else:
|
|
39
|
+
missing.append(nid)
|
|
40
|
+
if missing:
|
|
41
|
+
fallback = _compute_layout(graph, dimensions=dimensions, seed=0)
|
|
42
|
+
for i, nid in enumerate(node_ids):
|
|
43
|
+
if nid in missing:
|
|
44
|
+
positions_arr[i] = fallback[nid]
|
|
45
|
+
positions_arr -= positions_arr.mean(axis=0, keepdims=True)
|
|
46
|
+
max_norm = np.linalg.norm(positions_arr, axis=1).max()
|
|
47
|
+
if max_norm > 1e-6:
|
|
48
|
+
positions_arr /= max_norm / 1.6
|
|
49
|
+
return {nid: positions_arr[i].copy() for i, nid in enumerate(node_ids)}
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _resolve_flag(value: bool | None, default: bool) -> bool:
|
|
53
|
+
if value is None:
|
|
54
|
+
return default
|
|
55
|
+
return value
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _compute_scale(n_nodes: int) -> float:
|
|
59
|
+
"""Scale factor for visual elements: larger for few nodes, smaller for many."""
|
|
60
|
+
if n_nodes <= 1:
|
|
61
|
+
return 1.2
|
|
62
|
+
return max(0.5, min(1.6, 2.2 - 0.07 * n_nodes))
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _prepare_axes_2d(
|
|
66
|
+
ax: Axes | None,
|
|
67
|
+
*,
|
|
68
|
+
figsize: tuple[float, float] | None,
|
|
69
|
+
) -> tuple[Figure, Axes]:
|
|
70
|
+
if ax is None:
|
|
71
|
+
fig, ax = plt.subplots(figsize=figsize or (14, 10))
|
|
72
|
+
return fig, ax
|
|
73
|
+
|
|
74
|
+
if getattr(ax, "name", "") == "3d":
|
|
75
|
+
raise ValueError("plot_tensorkrowch_network_2d requires a 2D Matplotlib axis.")
|
|
76
|
+
return ax.figure, ax
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _prepare_axes_3d(
|
|
80
|
+
ax: Axes | Axes3D | None,
|
|
81
|
+
*,
|
|
82
|
+
figsize: tuple[float, float] | None,
|
|
83
|
+
) -> tuple[Figure, Axes3D]:
|
|
84
|
+
if ax is None:
|
|
85
|
+
fig = plt.figure(figsize=figsize or (14, 10))
|
|
86
|
+
created_ax = fig.add_subplot(111, projection="3d")
|
|
87
|
+
return fig, cast(Axes3D, created_ax)
|
|
88
|
+
|
|
89
|
+
if getattr(ax, "name", "") != "3d":
|
|
90
|
+
raise ValueError("plot_tensorkrowch_network_3d requires a 3D Matplotlib axis.")
|
|
91
|
+
return ax.figure, cast(Axes3D, ax)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def plot_tensorkrowch_network_2d(
|
|
95
|
+
network: Any,
|
|
96
|
+
*,
|
|
97
|
+
ax: Axes | None = None,
|
|
98
|
+
config: PlotConfig | None = None,
|
|
99
|
+
show_tensor_labels: bool | None = None,
|
|
100
|
+
show_index_labels: bool | None = None,
|
|
101
|
+
seed: int = 0,
|
|
102
|
+
) -> tuple[Figure, Axes]:
|
|
103
|
+
"""Plot a TensorKrowch tensor network in 2D.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
network: TensorKrowch TensorNetwork with nodes and edges.
|
|
107
|
+
ax: Matplotlib 2D axes; if None, creates a new figure.
|
|
108
|
+
config: Styling options; uses defaults if None.
|
|
109
|
+
show_tensor_labels: Override config; None uses config value.
|
|
110
|
+
show_index_labels: Override config; None uses config value.
|
|
111
|
+
seed: Random seed for layout when using force-directed positioning.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Tuple of (Figure, Axes) for further customization.
|
|
115
|
+
"""
|
|
116
|
+
style = config or PlotConfig()
|
|
117
|
+
graph = _build_graph(network)
|
|
118
|
+
fig, ax = _prepare_axes_2d(ax=ax, figsize=style.figsize)
|
|
119
|
+
if style.positions is not None:
|
|
120
|
+
positions = _apply_custom_positions(graph, style.positions, dimensions=2)
|
|
121
|
+
else:
|
|
122
|
+
positions = _compute_layout(graph, dimensions=2, seed=seed)
|
|
123
|
+
directions = _compute_axis_directions(graph, positions, dimensions=2)
|
|
124
|
+
scale = _compute_scale(len(graph.nodes))
|
|
125
|
+
_draw_2d(
|
|
126
|
+
ax=ax,
|
|
127
|
+
graph=graph,
|
|
128
|
+
positions=positions,
|
|
129
|
+
directions=directions,
|
|
130
|
+
show_tensor_labels=_resolve_flag(show_tensor_labels, style.show_tensor_labels),
|
|
131
|
+
show_index_labels=_resolve_flag(show_index_labels, style.show_index_labels),
|
|
132
|
+
config=style,
|
|
133
|
+
scale=scale,
|
|
134
|
+
)
|
|
135
|
+
fig.subplots_adjust(left=0.02, right=0.98, bottom=0.02, top=0.98)
|
|
136
|
+
return fig, ax
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def plot_tensorkrowch_network_3d(
|
|
140
|
+
network: Any,
|
|
141
|
+
*,
|
|
142
|
+
ax: Axes | Axes3D | None = None,
|
|
143
|
+
config: PlotConfig | None = None,
|
|
144
|
+
show_tensor_labels: bool | None = None,
|
|
145
|
+
show_index_labels: bool | None = None,
|
|
146
|
+
seed: int = 0,
|
|
147
|
+
) -> tuple[Figure, Axes3D]:
|
|
148
|
+
"""Plot a TensorKrowch tensor network in 3D.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
network: TensorKrowch TensorNetwork with nodes and edges.
|
|
152
|
+
ax: Matplotlib 3D axes; if None, creates a new figure with 3D projection.
|
|
153
|
+
config: Styling options; uses defaults if None.
|
|
154
|
+
show_tensor_labels: Override config; None uses config value.
|
|
155
|
+
show_index_labels: Override config; None uses config value.
|
|
156
|
+
seed: Random seed for layout when using force-directed positioning.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Tuple of (Figure, Axes3D) for further customization.
|
|
160
|
+
"""
|
|
161
|
+
style = config or PlotConfig()
|
|
162
|
+
graph = _build_graph(network)
|
|
163
|
+
fig, ax = _prepare_axes_3d(ax=ax, figsize=style.figsize)
|
|
164
|
+
if style.positions is not None:
|
|
165
|
+
positions = _apply_custom_positions(graph, style.positions, dimensions=3)
|
|
166
|
+
else:
|
|
167
|
+
positions = _compute_layout(graph, dimensions=3, seed=seed)
|
|
168
|
+
directions = _compute_axis_directions(graph, positions, dimensions=3)
|
|
169
|
+
scale = _compute_scale(len(graph.nodes))
|
|
170
|
+
_draw_3d(
|
|
171
|
+
ax=ax,
|
|
172
|
+
graph=graph,
|
|
173
|
+
positions=positions,
|
|
174
|
+
directions=directions,
|
|
175
|
+
show_tensor_labels=_resolve_flag(show_tensor_labels, style.show_tensor_labels),
|
|
176
|
+
show_index_labels=_resolve_flag(show_index_labels, style.show_index_labels),
|
|
177
|
+
config=style,
|
|
178
|
+
scale=scale,
|
|
179
|
+
)
|
|
180
|
+
fig.subplots_adjust(left=0.02, right=0.98, bottom=0.02, top=0.98)
|
|
181
|
+
return fig, ax
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TypeAlias
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
from matplotlib.axes import Axes
|
|
7
|
+
from matplotlib.figure import Figure
|
|
8
|
+
from mpl_toolkits.mplot3d.axes3d import Axes3D
|
|
9
|
+
|
|
10
|
+
from .config import EngineName, PlotConfig, ViewName
|
|
11
|
+
from .tensorkrowch import (
|
|
12
|
+
plot_tensorkrowch_network_2d,
|
|
13
|
+
plot_tensorkrowch_network_3d,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
RenderedAxes: TypeAlias = Axes | Axes3D
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def show_tensor_network(
|
|
20
|
+
network: Any,
|
|
21
|
+
*,
|
|
22
|
+
engine: EngineName,
|
|
23
|
+
view: ViewName,
|
|
24
|
+
config: PlotConfig | None = None,
|
|
25
|
+
show: bool = True,
|
|
26
|
+
) -> tuple[Figure, RenderedAxes]:
|
|
27
|
+
"""Render a tensor network and optionally display the figure.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
network: Tensor network object (must expose 'nodes' or 'leaf_nodes' with
|
|
31
|
+
nodes having 'edges', 'axes_names', and 'name' attributes).
|
|
32
|
+
engine: Rendering engine; currently only "tensorkrowch" is supported.
|
|
33
|
+
view: "2d" or "3d" visualization mode.
|
|
34
|
+
config: Optional styling; uses defaults if None.
|
|
35
|
+
show: If True, call plt.show() to display the figure. Set False when
|
|
36
|
+
integrating into other applications (e.g. adding a title before showing).
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Tuple of (Figure, Axes) for further customization.
|
|
40
|
+
|
|
41
|
+
Example:
|
|
42
|
+
>>> config = PlotConfig(figsize=(8, 6))
|
|
43
|
+
>>> fig, ax = show_tensor_network(network, engine="tensorkrowch", view="2d", config=config)
|
|
44
|
+
"""
|
|
45
|
+
style = config or PlotConfig()
|
|
46
|
+
|
|
47
|
+
if engine == "tensorkrowch":
|
|
48
|
+
if view == "2d":
|
|
49
|
+
fig, ax = plot_tensorkrowch_network_2d(network, config=style)
|
|
50
|
+
else:
|
|
51
|
+
fig, ax = plot_tensorkrowch_network_3d(network, config=style)
|
|
52
|
+
else:
|
|
53
|
+
raise ValueError(f"Unsupported tensor network engine: {engine}")
|
|
54
|
+
|
|
55
|
+
if show:
|
|
56
|
+
plt.show()
|
|
57
|
+
return fig, ax
|