scikit-network 0.33.3__cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-network might be problematic. Click here for more details.

Files changed (229) hide show
  1. scikit_network-0.33.3.dist-info/METADATA +122 -0
  2. scikit_network-0.33.3.dist-info/RECORD +229 -0
  3. scikit_network-0.33.3.dist-info/WHEEL +6 -0
  4. scikit_network-0.33.3.dist-info/licenses/AUTHORS.rst +43 -0
  5. scikit_network-0.33.3.dist-info/licenses/LICENSE +34 -0
  6. scikit_network-0.33.3.dist-info/top_level.txt +1 -0
  7. scikit_network.libs/libgomp-d22c30c5.so.1.0.0 +0 -0
  8. sknetwork/__init__.py +21 -0
  9. sknetwork/base.py +67 -0
  10. sknetwork/classification/__init__.py +8 -0
  11. sknetwork/classification/base.py +142 -0
  12. sknetwork/classification/base_rank.py +133 -0
  13. sknetwork/classification/diffusion.py +134 -0
  14. sknetwork/classification/knn.py +139 -0
  15. sknetwork/classification/metrics.py +205 -0
  16. sknetwork/classification/pagerank.py +66 -0
  17. sknetwork/classification/propagation.py +152 -0
  18. sknetwork/classification/tests/__init__.py +1 -0
  19. sknetwork/classification/tests/test_API.py +30 -0
  20. sknetwork/classification/tests/test_diffusion.py +77 -0
  21. sknetwork/classification/tests/test_knn.py +23 -0
  22. sknetwork/classification/tests/test_metrics.py +53 -0
  23. sknetwork/classification/tests/test_pagerank.py +20 -0
  24. sknetwork/classification/tests/test_propagation.py +24 -0
  25. sknetwork/classification/vote.cpp +27587 -0
  26. sknetwork/classification/vote.cpython-313-aarch64-linux-gnu.so +0 -0
  27. sknetwork/classification/vote.pyx +56 -0
  28. sknetwork/clustering/__init__.py +8 -0
  29. sknetwork/clustering/base.py +172 -0
  30. sknetwork/clustering/kcenters.py +253 -0
  31. sknetwork/clustering/leiden.py +242 -0
  32. sknetwork/clustering/leiden_core.cpp +31578 -0
  33. sknetwork/clustering/leiden_core.cpython-313-aarch64-linux-gnu.so +0 -0
  34. sknetwork/clustering/leiden_core.pyx +124 -0
  35. sknetwork/clustering/louvain.py +286 -0
  36. sknetwork/clustering/louvain_core.cpp +31223 -0
  37. sknetwork/clustering/louvain_core.cpython-313-aarch64-linux-gnu.so +0 -0
  38. sknetwork/clustering/louvain_core.pyx +124 -0
  39. sknetwork/clustering/metrics.py +91 -0
  40. sknetwork/clustering/postprocess.py +66 -0
  41. sknetwork/clustering/propagation_clustering.py +104 -0
  42. sknetwork/clustering/tests/__init__.py +1 -0
  43. sknetwork/clustering/tests/test_API.py +38 -0
  44. sknetwork/clustering/tests/test_kcenters.py +60 -0
  45. sknetwork/clustering/tests/test_leiden.py +34 -0
  46. sknetwork/clustering/tests/test_louvain.py +135 -0
  47. sknetwork/clustering/tests/test_metrics.py +50 -0
  48. sknetwork/clustering/tests/test_postprocess.py +39 -0
  49. sknetwork/data/__init__.py +6 -0
  50. sknetwork/data/base.py +33 -0
  51. sknetwork/data/load.py +406 -0
  52. sknetwork/data/models.py +459 -0
  53. sknetwork/data/parse.py +644 -0
  54. sknetwork/data/test_graphs.py +84 -0
  55. sknetwork/data/tests/__init__.py +1 -0
  56. sknetwork/data/tests/test_API.py +30 -0
  57. sknetwork/data/tests/test_base.py +14 -0
  58. sknetwork/data/tests/test_load.py +95 -0
  59. sknetwork/data/tests/test_models.py +52 -0
  60. sknetwork/data/tests/test_parse.py +250 -0
  61. sknetwork/data/tests/test_test_graphs.py +29 -0
  62. sknetwork/data/tests/test_toy_graphs.py +68 -0
  63. sknetwork/data/timeout.py +38 -0
  64. sknetwork/data/toy_graphs.py +611 -0
  65. sknetwork/embedding/__init__.py +8 -0
  66. sknetwork/embedding/base.py +94 -0
  67. sknetwork/embedding/force_atlas.py +198 -0
  68. sknetwork/embedding/louvain_embedding.py +148 -0
  69. sknetwork/embedding/random_projection.py +135 -0
  70. sknetwork/embedding/spectral.py +141 -0
  71. sknetwork/embedding/spring.py +198 -0
  72. sknetwork/embedding/svd.py +359 -0
  73. sknetwork/embedding/tests/__init__.py +1 -0
  74. sknetwork/embedding/tests/test_API.py +49 -0
  75. sknetwork/embedding/tests/test_force_atlas.py +35 -0
  76. sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
  77. sknetwork/embedding/tests/test_random_projection.py +28 -0
  78. sknetwork/embedding/tests/test_spectral.py +81 -0
  79. sknetwork/embedding/tests/test_spring.py +50 -0
  80. sknetwork/embedding/tests/test_svd.py +43 -0
  81. sknetwork/gnn/__init__.py +10 -0
  82. sknetwork/gnn/activation.py +117 -0
  83. sknetwork/gnn/base.py +181 -0
  84. sknetwork/gnn/base_activation.py +90 -0
  85. sknetwork/gnn/base_layer.py +109 -0
  86. sknetwork/gnn/gnn_classifier.py +305 -0
  87. sknetwork/gnn/layer.py +153 -0
  88. sknetwork/gnn/loss.py +180 -0
  89. sknetwork/gnn/neighbor_sampler.py +65 -0
  90. sknetwork/gnn/optimizer.py +164 -0
  91. sknetwork/gnn/tests/__init__.py +1 -0
  92. sknetwork/gnn/tests/test_activation.py +56 -0
  93. sknetwork/gnn/tests/test_base.py +75 -0
  94. sknetwork/gnn/tests/test_base_layer.py +37 -0
  95. sknetwork/gnn/tests/test_gnn_classifier.py +130 -0
  96. sknetwork/gnn/tests/test_layers.py +80 -0
  97. sknetwork/gnn/tests/test_loss.py +33 -0
  98. sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
  99. sknetwork/gnn/tests/test_optimizer.py +43 -0
  100. sknetwork/gnn/tests/test_utils.py +41 -0
  101. sknetwork/gnn/utils.py +127 -0
  102. sknetwork/hierarchy/__init__.py +6 -0
  103. sknetwork/hierarchy/base.py +96 -0
  104. sknetwork/hierarchy/louvain_hierarchy.py +272 -0
  105. sknetwork/hierarchy/metrics.py +234 -0
  106. sknetwork/hierarchy/paris.cpp +37871 -0
  107. sknetwork/hierarchy/paris.cpython-313-aarch64-linux-gnu.so +0 -0
  108. sknetwork/hierarchy/paris.pyx +316 -0
  109. sknetwork/hierarchy/postprocess.py +350 -0
  110. sknetwork/hierarchy/tests/__init__.py +1 -0
  111. sknetwork/hierarchy/tests/test_API.py +24 -0
  112. sknetwork/hierarchy/tests/test_algos.py +34 -0
  113. sknetwork/hierarchy/tests/test_metrics.py +62 -0
  114. sknetwork/hierarchy/tests/test_postprocess.py +57 -0
  115. sknetwork/linalg/__init__.py +9 -0
  116. sknetwork/linalg/basics.py +37 -0
  117. sknetwork/linalg/diteration.cpp +27403 -0
  118. sknetwork/linalg/diteration.cpython-313-aarch64-linux-gnu.so +0 -0
  119. sknetwork/linalg/diteration.pyx +47 -0
  120. sknetwork/linalg/eig_solver.py +93 -0
  121. sknetwork/linalg/laplacian.py +15 -0
  122. sknetwork/linalg/normalizer.py +86 -0
  123. sknetwork/linalg/operators.py +225 -0
  124. sknetwork/linalg/polynome.py +76 -0
  125. sknetwork/linalg/ppr_solver.py +170 -0
  126. sknetwork/linalg/push.cpp +31075 -0
  127. sknetwork/linalg/push.cpython-313-aarch64-linux-gnu.so +0 -0
  128. sknetwork/linalg/push.pyx +71 -0
  129. sknetwork/linalg/sparse_lowrank.py +142 -0
  130. sknetwork/linalg/svd_solver.py +91 -0
  131. sknetwork/linalg/tests/__init__.py +1 -0
  132. sknetwork/linalg/tests/test_eig.py +44 -0
  133. sknetwork/linalg/tests/test_laplacian.py +18 -0
  134. sknetwork/linalg/tests/test_normalization.py +34 -0
  135. sknetwork/linalg/tests/test_operators.py +66 -0
  136. sknetwork/linalg/tests/test_polynome.py +38 -0
  137. sknetwork/linalg/tests/test_ppr.py +50 -0
  138. sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
  139. sknetwork/linalg/tests/test_svd.py +38 -0
  140. sknetwork/linkpred/__init__.py +2 -0
  141. sknetwork/linkpred/base.py +46 -0
  142. sknetwork/linkpred/nn.py +126 -0
  143. sknetwork/linkpred/tests/__init__.py +1 -0
  144. sknetwork/linkpred/tests/test_nn.py +27 -0
  145. sknetwork/log.py +19 -0
  146. sknetwork/path/__init__.py +5 -0
  147. sknetwork/path/dag.py +54 -0
  148. sknetwork/path/distances.py +98 -0
  149. sknetwork/path/search.py +31 -0
  150. sknetwork/path/shortest_path.py +61 -0
  151. sknetwork/path/tests/__init__.py +1 -0
  152. sknetwork/path/tests/test_dag.py +37 -0
  153. sknetwork/path/tests/test_distances.py +62 -0
  154. sknetwork/path/tests/test_search.py +40 -0
  155. sknetwork/path/tests/test_shortest_path.py +40 -0
  156. sknetwork/ranking/__init__.py +8 -0
  157. sknetwork/ranking/base.py +61 -0
  158. sknetwork/ranking/betweenness.cpp +9710 -0
  159. sknetwork/ranking/betweenness.cpython-313-aarch64-linux-gnu.so +0 -0
  160. sknetwork/ranking/betweenness.pyx +97 -0
  161. sknetwork/ranking/closeness.py +92 -0
  162. sknetwork/ranking/hits.py +94 -0
  163. sknetwork/ranking/katz.py +83 -0
  164. sknetwork/ranking/pagerank.py +110 -0
  165. sknetwork/ranking/postprocess.py +37 -0
  166. sknetwork/ranking/tests/__init__.py +1 -0
  167. sknetwork/ranking/tests/test_API.py +32 -0
  168. sknetwork/ranking/tests/test_betweenness.py +38 -0
  169. sknetwork/ranking/tests/test_closeness.py +30 -0
  170. sknetwork/ranking/tests/test_hits.py +20 -0
  171. sknetwork/ranking/tests/test_pagerank.py +62 -0
  172. sknetwork/ranking/tests/test_postprocess.py +26 -0
  173. sknetwork/regression/__init__.py +4 -0
  174. sknetwork/regression/base.py +61 -0
  175. sknetwork/regression/diffusion.py +210 -0
  176. sknetwork/regression/tests/__init__.py +1 -0
  177. sknetwork/regression/tests/test_API.py +32 -0
  178. sknetwork/regression/tests/test_diffusion.py +56 -0
  179. sknetwork/sknetwork.py +3 -0
  180. sknetwork/test_base.py +35 -0
  181. sknetwork/test_log.py +15 -0
  182. sknetwork/topology/__init__.py +8 -0
  183. sknetwork/topology/cliques.cpp +32568 -0
  184. sknetwork/topology/cliques.cpython-313-aarch64-linux-gnu.so +0 -0
  185. sknetwork/topology/cliques.pyx +149 -0
  186. sknetwork/topology/core.cpp +30654 -0
  187. sknetwork/topology/core.cpython-313-aarch64-linux-gnu.so +0 -0
  188. sknetwork/topology/core.pyx +90 -0
  189. sknetwork/topology/cycles.py +243 -0
  190. sknetwork/topology/minheap.cpp +27335 -0
  191. sknetwork/topology/minheap.cpython-313-aarch64-linux-gnu.so +0 -0
  192. sknetwork/topology/minheap.pxd +20 -0
  193. sknetwork/topology/minheap.pyx +109 -0
  194. sknetwork/topology/structure.py +194 -0
  195. sknetwork/topology/tests/__init__.py +1 -0
  196. sknetwork/topology/tests/test_cliques.py +28 -0
  197. sknetwork/topology/tests/test_core.py +19 -0
  198. sknetwork/topology/tests/test_cycles.py +65 -0
  199. sknetwork/topology/tests/test_structure.py +85 -0
  200. sknetwork/topology/tests/test_triangles.py +38 -0
  201. sknetwork/topology/tests/test_wl.py +72 -0
  202. sknetwork/topology/triangles.cpp +8897 -0
  203. sknetwork/topology/triangles.cpython-313-aarch64-linux-gnu.so +0 -0
  204. sknetwork/topology/triangles.pyx +151 -0
  205. sknetwork/topology/weisfeiler_lehman.py +133 -0
  206. sknetwork/topology/weisfeiler_lehman_core.cpp +27638 -0
  207. sknetwork/topology/weisfeiler_lehman_core.cpython-313-aarch64-linux-gnu.so +0 -0
  208. sknetwork/topology/weisfeiler_lehman_core.pyx +114 -0
  209. sknetwork/utils/__init__.py +7 -0
  210. sknetwork/utils/check.py +355 -0
  211. sknetwork/utils/format.py +221 -0
  212. sknetwork/utils/membership.py +82 -0
  213. sknetwork/utils/neighbors.py +115 -0
  214. sknetwork/utils/tests/__init__.py +1 -0
  215. sknetwork/utils/tests/test_check.py +190 -0
  216. sknetwork/utils/tests/test_format.py +63 -0
  217. sknetwork/utils/tests/test_membership.py +24 -0
  218. sknetwork/utils/tests/test_neighbors.py +41 -0
  219. sknetwork/utils/tests/test_tfidf.py +18 -0
  220. sknetwork/utils/tests/test_values.py +66 -0
  221. sknetwork/utils/tfidf.py +37 -0
  222. sknetwork/utils/values.py +76 -0
  223. sknetwork/visualization/__init__.py +4 -0
  224. sknetwork/visualization/colors.py +34 -0
  225. sknetwork/visualization/dendrograms.py +277 -0
  226. sknetwork/visualization/graphs.py +1039 -0
  227. sknetwork/visualization/tests/__init__.py +1 -0
  228. sknetwork/visualization/tests/test_dendrograms.py +53 -0
  229. sknetwork/visualization/tests/test_graphs.py +176 -0
@@ -0,0 +1,1039 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created in April 2020
5
+ @author: Thomas Bonald <thomas.bonald@telecom-paris.fr>
6
+ @author: Quentin Lutz <qlutz@live.fr>
7
+ """
8
+ from typing import Optional, Iterable, Union, Tuple
9
+
10
+ import numpy as np
11
+ from scipy import sparse
12
+
13
+ from sknetwork.clustering.louvain import Louvain
14
+ from sknetwork.utils.format import is_symmetric, check_format
15
+ from sknetwork.embedding.spring import Spring
16
+ from sknetwork.visualization.colors import STANDARD_COLORS, COOLWARM_RGB
17
+
18
+
19
+ def min_max_scaling(x: np.ndarray, x_min: Optional[float] = None, x_max: Optional[float] = None) -> np.ndarray:
20
+ """Shift and scale vector to be between 0 and 1."""
21
+ x = x.astype(float)
22
+ if x_min is None:
23
+ x_min = np.min(x)
24
+ if x_max is None:
25
+ x_max = np.max(x)
26
+ x -= x_min
27
+ if x_max > x_min:
28
+ x /= (x_max - x_min)
29
+ else:
30
+ x = .5 * np.ones_like(x)
31
+ return x
32
+
33
+
34
+ def rescale(position: np.ndarray, width: float, height: float, margin: float, node_size: float, node_size_max: float,
35
+ display_node_weight: bool, names: Optional[np.ndarray] = None, name_position: str = 'right',
36
+ font_size: int = 12):
37
+ """Rescale position and adjust parameters.
38
+
39
+ Parameters
40
+ ----------
41
+ position :
42
+ array to rescale
43
+ width :
44
+ Horizontal scaling parameter
45
+ height :
46
+ Vertical scaling parameter
47
+ margin :
48
+ Minimal margin for the plot
49
+ node_size :
50
+ Node size (used to adapt the margin)
51
+ node_size_max :
52
+ Maximum node size (used to adapt the margin)
53
+ display_node_weight :
54
+ If ``True``, display node weight (used to adapt the margin)
55
+ names :
56
+ Names of nodes.
57
+ name_position :
58
+ Position of names (left, right, above, below)
59
+ font_size :
60
+ Font size
61
+
62
+ Returns
63
+ -------
64
+ position :
65
+ Rescaled positions
66
+ width :
67
+ Rescaled width
68
+ height :
69
+ Rescaled height
70
+ """
71
+ x = position[:, 0]
72
+ y = position[:, 1]
73
+ span_x = np.max(x) - np.min(x)
74
+ span_y = np.max(y) - np.min(y)
75
+ x = min_max_scaling(x)
76
+ y = 1 - min_max_scaling(y)
77
+ position = np.vstack((x, y)).T
78
+
79
+ # rescale
80
+ if width and not height:
81
+ height = width
82
+ if span_x and span_y:
83
+ height *= span_y / span_x
84
+ elif height and not width:
85
+ width = height
86
+ if span_x and span_y:
87
+ width *= span_x / span_y
88
+ position = position * np.array([width, height])
89
+
90
+ # text
91
+ if names is not None:
92
+ lengths = np.array([len(str(name)) for name in names])
93
+ if name_position == 'left':
94
+ margin_left = -np.min(position[:, 0] - lengths * font_size)
95
+ margin_left = margin_left * (margin_left > 0)
96
+ position[:, 0] += margin_left
97
+ width += margin_left
98
+ elif name_position == 'right':
99
+ margin_right = np.max(position[:, 0] + lengths * font_size - width)
100
+ margin_right = margin_right * (margin_right > 0)
101
+ width += margin_right
102
+ else:
103
+ margin_left = -np.min(position[:, 0] - lengths * font_size / 2)
104
+ margin_left = margin_left * (margin_left > 0)
105
+ margin_right = np.max(position[:, 0] + lengths * font_size / 2 - width)
106
+ margin_right = margin_right * (margin_right > 0)
107
+ position[:, 0] += margin_left
108
+ width += margin_left + margin_right
109
+ if name_position == 'above':
110
+ position[:, 1] += font_size
111
+ height += font_size
112
+ else:
113
+ height += font_size
114
+
115
+ # margins
116
+ margin = max(margin, node_size_max * display_node_weight, node_size)
117
+ position += margin
118
+ width += 2 * margin
119
+ height += 2 * margin
120
+ return position, width, height
121
+
122
+
123
+ def get_label_colors(label_colors: Optional[Iterable]):
124
+ """Return label svg colors.
125
+
126
+ Examples
127
+ --------
128
+ >>> get_label_colors(['black'])
129
+ array(['black'], dtype='<U5')
130
+ >>> get_label_colors({0: 'blue'})
131
+ array(['blue'], dtype='<U64')
132
+ """
133
+ if label_colors is not None:
134
+ if isinstance(label_colors, dict):
135
+ keys = list(label_colors.keys())
136
+ values = list(label_colors.values())
137
+ label_colors = np.array(['black'] * (max(keys) + 1), dtype='U64')
138
+ label_colors[keys] = values
139
+ elif isinstance(label_colors, list):
140
+ label_colors = np.array(label_colors)
141
+ else:
142
+ label_colors = STANDARD_COLORS.copy()
143
+ return label_colors
144
+
145
+
146
+ def get_node_colors(n: int, labels: Optional[Iterable], scores: Optional[Iterable],
147
+ membership: Optional[sparse.csr_matrix],
148
+ node_color: str, label_colors: Optional[Iterable],
149
+ score_min: Optional[float] = None, score_max: Optional[float] = None) -> np.ndarray:
150
+ """Return the colors of the nodes using either labels or scores or default color."""
151
+ node_colors = np.array(n * [node_color]).astype('U64')
152
+ if labels is not None:
153
+ if isinstance(labels, dict):
154
+ keys = np.array(list(labels.keys()))
155
+ values = np.array(list(labels.values())).astype(int)
156
+ labels = -np.ones(n, dtype=int)
157
+ labels[keys] = values
158
+ elif isinstance(labels, list):
159
+ if len(labels) != n:
160
+ raise ValueError("The number of labels must be equal to the corresponding number of nodes.")
161
+ else:
162
+ labels = np.array(labels)
163
+ index = labels >= 0
164
+ label_colors = get_label_colors(label_colors)
165
+ node_colors[index] = label_colors[labels[index] % len(label_colors)]
166
+ elif scores is not None:
167
+ colors_score = COOLWARM_RGB.copy()
168
+ n_colors = colors_score.shape[0]
169
+ colors_score_svg = np.array(['rgb' + str(tuple([int(c) for c in colors_score[i]])) for i in range(n_colors)])
170
+ if isinstance(scores, dict):
171
+ keys = np.array(list(scores.keys()))
172
+ values = np.array(list(scores.values()))
173
+ scores = (min_max_scaling(values, score_min, score_max) * (n_colors - 1)).astype(int)
174
+ node_colors[keys] = colors_score_svg[scores]
175
+ else:
176
+ if isinstance(scores, list):
177
+ if len(scores) != n:
178
+ raise ValueError("The number of scores must be equal to the corresponding number of nodes.")
179
+ else:
180
+ scores = np.array(scores)
181
+ scores = (min_max_scaling(scores, score_min, score_max) * (n_colors - 1)).astype(int)
182
+ node_colors = colors_score_svg[scores]
183
+ elif membership is not None:
184
+ if isinstance(label_colors, dict):
185
+ raise TypeError("Label colors must be a list or an array when using a membership.")
186
+ label_colors = get_label_colors(label_colors)
187
+ node_colors = label_colors
188
+ return node_colors
189
+
190
+
191
+ def get_node_widths(n: int, seeds: Union[int, dict, list], node_width: float, node_width_max: float) -> np.ndarray:
192
+ """Return the node widths."""
193
+ node_widths = node_width * np.ones(n)
194
+ if seeds is not None:
195
+ if type(seeds) == dict:
196
+ seeds = list(seeds.keys())
197
+ elif np.issubdtype(type(seeds), np.integer):
198
+ seeds = [seeds]
199
+ if len(seeds):
200
+ node_widths[np.array(seeds)] = node_width_max
201
+ return node_widths
202
+
203
+
204
+ def get_node_sizes(weights: np.ndarray, node_size: float, node_size_min: float, node_size_max: float, node_weight) \
205
+ -> np.ndarray:
206
+ """Return the node sizes."""
207
+ if node_weight and np.min(weights) < np.max(weights):
208
+ node_sizes = node_size_min + np.abs(node_size_max - node_size_min) * weights / np.max(weights)
209
+ else:
210
+ node_sizes = node_size * np.ones_like(weights)
211
+ return node_sizes
212
+
213
+
214
+ def get_node_sizes_bipartite(weights_row: np.ndarray, weights_col: np.ndarray, node_size: float, node_size_min: float,
215
+ node_size_max: float, node_weight) -> (np.ndarray, np.ndarray):
216
+ """Return the node sizes for bipartite graphs."""
217
+ weights = np.hstack((weights_row, weights_col))
218
+ if node_weight and np.min(weights) < np.max(weights):
219
+ node_sizes_row = node_size_min + np.abs(node_size_max - node_size_min) * weights_row / np.max(weights)
220
+ node_sizes_col = node_size_min + np.abs(node_size_max - node_size_min) * weights_col / np.max(weights)
221
+ else:
222
+ node_sizes_row = node_size * np.ones_like(weights_row)
223
+ node_sizes_col = node_size * np.ones_like(weights_col)
224
+ return node_sizes_row, node_sizes_col
225
+
226
+
227
+ def get_edge_colors(adjacency: sparse.csr_matrix, edge_labels: Optional[list], edge_color: str,
228
+ label_colors: Optional[Iterable]) -> Tuple[np.ndarray, np.ndarray, list]:
229
+ """Return the edge colors."""
230
+ n_row, n_col = adjacency.shape
231
+ n_edges = adjacency.nnz
232
+ adjacency_labels = (adjacency > 0).astype(int)
233
+ adjacency_labels.data = -adjacency_labels.data
234
+ edge_colors_residual = []
235
+ if edge_labels:
236
+ label_colors = get_label_colors(label_colors)
237
+ for i, j, label in edge_labels:
238
+ if i < 0 or i >= n_row or j < 0 or j >= n_col:
239
+ raise ValueError('Invalid node index in edge labels.')
240
+ if adjacency[i, j]:
241
+ adjacency_labels[i, j] = label % len(label_colors)
242
+ else:
243
+ color = label_colors[label % len(label_colors)]
244
+ edge_colors_residual.append((i, j, color))
245
+ edge_order = np.argsort(adjacency_labels.data)
246
+ edge_colors = np.array(n_edges * [edge_color]).astype('U64')
247
+ index = np.argwhere(adjacency_labels.data >= 0).ravel()
248
+ if len(index):
249
+ edge_colors[index] = label_colors[adjacency_labels.data[index]]
250
+ return edge_colors, edge_order, edge_colors_residual
251
+
252
+
253
+ def get_edge_widths(adjacency: sparse.coo_matrix, edge_width: float, edge_width_min: float, edge_width_max: float,
254
+ display_edge_weight: bool) -> np.ndarray:
255
+ """Return the edge widths."""
256
+ weights = adjacency.data
257
+ edge_widths = None
258
+ if len(weights):
259
+ if display_edge_weight and np.min(weights) < np.max(weights):
260
+ edge_widths = edge_width_min + np.abs(edge_width_max - edge_width_min) * (weights - np.min(weights))\
261
+ / (np.max(weights) - np.min(weights))
262
+ else:
263
+ edge_widths = edge_width * np.ones_like(weights)
264
+ return edge_widths
265
+
266
+
267
+ def svg_node(pos_node: np.ndarray, size: float, color: str, stroke_width: float = 1, stroke_color: str = 'black') \
268
+ -> str:
269
+ """Return svg code for a node.
270
+
271
+ Parameters
272
+ ----------
273
+ pos_node :
274
+ (x, y) coordinates of the node.
275
+ size :
276
+ Radius of disk in pixels.
277
+ color :
278
+ Color of the disk in SVG valid format.
279
+ stroke_width :
280
+ Width of the contour of the disk in pixels, centered around the circle.
281
+ stroke_color :
282
+ Color of the contour in SVG valid format.
283
+
284
+ Returns
285
+ -------
286
+ SVG code for the node.
287
+ """
288
+ x, y = pos_node.astype(int)
289
+ return """<circle cx="{}" cy="{}" r="{}" style="fill:{};stroke:{};stroke-width:{}"/>\n"""\
290
+ .format(x, y, size, color, stroke_color, stroke_width)
291
+
292
+
293
+ def svg_pie_chart_node(pos_node: np.ndarray, size: float, probs: np.ndarray, colors: np.ndarray,
294
+ stroke_width: float = 1, stroke_color: str = 'black') -> str:
295
+ """Return svg code for a pie-chart node."""
296
+ x, y = pos_node.astype(float)
297
+ n_colors = len(colors)
298
+ out = ""
299
+ cumsum = np.zeros(probs.shape[1] + 1)
300
+ cumsum[1:] = np.cumsum(probs)
301
+ if cumsum[-1] == 0:
302
+ return svg_node(pos_node, size, 'white', stroke_width=3)
303
+ sum_probs = cumsum[-1]
304
+ cumsum = np.multiply(cumsum, (2 * np.pi) / cumsum[-1])
305
+ x_array = size * np.cos(cumsum) + x
306
+ y_array = size * np.sin(cumsum) + y
307
+ large = np.array(probs > sum_probs / 2).ravel()
308
+ for index in range(probs.shape[1]):
309
+ out += """<path d="M {} {} A {} {} 0 {} 1 {} {} L {} {}" style="fill:{};stroke:{};stroke-width:{}" />\n"""\
310
+ .format(x_array[index], y_array[index], size, size, int(large[index]),
311
+ x_array[index + 1], y_array[index + 1], x, y, colors[index % n_colors], stroke_color, stroke_width)
312
+ return out
313
+
314
+
315
+ def svg_edge(pos_1: np.ndarray, pos_2: np.ndarray, edge_width: float = 1, edge_color: str = 'black') -> str:
316
+ """Return svg code for an edge."""
317
+ x1, y1 = pos_1.astype(int)
318
+ x2, y2 = pos_2.astype(int)
319
+ return """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}"/>\n"""\
320
+ .format(edge_width, edge_color, x1, y1, x2, y2)
321
+
322
+
323
+ def svg_edge_directed(pos_1: np.ndarray, pos_2: np.ndarray, edge_width: float = 1, edge_color: str = 'black',
324
+ node_size: float = 1.) -> str:
325
+ """Return svg code for a directed edge."""
326
+ vec = pos_2 - pos_1
327
+ norm = np.linalg.norm(vec)
328
+ if norm:
329
+ x, y = ((vec / norm) * node_size).astype(int)
330
+ x1, y1 = pos_1.astype(int)
331
+ x2, y2 = pos_2.astype(int)
332
+ return """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" marker-end="url(#arrow-{})"/>\n"""\
333
+ .format(edge_width, edge_color, x1, y1, x2 - x, y2 - y, edge_color)
334
+ else:
335
+ return ""
336
+
337
+
338
+ def svg_text(pos, text, margin_text, font_size=12, position: str = 'right'):
339
+ """Return svg code for text."""
340
+ if position == 'left':
341
+ pos[0] -= margin_text
342
+ anchor = 'end'
343
+ elif position == 'above':
344
+ pos[1] -= margin_text
345
+ anchor = 'middle'
346
+ elif position == 'below':
347
+ pos[1] += 2 * margin_text
348
+ anchor = 'middle'
349
+ else:
350
+ pos[0] += margin_text
351
+ anchor = 'start'
352
+ x, y = pos.astype(int)
353
+ text = str(text)
354
+ for c in ['&', '<', '>']:
355
+ text = text.replace(c, ' ')
356
+ return """<text text-anchor="{}" x="{}" y="{}" font-size="{}">{}</text>""".format(anchor, x, y, font_size, text)
357
+
358
+
359
+ def visualize_graph(adjacency: Optional[sparse.csr_matrix] = None, position: Optional[np.ndarray] = None,
360
+ names: Optional[np.ndarray] = None, labels: Optional[Iterable] = None,
361
+ name_position: str = 'right', scores: Optional[Iterable] = None,
362
+ probs: Optional[Union[np.ndarray, sparse.csr_matrix]] = None,
363
+ seeds: Union[list, dict] = None, width: Optional[float] = 400, height: Optional[float] = 300,
364
+ margin: float = 20, margin_text: float = 3, scale: float = 1,
365
+ node_order: Optional[np.ndarray] = None, node_size: float = 7, node_size_min: float = 1,
366
+ node_size_max: float = 20,
367
+ display_node_weight: Optional[bool] = None, node_weights: Optional[np.ndarray] = None,
368
+ node_width: float = 1, node_width_max: float = 3, node_color: str = 'gray',
369
+ display_edges: bool = True, edge_labels: Optional[list] = None,
370
+ edge_width: float = 1, edge_width_min: float = 0.5,
371
+ edge_width_max: float = 20, display_edge_weight: bool = False,
372
+ edge_color: Optional[str] = None, label_colors: Optional[Iterable] = None,
373
+ font_size: int = 12, directed: Optional[bool] = None, filename: Optional[str] = None) -> str:
374
+ """Return the image of a graph in SVG format.
375
+
376
+ Parameters
377
+ ----------
378
+ adjacency :
379
+ Adjacency matrix of the graph.
380
+ position :
381
+ Positions of the nodes.
382
+ names :
383
+ Names of the nodes.
384
+ labels :
385
+ Labels of the nodes (negative values mean no label).
386
+ name_position :
387
+ Position of the names (left, right, above, below)
388
+ scores :
389
+ Scores of the nodes (measure of importance).
390
+ probs :
391
+ Probability distribution over labels.
392
+ seeds :
393
+ Nodes to be highlighted (if dict, only keys are considered).
394
+ width :
395
+ Width of the image.
396
+ height :
397
+ Height of the image.
398
+ margin :
399
+ Margin of the image.
400
+ margin_text :
401
+ Margin between node and text.
402
+ scale :
403
+ Multiplicative factor on the dimensions of the image.
404
+ node_order :
405
+ Order in which nodes are displayed.
406
+ node_size :
407
+ Size of nodes.
408
+ node_size_min :
409
+ Minimum size of a node.
410
+ node_size_max:
411
+ Maximum size of a node.
412
+ node_width :
413
+ Width of node circle.
414
+ node_width_max :
415
+ Maximum width of node circle.
416
+ node_color :
417
+ Default color of nodes (svg color).
418
+ display_node_weight :
419
+ If ``True``, display node weights through node size.
420
+ node_weights :
421
+ Node weights.
422
+ display_edges :
423
+ If ``True``, display edges.
424
+ edge_labels :
425
+ Labels of the edges, as a list of tuples (source, destination, label)
426
+ edge_width :
427
+ Width of edges.
428
+ edge_width_min :
429
+ Minimum width of edges.
430
+ edge_width_max :
431
+ Maximum width of edges.
432
+ display_edge_weight :
433
+ If ``True``, display edge weights through edge widths.
434
+ edge_color :
435
+ Default color of edges (svg color).
436
+ label_colors:
437
+ Colors of the labels (svg colors).
438
+ font_size :
439
+ Font size.
440
+ directed :
441
+ If ``True``, considers the graph as directed.
442
+ filename :
443
+ Filename for saving image (optional).
444
+
445
+ Returns
446
+ -------
447
+ image : str
448
+ SVG image.
449
+
450
+ Example
451
+ -------
452
+ >>> from sknetwork.data import karate_club
453
+ >>> graph = karate_club(True)
454
+ >>> adjacency = graph.adjacency
455
+ >>> position = graph.position
456
+ >>> from sknetwork.visualization import visualize_graph
457
+ >>> image = visualize_graph(adjacency, position)
458
+ >>> image[1:4]
459
+ 'svg'
460
+ """
461
+ # check adjacency
462
+ if adjacency is None:
463
+ if position is None:
464
+ raise ValueError("You must specify either adjacency or position.")
465
+ else:
466
+ n = position.shape[0]
467
+ adjacency = sparse.csr_matrix((n, n)).astype(int)
468
+ else:
469
+ n = adjacency.shape[0]
470
+ adjacency.eliminate_zeros()
471
+ if directed is None:
472
+ directed = not is_symmetric(adjacency)
473
+
474
+ # node order
475
+ if node_order is None:
476
+ node_order = np.arange(n)
477
+
478
+ # position
479
+ if position is None:
480
+ spring = Spring()
481
+ position = spring.fit_transform(adjacency)
482
+
483
+ # node colors
484
+ node_colors = get_node_colors(n, labels, scores, probs, node_color, label_colors)
485
+
486
+ # node sizes
487
+ if display_node_weight is None:
488
+ display_node_weight = node_weights is not None
489
+ if node_weights is None:
490
+ node_weights = adjacency.T.dot(np.ones(n))
491
+ node_sizes = get_node_sizes(node_weights, node_size, node_size_min, node_size_max, display_node_weight)
492
+
493
+ # node widths
494
+ node_widths = get_node_widths(n, seeds, node_width, node_width_max)
495
+
496
+ # rescaling
497
+ position, width, height = rescale(position, width, height, margin, node_size, node_size_max, display_node_weight,
498
+ names, name_position, font_size)
499
+
500
+ # scaling
501
+ position *= scale
502
+ height *= scale
503
+ width *= scale
504
+
505
+ svg = """<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">\n""".format(width, height)
506
+
507
+ # edges
508
+ if display_edges:
509
+ adjacency_coo = sparse.coo_matrix(adjacency)
510
+
511
+ if edge_color is None:
512
+ if names is None:
513
+ edge_color = 'black'
514
+ else:
515
+ edge_color = 'gray'
516
+
517
+ edge_colors, edge_order, edge_colors_residual = get_edge_colors(adjacency, edge_labels, edge_color,
518
+ label_colors)
519
+ edge_widths = get_edge_widths(adjacency_coo, edge_width, edge_width_min, edge_width_max, display_edge_weight)
520
+
521
+ if directed:
522
+ for edge_color in set(edge_colors):
523
+ svg += """<defs><marker id="arrow-{}" markerWidth="10" markerHeight="10" refX="9" refY="3"
524
+ orient="auto" >\n""".format(edge_color)
525
+ svg += """<path d="M0,0 L0,6 L9,3 z" fill="{}"/></marker></defs>\n""".format(edge_color)
526
+
527
+ for ix in edge_order:
528
+ i = adjacency_coo.row[ix]
529
+ j = adjacency_coo.col[ix]
530
+ color = edge_colors[ix]
531
+ if directed:
532
+ svg += svg_edge_directed(pos_1=position[i], pos_2=position[j], edge_width=edge_widths[ix],
533
+ edge_color=color, node_size=node_sizes[j])
534
+ else:
535
+ svg += svg_edge(pos_1=position[i], pos_2=position[j],
536
+ edge_width=edge_widths[ix], edge_color=color)
537
+
538
+ for i, j, color in edge_colors_residual:
539
+ if directed:
540
+ svg += svg_edge_directed(pos_1=position[i], pos_2=position[j], edge_width=edge_width,
541
+ edge_color=color, node_size=node_sizes[j])
542
+ else:
543
+ svg += svg_edge(pos_1=position[i], pos_2=position[j],
544
+ edge_width=edge_width, edge_color=color)
545
+
546
+ # nodes
547
+ for i in node_order:
548
+ if probs is None:
549
+ svg += svg_node(position[i], node_sizes[i], node_colors[i], node_widths[i])
550
+ else:
551
+ probs = check_format(probs)
552
+ if probs[i].nnz == 1:
553
+ index = probs[i].indices[0]
554
+ svg += svg_node(position[i], node_sizes[i], node_colors[index], node_widths[i])
555
+ else:
556
+ svg += svg_pie_chart_node(position[i], node_sizes[i], probs[i].todense(),
557
+ node_colors, node_widths[i])
558
+
559
+ # text
560
+ if names is not None:
561
+ for i in range(n):
562
+ svg += svg_text(position[i], names[i], node_sizes[i] + margin_text, font_size, name_position)
563
+ svg += """</svg>\n"""
564
+
565
+ if filename is not None:
566
+ with open(filename + '.svg', 'w') as f:
567
+ f.write(svg)
568
+
569
+ return svg
570
+
571
+
572
+ def visualize_bigraph(biadjacency: sparse.csr_matrix,
573
+ names_row: Optional[np.ndarray] = None, names_col: Optional[np.ndarray] = None,
574
+ labels_row: Optional[Union[dict, np.ndarray]] = None,
575
+ labels_col: Optional[Union[dict, np.ndarray]] = None,
576
+ scores_row: Optional[Union[dict, np.ndarray]] = None,
577
+ scores_col: Optional[Union[dict, np.ndarray]] = None,
578
+ probs_row: Optional[Union[np.ndarray, sparse.csr_matrix]] = None,
579
+ probs_col: Optional[Union[np.ndarray, sparse.csr_matrix]] = None,
580
+ seeds_row: Union[list, dict] = None, seeds_col: Union[list, dict] = None,
581
+ position_row: Optional[np.ndarray] = None, position_col: Optional[np.ndarray] = None,
582
+ reorder: bool = True, width: Optional[float] = 400,
583
+ height: Optional[float] = 300, margin: float = 20, margin_text: float = 3, scale: float = 1,
584
+ node_size: float = 7, node_size_min: float = 1, node_size_max: float = 20,
585
+ display_node_weight: bool = False,
586
+ node_weights_row: Optional[np.ndarray] = None, node_weights_col: Optional[np.ndarray] = None,
587
+ node_width: float = 1, node_width_max: float = 3,
588
+ color_row: str = 'gray', color_col: str = 'gray', label_colors: Optional[Iterable] = None,
589
+ display_edges: bool = True, edge_labels: Optional[list] = None, edge_width: float = 1,
590
+ edge_width_min: float = 0.5, edge_width_max: float = 10, edge_color: str = 'black',
591
+ display_edge_weight: bool = True,
592
+ font_size: int = 12, filename: Optional[str] = None) -> str:
593
+ """Return the image of a bipartite graph in SVG format.
594
+
595
+ Parameters
596
+ ----------
597
+ biadjacency :
598
+ Biadjacency matrix of the graph.
599
+ names_row :
600
+ Names of the rows.
601
+ names_col :
602
+ Names of the columns.
603
+ labels_row :
604
+ Labels of the rows (negative values mean no label).
605
+ labels_col :
606
+ Labels of the columns (negative values mean no label).
607
+ scores_row :
608
+ Scores of the rows (measure of importance).
609
+ scores_col :
610
+ Scores of the columns (measure of importance).
611
+ probs_row :
612
+ Probability distribution over labels for rows.
613
+ probs_col :
614
+ Probability distribution over labels for columns.
615
+ seeds_row :
616
+ Rows to be highlighted (if dict, only keys are considered).
617
+ seeds_col :
618
+ Columns to be highlighted (if dict, only keys are considered).
619
+ position_row :
620
+ Positions of the rows.
621
+ position_col :
622
+ Positions of the columns.
623
+ reorder :
624
+ Use clustering to order nodes.
625
+ width :
626
+ Width of the image.
627
+ height :
628
+ Height of the image.
629
+ margin :
630
+ Margin of the image.
631
+ margin_text :
632
+ Margin between node and text.
633
+ scale :
634
+ Multiplicative factor on the dimensions of the image.
635
+ node_size :
636
+ Size of nodes.
637
+ node_size_min :
638
+ Minimum size of nodes.
639
+ node_size_max :
640
+ Maximum size of nodes.
641
+ display_node_weight :
642
+ If ``True``, display node weights through node size.
643
+ node_weights_row :
644
+ Weights of rows (used only if **display_node_weight** is ``True``).
645
+ node_weights_col :
646
+ Weights of columns (used only if **display_node_weight** is ``True``).
647
+ node_width :
648
+ Width of node circle.
649
+ node_width_max :
650
+ Maximum width of node circle.
651
+ color_row :
652
+ Default color of rows (svg color).
653
+ color_col :
654
+ Default color of cols (svg color).
655
+ label_colors :
656
+ Colors of the labels (svg color).
657
+ display_edges :
658
+ If ``True``, display edges.
659
+ edge_labels :
660
+ Labels of the edges, as a list of tuples (source, destination, label)
661
+ edge_width :
662
+ Width of edges.
663
+ edge_width_min :
664
+ Minimum width of edges.
665
+ edge_width_max :
666
+ Maximum width of edges.
667
+ display_edge_weight :
668
+ If ``True``, display edge weights through edge widths.
669
+ edge_color :
670
+ Default color of edges (svg color).
671
+ font_size :
672
+ Font size.
673
+ filename :
674
+ Filename for saving image (optional).
675
+
676
+ Returns
677
+ -------
678
+ image : str
679
+ SVG image.
680
+
681
+ Example
682
+ -------
683
+ >>> from sknetwork.data import movie_actor
684
+ >>> biadjacency = movie_actor()
685
+ >>> from sknetwork.visualization import visualize_bigraph
686
+ >>> image = visualize_bigraph(biadjacency)
687
+ >>> image[1:4]
688
+ 'svg'
689
+ """
690
+ n_row, n_col = biadjacency.shape
691
+
692
+ # node positions
693
+ if position_row is None or position_col is None:
694
+ position_row = np.zeros((n_row, 2))
695
+ position_col = np.ones((n_col, 2))
696
+ if reorder:
697
+ louvain = Louvain()
698
+ louvain.fit(biadjacency, force_bipartite=True)
699
+ index_row = np.argsort(louvain.labels_row_)
700
+ index_col = np.argsort(louvain.labels_col_)
701
+ else:
702
+ index_row = np.arange(n_row)
703
+ index_col = np.arange(n_col)
704
+ position_row[index_row, 1] = np.arange(n_row)
705
+ position_col[index_col, 1] = np.arange(n_col) + .5 * (n_row - n_col)
706
+ position = np.vstack((position_row, position_col))
707
+
708
+ # node colors
709
+ if scores_row is not None and scores_col is not None:
710
+ if isinstance(scores_row, dict):
711
+ scores_row = np.array(list(scores_row.values()))
712
+ if isinstance(scores_col, dict):
713
+ scores_col = np.array(list(scores_col.values()))
714
+ scores = np.hstack((scores_row, scores_col))
715
+ score_min = np.min(scores)
716
+ score_max = np.max(scores)
717
+ else:
718
+ score_min = None
719
+ score_max = None
720
+
721
+ colors_row = get_node_colors(n_row, labels_row, scores_row, probs_row, color_row, label_colors,
722
+ score_min, score_max)
723
+ colors_col = get_node_colors(n_col, labels_col, scores_col, probs_col, color_col, label_colors,
724
+ score_min, score_max)
725
+
726
+ # node sizes
727
+ if node_weights_row is None:
728
+ node_weights_row = biadjacency.dot(np.ones(n_col))
729
+ if node_weights_col is None:
730
+ node_weights_col = biadjacency.T.dot(np.ones(n_row))
731
+ node_sizes_row, node_sizes_col = get_node_sizes_bipartite(node_weights_row, node_weights_col,
732
+ node_size, node_size_min, node_size_max,
733
+ display_node_weight)
734
+
735
+ # node widths
736
+ node_widths_row = get_node_widths(n_row, seeds_row, node_width, node_width_max)
737
+ node_widths_col = get_node_widths(n_col, seeds_col, node_width, node_width_max)
738
+
739
+ # rescaling
740
+ if not width and not height:
741
+ raise ValueError("You must specify either the width or the height of the image.")
742
+ position, width, height = rescale(position, width, height, margin, node_size, node_size_max, display_node_weight)
743
+
744
+ # node names
745
+ if names_row is not None:
746
+ text_length = np.max(np.array([len(str(name)) for name in names_row]))
747
+ position[:, 0] += text_length * font_size * .5
748
+ width += text_length * font_size * .5
749
+ if names_col is not None:
750
+ text_length = np.max(np.array([len(str(name)) for name in names_col]))
751
+ width += text_length * font_size * .5
752
+
753
+ # scaling
754
+ position *= scale
755
+ height *= scale
756
+ width *= scale
757
+ position_row = position[:n_row]
758
+ position_col = position[n_row:]
759
+
760
+ svg = """<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">\n""".format(width, height)
761
+
762
+ # edges
763
+ if display_edges:
764
+ biadjacency_coo = sparse.coo_matrix(biadjacency)
765
+
766
+ if edge_color is None:
767
+ if names_row is None and names_col is None:
768
+ edge_color = 'black'
769
+ else:
770
+ edge_color = 'gray'
771
+
772
+ edge_colors, edge_order, edge_colors_residual = get_edge_colors(biadjacency, edge_labels, edge_color,
773
+ label_colors)
774
+ edge_widths = get_edge_widths(biadjacency_coo, edge_width, edge_width_min, edge_width_max, display_edge_weight)
775
+
776
+ for ix in edge_order:
777
+ i = biadjacency_coo.row[ix]
778
+ j = biadjacency_coo.col[ix]
779
+ color = edge_colors[ix]
780
+ svg += svg_edge(pos_1=position_row[i], pos_2=position_col[j], edge_width=edge_widths[ix], edge_color=color)
781
+
782
+ for i, j, color in edge_colors_residual:
783
+ svg += svg_edge(pos_1=position_row[i], pos_2=position_col[j], edge_width=edge_width, edge_color=color)
784
+
785
+ # nodes
786
+ for i in range(n_row):
787
+ if probs_row is None:
788
+ svg += svg_node(position_row[i], node_sizes_row[i], colors_row[i], node_widths_row[i])
789
+ else:
790
+ probs_row = check_format(probs_row)
791
+ if probs_row[i].nnz == 1:
792
+ index = probs_row[i].indices[0]
793
+ svg += svg_node(position_row[i], node_sizes_row[i], colors_row[index], node_widths_row[i])
794
+ else:
795
+ svg += svg_pie_chart_node(position_row[i], node_sizes_row[i], probs_row[i].todense(),
796
+ colors_row, node_widths_row[i])
797
+
798
+ for i in range(n_col):
799
+ if probs_col is None:
800
+ svg += svg_node(position_col[i], node_sizes_col[i], colors_col[i], node_widths_col[i])
801
+ else:
802
+ probs_col = check_format(probs_col)
803
+ if probs_col[i].nnz == 1:
804
+ index = probs_col[i].indices[0]
805
+ svg += svg_node(position_col[i], node_sizes_col[i], colors_col[index], node_widths_col[i])
806
+ else:
807
+ svg += svg_pie_chart_node(position_col[i], node_sizes_col[i], probs_col[i].todense(),
808
+ colors_col, node_widths_col[i])
809
+ # text
810
+ if names_row is not None:
811
+ for i in range(n_row):
812
+ svg += svg_text(position_row[i], names_row[i], margin_text + node_sizes_row[i], font_size, 'left')
813
+ if names_col is not None:
814
+ for i in range(n_col):
815
+ svg += svg_text(position_col[i], names_col[i], margin_text + node_sizes_col[i], font_size)
816
+ svg += """</svg>\n"""
817
+
818
+ if filename is not None:
819
+ with open(filename + '.svg', 'w') as f:
820
+ f.write(svg)
821
+
822
+ return svg
823
+
824
+
825
+ def svg_graph(adjacency: Optional[sparse.csr_matrix] = None, position: Optional[np.ndarray] = None,
826
+ names: Optional[np.ndarray] = None, labels: Optional[Iterable] = None, name_position: str = 'right',
827
+ scores: Optional[Iterable] = None, probs: Optional[Union[np.ndarray, sparse.csr_matrix]] = None,
828
+ seeds: Union[list, dict] = None, width: Optional[float] = 400, height: Optional[float] = 300,
829
+ margin: float = 20, margin_text: float = 3, scale: float = 1, node_order: Optional[np.ndarray] = None,
830
+ node_size: float = 7, node_size_min: float = 1, node_size_max: float = 20,
831
+ display_node_weight: Optional[bool] = None, node_weights: Optional[np.ndarray] = None,
832
+ node_width: float = 1, node_width_max: float = 3, node_color: str = 'gray',
833
+ display_edges: bool = True, edge_labels: Optional[list] = None,
834
+ edge_width: float = 1, edge_width_min: float = 0.5,
835
+ edge_width_max: float = 20, display_edge_weight: bool = True,
836
+ edge_color: Optional[str] = None, label_colors: Optional[Iterable] = None,
837
+ font_size: int = 12, directed: Optional[bool] = None, filename: Optional[str] = None) -> str:
838
+ """Return the image of a graph in SVG format.
839
+
840
+ Alias for visualize_graph.
841
+
842
+ Parameters
843
+ ----------
844
+ adjacency :
845
+ Adjacency matrix of the graph.
846
+ position :
847
+ Positions of the nodes.
848
+ names :
849
+ Names of the nodes.
850
+ labels :
851
+ Labels of the nodes (negative values mean no label).
852
+ name_position :
853
+ Position of the names (left, right, above, below)
854
+ scores :
855
+ Scores of the nodes (measure of importance).
856
+ probs :
857
+ Probability distribution over labels.
858
+ seeds :
859
+ Nodes to be highlighted (if dict, only keys are considered).
860
+ width :
861
+ Width of the image.
862
+ height :
863
+ Height of the image.
864
+ margin :
865
+ Margin of the image.
866
+ margin_text :
867
+ Margin between node and text.
868
+ scale :
869
+ Multiplicative factor on the dimensions of the image.
870
+ node_order :
871
+ Order in which nodes are displayed.
872
+ node_size :
873
+ Size of nodes.
874
+ node_size_min :
875
+ Minimum size of a node.
876
+ node_size_max:
877
+ Maximum size of a node.
878
+ node_width :
879
+ Width of node circle.
880
+ node_width_max :
881
+ Maximum width of node circle.
882
+ node_color :
883
+ Default color of nodes (svg color).
884
+ display_node_weight :
885
+ If ``True``, display node weights through node size.
886
+ node_weights :
887
+ Node weights.
888
+ display_edges :
889
+ If ``True``, display edges.
890
+ edge_labels :
891
+ Labels of the edges, as a list of tuples (source, destination, label)
892
+ edge_width :
893
+ Width of edges.
894
+ edge_width_min :
895
+ Minimum width of edges.
896
+ edge_width_max :
897
+ Maximum width of edges.
898
+ display_edge_weight :
899
+ If ``True``, display edge weights through edge widths.
900
+ edge_color :
901
+ Default color of edges (svg color).
902
+ label_colors:
903
+ Colors of the labels (svg colors).
904
+ font_size :
905
+ Font size.
906
+ directed :
907
+ If ``True``, considers the graph as directed.
908
+ filename :
909
+ Filename for saving image (optional).
910
+
911
+ Returns
912
+ -------
913
+ image : str
914
+ SVG image.
915
+ """
916
+ return visualize_graph(adjacency, position, names, labels, name_position, scores, probs, seeds, width, height,
917
+ margin, margin_text, scale, node_order, node_size, node_size_min, node_size_max,
918
+ display_node_weight, node_weights, node_width, node_width_max, node_color, display_edges,
919
+ edge_labels, edge_width, edge_width_min, edge_width_max, display_edge_weight, edge_color,
920
+ label_colors, font_size, directed, filename)
921
+
922
+
923
+ def svg_bigraph(biadjacency: sparse.csr_matrix,
924
+ names_row: Optional[np.ndarray] = None, names_col: Optional[np.ndarray] = None,
925
+ labels_row: Optional[Union[dict, np.ndarray]] = None,
926
+ labels_col: Optional[Union[dict, np.ndarray]] = None,
927
+ scores_row: Optional[Union[dict, np.ndarray]] = None,
928
+ scores_col: Optional[Union[dict, np.ndarray]] = None,
929
+ probs_row: Optional[Union[np.ndarray, sparse.csr_matrix]] = None,
930
+ probs_col: Optional[Union[np.ndarray, sparse.csr_matrix]] = None,
931
+ seeds_row: Union[list, dict] = None, seeds_col: Union[list, dict] = None,
932
+ position_row: Optional[np.ndarray] = None, position_col: Optional[np.ndarray] = None,
933
+ reorder: bool = True, width: Optional[float] = 400,
934
+ height: Optional[float] = 300, margin: float = 20, margin_text: float = 3, scale: float = 1,
935
+ node_size: float = 7, node_size_min: float = 1, node_size_max: float = 20,
936
+ display_node_weight: bool = False,
937
+ node_weights_row: Optional[np.ndarray] = None, node_weights_col: Optional[np.ndarray] = None,
938
+ node_width: float = 1, node_width_max: float = 3,
939
+ color_row: str = 'gray', color_col: str = 'gray', label_colors: Optional[Iterable] = None,
940
+ display_edges: bool = True, edge_labels: Optional[list] = None, edge_width: float = 1,
941
+ edge_width_min: float = 0.5, edge_width_max: float = 10, edge_color: str = 'black',
942
+ display_edge_weight: bool = True,
943
+ font_size: int = 12, filename: Optional[str] = None) -> str:
944
+ """Return the image of a bipartite graph in SVG format.
945
+
946
+ Alias for visualize_bigraph.
947
+
948
+ Parameters
949
+ ----------
950
+ biadjacency :
951
+ Biadjacency matrix of the graph.
952
+ names_row :
953
+ Names of the rows.
954
+ names_col :
955
+ Names of the columns.
956
+ labels_row :
957
+ Labels of the rows (negative values mean no label).
958
+ labels_col :
959
+ Labels of the columns (negative values mean no label).
960
+ scores_row :
961
+ Scores of the rows (measure of importance).
962
+ scores_col :
963
+ Scores of the columns (measure of importance).
964
+ probs_row :
965
+ Probability distribution over labels for rows.
966
+ probs_col :
967
+ Probability distribution over labels for columns.
968
+ seeds_row :
969
+ Rows to be highlighted (if dict, only keys are considered).
970
+ seeds_col :
971
+ Columns to be highlighted (if dict, only keys are considered).
972
+ position_row :
973
+ Positions of the rows.
974
+ position_col :
975
+ Positions of the columns.
976
+ reorder :
977
+ Use clustering to order nodes.
978
+ width :
979
+ Width of the image.
980
+ height :
981
+ Height of the image.
982
+ margin :
983
+ Margin of the image.
984
+ margin_text :
985
+ Margin between node and text.
986
+ scale :
987
+ Multiplicative factor on the dimensions of the image.
988
+ node_size :
989
+ Size of nodes.
990
+ node_size_min :
991
+ Minimum size of nodes.
992
+ node_size_max :
993
+ Maximum size of nodes.
994
+ display_node_weight :
995
+ If ``True``, display node weights through node size.
996
+ node_weights_row :
997
+ Weights of rows (used only if **display_node_weight** is ``True``).
998
+ node_weights_col :
999
+ Weights of columns (used only if **display_node_weight** is ``True``).
1000
+ node_width :
1001
+ Width of node circle.
1002
+ node_width_max :
1003
+ Maximum width of node circle.
1004
+ color_row :
1005
+ Default color of rows (svg color).
1006
+ color_col :
1007
+ Default color of cols (svg color).
1008
+ label_colors :
1009
+ Colors of the labels (svg color).
1010
+ display_edges :
1011
+ If ``True``, display edges.
1012
+ edge_labels :
1013
+ Labels of the edges, as a list of tuples (source, destination, label)
1014
+ edge_width :
1015
+ Width of edges.
1016
+ edge_width_min :
1017
+ Minimum width of edges.
1018
+ edge_width_max :
1019
+ Maximum width of edges.
1020
+ display_edge_weight :
1021
+ If ``True``, display edge weights through edge widths.
1022
+ edge_color :
1023
+ Default color of edges (svg color).
1024
+ font_size :
1025
+ Font size.
1026
+ filename :
1027
+ Filename for saving image (optional).
1028
+
1029
+ Returns
1030
+ -------
1031
+ image : str
1032
+ SVG image.
1033
+ """
1034
+ return visualize_bigraph(biadjacency, names_row, names_col, labels_row, labels_col, scores_row, scores_col,
1035
+ probs_row, probs_col, seeds_row, seeds_col, position_row, position_col, reorder,
1036
+ width, height, margin, margin_text, scale, node_size, node_size_min, node_size_max,
1037
+ display_node_weight, node_weights_row, node_weights_col, node_width, node_width_max,
1038
+ color_row, color_col, label_colors, display_edges, edge_labels, edge_width, edge_width_min,
1039
+ edge_width_max, edge_color, display_edge_weight, font_size, filename)