scikit-network 0.33.0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-network might be problematic. Click here for more details.

Files changed (228) hide show
  1. scikit_network-0.33.0.dist-info/AUTHORS.rst +43 -0
  2. scikit_network-0.33.0.dist-info/LICENSE +34 -0
  3. scikit_network-0.33.0.dist-info/METADATA +517 -0
  4. scikit_network-0.33.0.dist-info/RECORD +228 -0
  5. scikit_network-0.33.0.dist-info/WHEEL +5 -0
  6. scikit_network-0.33.0.dist-info/top_level.txt +1 -0
  7. sknetwork/__init__.py +21 -0
  8. sknetwork/base.py +67 -0
  9. sknetwork/classification/__init__.py +8 -0
  10. sknetwork/classification/base.py +142 -0
  11. sknetwork/classification/base_rank.py +133 -0
  12. sknetwork/classification/diffusion.py +134 -0
  13. sknetwork/classification/knn.py +139 -0
  14. sknetwork/classification/metrics.py +205 -0
  15. sknetwork/classification/pagerank.py +66 -0
  16. sknetwork/classification/propagation.py +152 -0
  17. sknetwork/classification/tests/__init__.py +1 -0
  18. sknetwork/classification/tests/test_API.py +30 -0
  19. sknetwork/classification/tests/test_diffusion.py +77 -0
  20. sknetwork/classification/tests/test_knn.py +23 -0
  21. sknetwork/classification/tests/test_metrics.py +53 -0
  22. sknetwork/classification/tests/test_pagerank.py +20 -0
  23. sknetwork/classification/tests/test_propagation.py +24 -0
  24. sknetwork/classification/vote.cp312-win_amd64.pyd +0 -0
  25. sknetwork/classification/vote.cpp +27577 -0
  26. sknetwork/classification/vote.pyx +56 -0
  27. sknetwork/clustering/__init__.py +8 -0
  28. sknetwork/clustering/base.py +172 -0
  29. sknetwork/clustering/kcenters.py +253 -0
  30. sknetwork/clustering/leiden.py +242 -0
  31. sknetwork/clustering/leiden_core.cp312-win_amd64.pyd +0 -0
  32. sknetwork/clustering/leiden_core.cpp +31564 -0
  33. sknetwork/clustering/leiden_core.pyx +124 -0
  34. sknetwork/clustering/louvain.py +286 -0
  35. sknetwork/clustering/louvain_core.cp312-win_amd64.pyd +0 -0
  36. sknetwork/clustering/louvain_core.cpp +31209 -0
  37. sknetwork/clustering/louvain_core.pyx +124 -0
  38. sknetwork/clustering/metrics.py +91 -0
  39. sknetwork/clustering/postprocess.py +66 -0
  40. sknetwork/clustering/propagation_clustering.py +104 -0
  41. sknetwork/clustering/tests/__init__.py +1 -0
  42. sknetwork/clustering/tests/test_API.py +38 -0
  43. sknetwork/clustering/tests/test_kcenters.py +60 -0
  44. sknetwork/clustering/tests/test_leiden.py +34 -0
  45. sknetwork/clustering/tests/test_louvain.py +129 -0
  46. sknetwork/clustering/tests/test_metrics.py +50 -0
  47. sknetwork/clustering/tests/test_postprocess.py +39 -0
  48. sknetwork/data/__init__.py +6 -0
  49. sknetwork/data/base.py +33 -0
  50. sknetwork/data/load.py +406 -0
  51. sknetwork/data/models.py +459 -0
  52. sknetwork/data/parse.py +644 -0
  53. sknetwork/data/test_graphs.py +84 -0
  54. sknetwork/data/tests/__init__.py +1 -0
  55. sknetwork/data/tests/test_API.py +30 -0
  56. sknetwork/data/tests/test_base.py +14 -0
  57. sknetwork/data/tests/test_load.py +95 -0
  58. sknetwork/data/tests/test_models.py +52 -0
  59. sknetwork/data/tests/test_parse.py +250 -0
  60. sknetwork/data/tests/test_test_graphs.py +29 -0
  61. sknetwork/data/tests/test_toy_graphs.py +68 -0
  62. sknetwork/data/timeout.py +38 -0
  63. sknetwork/data/toy_graphs.py +611 -0
  64. sknetwork/embedding/__init__.py +8 -0
  65. sknetwork/embedding/base.py +94 -0
  66. sknetwork/embedding/force_atlas.py +198 -0
  67. sknetwork/embedding/louvain_embedding.py +148 -0
  68. sknetwork/embedding/random_projection.py +135 -0
  69. sknetwork/embedding/spectral.py +141 -0
  70. sknetwork/embedding/spring.py +198 -0
  71. sknetwork/embedding/svd.py +359 -0
  72. sknetwork/embedding/tests/__init__.py +1 -0
  73. sknetwork/embedding/tests/test_API.py +49 -0
  74. sknetwork/embedding/tests/test_force_atlas.py +35 -0
  75. sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
  76. sknetwork/embedding/tests/test_random_projection.py +28 -0
  77. sknetwork/embedding/tests/test_spectral.py +81 -0
  78. sknetwork/embedding/tests/test_spring.py +50 -0
  79. sknetwork/embedding/tests/test_svd.py +43 -0
  80. sknetwork/gnn/__init__.py +10 -0
  81. sknetwork/gnn/activation.py +117 -0
  82. sknetwork/gnn/base.py +181 -0
  83. sknetwork/gnn/base_activation.py +89 -0
  84. sknetwork/gnn/base_layer.py +109 -0
  85. sknetwork/gnn/gnn_classifier.py +305 -0
  86. sknetwork/gnn/layer.py +153 -0
  87. sknetwork/gnn/loss.py +180 -0
  88. sknetwork/gnn/neighbor_sampler.py +65 -0
  89. sknetwork/gnn/optimizer.py +164 -0
  90. sknetwork/gnn/tests/__init__.py +1 -0
  91. sknetwork/gnn/tests/test_activation.py +56 -0
  92. sknetwork/gnn/tests/test_base.py +75 -0
  93. sknetwork/gnn/tests/test_base_layer.py +37 -0
  94. sknetwork/gnn/tests/test_gnn_classifier.py +130 -0
  95. sknetwork/gnn/tests/test_layers.py +80 -0
  96. sknetwork/gnn/tests/test_loss.py +33 -0
  97. sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
  98. sknetwork/gnn/tests/test_optimizer.py +43 -0
  99. sknetwork/gnn/tests/test_utils.py +41 -0
  100. sknetwork/gnn/utils.py +127 -0
  101. sknetwork/hierarchy/__init__.py +6 -0
  102. sknetwork/hierarchy/base.py +96 -0
  103. sknetwork/hierarchy/louvain_hierarchy.py +272 -0
  104. sknetwork/hierarchy/metrics.py +234 -0
  105. sknetwork/hierarchy/paris.cp312-win_amd64.pyd +0 -0
  106. sknetwork/hierarchy/paris.cpp +37264 -0
  107. sknetwork/hierarchy/paris.pyx +316 -0
  108. sknetwork/hierarchy/postprocess.py +350 -0
  109. sknetwork/hierarchy/tests/__init__.py +1 -0
  110. sknetwork/hierarchy/tests/test_API.py +24 -0
  111. sknetwork/hierarchy/tests/test_algos.py +34 -0
  112. sknetwork/hierarchy/tests/test_metrics.py +62 -0
  113. sknetwork/hierarchy/tests/test_postprocess.py +57 -0
  114. sknetwork/linalg/__init__.py +9 -0
  115. sknetwork/linalg/basics.py +37 -0
  116. sknetwork/linalg/diteration.cp312-win_amd64.pyd +0 -0
  117. sknetwork/linalg/diteration.cpp +27393 -0
  118. sknetwork/linalg/diteration.pyx +47 -0
  119. sknetwork/linalg/eig_solver.py +93 -0
  120. sknetwork/linalg/laplacian.py +15 -0
  121. sknetwork/linalg/normalizer.py +86 -0
  122. sknetwork/linalg/operators.py +225 -0
  123. sknetwork/linalg/polynome.py +76 -0
  124. sknetwork/linalg/ppr_solver.py +170 -0
  125. sknetwork/linalg/push.cp312-win_amd64.pyd +0 -0
  126. sknetwork/linalg/push.cpp +30474 -0
  127. sknetwork/linalg/push.pyx +71 -0
  128. sknetwork/linalg/sparse_lowrank.py +142 -0
  129. sknetwork/linalg/svd_solver.py +91 -0
  130. sknetwork/linalg/tests/__init__.py +1 -0
  131. sknetwork/linalg/tests/test_eig.py +44 -0
  132. sknetwork/linalg/tests/test_laplacian.py +18 -0
  133. sknetwork/linalg/tests/test_normalization.py +34 -0
  134. sknetwork/linalg/tests/test_operators.py +66 -0
  135. sknetwork/linalg/tests/test_polynome.py +38 -0
  136. sknetwork/linalg/tests/test_ppr.py +50 -0
  137. sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
  138. sknetwork/linalg/tests/test_svd.py +38 -0
  139. sknetwork/linkpred/__init__.py +2 -0
  140. sknetwork/linkpred/base.py +46 -0
  141. sknetwork/linkpred/nn.py +126 -0
  142. sknetwork/linkpred/tests/__init__.py +1 -0
  143. sknetwork/linkpred/tests/test_nn.py +27 -0
  144. sknetwork/log.py +19 -0
  145. sknetwork/path/__init__.py +5 -0
  146. sknetwork/path/dag.py +54 -0
  147. sknetwork/path/distances.py +98 -0
  148. sknetwork/path/search.py +31 -0
  149. sknetwork/path/shortest_path.py +61 -0
  150. sknetwork/path/tests/__init__.py +1 -0
  151. sknetwork/path/tests/test_dag.py +37 -0
  152. sknetwork/path/tests/test_distances.py +62 -0
  153. sknetwork/path/tests/test_search.py +40 -0
  154. sknetwork/path/tests/test_shortest_path.py +40 -0
  155. sknetwork/ranking/__init__.py +8 -0
  156. sknetwork/ranking/base.py +61 -0
  157. sknetwork/ranking/betweenness.cp312-win_amd64.pyd +0 -0
  158. sknetwork/ranking/betweenness.cpp +9701 -0
  159. sknetwork/ranking/betweenness.pyx +97 -0
  160. sknetwork/ranking/closeness.py +92 -0
  161. sknetwork/ranking/hits.py +94 -0
  162. sknetwork/ranking/katz.py +83 -0
  163. sknetwork/ranking/pagerank.py +110 -0
  164. sknetwork/ranking/postprocess.py +37 -0
  165. sknetwork/ranking/tests/__init__.py +1 -0
  166. sknetwork/ranking/tests/test_API.py +32 -0
  167. sknetwork/ranking/tests/test_betweenness.py +38 -0
  168. sknetwork/ranking/tests/test_closeness.py +30 -0
  169. sknetwork/ranking/tests/test_hits.py +20 -0
  170. sknetwork/ranking/tests/test_pagerank.py +62 -0
  171. sknetwork/ranking/tests/test_postprocess.py +26 -0
  172. sknetwork/regression/__init__.py +4 -0
  173. sknetwork/regression/base.py +61 -0
  174. sknetwork/regression/diffusion.py +210 -0
  175. sknetwork/regression/tests/__init__.py +1 -0
  176. sknetwork/regression/tests/test_API.py +32 -0
  177. sknetwork/regression/tests/test_diffusion.py +56 -0
  178. sknetwork/sknetwork.py +3 -0
  179. sknetwork/test_base.py +35 -0
  180. sknetwork/test_log.py +15 -0
  181. sknetwork/topology/__init__.py +8 -0
  182. sknetwork/topology/cliques.cp312-win_amd64.pyd +0 -0
  183. sknetwork/topology/cliques.cpp +31964 -0
  184. sknetwork/topology/cliques.pyx +149 -0
  185. sknetwork/topology/core.cp312-win_amd64.pyd +0 -0
  186. sknetwork/topology/core.cpp +30053 -0
  187. sknetwork/topology/core.pyx +90 -0
  188. sknetwork/topology/cycles.py +243 -0
  189. sknetwork/topology/minheap.cp312-win_amd64.pyd +0 -0
  190. sknetwork/topology/minheap.cpp +27322 -0
  191. sknetwork/topology/minheap.pxd +20 -0
  192. sknetwork/topology/minheap.pyx +109 -0
  193. sknetwork/topology/structure.py +194 -0
  194. sknetwork/topology/tests/__init__.py +1 -0
  195. sknetwork/topology/tests/test_cliques.py +28 -0
  196. sknetwork/topology/tests/test_core.py +19 -0
  197. sknetwork/topology/tests/test_cycles.py +65 -0
  198. sknetwork/topology/tests/test_structure.py +85 -0
  199. sknetwork/topology/tests/test_triangles.py +38 -0
  200. sknetwork/topology/tests/test_wl.py +72 -0
  201. sknetwork/topology/triangles.cp312-win_amd64.pyd +0 -0
  202. sknetwork/topology/triangles.cpp +8889 -0
  203. sknetwork/topology/triangles.pyx +151 -0
  204. sknetwork/topology/weisfeiler_lehman.py +133 -0
  205. sknetwork/topology/weisfeiler_lehman_core.cp312-win_amd64.pyd +0 -0
  206. sknetwork/topology/weisfeiler_lehman_core.cpp +27628 -0
  207. sknetwork/topology/weisfeiler_lehman_core.pyx +114 -0
  208. sknetwork/utils/__init__.py +7 -0
  209. sknetwork/utils/check.py +355 -0
  210. sknetwork/utils/format.py +221 -0
  211. sknetwork/utils/membership.py +82 -0
  212. sknetwork/utils/neighbors.py +115 -0
  213. sknetwork/utils/tests/__init__.py +1 -0
  214. sknetwork/utils/tests/test_check.py +190 -0
  215. sknetwork/utils/tests/test_format.py +63 -0
  216. sknetwork/utils/tests/test_membership.py +24 -0
  217. sknetwork/utils/tests/test_neighbors.py +41 -0
  218. sknetwork/utils/tests/test_tfidf.py +18 -0
  219. sknetwork/utils/tests/test_values.py +66 -0
  220. sknetwork/utils/tfidf.py +37 -0
  221. sknetwork/utils/values.py +76 -0
  222. sknetwork/visualization/__init__.py +4 -0
  223. sknetwork/visualization/colors.py +34 -0
  224. sknetwork/visualization/dendrograms.py +277 -0
  225. sknetwork/visualization/graphs.py +1039 -0
  226. sknetwork/visualization/tests/__init__.py +1 -0
  227. sknetwork/visualization/tests/test_dendrograms.py +53 -0
  228. sknetwork/visualization/tests/test_graphs.py +176 -0
@@ -0,0 +1 @@
1
+ """tests for classification"""
@@ -0,0 +1,30 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Tests for classification API"""
4
+
5
+ import unittest
6
+
7
+ from sknetwork.classification import *
8
+ from sknetwork.data.test_graphs import *
9
+ from sknetwork.embedding import LouvainEmbedding
10
+
11
+
12
+ class TestClassificationAPI(unittest.TestCase):
13
+
14
+ def test_undirected(self):
15
+ for adjacency in [test_graph(), test_digraph()]:
16
+ n = adjacency.shape[0]
17
+ seeds_array = -np.ones(n)
18
+ seeds_array[:2] = np.arange(2)
19
+ seeds_dict = {0: 0, 1: 1}
20
+
21
+ classifiers = [PageRankClassifier(), DiffusionClassifier(),
22
+ NNClassifier(embedding_method=LouvainEmbedding(), n_neighbors=1), Propagation()]
23
+
24
+ for algo in classifiers:
25
+ labels1 = algo.fit_predict(adjacency, seeds_array)
26
+ labels2 = algo.fit_predict(adjacency, seeds_dict)
27
+ self.assertTrue((labels1 == labels2).all())
28
+ self.assertEqual(labels2.shape, (n,))
29
+ membership = algo.fit_transform(adjacency, seeds_array)
30
+ self.assertTupleEqual(membership.shape, (n, 2))
@@ -0,0 +1,77 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Tests for DiffusionClassifier"""
4
+
5
+ import unittest
6
+
7
+ from sknetwork.classification import DiffusionClassifier
8
+ from sknetwork.data.test_graphs import *
9
+
10
+
11
+ class TestDiffusionClassifier(unittest.TestCase):
12
+
13
+ def test_graph(self):
14
+ adjacency = test_graph()
15
+ n_nodes = adjacency.shape[0]
16
+ labels = {0: 0, 1: 1}
17
+ algo = DiffusionClassifier()
18
+ algo.fit(adjacency, labels=labels)
19
+ self.assertTrue(len(algo.labels_) == n_nodes)
20
+ adjacency = test_digraph()
21
+ algo = DiffusionClassifier(centering=False)
22
+ algo.fit(adjacency, labels=labels)
23
+ self.assertTrue(len(algo.labels_) == n_nodes)
24
+ with self.assertRaises(ValueError):
25
+ DiffusionClassifier(n_iter=0)
26
+ algo = DiffusionClassifier(centering=True, scale=10)
27
+ probs = algo.fit_predict_proba(adjacency, labels=labels)[:, 1]
28
+ self.assertTrue(max(probs) > 0.99)
29
+
30
+ def test_bipartite(self):
31
+ biadjacency = test_bigraph()
32
+ n_row, n_col = biadjacency.shape
33
+ labels_row = {0: 0, 1: 1}
34
+ labels_col = {5: 1}
35
+ algo = DiffusionClassifier()
36
+ algo.fit(biadjacency, labels_row=labels_row, labels_col=labels_col)
37
+ self.assertTrue(len(algo.labels_row_) == n_row)
38
+ self.assertTrue(len(algo.labels_col_) == n_col)
39
+ self.assertTrue(all(algo.labels_col_ == algo.predict(columns=True)))
40
+
41
+ def test_predict(self):
42
+ adjacency = test_graph()
43
+ n_nodes = adjacency.shape[0]
44
+ labels = {0: 0, 1: 1}
45
+ algo = DiffusionClassifier()
46
+ labels_pred = algo.fit_predict(adjacency, labels=labels)
47
+ self.assertTrue(len(labels_pred) == n_nodes)
48
+ probs_pred = algo.fit_predict_proba(adjacency, labels=labels)
49
+ self.assertTrue(probs_pred.shape == (n_nodes, 2))
50
+ membership = algo.fit_transform(adjacency, labels=labels)
51
+ self.assertTrue(membership.shape == (n_nodes, 2))
52
+
53
+ biadjacency = test_bigraph()
54
+ n_row, n_col = biadjacency.shape
55
+ labels_row = {0: 0, 1: 1}
56
+ algo = DiffusionClassifier()
57
+ labels_pred = algo.fit_predict(biadjacency, labels_row=labels_row)
58
+ self.assertTrue(len(labels_pred) == n_row)
59
+ labels_pred = algo.predict(columns=True)
60
+ self.assertTrue(len(labels_pred) == n_col)
61
+ probs_pred = algo.fit_predict_proba(biadjacency, labels_row=labels_row)
62
+ self.assertTrue(probs_pred.shape == (n_row, 2))
63
+ probs_pred = algo.predict_proba(columns=True)
64
+ self.assertTrue(probs_pred.shape == (n_col, 2))
65
+ membership = algo.fit_transform(biadjacency, labels_row=labels_row)
66
+ self.assertTrue(membership.shape == (n_row, 2))
67
+ membership = algo.transform(columns=True)
68
+ self.assertTrue(membership.shape == (n_col, 2))
69
+
70
+ def test_reindex_label(self):
71
+ adjacency = test_graph()
72
+ n_nodes = adjacency.shape[0]
73
+ labels = {0: 0, 1: 2, 2: 3}
74
+ algo = DiffusionClassifier()
75
+ labels_pred = algo.fit_predict(adjacency, labels=labels)
76
+ self.assertTrue(len(labels_pred) == n_nodes)
77
+ self.assertTrue(set(list(labels_pred)) == {0, 2, 3})
@@ -0,0 +1,23 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Tests for KNN"""
4
+ import unittest
5
+
6
+ from sknetwork.classification import NNClassifier
7
+ from sknetwork.data.test_graphs import *
8
+ from sknetwork.embedding import Spectral
9
+
10
+
11
+ class TestKNNClassifier(unittest.TestCase):
12
+
13
+ def test_classification(self):
14
+ for adjacency in [test_graph(), test_digraph(), test_bigraph()]:
15
+ labels = {0: 0, 1: 1}
16
+
17
+ algo = NNClassifier(n_neighbors=1)
18
+ labels_pred = algo.fit_predict(adjacency, labels)
19
+ self.assertTrue(len(set(labels_pred)) == 2)
20
+
21
+ algo = NNClassifier(n_neighbors=1, embedding_method=Spectral(2), normalize=False)
22
+ labels_pred = algo.fit_predict(adjacency, labels)
23
+ self.assertTrue(len(set(labels_pred)) == 2)
@@ -0,0 +1,53 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Tests for classification metrics"""
4
+
5
+ import unittest
6
+
7
+ from sknetwork.classification.metrics import *
8
+
9
+
10
+ class TestMetrics(unittest.TestCase):
11
+
12
+ def setUp(self) -> None:
13
+ self.labels_true = np.array([0, 1, 1, 2, 2, -1])
14
+ self.labels_pred1 = np.array([0, -1, 1, 2, 0, 0])
15
+ self.labels_pred2 = np.array([-1, -1, -1, -1, -1, 0])
16
+
17
+ def test_accuracy(self):
18
+ self.assertEqual(get_accuracy_score(self.labels_true, self.labels_pred1), 0.75)
19
+ with self.assertRaises(ValueError):
20
+ get_accuracy_score(self.labels_true, self.labels_pred2)
21
+
22
+ def test_confusion(self):
23
+ confusion = get_confusion_matrix(self.labels_true, self.labels_pred1)
24
+ self.assertEqual(confusion.data.sum(), 4)
25
+ self.assertEqual(confusion.diagonal().sum(), 3)
26
+ with self.assertRaises(ValueError):
27
+ get_accuracy_score(self.labels_true, self.labels_pred2)
28
+
29
+ def test_f1_score(self):
30
+ f1_score = get_f1_score(np.array([0, 0, 1]), np.array([0, 1, 1]))
31
+ self.assertAlmostEqual(f1_score, 0.67, 2)
32
+ with self.assertRaises(ValueError):
33
+ get_f1_score(self.labels_true, self.labels_pred1)
34
+
35
+ def test_f1_scores(self):
36
+ f1_scores = get_f1_scores(self.labels_true, self.labels_pred1)
37
+ self.assertAlmostEqual(min(f1_scores), 0.67, 2)
38
+ f1_scores, precisions, recalls = get_f1_scores(self.labels_true, self.labels_pred1, True)
39
+ self.assertAlmostEqual(min(f1_scores), 0.67, 2)
40
+ self.assertAlmostEqual(min(precisions), 0.5, 2)
41
+ self.assertAlmostEqual(min(recalls), 0.5, 2)
42
+ with self.assertRaises(ValueError):
43
+ get_f1_scores(self.labels_true, self.labels_pred2)
44
+
45
+ def test_average_f1_score(self):
46
+ f1_score = get_average_f1_score(self.labels_true, self.labels_pred1)
47
+ self.assertAlmostEqual(f1_score, 0.78, 2)
48
+ f1_score = get_average_f1_score(self.labels_true, self.labels_pred1, average='micro')
49
+ self.assertEqual(f1_score, 0.75)
50
+ f1_score = get_average_f1_score(self.labels_true, self.labels_pred1, average='weighted')
51
+ self.assertEqual(f1_score, 0.80)
52
+ with self.assertRaises(ValueError):
53
+ get_average_f1_score(self.labels_true, self.labels_pred2, 'toto')
@@ -0,0 +1,20 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Tests for PageRankClassifier"""
4
+
5
+ import unittest
6
+
7
+ from sknetwork.classification import PageRankClassifier
8
+ from sknetwork.data.test_graphs import *
9
+
10
+
11
+ class TestPageRankClassifier(unittest.TestCase):
12
+
13
+ def test_solvers(self):
14
+ adjacency = test_graph()
15
+ labels = {0: 0, 1: 1}
16
+
17
+ ref = PageRankClassifier(solver='piteration').fit_predict(adjacency, labels)
18
+ for solver in ['lanczos', 'bicgstab']:
19
+ labels_pred = PageRankClassifier(solver=solver).fit_predict(adjacency, labels)
20
+ self.assertTrue((ref == labels_pred).all())
@@ -0,0 +1,24 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Tests for label propagation"""
4
+
5
+ import unittest
6
+
7
+ from sknetwork.classification import Propagation
8
+ from sknetwork.data.test_graphs import *
9
+
10
+
11
+ class TestLabelPropagation(unittest.TestCase):
12
+
13
+ def test_algo(self):
14
+ for adjacency in [test_graph(), test_digraph(), test_bigraph()]:
15
+ n = adjacency.shape[0]
16
+ labels = {0: 0, 1: 1}
17
+ propagation = Propagation(n_iter=3, weighted=False)
18
+ labels_pred = propagation.fit_predict(adjacency, labels)
19
+ self.assertEqual(labels_pred.shape, (n,))
20
+
21
+ for order in ['random', 'decreasing', 'increasing']:
22
+ propagation = Propagation(node_order=order)
23
+ labels_pred = propagation.fit_predict(adjacency, labels)
24
+ self.assertEqual(labels_pred.shape, (n,))