scikit-network 0.28.3__cp39-cp39-macosx_12_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-network might be problematic. Click here for more details.
- scikit_network-0.28.3.dist-info/AUTHORS.rst +41 -0
- scikit_network-0.28.3.dist-info/LICENSE +34 -0
- scikit_network-0.28.3.dist-info/METADATA +457 -0
- scikit_network-0.28.3.dist-info/RECORD +240 -0
- scikit_network-0.28.3.dist-info/WHEEL +5 -0
- scikit_network-0.28.3.dist-info/top_level.txt +1 -0
- sknetwork/__init__.py +21 -0
- sknetwork/classification/__init__.py +8 -0
- sknetwork/classification/base.py +84 -0
- sknetwork/classification/base_rank.py +143 -0
- sknetwork/classification/diffusion.py +134 -0
- sknetwork/classification/knn.py +162 -0
- sknetwork/classification/metrics.py +205 -0
- sknetwork/classification/pagerank.py +66 -0
- sknetwork/classification/propagation.py +152 -0
- sknetwork/classification/tests/__init__.py +1 -0
- sknetwork/classification/tests/test_API.py +35 -0
- sknetwork/classification/tests/test_diffusion.py +37 -0
- sknetwork/classification/tests/test_knn.py +24 -0
- sknetwork/classification/tests/test_metrics.py +53 -0
- sknetwork/classification/tests/test_pagerank.py +20 -0
- sknetwork/classification/tests/test_propagation.py +24 -0
- sknetwork/classification/vote.cpython-39-darwin.so +0 -0
- sknetwork/classification/vote.pyx +58 -0
- sknetwork/clustering/__init__.py +7 -0
- sknetwork/clustering/base.py +102 -0
- sknetwork/clustering/kmeans.py +142 -0
- sknetwork/clustering/louvain.py +255 -0
- sknetwork/clustering/louvain_core.cpython-39-darwin.so +0 -0
- sknetwork/clustering/louvain_core.pyx +134 -0
- sknetwork/clustering/metrics.py +91 -0
- sknetwork/clustering/postprocess.py +66 -0
- sknetwork/clustering/propagation_clustering.py +108 -0
- sknetwork/clustering/tests/__init__.py +1 -0
- sknetwork/clustering/tests/test_API.py +37 -0
- sknetwork/clustering/tests/test_kmeans.py +47 -0
- sknetwork/clustering/tests/test_louvain.py +104 -0
- sknetwork/clustering/tests/test_metrics.py +50 -0
- sknetwork/clustering/tests/test_post_processing.py +23 -0
- sknetwork/clustering/tests/test_postprocess.py +39 -0
- sknetwork/data/__init__.py +5 -0
- sknetwork/data/load.py +408 -0
- sknetwork/data/models.py +459 -0
- sknetwork/data/parse.py +621 -0
- sknetwork/data/test_graphs.py +84 -0
- sknetwork/data/tests/__init__.py +1 -0
- sknetwork/data/tests/test_API.py +30 -0
- sknetwork/data/tests/test_load.py +95 -0
- sknetwork/data/tests/test_models.py +52 -0
- sknetwork/data/tests/test_parse.py +253 -0
- sknetwork/data/tests/test_test_graphs.py +30 -0
- sknetwork/data/tests/test_toy_graphs.py +68 -0
- sknetwork/data/toy_graphs.py +619 -0
- sknetwork/embedding/__init__.py +10 -0
- sknetwork/embedding/base.py +90 -0
- sknetwork/embedding/force_atlas.py +197 -0
- sknetwork/embedding/louvain_embedding.py +174 -0
- sknetwork/embedding/louvain_hierarchy.py +142 -0
- sknetwork/embedding/metrics.py +66 -0
- sknetwork/embedding/random_projection.py +133 -0
- sknetwork/embedding/spectral.py +214 -0
- sknetwork/embedding/spring.py +198 -0
- sknetwork/embedding/svd.py +363 -0
- sknetwork/embedding/tests/__init__.py +1 -0
- sknetwork/embedding/tests/test_API.py +73 -0
- sknetwork/embedding/tests/test_force_atlas.py +35 -0
- sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
- sknetwork/embedding/tests/test_louvain_hierarchy.py +19 -0
- sknetwork/embedding/tests/test_metrics.py +29 -0
- sknetwork/embedding/tests/test_random_projection.py +28 -0
- sknetwork/embedding/tests/test_spectral.py +84 -0
- sknetwork/embedding/tests/test_spring.py +50 -0
- sknetwork/embedding/tests/test_svd.py +37 -0
- sknetwork/flow/__init__.py +3 -0
- sknetwork/flow/flow.py +73 -0
- sknetwork/flow/tests/__init__.py +1 -0
- sknetwork/flow/tests/test_flow.py +17 -0
- sknetwork/flow/tests/test_utils.py +69 -0
- sknetwork/flow/utils.py +91 -0
- sknetwork/gnn/__init__.py +10 -0
- sknetwork/gnn/activation.py +117 -0
- sknetwork/gnn/base.py +155 -0
- sknetwork/gnn/base_activation.py +89 -0
- sknetwork/gnn/base_layer.py +109 -0
- sknetwork/gnn/gnn_classifier.py +381 -0
- sknetwork/gnn/layer.py +153 -0
- sknetwork/gnn/layers.py +127 -0
- sknetwork/gnn/loss.py +180 -0
- sknetwork/gnn/neighbor_sampler.py +65 -0
- sknetwork/gnn/optimizer.py +163 -0
- sknetwork/gnn/tests/__init__.py +1 -0
- sknetwork/gnn/tests/test_activation.py +56 -0
- sknetwork/gnn/tests/test_base.py +79 -0
- sknetwork/gnn/tests/test_base_layer.py +37 -0
- sknetwork/gnn/tests/test_gnn_classifier.py +192 -0
- sknetwork/gnn/tests/test_layers.py +80 -0
- sknetwork/gnn/tests/test_loss.py +33 -0
- sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
- sknetwork/gnn/tests/test_optimizer.py +43 -0
- sknetwork/gnn/tests/test_utils.py +93 -0
- sknetwork/gnn/utils.py +219 -0
- sknetwork/hierarchy/__init__.py +7 -0
- sknetwork/hierarchy/base.py +69 -0
- sknetwork/hierarchy/louvain_hierarchy.py +264 -0
- sknetwork/hierarchy/metrics.py +234 -0
- sknetwork/hierarchy/paris.cpython-39-darwin.so +0 -0
- sknetwork/hierarchy/paris.pyx +317 -0
- sknetwork/hierarchy/postprocess.py +350 -0
- sknetwork/hierarchy/tests/__init__.py +1 -0
- sknetwork/hierarchy/tests/test_API.py +25 -0
- sknetwork/hierarchy/tests/test_algos.py +29 -0
- sknetwork/hierarchy/tests/test_metrics.py +62 -0
- sknetwork/hierarchy/tests/test_postprocess.py +57 -0
- sknetwork/hierarchy/tests/test_ward.py +25 -0
- sknetwork/hierarchy/ward.py +94 -0
- sknetwork/linalg/__init__.py +9 -0
- sknetwork/linalg/basics.py +37 -0
- sknetwork/linalg/diteration.cpython-39-darwin.so +0 -0
- sknetwork/linalg/diteration.pyx +49 -0
- sknetwork/linalg/eig_solver.py +93 -0
- sknetwork/linalg/laplacian.py +15 -0
- sknetwork/linalg/normalization.py +66 -0
- sknetwork/linalg/operators.py +225 -0
- sknetwork/linalg/polynome.py +76 -0
- sknetwork/linalg/ppr_solver.py +170 -0
- sknetwork/linalg/push.cpython-39-darwin.so +0 -0
- sknetwork/linalg/push.pyx +73 -0
- sknetwork/linalg/sparse_lowrank.py +142 -0
- sknetwork/linalg/svd_solver.py +91 -0
- sknetwork/linalg/tests/__init__.py +1 -0
- sknetwork/linalg/tests/test_eig.py +44 -0
- sknetwork/linalg/tests/test_laplacian.py +18 -0
- sknetwork/linalg/tests/test_normalization.py +38 -0
- sknetwork/linalg/tests/test_operators.py +70 -0
- sknetwork/linalg/tests/test_polynome.py +38 -0
- sknetwork/linalg/tests/test_ppr.py +50 -0
- sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
- sknetwork/linalg/tests/test_svd.py +38 -0
- sknetwork/linkpred/__init__.py +4 -0
- sknetwork/linkpred/base.py +80 -0
- sknetwork/linkpred/first_order.py +508 -0
- sknetwork/linkpred/first_order_core.cpython-39-darwin.so +0 -0
- sknetwork/linkpred/first_order_core.pyx +315 -0
- sknetwork/linkpred/postprocessing.py +98 -0
- sknetwork/linkpred/tests/__init__.py +1 -0
- sknetwork/linkpred/tests/test_API.py +49 -0
- sknetwork/linkpred/tests/test_postprocessing.py +21 -0
- sknetwork/path/__init__.py +4 -0
- sknetwork/path/metrics.py +148 -0
- sknetwork/path/search.py +65 -0
- sknetwork/path/shortest_path.py +186 -0
- sknetwork/path/tests/__init__.py +1 -0
- sknetwork/path/tests/test_metrics.py +29 -0
- sknetwork/path/tests/test_search.py +25 -0
- sknetwork/path/tests/test_shortest_path.py +45 -0
- sknetwork/ranking/__init__.py +9 -0
- sknetwork/ranking/base.py +56 -0
- sknetwork/ranking/betweenness.cpython-39-darwin.so +0 -0
- sknetwork/ranking/betweenness.pyx +99 -0
- sknetwork/ranking/closeness.py +95 -0
- sknetwork/ranking/harmonic.py +82 -0
- sknetwork/ranking/hits.py +94 -0
- sknetwork/ranking/katz.py +81 -0
- sknetwork/ranking/pagerank.py +107 -0
- sknetwork/ranking/postprocess.py +25 -0
- sknetwork/ranking/tests/__init__.py +1 -0
- sknetwork/ranking/tests/test_API.py +34 -0
- sknetwork/ranking/tests/test_betweenness.py +38 -0
- sknetwork/ranking/tests/test_closeness.py +34 -0
- sknetwork/ranking/tests/test_hits.py +20 -0
- sknetwork/ranking/tests/test_pagerank.py +69 -0
- sknetwork/regression/__init__.py +4 -0
- sknetwork/regression/base.py +56 -0
- sknetwork/regression/diffusion.py +190 -0
- sknetwork/regression/tests/__init__.py +1 -0
- sknetwork/regression/tests/test_API.py +34 -0
- sknetwork/regression/tests/test_diffusion.py +48 -0
- sknetwork/sknetwork.py +3 -0
- sknetwork/topology/__init__.py +9 -0
- sknetwork/topology/dag.py +74 -0
- sknetwork/topology/dag_core.cpython-39-darwin.so +0 -0
- sknetwork/topology/dag_core.pyx +38 -0
- sknetwork/topology/kcliques.cpython-39-darwin.so +0 -0
- sknetwork/topology/kcliques.pyx +193 -0
- sknetwork/topology/kcore.cpython-39-darwin.so +0 -0
- sknetwork/topology/kcore.pyx +120 -0
- sknetwork/topology/structure.py +234 -0
- sknetwork/topology/tests/__init__.py +1 -0
- sknetwork/topology/tests/test_cliques.py +28 -0
- sknetwork/topology/tests/test_cores.py +21 -0
- sknetwork/topology/tests/test_dag.py +26 -0
- sknetwork/topology/tests/test_structure.py +99 -0
- sknetwork/topology/tests/test_triangles.py +42 -0
- sknetwork/topology/tests/test_wl_coloring.py +49 -0
- sknetwork/topology/tests/test_wl_kernel.py +31 -0
- sknetwork/topology/triangles.cpython-39-darwin.so +0 -0
- sknetwork/topology/triangles.pyx +166 -0
- sknetwork/topology/weisfeiler_lehman.py +163 -0
- sknetwork/topology/weisfeiler_lehman_core.cpython-39-darwin.so +0 -0
- sknetwork/topology/weisfeiler_lehman_core.pyx +116 -0
- sknetwork/utils/__init__.py +40 -0
- sknetwork/utils/base.py +35 -0
- sknetwork/utils/check.py +354 -0
- sknetwork/utils/co_neighbor.py +71 -0
- sknetwork/utils/format.py +219 -0
- sknetwork/utils/kmeans.py +89 -0
- sknetwork/utils/knn.py +166 -0
- sknetwork/utils/knn1d.cpython-39-darwin.so +0 -0
- sknetwork/utils/knn1d.pyx +80 -0
- sknetwork/utils/membership.py +82 -0
- sknetwork/utils/minheap.cpython-39-darwin.so +0 -0
- sknetwork/utils/minheap.pxd +22 -0
- sknetwork/utils/minheap.pyx +111 -0
- sknetwork/utils/neighbors.py +115 -0
- sknetwork/utils/seeds.py +75 -0
- sknetwork/utils/simplex.py +140 -0
- sknetwork/utils/tests/__init__.py +1 -0
- sknetwork/utils/tests/test_base.py +28 -0
- sknetwork/utils/tests/test_bunch.py +16 -0
- sknetwork/utils/tests/test_check.py +190 -0
- sknetwork/utils/tests/test_co_neighbor.py +43 -0
- sknetwork/utils/tests/test_format.py +61 -0
- sknetwork/utils/tests/test_kmeans.py +21 -0
- sknetwork/utils/tests/test_knn.py +32 -0
- sknetwork/utils/tests/test_membership.py +24 -0
- sknetwork/utils/tests/test_neighbors.py +41 -0
- sknetwork/utils/tests/test_projection_simplex.py +33 -0
- sknetwork/utils/tests/test_seeds.py +67 -0
- sknetwork/utils/tests/test_verbose.py +15 -0
- sknetwork/utils/tests/test_ward.py +20 -0
- sknetwork/utils/timeout.py +38 -0
- sknetwork/utils/verbose.py +37 -0
- sknetwork/utils/ward.py +60 -0
- sknetwork/visualization/__init__.py +4 -0
- sknetwork/visualization/colors.py +34 -0
- sknetwork/visualization/dendrograms.py +229 -0
- sknetwork/visualization/graphs.py +819 -0
- sknetwork/visualization/tests/__init__.py +1 -0
- sknetwork/visualization/tests/test_dendrograms.py +53 -0
- sknetwork/visualization/tests/test_graphs.py +167 -0
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""Tests for Louvain"""
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from sknetwork.data import house, karate_club
|
|
9
|
+
from sknetwork.data.parse import from_edge_list
|
|
10
|
+
from sknetwork.data.test_graphs import *
|
|
11
|
+
from sknetwork.linalg.operators import Regularizer
|
|
12
|
+
from sknetwork.linalg.ppr_solver import get_pagerank
|
|
13
|
+
from sknetwork.utils.check import is_proba_array
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TestPPR(unittest.TestCase):
|
|
17
|
+
|
|
18
|
+
def test_diteration(self):
|
|
19
|
+
# test convergence by tolerance
|
|
20
|
+
for adjacency in [house(), test_graph(), test_digraph()]:
|
|
21
|
+
seeds = np.ones(adjacency.shape[0]) / adjacency.shape[0]
|
|
22
|
+
pr = get_pagerank(adjacency, damping_factor=0.85, n_iter=100, tol=10, solver='diteration', seeds=seeds)
|
|
23
|
+
self.assertTrue(is_proba_array(pr))
|
|
24
|
+
|
|
25
|
+
# test graph with some null out-degree
|
|
26
|
+
adjacency = from_edge_list([(0, 1)])
|
|
27
|
+
seeds = np.ones(adjacency.shape[0]) / adjacency.shape[0]
|
|
28
|
+
pr = get_pagerank(adjacency, damping_factor=0.85, n_iter=100, tol=10, solver='diteration', seeds=seeds)
|
|
29
|
+
self.assertTrue(is_proba_array(pr))
|
|
30
|
+
|
|
31
|
+
# test invalid entry
|
|
32
|
+
adjacency = Regularizer(house(), 0.1)
|
|
33
|
+
seeds = np.ones(adjacency.shape[0]) / adjacency.shape[0]
|
|
34
|
+
with self.assertRaises(ValueError):
|
|
35
|
+
get_pagerank(adjacency, damping_factor=0.85, n_iter=100, tol=10, solver='diteration', seeds=seeds)
|
|
36
|
+
|
|
37
|
+
def test_push(self):
|
|
38
|
+
# test convergence by tolerance
|
|
39
|
+
adjacency = karate_club()
|
|
40
|
+
seeds = np.ones(adjacency.shape[0]) / adjacency.shape[0]
|
|
41
|
+
pr = get_pagerank(adjacency, damping_factor=0.85,
|
|
42
|
+
n_iter=100, tol=1e-1, solver='push', seeds=seeds)
|
|
43
|
+
self.assertTrue(is_proba_array(pr))
|
|
44
|
+
|
|
45
|
+
def test_piteration(self):
|
|
46
|
+
# test on SparseLR matrix
|
|
47
|
+
adjacency = Regularizer(house(), 0.1)
|
|
48
|
+
seeds = np.ones(adjacency.shape[0]) / adjacency.shape[0]
|
|
49
|
+
pr = get_pagerank(adjacency, damping_factor=0.85, n_iter=100, tol=10, solver='piteration', seeds=seeds)
|
|
50
|
+
self.assertTrue(is_proba_array(pr))
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""Tests for embeddings metrics."""
|
|
4
|
+
|
|
5
|
+
import unittest
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from sknetwork.data import house, star_wars
|
|
10
|
+
from sknetwork.linalg.sparse_lowrank import SparseLR
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestSparseLowRank(unittest.TestCase):
|
|
14
|
+
|
|
15
|
+
def setUp(self):
|
|
16
|
+
"""Simple regularized adjacency and biadjacency for tests."""
|
|
17
|
+
self.undirected = SparseLR(house(), [(np.ones(5), np.ones(5))])
|
|
18
|
+
self.bipartite = SparseLR(star_wars(), [(np.ones(4), np.ones(3))])
|
|
19
|
+
|
|
20
|
+
def test_init(self):
|
|
21
|
+
with self.assertRaises(ValueError):
|
|
22
|
+
SparseLR(house(), [(np.ones(5), np.ones(4))])
|
|
23
|
+
with self.assertRaises(ValueError):
|
|
24
|
+
SparseLR(house(), [(np.ones(4), np.ones(5))])
|
|
25
|
+
|
|
26
|
+
def test_addition(self):
|
|
27
|
+
addition = self.undirected + self.undirected
|
|
28
|
+
expected = SparseLR(2 * house(), [(np.ones(5), 2 * np.ones(5))])
|
|
29
|
+
err = (addition.sparse_mat - expected.sparse_mat).count_nonzero()
|
|
30
|
+
self.assertEqual(err, 0)
|
|
31
|
+
x = np.random.rand(5)
|
|
32
|
+
self.assertAlmostEqual(np.linalg.norm(addition.dot(x) - expected.dot(x)), 0)
|
|
33
|
+
|
|
34
|
+
def test_operations(self):
|
|
35
|
+
adjacency = self.undirected.sparse_mat
|
|
36
|
+
slr = -self.undirected
|
|
37
|
+
slr += adjacency
|
|
38
|
+
slr -= adjacency
|
|
39
|
+
slr.left_sparse_dot(adjacency)
|
|
40
|
+
slr.right_sparse_dot(adjacency)
|
|
41
|
+
slr.astype(float)
|
|
42
|
+
|
|
43
|
+
def test_product(self):
|
|
44
|
+
prod = self.undirected.dot(np.ones(5))
|
|
45
|
+
self.assertEqual(prod.shape, (5,))
|
|
46
|
+
prod = self.bipartite.dot(np.ones(3))
|
|
47
|
+
self.assertEqual(np.linalg.norm(prod - np.array([5., 4., 6., 5.])), 0.)
|
|
48
|
+
prod = self.bipartite.dot(0.5 * np.ones(3))
|
|
49
|
+
self.assertEqual(np.linalg.norm(prod - np.array([2.5, 2., 3., 2.5])), 0.)
|
|
50
|
+
prod = (2 * self.bipartite).dot(0.5 * np.ones(3))
|
|
51
|
+
self.assertEqual(np.linalg.norm(prod - 2 * np.array([2.5, 2., 3., 2.5])), 0.)
|
|
52
|
+
|
|
53
|
+
def test_transposition(self):
|
|
54
|
+
transposed = self.undirected.T
|
|
55
|
+
error = (self.undirected.sparse_mat - transposed.sparse_mat).data
|
|
56
|
+
self.assertEqual(abs(error).sum(), 0.)
|
|
57
|
+
transposed = self.bipartite.T
|
|
58
|
+
x, y = transposed.low_rank_tuples[0]
|
|
59
|
+
self.assertTrue((x == np.ones(3)).all())
|
|
60
|
+
self.assertTrue((y == np.ones(4)).all())
|
|
61
|
+
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""Tests for svd."""
|
|
4
|
+
|
|
5
|
+
import unittest
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from sknetwork.data import movie_actor
|
|
10
|
+
from sknetwork.linalg import LanczosSVD, SparseLR
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def svd_err(matrix, u, v, sigma):
|
|
14
|
+
"""Approximation error for singular vectors."""
|
|
15
|
+
err = matrix.dot(v) - u * sigma
|
|
16
|
+
return np.linalg.norm(err)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# noinspection DuplicatedCode
|
|
20
|
+
class TestSolvers(unittest.TestCase):
|
|
21
|
+
|
|
22
|
+
def setUp(self):
|
|
23
|
+
"""Simple biadjacency for tests."""
|
|
24
|
+
self.biadjacency = movie_actor()
|
|
25
|
+
n_row, n_col = self.biadjacency.shape
|
|
26
|
+
self.slr = SparseLR(self.biadjacency, [(np.random.rand(n_row), np.random.rand(n_col))])
|
|
27
|
+
|
|
28
|
+
def test_lanczos(self):
|
|
29
|
+
solver = LanczosSVD()
|
|
30
|
+
solver.fit(self.biadjacency, 2)
|
|
31
|
+
self.assertEqual(len(solver.singular_values_), 2)
|
|
32
|
+
self.assertAlmostEqual(svd_err(self.biadjacency, solver.singular_vectors_left_, solver.singular_vectors_right_,
|
|
33
|
+
solver.singular_values_), 0)
|
|
34
|
+
|
|
35
|
+
solver.fit(self.slr, 2)
|
|
36
|
+
self.assertEqual(len(solver.singular_values_), 2)
|
|
37
|
+
self.assertAlmostEqual(svd_err(self.slr, solver.singular_vectors_left_, solver.singular_vectors_right_,
|
|
38
|
+
solver.singular_values_), 0)
|
|
@@ -0,0 +1,4 @@
|
|
|
1
|
+
"""link prediction module"""
|
|
2
|
+
from sknetwork.linkpred.first_order import CommonNeighbors, JaccardIndex, SaltonIndex, SorensenIndex, HubPromotedIndex,\
|
|
3
|
+
HubDepressedIndex, AdamicAdar, ResourceAllocation, PreferentialAttachment
|
|
4
|
+
from sknetwork.linkpred.postprocessing import is_edge, whitened_sigmoid
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created on May, 2020
|
|
5
|
+
@author: Nathan de Lara <nathan.delara@polytechnique.org>
|
|
6
|
+
"""
|
|
7
|
+
from abc import ABC
|
|
8
|
+
from typing import Union, Iterable, Tuple
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from sknetwork.utils.base import Algorithm
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BaseLinkPred(Algorithm, ABC):
|
|
16
|
+
"""Base class for link prediction algorithms."""
|
|
17
|
+
|
|
18
|
+
def _predict_base(self, source: int, targets: Iterable):
|
|
19
|
+
"""Prediction for a single node and multiple targets"""
|
|
20
|
+
raise NotImplementedError
|
|
21
|
+
|
|
22
|
+
def _predict_node(self, node: int):
|
|
23
|
+
"""Prediction for a single node."""
|
|
24
|
+
raise NotImplementedError
|
|
25
|
+
|
|
26
|
+
def _predict_nodes(self, nodes: np.ndarray):
|
|
27
|
+
"""Prediction for multiple nodes."""
|
|
28
|
+
preds = []
|
|
29
|
+
for node in nodes:
|
|
30
|
+
preds.append(self._predict_node(node))
|
|
31
|
+
return np.array(preds)
|
|
32
|
+
|
|
33
|
+
def _predict_edge(self, source: int, target: int):
|
|
34
|
+
"""Prediction for a single edge."""
|
|
35
|
+
return self._predict_base(source, [target])[0]
|
|
36
|
+
|
|
37
|
+
def _predict_edges(self, edges: np.ndarray):
|
|
38
|
+
"""Prediction for a list of edges."""
|
|
39
|
+
preds = []
|
|
40
|
+
for edge in edges:
|
|
41
|
+
i, j = edge[0], edge[1]
|
|
42
|
+
preds.append(self._predict_edge(i, j))
|
|
43
|
+
return np.array(preds)
|
|
44
|
+
|
|
45
|
+
def predict(self, query: Union[int, Iterable, Tuple]):
|
|
46
|
+
"""Compute similarity scores.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
query : int, list, array or Tuple
|
|
51
|
+
* If int i, return the similarities s(i, j) for all j.
|
|
52
|
+
* If list or array integers, return s(i, j) for i in query, for all j as array.
|
|
53
|
+
* If tuple (i, j), return the similarity s(i, j).
|
|
54
|
+
* If list of tuples or array of shape (n_queries, 2), return s(i, j) for (i, j) in query as array.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
predictions : int, float or array
|
|
59
|
+
The prediction scores.
|
|
60
|
+
"""
|
|
61
|
+
if np.issubdtype(type(query), np.integer):
|
|
62
|
+
return self._predict_node(query)
|
|
63
|
+
if isinstance(query, Tuple):
|
|
64
|
+
return self._predict_edge(query[0], query[1])
|
|
65
|
+
if isinstance(query, list):
|
|
66
|
+
query = np.array(query)
|
|
67
|
+
if isinstance(query, np.ndarray):
|
|
68
|
+
if query.ndim == 1:
|
|
69
|
+
return self._predict_nodes(query)
|
|
70
|
+
elif query.ndim == 2 and query.shape[1] == 2:
|
|
71
|
+
return self._predict_edges(query)
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError("Query not understood.")
|
|
74
|
+
else:
|
|
75
|
+
raise ValueError("Query not understood.")
|
|
76
|
+
|
|
77
|
+
def fit_predict(self, adjacency, query):
|
|
78
|
+
"""Fit algorithm to data and compute scores for requested edges."""
|
|
79
|
+
self.fit(adjacency)
|
|
80
|
+
return self.predict(query)
|