Trajectree 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. trajectree/__init__.py +0 -3
  2. trajectree/fock_optics/devices.py +1 -1
  3. trajectree/fock_optics/light_sources.py +2 -2
  4. trajectree/fock_optics/measurement.py +3 -3
  5. trajectree/fock_optics/utils.py +6 -6
  6. trajectree/trajectory.py +2 -2
  7. {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/METADATA +2 -3
  8. trajectree-0.0.2.dist-info/RECORD +16 -0
  9. trajectree/quimb/docs/_pygments/_pygments_dark.py +0 -118
  10. trajectree/quimb/docs/_pygments/_pygments_light.py +0 -118
  11. trajectree/quimb/docs/conf.py +0 -158
  12. trajectree/quimb/docs/examples/ex_mpi_expm_evo.py +0 -62
  13. trajectree/quimb/quimb/__init__.py +0 -507
  14. trajectree/quimb/quimb/calc.py +0 -1491
  15. trajectree/quimb/quimb/core.py +0 -2279
  16. trajectree/quimb/quimb/evo.py +0 -712
  17. trajectree/quimb/quimb/experimental/__init__.py +0 -0
  18. trajectree/quimb/quimb/experimental/autojittn.py +0 -129
  19. trajectree/quimb/quimb/experimental/belief_propagation/__init__.py +0 -109
  20. trajectree/quimb/quimb/experimental/belief_propagation/bp_common.py +0 -397
  21. trajectree/quimb/quimb/experimental/belief_propagation/d1bp.py +0 -316
  22. trajectree/quimb/quimb/experimental/belief_propagation/d2bp.py +0 -653
  23. trajectree/quimb/quimb/experimental/belief_propagation/hd1bp.py +0 -571
  24. trajectree/quimb/quimb/experimental/belief_propagation/hv1bp.py +0 -775
  25. trajectree/quimb/quimb/experimental/belief_propagation/l1bp.py +0 -316
  26. trajectree/quimb/quimb/experimental/belief_propagation/l2bp.py +0 -537
  27. trajectree/quimb/quimb/experimental/belief_propagation/regions.py +0 -194
  28. trajectree/quimb/quimb/experimental/cluster_update.py +0 -286
  29. trajectree/quimb/quimb/experimental/merabuilder.py +0 -865
  30. trajectree/quimb/quimb/experimental/operatorbuilder/__init__.py +0 -15
  31. trajectree/quimb/quimb/experimental/operatorbuilder/operatorbuilder.py +0 -1631
  32. trajectree/quimb/quimb/experimental/schematic.py +0 -7
  33. trajectree/quimb/quimb/experimental/tn_marginals.py +0 -130
  34. trajectree/quimb/quimb/experimental/tnvmc.py +0 -1483
  35. trajectree/quimb/quimb/gates.py +0 -36
  36. trajectree/quimb/quimb/gen/__init__.py +0 -2
  37. trajectree/quimb/quimb/gen/operators.py +0 -1167
  38. trajectree/quimb/quimb/gen/rand.py +0 -713
  39. trajectree/quimb/quimb/gen/states.py +0 -479
  40. trajectree/quimb/quimb/linalg/__init__.py +0 -6
  41. trajectree/quimb/quimb/linalg/approx_spectral.py +0 -1109
  42. trajectree/quimb/quimb/linalg/autoblock.py +0 -258
  43. trajectree/quimb/quimb/linalg/base_linalg.py +0 -719
  44. trajectree/quimb/quimb/linalg/mpi_launcher.py +0 -397
  45. trajectree/quimb/quimb/linalg/numpy_linalg.py +0 -244
  46. trajectree/quimb/quimb/linalg/rand_linalg.py +0 -514
  47. trajectree/quimb/quimb/linalg/scipy_linalg.py +0 -293
  48. trajectree/quimb/quimb/linalg/slepc_linalg.py +0 -892
  49. trajectree/quimb/quimb/schematic.py +0 -1518
  50. trajectree/quimb/quimb/tensor/__init__.py +0 -401
  51. trajectree/quimb/quimb/tensor/array_ops.py +0 -610
  52. trajectree/quimb/quimb/tensor/circuit.py +0 -4824
  53. trajectree/quimb/quimb/tensor/circuit_gen.py +0 -411
  54. trajectree/quimb/quimb/tensor/contraction.py +0 -336
  55. trajectree/quimb/quimb/tensor/decomp.py +0 -1255
  56. trajectree/quimb/quimb/tensor/drawing.py +0 -1646
  57. trajectree/quimb/quimb/tensor/fitting.py +0 -385
  58. trajectree/quimb/quimb/tensor/geometry.py +0 -583
  59. trajectree/quimb/quimb/tensor/interface.py +0 -114
  60. trajectree/quimb/quimb/tensor/networking.py +0 -1058
  61. trajectree/quimb/quimb/tensor/optimize.py +0 -1818
  62. trajectree/quimb/quimb/tensor/tensor_1d.py +0 -4778
  63. trajectree/quimb/quimb/tensor/tensor_1d_compress.py +0 -1854
  64. trajectree/quimb/quimb/tensor/tensor_1d_tebd.py +0 -662
  65. trajectree/quimb/quimb/tensor/tensor_2d.py +0 -5954
  66. trajectree/quimb/quimb/tensor/tensor_2d_compress.py +0 -96
  67. trajectree/quimb/quimb/tensor/tensor_2d_tebd.py +0 -1230
  68. trajectree/quimb/quimb/tensor/tensor_3d.py +0 -2869
  69. trajectree/quimb/quimb/tensor/tensor_3d_tebd.py +0 -46
  70. trajectree/quimb/quimb/tensor/tensor_approx_spectral.py +0 -60
  71. trajectree/quimb/quimb/tensor/tensor_arbgeom.py +0 -3237
  72. trajectree/quimb/quimb/tensor/tensor_arbgeom_compress.py +0 -565
  73. trajectree/quimb/quimb/tensor/tensor_arbgeom_tebd.py +0 -1138
  74. trajectree/quimb/quimb/tensor/tensor_builder.py +0 -5411
  75. trajectree/quimb/quimb/tensor/tensor_core.py +0 -11179
  76. trajectree/quimb/quimb/tensor/tensor_dmrg.py +0 -1472
  77. trajectree/quimb/quimb/tensor/tensor_mera.py +0 -204
  78. trajectree/quimb/quimb/utils.py +0 -892
  79. trajectree/quimb/tests/__init__.py +0 -0
  80. trajectree/quimb/tests/test_accel.py +0 -501
  81. trajectree/quimb/tests/test_calc.py +0 -788
  82. trajectree/quimb/tests/test_core.py +0 -847
  83. trajectree/quimb/tests/test_evo.py +0 -565
  84. trajectree/quimb/tests/test_gen/__init__.py +0 -0
  85. trajectree/quimb/tests/test_gen/test_operators.py +0 -361
  86. trajectree/quimb/tests/test_gen/test_rand.py +0 -296
  87. trajectree/quimb/tests/test_gen/test_states.py +0 -261
  88. trajectree/quimb/tests/test_linalg/__init__.py +0 -0
  89. trajectree/quimb/tests/test_linalg/test_approx_spectral.py +0 -368
  90. trajectree/quimb/tests/test_linalg/test_base_linalg.py +0 -351
  91. trajectree/quimb/tests/test_linalg/test_mpi_linalg.py +0 -127
  92. trajectree/quimb/tests/test_linalg/test_numpy_linalg.py +0 -84
  93. trajectree/quimb/tests/test_linalg/test_rand_linalg.py +0 -134
  94. trajectree/quimb/tests/test_linalg/test_slepc_linalg.py +0 -283
  95. trajectree/quimb/tests/test_tensor/__init__.py +0 -0
  96. trajectree/quimb/tests/test_tensor/test_belief_propagation/__init__.py +0 -0
  97. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d1bp.py +0 -39
  98. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_d2bp.py +0 -67
  99. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hd1bp.py +0 -64
  100. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_hv1bp.py +0 -51
  101. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l1bp.py +0 -142
  102. trajectree/quimb/tests/test_tensor/test_belief_propagation/test_l2bp.py +0 -101
  103. trajectree/quimb/tests/test_tensor/test_circuit.py +0 -816
  104. trajectree/quimb/tests/test_tensor/test_contract.py +0 -67
  105. trajectree/quimb/tests/test_tensor/test_decomp.py +0 -40
  106. trajectree/quimb/tests/test_tensor/test_mera.py +0 -52
  107. trajectree/quimb/tests/test_tensor/test_optimizers.py +0 -488
  108. trajectree/quimb/tests/test_tensor/test_tensor_1d.py +0 -1171
  109. trajectree/quimb/tests/test_tensor/test_tensor_2d.py +0 -606
  110. trajectree/quimb/tests/test_tensor/test_tensor_2d_tebd.py +0 -144
  111. trajectree/quimb/tests/test_tensor/test_tensor_3d.py +0 -123
  112. trajectree/quimb/tests/test_tensor/test_tensor_arbgeom.py +0 -226
  113. trajectree/quimb/tests/test_tensor/test_tensor_builder.py +0 -441
  114. trajectree/quimb/tests/test_tensor/test_tensor_core.py +0 -2066
  115. trajectree/quimb/tests/test_tensor/test_tensor_dmrg.py +0 -388
  116. trajectree/quimb/tests/test_tensor/test_tensor_spectral_approx.py +0 -63
  117. trajectree/quimb/tests/test_tensor/test_tensor_tebd.py +0 -270
  118. trajectree/quimb/tests/test_utils.py +0 -85
  119. trajectree-0.0.1.dist-info/RECORD +0 -126
  120. {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/WHEEL +0 -0
  121. {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/licenses/LICENSE +0 -0
  122. {trajectree-0.0.1.dist-info → trajectree-0.0.2.dist-info}/top_level.txt +0 -0
@@ -1,1646 +0,0 @@
1
- """Functionailty for drawing tensor networks.
2
- """
3
- import collections
4
- import importlib
5
- import textwrap
6
- import warnings
7
-
8
- import numpy as np
9
-
10
- from ..utils import autocorrect_kwargs, check_opt, valmap
11
-
12
- # from ..schematic import average_color, darken_color, auto_colors, hash_to_color
13
-
14
- HAS_FA2 = importlib.util.find_spec("fa2") is not None
15
-
16
-
17
- @autocorrect_kwargs
18
- def draw_tn(
19
- tn,
20
- color=None,
21
- *,
22
- show_inds=None,
23
- show_tags=None,
24
- output_inds=None,
25
- highlight_inds=(),
26
- highlight_tids=(),
27
- highlight_inds_color=(1.0, 0.2, 0.2),
28
- highlight_tids_color=(1.0, 0.2, 0.2),
29
- custom_colors=None,
30
- legend="auto",
31
- dim=2,
32
- fix=None,
33
- layout="auto",
34
- initial_layout="auto",
35
- refine_layout="auto",
36
- iterations="auto",
37
- k=None,
38
- pos=None,
39
- node_color=None,
40
- node_scale=1.0,
41
- node_size=None,
42
- node_alpha=1.0,
43
- node_shape="o",
44
- node_outline_size=None,
45
- node_outline_darkness=0.9,
46
- node_hatch="",
47
- edge_color=None,
48
- edge_scale=1.0,
49
- edge_alpha=1 / 2,
50
- multi_edge_spread=0.1,
51
- multi_tag_style="auto",
52
- show_left_inds=True,
53
- arrow_opts=None,
54
- label_color=None,
55
- font_size=10,
56
- font_size_inner=7,
57
- font_family="monospace",
58
- isdark=None,
59
- title=None,
60
- backend="matplotlib",
61
- figsize=(6, 6),
62
- margin=None,
63
- xlims=None,
64
- ylims=None,
65
- get=None,
66
- return_fig=False,
67
- ax=None,
68
- ):
69
- """Plot this tensor network as a networkx graph using matplotlib,
70
- with edge width corresponding to bond dimension.
71
-
72
- Parameters
73
- ----------
74
- color : sequence of tags, optional
75
- If given, uniquely color any tensors which have each of the tags.
76
- If some tensors have more than of the tags, only one color will show.
77
- output_inds : sequence of str, optional
78
- For hyper tensor networks explicitly specify which indices should be
79
- drawn as outer indices. If not set, the outer indices are assumed to be
80
- those that only appear on a single tensor.
81
- highlight_inds : iterable, optional
82
- Highlight these edges.
83
- highlight_tids : iterable, optional
84
- Highlight these nodes.
85
- highlight_inds_color
86
- What color to use for ``highlight_inds`` nodes.
87
- highlight_tids_color : tuple[float], optional
88
- What color to use for ``highlight_tids`` nodes.
89
- show_inds : {None, False, True, 'all', 'bond-size'}, optional
90
- Explicitly turn on labels for each tensors indices.
91
- show_tags : {None, False, True}, optional
92
- Explicitly turn on labels for each tensors tags.
93
- custom_colors : sequence of colors, optional
94
- Supply a custom sequence of colors to match the tags given
95
- in ``color``.
96
- title : str, optional
97
- Set a title for the axis.
98
- legend : "auto" or bool, optional
99
- Whether to draw a legend for the colored tags. If ``"auto"`` then
100
- only draw a legend if there are less than 20 tags.
101
- dim : {2, 2.5, 3}, optional
102
- What dimension to position the graph nodes in. 2.5 positions the nodes
103
- in 3D but then projects then down to 2D.
104
- fix : dict[tags_ind_or_tid], (float, float)], optional
105
- Used to specify actual relative positions for each tensor node.
106
- Each key should be a sequence of tags that uniquely identifies a
107
- tensor, a ``tid``, or a ``ind``, and each value should be a ``(x, y)``
108
- coordinate tuple.
109
- layout : str, optional
110
- How to layout the graph. Can be any of the following:
111
-
112
- - ``'auto'``: layout the graph using a networkx method then relax
113
- the layout using a force-directed algorithm.
114
- - a networkx layout method name, e.g. ``'kamada_kawai'``: just
115
- layout the graph using a networkx method, with no relaxation.
116
- - a graphviz method such as ``'dot'``, ``'neato'`` or ``'sfdp'``:
117
- layout the graph using ``pygraphviz``.
118
-
119
- initial_layout : {'auto', 'spectral', 'kamada_kawai', 'circular', \\
120
- 'planar', 'random', 'shell', 'bipartite', ...}, optional
121
- If ``layout == 'auto'`` The name of a networkx layout to use before
122
- iterating with the spring layout. Set `layout` directly or
123
- ``iterations=0`` if you don't want any spring relaxation.
124
- iterations : int, optional
125
- How many iterations to perform when when finding the best layout
126
- using node repulsion. Ramp this up if the graph is drawing messily.
127
- k : float, optional
128
- The optimal distance between nodes.
129
- pos : dict, optional
130
- Pre-computed positions for the nodes. If given, this will override
131
- ``layout``. The nodes shouuld be exactly the same as the nodes in the
132
- graph returned by ``draw(get='graph')``.
133
- node_color : tuple[float], optional
134
- Default color of nodes.
135
- node_scale : float, optional
136
- Scale the node sizes by this factor, in addition to the automatic
137
- scaling based on the number of tensors.
138
- node_size : None, float or dict, optional
139
- How big to draw the tensors. Can be a global single value, or a dict
140
- containing values for specific tags or tids. This is in absolute
141
- figure units. See ``node_scale`` simply scale the node sizes up or
142
- down.
143
- node_alpha : float, optional
144
- Transparency of the nodes.
145
- node_shape : None, str or dict, optional
146
- What shape to draw the tensors. Should correspond to a matplotlib
147
- scatter marker. Can be a global single value, or a dict containing
148
- values for specific tags or tids.
149
- node_outline_size : None, float or dict, optional
150
- The width of the border of each node. Can be a global single value, or
151
- a dict containing values for specific tags or tids.
152
- node_outline_darkness : float, optional
153
- Darkening of nodes outlines.
154
- edge_color : tuple[float], optional
155
- Default color of edges.
156
- edge_scale : float, optional
157
- How much to scale the width of the edges.
158
- edge_alpha : float, optional
159
- Set the alpha (opacity) of the drawn edges.
160
- multi_edge_spread : float, optional
161
- How much to spread the lines of multi-edges.
162
- show_left_inds : bool, optional
163
- Whether to show ``tensor.left_inds`` as incoming arrows.
164
- arrow_closeness : float, optional
165
- How close to draw the arrow to its target.
166
- arrow_length : float, optional
167
- The size of the arrow with respect to the edge.
168
- arrow_overhang : float, optional
169
- Varies the arrowhead between a triangle (0.0) and 'V' (1.0).
170
- arrow_linewidth : float, optional
171
- The width of the arrow line itself.
172
- label_color : tuple[float], optional
173
- Color to draw labels with.
174
- font_size : int, optional
175
- Font size for drawing tags and outer indices.
176
- font_size_inner : int, optional
177
- Font size for drawing inner indices.
178
- font_family : str, optional
179
- Font family to use for all labels.
180
- isdark : bool, optional
181
- Explicitly specify that the background is dark, and use slightly
182
- different default drawing colors. If not specified detects
183
- automatically from `matplotlib.rcParams`.
184
- figsize : tuple of int, optional
185
- The size of the drawing.
186
- margin : None or float, optional
187
- Specify an argument for ``ax.margin``, else the plot limits will try
188
- and be computed based on the node positions and node sizes.
189
- xlims : None or tuple, optional
190
- Explicitly set the x plot range.
191
- xlims : None or tuple, optional
192
- Explicitly set the y plot range.
193
- get : {None, 'pos', 'graph'}, optional
194
- If ``None`` then plot as normal, else if:
195
-
196
- - ``'pos'``, return the plotting positions of each ``tid`` and
197
- ``ind`` drawn as a node, this can supplied to subsequent calls as
198
- ``fix=pos`` to maintain positions, even as the graph structure
199
- changes.
200
- - ``'graph'``, return the ``networkx.Graph`` object. Note that this
201
- will potentially have extra nodes representing output and hyper
202
- indices.
203
-
204
- return_fig : bool, optional
205
- If True and ``ax is None`` then return the figure created rather than
206
- executing ``pyplot.show()``.
207
- ax : matplotlib.Axis, optional
208
- Draw the graph on this axis rather than creating a new figure.
209
- """
210
- import math
211
-
212
- import matplotlib as mpl
213
- import networkx as nx
214
- from matplotlib.colors import to_rgb, to_rgba
215
-
216
- from ..schematic import darken_color, hash_to_color
217
-
218
- check_opt(
219
- "multi_tag_style",
220
- multi_tag_style,
221
- ("auto", "pie", "nest", "average", "last"),
222
- )
223
-
224
- if output_inds is None:
225
- output_inds = set(tn.outer_inds())
226
- elif isinstance(output_inds, str):
227
- output_inds = {output_inds}
228
- else:
229
- output_inds = set(output_inds)
230
-
231
- # automatically decide whether to show tags and inds
232
- if show_inds is None:
233
- show_inds = len(tn.outer_inds()) <= 20
234
- show_inds = {False: "", True: "outer"}.get(show_inds, show_inds)
235
-
236
- if show_tags is None:
237
- show_tags = len(tn.tag_map) <= 20
238
- show_tags = {False: "", True: "tags"}.get(show_tags, show_tags)
239
-
240
- if isdark is None:
241
- isdark = sum(to_rgb(mpl.rcParams["figure.facecolor"])) / 3 < 0.5
242
-
243
- if isdark:
244
- default_draw_color = (0.55, 0.57, 0.60, 1.0)
245
- default_label_color = (0.85, 0.86, 0.87, 1.0)
246
- else:
247
- default_draw_color = (0.45, 0.47, 0.50, 1.0)
248
- default_label_color = (0.33, 0.34, 0.35, 1.0)
249
-
250
- if edge_color is None:
251
- edge_color = mpl.colors.to_rgba(default_draw_color, edge_alpha)
252
- elif edge_color is True:
253
- # hash edge to get color
254
- pass
255
- else:
256
- edge_color = mpl.colors.to_rgba(edge_color, edge_alpha)
257
-
258
- if node_color is None:
259
- node_color = mpl.colors.to_rgba(default_draw_color, node_alpha)
260
- else:
261
- node_color = mpl.colors.to_rgba(node_color, node_alpha)
262
-
263
- if label_color is None:
264
- label_color = default_label_color
265
- elif label_color == "inherit":
266
- label_color = mpl.rcParams["axes.labelcolor"]
267
-
268
- # get colors for tagged nodes
269
- colors = get_colors(color, custom_colors, node_alpha)
270
-
271
- if legend == "auto":
272
- legend = len(colors) <= 20
273
-
274
- highlight_tids_color = to_rgba(highlight_tids_color, node_alpha)
275
- highlight_inds_color = to_rgba(highlight_inds_color, edge_alpha)
276
-
277
- # set the size of the nodes and their border
278
- node_size = parse_dict_to_tids_or_inds(
279
- node_size,
280
- tn,
281
- default=1,
282
- )
283
- node_outline_size = parse_dict_to_tids_or_inds(
284
- node_outline_size,
285
- tn,
286
- default=1,
287
- )
288
- node_shape = parse_dict_to_tids_or_inds(node_shape, tn, default="o")
289
- node_hatch = parse_dict_to_tids_or_inds(node_hatch, tn, default="")
290
-
291
- # build the graph
292
- edges = collections.defaultdict(lambda: collections.defaultdict(list))
293
- nodes = collections.defaultdict(dict)
294
-
295
- # parse all indices / edges
296
- for ix, tids in tn.ind_map.items():
297
- tids = sorted(tids)
298
-
299
- isouter = ix in output_inds
300
- ishyper = isouter or (len(tids) != 2)
301
- ind_size = tn.ind_size(ix)
302
- edge_size = edge_scale * math.log2(ind_size)
303
-
304
- # compute a color for this index
305
- color = (
306
- highlight_inds_color
307
- if ix in highlight_inds
308
- else to_rgba(hash_to_color(ix))
309
- if edge_color is True
310
- else edge_color
311
- )
312
-
313
- # compute a label for this index
314
- if ishyper:
315
- # each tensor connects to the dummy node represeting the hyper edge
316
- pairs = [(tid, ix) for tid in tids]
317
- if isouter and len(tids) > 1:
318
- # 'hyper outer' index
319
- pairs.append((("outer", ix), ix))
320
- # hyper labels get put on dummy node
321
- label = ""
322
-
323
- nodes[ix]["ind"] = ix
324
- nodes[ix]["ind_size"] = ind_size
325
- # make actual node invisible
326
- nodes[ix]["color"] = (1.0, 1.0, 1.0, 1.0)
327
- nodes[ix]["size"] = 0.0
328
- nodes[ix]["outline_size"] = 0.0
329
- nodes[ix]["outline_color"] = (1.0, 1.0, 1.0, 1.0)
330
- nodes[ix]["marker"] = "." # set this to avoid warning - size is 0
331
- nodes[ix]["hatch"] = ""
332
-
333
- # set these for plotly hover info
334
- nodes[ix]["tid"] = nodes[ix]["shape"] = nodes[ix]["tags"] = ""
335
-
336
- if ((show_inds == "outer") and isouter) or (show_inds == "all"):
337
- # show as outer index or inner index name
338
- nodes[ix]["label"] = ix
339
- elif show_inds == "bond-size":
340
- # show all bond sizes
341
- nodes[ix]["label"] = f"{tn.ind_size(ix)}"
342
- else:
343
- # labels hidden or inner edge
344
- nodes[ix]["label"] = ""
345
-
346
- nodes[ix]["label_fontsize"] = font_size_inner
347
- nodes[ix]["label_color"] = label_color
348
- nodes[ix]["label_fontfamily"] = font_family
349
-
350
- else:
351
- # standard edge
352
- pairs = [tuple(tids)]
353
-
354
- if show_inds == "all":
355
- # show inner index name
356
- label = ix
357
- elif show_inds == "bond-size":
358
- # show all bond sizes
359
- label = f"{ind_size}"
360
- else:
361
- # labels hidden or inner edge
362
- label = ""
363
-
364
- for pair in pairs:
365
- edges[pair]["color"].append(color)
366
- edges[pair]["ind"].append(ix)
367
- edges[pair]["ind_size"].append(ind_size)
368
- edges[pair]["edge_size"].append(edge_size)
369
- edges[pair]["label"].append(label)
370
- edges[pair]["label_fontsize"] = font_size_inner
371
- edges[pair]["label_color"] = label_color
372
- edges[pair]["label_fontfamily"] = font_family
373
-
374
- if isinstance(pair[0], tuple):
375
- # dummy hyper outer edge - no arrows
376
- edges[pair]["arrow_left"].append(False)
377
- edges[pair]["arrow_right"].append(False)
378
- else:
379
- # tensor side can always have an incoming arrow
380
- tl_left_inds = tn.tensor_map[pair[0]].left_inds
381
- edges[pair]["arrow_left"].append(
382
- show_left_inds
383
- and (tl_left_inds is not None)
384
- and (ix in tl_left_inds)
385
- )
386
- if ishyper:
387
- # hyper edge can't have an incoming arrow
388
- edges[pair]["arrow_right"].append(False)
389
- else:
390
- # standard edge can
391
- tr_left_inds = tn.tensor_map[pair[1]].left_inds
392
- edges[pair]["arrow_right"].append(
393
- show_left_inds
394
- and (tr_left_inds is not None)
395
- and (ix in tr_left_inds)
396
- )
397
-
398
- # parse all tensors / nodes
399
- for tid, t in tn.tensor_map.items():
400
- nodes[tid]["tid"] = tid
401
- nodes[tid]["tags"] = str(list(t.tags))
402
- nodes[tid]["shape"] = str(t.shape)
403
- nodes[tid]["size"] = node_size[tid]
404
- nodes[tid]["outline_size"] = node_outline_size[tid]
405
- nodes[tid]["marker"] = node_shape[tid]
406
- nodes[tid]["hatch"] = node_hatch[tid]
407
-
408
- if show_tags == "tags":
409
- node_label = ", ".join(map(str, t.tags))
410
- # make the tags appear with auto vertical extent
411
- nodes[tid]["label"] = "\n".join(
412
- textwrap.wrap(node_label, max(2 * len(node_label) ** 0.5, 16))
413
- )
414
- elif show_tags == "tids":
415
- nodes[tid]["label"] = str(tid)
416
- elif show_tags == "shape":
417
- nodes[tid]["label"] = nodes[tid]["shape"]
418
- else:
419
- nodes[tid]["label"] = ""
420
-
421
- nodes[tid]["label_fontsize"] = font_size
422
- nodes[tid]["label_color"] = label_color
423
- nodes[tid]["label_fontfamily"] = font_family
424
-
425
- if tid in highlight_tids:
426
- nodes[tid]["color"] = highlight_tids_color
427
- nodes[tid]["outline_color"] = darken_color(
428
- highlight_tids_color, node_outline_darkness
429
- )
430
- else:
431
- # collect all relevant tag colors
432
- multi_colors = []
433
- multi_outline_colors = []
434
- for tag in colors:
435
- if tag in t.tags:
436
- multi_colors.append(colors[tag])
437
- multi_outline_colors.append(
438
- darken_color(colors[tag], node_outline_darkness)
439
- )
440
-
441
- if len(multi_colors) >= 1:
442
- # set the basic color to the last tag
443
- nodes[tid]["color"] = multi_colors[-1]
444
- nodes[tid]["outline_color"] = multi_outline_colors[-1]
445
- if len(multi_colors) >= 2:
446
- # have multiple relevant tags - store them, but some
447
- # backends might support, so store alongside basic color
448
- nodes[tid]["multi_colors"] = multi_colors
449
- nodes[tid]["multi_outline_colors"] = multi_outline_colors
450
- else:
451
- # untagged node
452
- nodes[tid]["color"] = node_color
453
- nodes[tid]["outline_color"] = darken_color(
454
- node_color, node_outline_darkness**2
455
- )
456
-
457
- G = nx.Graph()
458
- for edge, edge_data in edges.items():
459
- G.add_edge(*edge, **edge_data)
460
- for node, node_data in nodes.items():
461
- G.add_node(node, **node_data)
462
-
463
- if pos is None:
464
- pos = get_positions(
465
- tn=tn,
466
- G=G,
467
- fix=fix,
468
- layout=layout,
469
- initial_layout=initial_layout,
470
- refine_layout=refine_layout,
471
- k=k,
472
- dim=dim,
473
- iterations=iterations,
474
- )
475
- else:
476
- pos = _normalize_positions(pos)
477
-
478
- # compute a base size using the position and number of tensors
479
- # first get plot volume:
480
- node_packing_factor = tn.num_tensors**-0.45
481
- xs, ys, *zs = zip(*pos.values())
482
- xmin, xmax = min(xs), max(xs)
483
- ymin, ymax = min(ys), max(ys)
484
- # if there only a few tensors we don't want to limit the node size
485
- # because of flatness, also don't allow the plot volume to go to zero
486
- xrange = max(((xmax - xmin) / 2, node_packing_factor, 0.1))
487
- yrange = max(((ymax - ymin) / 2, node_packing_factor, 0.1))
488
- plot_volume = xrange * yrange
489
- if zs:
490
- zmin, zmax = min(zs[0]), max(zs[0])
491
- zrange = max(((zmax - zmin) / 2, node_packing_factor, 0.1))
492
- plot_volume *= zrange
493
- # in total we account for:
494
- # - user specified scaling
495
- # - number of tensors
496
- # - how flat the plot area is (flatter requires smaller nodes)
497
- full_node_scale = 0.2 * node_scale * node_packing_factor * plot_volume**0.5
498
-
499
- default_outline_size = 6 * full_node_scale**0.5
500
-
501
- # update node size and position attributes
502
- for node, node_data in nodes.items():
503
- nodes[node]["size"] = G.nodes[node]["size"] = (
504
- full_node_scale * node_data["size"]
505
- )
506
- nodes[node]["outline_size"] = G.nodes[node]["outline_size"] = (
507
- default_outline_size * node_data["outline_size"]
508
- )
509
- nodes[node]["coo"] = G.nodes[node]["coo"] = pos[node]
510
-
511
- for (i, j), edge_data in edges.items():
512
- edges[i, j]["coos"] = G.edges[i, j]["coos"] = pos[i], pos[j]
513
-
514
- if get == "pos":
515
- return pos
516
- if get == "graph,pos":
517
- return G, pos
518
-
519
- opts = {
520
- "colors": colors,
521
- "node_outline_darkness": node_outline_darkness,
522
- "title": title,
523
- "legend": legend,
524
- "multi_edge_spread": multi_edge_spread,
525
- "multi_tag_style": multi_tag_style,
526
- "arrow_opts": arrow_opts,
527
- "label_color": label_color,
528
- "font_family": font_family,
529
- "figsize": figsize,
530
- "margin": margin,
531
- "xlims": xlims,
532
- "ylims": ylims,
533
- "return_fig": return_fig,
534
- "ax": ax,
535
- }
536
-
537
- if get == "data":
538
- return edges, nodes, opts
539
-
540
- if backend == "matplotlib":
541
- return _draw_matplotlib(edges=edges, nodes=nodes, **opts)
542
-
543
- if backend == "matplotlib3d":
544
- return _draw_matplotlib3d(G, **opts)
545
-
546
- if backend == "plotly":
547
- return _draw_plotly(G, **opts)
548
-
549
-
550
- def parse_dict_to_tids_or_inds(spec, tn, default="__NONE__"):
551
- """Parse a dictionary possibly containing a mix of tags, tids and inds, to
552
- a dictionary with only sinlge tids and inds as keys. If a tag or set of
553
- tags are given as a key, all matching tensor tids will receive the value.
554
- """
555
- #
556
- if (spec is not None) and (not isinstance(spec, dict)):
557
- # assume new default value for everything
558
- return collections.defaultdict(lambda: spec)
559
-
560
- # allow not specifying a default value
561
- if default != "__NONE__":
562
- new = collections.defaultdict(lambda: default)
563
- else:
564
- new = {}
565
-
566
- if spec is None:
567
- return new
568
-
569
- # parse the special values
570
- for k, v in spec.items():
571
- if (
572
- # given as tid
573
- (isinstance(k, int) and k in tn.tensor_map)
574
- or
575
- # given as ind
576
- (isinstance(k, str) and k in tn.ind_map)
577
- ):
578
- # already a tid
579
- new[k] = v
580
- continue
581
-
582
- try:
583
- for tid in tn._get_tids_from_tags(k):
584
- new[tid] = v
585
- except KeyError:
586
- # just ignore keys that don't match any tensor
587
- pass
588
-
589
- return new
590
-
591
-
592
- def _add_legend_matplotlib(
593
- ax, colors, legend, node_outline_darkness, label_color, font_family
594
- ):
595
- import matplotlib.pyplot as plt
596
-
597
- # create legend
598
- if colors and legend:
599
- handles = []
600
- for color in colors.values():
601
- ecolor = tuple(
602
- (1.0 if i == 3 else node_outline_darkness) * c
603
- for i, c in enumerate(color)
604
- )
605
- handles += [
606
- plt.Line2D(
607
- [0],
608
- [0],
609
- marker="o",
610
- color=color,
611
- markeredgecolor=ecolor,
612
- markeredgewidth=1,
613
- linestyle="",
614
- markersize=10,
615
- )
616
- ]
617
-
618
- # needed in case '_' is the first character
619
- lbls = [f" {lbl}" for lbl in colors]
620
-
621
- legend = ax.legend(
622
- handles,
623
- lbls,
624
- ncol=max(round(len(handles) / 20), 1),
625
- loc="center left",
626
- bbox_to_anchor=(1, 0.5),
627
- labelcolor=label_color,
628
- prop={"family": font_family},
629
- )
630
- # do this manually as otherwise can't make only face transparent
631
- legend.get_frame().set_alpha(None)
632
- legend.get_frame().set_facecolor((0.0, 0.0, 0.0, 0.0))
633
- legend.get_frame().set_edgecolor((0.6, 0.6, 0.6, 0.2))
634
-
635
-
636
- def _draw_matplotlib(
637
- edges,
638
- nodes,
639
- *,
640
- colors=None,
641
- node_outline_darkness=0.9,
642
- title=None,
643
- legend=True,
644
- multi_edge_spread=0.1,
645
- multi_tag_style="auto",
646
- arrow_opts=None,
647
- label_color=None,
648
- font_family="monospace",
649
- figsize=(6, 6),
650
- margin=None,
651
- xlims=None,
652
- ylims=None,
653
- return_fig=False,
654
- ax=None,
655
- ):
656
- import matplotlib.pyplot as plt
657
-
658
- from quimb.schematic import Drawing, average_color
659
-
660
- d = Drawing(figsize=figsize, ax=ax)
661
- if ax is None:
662
- fig = d.fig
663
- ax = d.ax
664
- fig.patch.set_alpha(0.0)
665
- ax.patch.set_alpha(0.0)
666
- else:
667
- fig = None
668
-
669
- arrow_opts = arrow_opts or {}
670
- arrow_opts.setdefault("center", 3 / 4)
671
- arrow_opts.setdefault("linewidth", 1)
672
- arrow_opts.setdefault("width", 0.08)
673
- arrow_opts.setdefault("length", 0.12)
674
-
675
- if title is not None:
676
- ax.set_title(str(title))
677
-
678
- for _, edge_data in edges.items():
679
- cooa, coob = edge_data["coos"]
680
- edge_colors = edge_data["color"]
681
- edge_sizes = edge_data["edge_size"]
682
- labels = edge_data["label"]
683
- arrow_lefts = edge_data["arrow_left"]
684
- arrow_rights = edge_data["arrow_right"]
685
- multiplicity = len(edge_colors)
686
-
687
- if multiplicity > 1:
688
- offsets = np.linspace(
689
- +multiplicity * multi_edge_spread / 2,
690
- -multiplicity * multi_edge_spread / 2,
691
- multiplicity,
692
- )
693
- else:
694
- offsets = None
695
-
696
- for m in range(multiplicity):
697
- line_opts = dict(
698
- cooa=cooa,
699
- coob=coob,
700
- linewidth=edge_sizes[m],
701
- color=edge_colors[m],
702
- )
703
-
704
- arrowhead, reverse = {
705
- (False, False): (None, False), # no arrow
706
- (False, True): (True, False), # arrowhead to right
707
- (True, False): (True, True), # arrowhead to left
708
- (True, True): (True, "both"), # arrowheads both sides
709
- }[arrow_lefts[m], arrow_rights[m]]
710
-
711
- if arrowhead:
712
- line_opts["arrowhead"] = dict(
713
- reverse=reverse,
714
- **arrow_opts,
715
- )
716
-
717
- if labels[m]:
718
- line_opts["text"] = dict(
719
- text=labels[m],
720
- fontsize=edge_data["label_fontsize"],
721
- color=edge_data["label_color"],
722
- fontfamily=edge_data["label_fontfamily"],
723
- )
724
-
725
- if multiplicity > 1:
726
- d.line_offset(offset=offsets[m], **line_opts)
727
- else:
728
- d.line(**line_opts)
729
-
730
- # draw the tensors
731
- for _, node_data in nodes.items():
732
- patch_opts = dict(
733
- coo=node_data["coo"],
734
- radius=node_data["size"],
735
- facecolor=node_data["color"],
736
- edgecolor=node_data["outline_color"],
737
- linewidth=node_data["outline_size"],
738
- hatch=node_data["hatch"],
739
- )
740
- marker = node_data["marker"]
741
-
742
- if "multi_colors" in node_data:
743
- # tensor has multiple tags which are colored
744
-
745
- if multi_tag_style in ("pie", "auto"):
746
- # draw a mini pie chart
747
- if marker not in ("o", "."):
748
- warnings.warn(
749
- "Can only draw multi-colored nodes as circles."
750
- )
751
-
752
- angles = np.linspace(
753
- 0, 360, len(node_data["multi_colors"]) + 1
754
- )
755
- for i, (color, outline_color) in enumerate(
756
- zip(
757
- node_data["multi_colors"],
758
- node_data["multi_outline_colors"],
759
- )
760
- ):
761
- patch_opts["facecolor"] = color
762
- patch_opts["edgecolor"] = outline_color
763
- d.wedge(
764
- theta1=angles[i] - 67.5,
765
- theta2=angles[i + 1] - 67.5,
766
- **patch_opts,
767
- )
768
- elif multi_tag_style == "nest":
769
- # draw nested markers of decreasing size
770
- radii = np.linspace(
771
- node_data["size"], 0, len(node_data["multi_colors"]) + 1
772
- )
773
- for i, (color, outline_color) in enumerate(
774
- zip(
775
- node_data["multi_colors"],
776
- node_data["multi_outline_colors"],
777
- )
778
- ):
779
- patch_opts["facecolor"] = color
780
- patch_opts["edgecolor"] = outline_color
781
- d.marker(
782
- marker=marker,
783
- **{**patch_opts, "radius": radii[i], "linewidth": 0},
784
- )
785
- elif multi_tag_style == "last":
786
- # draw a single marker with last tag
787
- patch_opts["facecolor"] = node_data["multi_colors"][-1]
788
- patch_opts["edgecolor"] = node_data["multi_outline_colors"][-1]
789
- d.marker(marker=marker, **patch_opts)
790
- else: # multi_tag_style == "average":
791
- # draw a single marker with average color
792
- patch_opts["facecolor"] = average_color(
793
- node_data["multi_colors"]
794
- )
795
- patch_opts["edgecolor"] = average_color(
796
- node_data["multi_outline_colors"]
797
- )
798
- d.marker(marker=marker, **patch_opts)
799
-
800
- else:
801
- d.marker(marker=marker, **patch_opts)
802
-
803
- if node_data["label"]:
804
- d.text(
805
- node_data["coo"],
806
- node_data["label"],
807
- fontsize=node_data["label_fontsize"],
808
- color=node_data["label_color"],
809
- fontfamily=node_data["label_fontfamily"],
810
- )
811
-
812
- _add_legend_matplotlib(
813
- ax, colors, legend, node_outline_darkness, label_color, font_family
814
- )
815
-
816
- if fig is None:
817
- # ax was supplied, don't modify and simply return
818
- return
819
- else:
820
- # axes and figure were created
821
- if xlims is not None:
822
- ax.set_xlim(xlims)
823
- if ylims is not None:
824
- ax.set_ylim(ylims)
825
- if margin is not None:
826
- ax.margins(margin)
827
-
828
- if return_fig:
829
- return fig
830
- else:
831
- plt.show()
832
- plt.close(fig)
833
-
834
-
835
- def _linearize_graph_data(G, multi_tag_style="auto"):
836
- from ..schematic import average_color
837
-
838
- edge_source = collections.defaultdict(list)
839
- for _, _, edge_data in G.edges(data=True):
840
- cooa, coob = edge_data["coos"]
841
- x0, y0, *maybe_z0 = cooa
842
- x1, y1, *maybe_z1 = coob
843
- edge_source["x0"].append(x0)
844
- edge_source["y0"].append(y0)
845
- edge_source["x1"].append(x1)
846
- edge_source["y1"].append(y1)
847
- if maybe_z0:
848
- edge_source["z0"].extend(maybe_z0)
849
- edge_source["z1"].extend(maybe_z1)
850
-
851
- # we just aggregate all multi-edges into one
852
- edge_source["color"].append(average_color(edge_data["color"]))
853
- edge_source["edge_size"].append(sum(edge_data["edge_size"]))
854
- edge_source["ind"].append(" ".join(edge_data["ind"]))
855
- edge_source["ind_size"].append(np.prod(edge_data["ind_size"]))
856
- edge_source["label"].append(" ".join(edge_data["label"]))
857
-
858
- node_source = collections.defaultdict(list)
859
- for _, node_data in G.nodes(data=True):
860
- if "ind" in node_data:
861
- continue
862
-
863
- x, y, *maybe_z = node_data["coo"]
864
-
865
- if "multi_colors" not in node_data:
866
- # single marker
867
- mcs = [node_data["color"]]
868
- mocs = [node_data["outline_color"]]
869
- szs = [node_data["size"]]
870
- os = node_data["outline_size"]
871
- elif multi_tag_style == "average":
872
- # plot a single marker with average color
873
- mcs = [average_color(node_data["multi_colors"])]
874
- mocs = [average_color(node_data["multi_outline_colors"])]
875
- szs = [node_data["size"]]
876
- os = node_data["outline_size"]
877
- elif multi_tag_style == "last":
878
- # plot a single marker with last tag
879
- mcs = [node_data["multi_colors"][-1]]
880
- mocs = [node_data["multi_outline_colors"][-1]]
881
- szs = [node_data["size"]]
882
- os = node_data["outline_size"]
883
- else: # multi_tag_style in ("auto", "nest"):
884
- # plot multiple nested markers
885
- mcs = node_data["multi_colors"]
886
- mocs = node_data["multi_outline_colors"]
887
- szs = np.linspace(node_data["size"], 0, len(mcs) + 1)
888
- os = 0.0
889
-
890
- for mc, moc, sz in zip(mcs, mocs, szs):
891
- node_source["x"].append(x)
892
- node_source["y"].append(y)
893
- if maybe_z:
894
- node_source["z"].extend(maybe_z)
895
-
896
- node_source["color"].append(mc)
897
- node_source["outline_color"].append(moc)
898
- node_source["size"].append(sz)
899
- node_source["outline_size"].append(os)
900
-
901
- for k in ("hatch", "tags", "shape", "tid", "label"):
902
- node_source[k].append(node_data.get(k, None))
903
-
904
- return dict(edge_source), dict(node_source)
905
-
906
-
907
- def _draw_matplotlib3d(G, **kwargs):
908
- import matplotlib.pyplot as plt
909
-
910
- edge_source, node_source = _linearize_graph_data(
911
- G, multi_tag_style=kwargs["multi_tag_style"]
912
- )
913
-
914
- ax = kwargs.pop("ax")
915
- if ax is None:
916
- fig = plt.figure(figsize=kwargs["figsize"])
917
- fig.patch.set_alpha(0.0)
918
- ax = plt.axes([0, 0, 1, 1], projection="3d")
919
- ax.patch.set_alpha(0.0)
920
-
921
- xmin = min(node_source["x"])
922
- xmax = max(node_source["x"])
923
- ymin = min(node_source["y"])
924
- ymax = max(node_source["y"])
925
- zmin = min(node_source["z"])
926
- zmax = max(node_source["z"])
927
- xyzmin = min((xmin, ymin, zmin))
928
- xyzmax = max((xmax, ymax, zmax))
929
-
930
- ax.set_xlim(xyzmin, xyzmax)
931
- ax.set_ylim(xyzmin, xyzmax)
932
- ax.set_zlim(xyzmin, xyzmax)
933
- ax.set_aspect("equal")
934
- ax.axis("off")
935
-
936
- # draw the edges
937
- # TODO: multiedges and left_inds
938
- for i in range(len(edge_source["x0"])):
939
- x0, x1 = edge_source["x0"][i], edge_source["x1"][i]
940
- xm = (x0 + x1) / 2
941
- y0, y1 = edge_source["y0"][i], edge_source["y1"][i]
942
- ym = (y0 + y1) / 2
943
- z0, z1 = edge_source["z0"][i], edge_source["z1"][i]
944
- zm = (z0 + z1) / 2
945
- ax.plot3D(
946
- [x0, x1],
947
- [y0, y1],
948
- [z0, z1],
949
- c=edge_source["color"][i],
950
- linewidth=edge_source["edge_size"][i],
951
- )
952
- label = edge_source["label"][i]
953
- if label:
954
- ax.text(
955
- xm,
956
- ym,
957
- zm,
958
- s=label,
959
- ha="center",
960
- va="center",
961
- color=edge_source["color"][i],
962
- fontsize=6,
963
- )
964
-
965
- node_source["color"] = [rgba[:3] for rgba in node_source["color"]]
966
- node_source["size"] = [100000 * s**2 for s in node_source["size"]]
967
- node_source["linewdith"] = [lw / 50 for lw in node_source["outline_size"]]
968
-
969
- # draw the nodes
970
- ax.scatter3D(
971
- xs="x",
972
- ys="y",
973
- zs="z",
974
- c="color",
975
- s="size",
976
- alpha=1.0,
977
- marker="o",
978
- data=node_source,
979
- depthshade=False,
980
- edgecolor=node_source["outline_color"],
981
- linewidth=node_source["outline_size"],
982
- )
983
-
984
- for _, node_data in G.nodes(data=True):
985
- label = node_data["label"]
986
- if label:
987
- ax.text(
988
- *node_data["coo"],
989
- s=label,
990
- ha="center",
991
- va="center",
992
- color=node_data["label_color"],
993
- fontsize=node_data["label_fontsize"],
994
- fontfamily=node_data["label_fontfamily"],
995
- )
996
-
997
- _add_legend_matplotlib(
998
- ax,
999
- kwargs["colors"],
1000
- kwargs["legend"],
1001
- kwargs["node_outline_darkness"],
1002
- kwargs["label_color"],
1003
- kwargs["font_family"],
1004
- )
1005
-
1006
- if kwargs["return_fig"]:
1007
- return fig
1008
- else:
1009
- plt.show()
1010
- plt.close(fig)
1011
-
1012
-
1013
- def _draw_plotly(G, **kwargs):
1014
- import plotly.graph_objects as go
1015
-
1016
- edge_source, node_source = _linearize_graph_data(
1017
- G, multi_tag_style=kwargs["multi_tag_style"]
1018
- )
1019
-
1020
- fig = go.Figure()
1021
- fig.update_xaxes(visible=False)
1022
- fig.update_yaxes(visible=False)
1023
- fig.update_layout(
1024
- width=100 * kwargs["figsize"][0],
1025
- height=100 * kwargs["figsize"][1],
1026
- margin=dict(l=10, r=10, b=10, t=10),
1027
- plot_bgcolor="rgba(0,0,0,0)",
1028
- paper_bgcolor="rgba(0,0,0,0)",
1029
- showlegend=False,
1030
- scene=dict(
1031
- xaxis=dict(visible=False),
1032
- yaxis=dict(visible=False),
1033
- zaxis=dict(visible=False),
1034
- ),
1035
- )
1036
-
1037
- for i in range(len(edge_source["x0"])):
1038
- x0, x1 = edge_source["x0"][i], edge_source["x1"][i]
1039
- y0, y1 = edge_source["y0"][i], edge_source["y1"][i]
1040
- xm, ym = (x0 + x1) / 2, (y0 + y1) / 2
1041
- *rgb, alpha = edge_source["color"][i]
1042
- edge_kwargs = dict(
1043
- x=[x0, xm, x1],
1044
- y=[y0, ym, y1],
1045
- opacity=alpha,
1046
- line=dict(
1047
- color=to_rgba_str(rgb, 1.0),
1048
- width=edge_source["edge_size"][i],
1049
- ),
1050
- customdata=[[edge_source["ind"][i], edge_source["ind_size"][i]]]
1051
- * 2,
1052
- # show ind and ind_size on hover:
1053
- hovertemplate="%{customdata[0]}<br>size: %{customdata[1]}",
1054
- mode="lines",
1055
- name="",
1056
- )
1057
- if "z0" in edge_source:
1058
- z0, z1 = edge_source["z0"][i], edge_source["z1"][i]
1059
- zm = (z0 + z1) / 2
1060
- edge_kwargs["z"] = [z0, zm, z1]
1061
- # edges appear much thinner in 3D
1062
- edge_kwargs["line"]["width"] *= 2
1063
- fig.add_trace(go.Scatter3d(**edge_kwargs))
1064
- else:
1065
- fig.add_trace(go.Scatter(**edge_kwargs))
1066
-
1067
- node_kwargs = dict(
1068
- x=node_source["x"],
1069
- y=node_source["y"],
1070
- marker=dict(
1071
- opacity=1.0,
1072
- color=list(map(to_rgba_str, node_source["color"])),
1073
- size=[300 * s for s in node_source["size"]],
1074
- line=dict(
1075
- color=list(map(to_rgba_str, node_source["outline_color"])),
1076
- width=2,
1077
- ),
1078
- ),
1079
- customdata=list(
1080
- zip(node_source["tid"], node_source["shape"], node_source["tags"])
1081
- ),
1082
- hovertemplate=(
1083
- "tid: %{customdata[0]}<br>"
1084
- "shape: %{customdata[1]}<br>"
1085
- "tags: %{customdata[2]}"
1086
- ),
1087
- mode="markers",
1088
- name="",
1089
- )
1090
- if "z" in node_source:
1091
- node_kwargs["z"] = node_source["z"]
1092
- fig.add_trace(go.Scatter3d(**node_kwargs))
1093
- else:
1094
- fig.add_trace(go.Scatter(**node_kwargs))
1095
- fig.show()
1096
-
1097
-
1098
- # ---------------------------- layout functions ----------------------------- #
1099
-
1100
-
1101
- def _normalize_positions(pos):
1102
- # normalize to unit square
1103
- xmin = ymin = zmin = float("inf")
1104
- xmax = ymax = zmax = float("-inf")
1105
- for x, y, *maybe_z in pos.values():
1106
- xmin = min(xmin, x)
1107
- xmax = max(xmax, x)
1108
- ymin = min(ymin, y)
1109
- ymax = max(ymax, y)
1110
- for z in maybe_z:
1111
- zmin = min(zmin, z)
1112
- zmax = max(zmax, z)
1113
-
1114
- # maintain aspect ratio:
1115
- # center each dimension separately
1116
- xmid, ymid, zmid = (xmin + xmax) / 2, (ymin + ymax) / 2, (zmin + zmax) / 2
1117
- # but scale all dimensions by the largest range
1118
- xdiameter, ydiameter, zdiameter = xmax - xmin, ymax - ymin, zmax - zmin
1119
- radius = max((xdiameter, ydiameter, zdiameter)) / 2
1120
-
1121
- for node, (x, y, *maybe_z) in pos.items():
1122
- pos[node] = (
1123
- (x - xmid) / radius,
1124
- (y - ymid) / radius,
1125
- *((z - zmid) / radius for z in maybe_z),
1126
- )
1127
-
1128
- return pos
1129
-
1130
-
1131
- def _rotate(xy, theta):
1132
- """Return a rotated set of points."""
1133
- s = np.sin(theta)
1134
- c = np.cos(theta)
1135
-
1136
- xyr = np.empty_like(xy)
1137
- xyr[:, 0] = c * xy[:, 0] - s * xy[:, 1]
1138
- xyr[:, 1] = s * xy[:, 0] + c * xy[:, 1]
1139
-
1140
- return xyr
1141
-
1142
-
1143
- def _span(xy):
1144
- """Return the vertical span of the points."""
1145
- return xy[:, 1].max() - xy[:, 1].min()
1146
-
1147
-
1148
- def _massage_pos(pos, nangles=360, flatten=False):
1149
- """Rotate a position dict's points to cover a small vertical span"""
1150
- xy = np.empty((len(pos), 2))
1151
- for i, (x, y) in enumerate(pos.values()):
1152
- xy[i, 0] = x
1153
- xy[i, 1] = y
1154
-
1155
- thetas = np.linspace(0, 2 * np.pi, nangles, endpoint=False)
1156
- rxys = (_rotate(xy, theta) for theta in thetas)
1157
- rxy0 = min(rxys, key=lambda rxy: _span(rxy))
1158
-
1159
- if flatten:
1160
- rxy0[:, 1] /= 2
1161
-
1162
- return dict(zip(pos, rxy0))
1163
-
1164
-
1165
- def phyllotaxis_points(n):
1166
- """J. Kogan, "A New Computationally Efficient Method for Spacing Points on
1167
- a Sphere," Rose-Hulman Undergraduate Mathematics Journal, 18(2), 2017
1168
- Article 5. scholar.rose-hulman.edu/rhumj/vol18/iss2/5.
1169
- """
1170
-
1171
- def spherical_coordinate(x, y):
1172
- return [np.cos(x) * np.cos(y), np.sin(x) * np.cos(y), np.sin(y)]
1173
-
1174
- x = 0.1 + 1.2 * n
1175
- pts = []
1176
- start = -1.0 + 1.0 / (n - 1.0)
1177
- increment = (2.0 - 2.0 / (n - 1.0)) / (n - 1.0)
1178
- for j in range(n):
1179
- s = start + j * increment
1180
- pts.append(
1181
- spherical_coordinate(
1182
- s * x,
1183
- np.pi
1184
- / 2.0
1185
- * np.copysign(1, s)
1186
- * (1.0 - np.sqrt(1.0 - abs(s))),
1187
- )
1188
- )
1189
- return pts
1190
-
1191
-
1192
- def layout_single_tensor(tn, dim=2):
1193
- """Manually layout indices around a tensor either in a circle or sphere."""
1194
- ((tid, t),) = tn.tensor_map.items()
1195
-
1196
- if dim == 2.5:
1197
- dim = 3
1198
- project_back_to_2d = True
1199
- else:
1200
- project_back_to_2d = False
1201
-
1202
- pos = {tid: (0.0,) * dim}
1203
- if dim == 2:
1204
- # fix around a circle
1205
- angles = np.linspace(0, 2 * np.pi, t.ndim, endpoint=False)
1206
- for ind, angle in zip(t.inds, angles):
1207
- pos[ind] = (-np.cos(angle), np.sin(angle))
1208
- else:
1209
- # fix around a sphere
1210
- for ind, coo in zip(t.inds, phyllotaxis_points(t.ndim)):
1211
- pos[ind] = coo
1212
-
1213
- if project_back_to_2d:
1214
- pos = {k: v[:2] for k, v in pos.items()}
1215
-
1216
- return pos
1217
-
1218
-
1219
- def layout_networkx(
1220
- G,
1221
- layout="kamada_kawai",
1222
- pos0=None,
1223
- fixed=None,
1224
- dim=2,
1225
- **kwargs,
1226
- ):
1227
- import networkx as nx
1228
-
1229
- layout_fn = getattr(nx, layout + "_layout")
1230
-
1231
- if pos0 is not None:
1232
- if layout not in ("spring", "kamada_kawai"):
1233
- warnings.warn(
1234
- "Initial positions supplied but layout is not spring-based, "
1235
- "so `pos0` is being ignored."
1236
- )
1237
- else:
1238
- kwargs["pos"] = pos0
1239
-
1240
- if fixed is not None:
1241
- if layout != "spring":
1242
- warnings.warn(
1243
- "Fixed positions supplied but layout is not spring-based, "
1244
- "so `fixed` is being ignored."
1245
- )
1246
- else:
1247
- kwargs["fixed"] = fixed
1248
-
1249
- return layout_fn(G, dim=dim, **kwargs)
1250
-
1251
-
1252
- def layout_pygraphviz(
1253
- G,
1254
- layout="neato",
1255
- pos0=None,
1256
- fixed=None,
1257
- dim=2,
1258
- iterations=None,
1259
- k=None,
1260
- **kwargs,
1261
- ):
1262
- # TODO: max iters
1263
- # TODO: spring parameter
1264
- # TODO: work out why pos0 and fix don't work
1265
- import pygraphviz as pgv
1266
-
1267
- if k is not None:
1268
- warnings.warn(
1269
- "`k` is being ignored as layout is being done by pygraphviz."
1270
- )
1271
-
1272
- aG = pgv.AGraph()
1273
-
1274
- # create nodes
1275
- if pos0 is not None:
1276
- fixed = fixed or set()
1277
- for node, coo in pos0.items():
1278
- pos = ",".join((f"{w:f}" for w in coo))
1279
- pin = "true" if node in fixed else "false"
1280
- aG.add_node(str(node), pos=pos, pin=pin)
1281
-
1282
- warnings.warn(
1283
- "Initial and fixed positions don't seem "
1284
- "to work currently with pygraphviz."
1285
- )
1286
-
1287
- # create edges
1288
- mapping = {}
1289
- for nodea, nodeb in G.edges():
1290
- s_nodea = str(nodea)
1291
- s_nodeb = str(nodeb)
1292
- mapping[s_nodea] = nodea
1293
- mapping[s_nodeb] = nodeb
1294
- aG.add_edge(s_nodea, s_nodeb)
1295
-
1296
- # layout options
1297
- if iterations is not None:
1298
- kwargs["maxiter"] = iterations
1299
- if dim == 2.5:
1300
- kwargs["dim"] = 3
1301
- kwargs["dimen"] = 2
1302
- else:
1303
- kwargs["dim"] = kwargs["dimen"] = dim
1304
- args = " ".join(f"-G{k}={v}" for k, v in kwargs.items())
1305
-
1306
- # run layout algorithm
1307
- aG.layout(prog=layout, args=args)
1308
-
1309
- # extract layout
1310
- pos = {}
1311
- for snode, node in mapping.items():
1312
- spos = aG.get_node(snode).attr["pos"]
1313
- pos[node] = tuple(map(float, spos.split(",")))
1314
-
1315
- pos = _normalize_positions(pos)
1316
- if dim < 3:
1317
- pos = _massage_pos(pos)
1318
-
1319
- return pos
1320
-
1321
-
1322
- def get_positions(
1323
- tn,
1324
- G,
1325
- *,
1326
- dim=2,
1327
- fix=None,
1328
- layout="auto",
1329
- initial_layout="auto",
1330
- refine_layout="auto",
1331
- iterations="auto",
1332
- k=None,
1333
- ):
1334
- if (tn is not None) and (tn.num_tensors == 1) and (fix is None):
1335
- # single tensor, layout manually
1336
- return layout_single_tensor(tn, dim=dim)
1337
-
1338
- if layout != "auto":
1339
- # don't use two step layout with relaxation
1340
- initial_layout = layout
1341
- iterations = 0
1342
-
1343
- if fix is None:
1344
- fix = dict()
1345
- else:
1346
- fix = parse_dict_to_tids_or_inds(fix, tn)
1347
- # find range with which to scale spectral points with
1348
- xmin, xmax, ymin, ymax = (
1349
- f(fix.values(), key=lambda xy: xy[i])[i]
1350
- for f, i in [(min, 0), (max, 0), (min, 1), (max, 1)]
1351
- )
1352
- if xmin == xmax:
1353
- xmin, xmax = xmin - 1, xmax + 1
1354
- if ymin == ymax:
1355
- ymin, ymax = ymin - 1, ymax + 1
1356
- xymin, xymax = min(xmin, ymin), max(xmax, ymax)
1357
-
1358
- if all(node in fix for node in G.nodes):
1359
- # everything is already fixed -> simply normalize
1360
- return _normalize_positions(fix)
1361
-
1362
- if initial_layout == "auto":
1363
- # automatically select
1364
- if len(G) <= 100:
1365
- # usually nicest
1366
- initial_layout = "kamada_kawai"
1367
- else:
1368
- # faster, but not as nice
1369
- initial_layout = "spectral"
1370
-
1371
- if refine_layout == "auto":
1372
- # automatically select
1373
- refine_layout = "spring"
1374
- # if len(G) <= 100:
1375
- # # usually nicest
1376
- # refine_layout = "fdp"
1377
- # else:
1378
- # # faster, but not as nice
1379
- # refine_layout = "sfdp"
1380
-
1381
- if iterations == "auto":
1382
- # the smaller the graph, the more iterations we can afford
1383
- iterations = max(200, 1000 - len(G))
1384
-
1385
- if dim == 2.5:
1386
- dim = 3
1387
- project_back_to_2d = True
1388
- else:
1389
- project_back_to_2d = False
1390
-
1391
- # use spectral or other layout as starting point
1392
- if initial_layout in ("neato", "fdp", "sfdp", "dot"):
1393
- pos0 = layout_pygraphviz(G, initial_layout, dim=dim)
1394
- else:
1395
- pos0 = layout_networkx(G, initial_layout, dim=dim)
1396
-
1397
- # scale points to fit with specified positions
1398
- if fix:
1399
- # but update with fixed positions
1400
- pos0.update(
1401
- valmap(
1402
- lambda xy: np.array(
1403
- (
1404
- 2 * (xy[0] - xymin) / (xymax - xymin) - 1,
1405
- 2 * (xy[1] - xymin) / (xymax - xymin) - 1,
1406
- )
1407
- ),
1408
- fix,
1409
- )
1410
- )
1411
- fixed = fix.keys()
1412
- else:
1413
- fixed = None
1414
-
1415
- # and then relax remaining using spring layout
1416
- if iterations:
1417
- if refine_layout == "spring":
1418
- pos = layout_networkx(
1419
- G,
1420
- "spring",
1421
- pos0=pos0,
1422
- fixed=fixed,
1423
- k=k,
1424
- dim=dim,
1425
- iterations=iterations,
1426
- )
1427
- elif refine_layout in ("fdp", "sfdp", "neato"):
1428
- # XXX: currently doesn't seem to work with pos0 and fixed
1429
- pos = layout_pygraphviz(
1430
- G,
1431
- refine_layout,
1432
- pos0=pos0,
1433
- fixed=fixed,
1434
- k=k,
1435
- dim=dim,
1436
- iterations=iterations,
1437
- )
1438
- else:
1439
- raise ValueError(f"Unknown refining layout {refine_layout}.")
1440
- else:
1441
- # no relaxation
1442
- pos = pos0
1443
-
1444
- if project_back_to_2d:
1445
- # ignore z-coordinate
1446
- pos = {k: v[:2] for k, v in pos.items()}
1447
- dim = 2
1448
-
1449
- # map all to range [-1, +1], but preserving aspect ratio
1450
- pos = _normalize_positions(pos)
1451
-
1452
- if (not fix) and (dim == 2):
1453
- # finally rotate them to cover a small vertical span
1454
- pos = _massage_pos(pos)
1455
-
1456
- return pos
1457
-
1458
-
1459
- # ----------------------------- color functions ----------------------------- #
1460
-
1461
-
1462
- def get_colors(color, custom_colors=None, alpha=None):
1463
- """Generate a sequence of rgbs for tag(s) ``color``."""
1464
- from matplotlib.colors import to_rgba
1465
-
1466
- from ..schematic import auto_colors
1467
-
1468
- if color is None:
1469
- return dict()
1470
-
1471
- if isinstance(color, str):
1472
- color = (color,)
1473
-
1474
- if custom_colors is not None:
1475
- rgbs = [to_rgba(c, alpha=alpha) for c in custom_colors]
1476
- return dict(zip(color, rgbs))
1477
-
1478
- nc = len(color)
1479
- if nc <= 7:
1480
- rgbs = auto_colors(nc, alpha=alpha, default_sequence=True)
1481
- return dict(zip(color, rgbs))
1482
-
1483
- rgbs = auto_colors(nc, alpha)
1484
- return dict(zip(color, rgbs))
1485
-
1486
-
1487
- def to_rgba_str(color, alpha=None):
1488
- from matplotlib.colors import to_rgba
1489
-
1490
- rgba = to_rgba(color, alpha)
1491
- r = int(rgba[0] * 255) if isinstance(rgba[0], float) else rgba[0]
1492
- g = int(rgba[1] * 255) if isinstance(rgba[1], float) else rgba[1]
1493
- b = int(rgba[2] * 255) if isinstance(rgba[2], float) else rgba[2]
1494
- return f"rgba({r}, {g}, {b}, {rgba[3]})"
1495
-
1496
-
1497
- def auto_color_html(s):
1498
- """Automatically hash and color a string for HTML display."""
1499
- from ..schematic import hash_to_color
1500
-
1501
- if not isinstance(s, str):
1502
- s = str(s)
1503
- return f'<b style="color: {hash_to_color(s)};">{s}</b>'
1504
-
1505
-
1506
- # ---------------------------- tensor functions ----------------------------- #
1507
-
1508
-
1509
- def visualize_tensor(tensor, **kwargs):
1510
- """Visualize all entries of a tensor, with indices mapped into the plane
1511
- and values mapped into a color wheel.
1512
-
1513
- Parameters
1514
- ----------
1515
- tensor : Tensor
1516
- The tensor to visualize.
1517
- skew_factor : float, optional
1518
- When there are more than two dimensions, a factor to scale the
1519
- rotations by to avoid overlapping data points.
1520
- size_map : bool, optional
1521
- Whether to map the tensor value magnitudes to marker size.
1522
- size_scale : float, optional
1523
- An overall factor to scale the marker size by.
1524
- alpha_map : bool, optional
1525
- Whether to map the tensor value magnitudes to marker alpha.
1526
- alpha_pow : float, optional
1527
- The power to raise the magnitude to when mapping to alpha.
1528
- alpha : float, optional
1529
- The overall alpha to use for all markers if ``not alpha_map``.
1530
- show_lattice : bool, optional
1531
- Show a small grey dot for every 'lattice' point regardless of value.
1532
- lattice_opts : dict, optional
1533
- Options to pass to ``maplotlib.Axis.scatter`` for the lattice points.
1534
- linewidths : float, optional
1535
- The linewidth to use for the markers.
1536
- marker : str, optional
1537
- The marker to use for the markers.
1538
- figsize : tuple, optional
1539
- The size of the figure to create, if ``ax`` is not provided.
1540
- ax : matplotlib.Axis, optional
1541
- The axis to draw to. If not provided, a new figure will be created.
1542
-
1543
- Returns
1544
- -------
1545
- fig : matplotlib.Figure
1546
- The figure containing the plot, or ``None`` if ``ax`` was provided.
1547
- ax : matplotlib.Axis
1548
- The axis containing the plot.
1549
- """
1550
- import xyzpy as xyz
1551
-
1552
- kwargs.setdefault("legend", True)
1553
- kwargs.setdefault("compass", True)
1554
- kwargs.setdefault("compass_labels", tensor.inds)
1555
- return xyz.visualize_tensor(tensor.data, **kwargs)
1556
-
1557
-
1558
- def choose_squarest_grid(x):
1559
- p = x**0.5
1560
- if p.is_integer():
1561
- m = n = int(p)
1562
- else:
1563
- m = int(round(p))
1564
- p = int(p)
1565
- n = p if m * p >= x else p + 1
1566
- return m, n
1567
-
1568
-
1569
- def visualize_tensors(
1570
- tn,
1571
- mode="network",
1572
- r=None,
1573
- r_scale=1.0,
1574
- figsize=None,
1575
- **visualize_opts,
1576
- ):
1577
- """Visualize all the entries of every tensor in this network.
1578
-
1579
- Parameters
1580
- ----------
1581
- tn : TensorNetwork
1582
- The tensor network to visualize.
1583
- mode : {'network', 'grid', 'row', 'col'}, optional
1584
- How to arrange each tensor's visualization.
1585
-
1586
- - ``'network'``: arrange each tensor's visualization according to the
1587
- automatic layout given by ``draw``.
1588
- - ``'grid'``: arrange each tensor's visualization in a grid.
1589
- - ``'row'``: arrange each tensor's visualization horizontally.
1590
- - ``'col'``: arrange each tensor's visualization vertically.
1591
-
1592
- r : float, optional
1593
- The absolute radius of each tensor's visualization, when
1594
- ``mode='network'``.
1595
- r_scale : float, optional
1596
- A relative scaling factor for the radius of each tensor's
1597
- visualization, when ``mode='network'``.
1598
- figsize : tuple, optional
1599
- The size of the figure to create, if ``ax`` is not provided.
1600
- visualize_opts
1601
- Supplied to ``visualize_tensor``.
1602
- """
1603
- from matplotlib import pyplot as plt
1604
-
1605
- if figsize is None:
1606
- figsize = (2 * tn.num_tensors**0.4, 2 * tn.num_tensors**0.4)
1607
- if r is None:
1608
- r = 1.0 / tn.num_tensors**0.5
1609
- r *= r_scale
1610
-
1611
- max_mag = None
1612
- visualize_opts.setdefault("max_mag", max_mag)
1613
- visualize_opts.setdefault("size_scale", r)
1614
-
1615
- if mode == "network":
1616
- fig = plt.figure(figsize=figsize)
1617
- pos = tn.draw(get="pos")
1618
- for tid, (x, y) in pos.items():
1619
- if tid not in tn.tensor_map:
1620
- # hyper indez
1621
- continue
1622
- x = (x + 1) / 2 - r / 2
1623
- y = (y + 1) / 2 - r / 2
1624
- ax = fig.add_axes((x, y, r / 2, r / 2))
1625
- tn.tensor_map[tid].visualize(ax=ax, **visualize_opts)
1626
- else:
1627
- if mode == "grid":
1628
- px, py = choose_squarest_grid(tn.num_tensors)
1629
- elif mode == "row":
1630
- px, py = tn.num_tensors, 1
1631
- figsize = (2 * figsize[0], figsize[1] / 2)
1632
- elif mode == "col":
1633
- px, py = 1, tn.num_tensors
1634
- figsize = (figsize[0] / 2, 2 * figsize[1])
1635
-
1636
- fig, axs = plt.subplots(py, px, figsize=figsize)
1637
- for i, t in enumerate(tn):
1638
- t.visualize(ax=axs.flat[i], **visualize_opts)
1639
- for ax in axs.flat[i:]:
1640
- ax.set_axis_off()
1641
-
1642
- # transparent background
1643
- fig.patch.set_alpha(0.0)
1644
-
1645
- plt.show()
1646
- plt.close()