howso-visuals 3.1.2__tar.gz → 3.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/PKG-INFO +1 -1
- howso_visuals-3.2.0/howso/visuals/graph.py +480 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/howso/visuals/tests/test_utilities.py +53 -1
- howso_visuals-3.2.0/howso/visuals/utilities.py +165 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/howso_visuals.egg-info/PKG-INFO +1 -1
- howso_visuals-3.1.2/howso/visuals/graph.py +0 -258
- howso_visuals-3.1.2/howso/visuals/utilities.py +0 -79
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/.flake8 +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/.github/CODEOWNERS +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/.github/templates/version_summary.md +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/.github/workflows/build-pr.yml +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/.github/workflows/build-release.yml +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/.github/workflows/build.yml +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/.github/workflows/rebuild-requirements.yml +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/.gitignore +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/CONTRIBUTING.md +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/LICENSE-3RD-PARTY.txt +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/LICENSE.txt +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/README.md +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/bin/build.sh +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/config/latest-mt-debug-howso.yml +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/config/latest-mt-howso.yml +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/config/latest-mt-noavx-debug-howso.yml +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/config/latest-st-debug-howso.yml +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/config/latest-st-howso.yml +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/config/powershell/Download-Tzdata.ps1 +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/config/powershell/Helper-Functions.ps1 +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/howso/visuals/__init__.py +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/howso/visuals/colors.py +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/howso/visuals/data/iris.csv +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/howso/visuals/tests/conftest.py +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/howso/visuals/tests/test_graph.py +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/howso/visuals/tests/test_plot.py +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/howso/visuals/visuals.py +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/howso_visuals.egg-info/SOURCES.txt +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/howso_visuals.egg-info/dependency_links.txt +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/howso_visuals.egg-info/requires.txt +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/howso_visuals.egg-info/top_level.txt +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/pyproject.toml +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/requirements-3.10-dev.txt +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/requirements-3.10.txt +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/requirements-3.11-dev.txt +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/requirements-3.11.txt +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/requirements-3.12-dev.txt +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/requirements-3.12.txt +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/requirements-3.13-dev.txt +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/requirements-3.13.txt +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/ruff.toml +0 -0
- {howso_visuals-3.1.2 → howso_visuals-3.2.0}/setup.cfg +0 -0
|
@@ -0,0 +1,480 @@
|
|
|
1
|
+
from collections.abc import Callable, Collection, Mapping, Sequence
|
|
2
|
+
import math
|
|
3
|
+
from typing import Any, SupportsInt, TypeAlias
|
|
4
|
+
|
|
5
|
+
import networkx as nx
|
|
6
|
+
import numpy as np
|
|
7
|
+
import plotly.colors as pc
|
|
8
|
+
import plotly.graph_objects as go
|
|
9
|
+
from plotly.subplots import make_subplots
|
|
10
|
+
from sklearn.preprocessing import minmax_scale
|
|
11
|
+
|
|
12
|
+
from .utilities import compact_number
|
|
13
|
+
|
|
14
|
+
LayoutMapping: TypeAlias = Mapping[Any, tuple[float, float]]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _create_edge_annotations(
|
|
18
|
+
G: nx.Graph, # noqa: N803
|
|
19
|
+
pos: LayoutMapping,
|
|
20
|
+
edge_attr: str | None = None,
|
|
21
|
+
edge_attr_sigfigs: SupportsInt | None = 4,
|
|
22
|
+
label_edges: bool = True,
|
|
23
|
+
uncertain_edges: Collection[tuple[str, str]] | None = None,
|
|
24
|
+
uncertain_edge_opacity: float = 0.3,
|
|
25
|
+
) -> tuple[list[go.layout.Annotation], list[dict[str, Any]]]:
|
|
26
|
+
# Annotations are created to show the edges between nodes,
|
|
27
|
+
# while invisible shapes with labels are created to label them with the edge weight.
|
|
28
|
+
annotations = []
|
|
29
|
+
shapes = []
|
|
30
|
+
directed = nx.is_directed(G)
|
|
31
|
+
|
|
32
|
+
widths = None
|
|
33
|
+
unscaled_widths = None
|
|
34
|
+
if edge_attr is not None:
|
|
35
|
+
unscaled_widths = [d[edge_attr] for _, _, d in G.edges(data=True)]
|
|
36
|
+
widths = minmax_scale(np.array(unscaled_widths).reshape(-1, 1), (2, 5))
|
|
37
|
+
widths = widths.reshape(-1)
|
|
38
|
+
|
|
39
|
+
edge_blacklist = set()
|
|
40
|
+
|
|
41
|
+
for i, (s, d) in enumerate(G.edges()):
|
|
42
|
+
if (s, d) in edge_blacklist:
|
|
43
|
+
continue
|
|
44
|
+
|
|
45
|
+
x0, y0 = pos[s]
|
|
46
|
+
x1, y1 = pos[d]
|
|
47
|
+
width = widths[i] if widths is not None else 2
|
|
48
|
+
|
|
49
|
+
if directed and G.has_edge(d, s):
|
|
50
|
+
edge_blacklist.add((d, s))
|
|
51
|
+
arrowside = "end+start"
|
|
52
|
+
elif not directed:
|
|
53
|
+
arrowside = "none"
|
|
54
|
+
else:
|
|
55
|
+
arrowside = "end"
|
|
56
|
+
|
|
57
|
+
if uncertain_edges and ((s, d) in uncertain_edges or (d, s) in uncertain_edges):
|
|
58
|
+
opacity = uncertain_edge_opacity
|
|
59
|
+
arrowside = "none"
|
|
60
|
+
else:
|
|
61
|
+
opacity = 0.8
|
|
62
|
+
|
|
63
|
+
annotations.append(
|
|
64
|
+
go.layout.Annotation(
|
|
65
|
+
ax=x0,
|
|
66
|
+
ay=y0,
|
|
67
|
+
axref="x",
|
|
68
|
+
ayref="y",
|
|
69
|
+
x=x1,
|
|
70
|
+
y=y1,
|
|
71
|
+
xref="x",
|
|
72
|
+
yref="y",
|
|
73
|
+
showarrow=True,
|
|
74
|
+
arrowhead=4,
|
|
75
|
+
standoff=40.5,
|
|
76
|
+
startstandoff=37.5,
|
|
77
|
+
arrowside=arrowside,
|
|
78
|
+
arrowwidth=width,
|
|
79
|
+
opacity=opacity,
|
|
80
|
+
captureevents=True,
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
if label_edges:
|
|
85
|
+
if edge_attr_sigfigs is not None and unscaled_widths is not None:
|
|
86
|
+
shape_label = compact_number(unscaled_widths[i], edge_attr_sigfigs)
|
|
87
|
+
elif unscaled_widths is not None:
|
|
88
|
+
shape_label = f"{unscaled_widths[i]}"
|
|
89
|
+
else:
|
|
90
|
+
shape_label = ""
|
|
91
|
+
else:
|
|
92
|
+
shape_label = ""
|
|
93
|
+
|
|
94
|
+
shape_label = (
|
|
95
|
+
'<span style="text-shadow: -1px -1px 0 #fff, 1px -1px 0 #fff, -1px 1px 0 #fff, 1px 1px 0 #fff;">'
|
|
96
|
+
f"{shape_label}</span>"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
shapes.append(
|
|
100
|
+
dict(
|
|
101
|
+
type="line",
|
|
102
|
+
x0=x0,
|
|
103
|
+
y0=y0,
|
|
104
|
+
x1=x1,
|
|
105
|
+
y1=y1,
|
|
106
|
+
xref="x",
|
|
107
|
+
yref="y",
|
|
108
|
+
label=dict(text=shape_label),
|
|
109
|
+
opacity=0,
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return annotations, shapes
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def plot_graph(
|
|
117
|
+
G: nx.Graph, # noqa: N803
|
|
118
|
+
*,
|
|
119
|
+
colorscale: str | Sequence[tuple[float, str]] = "Bluered",
|
|
120
|
+
cscale_tuple: tuple[float, float, float] | None = None,
|
|
121
|
+
edge_attr_sigfigs: SupportsInt | None = 4,
|
|
122
|
+
edge_attr: str | None = None,
|
|
123
|
+
label_edges: bool = True,
|
|
124
|
+
layout: Callable[[nx.Graph], LayoutMapping] = nx.shell_layout,
|
|
125
|
+
node_color: list[float] | None = None,
|
|
126
|
+
subtitle: str | None = None,
|
|
127
|
+
title: str = "Causal Graph",
|
|
128
|
+
uncertain_edges: Collection[tuple[str, str]] | None = None,
|
|
129
|
+
uncertain_edge_opacity: float = 0.3,
|
|
130
|
+
) -> go.Figure:
|
|
131
|
+
"""
|
|
132
|
+
Plot a ``networkx`` graph using `Plotly`.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
G : nx.Graph
|
|
137
|
+
The graph to plot.
|
|
138
|
+
colorscale : str | Sequence[tuple[float, str]], default "Bluered"
|
|
139
|
+
The colorscale to use when plotting nodes using ``node_color``. Defaults to `Plotly`'s reversed "Bluered"
|
|
140
|
+
colorscale.
|
|
141
|
+
cscale_tuple : tuple[float, float, float], optional
|
|
142
|
+
The tuple of values (``cmin``, ``cmid``, ``cmax``) to use for the colorscale. If None, ``(3, 15, 30)`` will be used.
|
|
143
|
+
edge_attr : str, optional
|
|
144
|
+
The name of the edge attribute to use when scaling the size of the edges. This should
|
|
145
|
+
be an attribute that is contained within ``G``.
|
|
146
|
+
edge_attr_sigfigs : SupportsInt | None, default 4
|
|
147
|
+
The number of significant figures to round to when labelling each edge. If None, no rounding
|
|
148
|
+
will be performed.
|
|
149
|
+
label_edges : bool, default True
|
|
150
|
+
Whether to label plotted edges.
|
|
151
|
+
layout : Callable[nx.Graph, Mapping[Any, tuple[float, float]]], default nx.shell_layout
|
|
152
|
+
A callable which generates a mapping of nodes to ``(x, y)`` coordinates.
|
|
153
|
+
node_color : list[float], optional
|
|
154
|
+
The data to use when determining the color for each node.
|
|
155
|
+
subtitle : str, optional
|
|
156
|
+
The subtitle of the plot.
|
|
157
|
+
title : str, default "Causal Graph"
|
|
158
|
+
The title of the plot.
|
|
159
|
+
uncertain_edges : Collection[tuple[str, str]], optional
|
|
160
|
+
Edges that are deemed uncertain by the caller. These will be plotted with an opacity equal to
|
|
161
|
+
``uncertain_edge_opacity`` and will not have directional arrows.
|
|
162
|
+
uncertain_edge_opacity : float, default 0.3
|
|
163
|
+
The opacity use when plotting edges contained in ``uncertain_edges``.
|
|
164
|
+
|
|
165
|
+
Returns
|
|
166
|
+
-------
|
|
167
|
+
go.Figure
|
|
168
|
+
The resultant `Plotly` figure.
|
|
169
|
+
"""
|
|
170
|
+
pos = layout(G, center=(1, 1))
|
|
171
|
+
|
|
172
|
+
text = []
|
|
173
|
+
node_x = []
|
|
174
|
+
node_y = []
|
|
175
|
+
for node in G.nodes():
|
|
176
|
+
text.append(node)
|
|
177
|
+
x, y = pos[node]
|
|
178
|
+
node_x.append(x)
|
|
179
|
+
node_y.append(y)
|
|
180
|
+
|
|
181
|
+
# This places a 1px black border around the node labels.
|
|
182
|
+
text = [
|
|
183
|
+
f'<span style="text-shadow: -1px -1px 0 #000, 1px -1px 0 #000, -1px 1px 0 #000, 1px 1px 0 #000;">{t}</span>'
|
|
184
|
+
for t in text
|
|
185
|
+
]
|
|
186
|
+
hovertemplate = "<b>%{text}</b>"
|
|
187
|
+
if node_color is not None:
|
|
188
|
+
hovertemplate += "<br>Destination MIR: %{customdata[0]:.4f}</br>"
|
|
189
|
+
|
|
190
|
+
node_trace = go.Scatter(
|
|
191
|
+
x=node_x,
|
|
192
|
+
y=node_y,
|
|
193
|
+
text=text,
|
|
194
|
+
textposition="middle center",
|
|
195
|
+
mode="markers+text",
|
|
196
|
+
marker=dict(
|
|
197
|
+
color=node_color,
|
|
198
|
+
coloraxis="coloraxis",
|
|
199
|
+
size=75,
|
|
200
|
+
),
|
|
201
|
+
zorder=999,
|
|
202
|
+
textfont=dict(color="white"),
|
|
203
|
+
name="Nodes",
|
|
204
|
+
customdata=[[x] for x in node_color] if node_color is not None else None,
|
|
205
|
+
hovertemplate=hovertemplate,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
annotations, shapes = _create_edge_annotations(
|
|
209
|
+
G,
|
|
210
|
+
pos,
|
|
211
|
+
edge_attr=edge_attr,
|
|
212
|
+
edge_attr_sigfigs=edge_attr_sigfigs,
|
|
213
|
+
label_edges=label_edges,
|
|
214
|
+
uncertain_edges=uncertain_edges,
|
|
215
|
+
uncertain_edge_opacity=uncertain_edge_opacity,
|
|
216
|
+
)
|
|
217
|
+
fig = go.Figure(
|
|
218
|
+
layout=go.Layout(
|
|
219
|
+
title=dict(text="<br>Network graph made with Python", font=dict(size=16)),
|
|
220
|
+
showlegend=False,
|
|
221
|
+
hovermode="closest",
|
|
222
|
+
margin=dict(b=8, l=8, r=8, t=48),
|
|
223
|
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, constrain="domain"),
|
|
224
|
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, constrain="domain"),
|
|
225
|
+
annotations=annotations,
|
|
226
|
+
)
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
if cscale_tuple is None:
|
|
230
|
+
cbot = 0
|
|
231
|
+
cmin = 1
|
|
232
|
+
cmid = 3
|
|
233
|
+
cmax = 10
|
|
234
|
+
else:
|
|
235
|
+
cbot = 0
|
|
236
|
+
cmin = cscale_tuple[0]
|
|
237
|
+
cmid = cscale_tuple[1]
|
|
238
|
+
cmax = cscale_tuple[2]
|
|
239
|
+
|
|
240
|
+
fig.update_layout(
|
|
241
|
+
coloraxis=dict(
|
|
242
|
+
colorscale=colorscale,
|
|
243
|
+
cmin=cbot,
|
|
244
|
+
cmid=cmid,
|
|
245
|
+
cmax=cmax,
|
|
246
|
+
colorbar=dict(
|
|
247
|
+
title="Missing Information",
|
|
248
|
+
tickvals=[cbot, cmin, cmid, cmax],
|
|
249
|
+
ticktext=[f"{cbot}", f"{cmin}", f"{cmid}", f"≥{cmax}"],
|
|
250
|
+
),
|
|
251
|
+
),
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
for s in shapes:
|
|
255
|
+
fig.add_shape(**s)
|
|
256
|
+
fig.add_trace(node_trace)
|
|
257
|
+
|
|
258
|
+
fig.update_layout(
|
|
259
|
+
title=dict(text=title, subtitle=dict(text=subtitle), xref="paper", xanchor="left", x=0),
|
|
260
|
+
width=1000,
|
|
261
|
+
height=750,
|
|
262
|
+
)
|
|
263
|
+
return fig
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _remap_axis_ref(ref: str | None, ax_idx: int) -> str | None:
|
|
267
|
+
"""Remap an axis ref given new index."""
|
|
268
|
+
if ref in (None, "pixel", "paper"):
|
|
269
|
+
return ref
|
|
270
|
+
if ref.startswith("x"):
|
|
271
|
+
return "x" if ax_idx == 1 else f"x{ax_idx}"
|
|
272
|
+
if ref.startswith("y"):
|
|
273
|
+
return "y" if ax_idx == 1 else f"y{ax_idx}"
|
|
274
|
+
return ref
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def compare_network_figures( # noqa: PLR0912, PLR0915
|
|
278
|
+
figures: Sequence[go.Figure],
|
|
279
|
+
*,
|
|
280
|
+
columns: int | None = None,
|
|
281
|
+
per_row_colorbar: bool = True,
|
|
282
|
+
subplot_titles: Sequence[str | None] | None = None,
|
|
283
|
+
title: str | None = None,
|
|
284
|
+
width: int = 800,
|
|
285
|
+
height: int = 650,
|
|
286
|
+
) -> go.Figure:
|
|
287
|
+
"""
|
|
288
|
+
Combine multiple network graphs for comparison.
|
|
289
|
+
|
|
290
|
+
Parameters
|
|
291
|
+
----------
|
|
292
|
+
figures : Sequence[go.Figure]
|
|
293
|
+
Network figures to compare. All must share the same colorbar scale.
|
|
294
|
+
columns : int, optional
|
|
295
|
+
The number of columns of figures. If unspecified, all figures will be rendered side by side on the same row.
|
|
296
|
+
per_row_colorbar : bool, default True
|
|
297
|
+
Show the colorbar on each row.
|
|
298
|
+
subplot_titles : Sequence[str | None], optional
|
|
299
|
+
Set the title of each individual figure. Using None will inherit the original figure title.
|
|
300
|
+
title : str, optional
|
|
301
|
+
Set an overall figure title.
|
|
302
|
+
width : int, default 800
|
|
303
|
+
The width of each individual figure. Clamped to a minimum of 400.
|
|
304
|
+
height : int, default 650
|
|
305
|
+
The height of each individual figure. Clamped to a minimum of 400.
|
|
306
|
+
|
|
307
|
+
Returns
|
|
308
|
+
-------
|
|
309
|
+
go.Figure
|
|
310
|
+
The resulting `Plotly` figure.
|
|
311
|
+
"""
|
|
312
|
+
n_figs = len(figures)
|
|
313
|
+
n_cols = min(n_figs, columns) if columns else n_figs
|
|
314
|
+
n_rows = math.ceil(n_figs / n_cols)
|
|
315
|
+
height = max(height, 400)
|
|
316
|
+
width = max(width, 400)
|
|
317
|
+
|
|
318
|
+
if columns is not None and columns < 1:
|
|
319
|
+
raise ValueError("When specified, `columns` must be greater than 0.")
|
|
320
|
+
|
|
321
|
+
uses_coloraxis = False
|
|
322
|
+
coloraxis = None
|
|
323
|
+
color_ranges = []
|
|
324
|
+
resolved_titles: list[str] = []
|
|
325
|
+
for i, fig in enumerate(figures):
|
|
326
|
+
# Capture titles provided or from figures (each title must be truthy)
|
|
327
|
+
if subplot_titles is not None and i < len(subplot_titles):
|
|
328
|
+
if subplot_titles[i] is None:
|
|
329
|
+
resolved_titles.append(fig.layout.title.text or " ")
|
|
330
|
+
else:
|
|
331
|
+
resolved_titles.append(subplot_titles[i] or " ")
|
|
332
|
+
else:
|
|
333
|
+
resolved_titles.append(fig.layout.title.text or " ")
|
|
334
|
+
|
|
335
|
+
if fig.data == tuple():
|
|
336
|
+
continue # Skip empty figures
|
|
337
|
+
|
|
338
|
+
# Validate trace and coloraxis
|
|
339
|
+
for trace in fig.data:
|
|
340
|
+
if not isinstance(trace, go.Scatter):
|
|
341
|
+
raise TypeError("All figures must be network graph figures.")
|
|
342
|
+
ca = fig.layout.coloraxis
|
|
343
|
+
if ca is not None:
|
|
344
|
+
color_ranges.append((ca.cmin, ca.cmid, ca.cmax))
|
|
345
|
+
if coloraxis is None:
|
|
346
|
+
# Capture first non empty network plot's color axis
|
|
347
|
+
coloraxis = ca.to_plotly_json()
|
|
348
|
+
if len(set(color_ranges)) > 1:
|
|
349
|
+
raise ValueError("All figures must share the same colorbar scale.")
|
|
350
|
+
|
|
351
|
+
horizontal_spacing = 0.02 / n_cols
|
|
352
|
+
vertical_gap = 0.01
|
|
353
|
+
if any(t != " " for t in resolved_titles):
|
|
354
|
+
vertical_gap = 0.05
|
|
355
|
+
vertical_spacing = (vertical_gap / n_rows) * (800 / height)
|
|
356
|
+
sub = make_subplots(
|
|
357
|
+
rows=n_rows,
|
|
358
|
+
cols=n_cols,
|
|
359
|
+
subplot_titles=resolved_titles,
|
|
360
|
+
horizontal_spacing=horizontal_spacing,
|
|
361
|
+
vertical_spacing=vertical_spacing,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# Left align titles
|
|
365
|
+
for i in range(len(sub.layout.annotations)):
|
|
366
|
+
row = (i // n_cols) + 1
|
|
367
|
+
col = (i % n_cols) + 1
|
|
368
|
+
# Get the actual domain Plotly computed for this subplot
|
|
369
|
+
ax_key = "xaxis" if (row == 1 and col == 1) else f"xaxis{(row - 1) * n_cols + col}"
|
|
370
|
+
left_edge = sub.layout[ax_key].domain[0]
|
|
371
|
+
sub.layout.annotations[i].x = left_edge
|
|
372
|
+
sub.layout.annotations[i].xanchor = "left"
|
|
373
|
+
|
|
374
|
+
# Add all the plots
|
|
375
|
+
for index, fig in enumerate(figures):
|
|
376
|
+
row = (index // n_cols) + 1
|
|
377
|
+
col = (index % n_cols) + 1
|
|
378
|
+
|
|
379
|
+
if fig.data == tuple():
|
|
380
|
+
# Render empty plot
|
|
381
|
+
sub.update_xaxes(range=[0, 1], row=row, col=col) # Prepare axes for centering label
|
|
382
|
+
sub.update_yaxes(range=[0, 1], row=row, col=col)
|
|
383
|
+
sub.add_shape(type="rect", x0=0, y0=0, x1=1, y1=1, line=dict(width=0), row=row, col=col)
|
|
384
|
+
sub.add_annotation(
|
|
385
|
+
x=0.5,
|
|
386
|
+
y=0.5,
|
|
387
|
+
text="No Results",
|
|
388
|
+
showarrow=False,
|
|
389
|
+
font=dict(size=14),
|
|
390
|
+
xref=_remap_axis_ref("x", index + 1),
|
|
391
|
+
yref=_remap_axis_ref("y", index + 1),
|
|
392
|
+
)
|
|
393
|
+
else:
|
|
394
|
+
# Traces
|
|
395
|
+
for trace in fig.data:
|
|
396
|
+
if hasattr(trace.marker, "coloraxis"):
|
|
397
|
+
# Apply coloraxis or marker color to each trace
|
|
398
|
+
if trace.marker.color is not None:
|
|
399
|
+
uses_coloraxis = coloraxis is not None # If at least one plot uses MIR
|
|
400
|
+
if per_row_colorbar:
|
|
401
|
+
# coloraxis per row
|
|
402
|
+
trace.marker.coloraxis = f"coloraxis{row}" if row > 1 else "coloraxis"
|
|
403
|
+
else:
|
|
404
|
+
# single shared coloraxis
|
|
405
|
+
trace.marker.coloraxis = "coloraxis"
|
|
406
|
+
else:
|
|
407
|
+
# Figure doesn't use MIR, use static color for all nodes
|
|
408
|
+
trace.marker.color = pc.qualitative.Safe_r[0]
|
|
409
|
+
trace.hoverlabel.font.color = "#ffffff"
|
|
410
|
+
sub.add_trace(trace, row=row, col=col)
|
|
411
|
+
|
|
412
|
+
# Shapes (edges)
|
|
413
|
+
for shape in fig.layout.shapes:
|
|
414
|
+
s = shape.to_plotly_json()
|
|
415
|
+
s.pop("xref", None)
|
|
416
|
+
s.pop("yref", None)
|
|
417
|
+
sub.add_shape(**s, row=row, col=col)
|
|
418
|
+
|
|
419
|
+
# Annotations (edge labels)
|
|
420
|
+
for ann in fig.layout.annotations:
|
|
421
|
+
a = ann.to_plotly_json()
|
|
422
|
+
a["xref"] = _remap_axis_ref(a.get("xref"), index + 1)
|
|
423
|
+
a["yref"] = _remap_axis_ref(a.get("yref"), index + 1)
|
|
424
|
+
a["axref"] = _remap_axis_ref(a.get("axref"), index + 1)
|
|
425
|
+
a["ayref"] = _remap_axis_ref(a.get("ayref"), index + 1)
|
|
426
|
+
sub.add_annotation(**a)
|
|
427
|
+
|
|
428
|
+
sub.update_xaxes(
|
|
429
|
+
showgrid=False,
|
|
430
|
+
zeroline=False,
|
|
431
|
+
fixedrange=True,
|
|
432
|
+
showticklabels=False,
|
|
433
|
+
constrain="domain",
|
|
434
|
+
row=row,
|
|
435
|
+
col=col,
|
|
436
|
+
)
|
|
437
|
+
sub.update_yaxes(
|
|
438
|
+
showgrid=False,
|
|
439
|
+
zeroline=False,
|
|
440
|
+
fixedrange=True,
|
|
441
|
+
showticklabels=False,
|
|
442
|
+
constrain="domain",
|
|
443
|
+
row=row,
|
|
444
|
+
col=col,
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
# Overall figure layout
|
|
448
|
+
colorbar_width = 100 if uses_coloraxis else 0
|
|
449
|
+
sub.update_layout(
|
|
450
|
+
title=dict(text=title, xref="paper", xanchor="left", x=0),
|
|
451
|
+
height=height * n_rows,
|
|
452
|
+
width=(width * n_cols) + colorbar_width,
|
|
453
|
+
showlegend=False,
|
|
454
|
+
dragmode=False,
|
|
455
|
+
hovermode="closest",
|
|
456
|
+
margin=dict(b=8, l=8, r=8, t=60 if title else 40),
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Update layout of each coloraxis
|
|
460
|
+
if uses_coloraxis and coloraxis is not None:
|
|
461
|
+
for row in range(1, n_rows + 1):
|
|
462
|
+
coloraxis_id = "coloraxis" if row == 1 else f"coloraxis{row}"
|
|
463
|
+
row_height = (1 - vertical_spacing * (n_rows - 1)) / n_rows
|
|
464
|
+
y_top = 1 - (row - 1) * (row_height + vertical_spacing)
|
|
465
|
+
sub.update_layout(
|
|
466
|
+
{
|
|
467
|
+
coloraxis_id: {
|
|
468
|
+
**coloraxis,
|
|
469
|
+
"colorbar": {
|
|
470
|
+
**coloraxis.get("colorbar", {}),
|
|
471
|
+
"len": height - 60,
|
|
472
|
+
"lenmode": "pixels",
|
|
473
|
+
"y": y_top,
|
|
474
|
+
"yanchor": "top",
|
|
475
|
+
},
|
|
476
|
+
}
|
|
477
|
+
}
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
return sub
|
|
@@ -1,7 +1,59 @@
|
|
|
1
1
|
import plotly.graph_objects as go
|
|
2
2
|
import pytest
|
|
3
3
|
|
|
4
|
-
from howso.visuals.utilities import nice_range, normalize_axis_range
|
|
4
|
+
from howso.visuals.utilities import compact_number, nice_range, normalize_axis_range
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@pytest.mark.parametrize(
|
|
8
|
+
("value", "digits", "expected"),
|
|
9
|
+
[
|
|
10
|
+
# Zero
|
|
11
|
+
(0, 3, "0"),
|
|
12
|
+
# Small decimals (milli range, handled without SI suffix)
|
|
13
|
+
(0.001, 3, "0.001"),
|
|
14
|
+
(0.001, 1, "1m"), # too few digits, uses m
|
|
15
|
+
(0.0099, 3, "0.01"),
|
|
16
|
+
(0.0001, 4, "0.0001"),
|
|
17
|
+
(0.0001, 3, "100µ"),
|
|
18
|
+
(0.00001, 8, "0.00001"), # no tailing 0s
|
|
19
|
+
(-0.00001, 6, "-0.00001"),
|
|
20
|
+
(0.00000009999, 4, "99.99n"),
|
|
21
|
+
(-0.00000009999, 4, "-99.99n"),
|
|
22
|
+
(-0.000000099999, 4, "-100n"), # round up
|
|
23
|
+
(1e-6, 3, "1µ"),
|
|
24
|
+
(2.5e-6, 3, "2.5µ"),
|
|
25
|
+
(1e-9, 3, "1n"),
|
|
26
|
+
# Base range
|
|
27
|
+
(1, 3, "1"),
|
|
28
|
+
(42, 3, "42"),
|
|
29
|
+
(999, 3, "999"),
|
|
30
|
+
(999, 1, "999"),
|
|
31
|
+
# Kilo
|
|
32
|
+
(1000, 3, "1k"),
|
|
33
|
+
(1500, 3, "1.5k"),
|
|
34
|
+
(1234, 4, "1.234k"),
|
|
35
|
+
(1234, 2, "1.2k"),
|
|
36
|
+
(1234, 1, "1k"),
|
|
37
|
+
# Mega
|
|
38
|
+
(999_999, 3, "1M"), # round up
|
|
39
|
+
(1_000_000, 3, "1M"),
|
|
40
|
+
(2_500_000, 3, "2.5M"),
|
|
41
|
+
# Giga (B)
|
|
42
|
+
(1_000_000_000, 3, "1B"),
|
|
43
|
+
(1_234_000_000, 3, "1.23B"),
|
|
44
|
+
# Tera
|
|
45
|
+
(1e12, 3, "1T"),
|
|
46
|
+
# Peta
|
|
47
|
+
(1e15, 3, "1P"),
|
|
48
|
+
# Negative values
|
|
49
|
+
(-1000, 3, "-1k"),
|
|
50
|
+
(-0.001, 3, "-0.001"),
|
|
51
|
+
(-1_500_000, 3, "-1.5M"),
|
|
52
|
+
],
|
|
53
|
+
)
|
|
54
|
+
def test_compact_number(value, digits, expected):
|
|
55
|
+
"""Test compact number creates expected formatted number."""
|
|
56
|
+
assert compact_number(value, digits) == expected
|
|
5
57
|
|
|
6
58
|
|
|
7
59
|
@pytest.mark.parametrize(
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Literal, SupportsFloat, SupportsInt
|
|
3
|
+
|
|
4
|
+
import plotly.graph_objects as go
|
|
5
|
+
|
|
6
|
+
SI_PREFIXES = [
|
|
7
|
+
(1e30, "Q"), # quetta
|
|
8
|
+
(1e27, "R"), # ronna
|
|
9
|
+
(1e24, "Y"), # yotta
|
|
10
|
+
(1e21, "Z"), # zetta
|
|
11
|
+
(1e18, "E"), # exa
|
|
12
|
+
(1e15, "P"), # peta
|
|
13
|
+
(1e12, "T"), # tera
|
|
14
|
+
(1e9, "B"), # giga (use B for billion over G)
|
|
15
|
+
(1e6, "M"), # mega
|
|
16
|
+
(1e3, "k"), # kilo
|
|
17
|
+
(1e0, ""), # base
|
|
18
|
+
(1e-3, "m"), # milli
|
|
19
|
+
(1e-6, "µ"), # micro
|
|
20
|
+
(1e-9, "n"), # nano
|
|
21
|
+
(1e-12, "p"), # pico
|
|
22
|
+
(1e-15, "f"), # femto
|
|
23
|
+
(1e-18, "a"), # atto
|
|
24
|
+
(1e-21, "z"), # zepto
|
|
25
|
+
(1e-24, "y"), # yocto
|
|
26
|
+
(1e-27, "r"), # ronto
|
|
27
|
+
(1e-30, "q"), # quecto
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def compact_number(value: SupportsFloat, digits: SupportsInt = 3) -> str:
|
|
32
|
+
"""
|
|
33
|
+
Format a number to specified digits with SI prefix.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
value : float
|
|
38
|
+
The value to format.
|
|
39
|
+
digits : int, default 3
|
|
40
|
+
The number of digits to format to.
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
str
|
|
45
|
+
The formatted number.
|
|
46
|
+
"""
|
|
47
|
+
value = float(value)
|
|
48
|
+
digits = int(digits)
|
|
49
|
+
abs_value = abs(value)
|
|
50
|
+
|
|
51
|
+
if value == 0:
|
|
52
|
+
return "0"
|
|
53
|
+
|
|
54
|
+
exp = math.floor(math.log10(abs_value))
|
|
55
|
+
|
|
56
|
+
# Fallback to sci notation if we run out of SI prefix
|
|
57
|
+
if abs(exp) > 30:
|
|
58
|
+
return f"{value:.{digits}g}"
|
|
59
|
+
|
|
60
|
+
# Don't use SI prefix if decimal places can fit it
|
|
61
|
+
if -digits <= exp < 0:
|
|
62
|
+
rounded = round(value, digits)
|
|
63
|
+
if rounded != 0:
|
|
64
|
+
formatted = f"{rounded:.{digits}f}"
|
|
65
|
+
return formatted.rstrip("0")
|
|
66
|
+
|
|
67
|
+
# Use SI prefix
|
|
68
|
+
exp_si = (exp // 3) * 3 # Snap down to nearest SI prefix boundary
|
|
69
|
+
scaled = value / 10**exp_si
|
|
70
|
+
|
|
71
|
+
if math.floor(math.log10(abs(round(scaled)))) >= 3:
|
|
72
|
+
# Move up to next prefix if rounding pushes scaled to 1000
|
|
73
|
+
exp_si += 3
|
|
74
|
+
scaled = value / 10**exp_si
|
|
75
|
+
|
|
76
|
+
index = (30 - exp_si) // 3
|
|
77
|
+
prefix = SI_PREFIXES[index][1]
|
|
78
|
+
formatted = f"{scaled:.{digits}g}"
|
|
79
|
+
if "e" in formatted:
|
|
80
|
+
# fallback if cant fit in digits and g produces sci notation
|
|
81
|
+
formatted = f"{scaled:.0f}"
|
|
82
|
+
return f"{formatted}{prefix}"
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def nice_range(lower: float, upper: float) -> tuple[float, float]:
|
|
86
|
+
"""
|
|
87
|
+
Expand a value interval to rounded bounds.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
lower : float
|
|
92
|
+
The lower bound of the data range.
|
|
93
|
+
upper : float
|
|
94
|
+
The upper bound of the data range.
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
tuple of float
|
|
99
|
+
A (nice_min, nice_max) pair suitable for use as an axis range.
|
|
100
|
+
"""
|
|
101
|
+
if upper < lower:
|
|
102
|
+
lower, upper = upper, lower
|
|
103
|
+
|
|
104
|
+
span = upper - lower
|
|
105
|
+
if span == 0:
|
|
106
|
+
return lower, upper
|
|
107
|
+
|
|
108
|
+
# Find power of 10 of span
|
|
109
|
+
power = math.floor(math.log10(span))
|
|
110
|
+
step = 10**power
|
|
111
|
+
|
|
112
|
+
# Adjust step to 1, 2, or 5 multiple
|
|
113
|
+
error = span / step
|
|
114
|
+
if error < 2:
|
|
115
|
+
step /= 5
|
|
116
|
+
elif error < 5:
|
|
117
|
+
step /= 2
|
|
118
|
+
|
|
119
|
+
nice_min = math.floor(lower / step) * step
|
|
120
|
+
nice_max = math.ceil(upper / step) * step
|
|
121
|
+
|
|
122
|
+
return nice_min, nice_max
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def normalize_axis_range(
|
|
126
|
+
*figures: go.Figure,
|
|
127
|
+
axis: Literal["y", "x"] = "y",
|
|
128
|
+
bounds: tuple[float, float] | None = None,
|
|
129
|
+
) -> None:
|
|
130
|
+
"""
|
|
131
|
+
Normalize the y or x axis range of all Figures.
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
axis : {"x", "y"}, default "y"
|
|
136
|
+
The axis to adjust.
|
|
137
|
+
bounds : tuple of float, optional
|
|
138
|
+
The axis range to normalize all figures to. If unset, calculates a range
|
|
139
|
+
given the axis values across all figures.
|
|
140
|
+
"""
|
|
141
|
+
if bounds is None:
|
|
142
|
+
bounds = (0, 0)
|
|
143
|
+
for fig in figures:
|
|
144
|
+
for trace in fig.data:
|
|
145
|
+
ax = getattr(trace, axis, None)
|
|
146
|
+
if ax is None:
|
|
147
|
+
ax = []
|
|
148
|
+
for val in ax:
|
|
149
|
+
if val is None:
|
|
150
|
+
continue
|
|
151
|
+
try:
|
|
152
|
+
value = float(val)
|
|
153
|
+
except ValueError:
|
|
154
|
+
break # Not a normalizable axis
|
|
155
|
+
if math.isfinite(value):
|
|
156
|
+
bounds = min(bounds[0], value), max(bounds[1], value)
|
|
157
|
+
if bounds == (0, 0):
|
|
158
|
+
return # no bounds detected
|
|
159
|
+
bounds = nice_range(*bounds)
|
|
160
|
+
|
|
161
|
+
for fig in figures:
|
|
162
|
+
if axis == "x":
|
|
163
|
+
fig.update_xaxes(range=bounds)
|
|
164
|
+
elif axis == "y":
|
|
165
|
+
fig.update_yaxes(range=bounds)
|
|
@@ -1,258 +0,0 @@
|
|
|
1
|
-
from collections.abc import Callable, Collection, Mapping, Sequence
|
|
2
|
-
from typing import Any, SupportsInt, TypeAlias
|
|
3
|
-
|
|
4
|
-
import networkx as nx
|
|
5
|
-
import numpy as np
|
|
6
|
-
import plotly.graph_objects as go
|
|
7
|
-
from sklearn.preprocessing import minmax_scale
|
|
8
|
-
|
|
9
|
-
LayoutMapping: TypeAlias = Mapping[Any, tuple[float, float]]
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def _create_edge_annotations(
|
|
13
|
-
G: nx.Graph, # noqa: N803
|
|
14
|
-
pos: LayoutMapping,
|
|
15
|
-
edge_attr: str | None = None,
|
|
16
|
-
edge_attr_sigfigs: SupportsInt | None = 4,
|
|
17
|
-
label_edges: bool = True,
|
|
18
|
-
uncertain_edges: Collection[tuple[str, str]] | None = None,
|
|
19
|
-
uncertain_edge_opacity: float = 0.3,
|
|
20
|
-
) -> tuple[list[go.layout.Annotation], list[dict[str, Any]]]:
|
|
21
|
-
# Annotations are created to show the edges between nodes,
|
|
22
|
-
# while invisible shapes with labels are created to label them with the edge weight.
|
|
23
|
-
annotations = []
|
|
24
|
-
shapes = []
|
|
25
|
-
directed = nx.is_directed(G)
|
|
26
|
-
|
|
27
|
-
widths = None
|
|
28
|
-
unscaled_widths = None
|
|
29
|
-
if edge_attr is not None:
|
|
30
|
-
unscaled_widths = [d[edge_attr] for _, _, d in G.edges(data=True)]
|
|
31
|
-
widths = minmax_scale(np.array(unscaled_widths).reshape(-1, 1), (2, 5))
|
|
32
|
-
widths = widths.reshape(-1)
|
|
33
|
-
|
|
34
|
-
edge_blacklist = set()
|
|
35
|
-
|
|
36
|
-
for i, (s, d) in enumerate(G.edges()):
|
|
37
|
-
if (s, d) in edge_blacklist:
|
|
38
|
-
continue
|
|
39
|
-
|
|
40
|
-
x0, y0 = pos[s]
|
|
41
|
-
x1, y1 = pos[d]
|
|
42
|
-
width = widths[i] if widths is not None else 2
|
|
43
|
-
|
|
44
|
-
if directed and G.has_edge(d, s):
|
|
45
|
-
edge_blacklist.add((d, s))
|
|
46
|
-
arrowside = "end+start"
|
|
47
|
-
elif not directed:
|
|
48
|
-
arrowside = "none"
|
|
49
|
-
else:
|
|
50
|
-
arrowside = "end"
|
|
51
|
-
|
|
52
|
-
if uncertain_edges and ((s, d) in uncertain_edges or (d, s) in uncertain_edges):
|
|
53
|
-
opacity = uncertain_edge_opacity
|
|
54
|
-
arrowside = "none"
|
|
55
|
-
else:
|
|
56
|
-
opacity = 0.8
|
|
57
|
-
|
|
58
|
-
annotations.append(
|
|
59
|
-
go.layout.Annotation(
|
|
60
|
-
ax=x0,
|
|
61
|
-
ay=y0,
|
|
62
|
-
axref="x",
|
|
63
|
-
ayref="y",
|
|
64
|
-
x=x1,
|
|
65
|
-
y=y1,
|
|
66
|
-
xref="x",
|
|
67
|
-
yref="y",
|
|
68
|
-
showarrow=True,
|
|
69
|
-
arrowhead=4,
|
|
70
|
-
standoff=40.5,
|
|
71
|
-
startstandoff=37.5,
|
|
72
|
-
arrowside=arrowside,
|
|
73
|
-
arrowwidth=width,
|
|
74
|
-
opacity=opacity,
|
|
75
|
-
captureevents=True,
|
|
76
|
-
)
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
if label_edges:
|
|
80
|
-
if edge_attr_sigfigs is not None and unscaled_widths is not None:
|
|
81
|
-
shape_label = f"{round(unscaled_widths[i], edge_attr_sigfigs)}"
|
|
82
|
-
elif unscaled_widths is not None:
|
|
83
|
-
shape_label = f"{unscaled_widths[i]}"
|
|
84
|
-
else:
|
|
85
|
-
shape_label = ""
|
|
86
|
-
else:
|
|
87
|
-
shape_label = ""
|
|
88
|
-
|
|
89
|
-
shape_label = (
|
|
90
|
-
'<span style="text-shadow: -1px -1px 0 #fff, 1px -1px 0 #fff, -1px 1px 0 #fff, 1px 1px 0 #fff;">'
|
|
91
|
-
f"{shape_label}</span>"
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
shapes.append(
|
|
95
|
-
dict(
|
|
96
|
-
type="line",
|
|
97
|
-
x0=x0,
|
|
98
|
-
y0=y0,
|
|
99
|
-
x1=x1,
|
|
100
|
-
y1=y1,
|
|
101
|
-
xref="x",
|
|
102
|
-
yref="y",
|
|
103
|
-
label=dict(text=shape_label),
|
|
104
|
-
opacity=0,
|
|
105
|
-
)
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
return annotations, shapes
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
def plot_graph(
|
|
112
|
-
G: nx.Graph, # noqa: N803
|
|
113
|
-
*,
|
|
114
|
-
colorscale: str | Sequence[tuple[float, str]] = "Bluered",
|
|
115
|
-
cscale_tuple: tuple[float, float, float] = None,
|
|
116
|
-
edge_attr_sigfigs: SupportsInt | None = 4,
|
|
117
|
-
edge_attr: str | None = None,
|
|
118
|
-
label_edges: bool = True,
|
|
119
|
-
layout: Callable[[nx.Graph], LayoutMapping] = nx.shell_layout,
|
|
120
|
-
node_color: list[float] | None = None,
|
|
121
|
-
subtitle: str | None = None,
|
|
122
|
-
title: str = "Causal Graph",
|
|
123
|
-
uncertain_edges: Collection[tuple[str, str]] | None = None,
|
|
124
|
-
uncertain_edge_opacity: float = 0.3,
|
|
125
|
-
) -> go.Figure:
|
|
126
|
-
"""
|
|
127
|
-
Plot a ``networkx`` graph using `Plotly`.
|
|
128
|
-
|
|
129
|
-
Parameters
|
|
130
|
-
----------
|
|
131
|
-
G : nx.Graph
|
|
132
|
-
The graph to plot.
|
|
133
|
-
colorscale : str | Sequence[tuple[float, str]], default "Bluered"
|
|
134
|
-
The colorscale to use when plotting nodes using ``node_color``. Defaults to `Plotly`'s reversed "Bluered"
|
|
135
|
-
colorscale.
|
|
136
|
-
cscale_tuple : tuple[float, float, float], optional
|
|
137
|
-
The tuple of values (``cmin``, ``cmid``, ``cmax``) to use for the colorscale. If None, ``(3, 15, 30)`` will be used.
|
|
138
|
-
edge_attr : str, optional
|
|
139
|
-
The name of the edge attribute to use when scaling the size of the edges. This should
|
|
140
|
-
be an attribute that is contained within ``G``.
|
|
141
|
-
edge_attr_sigfigs : SupportsInt | None, default 4
|
|
142
|
-
The number of significant figures to round to when labelling each edge. If None, no rounding
|
|
143
|
-
will be performed.
|
|
144
|
-
label_edges : bool, default True
|
|
145
|
-
Whether to label plotted edges.
|
|
146
|
-
layout : Callable[nx.Graph, Mapping[Any, tuple[float, float]]], default nx.shell_layout
|
|
147
|
-
A callable which generates a mapping of nodes to ``(x, y)`` coordinates.
|
|
148
|
-
node_color : list[float], optional
|
|
149
|
-
The data to use when determining the color for each node.
|
|
150
|
-
subtitle : str, optional
|
|
151
|
-
The subtitle of the plot.
|
|
152
|
-
title : str, default "Causal Graph"
|
|
153
|
-
The title of the plot.
|
|
154
|
-
uncertain_edges : Collection[tuple[str, str]], optional
|
|
155
|
-
Edges that are deemed uncertain by the caller. These will be plotted with an opacity equal to
|
|
156
|
-
``uncertain_edge_opacity`` and will not have directional arrows.
|
|
157
|
-
uncertain_edge_opacity : float, default 0.3
|
|
158
|
-
The opacity use when plotting edges contained in ``uncertain_edges``.
|
|
159
|
-
|
|
160
|
-
Returns
|
|
161
|
-
-------
|
|
162
|
-
go.Figure
|
|
163
|
-
The resultant `Plotly` figure.
|
|
164
|
-
"""
|
|
165
|
-
pos = layout(G, center=(1, 1))
|
|
166
|
-
|
|
167
|
-
text = []
|
|
168
|
-
node_x = []
|
|
169
|
-
node_y = []
|
|
170
|
-
for node in G.nodes():
|
|
171
|
-
text.append(node)
|
|
172
|
-
x, y = pos[node]
|
|
173
|
-
node_x.append(x)
|
|
174
|
-
node_y.append(y)
|
|
175
|
-
|
|
176
|
-
# This places a 1px black border around the node labels.
|
|
177
|
-
text = [
|
|
178
|
-
f'<span style="text-shadow: -1px -1px 0 #000, 1px -1px 0 #000, -1px 1px 0 #000, 1px 1px 0 #000;">{t}</span>'
|
|
179
|
-
for t in text
|
|
180
|
-
]
|
|
181
|
-
hovertemplate = "<b>%{text}</b>"
|
|
182
|
-
if node_color is not None:
|
|
183
|
-
hovertemplate += "<br>Destination MIR: %{customdata[0]:.4f}</br>"
|
|
184
|
-
|
|
185
|
-
node_trace = go.Scatter(
|
|
186
|
-
x=node_x,
|
|
187
|
-
y=node_y,
|
|
188
|
-
text=text,
|
|
189
|
-
textposition="middle center",
|
|
190
|
-
mode="markers+text",
|
|
191
|
-
marker=dict(
|
|
192
|
-
color=node_color,
|
|
193
|
-
coloraxis="coloraxis",
|
|
194
|
-
size=75,
|
|
195
|
-
),
|
|
196
|
-
zorder=999,
|
|
197
|
-
textfont=dict(color="white"),
|
|
198
|
-
name="Nodes",
|
|
199
|
-
customdata=[[x] for x in node_color] if node_color is not None else None,
|
|
200
|
-
hovertemplate=hovertemplate,
|
|
201
|
-
)
|
|
202
|
-
|
|
203
|
-
annotations, shapes = _create_edge_annotations(
|
|
204
|
-
G,
|
|
205
|
-
pos,
|
|
206
|
-
edge_attr=edge_attr,
|
|
207
|
-
edge_attr_sigfigs=edge_attr_sigfigs,
|
|
208
|
-
label_edges=label_edges,
|
|
209
|
-
uncertain_edges=uncertain_edges,
|
|
210
|
-
uncertain_edge_opacity=uncertain_edge_opacity,
|
|
211
|
-
)
|
|
212
|
-
fig = go.Figure(
|
|
213
|
-
layout=go.Layout(
|
|
214
|
-
title=dict(text="<br>Network graph made with Python", font=dict(size=16)),
|
|
215
|
-
showlegend=False,
|
|
216
|
-
hovermode="closest",
|
|
217
|
-
margin=dict(b=20, l=5, r=5, t=40),
|
|
218
|
-
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, constrain="domain"),
|
|
219
|
-
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, constrain="domain"),
|
|
220
|
-
annotations=annotations,
|
|
221
|
-
)
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
if cscale_tuple is None:
|
|
225
|
-
cbot = 0
|
|
226
|
-
cmin = 1
|
|
227
|
-
cmid = 3
|
|
228
|
-
cmax = 10
|
|
229
|
-
else:
|
|
230
|
-
cbot = 0
|
|
231
|
-
cmin = cscale_tuple[0]
|
|
232
|
-
cmid = cscale_tuple[1]
|
|
233
|
-
cmax = cscale_tuple[2]
|
|
234
|
-
|
|
235
|
-
fig.update_layout(
|
|
236
|
-
coloraxis=dict(
|
|
237
|
-
colorscale=colorscale,
|
|
238
|
-
cmin=cbot,
|
|
239
|
-
cmid=cmid,
|
|
240
|
-
cmax=cmax,
|
|
241
|
-
colorbar=dict(
|
|
242
|
-
title="Missing Information",
|
|
243
|
-
tickvals=[cbot, cmin, cmid, cmax],
|
|
244
|
-
ticktext=[f"{cbot}", f"{cmin}", f"{cmid}", f"≥{cmax}"],
|
|
245
|
-
),
|
|
246
|
-
),
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
for s in shapes:
|
|
250
|
-
fig.add_shape(**s)
|
|
251
|
-
fig.add_trace(node_trace)
|
|
252
|
-
|
|
253
|
-
fig.update_layout(
|
|
254
|
-
title=dict(text=title, subtitle=dict(text=subtitle)),
|
|
255
|
-
width=1000,
|
|
256
|
-
height=750,
|
|
257
|
-
)
|
|
258
|
-
return fig
|
|
@@ -1,79 +0,0 @@
|
|
|
1
|
-
import math
|
|
2
|
-
from typing import Literal
|
|
3
|
-
|
|
4
|
-
import plotly.graph_objects as go
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def nice_range(lower: float, upper: float) -> tuple[float, float]:
|
|
8
|
-
"""
|
|
9
|
-
Expand a value interval to rounded bounds.
|
|
10
|
-
|
|
11
|
-
Parameters
|
|
12
|
-
----------
|
|
13
|
-
lower : float
|
|
14
|
-
The lower bound of the data range.
|
|
15
|
-
upper : float
|
|
16
|
-
The upper bound of the data range.
|
|
17
|
-
|
|
18
|
-
Returns
|
|
19
|
-
-------
|
|
20
|
-
tuple of float
|
|
21
|
-
A (nice_min, nice_max) pair suitable for use as an axis range.
|
|
22
|
-
"""
|
|
23
|
-
if upper < lower:
|
|
24
|
-
lower, upper = upper, lower
|
|
25
|
-
|
|
26
|
-
span = upper - lower
|
|
27
|
-
if span == 0:
|
|
28
|
-
return lower, upper
|
|
29
|
-
|
|
30
|
-
# Find power of 10 of span
|
|
31
|
-
power = math.floor(math.log10(span))
|
|
32
|
-
step = 10**power
|
|
33
|
-
|
|
34
|
-
# Adjust step to 1, 2, or 5 multiple
|
|
35
|
-
error = span / step
|
|
36
|
-
if error < 2:
|
|
37
|
-
step /= 5
|
|
38
|
-
elif error < 5:
|
|
39
|
-
step /= 2
|
|
40
|
-
|
|
41
|
-
nice_min = math.floor(lower / step) * step
|
|
42
|
-
nice_max = math.ceil(upper / step) * step
|
|
43
|
-
|
|
44
|
-
return nice_min, nice_max
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def normalize_axis_range(
|
|
48
|
-
*figures: go.Figure,
|
|
49
|
-
axis: Literal["y", "x"] = "y",
|
|
50
|
-
bounds: tuple[float, float] | None = None,
|
|
51
|
-
) -> None:
|
|
52
|
-
"""
|
|
53
|
-
Normalize the y or x axis range of all Figures.
|
|
54
|
-
|
|
55
|
-
Parameters
|
|
56
|
-
----------
|
|
57
|
-
axis : {"x", "y"}, default "y"
|
|
58
|
-
The axis to adjust.
|
|
59
|
-
bounds : tuple of float, optional
|
|
60
|
-
The axis range to normalize all figures to. If unset, calculates a range
|
|
61
|
-
given the axis values across all figures.
|
|
62
|
-
"""
|
|
63
|
-
if bounds is None:
|
|
64
|
-
bounds = (0, 0)
|
|
65
|
-
for fig in figures:
|
|
66
|
-
for trace in fig.data:
|
|
67
|
-
ax = getattr(trace, axis, None)
|
|
68
|
-
if ax is None:
|
|
69
|
-
ax = []
|
|
70
|
-
for val in ax:
|
|
71
|
-
if val is not None and math.isfinite(float(val)):
|
|
72
|
-
bounds = min(bounds[0], float(val)), max(bounds[1], float(val))
|
|
73
|
-
bounds = nice_range(*bounds)
|
|
74
|
-
|
|
75
|
-
for fig in figures:
|
|
76
|
-
if axis == "x":
|
|
77
|
-
fig.update_xaxes(range=bounds)
|
|
78
|
-
elif axis == "y":
|
|
79
|
-
fig.update_yaxes(range=bounds)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|