tensor-network-visualization 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,352 @@
1
+ """Layout computation for tensor network graphs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from typing import TypeAlias
7
+
8
+ import networkx as nx
9
+ import numpy as np
10
+
11
+ from .graph import _GraphData
12
+
13
+ Vector: TypeAlias = np.ndarray
14
+ NodePositions: TypeAlias = dict[int, Vector]
15
+ AxisDirections: TypeAlias = dict[tuple[int, int], Vector]
16
+
17
+ # Axis names that map to fixed 2D directions (x, y). Case-insensitive.
18
+ _AXIS_DIR_2D: dict[str, tuple[float, float]] = {
19
+ "up": (0.0, 1.0),
20
+ "down": (0.0, -1.0),
21
+ "left": (-1.0, 0.0),
22
+ "right": (1.0, 0.0),
23
+ "north": (0.0, 1.0),
24
+ "south": (0.0, -1.0),
25
+ "east": (1.0, 0.0),
26
+ "west": (-1.0, 0.0),
27
+ }
28
+
29
+ _AXIS_DIR_3D: dict[str, tuple[float, float, float]] = {
30
+ "up": (0.0, 0.0, 1.0),
31
+ "down": (0.0, 0.0, -1.0),
32
+ "left": (-1.0, 0.0, 0.0),
33
+ "right": (1.0, 0.0, 0.0),
34
+ "north": (0.0, 0.0, 1.0),
35
+ "south": (0.0, 0.0, -1.0),
36
+ "east": (1.0, 0.0, 0.0),
37
+ "west": (-1.0, 0.0, 0.0),
38
+ "front": (0.0, 1.0, 0.0),
39
+ "back": (0.0, -1.0, 0.0),
40
+ "in": (0.0, 1.0, 0.0),
41
+ "out": (0.0, -1.0, 0.0),
42
+ }
43
+
44
+
45
+ def _direction_from_axis_name_2d(axis_name: str | None) -> np.ndarray | None:
46
+ if not axis_name:
47
+ return None
48
+ key = axis_name.lower().strip()
49
+ if key in _AXIS_DIR_2D:
50
+ return np.array(_AXIS_DIR_2D[key], dtype=float)
51
+ return None
52
+
53
+
54
+ def _direction_from_axis_name_3d(axis_name: str | None) -> np.ndarray | None:
55
+ if not axis_name:
56
+ return None
57
+ key = axis_name.lower().strip()
58
+ if key in _AXIS_DIR_3D:
59
+ return np.array(_AXIS_DIR_3D[key], dtype=float)
60
+ return None
61
+
62
+
63
+ def _compute_layout(graph: _GraphData, dimensions: int, seed: int) -> NodePositions:
64
+ node_ids = list(graph.nodes)
65
+ if len(node_ids) == 1:
66
+ origin = np.zeros(dimensions, dtype=float)
67
+ return {node_ids[0]: origin}
68
+
69
+ if dimensions == 2:
70
+ grid_pos = _try_grid_layout_2d(graph)
71
+ if grid_pos is not None:
72
+ return grid_pos
73
+ planar_pos = _try_planar_layout_2d(graph)
74
+ if planar_pos is not None:
75
+ return planar_pos
76
+
77
+ positions = _initial_positions(node_ids, dimensions=dimensions, seed=seed)
78
+ index_by_node = {node_id: index for index, node_id in enumerate(node_ids)}
79
+
80
+ pair_weights: dict[tuple[int, int], int] = {}
81
+ for edge in graph.edges:
82
+ if edge.kind != "contraction":
83
+ continue
84
+ left, right = edge.node_ids
85
+ key = tuple(sorted((left, right)))
86
+ pair_weights[key] = pair_weights.get(key, 0) + 1
87
+
88
+ k = 1.6
89
+ temperature = 0.12
90
+ for _ in range(220):
91
+ deltas = positions[:, None, :] - positions[None, :, :]
92
+ distances = np.linalg.norm(deltas, axis=2)
93
+ np.fill_diagonal(distances, 1.0)
94
+ directions = deltas / np.maximum(distances[..., None], 1e-6)
95
+ repulsion = (k * k / np.maximum(distances, 1e-6) ** 2)[..., None] * directions
96
+ displacement = repulsion.sum(axis=1)
97
+
98
+ for (left_id, right_id), weight in pair_weights.items():
99
+ left_index = index_by_node[left_id]
100
+ right_index = index_by_node[right_id]
101
+ delta = positions[right_index] - positions[left_index]
102
+ distance = max(float(np.linalg.norm(delta)), 1e-6)
103
+ direction = delta / distance
104
+ attraction = weight * distance * distance / k * direction
105
+ displacement[left_index] += attraction
106
+ displacement[right_index] -= attraction
107
+
108
+ norms = np.linalg.norm(displacement, axis=1, keepdims=True)
109
+ positions += displacement / np.maximum(norms, 1e-6) * temperature
110
+ positions -= positions.mean(axis=0, keepdims=True)
111
+ max_norm = np.linalg.norm(positions, axis=1).max()
112
+ if max_norm > 1.6:
113
+ positions /= max_norm / 1.6
114
+ temperature *= 0.985
115
+
116
+ return {node_id: positions[index] for index, node_id in enumerate(node_ids)}
117
+
118
+
119
+ def _try_grid_layout_2d(graph: _GraphData) -> NodePositions | None:
120
+ """Attempt regular grid layout when the graph is a 2D grid. Returns None otherwise."""
121
+ node_ids = list(graph.nodes)
122
+ if len(node_ids) <= 1:
123
+ return None
124
+
125
+ g = nx.Graph()
126
+ g.add_nodes_from(node_ids)
127
+ for edge in graph.edges:
128
+ if edge.kind != "contraction":
129
+ continue
130
+ left, right = edge.node_ids
131
+ if left != right:
132
+ g.add_edge(left, right)
133
+
134
+ n_nodes = g.number_of_nodes()
135
+ n_edges = g.number_of_edges()
136
+
137
+ for rows in range(1, n_nodes + 1):
138
+ if n_nodes % rows != 0:
139
+ continue
140
+ cols = n_nodes // rows
141
+ expected_edges = 2 * rows * cols - rows - cols
142
+ if n_edges != expected_edges:
143
+ continue
144
+ grid_g = nx.grid_2d_graph(rows, cols)
145
+ mapping = nx.vf2pp_isomorphism(g, grid_g)
146
+ if mapping is not None:
147
+ arr = np.array(
148
+ [[mapping[nid][1], mapping[nid][0]] for nid in node_ids],
149
+ dtype=float,
150
+ )
151
+ arr -= arr.mean(axis=0, keepdims=True)
152
+ max_norm = np.linalg.norm(arr, axis=1).max()
153
+ if max_norm > 1e-6:
154
+ arr /= max_norm / 1.6
155
+ return {nid: arr[i].copy() for i, nid in enumerate(node_ids)}
156
+ return None
157
+
158
+
159
+ def _try_planar_layout_2d(graph: _GraphData) -> NodePositions | None:
160
+ """Attempt planar layout for 2D. Returns None if graph is not planar."""
161
+ node_ids = list(graph.nodes)
162
+ if len(node_ids) <= 1:
163
+ return None
164
+
165
+ g = nx.Graph()
166
+ g.add_nodes_from(node_ids)
167
+ for edge in graph.edges:
168
+ if edge.kind != "contraction":
169
+ continue
170
+ left, right = edge.node_ids
171
+ if left != right:
172
+ g.add_edge(left, right)
173
+
174
+ try:
175
+ pos = nx.planar_layout(g)
176
+ except nx.NetworkXException:
177
+ return None
178
+
179
+ arr = np.array([pos[nid] for nid in node_ids], dtype=float)
180
+ arr -= arr.mean(axis=0, keepdims=True)
181
+ max_norm = np.linalg.norm(arr, axis=1).max()
182
+ if max_norm > 1e-6:
183
+ arr /= max_norm / 1.6
184
+ return {nid: arr[i].copy() for i, nid in enumerate(node_ids)}
185
+
186
+
187
+ def _initial_positions(node_ids: list[int], dimensions: int, seed: int) -> Vector:
188
+ count = len(node_ids)
189
+ rng = np.random.default_rng(seed)
190
+
191
+ if dimensions == 2:
192
+ angles = np.linspace(0.0, 2.0 * math.pi, count, endpoint=False)
193
+ positions = np.column_stack((np.cos(angles), np.sin(angles)))
194
+ else:
195
+ positions = np.zeros((count, 3), dtype=float)
196
+ golden_angle = math.pi * (3.0 - math.sqrt(5.0))
197
+ for index in range(count):
198
+ y = 1.0 - (2.0 * index) / max(count - 1, 1)
199
+ radius = math.sqrt(max(0.0, 1.0 - y * y))
200
+ theta = golden_angle * index
201
+ positions[index] = np.array(
202
+ [math.cos(theta) * radius, y, math.sin(theta) * radius],
203
+ dtype=float,
204
+ )
205
+
206
+ positions += rng.normal(loc=0.0, scale=0.03, size=positions.shape)
207
+ return positions
208
+
209
+
210
+ def _compute_axis_directions(
211
+ graph: _GraphData,
212
+ positions: NodePositions,
213
+ dimensions: int,
214
+ ) -> AxisDirections:
215
+ directions: AxisDirections = {}
216
+ center = np.mean(np.stack(list(positions.values())), axis=0)
217
+
218
+ for edge in graph.edges:
219
+ if edge.kind != "contraction":
220
+ continue
221
+ a_id, b_id = edge.node_ids
222
+ a_ep = next(ep for ep in edge.endpoints if ep.node_id == a_id)
223
+ b_ep = next(ep for ep in edge.endpoints if ep.node_id == b_id)
224
+ pa = positions[a_id]
225
+ pb = positions[b_id]
226
+ delta = pb - pa
227
+ dist = max(float(np.linalg.norm(delta)), 1e-6)
228
+ toward_b = delta / dist
229
+ toward_a = -toward_b
230
+ directions[(a_id, a_ep.axis_index)] = toward_b
231
+ directions[(b_id, b_ep.axis_index)] = toward_a
232
+
233
+ if dimensions == 2:
234
+ _compute_free_directions_2d(graph, positions, directions)
235
+ else:
236
+ _compute_free_directions_3d(graph, positions, center, directions)
237
+
238
+ return directions
239
+
240
+
241
+ def _compute_free_directions_2d(
242
+ graph: _GraphData,
243
+ positions: NodePositions,
244
+ directions: AxisDirections,
245
+ ) -> None:
246
+ pos_arr = np.stack(list(positions.values()))
247
+ node_ids = list(positions.keys())
248
+ index_by_node = {nid: i for i, nid in enumerate(node_ids)}
249
+ _SAMPLES = 72
250
+ angles = np.linspace(0.0, 2.0 * math.pi, _SAMPLES, endpoint=False)
251
+ unit_circle = np.column_stack((np.cos(angles), np.sin(angles)))
252
+
253
+ for node_id, node in graph.nodes.items():
254
+ origin = positions[node_id]
255
+ i_origin = index_by_node[node_id]
256
+ obstacles = list(np.delete(pos_arr, i_origin, axis=0))
257
+ for edge in graph.edges:
258
+ if edge.kind == "contraction" and node_id in edge.node_ids:
259
+ other_id = edge.node_ids[1] if edge.node_ids[0] == node_id else edge.node_ids[0]
260
+ mid = (positions[node_id] + positions[other_id]) / 2.0
261
+ obstacles.append(mid)
262
+ obstacles = np.array(obstacles, dtype=float) if obstacles else np.array(
263
+ [[origin[0] + 1.0, origin[1]]], dtype=float
264
+ )
265
+
266
+ vecs_to_obstacles = obstacles - origin
267
+ dists = np.linalg.norm(vecs_to_obstacles, axis=1, keepdims=True)
268
+ dists = np.maximum(dists, 1e-6)
269
+ dirs_to_obstacles = vecs_to_obstacles / dists
270
+
271
+ for axis_index in range(max(node.degree, 1)):
272
+ if (node_id, axis_index) in directions:
273
+ continue
274
+ axis_name = node.axes_names[axis_index] if axis_index < len(node.axes_names) else None
275
+ named_d = _direction_from_axis_name_2d(axis_name)
276
+ if named_d is not None:
277
+ used_dirs = [
278
+ directions[(node_id, j)]
279
+ for j in range(max(node.degree, 1))
280
+ if (node_id, j) in directions
281
+ ]
282
+ overlap = sum(max(0.0, float(np.dot(named_d, u[:2]))) for u in used_dirs)
283
+ if overlap < 0.7:
284
+ directions[(node_id, axis_index)] = named_d
285
+ continue
286
+ used_dirs = [
287
+ directions[(node_id, j)]
288
+ for j in range(max(node.degree, 1))
289
+ if (node_id, j) in directions
290
+ ]
291
+ best_score = -np.inf
292
+ best_d = np.array([1.0, 0.0], dtype=float)
293
+ for d in unit_circle:
294
+ d = d.astype(float)
295
+ toward_obstacles = np.dot(dirs_to_obstacles, d)
296
+ away_score = -float(np.min(toward_obstacles))
297
+ sep_score = 0.0
298
+ for u in used_dirs:
299
+ sim = float(np.dot(d, u[:2]))
300
+ sep_score += max(0.0, sim) * 2.0
301
+ score = away_score - sep_score
302
+ if score > best_score:
303
+ best_score = score
304
+ best_d = d.copy()
305
+ directions[(node_id, axis_index)] = best_d
306
+
307
+
308
+ def _compute_free_directions_3d(
309
+ graph: _GraphData,
310
+ positions: NodePositions,
311
+ center: np.ndarray,
312
+ directions: AxisDirections,
313
+ ) -> None:
314
+ for node_id, node in graph.nodes.items():
315
+ origin = positions[node_id]
316
+ radial = origin - center
317
+ if np.linalg.norm(radial) < 1e-6:
318
+ radial = np.array([1.0, 0.0, 0.0], dtype=float)
319
+ radial = radial / np.linalg.norm(radial)
320
+ basis_a = _orthogonal_unit(radial)
321
+ basis_b = np.cross(radial, basis_a)
322
+ basis_b = basis_b / np.linalg.norm(basis_b)
323
+ free_indices = [
324
+ j for j in range(max(node.degree, 1))
325
+ if (node_id, j) not in directions
326
+ ]
327
+ for idx, axis_index in enumerate(free_indices):
328
+ axis_name = node.axes_names[axis_index] if axis_index < len(node.axes_names) else None
329
+ named_d = _direction_from_axis_name_3d(axis_name)
330
+ if named_d is not None:
331
+ used_dirs = [
332
+ directions[(node_id, j)]
333
+ for j in range(max(node.degree, 1))
334
+ if (node_id, j) in directions
335
+ ]
336
+ overlap = sum(max(0.0, float(np.dot(named_d, u))) for u in used_dirs)
337
+ if overlap < 0.7:
338
+ directions[(node_id, axis_index)] = named_d
339
+ continue
340
+ angle = 2.0 * math.pi * idx / max(len(free_indices), 1)
341
+ direction = radial + 0.55 * (
342
+ math.cos(angle) * basis_a + math.sin(angle) * basis_b
343
+ )
344
+ directions[(node_id, axis_index)] = direction / np.linalg.norm(direction)
345
+
346
+
347
+ def _orthogonal_unit(vector: Vector) -> Vector:
348
+ reference = np.array([0.0, 0.0, 1.0], dtype=float)
349
+ if abs(float(np.dot(vector, reference))) > 0.9:
350
+ reference = np.array([0.0, 1.0, 0.0], dtype=float)
351
+ orthogonal = np.cross(vector, reference)
352
+ return orthogonal / np.linalg.norm(orthogonal)
@@ -0,0 +1,181 @@
1
+ """Main entry points for TensorKrowch tensor network plotting."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, cast
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ from matplotlib.axes import Axes
10
+ from matplotlib.figure import Figure
11
+ from mpl_toolkits.mplot3d.axes3d import Axes3D
12
+
13
+ from ..config import PlotConfig
14
+ from .draw_2d import _draw_2d
15
+ from .draw_3d import _draw_3d
16
+ from .graph import _build_graph, _GraphData
17
+ from .layout import _compute_axis_directions, _compute_layout
18
+
19
+ NodePositions = dict[int, np.ndarray]
20
+
21
+
22
+ def _apply_custom_positions(
23
+ graph: _GraphData,
24
+ custom_positions: dict[int, tuple[float, ...]],
25
+ dimensions: int,
26
+ ) -> NodePositions:
27
+ """Apply custom positions, using layout for missing nodes, then center and scale."""
28
+ from .layout import _compute_layout
29
+
30
+ node_ids = list(graph.nodes)
31
+ positions_arr = np.zeros((len(node_ids), dimensions), dtype=float)
32
+ missing: list[int] = []
33
+ for i, nid in enumerate(node_ids):
34
+ if nid in custom_positions:
35
+ pos = np.array(custom_positions[nid], dtype=float)
36
+ n = min(len(pos), dimensions)
37
+ positions_arr[i, :n] = pos[:n]
38
+ else:
39
+ missing.append(nid)
40
+ if missing:
41
+ fallback = _compute_layout(graph, dimensions=dimensions, seed=0)
42
+ for i, nid in enumerate(node_ids):
43
+ if nid in missing:
44
+ positions_arr[i] = fallback[nid]
45
+ positions_arr -= positions_arr.mean(axis=0, keepdims=True)
46
+ max_norm = np.linalg.norm(positions_arr, axis=1).max()
47
+ if max_norm > 1e-6:
48
+ positions_arr /= max_norm / 1.6
49
+ return {nid: positions_arr[i].copy() for i, nid in enumerate(node_ids)}
50
+
51
+
52
+ def _resolve_flag(value: bool | None, default: bool) -> bool:
53
+ if value is None:
54
+ return default
55
+ return value
56
+
57
+
58
+ def _compute_scale(n_nodes: int) -> float:
59
+ """Scale factor for visual elements: larger for few nodes, smaller for many."""
60
+ if n_nodes <= 1:
61
+ return 1.2
62
+ return max(0.5, min(1.6, 2.2 - 0.07 * n_nodes))
63
+
64
+
65
+ def _prepare_axes_2d(
66
+ ax: Axes | None,
67
+ *,
68
+ figsize: tuple[float, float] | None,
69
+ ) -> tuple[Figure, Axes]:
70
+ if ax is None:
71
+ fig, ax = plt.subplots(figsize=figsize or (14, 10))
72
+ return fig, ax
73
+
74
+ if getattr(ax, "name", "") == "3d":
75
+ raise ValueError("plot_tensorkrowch_network_2d requires a 2D Matplotlib axis.")
76
+ return ax.figure, ax
77
+
78
+
79
+ def _prepare_axes_3d(
80
+ ax: Axes | Axes3D | None,
81
+ *,
82
+ figsize: tuple[float, float] | None,
83
+ ) -> tuple[Figure, Axes3D]:
84
+ if ax is None:
85
+ fig = plt.figure(figsize=figsize or (14, 10))
86
+ created_ax = fig.add_subplot(111, projection="3d")
87
+ return fig, cast(Axes3D, created_ax)
88
+
89
+ if getattr(ax, "name", "") != "3d":
90
+ raise ValueError("plot_tensorkrowch_network_3d requires a 3D Matplotlib axis.")
91
+ return ax.figure, cast(Axes3D, ax)
92
+
93
+
94
+ def plot_tensorkrowch_network_2d(
95
+ network: Any,
96
+ *,
97
+ ax: Axes | None = None,
98
+ config: PlotConfig | None = None,
99
+ show_tensor_labels: bool | None = None,
100
+ show_index_labels: bool | None = None,
101
+ seed: int = 0,
102
+ ) -> tuple[Figure, Axes]:
103
+ """Plot a TensorKrowch tensor network in 2D.
104
+
105
+ Args:
106
+ network: TensorKrowch TensorNetwork with nodes and edges.
107
+ ax: Matplotlib 2D axes; if None, creates a new figure.
108
+ config: Styling options; uses defaults if None.
109
+ show_tensor_labels: Override config; None uses config value.
110
+ show_index_labels: Override config; None uses config value.
111
+ seed: Random seed for layout when using force-directed positioning.
112
+
113
+ Returns:
114
+ Tuple of (Figure, Axes) for further customization.
115
+ """
116
+ style = config or PlotConfig()
117
+ graph = _build_graph(network)
118
+ fig, ax = _prepare_axes_2d(ax=ax, figsize=style.figsize)
119
+ if style.positions is not None:
120
+ positions = _apply_custom_positions(graph, style.positions, dimensions=2)
121
+ else:
122
+ positions = _compute_layout(graph, dimensions=2, seed=seed)
123
+ directions = _compute_axis_directions(graph, positions, dimensions=2)
124
+ scale = _compute_scale(len(graph.nodes))
125
+ _draw_2d(
126
+ ax=ax,
127
+ graph=graph,
128
+ positions=positions,
129
+ directions=directions,
130
+ show_tensor_labels=_resolve_flag(show_tensor_labels, style.show_tensor_labels),
131
+ show_index_labels=_resolve_flag(show_index_labels, style.show_index_labels),
132
+ config=style,
133
+ scale=scale,
134
+ )
135
+ fig.subplots_adjust(left=0.02, right=0.98, bottom=0.02, top=0.98)
136
+ return fig, ax
137
+
138
+
139
+ def plot_tensorkrowch_network_3d(
140
+ network: Any,
141
+ *,
142
+ ax: Axes | Axes3D | None = None,
143
+ config: PlotConfig | None = None,
144
+ show_tensor_labels: bool | None = None,
145
+ show_index_labels: bool | None = None,
146
+ seed: int = 0,
147
+ ) -> tuple[Figure, Axes3D]:
148
+ """Plot a TensorKrowch tensor network in 3D.
149
+
150
+ Args:
151
+ network: TensorKrowch TensorNetwork with nodes and edges.
152
+ ax: Matplotlib 3D axes; if None, creates a new figure with 3D projection.
153
+ config: Styling options; uses defaults if None.
154
+ show_tensor_labels: Override config; None uses config value.
155
+ show_index_labels: Override config; None uses config value.
156
+ seed: Random seed for layout when using force-directed positioning.
157
+
158
+ Returns:
159
+ Tuple of (Figure, Axes3D) for further customization.
160
+ """
161
+ style = config or PlotConfig()
162
+ graph = _build_graph(network)
163
+ fig, ax = _prepare_axes_3d(ax=ax, figsize=style.figsize)
164
+ if style.positions is not None:
165
+ positions = _apply_custom_positions(graph, style.positions, dimensions=3)
166
+ else:
167
+ positions = _compute_layout(graph, dimensions=3, seed=seed)
168
+ directions = _compute_axis_directions(graph, positions, dimensions=3)
169
+ scale = _compute_scale(len(graph.nodes))
170
+ _draw_3d(
171
+ ax=ax,
172
+ graph=graph,
173
+ positions=positions,
174
+ directions=directions,
175
+ show_tensor_labels=_resolve_flag(show_tensor_labels, style.show_tensor_labels),
176
+ show_index_labels=_resolve_flag(show_index_labels, style.show_index_labels),
177
+ config=style,
178
+ scale=scale,
179
+ )
180
+ fig.subplots_adjust(left=0.02, right=0.98, bottom=0.02, top=0.98)
181
+ return fig, ax
@@ -0,0 +1,57 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, TypeAlias
4
+
5
+ import matplotlib.pyplot as plt
6
+ from matplotlib.axes import Axes
7
+ from matplotlib.figure import Figure
8
+ from mpl_toolkits.mplot3d.axes3d import Axes3D
9
+
10
+ from .config import EngineName, PlotConfig, ViewName
11
+ from .tensorkrowch import (
12
+ plot_tensorkrowch_network_2d,
13
+ plot_tensorkrowch_network_3d,
14
+ )
15
+
16
+ RenderedAxes: TypeAlias = Axes | Axes3D
17
+
18
+
19
+ def show_tensor_network(
20
+ network: Any,
21
+ *,
22
+ engine: EngineName,
23
+ view: ViewName,
24
+ config: PlotConfig | None = None,
25
+ show: bool = True,
26
+ ) -> tuple[Figure, RenderedAxes]:
27
+ """Render a tensor network and optionally display the figure.
28
+
29
+ Args:
30
+ network: Tensor network object (must expose 'nodes' or 'leaf_nodes' with
31
+ nodes having 'edges', 'axes_names', and 'name' attributes).
32
+ engine: Rendering engine; currently only "tensorkrowch" is supported.
33
+ view: "2d" or "3d" visualization mode.
34
+ config: Optional styling; uses defaults if None.
35
+ show: If True, call plt.show() to display the figure. Set False when
36
+ integrating into other applications (e.g. adding a title before showing).
37
+
38
+ Returns:
39
+ Tuple of (Figure, Axes) for further customization.
40
+
41
+ Example:
42
+ >>> config = PlotConfig(figsize=(8, 6))
43
+ >>> fig, ax = show_tensor_network(network, engine="tensorkrowch", view="2d", config=config)
44
+ """
45
+ style = config or PlotConfig()
46
+
47
+ if engine == "tensorkrowch":
48
+ if view == "2d":
49
+ fig, ax = plot_tensorkrowch_network_2d(network, config=style)
50
+ else:
51
+ fig, ax = plot_tensorkrowch_network_3d(network, config=style)
52
+ else:
53
+ raise ValueError(f"Unsupported tensor network engine: {engine}")
54
+
55
+ if show:
56
+ plt.show()
57
+ return fig, ax