Trajectree 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- trajectree/__init__.py +0 -3
- trajectree/fock_optics/devices.py +1 -1
- trajectree/fock_optics/light_sources.py +2 -2
- trajectree/fock_optics/measurement.py +3 -3
- trajectree/fock_optics/utils.py +6 -6
- trajectree/trajectory.py +2 -2
- {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/METADATA +2 -3
- trajectree-0.0.2.dist-info/RECORD +16 -0
- trajectree/quimb/docs/_pygments/_pygments_dark.py +0 -118
- trajectree/quimb/docs/_pygments/_pygments_light.py +0 -118
- trajectree/quimb/docs/conf.py +0 -158
- trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +0 -62
- trajectree/quimb/quimb/__init__.py +0 -507
- trajectree/quimb/quimb/calc.py +0 -1491
- trajectree/quimb/quimb/core.py +0 -2279
- trajectree/quimb/quimb/evo.py +0 -712
- trajectree/quimb/quimb/experimental/__init__.py +0 -0
- trajectree/quimb/quimb/experimental/autojittn.py +0 -129
- trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +0 -109
- trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +0 -397
- trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +0 -316
- trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +0 -653
- trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +0 -571
- trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +0 -775
- trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +0 -316
- trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +0 -537
- trajectree/quimb/quimb/experimental/belief_propagation/regions.py +0 -194
- trajectree/quimb/quimb/experimental/cluster_update.py +0 -286
- trajectree/quimb/quimb/experimental/merabuilder.py +0 -865
- trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +0 -15
- trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +0 -1631
- trajectree/quimb/quimb/experimental/schematic.py +0 -7
- trajectree/quimb/quimb/experimental/tn_marginals.py +0 -130
- trajectree/quimb/quimb/experimental/tnvmc.py +0 -1483
- trajectree/quimb/quimb/gates.py +0 -36
- trajectree/quimb/quimb/gen/__init__.py +0 -2
- trajectree/quimb/quimb/gen/operators.py +0 -1167
- trajectree/quimb/quimb/gen/rand.py +0 -713
- trajectree/quimb/quimb/gen/states.py +0 -479
- trajectree/quimb/quimb/linalg/__init__.py +0 -6
- trajectree/quimb/quimb/linalg/approx_spectral.py +0 -1109
- trajectree/quimb/quimb/linalg/autoblock.py +0 -258
- trajectree/quimb/quimb/linalg/base_linalg.py +0 -719
- trajectree/quimb/quimb/linalg/mpi_launcher.py +0 -397
- trajectree/quimb/quimb/linalg/numpy_linalg.py +0 -244
- trajectree/quimb/quimb/linalg/rand_linalg.py +0 -514
- trajectree/quimb/quimb/linalg/scipy_linalg.py +0 -293
- trajectree/quimb/quimb/linalg/slepc_linalg.py +0 -892
- trajectree/quimb/quimb/schematic.py +0 -1518
- trajectree/quimb/quimb/tensor/__init__.py +0 -401
- trajectree/quimb/quimb/tensor/array_ops.py +0 -610
- trajectree/quimb/quimb/tensor/circuit.py +0 -4824
- trajectree/quimb/quimb/tensor/circuit_gen.py +0 -411
- trajectree/quimb/quimb/tensor/contraction.py +0 -336
- trajectree/quimb/quimb/tensor/decomp.py +0 -1255
- trajectree/quimb/quimb/tensor/drawing.py +0 -1646
- trajectree/quimb/quimb/tensor/fitting.py +0 -385
- trajectree/quimb/quimb/tensor/geometry.py +0 -583
- trajectree/quimb/quimb/tensor/interface.py +0 -114
- trajectree/quimb/quimb/tensor/networking.py +0 -1058
- trajectree/quimb/quimb/tensor/optimize.py +0 -1818
- trajectree/quimb/quimb/tensor/tensor_1d.py +0 -4778
- trajectree/quimb/quimb/tensor/tensor_1d_compress.py +0 -1854
- trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +0 -662
- trajectree/quimb/quimb/tensor/tensor_2d.py +0 -5954
- trajectree/quimb/quimb/tensor/tensor_2d_compress.py +0 -96
- trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +0 -1230
- trajectree/quimb/quimb/tensor/tensor_3d.py +0 -2869
- trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +0 -46
- trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +0 -60
- trajectree/quimb/quimb/tensor/tensor_arbgeom.py +0 -3237
- trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +0 -565
- trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +0 -1138
- trajectree/quimb/quimb/tensor/tensor_builder.py +0 -5411
- trajectree/quimb/quimb/tensor/tensor_core.py +0 -11179
- trajectree/quimb/quimb/tensor/tensor_dmrg.py +0 -1472
- trajectree/quimb/quimb/tensor/tensor_mera.py +0 -204
- trajectree/quimb/quimb/utils.py +0 -892
- trajectree/quimb/tests/__init__.py +0 -0
- trajectree/quimb/tests/test_accel.py +0 -501
- trajectree/quimb/tests/test_calc.py +0 -788
- trajectree/quimb/tests/test_core.py +0 -847
- trajectree/quimb/tests/test_evo.py +0 -565
- trajectree/quimb/tests/test_gen/__init__.py +0 -0
- trajectree/quimb/tests/test_gen/test_operators.py +0 -361
- trajectree/quimb/tests/test_gen/test_rand.py +0 -296
- trajectree/quimb/tests/test_gen/test_states.py +0 -261
- trajectree/quimb/tests/test_linalg/__init__.py +0 -0
- trajectree/quimb/tests/test_linalg/test_approx_spectral.py +0 -368
- trajectree/quimb/tests/test_linalg/test_base_linalg.py +0 -351
- trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +0 -127
- trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +0 -84
- trajectree/quimb/tests/test_linalg/test_rand_linalg.py +0 -134
- trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +0 -283
- trajectree/quimb/tests/test_tensor/__init__.py +0 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +0 -39
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +0 -67
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +0 -64
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +0 -51
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +0 -142
- trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +0 -101
- trajectree/quimb/tests/test_tensor/test_circuit.py +0 -816
- trajectree/quimb/tests/test_tensor/test_contract.py +0 -67
- trajectree/quimb/tests/test_tensor/test_decomp.py +0 -40
- trajectree/quimb/tests/test_tensor/test_mera.py +0 -52
- trajectree/quimb/tests/test_tensor/test_optimizers.py +0 -488
- trajectree/quimb/tests/test_tensor/test_tensor_1d.py +0 -1171
- trajectree/quimb/tests/test_tensor/test_tensor_2d.py +0 -606
- trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +0 -144
- trajectree/quimb/tests/test_tensor/test_tensor_3d.py +0 -123
- trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +0 -226
- trajectree/quimb/tests/test_tensor/test_tensor_builder.py +0 -441
- trajectree/quimb/tests/test_tensor/test_tensor_core.py +0 -2066
- trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +0 -388
- trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +0 -63
- trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +0 -270
- trajectree/quimb/tests/test_utils.py +0 -85
- trajectree-0.0.1.dist-info/RECORD +0 -126
- {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/WHEEL +0 -0
- {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/licenses/LICENSE +0 -0
- {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/top_level.txt +0 -0
|
@@ -1,1646 +0,0 @@
|
|
|
1
|
-
"""Functionailty for drawing tensor networks.
|
|
2
|
-
"""
|
|
3
|
-
import collections
|
|
4
|
-
import importlib
|
|
5
|
-
import textwrap
|
|
6
|
-
import warnings
|
|
7
|
-
|
|
8
|
-
import numpy as np
|
|
9
|
-
|
|
10
|
-
from ..utils import autocorrect_kwargs, check_opt, valmap
|
|
11
|
-
|
|
12
|
-
# from ..schematic import average_color, darken_color, auto_colors, hash_to_color
|
|
13
|
-
|
|
14
|
-
HAS_FA2 = importlib.util.find_spec("fa2") is not None
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
@autocorrect_kwargs
|
|
18
|
-
def draw_tn(
|
|
19
|
-
tn,
|
|
20
|
-
color=None,
|
|
21
|
-
*,
|
|
22
|
-
show_inds=None,
|
|
23
|
-
show_tags=None,
|
|
24
|
-
output_inds=None,
|
|
25
|
-
highlight_inds=(),
|
|
26
|
-
highlight_tids=(),
|
|
27
|
-
highlight_inds_color=(1.0, 0.2, 0.2),
|
|
28
|
-
highlight_tids_color=(1.0, 0.2, 0.2),
|
|
29
|
-
custom_colors=None,
|
|
30
|
-
legend="auto",
|
|
31
|
-
dim=2,
|
|
32
|
-
fix=None,
|
|
33
|
-
layout="auto",
|
|
34
|
-
initial_layout="auto",
|
|
35
|
-
refine_layout="auto",
|
|
36
|
-
iterations="auto",
|
|
37
|
-
k=None,
|
|
38
|
-
pos=None,
|
|
39
|
-
node_color=None,
|
|
40
|
-
node_scale=1.0,
|
|
41
|
-
node_size=None,
|
|
42
|
-
node_alpha=1.0,
|
|
43
|
-
node_shape="o",
|
|
44
|
-
node_outline_size=None,
|
|
45
|
-
node_outline_darkness=0.9,
|
|
46
|
-
node_hatch="",
|
|
47
|
-
edge_color=None,
|
|
48
|
-
edge_scale=1.0,
|
|
49
|
-
edge_alpha=1 / 2,
|
|
50
|
-
multi_edge_spread=0.1,
|
|
51
|
-
multi_tag_style="auto",
|
|
52
|
-
show_left_inds=True,
|
|
53
|
-
arrow_opts=None,
|
|
54
|
-
label_color=None,
|
|
55
|
-
font_size=10,
|
|
56
|
-
font_size_inner=7,
|
|
57
|
-
font_family="monospace",
|
|
58
|
-
isdark=None,
|
|
59
|
-
title=None,
|
|
60
|
-
backend="matplotlib",
|
|
61
|
-
figsize=(6, 6),
|
|
62
|
-
margin=None,
|
|
63
|
-
xlims=None,
|
|
64
|
-
ylims=None,
|
|
65
|
-
get=None,
|
|
66
|
-
return_fig=False,
|
|
67
|
-
ax=None,
|
|
68
|
-
):
|
|
69
|
-
"""Plot this tensor network as a networkx graph using matplotlib,
|
|
70
|
-
with edge width corresponding to bond dimension.
|
|
71
|
-
|
|
72
|
-
Parameters
|
|
73
|
-
----------
|
|
74
|
-
color : sequence of tags, optional
|
|
75
|
-
If given, uniquely color any tensors which have each of the tags.
|
|
76
|
-
If some tensors have more than of the tags, only one color will show.
|
|
77
|
-
output_inds : sequence of str, optional
|
|
78
|
-
For hyper tensor networks explicitly specify which indices should be
|
|
79
|
-
drawn as outer indices. If not set, the outer indices are assumed to be
|
|
80
|
-
those that only appear on a single tensor.
|
|
81
|
-
highlight_inds : iterable, optional
|
|
82
|
-
Highlight these edges.
|
|
83
|
-
highlight_tids : iterable, optional
|
|
84
|
-
Highlight these nodes.
|
|
85
|
-
highlight_inds_color
|
|
86
|
-
What color to use for ``highlight_inds`` nodes.
|
|
87
|
-
highlight_tids_color : tuple[float], optional
|
|
88
|
-
What color to use for ``highlight_tids`` nodes.
|
|
89
|
-
show_inds : {None, False, True, 'all', 'bond-size'}, optional
|
|
90
|
-
Explicitly turn on labels for each tensors indices.
|
|
91
|
-
show_tags : {None, False, True}, optional
|
|
92
|
-
Explicitly turn on labels for each tensors tags.
|
|
93
|
-
custom_colors : sequence of colors, optional
|
|
94
|
-
Supply a custom sequence of colors to match the tags given
|
|
95
|
-
in ``color``.
|
|
96
|
-
title : str, optional
|
|
97
|
-
Set a title for the axis.
|
|
98
|
-
legend : "auto" or bool, optional
|
|
99
|
-
Whether to draw a legend for the colored tags. If ``"auto"`` then
|
|
100
|
-
only draw a legend if there are less than 20 tags.
|
|
101
|
-
dim : {2, 2.5, 3}, optional
|
|
102
|
-
What dimension to position the graph nodes in. 2.5 positions the nodes
|
|
103
|
-
in 3D but then projects then down to 2D.
|
|
104
|
-
fix : dict[tags_ind_or_tid], (float, float)], optional
|
|
105
|
-
Used to specify actual relative positions for each tensor node.
|
|
106
|
-
Each key should be a sequence of tags that uniquely identifies a
|
|
107
|
-
tensor, a ``tid``, or a ``ind``, and each value should be a ``(x, y)``
|
|
108
|
-
coordinate tuple.
|
|
109
|
-
layout : str, optional
|
|
110
|
-
How to layout the graph. Can be any of the following:
|
|
111
|
-
|
|
112
|
-
- ``'auto'``: layout the graph using a networkx method then relax
|
|
113
|
-
the layout using a force-directed algorithm.
|
|
114
|
-
- a networkx layout method name, e.g. ``'kamada_kawai'``: just
|
|
115
|
-
layout the graph using a networkx method, with no relaxation.
|
|
116
|
-
- a graphviz method such as ``'dot'``, ``'neato'`` or ``'sfdp'``:
|
|
117
|
-
layout the graph using ``pygraphviz``.
|
|
118
|
-
|
|
119
|
-
initial_layout : {'auto', 'spectral', 'kamada_kawai', 'circular', \\
|
|
120
|
-
'planar', 'random', 'shell', 'bipartite', ...}, optional
|
|
121
|
-
If ``layout == 'auto'`` The name of a networkx layout to use before
|
|
122
|
-
iterating with the spring layout. Set `layout` directly or
|
|
123
|
-
``iterations=0`` if you don't want any spring relaxation.
|
|
124
|
-
iterations : int, optional
|
|
125
|
-
How many iterations to perform when when finding the best layout
|
|
126
|
-
using node repulsion. Ramp this up if the graph is drawing messily.
|
|
127
|
-
k : float, optional
|
|
128
|
-
The optimal distance between nodes.
|
|
129
|
-
pos : dict, optional
|
|
130
|
-
Pre-computed positions for the nodes. If given, this will override
|
|
131
|
-
``layout``. The nodes shouuld be exactly the same as the nodes in the
|
|
132
|
-
graph returned by ``draw(get='graph')``.
|
|
133
|
-
node_color : tuple[float], optional
|
|
134
|
-
Default color of nodes.
|
|
135
|
-
node_scale : float, optional
|
|
136
|
-
Scale the node sizes by this factor, in addition to the automatic
|
|
137
|
-
scaling based on the number of tensors.
|
|
138
|
-
node_size : None, float or dict, optional
|
|
139
|
-
How big to draw the tensors. Can be a global single value, or a dict
|
|
140
|
-
containing values for specific tags or tids. This is in absolute
|
|
141
|
-
figure units. See ``node_scale`` simply scale the node sizes up or
|
|
142
|
-
down.
|
|
143
|
-
node_alpha : float, optional
|
|
144
|
-
Transparency of the nodes.
|
|
145
|
-
node_shape : None, str or dict, optional
|
|
146
|
-
What shape to draw the tensors. Should correspond to a matplotlib
|
|
147
|
-
scatter marker. Can be a global single value, or a dict containing
|
|
148
|
-
values for specific tags or tids.
|
|
149
|
-
node_outline_size : None, float or dict, optional
|
|
150
|
-
The width of the border of each node. Can be a global single value, or
|
|
151
|
-
a dict containing values for specific tags or tids.
|
|
152
|
-
node_outline_darkness : float, optional
|
|
153
|
-
Darkening of nodes outlines.
|
|
154
|
-
edge_color : tuple[float], optional
|
|
155
|
-
Default color of edges.
|
|
156
|
-
edge_scale : float, optional
|
|
157
|
-
How much to scale the width of the edges.
|
|
158
|
-
edge_alpha : float, optional
|
|
159
|
-
Set the alpha (opacity) of the drawn edges.
|
|
160
|
-
multi_edge_spread : float, optional
|
|
161
|
-
How much to spread the lines of multi-edges.
|
|
162
|
-
show_left_inds : bool, optional
|
|
163
|
-
Whether to show ``tensor.left_inds`` as incoming arrows.
|
|
164
|
-
arrow_closeness : float, optional
|
|
165
|
-
How close to draw the arrow to its target.
|
|
166
|
-
arrow_length : float, optional
|
|
167
|
-
The size of the arrow with respect to the edge.
|
|
168
|
-
arrow_overhang : float, optional
|
|
169
|
-
Varies the arrowhead between a triangle (0.0) and 'V' (1.0).
|
|
170
|
-
arrow_linewidth : float, optional
|
|
171
|
-
The width of the arrow line itself.
|
|
172
|
-
label_color : tuple[float], optional
|
|
173
|
-
Color to draw labels with.
|
|
174
|
-
font_size : int, optional
|
|
175
|
-
Font size for drawing tags and outer indices.
|
|
176
|
-
font_size_inner : int, optional
|
|
177
|
-
Font size for drawing inner indices.
|
|
178
|
-
font_family : str, optional
|
|
179
|
-
Font family to use for all labels.
|
|
180
|
-
isdark : bool, optional
|
|
181
|
-
Explicitly specify that the background is dark, and use slightly
|
|
182
|
-
different default drawing colors. If not specified detects
|
|
183
|
-
automatically from `matplotlib.rcParams`.
|
|
184
|
-
figsize : tuple of int, optional
|
|
185
|
-
The size of the drawing.
|
|
186
|
-
margin : None or float, optional
|
|
187
|
-
Specify an argument for ``ax.margin``, else the plot limits will try
|
|
188
|
-
and be computed based on the node positions and node sizes.
|
|
189
|
-
xlims : None or tuple, optional
|
|
190
|
-
Explicitly set the x plot range.
|
|
191
|
-
xlims : None or tuple, optional
|
|
192
|
-
Explicitly set the y plot range.
|
|
193
|
-
get : {None, 'pos', 'graph'}, optional
|
|
194
|
-
If ``None`` then plot as normal, else if:
|
|
195
|
-
|
|
196
|
-
- ``'pos'``, return the plotting positions of each ``tid`` and
|
|
197
|
-
``ind`` drawn as a node, this can supplied to subsequent calls as
|
|
198
|
-
``fix=pos`` to maintain positions, even as the graph structure
|
|
199
|
-
changes.
|
|
200
|
-
- ``'graph'``, return the ``networkx.Graph`` object. Note that this
|
|
201
|
-
will potentially have extra nodes representing output and hyper
|
|
202
|
-
indices.
|
|
203
|
-
|
|
204
|
-
return_fig : bool, optional
|
|
205
|
-
If True and ``ax is None`` then return the figure created rather than
|
|
206
|
-
executing ``pyplot.show()``.
|
|
207
|
-
ax : matplotlib.Axis, optional
|
|
208
|
-
Draw the graph on this axis rather than creating a new figure.
|
|
209
|
-
"""
|
|
210
|
-
import math
|
|
211
|
-
|
|
212
|
-
import matplotlib as mpl
|
|
213
|
-
import networkx as nx
|
|
214
|
-
from matplotlib.colors import to_rgb, to_rgba
|
|
215
|
-
|
|
216
|
-
from ..schematic import darken_color, hash_to_color
|
|
217
|
-
|
|
218
|
-
check_opt(
|
|
219
|
-
"multi_tag_style",
|
|
220
|
-
multi_tag_style,
|
|
221
|
-
("auto", "pie", "nest", "average", "last"),
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
if output_inds is None:
|
|
225
|
-
output_inds = set(tn.outer_inds())
|
|
226
|
-
elif isinstance(output_inds, str):
|
|
227
|
-
output_inds = {output_inds}
|
|
228
|
-
else:
|
|
229
|
-
output_inds = set(output_inds)
|
|
230
|
-
|
|
231
|
-
# automatically decide whether to show tags and inds
|
|
232
|
-
if show_inds is None:
|
|
233
|
-
show_inds = len(tn.outer_inds()) <= 20
|
|
234
|
-
show_inds = {False: "", True: "outer"}.get(show_inds, show_inds)
|
|
235
|
-
|
|
236
|
-
if show_tags is None:
|
|
237
|
-
show_tags = len(tn.tag_map) <= 20
|
|
238
|
-
show_tags = {False: "", True: "tags"}.get(show_tags, show_tags)
|
|
239
|
-
|
|
240
|
-
if isdark is None:
|
|
241
|
-
isdark = sum(to_rgb(mpl.rcParams["figure.facecolor"])) / 3 < 0.5
|
|
242
|
-
|
|
243
|
-
if isdark:
|
|
244
|
-
default_draw_color = (0.55, 0.57, 0.60, 1.0)
|
|
245
|
-
default_label_color = (0.85, 0.86, 0.87, 1.0)
|
|
246
|
-
else:
|
|
247
|
-
default_draw_color = (0.45, 0.47, 0.50, 1.0)
|
|
248
|
-
default_label_color = (0.33, 0.34, 0.35, 1.0)
|
|
249
|
-
|
|
250
|
-
if edge_color is None:
|
|
251
|
-
edge_color = mpl.colors.to_rgba(default_draw_color, edge_alpha)
|
|
252
|
-
elif edge_color is True:
|
|
253
|
-
# hash edge to get color
|
|
254
|
-
pass
|
|
255
|
-
else:
|
|
256
|
-
edge_color = mpl.colors.to_rgba(edge_color, edge_alpha)
|
|
257
|
-
|
|
258
|
-
if node_color is None:
|
|
259
|
-
node_color = mpl.colors.to_rgba(default_draw_color, node_alpha)
|
|
260
|
-
else:
|
|
261
|
-
node_color = mpl.colors.to_rgba(node_color, node_alpha)
|
|
262
|
-
|
|
263
|
-
if label_color is None:
|
|
264
|
-
label_color = default_label_color
|
|
265
|
-
elif label_color == "inherit":
|
|
266
|
-
label_color = mpl.rcParams["axes.labelcolor"]
|
|
267
|
-
|
|
268
|
-
# get colors for tagged nodes
|
|
269
|
-
colors = get_colors(color, custom_colors, node_alpha)
|
|
270
|
-
|
|
271
|
-
if legend == "auto":
|
|
272
|
-
legend = len(colors) <= 20
|
|
273
|
-
|
|
274
|
-
highlight_tids_color = to_rgba(highlight_tids_color, node_alpha)
|
|
275
|
-
highlight_inds_color = to_rgba(highlight_inds_color, edge_alpha)
|
|
276
|
-
|
|
277
|
-
# set the size of the nodes and their border
|
|
278
|
-
node_size = parse_dict_to_tids_or_inds(
|
|
279
|
-
node_size,
|
|
280
|
-
tn,
|
|
281
|
-
default=1,
|
|
282
|
-
)
|
|
283
|
-
node_outline_size = parse_dict_to_tids_or_inds(
|
|
284
|
-
node_outline_size,
|
|
285
|
-
tn,
|
|
286
|
-
default=1,
|
|
287
|
-
)
|
|
288
|
-
node_shape = parse_dict_to_tids_or_inds(node_shape, tn, default="o")
|
|
289
|
-
node_hatch = parse_dict_to_tids_or_inds(node_hatch, tn, default="")
|
|
290
|
-
|
|
291
|
-
# build the graph
|
|
292
|
-
edges = collections.defaultdict(lambda: collections.defaultdict(list))
|
|
293
|
-
nodes = collections.defaultdict(dict)
|
|
294
|
-
|
|
295
|
-
# parse all indices / edges
|
|
296
|
-
for ix, tids in tn.ind_map.items():
|
|
297
|
-
tids = sorted(tids)
|
|
298
|
-
|
|
299
|
-
isouter = ix in output_inds
|
|
300
|
-
ishyper = isouter or (len(tids) != 2)
|
|
301
|
-
ind_size = tn.ind_size(ix)
|
|
302
|
-
edge_size = edge_scale * math.log2(ind_size)
|
|
303
|
-
|
|
304
|
-
# compute a color for this index
|
|
305
|
-
color = (
|
|
306
|
-
highlight_inds_color
|
|
307
|
-
if ix in highlight_inds
|
|
308
|
-
else to_rgba(hash_to_color(ix))
|
|
309
|
-
if edge_color is True
|
|
310
|
-
else edge_color
|
|
311
|
-
)
|
|
312
|
-
|
|
313
|
-
# compute a label for this index
|
|
314
|
-
if ishyper:
|
|
315
|
-
# each tensor connects to the dummy node represeting the hyper edge
|
|
316
|
-
pairs = [(tid, ix) for tid in tids]
|
|
317
|
-
if isouter and len(tids) > 1:
|
|
318
|
-
# 'hyper outer' index
|
|
319
|
-
pairs.append((("outer", ix), ix))
|
|
320
|
-
# hyper labels get put on dummy node
|
|
321
|
-
label = ""
|
|
322
|
-
|
|
323
|
-
nodes[ix]["ind"] = ix
|
|
324
|
-
nodes[ix]["ind_size"] = ind_size
|
|
325
|
-
# make actual node invisible
|
|
326
|
-
nodes[ix]["color"] = (1.0, 1.0, 1.0, 1.0)
|
|
327
|
-
nodes[ix]["size"] = 0.0
|
|
328
|
-
nodes[ix]["outline_size"] = 0.0
|
|
329
|
-
nodes[ix]["outline_color"] = (1.0, 1.0, 1.0, 1.0)
|
|
330
|
-
nodes[ix]["marker"] = "." # set this to avoid warning - size is 0
|
|
331
|
-
nodes[ix]["hatch"] = ""
|
|
332
|
-
|
|
333
|
-
# set these for plotly hover info
|
|
334
|
-
nodes[ix]["tid"] = nodes[ix]["shape"] = nodes[ix]["tags"] = ""
|
|
335
|
-
|
|
336
|
-
if ((show_inds == "outer") and isouter) or (show_inds == "all"):
|
|
337
|
-
# show as outer index or inner index name
|
|
338
|
-
nodes[ix]["label"] = ix
|
|
339
|
-
elif show_inds == "bond-size":
|
|
340
|
-
# show all bond sizes
|
|
341
|
-
nodes[ix]["label"] = f"{tn.ind_size(ix)}"
|
|
342
|
-
else:
|
|
343
|
-
# labels hidden or inner edge
|
|
344
|
-
nodes[ix]["label"] = ""
|
|
345
|
-
|
|
346
|
-
nodes[ix]["label_fontsize"] = font_size_inner
|
|
347
|
-
nodes[ix]["label_color"] = label_color
|
|
348
|
-
nodes[ix]["label_fontfamily"] = font_family
|
|
349
|
-
|
|
350
|
-
else:
|
|
351
|
-
# standard edge
|
|
352
|
-
pairs = [tuple(tids)]
|
|
353
|
-
|
|
354
|
-
if show_inds == "all":
|
|
355
|
-
# show inner index name
|
|
356
|
-
label = ix
|
|
357
|
-
elif show_inds == "bond-size":
|
|
358
|
-
# show all bond sizes
|
|
359
|
-
label = f"{ind_size}"
|
|
360
|
-
else:
|
|
361
|
-
# labels hidden or inner edge
|
|
362
|
-
label = ""
|
|
363
|
-
|
|
364
|
-
for pair in pairs:
|
|
365
|
-
edges[pair]["color"].append(color)
|
|
366
|
-
edges[pair]["ind"].append(ix)
|
|
367
|
-
edges[pair]["ind_size"].append(ind_size)
|
|
368
|
-
edges[pair]["edge_size"].append(edge_size)
|
|
369
|
-
edges[pair]["label"].append(label)
|
|
370
|
-
edges[pair]["label_fontsize"] = font_size_inner
|
|
371
|
-
edges[pair]["label_color"] = label_color
|
|
372
|
-
edges[pair]["label_fontfamily"] = font_family
|
|
373
|
-
|
|
374
|
-
if isinstance(pair[0], tuple):
|
|
375
|
-
# dummy hyper outer edge - no arrows
|
|
376
|
-
edges[pair]["arrow_left"].append(False)
|
|
377
|
-
edges[pair]["arrow_right"].append(False)
|
|
378
|
-
else:
|
|
379
|
-
# tensor side can always have an incoming arrow
|
|
380
|
-
tl_left_inds = tn.tensor_map[pair[0]].left_inds
|
|
381
|
-
edges[pair]["arrow_left"].append(
|
|
382
|
-
show_left_inds
|
|
383
|
-
and (tl_left_inds is not None)
|
|
384
|
-
and (ix in tl_left_inds)
|
|
385
|
-
)
|
|
386
|
-
if ishyper:
|
|
387
|
-
# hyper edge can't have an incoming arrow
|
|
388
|
-
edges[pair]["arrow_right"].append(False)
|
|
389
|
-
else:
|
|
390
|
-
# standard edge can
|
|
391
|
-
tr_left_inds = tn.tensor_map[pair[1]].left_inds
|
|
392
|
-
edges[pair]["arrow_right"].append(
|
|
393
|
-
show_left_inds
|
|
394
|
-
and (tr_left_inds is not None)
|
|
395
|
-
and (ix in tr_left_inds)
|
|
396
|
-
)
|
|
397
|
-
|
|
398
|
-
# parse all tensors / nodes
|
|
399
|
-
for tid, t in tn.tensor_map.items():
|
|
400
|
-
nodes[tid]["tid"] = tid
|
|
401
|
-
nodes[tid]["tags"] = str(list(t.tags))
|
|
402
|
-
nodes[tid]["shape"] = str(t.shape)
|
|
403
|
-
nodes[tid]["size"] = node_size[tid]
|
|
404
|
-
nodes[tid]["outline_size"] = node_outline_size[tid]
|
|
405
|
-
nodes[tid]["marker"] = node_shape[tid]
|
|
406
|
-
nodes[tid]["hatch"] = node_hatch[tid]
|
|
407
|
-
|
|
408
|
-
if show_tags == "tags":
|
|
409
|
-
node_label = ", ".join(map(str, t.tags))
|
|
410
|
-
# make the tags appear with auto vertical extent
|
|
411
|
-
nodes[tid]["label"] = "\n".join(
|
|
412
|
-
textwrap.wrap(node_label, max(2 * len(node_label) ** 0.5, 16))
|
|
413
|
-
)
|
|
414
|
-
elif show_tags == "tids":
|
|
415
|
-
nodes[tid]["label"] = str(tid)
|
|
416
|
-
elif show_tags == "shape":
|
|
417
|
-
nodes[tid]["label"] = nodes[tid]["shape"]
|
|
418
|
-
else:
|
|
419
|
-
nodes[tid]["label"] = ""
|
|
420
|
-
|
|
421
|
-
nodes[tid]["label_fontsize"] = font_size
|
|
422
|
-
nodes[tid]["label_color"] = label_color
|
|
423
|
-
nodes[tid]["label_fontfamily"] = font_family
|
|
424
|
-
|
|
425
|
-
if tid in highlight_tids:
|
|
426
|
-
nodes[tid]["color"] = highlight_tids_color
|
|
427
|
-
nodes[tid]["outline_color"] = darken_color(
|
|
428
|
-
highlight_tids_color, node_outline_darkness
|
|
429
|
-
)
|
|
430
|
-
else:
|
|
431
|
-
# collect all relevant tag colors
|
|
432
|
-
multi_colors = []
|
|
433
|
-
multi_outline_colors = []
|
|
434
|
-
for tag in colors:
|
|
435
|
-
if tag in t.tags:
|
|
436
|
-
multi_colors.append(colors[tag])
|
|
437
|
-
multi_outline_colors.append(
|
|
438
|
-
darken_color(colors[tag], node_outline_darkness)
|
|
439
|
-
)
|
|
440
|
-
|
|
441
|
-
if len(multi_colors) >= 1:
|
|
442
|
-
# set the basic color to the last tag
|
|
443
|
-
nodes[tid]["color"] = multi_colors[-1]
|
|
444
|
-
nodes[tid]["outline_color"] = multi_outline_colors[-1]
|
|
445
|
-
if len(multi_colors) >= 2:
|
|
446
|
-
# have multiple relevant tags - store them, but some
|
|
447
|
-
# backends might support, so store alongside basic color
|
|
448
|
-
nodes[tid]["multi_colors"] = multi_colors
|
|
449
|
-
nodes[tid]["multi_outline_colors"] = multi_outline_colors
|
|
450
|
-
else:
|
|
451
|
-
# untagged node
|
|
452
|
-
nodes[tid]["color"] = node_color
|
|
453
|
-
nodes[tid]["outline_color"] = darken_color(
|
|
454
|
-
node_color, node_outline_darkness**2
|
|
455
|
-
)
|
|
456
|
-
|
|
457
|
-
G = nx.Graph()
|
|
458
|
-
for edge, edge_data in edges.items():
|
|
459
|
-
G.add_edge(*edge, **edge_data)
|
|
460
|
-
for node, node_data in nodes.items():
|
|
461
|
-
G.add_node(node, **node_data)
|
|
462
|
-
|
|
463
|
-
if pos is None:
|
|
464
|
-
pos = get_positions(
|
|
465
|
-
tn=tn,
|
|
466
|
-
G=G,
|
|
467
|
-
fix=fix,
|
|
468
|
-
layout=layout,
|
|
469
|
-
initial_layout=initial_layout,
|
|
470
|
-
refine_layout=refine_layout,
|
|
471
|
-
k=k,
|
|
472
|
-
dim=dim,
|
|
473
|
-
iterations=iterations,
|
|
474
|
-
)
|
|
475
|
-
else:
|
|
476
|
-
pos = _normalize_positions(pos)
|
|
477
|
-
|
|
478
|
-
# compute a base size using the position and number of tensors
|
|
479
|
-
# first get plot volume:
|
|
480
|
-
node_packing_factor = tn.num_tensors**-0.45
|
|
481
|
-
xs, ys, *zs = zip(*pos.values())
|
|
482
|
-
xmin, xmax = min(xs), max(xs)
|
|
483
|
-
ymin, ymax = min(ys), max(ys)
|
|
484
|
-
# if there only a few tensors we don't want to limit the node size
|
|
485
|
-
# because of flatness, also don't allow the plot volume to go to zero
|
|
486
|
-
xrange = max(((xmax - xmin) / 2, node_packing_factor, 0.1))
|
|
487
|
-
yrange = max(((ymax - ymin) / 2, node_packing_factor, 0.1))
|
|
488
|
-
plot_volume = xrange * yrange
|
|
489
|
-
if zs:
|
|
490
|
-
zmin, zmax = min(zs[0]), max(zs[0])
|
|
491
|
-
zrange = max(((zmax - zmin) / 2, node_packing_factor, 0.1))
|
|
492
|
-
plot_volume *= zrange
|
|
493
|
-
# in total we account for:
|
|
494
|
-
# - user specified scaling
|
|
495
|
-
# - number of tensors
|
|
496
|
-
# - how flat the plot area is (flatter requires smaller nodes)
|
|
497
|
-
full_node_scale = 0.2 * node_scale * node_packing_factor * plot_volume**0.5
|
|
498
|
-
|
|
499
|
-
default_outline_size = 6 * full_node_scale**0.5
|
|
500
|
-
|
|
501
|
-
# update node size and position attributes
|
|
502
|
-
for node, node_data in nodes.items():
|
|
503
|
-
nodes[node]["size"] = G.nodes[node]["size"] = (
|
|
504
|
-
full_node_scale * node_data["size"]
|
|
505
|
-
)
|
|
506
|
-
nodes[node]["outline_size"] = G.nodes[node]["outline_size"] = (
|
|
507
|
-
default_outline_size * node_data["outline_size"]
|
|
508
|
-
)
|
|
509
|
-
nodes[node]["coo"] = G.nodes[node]["coo"] = pos[node]
|
|
510
|
-
|
|
511
|
-
for (i, j), edge_data in edges.items():
|
|
512
|
-
edges[i, j]["coos"] = G.edges[i, j]["coos"] = pos[i], pos[j]
|
|
513
|
-
|
|
514
|
-
if get == "pos":
|
|
515
|
-
return pos
|
|
516
|
-
if get == "graph,pos":
|
|
517
|
-
return G, pos
|
|
518
|
-
|
|
519
|
-
opts = {
|
|
520
|
-
"colors": colors,
|
|
521
|
-
"node_outline_darkness": node_outline_darkness,
|
|
522
|
-
"title": title,
|
|
523
|
-
"legend": legend,
|
|
524
|
-
"multi_edge_spread": multi_edge_spread,
|
|
525
|
-
"multi_tag_style": multi_tag_style,
|
|
526
|
-
"arrow_opts": arrow_opts,
|
|
527
|
-
"label_color": label_color,
|
|
528
|
-
"font_family": font_family,
|
|
529
|
-
"figsize": figsize,
|
|
530
|
-
"margin": margin,
|
|
531
|
-
"xlims": xlims,
|
|
532
|
-
"ylims": ylims,
|
|
533
|
-
"return_fig": return_fig,
|
|
534
|
-
"ax": ax,
|
|
535
|
-
}
|
|
536
|
-
|
|
537
|
-
if get == "data":
|
|
538
|
-
return edges, nodes, opts
|
|
539
|
-
|
|
540
|
-
if backend == "matplotlib":
|
|
541
|
-
return _draw_matplotlib(edges=edges, nodes=nodes, **opts)
|
|
542
|
-
|
|
543
|
-
if backend == "matplotlib3d":
|
|
544
|
-
return _draw_matplotlib3d(G, **opts)
|
|
545
|
-
|
|
546
|
-
if backend == "plotly":
|
|
547
|
-
return _draw_plotly(G, **opts)
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
def parse_dict_to_tids_or_inds(spec, tn, default="__NONE__"):
|
|
551
|
-
"""Parse a dictionary possibly containing a mix of tags, tids and inds, to
|
|
552
|
-
a dictionary with only sinlge tids and inds as keys. If a tag or set of
|
|
553
|
-
tags are given as a key, all matching tensor tids will receive the value.
|
|
554
|
-
"""
|
|
555
|
-
#
|
|
556
|
-
if (spec is not None) and (not isinstance(spec, dict)):
|
|
557
|
-
# assume new default value for everything
|
|
558
|
-
return collections.defaultdict(lambda: spec)
|
|
559
|
-
|
|
560
|
-
# allow not specifying a default value
|
|
561
|
-
if default != "__NONE__":
|
|
562
|
-
new = collections.defaultdict(lambda: default)
|
|
563
|
-
else:
|
|
564
|
-
new = {}
|
|
565
|
-
|
|
566
|
-
if spec is None:
|
|
567
|
-
return new
|
|
568
|
-
|
|
569
|
-
# parse the special values
|
|
570
|
-
for k, v in spec.items():
|
|
571
|
-
if (
|
|
572
|
-
# given as tid
|
|
573
|
-
(isinstance(k, int) and k in tn.tensor_map)
|
|
574
|
-
or
|
|
575
|
-
# given as ind
|
|
576
|
-
(isinstance(k, str) and k in tn.ind_map)
|
|
577
|
-
):
|
|
578
|
-
# already a tid
|
|
579
|
-
new[k] = v
|
|
580
|
-
continue
|
|
581
|
-
|
|
582
|
-
try:
|
|
583
|
-
for tid in tn._get_tids_from_tags(k):
|
|
584
|
-
new[tid] = v
|
|
585
|
-
except KeyError:
|
|
586
|
-
# just ignore keys that don't match any tensor
|
|
587
|
-
pass
|
|
588
|
-
|
|
589
|
-
return new
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
def _add_legend_matplotlib(
|
|
593
|
-
ax, colors, legend, node_outline_darkness, label_color, font_family
|
|
594
|
-
):
|
|
595
|
-
import matplotlib.pyplot as plt
|
|
596
|
-
|
|
597
|
-
# create legend
|
|
598
|
-
if colors and legend:
|
|
599
|
-
handles = []
|
|
600
|
-
for color in colors.values():
|
|
601
|
-
ecolor = tuple(
|
|
602
|
-
(1.0 if i == 3 else node_outline_darkness) * c
|
|
603
|
-
for i, c in enumerate(color)
|
|
604
|
-
)
|
|
605
|
-
handles += [
|
|
606
|
-
plt.Line2D(
|
|
607
|
-
[0],
|
|
608
|
-
[0],
|
|
609
|
-
marker="o",
|
|
610
|
-
color=color,
|
|
611
|
-
markeredgecolor=ecolor,
|
|
612
|
-
markeredgewidth=1,
|
|
613
|
-
linestyle="",
|
|
614
|
-
markersize=10,
|
|
615
|
-
)
|
|
616
|
-
]
|
|
617
|
-
|
|
618
|
-
# needed in case '_' is the first character
|
|
619
|
-
lbls = [f" {lbl}" for lbl in colors]
|
|
620
|
-
|
|
621
|
-
legend = ax.legend(
|
|
622
|
-
handles,
|
|
623
|
-
lbls,
|
|
624
|
-
ncol=max(round(len(handles) / 20), 1),
|
|
625
|
-
loc="center left",
|
|
626
|
-
bbox_to_anchor=(1, 0.5),
|
|
627
|
-
labelcolor=label_color,
|
|
628
|
-
prop={"family": font_family},
|
|
629
|
-
)
|
|
630
|
-
# do this manually as otherwise can't make only face transparent
|
|
631
|
-
legend.get_frame().set_alpha(None)
|
|
632
|
-
legend.get_frame().set_facecolor((0.0, 0.0, 0.0, 0.0))
|
|
633
|
-
legend.get_frame().set_edgecolor((0.6, 0.6, 0.6, 0.2))
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
def _draw_matplotlib(
|
|
637
|
-
edges,
|
|
638
|
-
nodes,
|
|
639
|
-
*,
|
|
640
|
-
colors=None,
|
|
641
|
-
node_outline_darkness=0.9,
|
|
642
|
-
title=None,
|
|
643
|
-
legend=True,
|
|
644
|
-
multi_edge_spread=0.1,
|
|
645
|
-
multi_tag_style="auto",
|
|
646
|
-
arrow_opts=None,
|
|
647
|
-
label_color=None,
|
|
648
|
-
font_family="monospace",
|
|
649
|
-
figsize=(6, 6),
|
|
650
|
-
margin=None,
|
|
651
|
-
xlims=None,
|
|
652
|
-
ylims=None,
|
|
653
|
-
return_fig=False,
|
|
654
|
-
ax=None,
|
|
655
|
-
):
|
|
656
|
-
import matplotlib.pyplot as plt
|
|
657
|
-
|
|
658
|
-
from quimb.schematic import Drawing, average_color
|
|
659
|
-
|
|
660
|
-
d = Drawing(figsize=figsize, ax=ax)
|
|
661
|
-
if ax is None:
|
|
662
|
-
fig = d.fig
|
|
663
|
-
ax = d.ax
|
|
664
|
-
fig.patch.set_alpha(0.0)
|
|
665
|
-
ax.patch.set_alpha(0.0)
|
|
666
|
-
else:
|
|
667
|
-
fig = None
|
|
668
|
-
|
|
669
|
-
arrow_opts = arrow_opts or {}
|
|
670
|
-
arrow_opts.setdefault("center", 3 / 4)
|
|
671
|
-
arrow_opts.setdefault("linewidth", 1)
|
|
672
|
-
arrow_opts.setdefault("width", 0.08)
|
|
673
|
-
arrow_opts.setdefault("length", 0.12)
|
|
674
|
-
|
|
675
|
-
if title is not None:
|
|
676
|
-
ax.set_title(str(title))
|
|
677
|
-
|
|
678
|
-
for _, edge_data in edges.items():
|
|
679
|
-
cooa, coob = edge_data["coos"]
|
|
680
|
-
edge_colors = edge_data["color"]
|
|
681
|
-
edge_sizes = edge_data["edge_size"]
|
|
682
|
-
labels = edge_data["label"]
|
|
683
|
-
arrow_lefts = edge_data["arrow_left"]
|
|
684
|
-
arrow_rights = edge_data["arrow_right"]
|
|
685
|
-
multiplicity = len(edge_colors)
|
|
686
|
-
|
|
687
|
-
if multiplicity > 1:
|
|
688
|
-
offsets = np.linspace(
|
|
689
|
-
+multiplicity * multi_edge_spread / 2,
|
|
690
|
-
-multiplicity * multi_edge_spread / 2,
|
|
691
|
-
multiplicity,
|
|
692
|
-
)
|
|
693
|
-
else:
|
|
694
|
-
offsets = None
|
|
695
|
-
|
|
696
|
-
for m in range(multiplicity):
|
|
697
|
-
line_opts = dict(
|
|
698
|
-
cooa=cooa,
|
|
699
|
-
coob=coob,
|
|
700
|
-
linewidth=edge_sizes[m],
|
|
701
|
-
color=edge_colors[m],
|
|
702
|
-
)
|
|
703
|
-
|
|
704
|
-
arrowhead, reverse = {
|
|
705
|
-
(False, False): (None, False), # no arrow
|
|
706
|
-
(False, True): (True, False), # arrowhead to right
|
|
707
|
-
(True, False): (True, True), # arrowhead to left
|
|
708
|
-
(True, True): (True, "both"), # arrowheads both sides
|
|
709
|
-
}[arrow_lefts[m], arrow_rights[m]]
|
|
710
|
-
|
|
711
|
-
if arrowhead:
|
|
712
|
-
line_opts["arrowhead"] = dict(
|
|
713
|
-
reverse=reverse,
|
|
714
|
-
**arrow_opts,
|
|
715
|
-
)
|
|
716
|
-
|
|
717
|
-
if labels[m]:
|
|
718
|
-
line_opts["text"] = dict(
|
|
719
|
-
text=labels[m],
|
|
720
|
-
fontsize=edge_data["label_fontsize"],
|
|
721
|
-
color=edge_data["label_color"],
|
|
722
|
-
fontfamily=edge_data["label_fontfamily"],
|
|
723
|
-
)
|
|
724
|
-
|
|
725
|
-
if multiplicity > 1:
|
|
726
|
-
d.line_offset(offset=offsets[m], **line_opts)
|
|
727
|
-
else:
|
|
728
|
-
d.line(**line_opts)
|
|
729
|
-
|
|
730
|
-
# draw the tensors
|
|
731
|
-
for _, node_data in nodes.items():
|
|
732
|
-
patch_opts = dict(
|
|
733
|
-
coo=node_data["coo"],
|
|
734
|
-
radius=node_data["size"],
|
|
735
|
-
facecolor=node_data["color"],
|
|
736
|
-
edgecolor=node_data["outline_color"],
|
|
737
|
-
linewidth=node_data["outline_size"],
|
|
738
|
-
hatch=node_data["hatch"],
|
|
739
|
-
)
|
|
740
|
-
marker = node_data["marker"]
|
|
741
|
-
|
|
742
|
-
if "multi_colors" in node_data:
|
|
743
|
-
# tensor has multiple tags which are colored
|
|
744
|
-
|
|
745
|
-
if multi_tag_style in ("pie", "auto"):
|
|
746
|
-
# draw a mini pie chart
|
|
747
|
-
if marker not in ("o", "."):
|
|
748
|
-
warnings.warn(
|
|
749
|
-
"Can only draw multi-colored nodes as circles."
|
|
750
|
-
)
|
|
751
|
-
|
|
752
|
-
angles = np.linspace(
|
|
753
|
-
0, 360, len(node_data["multi_colors"]) + 1
|
|
754
|
-
)
|
|
755
|
-
for i, (color, outline_color) in enumerate(
|
|
756
|
-
zip(
|
|
757
|
-
node_data["multi_colors"],
|
|
758
|
-
node_data["multi_outline_colors"],
|
|
759
|
-
)
|
|
760
|
-
):
|
|
761
|
-
patch_opts["facecolor"] = color
|
|
762
|
-
patch_opts["edgecolor"] = outline_color
|
|
763
|
-
d.wedge(
|
|
764
|
-
theta1=angles[i] - 67.5,
|
|
765
|
-
theta2=angles[i + 1] - 67.5,
|
|
766
|
-
**patch_opts,
|
|
767
|
-
)
|
|
768
|
-
elif multi_tag_style == "nest":
|
|
769
|
-
# draw nested markers of decreasing size
|
|
770
|
-
radii = np.linspace(
|
|
771
|
-
node_data["size"], 0, len(node_data["multi_colors"]) + 1
|
|
772
|
-
)
|
|
773
|
-
for i, (color, outline_color) in enumerate(
|
|
774
|
-
zip(
|
|
775
|
-
node_data["multi_colors"],
|
|
776
|
-
node_data["multi_outline_colors"],
|
|
777
|
-
)
|
|
778
|
-
):
|
|
779
|
-
patch_opts["facecolor"] = color
|
|
780
|
-
patch_opts["edgecolor"] = outline_color
|
|
781
|
-
d.marker(
|
|
782
|
-
marker=marker,
|
|
783
|
-
**{**patch_opts, "radius": radii[i], "linewidth": 0},
|
|
784
|
-
)
|
|
785
|
-
elif multi_tag_style == "last":
|
|
786
|
-
# draw a single marker with last tag
|
|
787
|
-
patch_opts["facecolor"] = node_data["multi_colors"][-1]
|
|
788
|
-
patch_opts["edgecolor"] = node_data["multi_outline_colors"][-1]
|
|
789
|
-
d.marker(marker=marker, **patch_opts)
|
|
790
|
-
else: # multi_tag_style == "average":
|
|
791
|
-
# draw a single marker with average color
|
|
792
|
-
patch_opts["facecolor"] = average_color(
|
|
793
|
-
node_data["multi_colors"]
|
|
794
|
-
)
|
|
795
|
-
patch_opts["edgecolor"] = average_color(
|
|
796
|
-
node_data["multi_outline_colors"]
|
|
797
|
-
)
|
|
798
|
-
d.marker(marker=marker, **patch_opts)
|
|
799
|
-
|
|
800
|
-
else:
|
|
801
|
-
d.marker(marker=marker, **patch_opts)
|
|
802
|
-
|
|
803
|
-
if node_data["label"]:
|
|
804
|
-
d.text(
|
|
805
|
-
node_data["coo"],
|
|
806
|
-
node_data["label"],
|
|
807
|
-
fontsize=node_data["label_fontsize"],
|
|
808
|
-
color=node_data["label_color"],
|
|
809
|
-
fontfamily=node_data["label_fontfamily"],
|
|
810
|
-
)
|
|
811
|
-
|
|
812
|
-
_add_legend_matplotlib(
|
|
813
|
-
ax, colors, legend, node_outline_darkness, label_color, font_family
|
|
814
|
-
)
|
|
815
|
-
|
|
816
|
-
if fig is None:
|
|
817
|
-
# ax was supplied, don't modify and simply return
|
|
818
|
-
return
|
|
819
|
-
else:
|
|
820
|
-
# axes and figure were created
|
|
821
|
-
if xlims is not None:
|
|
822
|
-
ax.set_xlim(xlims)
|
|
823
|
-
if ylims is not None:
|
|
824
|
-
ax.set_ylim(ylims)
|
|
825
|
-
if margin is not None:
|
|
826
|
-
ax.margins(margin)
|
|
827
|
-
|
|
828
|
-
if return_fig:
|
|
829
|
-
return fig
|
|
830
|
-
else:
|
|
831
|
-
plt.show()
|
|
832
|
-
plt.close(fig)
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
def _linearize_graph_data(G, multi_tag_style="auto"):
|
|
836
|
-
from ..schematic import average_color
|
|
837
|
-
|
|
838
|
-
edge_source = collections.defaultdict(list)
|
|
839
|
-
for _, _, edge_data in G.edges(data=True):
|
|
840
|
-
cooa, coob = edge_data["coos"]
|
|
841
|
-
x0, y0, *maybe_z0 = cooa
|
|
842
|
-
x1, y1, *maybe_z1 = coob
|
|
843
|
-
edge_source["x0"].append(x0)
|
|
844
|
-
edge_source["y0"].append(y0)
|
|
845
|
-
edge_source["x1"].append(x1)
|
|
846
|
-
edge_source["y1"].append(y1)
|
|
847
|
-
if maybe_z0:
|
|
848
|
-
edge_source["z0"].extend(maybe_z0)
|
|
849
|
-
edge_source["z1"].extend(maybe_z1)
|
|
850
|
-
|
|
851
|
-
# we just aggregate all multi-edges into one
|
|
852
|
-
edge_source["color"].append(average_color(edge_data["color"]))
|
|
853
|
-
edge_source["edge_size"].append(sum(edge_data["edge_size"]))
|
|
854
|
-
edge_source["ind"].append(" ".join(edge_data["ind"]))
|
|
855
|
-
edge_source["ind_size"].append(np.prod(edge_data["ind_size"]))
|
|
856
|
-
edge_source["label"].append(" ".join(edge_data["label"]))
|
|
857
|
-
|
|
858
|
-
node_source = collections.defaultdict(list)
|
|
859
|
-
for _, node_data in G.nodes(data=True):
|
|
860
|
-
if "ind" in node_data:
|
|
861
|
-
continue
|
|
862
|
-
|
|
863
|
-
x, y, *maybe_z = node_data["coo"]
|
|
864
|
-
|
|
865
|
-
if "multi_colors" not in node_data:
|
|
866
|
-
# single marker
|
|
867
|
-
mcs = [node_data["color"]]
|
|
868
|
-
mocs = [node_data["outline_color"]]
|
|
869
|
-
szs = [node_data["size"]]
|
|
870
|
-
os = node_data["outline_size"]
|
|
871
|
-
elif multi_tag_style == "average":
|
|
872
|
-
# plot a single marker with average color
|
|
873
|
-
mcs = [average_color(node_data["multi_colors"])]
|
|
874
|
-
mocs = [average_color(node_data["multi_outline_colors"])]
|
|
875
|
-
szs = [node_data["size"]]
|
|
876
|
-
os = node_data["outline_size"]
|
|
877
|
-
elif multi_tag_style == "last":
|
|
878
|
-
# plot a single marker with last tag
|
|
879
|
-
mcs = [node_data["multi_colors"][-1]]
|
|
880
|
-
mocs = [node_data["multi_outline_colors"][-1]]
|
|
881
|
-
szs = [node_data["size"]]
|
|
882
|
-
os = node_data["outline_size"]
|
|
883
|
-
else: # multi_tag_style in ("auto", "nest"):
|
|
884
|
-
# plot multiple nested markers
|
|
885
|
-
mcs = node_data["multi_colors"]
|
|
886
|
-
mocs = node_data["multi_outline_colors"]
|
|
887
|
-
szs = np.linspace(node_data["size"], 0, len(mcs) + 1)
|
|
888
|
-
os = 0.0
|
|
889
|
-
|
|
890
|
-
for mc, moc, sz in zip(mcs, mocs, szs):
|
|
891
|
-
node_source["x"].append(x)
|
|
892
|
-
node_source["y"].append(y)
|
|
893
|
-
if maybe_z:
|
|
894
|
-
node_source["z"].extend(maybe_z)
|
|
895
|
-
|
|
896
|
-
node_source["color"].append(mc)
|
|
897
|
-
node_source["outline_color"].append(moc)
|
|
898
|
-
node_source["size"].append(sz)
|
|
899
|
-
node_source["outline_size"].append(os)
|
|
900
|
-
|
|
901
|
-
for k in ("hatch", "tags", "shape", "tid", "label"):
|
|
902
|
-
node_source[k].append(node_data.get(k, None))
|
|
903
|
-
|
|
904
|
-
return dict(edge_source), dict(node_source)
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
def _draw_matplotlib3d(G, **kwargs):
|
|
908
|
-
import matplotlib.pyplot as plt
|
|
909
|
-
|
|
910
|
-
edge_source, node_source = _linearize_graph_data(
|
|
911
|
-
G, multi_tag_style=kwargs["multi_tag_style"]
|
|
912
|
-
)
|
|
913
|
-
|
|
914
|
-
ax = kwargs.pop("ax")
|
|
915
|
-
if ax is None:
|
|
916
|
-
fig = plt.figure(figsize=kwargs["figsize"])
|
|
917
|
-
fig.patch.set_alpha(0.0)
|
|
918
|
-
ax = plt.axes([0, 0, 1, 1], projection="3d")
|
|
919
|
-
ax.patch.set_alpha(0.0)
|
|
920
|
-
|
|
921
|
-
xmin = min(node_source["x"])
|
|
922
|
-
xmax = max(node_source["x"])
|
|
923
|
-
ymin = min(node_source["y"])
|
|
924
|
-
ymax = max(node_source["y"])
|
|
925
|
-
zmin = min(node_source["z"])
|
|
926
|
-
zmax = max(node_source["z"])
|
|
927
|
-
xyzmin = min((xmin, ymin, zmin))
|
|
928
|
-
xyzmax = max((xmax, ymax, zmax))
|
|
929
|
-
|
|
930
|
-
ax.set_xlim(xyzmin, xyzmax)
|
|
931
|
-
ax.set_ylim(xyzmin, xyzmax)
|
|
932
|
-
ax.set_zlim(xyzmin, xyzmax)
|
|
933
|
-
ax.set_aspect("equal")
|
|
934
|
-
ax.axis("off")
|
|
935
|
-
|
|
936
|
-
# draw the edges
|
|
937
|
-
# TODO: multiedges and left_inds
|
|
938
|
-
for i in range(len(edge_source["x0"])):
|
|
939
|
-
x0, x1 = edge_source["x0"][i], edge_source["x1"][i]
|
|
940
|
-
xm = (x0 + x1) / 2
|
|
941
|
-
y0, y1 = edge_source["y0"][i], edge_source["y1"][i]
|
|
942
|
-
ym = (y0 + y1) / 2
|
|
943
|
-
z0, z1 = edge_source["z0"][i], edge_source["z1"][i]
|
|
944
|
-
zm = (z0 + z1) / 2
|
|
945
|
-
ax.plot3D(
|
|
946
|
-
[x0, x1],
|
|
947
|
-
[y0, y1],
|
|
948
|
-
[z0, z1],
|
|
949
|
-
c=edge_source["color"][i],
|
|
950
|
-
linewidth=edge_source["edge_size"][i],
|
|
951
|
-
)
|
|
952
|
-
label = edge_source["label"][i]
|
|
953
|
-
if label:
|
|
954
|
-
ax.text(
|
|
955
|
-
xm,
|
|
956
|
-
ym,
|
|
957
|
-
zm,
|
|
958
|
-
s=label,
|
|
959
|
-
ha="center",
|
|
960
|
-
va="center",
|
|
961
|
-
color=edge_source["color"][i],
|
|
962
|
-
fontsize=6,
|
|
963
|
-
)
|
|
964
|
-
|
|
965
|
-
node_source["color"] = [rgba[:3] for rgba in node_source["color"]]
|
|
966
|
-
node_source["size"] = [100000 * s**2 for s in node_source["size"]]
|
|
967
|
-
node_source["linewdith"] = [lw / 50 for lw in node_source["outline_size"]]
|
|
968
|
-
|
|
969
|
-
# draw the nodes
|
|
970
|
-
ax.scatter3D(
|
|
971
|
-
xs="x",
|
|
972
|
-
ys="y",
|
|
973
|
-
zs="z",
|
|
974
|
-
c="color",
|
|
975
|
-
s="size",
|
|
976
|
-
alpha=1.0,
|
|
977
|
-
marker="o",
|
|
978
|
-
data=node_source,
|
|
979
|
-
depthshade=False,
|
|
980
|
-
edgecolor=node_source["outline_color"],
|
|
981
|
-
linewidth=node_source["outline_size"],
|
|
982
|
-
)
|
|
983
|
-
|
|
984
|
-
for _, node_data in G.nodes(data=True):
|
|
985
|
-
label = node_data["label"]
|
|
986
|
-
if label:
|
|
987
|
-
ax.text(
|
|
988
|
-
*node_data["coo"],
|
|
989
|
-
s=label,
|
|
990
|
-
ha="center",
|
|
991
|
-
va="center",
|
|
992
|
-
color=node_data["label_color"],
|
|
993
|
-
fontsize=node_data["label_fontsize"],
|
|
994
|
-
fontfamily=node_data["label_fontfamily"],
|
|
995
|
-
)
|
|
996
|
-
|
|
997
|
-
_add_legend_matplotlib(
|
|
998
|
-
ax,
|
|
999
|
-
kwargs["colors"],
|
|
1000
|
-
kwargs["legend"],
|
|
1001
|
-
kwargs["node_outline_darkness"],
|
|
1002
|
-
kwargs["label_color"],
|
|
1003
|
-
kwargs["font_family"],
|
|
1004
|
-
)
|
|
1005
|
-
|
|
1006
|
-
if kwargs["return_fig"]:
|
|
1007
|
-
return fig
|
|
1008
|
-
else:
|
|
1009
|
-
plt.show()
|
|
1010
|
-
plt.close(fig)
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
def _draw_plotly(G, **kwargs):
|
|
1014
|
-
import plotly.graph_objects as go
|
|
1015
|
-
|
|
1016
|
-
edge_source, node_source = _linearize_graph_data(
|
|
1017
|
-
G, multi_tag_style=kwargs["multi_tag_style"]
|
|
1018
|
-
)
|
|
1019
|
-
|
|
1020
|
-
fig = go.Figure()
|
|
1021
|
-
fig.update_xaxes(visible=False)
|
|
1022
|
-
fig.update_yaxes(visible=False)
|
|
1023
|
-
fig.update_layout(
|
|
1024
|
-
width=100 * kwargs["figsize"][0],
|
|
1025
|
-
height=100 * kwargs["figsize"][1],
|
|
1026
|
-
margin=dict(l=10, r=10, b=10, t=10),
|
|
1027
|
-
plot_bgcolor="rgba(0,0,0,0)",
|
|
1028
|
-
paper_bgcolor="rgba(0,0,0,0)",
|
|
1029
|
-
showlegend=False,
|
|
1030
|
-
scene=dict(
|
|
1031
|
-
xaxis=dict(visible=False),
|
|
1032
|
-
yaxis=dict(visible=False),
|
|
1033
|
-
zaxis=dict(visible=False),
|
|
1034
|
-
),
|
|
1035
|
-
)
|
|
1036
|
-
|
|
1037
|
-
for i in range(len(edge_source["x0"])):
|
|
1038
|
-
x0, x1 = edge_source["x0"][i], edge_source["x1"][i]
|
|
1039
|
-
y0, y1 = edge_source["y0"][i], edge_source["y1"][i]
|
|
1040
|
-
xm, ym = (x0 + x1) / 2, (y0 + y1) / 2
|
|
1041
|
-
*rgb, alpha = edge_source["color"][i]
|
|
1042
|
-
edge_kwargs = dict(
|
|
1043
|
-
x=[x0, xm, x1],
|
|
1044
|
-
y=[y0, ym, y1],
|
|
1045
|
-
opacity=alpha,
|
|
1046
|
-
line=dict(
|
|
1047
|
-
color=to_rgba_str(rgb, 1.0),
|
|
1048
|
-
width=edge_source["edge_size"][i],
|
|
1049
|
-
),
|
|
1050
|
-
customdata=[[edge_source["ind"][i], edge_source["ind_size"][i]]]
|
|
1051
|
-
* 2,
|
|
1052
|
-
# show ind and ind_size on hover:
|
|
1053
|
-
hovertemplate="%{customdata[0]}<br>size: %{customdata[1]}",
|
|
1054
|
-
mode="lines",
|
|
1055
|
-
name="",
|
|
1056
|
-
)
|
|
1057
|
-
if "z0" in edge_source:
|
|
1058
|
-
z0, z1 = edge_source["z0"][i], edge_source["z1"][i]
|
|
1059
|
-
zm = (z0 + z1) / 2
|
|
1060
|
-
edge_kwargs["z"] = [z0, zm, z1]
|
|
1061
|
-
# edges appear much thinner in 3D
|
|
1062
|
-
edge_kwargs["line"]["width"] *= 2
|
|
1063
|
-
fig.add_trace(go.Scatter3d(**edge_kwargs))
|
|
1064
|
-
else:
|
|
1065
|
-
fig.add_trace(go.Scatter(**edge_kwargs))
|
|
1066
|
-
|
|
1067
|
-
node_kwargs = dict(
|
|
1068
|
-
x=node_source["x"],
|
|
1069
|
-
y=node_source["y"],
|
|
1070
|
-
marker=dict(
|
|
1071
|
-
opacity=1.0,
|
|
1072
|
-
color=list(map(to_rgba_str, node_source["color"])),
|
|
1073
|
-
size=[300 * s for s in node_source["size"]],
|
|
1074
|
-
line=dict(
|
|
1075
|
-
color=list(map(to_rgba_str, node_source["outline_color"])),
|
|
1076
|
-
width=2,
|
|
1077
|
-
),
|
|
1078
|
-
),
|
|
1079
|
-
customdata=list(
|
|
1080
|
-
zip(node_source["tid"], node_source["shape"], node_source["tags"])
|
|
1081
|
-
),
|
|
1082
|
-
hovertemplate=(
|
|
1083
|
-
"tid: %{customdata[0]}<br>"
|
|
1084
|
-
"shape: %{customdata[1]}<br>"
|
|
1085
|
-
"tags: %{customdata[2]}"
|
|
1086
|
-
),
|
|
1087
|
-
mode="markers",
|
|
1088
|
-
name="",
|
|
1089
|
-
)
|
|
1090
|
-
if "z" in node_source:
|
|
1091
|
-
node_kwargs["z"] = node_source["z"]
|
|
1092
|
-
fig.add_trace(go.Scatter3d(**node_kwargs))
|
|
1093
|
-
else:
|
|
1094
|
-
fig.add_trace(go.Scatter(**node_kwargs))
|
|
1095
|
-
fig.show()
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
# ---------------------------- layout functions ----------------------------- #
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
def _normalize_positions(pos):
|
|
1102
|
-
# normalize to unit square
|
|
1103
|
-
xmin = ymin = zmin = float("inf")
|
|
1104
|
-
xmax = ymax = zmax = float("-inf")
|
|
1105
|
-
for x, y, *maybe_z in pos.values():
|
|
1106
|
-
xmin = min(xmin, x)
|
|
1107
|
-
xmax = max(xmax, x)
|
|
1108
|
-
ymin = min(ymin, y)
|
|
1109
|
-
ymax = max(ymax, y)
|
|
1110
|
-
for z in maybe_z:
|
|
1111
|
-
zmin = min(zmin, z)
|
|
1112
|
-
zmax = max(zmax, z)
|
|
1113
|
-
|
|
1114
|
-
# maintain aspect ratio:
|
|
1115
|
-
# center each dimension separately
|
|
1116
|
-
xmid, ymid, zmid = (xmin + xmax) / 2, (ymin + ymax) / 2, (zmin + zmax) / 2
|
|
1117
|
-
# but scale all dimensions by the largest range
|
|
1118
|
-
xdiameter, ydiameter, zdiameter = xmax - xmin, ymax - ymin, zmax - zmin
|
|
1119
|
-
radius = max((xdiameter, ydiameter, zdiameter)) / 2
|
|
1120
|
-
|
|
1121
|
-
for node, (x, y, *maybe_z) in pos.items():
|
|
1122
|
-
pos[node] = (
|
|
1123
|
-
(x - xmid) / radius,
|
|
1124
|
-
(y - ymid) / radius,
|
|
1125
|
-
*((z - zmid) / radius for z in maybe_z),
|
|
1126
|
-
)
|
|
1127
|
-
|
|
1128
|
-
return pos
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
def _rotate(xy, theta):
|
|
1132
|
-
"""Return a rotated set of points."""
|
|
1133
|
-
s = np.sin(theta)
|
|
1134
|
-
c = np.cos(theta)
|
|
1135
|
-
|
|
1136
|
-
xyr = np.empty_like(xy)
|
|
1137
|
-
xyr[:, 0] = c * xy[:, 0] - s * xy[:, 1]
|
|
1138
|
-
xyr[:, 1] = s * xy[:, 0] + c * xy[:, 1]
|
|
1139
|
-
|
|
1140
|
-
return xyr
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
def _span(xy):
|
|
1144
|
-
"""Return the vertical span of the points."""
|
|
1145
|
-
return xy[:, 1].max() - xy[:, 1].min()
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
def _massage_pos(pos, nangles=360, flatten=False):
|
|
1149
|
-
"""Rotate a position dict's points to cover a small vertical span"""
|
|
1150
|
-
xy = np.empty((len(pos), 2))
|
|
1151
|
-
for i, (x, y) in enumerate(pos.values()):
|
|
1152
|
-
xy[i, 0] = x
|
|
1153
|
-
xy[i, 1] = y
|
|
1154
|
-
|
|
1155
|
-
thetas = np.linspace(0, 2 * np.pi, nangles, endpoint=False)
|
|
1156
|
-
rxys = (_rotate(xy, theta) for theta in thetas)
|
|
1157
|
-
rxy0 = min(rxys, key=lambda rxy: _span(rxy))
|
|
1158
|
-
|
|
1159
|
-
if flatten:
|
|
1160
|
-
rxy0[:, 1] /= 2
|
|
1161
|
-
|
|
1162
|
-
return dict(zip(pos, rxy0))
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
def phyllotaxis_points(n):
|
|
1166
|
-
"""J. Kogan, "A New Computationally Efficient Method for Spacing Points on
|
|
1167
|
-
a Sphere," Rose-Hulman Undergraduate Mathematics Journal, 18(2), 2017
|
|
1168
|
-
Article 5. scholar.rose-hulman.edu/rhumj/vol18/iss2/5.
|
|
1169
|
-
"""
|
|
1170
|
-
|
|
1171
|
-
def spherical_coordinate(x, y):
|
|
1172
|
-
return [np.cos(x) * np.cos(y), np.sin(x) * np.cos(y), np.sin(y)]
|
|
1173
|
-
|
|
1174
|
-
x = 0.1 + 1.2 * n
|
|
1175
|
-
pts = []
|
|
1176
|
-
start = -1.0 + 1.0 / (n - 1.0)
|
|
1177
|
-
increment = (2.0 - 2.0 / (n - 1.0)) / (n - 1.0)
|
|
1178
|
-
for j in range(n):
|
|
1179
|
-
s = start + j * increment
|
|
1180
|
-
pts.append(
|
|
1181
|
-
spherical_coordinate(
|
|
1182
|
-
s * x,
|
|
1183
|
-
np.pi
|
|
1184
|
-
/ 2.0
|
|
1185
|
-
* np.copysign(1, s)
|
|
1186
|
-
* (1.0 - np.sqrt(1.0 - abs(s))),
|
|
1187
|
-
)
|
|
1188
|
-
)
|
|
1189
|
-
return pts
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
def layout_single_tensor(tn, dim=2):
|
|
1193
|
-
"""Manually layout indices around a tensor either in a circle or sphere."""
|
|
1194
|
-
((tid, t),) = tn.tensor_map.items()
|
|
1195
|
-
|
|
1196
|
-
if dim == 2.5:
|
|
1197
|
-
dim = 3
|
|
1198
|
-
project_back_to_2d = True
|
|
1199
|
-
else:
|
|
1200
|
-
project_back_to_2d = False
|
|
1201
|
-
|
|
1202
|
-
pos = {tid: (0.0,) * dim}
|
|
1203
|
-
if dim == 2:
|
|
1204
|
-
# fix around a circle
|
|
1205
|
-
angles = np.linspace(0, 2 * np.pi, t.ndim, endpoint=False)
|
|
1206
|
-
for ind, angle in zip(t.inds, angles):
|
|
1207
|
-
pos[ind] = (-np.cos(angle), np.sin(angle))
|
|
1208
|
-
else:
|
|
1209
|
-
# fix around a sphere
|
|
1210
|
-
for ind, coo in zip(t.inds, phyllotaxis_points(t.ndim)):
|
|
1211
|
-
pos[ind] = coo
|
|
1212
|
-
|
|
1213
|
-
if project_back_to_2d:
|
|
1214
|
-
pos = {k: v[:2] for k, v in pos.items()}
|
|
1215
|
-
|
|
1216
|
-
return pos
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
def layout_networkx(
|
|
1220
|
-
G,
|
|
1221
|
-
layout="kamada_kawai",
|
|
1222
|
-
pos0=None,
|
|
1223
|
-
fixed=None,
|
|
1224
|
-
dim=2,
|
|
1225
|
-
**kwargs,
|
|
1226
|
-
):
|
|
1227
|
-
import networkx as nx
|
|
1228
|
-
|
|
1229
|
-
layout_fn = getattr(nx, layout + "_layout")
|
|
1230
|
-
|
|
1231
|
-
if pos0 is not None:
|
|
1232
|
-
if layout not in ("spring", "kamada_kawai"):
|
|
1233
|
-
warnings.warn(
|
|
1234
|
-
"Initial positions supplied but layout is not spring-based, "
|
|
1235
|
-
"so `pos0` is being ignored."
|
|
1236
|
-
)
|
|
1237
|
-
else:
|
|
1238
|
-
kwargs["pos"] = pos0
|
|
1239
|
-
|
|
1240
|
-
if fixed is not None:
|
|
1241
|
-
if layout != "spring":
|
|
1242
|
-
warnings.warn(
|
|
1243
|
-
"Fixed positions supplied but layout is not spring-based, "
|
|
1244
|
-
"so `fixed` is being ignored."
|
|
1245
|
-
)
|
|
1246
|
-
else:
|
|
1247
|
-
kwargs["fixed"] = fixed
|
|
1248
|
-
|
|
1249
|
-
return layout_fn(G, dim=dim, **kwargs)
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
def layout_pygraphviz(
|
|
1253
|
-
G,
|
|
1254
|
-
layout="neato",
|
|
1255
|
-
pos0=None,
|
|
1256
|
-
fixed=None,
|
|
1257
|
-
dim=2,
|
|
1258
|
-
iterations=None,
|
|
1259
|
-
k=None,
|
|
1260
|
-
**kwargs,
|
|
1261
|
-
):
|
|
1262
|
-
# TODO: max iters
|
|
1263
|
-
# TODO: spring parameter
|
|
1264
|
-
# TODO: work out why pos0 and fix don't work
|
|
1265
|
-
import pygraphviz as pgv
|
|
1266
|
-
|
|
1267
|
-
if k is not None:
|
|
1268
|
-
warnings.warn(
|
|
1269
|
-
"`k` is being ignored as layout is being done by pygraphviz."
|
|
1270
|
-
)
|
|
1271
|
-
|
|
1272
|
-
aG = pgv.AGraph()
|
|
1273
|
-
|
|
1274
|
-
# create nodes
|
|
1275
|
-
if pos0 is not None:
|
|
1276
|
-
fixed = fixed or set()
|
|
1277
|
-
for node, coo in pos0.items():
|
|
1278
|
-
pos = ",".join((f"{w:f}" for w in coo))
|
|
1279
|
-
pin = "true" if node in fixed else "false"
|
|
1280
|
-
aG.add_node(str(node), pos=pos, pin=pin)
|
|
1281
|
-
|
|
1282
|
-
warnings.warn(
|
|
1283
|
-
"Initial and fixed positions don't seem "
|
|
1284
|
-
"to work currently with pygraphviz."
|
|
1285
|
-
)
|
|
1286
|
-
|
|
1287
|
-
# create edges
|
|
1288
|
-
mapping = {}
|
|
1289
|
-
for nodea, nodeb in G.edges():
|
|
1290
|
-
s_nodea = str(nodea)
|
|
1291
|
-
s_nodeb = str(nodeb)
|
|
1292
|
-
mapping[s_nodea] = nodea
|
|
1293
|
-
mapping[s_nodeb] = nodeb
|
|
1294
|
-
aG.add_edge(s_nodea, s_nodeb)
|
|
1295
|
-
|
|
1296
|
-
# layout options
|
|
1297
|
-
if iterations is not None:
|
|
1298
|
-
kwargs["maxiter"] = iterations
|
|
1299
|
-
if dim == 2.5:
|
|
1300
|
-
kwargs["dim"] = 3
|
|
1301
|
-
kwargs["dimen"] = 2
|
|
1302
|
-
else:
|
|
1303
|
-
kwargs["dim"] = kwargs["dimen"] = dim
|
|
1304
|
-
args = " ".join(f"-G{k}={v}" for k, v in kwargs.items())
|
|
1305
|
-
|
|
1306
|
-
# run layout algorithm
|
|
1307
|
-
aG.layout(prog=layout, args=args)
|
|
1308
|
-
|
|
1309
|
-
# extract layout
|
|
1310
|
-
pos = {}
|
|
1311
|
-
for snode, node in mapping.items():
|
|
1312
|
-
spos = aG.get_node(snode).attr["pos"]
|
|
1313
|
-
pos[node] = tuple(map(float, spos.split(",")))
|
|
1314
|
-
|
|
1315
|
-
pos = _normalize_positions(pos)
|
|
1316
|
-
if dim < 3:
|
|
1317
|
-
pos = _massage_pos(pos)
|
|
1318
|
-
|
|
1319
|
-
return pos
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
def get_positions(
|
|
1323
|
-
tn,
|
|
1324
|
-
G,
|
|
1325
|
-
*,
|
|
1326
|
-
dim=2,
|
|
1327
|
-
fix=None,
|
|
1328
|
-
layout="auto",
|
|
1329
|
-
initial_layout="auto",
|
|
1330
|
-
refine_layout="auto",
|
|
1331
|
-
iterations="auto",
|
|
1332
|
-
k=None,
|
|
1333
|
-
):
|
|
1334
|
-
if (tn is not None) and (tn.num_tensors == 1) and (fix is None):
|
|
1335
|
-
# single tensor, layout manually
|
|
1336
|
-
return layout_single_tensor(tn, dim=dim)
|
|
1337
|
-
|
|
1338
|
-
if layout != "auto":
|
|
1339
|
-
# don't use two step layout with relaxation
|
|
1340
|
-
initial_layout = layout
|
|
1341
|
-
iterations = 0
|
|
1342
|
-
|
|
1343
|
-
if fix is None:
|
|
1344
|
-
fix = dict()
|
|
1345
|
-
else:
|
|
1346
|
-
fix = parse_dict_to_tids_or_inds(fix, tn)
|
|
1347
|
-
# find range with which to scale spectral points with
|
|
1348
|
-
xmin, xmax, ymin, ymax = (
|
|
1349
|
-
f(fix.values(), key=lambda xy: xy[i])[i]
|
|
1350
|
-
for f, i in [(min, 0), (max, 0), (min, 1), (max, 1)]
|
|
1351
|
-
)
|
|
1352
|
-
if xmin == xmax:
|
|
1353
|
-
xmin, xmax = xmin - 1, xmax + 1
|
|
1354
|
-
if ymin == ymax:
|
|
1355
|
-
ymin, ymax = ymin - 1, ymax + 1
|
|
1356
|
-
xymin, xymax = min(xmin, ymin), max(xmax, ymax)
|
|
1357
|
-
|
|
1358
|
-
if all(node in fix for node in G.nodes):
|
|
1359
|
-
# everything is already fixed -> simply normalize
|
|
1360
|
-
return _normalize_positions(fix)
|
|
1361
|
-
|
|
1362
|
-
if initial_layout == "auto":
|
|
1363
|
-
# automatically select
|
|
1364
|
-
if len(G) <= 100:
|
|
1365
|
-
# usually nicest
|
|
1366
|
-
initial_layout = "kamada_kawai"
|
|
1367
|
-
else:
|
|
1368
|
-
# faster, but not as nice
|
|
1369
|
-
initial_layout = "spectral"
|
|
1370
|
-
|
|
1371
|
-
if refine_layout == "auto":
|
|
1372
|
-
# automatically select
|
|
1373
|
-
refine_layout = "spring"
|
|
1374
|
-
# if len(G) <= 100:
|
|
1375
|
-
# # usually nicest
|
|
1376
|
-
# refine_layout = "fdp"
|
|
1377
|
-
# else:
|
|
1378
|
-
# # faster, but not as nice
|
|
1379
|
-
# refine_layout = "sfdp"
|
|
1380
|
-
|
|
1381
|
-
if iterations == "auto":
|
|
1382
|
-
# the smaller the graph, the more iterations we can afford
|
|
1383
|
-
iterations = max(200, 1000 - len(G))
|
|
1384
|
-
|
|
1385
|
-
if dim == 2.5:
|
|
1386
|
-
dim = 3
|
|
1387
|
-
project_back_to_2d = True
|
|
1388
|
-
else:
|
|
1389
|
-
project_back_to_2d = False
|
|
1390
|
-
|
|
1391
|
-
# use spectral or other layout as starting point
|
|
1392
|
-
if initial_layout in ("neato", "fdp", "sfdp", "dot"):
|
|
1393
|
-
pos0 = layout_pygraphviz(G, initial_layout, dim=dim)
|
|
1394
|
-
else:
|
|
1395
|
-
pos0 = layout_networkx(G, initial_layout, dim=dim)
|
|
1396
|
-
|
|
1397
|
-
# scale points to fit with specified positions
|
|
1398
|
-
if fix:
|
|
1399
|
-
# but update with fixed positions
|
|
1400
|
-
pos0.update(
|
|
1401
|
-
valmap(
|
|
1402
|
-
lambda xy: np.array(
|
|
1403
|
-
(
|
|
1404
|
-
2 * (xy[0] - xymin) / (xymax - xymin) - 1,
|
|
1405
|
-
2 * (xy[1] - xymin) / (xymax - xymin) - 1,
|
|
1406
|
-
)
|
|
1407
|
-
),
|
|
1408
|
-
fix,
|
|
1409
|
-
)
|
|
1410
|
-
)
|
|
1411
|
-
fixed = fix.keys()
|
|
1412
|
-
else:
|
|
1413
|
-
fixed = None
|
|
1414
|
-
|
|
1415
|
-
# and then relax remaining using spring layout
|
|
1416
|
-
if iterations:
|
|
1417
|
-
if refine_layout == "spring":
|
|
1418
|
-
pos = layout_networkx(
|
|
1419
|
-
G,
|
|
1420
|
-
"spring",
|
|
1421
|
-
pos0=pos0,
|
|
1422
|
-
fixed=fixed,
|
|
1423
|
-
k=k,
|
|
1424
|
-
dim=dim,
|
|
1425
|
-
iterations=iterations,
|
|
1426
|
-
)
|
|
1427
|
-
elif refine_layout in ("fdp", "sfdp", "neato"):
|
|
1428
|
-
# XXX: currently doesn't seem to work with pos0 and fixed
|
|
1429
|
-
pos = layout_pygraphviz(
|
|
1430
|
-
G,
|
|
1431
|
-
refine_layout,
|
|
1432
|
-
pos0=pos0,
|
|
1433
|
-
fixed=fixed,
|
|
1434
|
-
k=k,
|
|
1435
|
-
dim=dim,
|
|
1436
|
-
iterations=iterations,
|
|
1437
|
-
)
|
|
1438
|
-
else:
|
|
1439
|
-
raise ValueError(f"Unknown refining layout {refine_layout}.")
|
|
1440
|
-
else:
|
|
1441
|
-
# no relaxation
|
|
1442
|
-
pos = pos0
|
|
1443
|
-
|
|
1444
|
-
if project_back_to_2d:
|
|
1445
|
-
# ignore z-coordinate
|
|
1446
|
-
pos = {k: v[:2] for k, v in pos.items()}
|
|
1447
|
-
dim = 2
|
|
1448
|
-
|
|
1449
|
-
# map all to range [-1, +1], but preserving aspect ratio
|
|
1450
|
-
pos = _normalize_positions(pos)
|
|
1451
|
-
|
|
1452
|
-
if (not fix) and (dim == 2):
|
|
1453
|
-
# finally rotate them to cover a small vertical span
|
|
1454
|
-
pos = _massage_pos(pos)
|
|
1455
|
-
|
|
1456
|
-
return pos
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
# ----------------------------- color functions ----------------------------- #
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
def get_colors(color, custom_colors=None, alpha=None):
|
|
1463
|
-
"""Generate a sequence of rgbs for tag(s) ``color``."""
|
|
1464
|
-
from matplotlib.colors import to_rgba
|
|
1465
|
-
|
|
1466
|
-
from ..schematic import auto_colors
|
|
1467
|
-
|
|
1468
|
-
if color is None:
|
|
1469
|
-
return dict()
|
|
1470
|
-
|
|
1471
|
-
if isinstance(color, str):
|
|
1472
|
-
color = (color,)
|
|
1473
|
-
|
|
1474
|
-
if custom_colors is not None:
|
|
1475
|
-
rgbs = [to_rgba(c, alpha=alpha) for c in custom_colors]
|
|
1476
|
-
return dict(zip(color, rgbs))
|
|
1477
|
-
|
|
1478
|
-
nc = len(color)
|
|
1479
|
-
if nc <= 7:
|
|
1480
|
-
rgbs = auto_colors(nc, alpha=alpha, default_sequence=True)
|
|
1481
|
-
return dict(zip(color, rgbs))
|
|
1482
|
-
|
|
1483
|
-
rgbs = auto_colors(nc, alpha)
|
|
1484
|
-
return dict(zip(color, rgbs))
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
def to_rgba_str(color, alpha=None):
|
|
1488
|
-
from matplotlib.colors import to_rgba
|
|
1489
|
-
|
|
1490
|
-
rgba = to_rgba(color, alpha)
|
|
1491
|
-
r = int(rgba[0] * 255) if isinstance(rgba[0], float) else rgba[0]
|
|
1492
|
-
g = int(rgba[1] * 255) if isinstance(rgba[1], float) else rgba[1]
|
|
1493
|
-
b = int(rgba[2] * 255) if isinstance(rgba[2], float) else rgba[2]
|
|
1494
|
-
return f"rgba({r}, {g}, {b}, {rgba[3]})"
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
def auto_color_html(s):
|
|
1498
|
-
"""Automatically hash and color a string for HTML display."""
|
|
1499
|
-
from ..schematic import hash_to_color
|
|
1500
|
-
|
|
1501
|
-
if not isinstance(s, str):
|
|
1502
|
-
s = str(s)
|
|
1503
|
-
return f'<b style="color: {hash_to_color(s)};">{s}</b>'
|
|
1504
|
-
|
|
1505
|
-
|
|
1506
|
-
# ---------------------------- tensor functions ----------------------------- #
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
def visualize_tensor(tensor, **kwargs):
|
|
1510
|
-
"""Visualize all entries of a tensor, with indices mapped into the plane
|
|
1511
|
-
and values mapped into a color wheel.
|
|
1512
|
-
|
|
1513
|
-
Parameters
|
|
1514
|
-
----------
|
|
1515
|
-
tensor : Tensor
|
|
1516
|
-
The tensor to visualize.
|
|
1517
|
-
skew_factor : float, optional
|
|
1518
|
-
When there are more than two dimensions, a factor to scale the
|
|
1519
|
-
rotations by to avoid overlapping data points.
|
|
1520
|
-
size_map : bool, optional
|
|
1521
|
-
Whether to map the tensor value magnitudes to marker size.
|
|
1522
|
-
size_scale : float, optional
|
|
1523
|
-
An overall factor to scale the marker size by.
|
|
1524
|
-
alpha_map : bool, optional
|
|
1525
|
-
Whether to map the tensor value magnitudes to marker alpha.
|
|
1526
|
-
alpha_pow : float, optional
|
|
1527
|
-
The power to raise the magnitude to when mapping to alpha.
|
|
1528
|
-
alpha : float, optional
|
|
1529
|
-
The overall alpha to use for all markers if ``not alpha_map``.
|
|
1530
|
-
show_lattice : bool, optional
|
|
1531
|
-
Show a small grey dot for every 'lattice' point regardless of value.
|
|
1532
|
-
lattice_opts : dict, optional
|
|
1533
|
-
Options to pass to ``maplotlib.Axis.scatter`` for the lattice points.
|
|
1534
|
-
linewidths : float, optional
|
|
1535
|
-
The linewidth to use for the markers.
|
|
1536
|
-
marker : str, optional
|
|
1537
|
-
The marker to use for the markers.
|
|
1538
|
-
figsize : tuple, optional
|
|
1539
|
-
The size of the figure to create, if ``ax`` is not provided.
|
|
1540
|
-
ax : matplotlib.Axis, optional
|
|
1541
|
-
The axis to draw to. If not provided, a new figure will be created.
|
|
1542
|
-
|
|
1543
|
-
Returns
|
|
1544
|
-
-------
|
|
1545
|
-
fig : matplotlib.Figure
|
|
1546
|
-
The figure containing the plot, or ``None`` if ``ax`` was provided.
|
|
1547
|
-
ax : matplotlib.Axis
|
|
1548
|
-
The axis containing the plot.
|
|
1549
|
-
"""
|
|
1550
|
-
import xyzpy as xyz
|
|
1551
|
-
|
|
1552
|
-
kwargs.setdefault("legend", True)
|
|
1553
|
-
kwargs.setdefault("compass", True)
|
|
1554
|
-
kwargs.setdefault("compass_labels", tensor.inds)
|
|
1555
|
-
return xyz.visualize_tensor(tensor.data, **kwargs)
|
|
1556
|
-
|
|
1557
|
-
|
|
1558
|
-
def choose_squarest_grid(x):
|
|
1559
|
-
p = x**0.5
|
|
1560
|
-
if p.is_integer():
|
|
1561
|
-
m = n = int(p)
|
|
1562
|
-
else:
|
|
1563
|
-
m = int(round(p))
|
|
1564
|
-
p = int(p)
|
|
1565
|
-
n = p if m * p >= x else p + 1
|
|
1566
|
-
return m, n
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
def visualize_tensors(
|
|
1570
|
-
tn,
|
|
1571
|
-
mode="network",
|
|
1572
|
-
r=None,
|
|
1573
|
-
r_scale=1.0,
|
|
1574
|
-
figsize=None,
|
|
1575
|
-
**visualize_opts,
|
|
1576
|
-
):
|
|
1577
|
-
"""Visualize all the entries of every tensor in this network.
|
|
1578
|
-
|
|
1579
|
-
Parameters
|
|
1580
|
-
----------
|
|
1581
|
-
tn : TensorNetwork
|
|
1582
|
-
The tensor network to visualize.
|
|
1583
|
-
mode : {'network', 'grid', 'row', 'col'}, optional
|
|
1584
|
-
How to arrange each tensor's visualization.
|
|
1585
|
-
|
|
1586
|
-
- ``'network'``: arrange each tensor's visualization according to the
|
|
1587
|
-
automatic layout given by ``draw``.
|
|
1588
|
-
- ``'grid'``: arrange each tensor's visualization in a grid.
|
|
1589
|
-
- ``'row'``: arrange each tensor's visualization horizontally.
|
|
1590
|
-
- ``'col'``: arrange each tensor's visualization vertically.
|
|
1591
|
-
|
|
1592
|
-
r : float, optional
|
|
1593
|
-
The absolute radius of each tensor's visualization, when
|
|
1594
|
-
``mode='network'``.
|
|
1595
|
-
r_scale : float, optional
|
|
1596
|
-
A relative scaling factor for the radius of each tensor's
|
|
1597
|
-
visualization, when ``mode='network'``.
|
|
1598
|
-
figsize : tuple, optional
|
|
1599
|
-
The size of the figure to create, if ``ax`` is not provided.
|
|
1600
|
-
visualize_opts
|
|
1601
|
-
Supplied to ``visualize_tensor``.
|
|
1602
|
-
"""
|
|
1603
|
-
from matplotlib import pyplot as plt
|
|
1604
|
-
|
|
1605
|
-
if figsize is None:
|
|
1606
|
-
figsize = (2 * tn.num_tensors**0.4, 2 * tn.num_tensors**0.4)
|
|
1607
|
-
if r is None:
|
|
1608
|
-
r = 1.0 / tn.num_tensors**0.5
|
|
1609
|
-
r *= r_scale
|
|
1610
|
-
|
|
1611
|
-
max_mag = None
|
|
1612
|
-
visualize_opts.setdefault("max_mag", max_mag)
|
|
1613
|
-
visualize_opts.setdefault("size_scale", r)
|
|
1614
|
-
|
|
1615
|
-
if mode == "network":
|
|
1616
|
-
fig = plt.figure(figsize=figsize)
|
|
1617
|
-
pos = tn.draw(get="pos")
|
|
1618
|
-
for tid, (x, y) in pos.items():
|
|
1619
|
-
if tid not in tn.tensor_map:
|
|
1620
|
-
# hyper indez
|
|
1621
|
-
continue
|
|
1622
|
-
x = (x + 1) / 2 - r / 2
|
|
1623
|
-
y = (y + 1) / 2 - r / 2
|
|
1624
|
-
ax = fig.add_axes((x, y, r / 2, r / 2))
|
|
1625
|
-
tn.tensor_map[tid].visualize(ax=ax, **visualize_opts)
|
|
1626
|
-
else:
|
|
1627
|
-
if mode == "grid":
|
|
1628
|
-
px, py = choose_squarest_grid(tn.num_tensors)
|
|
1629
|
-
elif mode == "row":
|
|
1630
|
-
px, py = tn.num_tensors, 1
|
|
1631
|
-
figsize = (2 * figsize[0], figsize[1] / 2)
|
|
1632
|
-
elif mode == "col":
|
|
1633
|
-
px, py = 1, tn.num_tensors
|
|
1634
|
-
figsize = (figsize[0] / 2, 2 * figsize[1])
|
|
1635
|
-
|
|
1636
|
-
fig, axs = plt.subplots(py, px, figsize=figsize)
|
|
1637
|
-
for i, t in enumerate(tn):
|
|
1638
|
-
t.visualize(ax=axs.flat[i], **visualize_opts)
|
|
1639
|
-
for ax in axs.flat[i:]:
|
|
1640
|
-
ax.set_axis_off()
|
|
1641
|
-
|
|
1642
|
-
# transparent background
|
|
1643
|
-
fig.patch.set_alpha(0.0)
|
|
1644
|
-
|
|
1645
|
-
plt.show()
|
|
1646
|
-
plt.close()
|