scikit-network 0.28.3__cp39-cp39-macosx_12_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-network might be problematic. Click here for more details.

Files changed (240) hide show
  1. scikit_network-0.28.3.dist-info/AUTHORS.rst +41 -0
  2. scikit_network-0.28.3.dist-info/LICENSE +34 -0
  3. scikit_network-0.28.3.dist-info/METADATA +457 -0
  4. scikit_network-0.28.3.dist-info/RECORD +240 -0
  5. scikit_network-0.28.3.dist-info/WHEEL +5 -0
  6. scikit_network-0.28.3.dist-info/top_level.txt +1 -0
  7. sknetwork/__init__.py +21 -0
  8. sknetwork/classification/__init__.py +8 -0
  9. sknetwork/classification/base.py +84 -0
  10. sknetwork/classification/base_rank.py +143 -0
  11. sknetwork/classification/diffusion.py +134 -0
  12. sknetwork/classification/knn.py +162 -0
  13. sknetwork/classification/metrics.py +205 -0
  14. sknetwork/classification/pagerank.py +66 -0
  15. sknetwork/classification/propagation.py +152 -0
  16. sknetwork/classification/tests/__init__.py +1 -0
  17. sknetwork/classification/tests/test_API.py +35 -0
  18. sknetwork/classification/tests/test_diffusion.py +37 -0
  19. sknetwork/classification/tests/test_knn.py +24 -0
  20. sknetwork/classification/tests/test_metrics.py +53 -0
  21. sknetwork/classification/tests/test_pagerank.py +20 -0
  22. sknetwork/classification/tests/test_propagation.py +24 -0
  23. sknetwork/classification/vote.cpython-39-darwin.so +0 -0
  24. sknetwork/classification/vote.pyx +58 -0
  25. sknetwork/clustering/__init__.py +7 -0
  26. sknetwork/clustering/base.py +102 -0
  27. sknetwork/clustering/kmeans.py +142 -0
  28. sknetwork/clustering/louvain.py +255 -0
  29. sknetwork/clustering/louvain_core.cpython-39-darwin.so +0 -0
  30. sknetwork/clustering/louvain_core.pyx +134 -0
  31. sknetwork/clustering/metrics.py +91 -0
  32. sknetwork/clustering/postprocess.py +66 -0
  33. sknetwork/clustering/propagation_clustering.py +108 -0
  34. sknetwork/clustering/tests/__init__.py +1 -0
  35. sknetwork/clustering/tests/test_API.py +37 -0
  36. sknetwork/clustering/tests/test_kmeans.py +47 -0
  37. sknetwork/clustering/tests/test_louvain.py +104 -0
  38. sknetwork/clustering/tests/test_metrics.py +50 -0
  39. sknetwork/clustering/tests/test_post_processing.py +23 -0
  40. sknetwork/clustering/tests/test_postprocess.py +39 -0
  41. sknetwork/data/__init__.py +5 -0
  42. sknetwork/data/load.py +408 -0
  43. sknetwork/data/models.py +459 -0
  44. sknetwork/data/parse.py +621 -0
  45. sknetwork/data/test_graphs.py +84 -0
  46. sknetwork/data/tests/__init__.py +1 -0
  47. sknetwork/data/tests/test_API.py +30 -0
  48. sknetwork/data/tests/test_load.py +95 -0
  49. sknetwork/data/tests/test_models.py +52 -0
  50. sknetwork/data/tests/test_parse.py +253 -0
  51. sknetwork/data/tests/test_test_graphs.py +30 -0
  52. sknetwork/data/tests/test_toy_graphs.py +68 -0
  53. sknetwork/data/toy_graphs.py +619 -0
  54. sknetwork/embedding/__init__.py +10 -0
  55. sknetwork/embedding/base.py +90 -0
  56. sknetwork/embedding/force_atlas.py +197 -0
  57. sknetwork/embedding/louvain_embedding.py +174 -0
  58. sknetwork/embedding/louvain_hierarchy.py +142 -0
  59. sknetwork/embedding/metrics.py +66 -0
  60. sknetwork/embedding/random_projection.py +133 -0
  61. sknetwork/embedding/spectral.py +214 -0
  62. sknetwork/embedding/spring.py +198 -0
  63. sknetwork/embedding/svd.py +363 -0
  64. sknetwork/embedding/tests/__init__.py +1 -0
  65. sknetwork/embedding/tests/test_API.py +73 -0
  66. sknetwork/embedding/tests/test_force_atlas.py +35 -0
  67. sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
  68. sknetwork/embedding/tests/test_louvain_hierarchy.py +19 -0
  69. sknetwork/embedding/tests/test_metrics.py +29 -0
  70. sknetwork/embedding/tests/test_random_projection.py +28 -0
  71. sknetwork/embedding/tests/test_spectral.py +84 -0
  72. sknetwork/embedding/tests/test_spring.py +50 -0
  73. sknetwork/embedding/tests/test_svd.py +37 -0
  74. sknetwork/flow/__init__.py +3 -0
  75. sknetwork/flow/flow.py +73 -0
  76. sknetwork/flow/tests/__init__.py +1 -0
  77. sknetwork/flow/tests/test_flow.py +17 -0
  78. sknetwork/flow/tests/test_utils.py +69 -0
  79. sknetwork/flow/utils.py +91 -0
  80. sknetwork/gnn/__init__.py +10 -0
  81. sknetwork/gnn/activation.py +117 -0
  82. sknetwork/gnn/base.py +155 -0
  83. sknetwork/gnn/base_activation.py +89 -0
  84. sknetwork/gnn/base_layer.py +109 -0
  85. sknetwork/gnn/gnn_classifier.py +381 -0
  86. sknetwork/gnn/layer.py +153 -0
  87. sknetwork/gnn/layers.py +127 -0
  88. sknetwork/gnn/loss.py +180 -0
  89. sknetwork/gnn/neighbor_sampler.py +65 -0
  90. sknetwork/gnn/optimizer.py +163 -0
  91. sknetwork/gnn/tests/__init__.py +1 -0
  92. sknetwork/gnn/tests/test_activation.py +56 -0
  93. sknetwork/gnn/tests/test_base.py +79 -0
  94. sknetwork/gnn/tests/test_base_layer.py +37 -0
  95. sknetwork/gnn/tests/test_gnn_classifier.py +192 -0
  96. sknetwork/gnn/tests/test_layers.py +80 -0
  97. sknetwork/gnn/tests/test_loss.py +33 -0
  98. sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
  99. sknetwork/gnn/tests/test_optimizer.py +43 -0
  100. sknetwork/gnn/tests/test_utils.py +93 -0
  101. sknetwork/gnn/utils.py +219 -0
  102. sknetwork/hierarchy/__init__.py +7 -0
  103. sknetwork/hierarchy/base.py +69 -0
  104. sknetwork/hierarchy/louvain_hierarchy.py +264 -0
  105. sknetwork/hierarchy/metrics.py +234 -0
  106. sknetwork/hierarchy/paris.cpython-39-darwin.so +0 -0
  107. sknetwork/hierarchy/paris.pyx +317 -0
  108. sknetwork/hierarchy/postprocess.py +350 -0
  109. sknetwork/hierarchy/tests/__init__.py +1 -0
  110. sknetwork/hierarchy/tests/test_API.py +25 -0
  111. sknetwork/hierarchy/tests/test_algos.py +29 -0
  112. sknetwork/hierarchy/tests/test_metrics.py +62 -0
  113. sknetwork/hierarchy/tests/test_postprocess.py +57 -0
  114. sknetwork/hierarchy/tests/test_ward.py +25 -0
  115. sknetwork/hierarchy/ward.py +94 -0
  116. sknetwork/linalg/__init__.py +9 -0
  117. sknetwork/linalg/basics.py +37 -0
  118. sknetwork/linalg/diteration.cpython-39-darwin.so +0 -0
  119. sknetwork/linalg/diteration.pyx +49 -0
  120. sknetwork/linalg/eig_solver.py +93 -0
  121. sknetwork/linalg/laplacian.py +15 -0
  122. sknetwork/linalg/normalization.py +66 -0
  123. sknetwork/linalg/operators.py +225 -0
  124. sknetwork/linalg/polynome.py +76 -0
  125. sknetwork/linalg/ppr_solver.py +170 -0
  126. sknetwork/linalg/push.cpython-39-darwin.so +0 -0
  127. sknetwork/linalg/push.pyx +73 -0
  128. sknetwork/linalg/sparse_lowrank.py +142 -0
  129. sknetwork/linalg/svd_solver.py +91 -0
  130. sknetwork/linalg/tests/__init__.py +1 -0
  131. sknetwork/linalg/tests/test_eig.py +44 -0
  132. sknetwork/linalg/tests/test_laplacian.py +18 -0
  133. sknetwork/linalg/tests/test_normalization.py +38 -0
  134. sknetwork/linalg/tests/test_operators.py +70 -0
  135. sknetwork/linalg/tests/test_polynome.py +38 -0
  136. sknetwork/linalg/tests/test_ppr.py +50 -0
  137. sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
  138. sknetwork/linalg/tests/test_svd.py +38 -0
  139. sknetwork/linkpred/__init__.py +4 -0
  140. sknetwork/linkpred/base.py +80 -0
  141. sknetwork/linkpred/first_order.py +508 -0
  142. sknetwork/linkpred/first_order_core.cpython-39-darwin.so +0 -0
  143. sknetwork/linkpred/first_order_core.pyx +315 -0
  144. sknetwork/linkpred/postprocessing.py +98 -0
  145. sknetwork/linkpred/tests/__init__.py +1 -0
  146. sknetwork/linkpred/tests/test_API.py +49 -0
  147. sknetwork/linkpred/tests/test_postprocessing.py +21 -0
  148. sknetwork/path/__init__.py +4 -0
  149. sknetwork/path/metrics.py +148 -0
  150. sknetwork/path/search.py +65 -0
  151. sknetwork/path/shortest_path.py +186 -0
  152. sknetwork/path/tests/__init__.py +1 -0
  153. sknetwork/path/tests/test_metrics.py +29 -0
  154. sknetwork/path/tests/test_search.py +25 -0
  155. sknetwork/path/tests/test_shortest_path.py +45 -0
  156. sknetwork/ranking/__init__.py +9 -0
  157. sknetwork/ranking/base.py +56 -0
  158. sknetwork/ranking/betweenness.cpython-39-darwin.so +0 -0
  159. sknetwork/ranking/betweenness.pyx +99 -0
  160. sknetwork/ranking/closeness.py +95 -0
  161. sknetwork/ranking/harmonic.py +82 -0
  162. sknetwork/ranking/hits.py +94 -0
  163. sknetwork/ranking/katz.py +81 -0
  164. sknetwork/ranking/pagerank.py +107 -0
  165. sknetwork/ranking/postprocess.py +25 -0
  166. sknetwork/ranking/tests/__init__.py +1 -0
  167. sknetwork/ranking/tests/test_API.py +34 -0
  168. sknetwork/ranking/tests/test_betweenness.py +38 -0
  169. sknetwork/ranking/tests/test_closeness.py +34 -0
  170. sknetwork/ranking/tests/test_hits.py +20 -0
  171. sknetwork/ranking/tests/test_pagerank.py +69 -0
  172. sknetwork/regression/__init__.py +4 -0
  173. sknetwork/regression/base.py +56 -0
  174. sknetwork/regression/diffusion.py +190 -0
  175. sknetwork/regression/tests/__init__.py +1 -0
  176. sknetwork/regression/tests/test_API.py +34 -0
  177. sknetwork/regression/tests/test_diffusion.py +48 -0
  178. sknetwork/sknetwork.py +3 -0
  179. sknetwork/topology/__init__.py +9 -0
  180. sknetwork/topology/dag.py +74 -0
  181. sknetwork/topology/dag_core.cpython-39-darwin.so +0 -0
  182. sknetwork/topology/dag_core.pyx +38 -0
  183. sknetwork/topology/kcliques.cpython-39-darwin.so +0 -0
  184. sknetwork/topology/kcliques.pyx +193 -0
  185. sknetwork/topology/kcore.cpython-39-darwin.so +0 -0
  186. sknetwork/topology/kcore.pyx +120 -0
  187. sknetwork/topology/structure.py +234 -0
  188. sknetwork/topology/tests/__init__.py +1 -0
  189. sknetwork/topology/tests/test_cliques.py +28 -0
  190. sknetwork/topology/tests/test_cores.py +21 -0
  191. sknetwork/topology/tests/test_dag.py +26 -0
  192. sknetwork/topology/tests/test_structure.py +99 -0
  193. sknetwork/topology/tests/test_triangles.py +42 -0
  194. sknetwork/topology/tests/test_wl_coloring.py +49 -0
  195. sknetwork/topology/tests/test_wl_kernel.py +31 -0
  196. sknetwork/topology/triangles.cpython-39-darwin.so +0 -0
  197. sknetwork/topology/triangles.pyx +166 -0
  198. sknetwork/topology/weisfeiler_lehman.py +163 -0
  199. sknetwork/topology/weisfeiler_lehman_core.cpython-39-darwin.so +0 -0
  200. sknetwork/topology/weisfeiler_lehman_core.pyx +116 -0
  201. sknetwork/utils/__init__.py +40 -0
  202. sknetwork/utils/base.py +35 -0
  203. sknetwork/utils/check.py +354 -0
  204. sknetwork/utils/co_neighbor.py +71 -0
  205. sknetwork/utils/format.py +219 -0
  206. sknetwork/utils/kmeans.py +89 -0
  207. sknetwork/utils/knn.py +166 -0
  208. sknetwork/utils/knn1d.cpython-39-darwin.so +0 -0
  209. sknetwork/utils/knn1d.pyx +80 -0
  210. sknetwork/utils/membership.py +82 -0
  211. sknetwork/utils/minheap.cpython-39-darwin.so +0 -0
  212. sknetwork/utils/minheap.pxd +22 -0
  213. sknetwork/utils/minheap.pyx +111 -0
  214. sknetwork/utils/neighbors.py +115 -0
  215. sknetwork/utils/seeds.py +75 -0
  216. sknetwork/utils/simplex.py +140 -0
  217. sknetwork/utils/tests/__init__.py +1 -0
  218. sknetwork/utils/tests/test_base.py +28 -0
  219. sknetwork/utils/tests/test_bunch.py +16 -0
  220. sknetwork/utils/tests/test_check.py +190 -0
  221. sknetwork/utils/tests/test_co_neighbor.py +43 -0
  222. sknetwork/utils/tests/test_format.py +61 -0
  223. sknetwork/utils/tests/test_kmeans.py +21 -0
  224. sknetwork/utils/tests/test_knn.py +32 -0
  225. sknetwork/utils/tests/test_membership.py +24 -0
  226. sknetwork/utils/tests/test_neighbors.py +41 -0
  227. sknetwork/utils/tests/test_projection_simplex.py +33 -0
  228. sknetwork/utils/tests/test_seeds.py +67 -0
  229. sknetwork/utils/tests/test_verbose.py +15 -0
  230. sknetwork/utils/tests/test_ward.py +20 -0
  231. sknetwork/utils/timeout.py +38 -0
  232. sknetwork/utils/verbose.py +37 -0
  233. sknetwork/utils/ward.py +60 -0
  234. sknetwork/visualization/__init__.py +4 -0
  235. sknetwork/visualization/colors.py +34 -0
  236. sknetwork/visualization/dendrograms.py +229 -0
  237. sknetwork/visualization/graphs.py +819 -0
  238. sknetwork/visualization/tests/__init__.py +1 -0
  239. sknetwork/visualization/tests/test_dendrograms.py +53 -0
  240. sknetwork/visualization/tests/test_graphs.py +167 -0
@@ -0,0 +1,819 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created in April 2020
5
+ @authors:
6
+ Thomas Bonald <thomas.bonald@telecom-paris.fr>
7
+ Quentin Lutz <qlutz@live.fr>
8
+ """
9
+ from typing import Optional, Iterable, Union, Tuple
10
+
11
+ import numpy as np
12
+ from scipy import sparse
13
+
14
+ from sknetwork.clustering.louvain import Louvain
15
+ from sknetwork.utils.format import is_symmetric
16
+ from sknetwork.embedding.spring import Spring
17
+ from sknetwork.visualization.colors import STANDARD_COLORS, COOLWARM_RGB
18
+
19
+
20
+ def min_max_scaling(x: np.ndarray, x_min: Optional[float] = None, x_max: Optional[float] = None) -> np.ndarray:
21
+ """Shift and scale vector to be between 0 and 1."""
22
+ x = x.astype(float)
23
+ if x_min is None:
24
+ x_min = np.min(x)
25
+ if x_max is None:
26
+ x_max = np.max(x)
27
+ x -= x_min
28
+ if x_max > x_min:
29
+ x /= (x_max - x_min)
30
+ else:
31
+ x = .5 * np.ones_like(x)
32
+ return x
33
+
34
+
35
+ def rescale(position: np.ndarray, width: float, height: float, margin: float, node_size: float, node_size_max: float,
36
+ display_node_weight: bool, names: Optional[np.ndarray] = None, name_position: str = 'right',
37
+ font_size: int = 12):
38
+ """Rescale position and adjust parameters.
39
+
40
+ Parameters
41
+ ----------
42
+ position :
43
+ array to rescale
44
+ width :
45
+ Horizontal scaling parameter
46
+ height :
47
+ Vertical scaling parameter
48
+ margin :
49
+ Minimal margin for the plot
50
+ node_size :
51
+ Node size (used to adapt the margin)
52
+ node_size_max :
53
+ Maximum node size (used to adapt the margin)
54
+ display_node_weight :
55
+ If ``True``, display node weight (used to adapt the margin)
56
+ names :
57
+ Names of nodes.
58
+ name_position :
59
+ Position of names (left, right, above, below)
60
+ font_size :
61
+ Font size
62
+
63
+ Returns
64
+ -------
65
+ position :
66
+ Rescaled positions
67
+ width :
68
+ Rescaled width
69
+ height :
70
+ Rescaled height
71
+ """
72
+ x = position[:, 0]
73
+ y = position[:, 1]
74
+ span_x = np.max(x) - np.min(x)
75
+ span_y = np.max(y) - np.min(y)
76
+ x = min_max_scaling(x)
77
+ y = 1 - min_max_scaling(y)
78
+ position = np.vstack((x, y)).T
79
+
80
+ # rescale
81
+ if width and not height:
82
+ height = width
83
+ if span_x and span_y:
84
+ height *= span_y / span_x
85
+ elif height and not width:
86
+ width = height
87
+ if span_x and span_y:
88
+ width *= span_x / span_y
89
+ position = position * np.array([width, height])
90
+
91
+ # text
92
+ if names is not None:
93
+ lengths = np.array([len(str(name)) for name in names])
94
+ if name_position == 'left':
95
+ margin_left = -np.min(position[:, 0] - lengths * font_size)
96
+ margin_left = margin_left * (margin_left > 0)
97
+ position[:, 0] += margin_left
98
+ width += margin_left
99
+ elif name_position == 'right':
100
+ margin_right = np.max(position[:, 0] + lengths * font_size - width)
101
+ margin_right = margin_right * (margin_right > 0)
102
+ width += margin_right
103
+ else:
104
+ margin_left = -np.min(position[:, 0] - lengths * font_size / 2)
105
+ margin_left = margin_left * (margin_left > 0)
106
+ margin_right = np.max(position[:, 0] + lengths * font_size / 2 - width)
107
+ margin_right = margin_right * (margin_right > 0)
108
+ position[:, 0] += margin_left
109
+ width += margin_left + margin_right
110
+ if name_position == 'above':
111
+ position[:, 1] += font_size
112
+ height += font_size
113
+ else:
114
+ height += font_size
115
+
116
+ # margins
117
+ margin = max(margin, node_size_max * display_node_weight, node_size)
118
+ position += margin
119
+ width += 2 * margin
120
+ height += 2 * margin
121
+ return position, width, height
122
+
123
+
124
+ def get_label_colors(label_colors: Optional[Iterable]):
125
+ """Return label svg colors.
126
+
127
+ Examples
128
+ --------
129
+ >>> get_label_colors(['black'])
130
+ array(['black'], dtype='<U5')
131
+ >>> get_label_colors({0: 'blue'})
132
+ array(['blue'], dtype='<U64')
133
+ """
134
+ if label_colors is not None:
135
+ if isinstance(label_colors, dict):
136
+ keys = list(label_colors.keys())
137
+ values = list(label_colors.values())
138
+ label_colors = np.array(['black'] * (max(keys) + 1), dtype='U64')
139
+ label_colors[keys] = values
140
+ elif isinstance(label_colors, list):
141
+ label_colors = np.array(label_colors)
142
+ else:
143
+ label_colors = STANDARD_COLORS.copy()
144
+ return label_colors
145
+
146
+
147
+ def get_node_colors(n: int, labels: Optional[Iterable], scores: Optional[Iterable],
148
+ membership: Optional[sparse.csr_matrix],
149
+ node_color: str, label_colors: Optional[Iterable],
150
+ score_min: Optional[float] = None, score_max: Optional[float] = None) -> np.ndarray:
151
+ """Return the colors of the nodes using either labels or scores or default color."""
152
+ node_colors = np.array(n * [node_color]).astype('U64')
153
+ if labels is not None:
154
+ if isinstance(labels, dict):
155
+ keys = np.array(list(labels.keys()))
156
+ values = np.array(list(labels.values())).astype(int)
157
+ labels = -np.ones(n, dtype=int)
158
+ labels[keys] = values
159
+ elif isinstance(labels, list):
160
+ if len(labels) != n:
161
+ raise ValueError("The number of labels must be equal to the corresponding number of nodes.")
162
+ else:
163
+ labels = np.array(labels)
164
+ index = labels >= 0
165
+ label_colors = get_label_colors(label_colors)
166
+ node_colors[index] = label_colors[labels[index] % len(label_colors)]
167
+ elif scores is not None:
168
+ colors_score = COOLWARM_RGB.copy()
169
+ n_colors = colors_score.shape[0]
170
+ colors_score_svg = np.array(['rgb' + str(tuple(colors_score[i])) for i in range(n_colors)])
171
+ if isinstance(scores, dict):
172
+ keys = np.array(list(scores.keys()))
173
+ values = np.array(list(scores.values()))
174
+ scores = (min_max_scaling(values, score_min, score_max) * (n_colors - 1)).astype(int)
175
+ node_colors[keys] = colors_score_svg[scores]
176
+ else:
177
+ if isinstance(scores, list):
178
+ if len(scores) != n:
179
+ raise ValueError("The number of scores must be equal to the corresponding number of nodes.")
180
+ else:
181
+ scores = np.array(scores)
182
+ scores = (min_max_scaling(scores, score_min, score_max) * (n_colors - 1)).astype(int)
183
+ node_colors = colors_score_svg[scores]
184
+ elif membership is not None:
185
+ if isinstance(label_colors, dict):
186
+ raise TypeError("Label colors must be a list or an array when using a membership.")
187
+ label_colors = get_label_colors(label_colors)
188
+ node_colors = label_colors
189
+ return node_colors
190
+
191
+
192
+ def get_node_widths(n: int, seeds: Union[int, dict, list], node_width: float, node_width_max: float) -> np.ndarray:
193
+ """Return the node widths."""
194
+ node_widths = node_width * np.ones(n)
195
+ if seeds is not None:
196
+ if type(seeds) == dict:
197
+ seeds = list(seeds.keys())
198
+ elif np.issubdtype(type(seeds), np.integer):
199
+ seeds = [seeds]
200
+ if len(seeds):
201
+ node_widths[np.array(seeds)] = node_width_max
202
+ return node_widths
203
+
204
+
205
+ def get_node_sizes(weights: np.ndarray, node_size: float, node_size_min: float, node_size_max: float, node_weight) \
206
+ -> np.ndarray:
207
+ """Return the node sizes."""
208
+ if node_weight and np.min(weights) < np.max(weights):
209
+ node_sizes = node_size_min + np.abs(node_size_max - node_size_min) * weights / np.max(weights)
210
+ else:
211
+ node_sizes = node_size * np.ones_like(weights)
212
+ return node_sizes
213
+
214
+
215
+ def get_node_sizes_bipartite(weights_row: np.ndarray, weights_col: np.ndarray, node_size: float, node_size_min: float,
216
+ node_size_max: float, node_weight) -> (np.ndarray, np.ndarray):
217
+ """Return the node sizes for bipartite graphs."""
218
+ weights = np.hstack((weights_row, weights_col))
219
+ if node_weight and np.min(weights) < np.max(weights):
220
+ node_sizes_row = node_size_min + np.abs(node_size_max - node_size_min) * weights_row / np.max(weights)
221
+ node_sizes_col = node_size_min + np.abs(node_size_max - node_size_min) * weights_col / np.max(weights)
222
+ else:
223
+ node_sizes_row = node_size * np.ones_like(weights_row)
224
+ node_sizes_col = node_size * np.ones_like(weights_col)
225
+ return node_sizes_row, node_sizes_col
226
+
227
+
228
+ def get_edge_colors(adjacency: sparse.csr_matrix, edge_labels: Optional[list], edge_color: str,
229
+ label_colors: Optional[Iterable]) -> Tuple[np.ndarray, np.ndarray, list]:
230
+ """Return the edge colors."""
231
+ n_row, n_col = adjacency.shape
232
+ n_edges = adjacency.nnz
233
+ adjacency_labels = (adjacency > 0).astype(int)
234
+ adjacency_labels.data = -adjacency_labels.data
235
+ edge_colors_residual = []
236
+ if edge_labels:
237
+ label_colors = get_label_colors(label_colors)
238
+ for i, j, label in edge_labels:
239
+ if i < 0 or i >= n_row or j < 0 or j >= n_col:
240
+ raise ValueError('Invalid node index in edge labels.')
241
+ if adjacency[i, j]:
242
+ adjacency_labels[i, j] = label % len(label_colors)
243
+ else:
244
+ color = label_colors[label % len(label_colors)]
245
+ edge_colors_residual.append((i, j, color))
246
+ edge_order = np.argsort(adjacency_labels.data)
247
+ edge_colors = np.array(n_edges * [edge_color]).astype('U64')
248
+ index = np.argwhere(adjacency_labels.data >= 0).ravel()
249
+ if len(index):
250
+ edge_colors[index] = label_colors[adjacency_labels.data[index]]
251
+ return edge_colors, edge_order, edge_colors_residual
252
+
253
+
254
+ def get_edge_widths(adjacency: sparse.coo_matrix, edge_width: float, edge_width_min: float, edge_width_max: float,
255
+ display_edge_weight: bool) -> np.ndarray:
256
+ """Return the edge widths."""
257
+ weights = adjacency.data
258
+ edge_widths = None
259
+ if len(weights):
260
+ if display_edge_weight and np.min(weights) < np.max(weights):
261
+ edge_widths = edge_width_min + np.abs(edge_width_max - edge_width_min) * (weights - np.min(weights))\
262
+ / (np.max(weights) - np.min(weights))
263
+ else:
264
+ edge_widths = edge_width * np.ones_like(weights)
265
+ return edge_widths
266
+
267
+
268
+ def svg_node(pos_node: np.ndarray, size: float, color: str, stroke_width: float = 1, stroke_color: str = 'black') \
269
+ -> str:
270
+ """Return svg code for a node.
271
+
272
+ Parameters
273
+ ----------
274
+ pos_node :
275
+ (x, y) coordinates of the node.
276
+ size :
277
+ Radius of disk in pixels.
278
+ color :
279
+ Color of the disk in SVG valid format.
280
+ stroke_width :
281
+ Width of the contour of the disk in pixels, centered around the circle.
282
+ stroke_color :
283
+ Color of the contour in SVG valid format.
284
+
285
+ Returns
286
+ -------
287
+ SVG code for the node.
288
+ """
289
+ x, y = pos_node.astype(int)
290
+ return """<circle cx="{}" cy="{}" r="{}" style="fill:{};stroke:{};stroke-width:{}"/>\n"""\
291
+ .format(x, y, size, color, stroke_color, stroke_width)
292
+
293
+
294
+ def svg_pie_chart_node(pos_node: np.ndarray, size: float, membership: np.ndarray, colors: np.ndarray,
295
+ stroke_width: float = 1, stroke_color: str = 'black') -> str:
296
+ """Return svg code for a pie-chart node."""
297
+ x, y = pos_node.astype(float)
298
+ n_colors = len(colors)
299
+ out = ""
300
+ cumsum = np.zeros(membership.shape[1] + 1)
301
+ cumsum[1:] = np.cumsum(membership)
302
+ if cumsum[-1] == 0:
303
+ return svg_node(pos_node, size, 'white', stroke_width=3)
304
+ sum_membership = cumsum[-1]
305
+ cumsum = np.multiply(cumsum, (2 * np.pi) / cumsum[-1])
306
+ x_array = size * np.cos(cumsum) + x
307
+ y_array = size * np.sin(cumsum) + y
308
+ large = np.array(membership > sum_membership / 2).ravel()
309
+ for index in range(membership.shape[1]):
310
+ out += """<path d="M {} {} A {} {} 0 {} 1 {} {} L {} {}" style="fill:{};stroke:{};stroke-width:{}" />\n"""\
311
+ .format(x_array[index], y_array[index], size, size, int(large[index]),
312
+ x_array[index + 1], y_array[index + 1], x, y, colors[index % n_colors], stroke_color, stroke_width)
313
+ return out
314
+
315
+
316
+ def svg_edge(pos_1: np.ndarray, pos_2: np.ndarray, edge_width: float = 1, edge_color: str = 'black') -> str:
317
+ """Return svg code for an edge."""
318
+ x1, y1 = pos_1.astype(int)
319
+ x2, y2 = pos_2.astype(int)
320
+ return """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}"/>\n"""\
321
+ .format(edge_width, edge_color, x1, y1, x2, y2)
322
+
323
+
324
+ def svg_edge_directed(pos_1: np.ndarray, pos_2: np.ndarray, edge_width: float = 1, edge_color: str = 'black',
325
+ node_size: float = 1.) -> str:
326
+ """Return svg code for a directed edge."""
327
+ vec = pos_2 - pos_1
328
+ norm = np.linalg.norm(vec)
329
+ if norm:
330
+ x, y = ((vec / norm) * node_size).astype(int)
331
+ x1, y1 = pos_1.astype(int)
332
+ x2, y2 = pos_2.astype(int)
333
+ return """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" marker-end="url(#arrow-{})"/>\n"""\
334
+ .format(edge_width, edge_color, x1, y1, x2 - x, y2 - y, edge_color)
335
+ else:
336
+ return ""
337
+
338
+
339
+ def svg_text(pos, text, margin_text, font_size=12, position: str = 'right'):
340
+ """Return svg code for text."""
341
+ if position == 'left':
342
+ pos[0] -= margin_text
343
+ anchor = 'end'
344
+ elif position == 'above':
345
+ pos[1] -= margin_text
346
+ anchor = 'middle'
347
+ elif position == 'below':
348
+ pos[1] += 2 * margin_text
349
+ anchor = 'middle'
350
+ else:
351
+ pos[0] += margin_text
352
+ anchor = 'start'
353
+ x, y = pos.astype(int)
354
+ text = str(text)
355
+ for c in ['&', '<', '>']:
356
+ text = text.replace(c, ' ')
357
+ return """<text text-anchor="{}" x="{}" y="{}" font-size="{}">{}</text>""".format(anchor, x, y, font_size, text)
358
+
359
+
360
+ def svg_graph(adjacency: Optional[sparse.csr_matrix] = None, position: Optional[np.ndarray] = None,
361
+ names: Optional[np.ndarray] = None, labels: Optional[Iterable] = None,
362
+ name_position: str = 'right', scores: Optional[Iterable] = None,
363
+ membership: Optional[sparse.csr_matrix] = None,
364
+ seeds: Union[list, dict] = None, width: Optional[float] = 400, height: Optional[float] = 300,
365
+ margin: float = 20, margin_text: float = 3, scale: float = 1, node_order: Optional[np.ndarray] = None,
366
+ node_size: float = 7, node_size_min: float = 1, node_size_max: float = 20,
367
+ display_node_weight: Optional[bool] = None, node_weights: Optional[np.ndarray] = None,
368
+ node_width: float = 1, node_width_max: float = 3, node_color: str = 'gray',
369
+ display_edges: bool = True, edge_labels: Optional[list] = None,
370
+ edge_width: float = 1, edge_width_min: float = 0.5,
371
+ edge_width_max: float = 20, display_edge_weight: bool = True,
372
+ edge_color: Optional[str] = None, label_colors: Optional[Iterable] = None,
373
+ font_size: int = 12, directed: Optional[bool] = None, filename: Optional[str] = None) -> str:
374
+ """Return SVG image of a graph.
375
+
376
+ Parameters
377
+ ----------
378
+ adjacency :
379
+ Adjacency matrix of the graph.
380
+ position :
381
+ Positions of the nodes.
382
+ names :
383
+ Names of the nodes.
384
+ labels :
385
+ Labels of the nodes (negative values mean no label).
386
+ name_position :
387
+ Position of the names (left, right, above, below)
388
+ scores :
389
+ Scores of the nodes (measure of importance).
390
+ membership :
391
+ Membership of the nodes (label distribution).
392
+ seeds :
393
+ Nodes to be highlighted (if dict, only keys are considered).
394
+ width :
395
+ Width of the image.
396
+ height :
397
+ Height of the image.
398
+ margin :
399
+ Margin of the image.
400
+ margin_text :
401
+ Margin between node and text.
402
+ scale :
403
+ Multiplicative factor on the dimensions of the image.
404
+ node_order :
405
+ Order in which nodes are displayed.
406
+ node_size :
407
+ Size of nodes.
408
+ node_size_min :
409
+ Minimum size of a node.
410
+ node_size_max:
411
+ Maximum size of a node.
412
+ node_width :
413
+ Width of node circle.
414
+ node_width_max :
415
+ Maximum width of node circle.
416
+ node_color :
417
+ Default color of nodes (svg color).
418
+ display_node_weight :
419
+ If ``True``, display node weights through node size.
420
+ node_weights :
421
+ Node weights.
422
+ display_edges :
423
+ If ``True``, display edges.
424
+ edge_labels :
425
+ Labels of the edges, as a list of tuples (source, destination, label)
426
+ edge_width :
427
+ Width of edges.
428
+ edge_width_min :
429
+ Minimum width of edges.
430
+ edge_width_max :
431
+ Maximum width of edges.
432
+ display_edge_weight :
433
+ If ``True``, display edge weights through edge widths.
434
+ edge_color :
435
+ Default color of edges (svg color).
436
+ label_colors:
437
+ Colors of the labels (svg colors).
438
+ font_size :
439
+ Font size.
440
+ directed :
441
+ If ``True``, considers the graph as directed.
442
+ filename :
443
+ Filename for saving image (optional).
444
+
445
+ Returns
446
+ -------
447
+ image : str
448
+ SVG image.
449
+
450
+ Example
451
+ -------
452
+ >>> from sknetwork.data import karate_club
453
+ >>> graph = karate_club(True)
454
+ >>> adjacency = graph.adjacency
455
+ >>> position = graph.position
456
+ >>> from sknetwork.visualization import svg_graph
457
+ >>> image = svg_graph(adjacency, position)
458
+ >>> image[1:4]
459
+ 'svg'
460
+ """
461
+ # check adjacency
462
+ if adjacency is None:
463
+ if position is None:
464
+ raise ValueError("You must specify either adjacency or position.")
465
+ else:
466
+ n = position.shape[0]
467
+ adjacency = sparse.csr_matrix((n, n)).astype(int)
468
+ else:
469
+ n = adjacency.shape[0]
470
+ adjacency.eliminate_zeros()
471
+ if directed is None:
472
+ directed = not is_symmetric(adjacency)
473
+
474
+ # node order
475
+ if node_order is None:
476
+ node_order = np.arange(n)
477
+
478
+ # position
479
+ if position is None:
480
+ spring = Spring()
481
+ position = spring.fit_transform(adjacency)
482
+
483
+ # node colors
484
+ node_colors = get_node_colors(n, labels, scores, membership, node_color, label_colors)
485
+
486
+ # node sizes
487
+ if display_node_weight is None:
488
+ display_node_weight = node_weights is not None
489
+ if node_weights is None:
490
+ node_weights = adjacency.T.dot(np.ones(n))
491
+ node_sizes = get_node_sizes(node_weights, node_size, node_size_min, node_size_max, display_node_weight)
492
+
493
+ # node widths
494
+ node_widths = get_node_widths(n, seeds, node_width, node_width_max)
495
+
496
+ # rescaling
497
+ position, width, height = rescale(position, width, height, margin, node_size, node_size_max, display_node_weight,
498
+ names, name_position, font_size)
499
+
500
+ # scaling
501
+ position *= scale
502
+ height *= scale
503
+ width *= scale
504
+
505
+ svg = """<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">\n""".format(width, height)
506
+
507
+ # edges
508
+ if display_edges:
509
+ adjacency_coo = sparse.coo_matrix(adjacency)
510
+
511
+ if edge_color is None:
512
+ if names is None:
513
+ edge_color = 'black'
514
+ else:
515
+ edge_color = 'gray'
516
+
517
+ edge_colors, edge_order, edge_colors_residual = get_edge_colors(adjacency, edge_labels, edge_color,
518
+ label_colors)
519
+ edge_widths = get_edge_widths(adjacency_coo, edge_width, edge_width_min, edge_width_max, display_edge_weight)
520
+
521
+ if directed:
522
+ for edge_color in set(edge_colors):
523
+ svg += """<defs><marker id="arrow-{}" markerWidth="10" markerHeight="10" refX="9" refY="3"
524
+ orient="auto" >\n""".format(edge_color)
525
+ svg += """<path d="M0,0 L0,6 L9,3 z" fill="{}"/></marker></defs>\n""".format(edge_color)
526
+
527
+ for ix in edge_order:
528
+ i = adjacency_coo.row[ix]
529
+ j = adjacency_coo.col[ix]
530
+ color = edge_colors[ix]
531
+ if directed:
532
+ svg += svg_edge_directed(pos_1=position[i], pos_2=position[j], edge_width=edge_widths[ix],
533
+ edge_color=color, node_size=node_sizes[j])
534
+ else:
535
+ svg += svg_edge(pos_1=position[i], pos_2=position[j],
536
+ edge_width=edge_widths[ix], edge_color=color)
537
+
538
+ for i, j, color in edge_colors_residual:
539
+ if directed:
540
+ svg += svg_edge_directed(pos_1=position[i], pos_2=position[j], edge_width=edge_width,
541
+ edge_color=color, node_size=node_sizes[j])
542
+ else:
543
+ svg += svg_edge(pos_1=position[i], pos_2=position[j],
544
+ edge_width=edge_width, edge_color=color)
545
+
546
+ # nodes
547
+ for i in node_order:
548
+ if membership is None:
549
+ svg += svg_node(position[i], node_sizes[i], node_colors[i], node_widths[i])
550
+ else:
551
+ if membership[i].nnz == 1:
552
+ index = membership[i].indices[0]
553
+ svg += svg_node(position[i], node_sizes[i], node_colors[index], node_widths[i])
554
+ else:
555
+ svg += svg_pie_chart_node(position[i], node_sizes[i], membership[i].todense(),
556
+ node_colors, node_widths[i])
557
+
558
+ # text
559
+ if names is not None:
560
+ for i in range(n):
561
+ svg += svg_text(position[i], names[i], node_sizes[i] + margin_text, font_size, name_position)
562
+ svg += """</svg>\n"""
563
+
564
+ if filename is not None:
565
+ with open(filename + '.svg', 'w') as f:
566
+ f.write(svg)
567
+
568
+ return svg
569
+
570
+
571
+ def svg_bigraph(biadjacency: sparse.csr_matrix,
572
+ names_row: Optional[np.ndarray] = None, names_col: Optional[np.ndarray] = None,
573
+ labels_row: Optional[Union[dict, np.ndarray]] = None,
574
+ labels_col: Optional[Union[dict, np.ndarray]] = None,
575
+ scores_row: Optional[Union[dict, np.ndarray]] = None,
576
+ scores_col: Optional[Union[dict, np.ndarray]] = None,
577
+ membership_row: Optional[sparse.csr_matrix] = None,
578
+ membership_col: Optional[sparse.csr_matrix] = None,
579
+ seeds_row: Union[list, dict] = None, seeds_col: Union[list, dict] = None,
580
+ position_row: Optional[np.ndarray] = None, position_col: Optional[np.ndarray] = None,
581
+ reorder: bool = True, width: Optional[float] = 400,
582
+ height: Optional[float] = 300, margin: float = 20, margin_text: float = 3, scale: float = 1,
583
+ node_size: float = 7, node_size_min: float = 1, node_size_max: float = 20,
584
+ display_node_weight: bool = False,
585
+ node_weights_row: Optional[np.ndarray] = None, node_weights_col: Optional[np.ndarray] = None,
586
+ node_width: float = 1, node_width_max: float = 3,
587
+ color_row: str = 'gray', color_col: str = 'gray', label_colors: Optional[Iterable] = None,
588
+ display_edges: bool = True, edge_labels: Optional[list] = None, edge_width: float = 1,
589
+ edge_width_min: float = 0.5, edge_width_max: float = 10, edge_color: str = 'black',
590
+ display_edge_weight: bool = True,
591
+ font_size: int = 12, filename: Optional[str] = None) -> str:
592
+ """Return SVG image of a bigraph.
593
+
594
+ Parameters
595
+ ----------
596
+ biadjacency :
597
+ Biadjacency matrix of the graph.
598
+ names_row :
599
+ Names of the rows.
600
+ names_col :
601
+ Names of the columns.
602
+ labels_row :
603
+ Labels of the rows (negative values mean no label).
604
+ labels_col :
605
+ Labels of the columns (negative values mean no label).
606
+ scores_row :
607
+ Scores of the rows (measure of importance).
608
+ scores_col :
609
+ Scores of the columns (measure of importance).
610
+ membership_row :
611
+ Membership of the rows (label distribution).
612
+ membership_col :
613
+ Membership of the columns (label distribution).
614
+ seeds_row :
615
+ Rows to be highlighted (if dict, only keys are considered).
616
+ seeds_col :
617
+ Columns to be highlighted (if dict, only keys are considered).
618
+ position_row :
619
+ Positions of the rows.
620
+ position_col :
621
+ Positions of the columns.
622
+ reorder :
623
+ Use clustering to order nodes.
624
+ width :
625
+ Width of the image.
626
+ height :
627
+ Height of the image.
628
+ margin :
629
+ Margin of the image.
630
+ margin_text :
631
+ Margin between node and text.
632
+ scale :
633
+ Multiplicative factor on the dimensions of the image.
634
+ node_size :
635
+ Size of nodes.
636
+ node_size_min :
637
+ Minimum size of nodes.
638
+ node_size_max :
639
+ Maximum size of nodes.
640
+ display_node_weight :
641
+ If ``True``, display node weights through node size.
642
+ node_weights_row :
643
+ Weights of rows (used only if **display_node_weight** is ``True``).
644
+ node_weights_col :
645
+ Weights of columns (used only if **display_node_weight** is ``True``).
646
+ node_width :
647
+ Width of node circle.
648
+ node_width_max :
649
+ Maximum width of node circle.
650
+ color_row :
651
+ Default color of rows (svg color).
652
+ color_col :
653
+ Default color of cols (svg color).
654
+ label_colors :
655
+ Colors of the labels (svg color).
656
+ display_edges :
657
+ If ``True``, display edges.
658
+ edge_labels :
659
+ Labels of the edges, as a list of tuples (source, destination, label)
660
+ edge_width :
661
+ Width of edges.
662
+ edge_width_min :
663
+ Minimum width of edges.
664
+ edge_width_max :
665
+ Maximum width of edges.
666
+ display_edge_weight :
667
+ If ``True``, display edge weights through edge widths.
668
+ edge_color :
669
+ Default color of edges (svg color).
670
+ font_size :
671
+ Font size.
672
+ filename :
673
+ Filename for saving image (optional).
674
+
675
+ Returns
676
+ -------
677
+ image : str
678
+ SVG image.
679
+
680
+ Example
681
+ -------
682
+ >>> from sknetwork.data import movie_actor
683
+ >>> biadjacency = movie_actor()
684
+ >>> from sknetwork.visualization import svg_bigraph
685
+ >>> image = svg_bigraph(biadjacency)
686
+ >>> image[1:4]
687
+ 'svg'
688
+ """
689
+ n_row, n_col = biadjacency.shape
690
+
691
+ # node positions
692
+ if position_row is None or position_col is None:
693
+ position_row = np.zeros((n_row, 2))
694
+ position_col = np.ones((n_col, 2))
695
+ if reorder:
696
+ louvain = Louvain()
697
+ louvain.fit(biadjacency, force_bipartite=True)
698
+ index_row = np.argsort(louvain.labels_row_)
699
+ index_col = np.argsort(louvain.labels_col_)
700
+ else:
701
+ index_row = np.arange(n_row)
702
+ index_col = np.arange(n_col)
703
+ position_row[index_row, 1] = np.arange(n_row)
704
+ position_col[index_col, 1] = np.arange(n_col) + .5 * (n_row - n_col)
705
+ position = np.vstack((position_row, position_col))
706
+
707
+ # node colors
708
+ if scores_row is not None and scores_col is not None:
709
+ if isinstance(scores_row, dict):
710
+ scores_row = np.array(list(scores_row.values()))
711
+ if isinstance(scores_col, dict):
712
+ scores_col = np.array(list(scores_col.values()))
713
+ scores = np.hstack((scores_row, scores_col))
714
+ score_min = np.min(scores)
715
+ score_max = np.max(scores)
716
+ else:
717
+ score_min = None
718
+ score_max = None
719
+
720
+ colors_row = get_node_colors(n_row, labels_row, scores_row, membership_row, color_row, label_colors,
721
+ score_min, score_max)
722
+ colors_col = get_node_colors(n_col, labels_col, scores_col, membership_col, color_col, label_colors,
723
+ score_min, score_max)
724
+
725
+ # node sizes
726
+ if node_weights_row is None:
727
+ node_weights_row = biadjacency.dot(np.ones(n_col))
728
+ if node_weights_col is None:
729
+ node_weights_col = biadjacency.T.dot(np.ones(n_row))
730
+ node_sizes_row, node_sizes_col = get_node_sizes_bipartite(node_weights_row, node_weights_col,
731
+ node_size, node_size_min, node_size_max,
732
+ display_node_weight)
733
+
734
+ # node widths
735
+ node_widths_row = get_node_widths(n_row, seeds_row, node_width, node_width_max)
736
+ node_widths_col = get_node_widths(n_col, seeds_col, node_width, node_width_max)
737
+
738
+ # rescaling
739
+ if not width and not height:
740
+ raise ValueError("You must specify either the width or the height of the image.")
741
+ position, width, height = rescale(position, width, height, margin, node_size, node_size_max, display_node_weight)
742
+
743
+ # node names
744
+ if names_row is not None:
745
+ text_length = np.max(np.array([len(str(name)) for name in names_row]))
746
+ position[:, 0] += text_length * font_size * .5
747
+ width += text_length * font_size * .5
748
+ if names_col is not None:
749
+ text_length = np.max(np.array([len(str(name)) for name in names_col]))
750
+ width += text_length * font_size * .5
751
+
752
+ # scaling
753
+ position *= scale
754
+ height *= scale
755
+ width *= scale
756
+ position_row = position[:n_row]
757
+ position_col = position[n_row:]
758
+
759
+ svg = """<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">\n""".format(width, height)
760
+
761
+ # edges
762
+ if display_edges:
763
+ biadjacency_coo = sparse.coo_matrix(biadjacency)
764
+
765
+ if edge_color is None:
766
+ if names_row is None and names_col is None:
767
+ edge_color = 'black'
768
+ else:
769
+ edge_color = 'gray'
770
+
771
+ edge_colors, edge_order, edge_colors_residual = get_edge_colors(biadjacency, edge_labels, edge_color,
772
+ label_colors)
773
+ edge_widths = get_edge_widths(biadjacency_coo, edge_width, edge_width_min, edge_width_max, display_edge_weight)
774
+
775
+ for ix in edge_order:
776
+ i = biadjacency_coo.row[ix]
777
+ j = biadjacency_coo.col[ix]
778
+ color = edge_colors[ix]
779
+ svg += svg_edge(pos_1=position_row[i], pos_2=position_col[j], edge_width=edge_widths[ix], edge_color=color)
780
+
781
+ for i, j, color in edge_colors_residual:
782
+ svg += svg_edge(pos_1=position_row[i], pos_2=position_col[j], edge_width=edge_width, edge_color=color)
783
+
784
+ # nodes
785
+ for i in range(n_row):
786
+ if membership_row is None:
787
+ svg += svg_node(position_row[i], node_sizes_row[i], colors_row[i], node_widths_row[i])
788
+ else:
789
+ if membership_row[i].nnz == 1:
790
+ index = membership_row[i].indices[0]
791
+ svg += svg_node(position_row[i], node_sizes_row[i], colors_row[index], node_widths_row[i])
792
+ else:
793
+ svg += svg_pie_chart_node(position_row[i], node_sizes_row[i], membership_row[i].todense(),
794
+ colors_row, node_widths_row[i])
795
+
796
+ for i in range(n_col):
797
+ if membership_col is None:
798
+ svg += svg_node(position_col[i], node_sizes_col[i], colors_col[i], node_widths_col[i])
799
+ else:
800
+ if membership_col[i].nnz == 1:
801
+ index = membership_col[i].indices[0]
802
+ svg += svg_node(position_col[i], node_sizes_col[i], colors_col[index], node_widths_col[i])
803
+ else:
804
+ svg += svg_pie_chart_node(position_col[i], node_sizes_col[i], membership_col[i].todense(),
805
+ colors_col, node_widths_col[i])
806
+ # text
807
+ if names_row is not None:
808
+ for i in range(n_row):
809
+ svg += svg_text(position_row[i], names_row[i], margin_text + node_sizes_row[i], font_size, 'left')
810
+ if names_col is not None:
811
+ for i in range(n_col):
812
+ svg += svg_text(position_col[i], names_col[i], margin_text + node_sizes_col[i], font_size)
813
+ svg += """</svg>\n"""
814
+
815
+ if filename is not None:
816
+ with open(filename + '.svg', 'w') as f:
817
+ f.write(svg)
818
+
819
+ return svg