scikit-network 0.28.3__cp39-cp39-macosx_12_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-network might be problematic. Click here for more details.

Files changed (240) hide show
  1. scikit_network-0.28.3.dist-info/AUTHORS.rst +41 -0
  2. scikit_network-0.28.3.dist-info/LICENSE +34 -0
  3. scikit_network-0.28.3.dist-info/METADATA +457 -0
  4. scikit_network-0.28.3.dist-info/RECORD +240 -0
  5. scikit_network-0.28.3.dist-info/WHEEL +5 -0
  6. scikit_network-0.28.3.dist-info/top_level.txt +1 -0
  7. sknetwork/__init__.py +21 -0
  8. sknetwork/classification/__init__.py +8 -0
  9. sknetwork/classification/base.py +84 -0
  10. sknetwork/classification/base_rank.py +143 -0
  11. sknetwork/classification/diffusion.py +134 -0
  12. sknetwork/classification/knn.py +162 -0
  13. sknetwork/classification/metrics.py +205 -0
  14. sknetwork/classification/pagerank.py +66 -0
  15. sknetwork/classification/propagation.py +152 -0
  16. sknetwork/classification/tests/__init__.py +1 -0
  17. sknetwork/classification/tests/test_API.py +35 -0
  18. sknetwork/classification/tests/test_diffusion.py +37 -0
  19. sknetwork/classification/tests/test_knn.py +24 -0
  20. sknetwork/classification/tests/test_metrics.py +53 -0
  21. sknetwork/classification/tests/test_pagerank.py +20 -0
  22. sknetwork/classification/tests/test_propagation.py +24 -0
  23. sknetwork/classification/vote.cpython-39-darwin.so +0 -0
  24. sknetwork/classification/vote.pyx +58 -0
  25. sknetwork/clustering/__init__.py +7 -0
  26. sknetwork/clustering/base.py +102 -0
  27. sknetwork/clustering/kmeans.py +142 -0
  28. sknetwork/clustering/louvain.py +255 -0
  29. sknetwork/clustering/louvain_core.cpython-39-darwin.so +0 -0
  30. sknetwork/clustering/louvain_core.pyx +134 -0
  31. sknetwork/clustering/metrics.py +91 -0
  32. sknetwork/clustering/postprocess.py +66 -0
  33. sknetwork/clustering/propagation_clustering.py +108 -0
  34. sknetwork/clustering/tests/__init__.py +1 -0
  35. sknetwork/clustering/tests/test_API.py +37 -0
  36. sknetwork/clustering/tests/test_kmeans.py +47 -0
  37. sknetwork/clustering/tests/test_louvain.py +104 -0
  38. sknetwork/clustering/tests/test_metrics.py +50 -0
  39. sknetwork/clustering/tests/test_post_processing.py +23 -0
  40. sknetwork/clustering/tests/test_postprocess.py +39 -0
  41. sknetwork/data/__init__.py +5 -0
  42. sknetwork/data/load.py +408 -0
  43. sknetwork/data/models.py +459 -0
  44. sknetwork/data/parse.py +621 -0
  45. sknetwork/data/test_graphs.py +84 -0
  46. sknetwork/data/tests/__init__.py +1 -0
  47. sknetwork/data/tests/test_API.py +30 -0
  48. sknetwork/data/tests/test_load.py +95 -0
  49. sknetwork/data/tests/test_models.py +52 -0
  50. sknetwork/data/tests/test_parse.py +253 -0
  51. sknetwork/data/tests/test_test_graphs.py +30 -0
  52. sknetwork/data/tests/test_toy_graphs.py +68 -0
  53. sknetwork/data/toy_graphs.py +619 -0
  54. sknetwork/embedding/__init__.py +10 -0
  55. sknetwork/embedding/base.py +90 -0
  56. sknetwork/embedding/force_atlas.py +197 -0
  57. sknetwork/embedding/louvain_embedding.py +174 -0
  58. sknetwork/embedding/louvain_hierarchy.py +142 -0
  59. sknetwork/embedding/metrics.py +66 -0
  60. sknetwork/embedding/random_projection.py +133 -0
  61. sknetwork/embedding/spectral.py +214 -0
  62. sknetwork/embedding/spring.py +198 -0
  63. sknetwork/embedding/svd.py +363 -0
  64. sknetwork/embedding/tests/__init__.py +1 -0
  65. sknetwork/embedding/tests/test_API.py +73 -0
  66. sknetwork/embedding/tests/test_force_atlas.py +35 -0
  67. sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
  68. sknetwork/embedding/tests/test_louvain_hierarchy.py +19 -0
  69. sknetwork/embedding/tests/test_metrics.py +29 -0
  70. sknetwork/embedding/tests/test_random_projection.py +28 -0
  71. sknetwork/embedding/tests/test_spectral.py +84 -0
  72. sknetwork/embedding/tests/test_spring.py +50 -0
  73. sknetwork/embedding/tests/test_svd.py +37 -0
  74. sknetwork/flow/__init__.py +3 -0
  75. sknetwork/flow/flow.py +73 -0
  76. sknetwork/flow/tests/__init__.py +1 -0
  77. sknetwork/flow/tests/test_flow.py +17 -0
  78. sknetwork/flow/tests/test_utils.py +69 -0
  79. sknetwork/flow/utils.py +91 -0
  80. sknetwork/gnn/__init__.py +10 -0
  81. sknetwork/gnn/activation.py +117 -0
  82. sknetwork/gnn/base.py +155 -0
  83. sknetwork/gnn/base_activation.py +89 -0
  84. sknetwork/gnn/base_layer.py +109 -0
  85. sknetwork/gnn/gnn_classifier.py +381 -0
  86. sknetwork/gnn/layer.py +153 -0
  87. sknetwork/gnn/layers.py +127 -0
  88. sknetwork/gnn/loss.py +180 -0
  89. sknetwork/gnn/neighbor_sampler.py +65 -0
  90. sknetwork/gnn/optimizer.py +163 -0
  91. sknetwork/gnn/tests/__init__.py +1 -0
  92. sknetwork/gnn/tests/test_activation.py +56 -0
  93. sknetwork/gnn/tests/test_base.py +79 -0
  94. sknetwork/gnn/tests/test_base_layer.py +37 -0
  95. sknetwork/gnn/tests/test_gnn_classifier.py +192 -0
  96. sknetwork/gnn/tests/test_layers.py +80 -0
  97. sknetwork/gnn/tests/test_loss.py +33 -0
  98. sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
  99. sknetwork/gnn/tests/test_optimizer.py +43 -0
  100. sknetwork/gnn/tests/test_utils.py +93 -0
  101. sknetwork/gnn/utils.py +219 -0
  102. sknetwork/hierarchy/__init__.py +7 -0
  103. sknetwork/hierarchy/base.py +69 -0
  104. sknetwork/hierarchy/louvain_hierarchy.py +264 -0
  105. sknetwork/hierarchy/metrics.py +234 -0
  106. sknetwork/hierarchy/paris.cpython-39-darwin.so +0 -0
  107. sknetwork/hierarchy/paris.pyx +317 -0
  108. sknetwork/hierarchy/postprocess.py +350 -0
  109. sknetwork/hierarchy/tests/__init__.py +1 -0
  110. sknetwork/hierarchy/tests/test_API.py +25 -0
  111. sknetwork/hierarchy/tests/test_algos.py +29 -0
  112. sknetwork/hierarchy/tests/test_metrics.py +62 -0
  113. sknetwork/hierarchy/tests/test_postprocess.py +57 -0
  114. sknetwork/hierarchy/tests/test_ward.py +25 -0
  115. sknetwork/hierarchy/ward.py +94 -0
  116. sknetwork/linalg/__init__.py +9 -0
  117. sknetwork/linalg/basics.py +37 -0
  118. sknetwork/linalg/diteration.cpython-39-darwin.so +0 -0
  119. sknetwork/linalg/diteration.pyx +49 -0
  120. sknetwork/linalg/eig_solver.py +93 -0
  121. sknetwork/linalg/laplacian.py +15 -0
  122. sknetwork/linalg/normalization.py +66 -0
  123. sknetwork/linalg/operators.py +225 -0
  124. sknetwork/linalg/polynome.py +76 -0
  125. sknetwork/linalg/ppr_solver.py +170 -0
  126. sknetwork/linalg/push.cpython-39-darwin.so +0 -0
  127. sknetwork/linalg/push.pyx +73 -0
  128. sknetwork/linalg/sparse_lowrank.py +142 -0
  129. sknetwork/linalg/svd_solver.py +91 -0
  130. sknetwork/linalg/tests/__init__.py +1 -0
  131. sknetwork/linalg/tests/test_eig.py +44 -0
  132. sknetwork/linalg/tests/test_laplacian.py +18 -0
  133. sknetwork/linalg/tests/test_normalization.py +38 -0
  134. sknetwork/linalg/tests/test_operators.py +70 -0
  135. sknetwork/linalg/tests/test_polynome.py +38 -0
  136. sknetwork/linalg/tests/test_ppr.py +50 -0
  137. sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
  138. sknetwork/linalg/tests/test_svd.py +38 -0
  139. sknetwork/linkpred/__init__.py +4 -0
  140. sknetwork/linkpred/base.py +80 -0
  141. sknetwork/linkpred/first_order.py +508 -0
  142. sknetwork/linkpred/first_order_core.cpython-39-darwin.so +0 -0
  143. sknetwork/linkpred/first_order_core.pyx +315 -0
  144. sknetwork/linkpred/postprocessing.py +98 -0
  145. sknetwork/linkpred/tests/__init__.py +1 -0
  146. sknetwork/linkpred/tests/test_API.py +49 -0
  147. sknetwork/linkpred/tests/test_postprocessing.py +21 -0
  148. sknetwork/path/__init__.py +4 -0
  149. sknetwork/path/metrics.py +148 -0
  150. sknetwork/path/search.py +65 -0
  151. sknetwork/path/shortest_path.py +186 -0
  152. sknetwork/path/tests/__init__.py +1 -0
  153. sknetwork/path/tests/test_metrics.py +29 -0
  154. sknetwork/path/tests/test_search.py +25 -0
  155. sknetwork/path/tests/test_shortest_path.py +45 -0
  156. sknetwork/ranking/__init__.py +9 -0
  157. sknetwork/ranking/base.py +56 -0
  158. sknetwork/ranking/betweenness.cpython-39-darwin.so +0 -0
  159. sknetwork/ranking/betweenness.pyx +99 -0
  160. sknetwork/ranking/closeness.py +95 -0
  161. sknetwork/ranking/harmonic.py +82 -0
  162. sknetwork/ranking/hits.py +94 -0
  163. sknetwork/ranking/katz.py +81 -0
  164. sknetwork/ranking/pagerank.py +107 -0
  165. sknetwork/ranking/postprocess.py +25 -0
  166. sknetwork/ranking/tests/__init__.py +1 -0
  167. sknetwork/ranking/tests/test_API.py +34 -0
  168. sknetwork/ranking/tests/test_betweenness.py +38 -0
  169. sknetwork/ranking/tests/test_closeness.py +34 -0
  170. sknetwork/ranking/tests/test_hits.py +20 -0
  171. sknetwork/ranking/tests/test_pagerank.py +69 -0
  172. sknetwork/regression/__init__.py +4 -0
  173. sknetwork/regression/base.py +56 -0
  174. sknetwork/regression/diffusion.py +190 -0
  175. sknetwork/regression/tests/__init__.py +1 -0
  176. sknetwork/regression/tests/test_API.py +34 -0
  177. sknetwork/regression/tests/test_diffusion.py +48 -0
  178. sknetwork/sknetwork.py +3 -0
  179. sknetwork/topology/__init__.py +9 -0
  180. sknetwork/topology/dag.py +74 -0
  181. sknetwork/topology/dag_core.cpython-39-darwin.so +0 -0
  182. sknetwork/topology/dag_core.pyx +38 -0
  183. sknetwork/topology/kcliques.cpython-39-darwin.so +0 -0
  184. sknetwork/topology/kcliques.pyx +193 -0
  185. sknetwork/topology/kcore.cpython-39-darwin.so +0 -0
  186. sknetwork/topology/kcore.pyx +120 -0
  187. sknetwork/topology/structure.py +234 -0
  188. sknetwork/topology/tests/__init__.py +1 -0
  189. sknetwork/topology/tests/test_cliques.py +28 -0
  190. sknetwork/topology/tests/test_cores.py +21 -0
  191. sknetwork/topology/tests/test_dag.py +26 -0
  192. sknetwork/topology/tests/test_structure.py +99 -0
  193. sknetwork/topology/tests/test_triangles.py +42 -0
  194. sknetwork/topology/tests/test_wl_coloring.py +49 -0
  195. sknetwork/topology/tests/test_wl_kernel.py +31 -0
  196. sknetwork/topology/triangles.cpython-39-darwin.so +0 -0
  197. sknetwork/topology/triangles.pyx +166 -0
  198. sknetwork/topology/weisfeiler_lehman.py +163 -0
  199. sknetwork/topology/weisfeiler_lehman_core.cpython-39-darwin.so +0 -0
  200. sknetwork/topology/weisfeiler_lehman_core.pyx +116 -0
  201. sknetwork/utils/__init__.py +40 -0
  202. sknetwork/utils/base.py +35 -0
  203. sknetwork/utils/check.py +354 -0
  204. sknetwork/utils/co_neighbor.py +71 -0
  205. sknetwork/utils/format.py +219 -0
  206. sknetwork/utils/kmeans.py +89 -0
  207. sknetwork/utils/knn.py +166 -0
  208. sknetwork/utils/knn1d.cpython-39-darwin.so +0 -0
  209. sknetwork/utils/knn1d.pyx +80 -0
  210. sknetwork/utils/membership.py +82 -0
  211. sknetwork/utils/minheap.cpython-39-darwin.so +0 -0
  212. sknetwork/utils/minheap.pxd +22 -0
  213. sknetwork/utils/minheap.pyx +111 -0
  214. sknetwork/utils/neighbors.py +115 -0
  215. sknetwork/utils/seeds.py +75 -0
  216. sknetwork/utils/simplex.py +140 -0
  217. sknetwork/utils/tests/__init__.py +1 -0
  218. sknetwork/utils/tests/test_base.py +28 -0
  219. sknetwork/utils/tests/test_bunch.py +16 -0
  220. sknetwork/utils/tests/test_check.py +190 -0
  221. sknetwork/utils/tests/test_co_neighbor.py +43 -0
  222. sknetwork/utils/tests/test_format.py +61 -0
  223. sknetwork/utils/tests/test_kmeans.py +21 -0
  224. sknetwork/utils/tests/test_knn.py +32 -0
  225. sknetwork/utils/tests/test_membership.py +24 -0
  226. sknetwork/utils/tests/test_neighbors.py +41 -0
  227. sknetwork/utils/tests/test_projection_simplex.py +33 -0
  228. sknetwork/utils/tests/test_seeds.py +67 -0
  229. sknetwork/utils/tests/test_verbose.py +15 -0
  230. sknetwork/utils/tests/test_ward.py +20 -0
  231. sknetwork/utils/timeout.py +38 -0
  232. sknetwork/utils/verbose.py +37 -0
  233. sknetwork/utils/ward.py +60 -0
  234. sknetwork/visualization/__init__.py +4 -0
  235. sknetwork/visualization/colors.py +34 -0
  236. sknetwork/visualization/dendrograms.py +229 -0
  237. sknetwork/visualization/graphs.py +819 -0
  238. sknetwork/visualization/tests/__init__.py +1 -0
  239. sknetwork/visualization/tests/test_dendrograms.py +53 -0
  240. sknetwork/visualization/tests/test_graphs.py +167 -0
@@ -0,0 +1,50 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Tests for Louvain"""
4
+ import unittest
5
+
6
+ import numpy as np
7
+
8
+ from sknetwork.data import house, karate_club
9
+ from sknetwork.data.parse import from_edge_list
10
+ from sknetwork.data.test_graphs import *
11
+ from sknetwork.linalg.operators import Regularizer
12
+ from sknetwork.linalg.ppr_solver import get_pagerank
13
+ from sknetwork.utils.check import is_proba_array
14
+
15
+
16
+ class TestPPR(unittest.TestCase):
17
+
18
+ def test_diteration(self):
19
+ # test convergence by tolerance
20
+ for adjacency in [house(), test_graph(), test_digraph()]:
21
+ seeds = np.ones(adjacency.shape[0]) / adjacency.shape[0]
22
+ pr = get_pagerank(adjacency, damping_factor=0.85, n_iter=100, tol=10, solver='diteration', seeds=seeds)
23
+ self.assertTrue(is_proba_array(pr))
24
+
25
+ # test graph with some null out-degree
26
+ adjacency = from_edge_list([(0, 1)])
27
+ seeds = np.ones(adjacency.shape[0]) / adjacency.shape[0]
28
+ pr = get_pagerank(adjacency, damping_factor=0.85, n_iter=100, tol=10, solver='diteration', seeds=seeds)
29
+ self.assertTrue(is_proba_array(pr))
30
+
31
+ # test invalid entry
32
+ adjacency = Regularizer(house(), 0.1)
33
+ seeds = np.ones(adjacency.shape[0]) / adjacency.shape[0]
34
+ with self.assertRaises(ValueError):
35
+ get_pagerank(adjacency, damping_factor=0.85, n_iter=100, tol=10, solver='diteration', seeds=seeds)
36
+
37
+ def test_push(self):
38
+ # test convergence by tolerance
39
+ adjacency = karate_club()
40
+ seeds = np.ones(adjacency.shape[0]) / adjacency.shape[0]
41
+ pr = get_pagerank(adjacency, damping_factor=0.85,
42
+ n_iter=100, tol=1e-1, solver='push', seeds=seeds)
43
+ self.assertTrue(is_proba_array(pr))
44
+
45
+ def test_piteration(self):
46
+ # test on SparseLR matrix
47
+ adjacency = Regularizer(house(), 0.1)
48
+ seeds = np.ones(adjacency.shape[0]) / adjacency.shape[0]
49
+ pr = get_pagerank(adjacency, damping_factor=0.85, n_iter=100, tol=10, solver='piteration', seeds=seeds)
50
+ self.assertTrue(is_proba_array(pr))
@@ -0,0 +1,61 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Tests for embeddings metrics."""
4
+
5
+ import unittest
6
+
7
+ import numpy as np
8
+
9
+ from sknetwork.data import house, star_wars
10
+ from sknetwork.linalg.sparse_lowrank import SparseLR
11
+
12
+
13
+ class TestSparseLowRank(unittest.TestCase):
14
+
15
+ def setUp(self):
16
+ """Simple regularized adjacency and biadjacency for tests."""
17
+ self.undirected = SparseLR(house(), [(np.ones(5), np.ones(5))])
18
+ self.bipartite = SparseLR(star_wars(), [(np.ones(4), np.ones(3))])
19
+
20
+ def test_init(self):
21
+ with self.assertRaises(ValueError):
22
+ SparseLR(house(), [(np.ones(5), np.ones(4))])
23
+ with self.assertRaises(ValueError):
24
+ SparseLR(house(), [(np.ones(4), np.ones(5))])
25
+
26
+ def test_addition(self):
27
+ addition = self.undirected + self.undirected
28
+ expected = SparseLR(2 * house(), [(np.ones(5), 2 * np.ones(5))])
29
+ err = (addition.sparse_mat - expected.sparse_mat).count_nonzero()
30
+ self.assertEqual(err, 0)
31
+ x = np.random.rand(5)
32
+ self.assertAlmostEqual(np.linalg.norm(addition.dot(x) - expected.dot(x)), 0)
33
+
34
+ def test_operations(self):
35
+ adjacency = self.undirected.sparse_mat
36
+ slr = -self.undirected
37
+ slr += adjacency
38
+ slr -= adjacency
39
+ slr.left_sparse_dot(adjacency)
40
+ slr.right_sparse_dot(adjacency)
41
+ slr.astype(float)
42
+
43
+ def test_product(self):
44
+ prod = self.undirected.dot(np.ones(5))
45
+ self.assertEqual(prod.shape, (5,))
46
+ prod = self.bipartite.dot(np.ones(3))
47
+ self.assertEqual(np.linalg.norm(prod - np.array([5., 4., 6., 5.])), 0.)
48
+ prod = self.bipartite.dot(0.5 * np.ones(3))
49
+ self.assertEqual(np.linalg.norm(prod - np.array([2.5, 2., 3., 2.5])), 0.)
50
+ prod = (2 * self.bipartite).dot(0.5 * np.ones(3))
51
+ self.assertEqual(np.linalg.norm(prod - 2 * np.array([2.5, 2., 3., 2.5])), 0.)
52
+
53
+ def test_transposition(self):
54
+ transposed = self.undirected.T
55
+ error = (self.undirected.sparse_mat - transposed.sparse_mat).data
56
+ self.assertEqual(abs(error).sum(), 0.)
57
+ transposed = self.bipartite.T
58
+ x, y = transposed.low_rank_tuples[0]
59
+ self.assertTrue((x == np.ones(3)).all())
60
+ self.assertTrue((y == np.ones(4)).all())
61
+
@@ -0,0 +1,38 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Tests for svd."""
4
+
5
+ import unittest
6
+
7
+ import numpy as np
8
+
9
+ from sknetwork.data import movie_actor
10
+ from sknetwork.linalg import LanczosSVD, SparseLR
11
+
12
+
13
+ def svd_err(matrix, u, v, sigma):
14
+ """Approximation error for singular vectors."""
15
+ err = matrix.dot(v) - u * sigma
16
+ return np.linalg.norm(err)
17
+
18
+
19
+ # noinspection DuplicatedCode
20
+ class TestSolvers(unittest.TestCase):
21
+
22
+ def setUp(self):
23
+ """Simple biadjacency for tests."""
24
+ self.biadjacency = movie_actor()
25
+ n_row, n_col = self.biadjacency.shape
26
+ self.slr = SparseLR(self.biadjacency, [(np.random.rand(n_row), np.random.rand(n_col))])
27
+
28
+ def test_lanczos(self):
29
+ solver = LanczosSVD()
30
+ solver.fit(self.biadjacency, 2)
31
+ self.assertEqual(len(solver.singular_values_), 2)
32
+ self.assertAlmostEqual(svd_err(self.biadjacency, solver.singular_vectors_left_, solver.singular_vectors_right_,
33
+ solver.singular_values_), 0)
34
+
35
+ solver.fit(self.slr, 2)
36
+ self.assertEqual(len(solver.singular_values_), 2)
37
+ self.assertAlmostEqual(svd_err(self.slr, solver.singular_vectors_left_, solver.singular_vectors_right_,
38
+ solver.singular_values_), 0)
@@ -0,0 +1,4 @@
1
+ """link prediction module"""
2
+ from sknetwork.linkpred.first_order import CommonNeighbors, JaccardIndex, SaltonIndex, SorensenIndex, HubPromotedIndex,\
3
+ HubDepressedIndex, AdamicAdar, ResourceAllocation, PreferentialAttachment
4
+ from sknetwork.linkpred.postprocessing import is_edge, whitened_sigmoid
@@ -0,0 +1,80 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on May, 2020
5
+ @author: Nathan de Lara <nathan.delara@polytechnique.org>
6
+ """
7
+ from abc import ABC
8
+ from typing import Union, Iterable, Tuple
9
+
10
+ import numpy as np
11
+
12
+ from sknetwork.utils.base import Algorithm
13
+
14
+
15
+ class BaseLinkPred(Algorithm, ABC):
16
+ """Base class for link prediction algorithms."""
17
+
18
+ def _predict_base(self, source: int, targets: Iterable):
19
+ """Prediction for a single node and multiple targets"""
20
+ raise NotImplementedError
21
+
22
+ def _predict_node(self, node: int):
23
+ """Prediction for a single node."""
24
+ raise NotImplementedError
25
+
26
+ def _predict_nodes(self, nodes: np.ndarray):
27
+ """Prediction for multiple nodes."""
28
+ preds = []
29
+ for node in nodes:
30
+ preds.append(self._predict_node(node))
31
+ return np.array(preds)
32
+
33
+ def _predict_edge(self, source: int, target: int):
34
+ """Prediction for a single edge."""
35
+ return self._predict_base(source, [target])[0]
36
+
37
+ def _predict_edges(self, edges: np.ndarray):
38
+ """Prediction for a list of edges."""
39
+ preds = []
40
+ for edge in edges:
41
+ i, j = edge[0], edge[1]
42
+ preds.append(self._predict_edge(i, j))
43
+ return np.array(preds)
44
+
45
+ def predict(self, query: Union[int, Iterable, Tuple]):
46
+ """Compute similarity scores.
47
+
48
+ Parameters
49
+ ----------
50
+ query : int, list, array or Tuple
51
+ * If int i, return the similarities s(i, j) for all j.
52
+ * If list or array integers, return s(i, j) for i in query, for all j as array.
53
+ * If tuple (i, j), return the similarity s(i, j).
54
+ * If list of tuples or array of shape (n_queries, 2), return s(i, j) for (i, j) in query as array.
55
+
56
+ Returns
57
+ -------
58
+ predictions : int, float or array
59
+ The prediction scores.
60
+ """
61
+ if np.issubdtype(type(query), np.integer):
62
+ return self._predict_node(query)
63
+ if isinstance(query, Tuple):
64
+ return self._predict_edge(query[0], query[1])
65
+ if isinstance(query, list):
66
+ query = np.array(query)
67
+ if isinstance(query, np.ndarray):
68
+ if query.ndim == 1:
69
+ return self._predict_nodes(query)
70
+ elif query.ndim == 2 and query.shape[1] == 2:
71
+ return self._predict_edges(query)
72
+ else:
73
+ raise ValueError("Query not understood.")
74
+ else:
75
+ raise ValueError("Query not understood.")
76
+
77
+ def fit_predict(self, adjacency, query):
78
+ """Fit algorithm to data and compute scores for requested edges."""
79
+ self.fit(adjacency)
80
+ return self.predict(query)