scikit-network 0.33.4__cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. scikit_network-0.33.4.dist-info/METADATA +122 -0
  2. scikit_network-0.33.4.dist-info/RECORD +229 -0
  3. scikit_network-0.33.4.dist-info/WHEEL +6 -0
  4. scikit_network-0.33.4.dist-info/licenses/AUTHORS.rst +43 -0
  5. scikit_network-0.33.4.dist-info/licenses/LICENSE +34 -0
  6. scikit_network-0.33.4.dist-info/top_level.txt +1 -0
  7. scikit_network.libs/libgomp-a34b3233.so.1.0.0 +0 -0
  8. sknetwork/__init__.py +21 -0
  9. sknetwork/base.py +67 -0
  10. sknetwork/classification/__init__.py +8 -0
  11. sknetwork/classification/base.py +138 -0
  12. sknetwork/classification/base_rank.py +129 -0
  13. sknetwork/classification/diffusion.py +127 -0
  14. sknetwork/classification/knn.py +131 -0
  15. sknetwork/classification/metrics.py +205 -0
  16. sknetwork/classification/pagerank.py +58 -0
  17. sknetwork/classification/propagation.py +144 -0
  18. sknetwork/classification/tests/__init__.py +1 -0
  19. sknetwork/classification/tests/test_API.py +30 -0
  20. sknetwork/classification/tests/test_diffusion.py +77 -0
  21. sknetwork/classification/tests/test_knn.py +23 -0
  22. sknetwork/classification/tests/test_metrics.py +53 -0
  23. sknetwork/classification/tests/test_pagerank.py +20 -0
  24. sknetwork/classification/tests/test_propagation.py +24 -0
  25. sknetwork/classification/vote.cpp +27593 -0
  26. sknetwork/classification/vote.cpython-312-x86_64-linux-gnu.so +0 -0
  27. sknetwork/classification/vote.pyx +56 -0
  28. sknetwork/clustering/__init__.py +8 -0
  29. sknetwork/clustering/base.py +168 -0
  30. sknetwork/clustering/kcenters.py +251 -0
  31. sknetwork/clustering/leiden.py +238 -0
  32. sknetwork/clustering/leiden_core.cpp +31928 -0
  33. sknetwork/clustering/leiden_core.cpython-312-x86_64-linux-gnu.so +0 -0
  34. sknetwork/clustering/leiden_core.pyx +124 -0
  35. sknetwork/clustering/louvain.py +282 -0
  36. sknetwork/clustering/louvain_core.cpp +31573 -0
  37. sknetwork/clustering/louvain_core.cpython-312-x86_64-linux-gnu.so +0 -0
  38. sknetwork/clustering/louvain_core.pyx +124 -0
  39. sknetwork/clustering/metrics.py +91 -0
  40. sknetwork/clustering/postprocess.py +66 -0
  41. sknetwork/clustering/propagation_clustering.py +100 -0
  42. sknetwork/clustering/tests/__init__.py +1 -0
  43. sknetwork/clustering/tests/test_API.py +38 -0
  44. sknetwork/clustering/tests/test_kcenters.py +60 -0
  45. sknetwork/clustering/tests/test_leiden.py +34 -0
  46. sknetwork/clustering/tests/test_louvain.py +135 -0
  47. sknetwork/clustering/tests/test_metrics.py +50 -0
  48. sknetwork/clustering/tests/test_postprocess.py +39 -0
  49. sknetwork/data/__init__.py +6 -0
  50. sknetwork/data/base.py +33 -0
  51. sknetwork/data/load.py +292 -0
  52. sknetwork/data/models.py +459 -0
  53. sknetwork/data/parse.py +644 -0
  54. sknetwork/data/test_graphs.py +93 -0
  55. sknetwork/data/tests/__init__.py +1 -0
  56. sknetwork/data/tests/test_API.py +30 -0
  57. sknetwork/data/tests/test_base.py +14 -0
  58. sknetwork/data/tests/test_load.py +61 -0
  59. sknetwork/data/tests/test_models.py +52 -0
  60. sknetwork/data/tests/test_parse.py +250 -0
  61. sknetwork/data/tests/test_test_graphs.py +29 -0
  62. sknetwork/data/tests/test_toy_graphs.py +68 -0
  63. sknetwork/data/timeout.py +38 -0
  64. sknetwork/data/toy_graphs.py +611 -0
  65. sknetwork/embedding/__init__.py +8 -0
  66. sknetwork/embedding/base.py +90 -0
  67. sknetwork/embedding/force_atlas.py +198 -0
  68. sknetwork/embedding/louvain_embedding.py +142 -0
  69. sknetwork/embedding/random_projection.py +131 -0
  70. sknetwork/embedding/spectral.py +137 -0
  71. sknetwork/embedding/spring.py +198 -0
  72. sknetwork/embedding/svd.py +351 -0
  73. sknetwork/embedding/tests/__init__.py +1 -0
  74. sknetwork/embedding/tests/test_API.py +49 -0
  75. sknetwork/embedding/tests/test_force_atlas.py +35 -0
  76. sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
  77. sknetwork/embedding/tests/test_random_projection.py +28 -0
  78. sknetwork/embedding/tests/test_spectral.py +81 -0
  79. sknetwork/embedding/tests/test_spring.py +50 -0
  80. sknetwork/embedding/tests/test_svd.py +43 -0
  81. sknetwork/gnn/__init__.py +10 -0
  82. sknetwork/gnn/activation.py +117 -0
  83. sknetwork/gnn/base.py +181 -0
  84. sknetwork/gnn/base_activation.py +90 -0
  85. sknetwork/gnn/base_layer.py +109 -0
  86. sknetwork/gnn/gnn_classifier.py +305 -0
  87. sknetwork/gnn/layer.py +153 -0
  88. sknetwork/gnn/loss.py +180 -0
  89. sknetwork/gnn/neighbor_sampler.py +65 -0
  90. sknetwork/gnn/optimizer.py +164 -0
  91. sknetwork/gnn/tests/__init__.py +1 -0
  92. sknetwork/gnn/tests/test_activation.py +56 -0
  93. sknetwork/gnn/tests/test_base.py +75 -0
  94. sknetwork/gnn/tests/test_base_layer.py +37 -0
  95. sknetwork/gnn/tests/test_gnn_classifier.py +130 -0
  96. sknetwork/gnn/tests/test_layers.py +80 -0
  97. sknetwork/gnn/tests/test_loss.py +33 -0
  98. sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
  99. sknetwork/gnn/tests/test_optimizer.py +43 -0
  100. sknetwork/gnn/tests/test_utils.py +41 -0
  101. sknetwork/gnn/utils.py +127 -0
  102. sknetwork/hierarchy/__init__.py +6 -0
  103. sknetwork/hierarchy/base.py +90 -0
  104. sknetwork/hierarchy/louvain_hierarchy.py +260 -0
  105. sknetwork/hierarchy/metrics.py +234 -0
  106. sknetwork/hierarchy/paris.cpp +37877 -0
  107. sknetwork/hierarchy/paris.cpython-312-x86_64-linux-gnu.so +0 -0
  108. sknetwork/hierarchy/paris.pyx +310 -0
  109. sknetwork/hierarchy/postprocess.py +350 -0
  110. sknetwork/hierarchy/tests/__init__.py +1 -0
  111. sknetwork/hierarchy/tests/test_API.py +24 -0
  112. sknetwork/hierarchy/tests/test_algos.py +34 -0
  113. sknetwork/hierarchy/tests/test_metrics.py +62 -0
  114. sknetwork/hierarchy/tests/test_postprocess.py +57 -0
  115. sknetwork/linalg/__init__.py +9 -0
  116. sknetwork/linalg/basics.py +37 -0
  117. sknetwork/linalg/diteration.cpp +27409 -0
  118. sknetwork/linalg/diteration.cpython-312-x86_64-linux-gnu.so +0 -0
  119. sknetwork/linalg/diteration.pyx +47 -0
  120. sknetwork/linalg/eig_solver.py +93 -0
  121. sknetwork/linalg/laplacian.py +15 -0
  122. sknetwork/linalg/normalizer.py +86 -0
  123. sknetwork/linalg/operators.py +225 -0
  124. sknetwork/linalg/polynome.py +76 -0
  125. sknetwork/linalg/ppr_solver.py +170 -0
  126. sknetwork/linalg/push.cpp +31081 -0
  127. sknetwork/linalg/push.cpython-312-x86_64-linux-gnu.so +0 -0
  128. sknetwork/linalg/push.pyx +71 -0
  129. sknetwork/linalg/sparse_lowrank.py +142 -0
  130. sknetwork/linalg/svd_solver.py +91 -0
  131. sknetwork/linalg/tests/__init__.py +1 -0
  132. sknetwork/linalg/tests/test_eig.py +44 -0
  133. sknetwork/linalg/tests/test_laplacian.py +18 -0
  134. sknetwork/linalg/tests/test_normalization.py +34 -0
  135. sknetwork/linalg/tests/test_operators.py +66 -0
  136. sknetwork/linalg/tests/test_polynome.py +38 -0
  137. sknetwork/linalg/tests/test_ppr.py +50 -0
  138. sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
  139. sknetwork/linalg/tests/test_svd.py +38 -0
  140. sknetwork/linkpred/__init__.py +2 -0
  141. sknetwork/linkpred/base.py +46 -0
  142. sknetwork/linkpred/nn.py +126 -0
  143. sknetwork/linkpred/tests/__init__.py +1 -0
  144. sknetwork/linkpred/tests/test_nn.py +26 -0
  145. sknetwork/log.py +19 -0
  146. sknetwork/path/__init__.py +5 -0
  147. sknetwork/path/dag.py +54 -0
  148. sknetwork/path/distances.py +98 -0
  149. sknetwork/path/search.py +31 -0
  150. sknetwork/path/shortest_path.py +61 -0
  151. sknetwork/path/tests/__init__.py +1 -0
  152. sknetwork/path/tests/test_dag.py +37 -0
  153. sknetwork/path/tests/test_distances.py +62 -0
  154. sknetwork/path/tests/test_search.py +40 -0
  155. sknetwork/path/tests/test_shortest_path.py +40 -0
  156. sknetwork/ranking/__init__.py +8 -0
  157. sknetwork/ranking/base.py +57 -0
  158. sknetwork/ranking/betweenness.cpp +9716 -0
  159. sknetwork/ranking/betweenness.cpython-312-x86_64-linux-gnu.so +0 -0
  160. sknetwork/ranking/betweenness.pyx +97 -0
  161. sknetwork/ranking/closeness.py +92 -0
  162. sknetwork/ranking/hits.py +90 -0
  163. sknetwork/ranking/katz.py +79 -0
  164. sknetwork/ranking/pagerank.py +106 -0
  165. sknetwork/ranking/postprocess.py +37 -0
  166. sknetwork/ranking/tests/__init__.py +1 -0
  167. sknetwork/ranking/tests/test_API.py +32 -0
  168. sknetwork/ranking/tests/test_betweenness.py +38 -0
  169. sknetwork/ranking/tests/test_closeness.py +30 -0
  170. sknetwork/ranking/tests/test_hits.py +20 -0
  171. sknetwork/ranking/tests/test_pagerank.py +62 -0
  172. sknetwork/ranking/tests/test_postprocess.py +26 -0
  173. sknetwork/regression/__init__.py +4 -0
  174. sknetwork/regression/base.py +57 -0
  175. sknetwork/regression/diffusion.py +204 -0
  176. sknetwork/regression/tests/__init__.py +1 -0
  177. sknetwork/regression/tests/test_API.py +32 -0
  178. sknetwork/regression/tests/test_diffusion.py +56 -0
  179. sknetwork/sknetwork.py +3 -0
  180. sknetwork/test_base.py +35 -0
  181. sknetwork/test_log.py +15 -0
  182. sknetwork/topology/__init__.py +8 -0
  183. sknetwork/topology/cliques.cpp +32574 -0
  184. sknetwork/topology/cliques.cpython-312-x86_64-linux-gnu.so +0 -0
  185. sknetwork/topology/cliques.pyx +149 -0
  186. sknetwork/topology/core.cpp +30660 -0
  187. sknetwork/topology/core.cpython-312-x86_64-linux-gnu.so +0 -0
  188. sknetwork/topology/core.pyx +90 -0
  189. sknetwork/topology/cycles.py +243 -0
  190. sknetwork/topology/minheap.cpp +27341 -0
  191. sknetwork/topology/minheap.cpython-312-x86_64-linux-gnu.so +0 -0
  192. sknetwork/topology/minheap.pxd +20 -0
  193. sknetwork/topology/minheap.pyx +109 -0
  194. sknetwork/topology/structure.py +194 -0
  195. sknetwork/topology/tests/__init__.py +1 -0
  196. sknetwork/topology/tests/test_cliques.py +28 -0
  197. sknetwork/topology/tests/test_core.py +19 -0
  198. sknetwork/topology/tests/test_cycles.py +65 -0
  199. sknetwork/topology/tests/test_structure.py +85 -0
  200. sknetwork/topology/tests/test_triangles.py +38 -0
  201. sknetwork/topology/tests/test_wl.py +72 -0
  202. sknetwork/topology/triangles.cpp +8903 -0
  203. sknetwork/topology/triangles.cpython-312-x86_64-linux-gnu.so +0 -0
  204. sknetwork/topology/triangles.pyx +151 -0
  205. sknetwork/topology/weisfeiler_lehman.py +133 -0
  206. sknetwork/topology/weisfeiler_lehman_core.cpp +27644 -0
  207. sknetwork/topology/weisfeiler_lehman_core.cpython-312-x86_64-linux-gnu.so +0 -0
  208. sknetwork/topology/weisfeiler_lehman_core.pyx +114 -0
  209. sknetwork/utils/__init__.py +7 -0
  210. sknetwork/utils/check.py +355 -0
  211. sknetwork/utils/format.py +221 -0
  212. sknetwork/utils/membership.py +82 -0
  213. sknetwork/utils/neighbors.py +115 -0
  214. sknetwork/utils/tests/__init__.py +1 -0
  215. sknetwork/utils/tests/test_check.py +190 -0
  216. sknetwork/utils/tests/test_format.py +63 -0
  217. sknetwork/utils/tests/test_membership.py +24 -0
  218. sknetwork/utils/tests/test_neighbors.py +41 -0
  219. sknetwork/utils/tests/test_tfidf.py +18 -0
  220. sknetwork/utils/tests/test_values.py +66 -0
  221. sknetwork/utils/tfidf.py +37 -0
  222. sknetwork/utils/values.py +76 -0
  223. sknetwork/visualization/__init__.py +4 -0
  224. sknetwork/visualization/colors.py +34 -0
  225. sknetwork/visualization/dendrograms.py +277 -0
  226. sknetwork/visualization/graphs.py +1039 -0
  227. sknetwork/visualization/tests/__init__.py +1 -0
  228. sknetwork/visualization/tests/test_dendrograms.py +53 -0
  229. sknetwork/visualization/tests/test_graphs.py +176 -0
@@ -0,0 +1,56 @@
1
+ # distutils: language = c++
2
+ # cython: language_level=3
3
+ """
4
+ Created in April 2020
5
+ @author: Nathan de Lara <nathan.delara@polytechnique.org>
6
+ """
7
+ from libcpp.set cimport set
8
+ from libcpp.vector cimport vector
9
+
10
+ cimport cython
11
+
12
+
13
+ @cython.boundscheck(False)
14
+ @cython.wraparound(False)
15
+ def vote_update(int[:] indptr, int[:] indices, float[:] data, int[:] labels, int[:] index):
16
+ """One pass of label updates over the graph by majority vote among neighbors."""
17
+ cdef int i
18
+ cdef int ii
19
+ cdef int j
20
+ cdef int jj
21
+ cdef int n_indices = index.shape[0]
22
+ cdef int label
23
+ cdef int label_neigh_size
24
+ cdef float best_score
25
+
26
+ cdef vector[int] labels_neigh
27
+ cdef vector[float] votes_neigh, votes
28
+ cdef set[int] labels_unique = ()
29
+
30
+ cdef int n = labels.shape[0]
31
+ for i in range(n):
32
+ votes.push_back(0)
33
+
34
+ for ii in range(n_indices):
35
+ i = index[ii]
36
+ labels_neigh.clear()
37
+ for j in range(indptr[i], indptr[i + 1]):
38
+ jj = indices[j]
39
+ labels_neigh.push_back(labels[jj])
40
+ votes_neigh.push_back(data[jj])
41
+
42
+ labels_unique.clear()
43
+ label_neigh_size = labels_neigh.size()
44
+ for jj in range(label_neigh_size):
45
+ label = labels_neigh[jj]
46
+ if label >= 0:
47
+ labels_unique.insert(label)
48
+ votes[label] += votes_neigh[jj]
49
+
50
+ best_score = -1
51
+ for label in labels_unique:
52
+ if votes[label] > best_score:
53
+ labels[i] = label
54
+ best_score = votes[label]
55
+ votes[label] = 0
56
+ return labels
@@ -0,0 +1,8 @@
1
+ """clustering module"""
2
+ from sknetwork.clustering.base import BaseClustering
3
+ from sknetwork.clustering.louvain import Louvain
4
+ from sknetwork.clustering.leiden import Leiden
5
+ from sknetwork.clustering.propagation_clustering import PropagationClustering
6
+ from sknetwork.clustering.metrics import get_modularity
7
+ from sknetwork.clustering.postprocess import reindex_labels, aggregate_graph
8
+ from sknetwork.clustering.kcenters import KCenters
@@ -0,0 +1,168 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Nov, 2019
5
+ @author: Nathan de Lara <nathan.delara@polytechnique.org>
6
+ """
7
+ from abc import ABC
8
+
9
+ import numpy as np
10
+ from scipy import sparse
11
+
12
+ from sknetwork.linalg.normalizer import normalize
13
+ from sknetwork.base import Algorithm
14
+ from sknetwork.utils.membership import get_membership
15
+
16
+
17
+ class BaseClustering(Algorithm, ABC):
18
+ """Base class for clustering algorithms.
19
+
20
+ Attributes
21
+ ----------
22
+ labels\_ : np.ndarray, shape (n_nodes,)
23
+ Label of each node.
24
+ probs\_ : sparse.csr_matrix, shape (n_nodes, n_labels)
25
+ Probability distribution over labels.
26
+ aggregate\_ : sparse.csr_matrix
27
+ Aggregate adjacency matrix or biadjacency matrix between clusters.
28
+ """
29
+ def __init__(self, sort_clusters: bool = True, return_probs: bool = False, return_aggregate: bool = False):
30
+ self.sort_clusters = sort_clusters
31
+ self.return_probs = return_probs
32
+ self.return_aggregate = return_aggregate
33
+ self._init_vars()
34
+
35
+ def predict(self, columns=False) -> np.ndarray:
36
+ """Return the labels predicted by the algorithm.
37
+
38
+ Parameters
39
+ ----------
40
+ columns : bool
41
+ If ``True``, return the prediction for columns.
42
+
43
+ Returns
44
+ -------
45
+ labels : np.ndarray
46
+ Labels.
47
+ """
48
+ if columns:
49
+ return self.labels_col_
50
+ return self.labels_
51
+
52
+ def fit_predict(self, *args, **kwargs) -> np.ndarray:
53
+ """Fit algorithm to the data and return the labels. Same parameters as the ``fit`` method.
54
+
55
+ Returns
56
+ -------
57
+ labels : np.ndarray
58
+ Labels.
59
+ """
60
+ self.fit(*args, **kwargs)
61
+ return self.predict()
62
+
63
+ def predict_proba(self, columns=False) -> np.ndarray:
64
+ """Return the probability distribution over labels as predicted by the algorithm.
65
+
66
+ Parameters
67
+ ----------
68
+ columns : bool
69
+ If ``True``, return the prediction for columns.
70
+
71
+ Returns
72
+ -------
73
+ probs : np.ndarray
74
+ Probability distribution over labels.
75
+ """
76
+ if columns:
77
+ return self.probs_col_.toarray()
78
+ return self.probs_.toarray()
79
+
80
+ def fit_predict_proba(self, *args, **kwargs) -> np.ndarray:
81
+ """Fit algorithm to the data and return the probability distribution over labels.
82
+ Same parameters as the ``fit`` method.
83
+
84
+ Returns
85
+ -------
86
+ probs : np.ndarray
87
+ Probability of each label.
88
+ """
89
+ self.fit(*args, **kwargs)
90
+ return self.predict_proba()
91
+
92
+ def transform(self, columns=False) -> sparse.csr_matrix:
93
+ """Return the probability distribution over labels in sparse format.
94
+
95
+ Parameters
96
+ ----------
97
+ columns : bool
98
+ If ``True``, return the prediction for columns.
99
+
100
+ Returns
101
+ -------
102
+ probs : sparse.csr_matrix
103
+ Probability distribution over labels.
104
+ """
105
+ if columns:
106
+ return self.probs_col_
107
+ return self.probs_
108
+
109
+ def fit_transform(self, *args, **kwargs) -> np.ndarray:
110
+ """Fit algorithm to the data and return the membership matrix. Same parameters as the ``fit`` method.
111
+
112
+ Returns
113
+ -------
114
+ membership : np.ndarray
115
+ Membership matrix (distribution over clusters).
116
+ """
117
+ self.fit(*args, **kwargs)
118
+ return self.transform()
119
+
120
+ def _init_vars(self):
121
+ """Init variables."""
122
+ self.labels_ = None
123
+ self.labels_row_ = None
124
+ self.labels_col_ = None
125
+ self.probs_ = None
126
+ self.probs_row_ = None
127
+ self.probs_col_ = None
128
+ self.aggregate_ = None
129
+ self.bipartite = None
130
+ return self
131
+
132
+ def _split_vars(self, shape):
133
+ """Split labels_ into labels_row_ and labels_col_"""
134
+ n_row = shape[0]
135
+ self.labels_row_ = self.labels_[:n_row]
136
+ self.labels_col_ = self.labels_[n_row:]
137
+ self.labels_ = self.labels_row_
138
+ return self
139
+
140
+ def _secondary_outputs(self, input_matrix: sparse.csr_matrix):
141
+ """Compute different variables from labels_."""
142
+ if self.return_probs or self.return_aggregate:
143
+ input_matrix = input_matrix.astype(float)
144
+ if not self.bipartite:
145
+ probs = get_membership(self.labels_)
146
+ if self.return_probs:
147
+ self.probs_ = normalize(input_matrix.dot(probs))
148
+ if self.return_aggregate:
149
+ self.aggregate_ = sparse.csr_matrix(probs.T.dot(input_matrix.dot(probs)))
150
+ else:
151
+ if self.labels_col_ is None:
152
+ n_labels = max(self.labels_) + 1
153
+ probs_row = get_membership(self.labels_, n_labels=n_labels)
154
+ probs_col = normalize(input_matrix.T.dot(probs_row))
155
+ else:
156
+ n_labels = max(max(self.labels_row_), max(self.labels_col_)) + 1
157
+ probs_row = get_membership(self.labels_row_, n_labels=n_labels)
158
+ probs_col = get_membership(self.labels_col_, n_labels=n_labels)
159
+ if self.return_probs:
160
+ self.probs_row_ = normalize(input_matrix.dot(probs_col))
161
+ self.probs_col_ = normalize(input_matrix.T.dot(probs_row))
162
+ self.probs_ = self.probs_row_
163
+ if self.return_aggregate:
164
+ aggregate_ = sparse.csr_matrix(probs_row.T.dot(input_matrix))
165
+ aggregate_ = aggregate_.dot(probs_col)
166
+ self.aggregate_ = aggregate_
167
+
168
+ return self
@@ -0,0 +1,251 @@
1
+ """
2
+ Created in March 2024
3
+ @author: Laurène David <laurene.david@ip-paris.fr>
4
+ @author: Thomas Bonald <bonald@enst.fr>
5
+ """
6
+
7
+ from typing import Union
8
+
9
+ import numpy as np
10
+ from scipy import sparse
11
+
12
+ from sknetwork.clustering import BaseClustering
13
+ from sknetwork.ranking import PageRank
14
+ from sknetwork.clustering import get_modularity
15
+ from sknetwork.classification.pagerank import PageRankClassifier
16
+ from sknetwork.utils.format import get_adjacency, directed2undirected
17
+
18
+
19
+ class KCenters(BaseClustering):
20
+ """K-center clustering algorithm. The center of each cluster is obtained by the PageRank algorithm.
21
+
22
+ Parameters
23
+ ----------
24
+ n_clusters : int
25
+ Number of clusters.
26
+ directed : bool, default False
27
+ If ``True``, the graph is considered directed.
28
+ center_position : str, default "row"
29
+ Force centers to correspond to the nodes on the rows or columns of the biadjacency matrix.
30
+ Can be ``row``, ``col`` or ``both``. Only considered for bipartite graphs.
31
+ n_init : int, default 5
32
+ Number of reruns of the k-centers algorithm with different centers.
33
+ The run that produce the best modularity is chosen as the final result.
34
+ max_iter : int, default 20
35
+ Maximum number of iterations of the k-centers algorithm for a single run.
36
+
37
+ Attributes
38
+ ----------
39
+ labels\_ : np.ndarray, shape (n_nodes,)
40
+ Label of each node.
41
+ probs\_ : sparse.csr_matrix, shape (n_nodes, n_labels)
42
+ Probability distribution over labels.
43
+ aggregate\_ : sparse.csr_matrix
44
+ Aggregate adjacency matrix or biadjacency matrix between clusters.
45
+
46
+ Example
47
+ -------
48
+ >>> from sknetwork.clustering import KCenters
49
+ >>> from sknetwork.data import karate_club
50
+ >>> kcenters = KCenters(n_clusters=2)
51
+ >>> adjacency = karate_club()
52
+ >>> labels = kcenters.fit_predict(adjacency)
53
+ >>> len(set(labels))
54
+ 2
55
+
56
+ """
57
+ def __init__(self, n_clusters: int, directed: bool = False, center_position: str = "row", n_init: int = 5,
58
+ max_iter: int = 20):
59
+ super(BaseClustering, self).__init__()
60
+ self.n_clusters = n_clusters
61
+ self.directed = directed
62
+ self.bipartite = None
63
+ self.center_position = center_position
64
+ self.n_init = n_init
65
+ self.max_iter = max_iter
66
+ self.labels_ = None
67
+ self.centers_ = None
68
+ self.centers_row_ = None
69
+ self.centers_col_ = None
70
+
71
+ def _compute_mask_centers(self, input_matrix: Union[sparse.csr_matrix, np.ndarray]):
72
+ """Generate mask to filter nodes that can be cluster centers.
73
+
74
+ Parameters
75
+ ----------
76
+ input_matrix :
77
+ Adjacency matrix or biadjacency matrix of the graph.
78
+
79
+ Return
80
+ ------
81
+ mask : np.array, shape (n_nodes,)
82
+ Mask for possible cluster centers.
83
+
84
+ """
85
+ n_row, n_col = input_matrix.shape
86
+ if self.bipartite:
87
+ n_nodes = n_row + n_col
88
+ mask = np.zeros(n_nodes, dtype=bool)
89
+ if self.center_position == "row":
90
+ mask[:n_row] = True
91
+ elif self.center_position == "col":
92
+ mask[n_row:] = True
93
+ elif self.center_position == "both":
94
+ mask[:] = True
95
+ else:
96
+ raise ValueError('Unknown center position')
97
+ else:
98
+ mask = np.ones(n_row, dtype=bool)
99
+
100
+ return mask
101
+
102
+ @staticmethod
103
+ def _init_centers(adjacency: Union[sparse.csr_matrix, np.ndarray], mask: np.ndarray, n_clusters: int):
104
+ """
105
+ Kcenters++ initialization to select cluster centers.
106
+ This algorithm is an adaptation of the Kmeans++ algorithm to graphs.
107
+
108
+ Parameters
109
+ ----------
110
+ adjacency :
111
+ Adjacency matrix of the graph.
112
+ mask :
113
+ Initial mask for allowed positions of centers.
114
+ n_clusters : int
115
+ Number of centers to initialize.
116
+
117
+ Returns
118
+ ---------
119
+ centers : np.array, shape (n_clusters,)
120
+ Initial cluster centers.
121
+ """
122
+ mask = mask.copy()
123
+ n_nodes = adjacency.shape[0]
124
+ nodes = np.arange(n_nodes)
125
+ centers = []
126
+
127
+ # Choose the first center uniformly at random
128
+ center = np.random.choice(nodes[mask])
129
+ mask[center] = 0
130
+ centers.append(center)
131
+
132
+ pagerank = PageRank()
133
+ weights = {center: 1}
134
+
135
+ for k in range(n_clusters - 1):
136
+ # select nodes that are far from existing centers
137
+ ppr_scores = pagerank.fit_predict(adjacency, weights)
138
+ ppr_scores = ppr_scores[mask]
139
+
140
+ if min(ppr_scores) == 0:
141
+ center = np.random.choice(nodes[mask][ppr_scores == 0])
142
+ else:
143
+ probs = 1 / ppr_scores
144
+ probs = probs / np.sum(probs)
145
+ center = np.random.choice(nodes[mask], p=probs)
146
+
147
+ mask[center] = 0
148
+ centers.append(center)
149
+ weights.update({center: 1})
150
+
151
+ centers = np.array(centers)
152
+ return centers
153
+
154
+ def fit(self, input_matrix: Union[sparse.csr_matrix, np.ndarray], force_bipartite: bool = False) -> "KCenters":
155
+ """Compute the clustering of the graph by k-centers.
156
+
157
+ Parameters
158
+ ----------
159
+ input_matrix :
160
+ Adjacency matrix or biadjacency matrix of the graph.
161
+ force_bipartite :
162
+ If ``True``, force the input matrix to be considered as a biadjacency matrix even if square.
163
+
164
+ Returns
165
+ -------
166
+ self : :class:`KCenters`
167
+ """
168
+
169
+ if self.n_clusters < 2:
170
+ raise ValueError("The number of clusters must be at least 2.")
171
+
172
+ if self.n_init < 1:
173
+ raise ValueError("The n_init parameter must be at least 1.")
174
+
175
+ if self.directed:
176
+ input_matrix = directed2undirected(input_matrix)
177
+
178
+ adjacency, self.bipartite = get_adjacency(input_matrix, force_bipartite=force_bipartite)
179
+ n_row = input_matrix.shape[0]
180
+ n_nodes = adjacency.shape[0]
181
+ nodes = np.arange(n_nodes)
182
+
183
+ mask = self._compute_mask_centers(input_matrix)
184
+ if self.n_clusters > np.sum(mask):
185
+ raise ValueError("The number of clusters is to high. This might be due to the center_position parameter.")
186
+
187
+ pagerank_clf = PageRankClassifier()
188
+ pagerank = PageRank()
189
+
190
+ labels_ = []
191
+ centers_ = []
192
+ modularity_ = []
193
+
194
+ # Restarts
195
+ for i in range(self.n_init):
196
+
197
+ # Initialization
198
+ centers = self._init_centers(adjacency, mask, self.n_clusters)
199
+ prev_centers = None
200
+ labels = None
201
+ n_iter = 0
202
+
203
+ while not np.equal(prev_centers, centers).all() and (n_iter < self.max_iter):
204
+
205
+ # Assign nodes to centers
206
+ labels_center = {center: label for label, center in enumerate(centers)}
207
+ labels = pagerank_clf.fit_predict(adjacency, labels_center)
208
+
209
+ # Find new centers
210
+ prev_centers = centers.copy()
211
+ new_centers = []
212
+
213
+ for label in np.unique(labels):
214
+ mask_cluster = labels == label
215
+ mask_cluster &= mask
216
+ scores = pagerank.fit_predict(adjacency, weights=mask_cluster)
217
+ scores[~mask_cluster] = 0
218
+ new_centers.append(nodes[np.argmax(scores)])
219
+
220
+ n_iter += 1
221
+
222
+ # Store results
223
+ if self.bipartite:
224
+ labels_row = labels[:n_row]
225
+ labels_col = labels[n_row:]
226
+ modularity = get_modularity(input_matrix, labels_row, labels_col)
227
+ else:
228
+ modularity = get_modularity(adjacency, labels)
229
+
230
+ labels_.append(labels)
231
+ centers_.append(centers)
232
+ modularity_.append(modularity)
233
+
234
+ # Select restart with the highest modularity
235
+ idx_max = np.argmax(modularity_)
236
+ self.labels_ = np.array(labels_[idx_max])
237
+ self.centers_ = np.array(centers_[idx_max])
238
+
239
+ if self.bipartite:
240
+ self._split_vars(input_matrix.shape)
241
+
242
+ # Define centers based on center position
243
+ if self.center_position == "row":
244
+ self.centers_row_ = self.centers_
245
+ elif self.center_position == "col":
246
+ self.centers_col_ = self.centers_ - n_row
247
+ else:
248
+ self.centers_row_ = self.centers_[self.centers_ < n_row]
249
+ self.centers_col_ = self.centers_[~np.isin(self.centers_, self.centers_row_)] - n_row
250
+
251
+ return self