scikit-network 0.30.0__cp39-cp39-win_amd64.whl → 0.32.1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-network might be problematic. Click here for more details.
- {scikit_network-0.30.0.dist-info → scikit_network-0.32.1.dist-info}/AUTHORS.rst +3 -0
- {scikit_network-0.30.0.dist-info → scikit_network-0.32.1.dist-info}/METADATA +31 -3
- scikit_network-0.32.1.dist-info/RECORD +228 -0
- {scikit_network-0.30.0.dist-info → scikit_network-0.32.1.dist-info}/WHEEL +1 -1
- sknetwork/__init__.py +1 -1
- sknetwork/base.py +67 -0
- sknetwork/classification/base.py +24 -24
- sknetwork/classification/base_rank.py +17 -25
- sknetwork/classification/diffusion.py +35 -35
- sknetwork/classification/knn.py +24 -21
- sknetwork/classification/metrics.py +1 -1
- sknetwork/classification/pagerank.py +10 -10
- sknetwork/classification/propagation.py +23 -20
- sknetwork/classification/tests/test_diffusion.py +13 -3
- sknetwork/classification/vote.cp39-win_amd64.pyd +0 -0
- sknetwork/classification/vote.cpp +14482 -10351
- sknetwork/classification/vote.pyx +1 -3
- sknetwork/clustering/__init__.py +3 -1
- sknetwork/clustering/base.py +36 -40
- sknetwork/clustering/kcenters.py +253 -0
- sknetwork/clustering/leiden.py +241 -0
- sknetwork/clustering/leiden_core.cp39-win_amd64.pyd +0 -0
- sknetwork/clustering/leiden_core.cpp +31564 -0
- sknetwork/clustering/leiden_core.pyx +124 -0
- sknetwork/clustering/louvain.py +133 -102
- sknetwork/clustering/louvain_core.cp39-win_amd64.pyd +0 -0
- sknetwork/clustering/louvain_core.cpp +22457 -18792
- sknetwork/clustering/louvain_core.pyx +86 -96
- sknetwork/clustering/postprocess.py +2 -2
- sknetwork/clustering/propagation_clustering.py +15 -19
- sknetwork/clustering/tests/test_API.py +8 -4
- sknetwork/clustering/tests/test_kcenters.py +92 -0
- sknetwork/clustering/tests/test_leiden.py +34 -0
- sknetwork/clustering/tests/test_louvain.py +3 -4
- sknetwork/data/__init__.py +2 -1
- sknetwork/data/base.py +28 -0
- sknetwork/data/load.py +38 -37
- sknetwork/data/models.py +18 -18
- sknetwork/data/parse.py +54 -33
- sknetwork/data/test_graphs.py +2 -2
- sknetwork/data/tests/test_API.py +1 -1
- sknetwork/data/tests/test_base.py +14 -0
- sknetwork/data/tests/test_load.py +1 -1
- sknetwork/data/tests/test_parse.py +9 -12
- sknetwork/data/tests/test_test_graphs.py +1 -2
- sknetwork/data/toy_graphs.py +18 -18
- sknetwork/embedding/__init__.py +0 -1
- sknetwork/embedding/base.py +21 -20
- sknetwork/embedding/force_atlas.py +3 -2
- sknetwork/embedding/louvain_embedding.py +2 -2
- sknetwork/embedding/random_projection.py +5 -3
- sknetwork/embedding/spectral.py +0 -73
- sknetwork/embedding/tests/test_API.py +4 -28
- sknetwork/embedding/tests/test_louvain_embedding.py +4 -9
- sknetwork/embedding/tests/test_random_projection.py +2 -2
- sknetwork/embedding/tests/test_spectral.py +5 -8
- sknetwork/embedding/tests/test_svd.py +1 -1
- sknetwork/gnn/base.py +4 -4
- sknetwork/gnn/base_layer.py +3 -3
- sknetwork/gnn/gnn_classifier.py +45 -89
- sknetwork/gnn/layer.py +1 -1
- sknetwork/gnn/loss.py +1 -1
- sknetwork/gnn/optimizer.py +4 -3
- sknetwork/gnn/tests/test_base_layer.py +4 -4
- sknetwork/gnn/tests/test_gnn_classifier.py +12 -35
- sknetwork/gnn/utils.py +8 -8
- sknetwork/hierarchy/base.py +29 -2
- sknetwork/hierarchy/louvain_hierarchy.py +45 -41
- sknetwork/hierarchy/paris.cp39-win_amd64.pyd +0 -0
- sknetwork/hierarchy/paris.cpp +27369 -22852
- sknetwork/hierarchy/paris.pyx +7 -9
- sknetwork/hierarchy/postprocess.py +16 -16
- sknetwork/hierarchy/tests/test_API.py +1 -1
- sknetwork/hierarchy/tests/test_algos.py +5 -0
- sknetwork/hierarchy/tests/test_metrics.py +1 -1
- sknetwork/linalg/__init__.py +1 -1
- sknetwork/linalg/diteration.cp39-win_amd64.pyd +0 -0
- sknetwork/linalg/diteration.cpp +13474 -9454
- sknetwork/linalg/diteration.pyx +0 -2
- sknetwork/linalg/eig_solver.py +1 -1
- sknetwork/linalg/{normalization.py → normalizer.py} +18 -15
- sknetwork/linalg/operators.py +1 -1
- sknetwork/linalg/ppr_solver.py +1 -1
- sknetwork/linalg/push.cp39-win_amd64.pyd +0 -0
- sknetwork/linalg/push.cpp +22993 -18807
- sknetwork/linalg/push.pyx +0 -2
- sknetwork/linalg/svd_solver.py +1 -1
- sknetwork/linalg/tests/test_normalization.py +3 -7
- sknetwork/linalg/tests/test_operators.py +4 -8
- sknetwork/linalg/tests/test_ppr.py +1 -1
- sknetwork/linkpred/base.py +13 -2
- sknetwork/linkpred/nn.py +6 -6
- sknetwork/log.py +19 -0
- sknetwork/path/__init__.py +4 -3
- sknetwork/path/dag.py +54 -0
- sknetwork/path/distances.py +98 -0
- sknetwork/path/search.py +13 -47
- sknetwork/path/shortest_path.py +37 -162
- sknetwork/path/tests/test_dag.py +37 -0
- sknetwork/path/tests/test_distances.py +62 -0
- sknetwork/path/tests/test_search.py +26 -11
- sknetwork/path/tests/test_shortest_path.py +31 -36
- sknetwork/ranking/__init__.py +0 -1
- sknetwork/ranking/base.py +13 -8
- sknetwork/ranking/betweenness.cp39-win_amd64.pyd +0 -0
- sknetwork/ranking/betweenness.cpp +5709 -3017
- sknetwork/ranking/betweenness.pyx +0 -2
- sknetwork/ranking/closeness.py +7 -10
- sknetwork/ranking/pagerank.py +14 -14
- sknetwork/ranking/postprocess.py +12 -3
- sknetwork/ranking/tests/test_API.py +2 -4
- sknetwork/ranking/tests/test_betweenness.py +3 -3
- sknetwork/ranking/tests/test_closeness.py +3 -7
- sknetwork/ranking/tests/test_pagerank.py +11 -5
- sknetwork/ranking/tests/test_postprocess.py +5 -0
- sknetwork/regression/base.py +19 -2
- sknetwork/regression/diffusion.py +24 -10
- sknetwork/regression/tests/test_diffusion.py +8 -0
- sknetwork/test_base.py +35 -0
- sknetwork/test_log.py +15 -0
- sknetwork/topology/__init__.py +7 -8
- sknetwork/topology/cliques.cp39-win_amd64.pyd +0 -0
- sknetwork/topology/{kcliques.cpp → cliques.cpp} +23412 -20276
- sknetwork/topology/cliques.pyx +149 -0
- sknetwork/topology/core.cp39-win_amd64.pyd +0 -0
- sknetwork/topology/{kcore.cpp → core.cpp} +21732 -18867
- sknetwork/topology/core.pyx +90 -0
- sknetwork/topology/cycles.py +243 -0
- sknetwork/topology/minheap.cp39-win_amd64.pyd +0 -0
- sknetwork/{utils → topology}/minheap.cpp +19452 -15368
- sknetwork/{utils → topology}/minheap.pxd +1 -3
- sknetwork/{utils → topology}/minheap.pyx +1 -3
- sknetwork/topology/structure.py +3 -43
- sknetwork/topology/tests/test_cliques.py +11 -11
- sknetwork/topology/tests/test_core.py +19 -0
- sknetwork/topology/tests/test_cycles.py +65 -0
- sknetwork/topology/tests/test_structure.py +2 -16
- sknetwork/topology/tests/test_triangles.py +11 -15
- sknetwork/topology/tests/test_wl.py +72 -0
- sknetwork/topology/triangles.cp39-win_amd64.pyd +0 -0
- sknetwork/topology/triangles.cpp +5056 -2696
- sknetwork/topology/triangles.pyx +74 -89
- sknetwork/topology/weisfeiler_lehman.py +56 -86
- sknetwork/topology/weisfeiler_lehman_core.cp39-win_amd64.pyd +0 -0
- sknetwork/topology/weisfeiler_lehman_core.cpp +14727 -10622
- sknetwork/topology/weisfeiler_lehman_core.pyx +0 -2
- sknetwork/utils/__init__.py +1 -31
- sknetwork/utils/check.py +2 -2
- sknetwork/utils/format.py +5 -3
- sknetwork/utils/membership.py +2 -2
- sknetwork/utils/tests/test_check.py +3 -3
- sknetwork/utils/tests/test_format.py +3 -1
- sknetwork/utils/values.py +1 -1
- sknetwork/visualization/__init__.py +2 -2
- sknetwork/visualization/dendrograms.py +55 -7
- sknetwork/visualization/graphs.py +292 -72
- sknetwork/visualization/tests/test_dendrograms.py +9 -9
- sknetwork/visualization/tests/test_graphs.py +71 -62
- scikit_network-0.30.0.dist-info/RECORD +0 -227
- sknetwork/embedding/louvain_hierarchy.py +0 -142
- sknetwork/embedding/tests/test_louvain_hierarchy.py +0 -19
- sknetwork/path/metrics.py +0 -148
- sknetwork/path/tests/test_metrics.py +0 -29
- sknetwork/ranking/harmonic.py +0 -82
- sknetwork/topology/dag.py +0 -74
- sknetwork/topology/dag_core.cp39-win_amd64.pyd +0 -0
- sknetwork/topology/dag_core.cpp +0 -23350
- sknetwork/topology/dag_core.pyx +0 -38
- sknetwork/topology/kcliques.cp39-win_amd64.pyd +0 -0
- sknetwork/topology/kcliques.pyx +0 -193
- sknetwork/topology/kcore.cp39-win_amd64.pyd +0 -0
- sknetwork/topology/kcore.pyx +0 -120
- sknetwork/topology/tests/test_cores.py +0 -21
- sknetwork/topology/tests/test_dag.py +0 -26
- sknetwork/topology/tests/test_wl_coloring.py +0 -49
- sknetwork/topology/tests/test_wl_kernel.py +0 -31
- sknetwork/utils/base.py +0 -35
- sknetwork/utils/minheap.cp39-win_amd64.pyd +0 -0
- sknetwork/utils/simplex.py +0 -140
- sknetwork/utils/tests/test_base.py +0 -28
- sknetwork/utils/tests/test_bunch.py +0 -16
- sknetwork/utils/tests/test_projection_simplex.py +0 -33
- sknetwork/utils/tests/test_verbose.py +0 -15
- sknetwork/utils/verbose.py +0 -37
- {scikit_network-0.30.0.dist-info → scikit_network-0.32.1.dist-info}/LICENSE +0 -0
- {scikit_network-0.30.0.dist-info → scikit_network-0.32.1.dist-info}/top_level.txt +0 -0
- /sknetwork/{utils → data}/timeout.py +0 -0
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
""""tests for search.py"""
|
|
4
|
+
|
|
5
|
+
import unittest
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from sknetwork.data import cyclic_digraph
|
|
10
|
+
from sknetwork.data.test_graphs import *
|
|
11
|
+
from sknetwork.path import get_dag
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestSearch(unittest.TestCase):
|
|
15
|
+
|
|
16
|
+
def test(self):
|
|
17
|
+
adjacency = cyclic_digraph(3)
|
|
18
|
+
dag = get_dag(adjacency)
|
|
19
|
+
self.assertEqual(dag.nnz, 2)
|
|
20
|
+
|
|
21
|
+
adjacency = test_graph_empty()
|
|
22
|
+
dag = get_dag(adjacency)
|
|
23
|
+
self.assertEqual(dag.nnz, 0)
|
|
24
|
+
|
|
25
|
+
adjacency = test_graph()
|
|
26
|
+
dag = get_dag(adjacency)
|
|
27
|
+
self.assertEqual(dag.nnz, 12)
|
|
28
|
+
dag = get_dag(adjacency, order=np.arange(10) % 3)
|
|
29
|
+
self.assertEqual(dag.nnz, 10)
|
|
30
|
+
|
|
31
|
+
adjacency = test_disconnected_graph()
|
|
32
|
+
dag = get_dag(adjacency, 3)
|
|
33
|
+
self.assertEqual(dag.nnz, 1)
|
|
34
|
+
|
|
35
|
+
adjacency = test_digraph()
|
|
36
|
+
dag = get_dag(adjacency, 1)
|
|
37
|
+
self.assertEqual(dag.nnz, 4)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
""""tests for distances.py"""
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
from sknetwork.data.test_graphs import *
|
|
7
|
+
from sknetwork.path.distances import get_distances
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TestDistances(unittest.TestCase):
|
|
11
|
+
|
|
12
|
+
def test_input(self):
|
|
13
|
+
adjacency = test_graph()
|
|
14
|
+
with self.assertRaises(ValueError):
|
|
15
|
+
get_distances(adjacency)
|
|
16
|
+
with self.assertRaises(ValueError):
|
|
17
|
+
get_distances(adjacency, source=0, source_row=5)
|
|
18
|
+
|
|
19
|
+
def test_algo(self):
|
|
20
|
+
adjacency = test_graph()
|
|
21
|
+
distances = get_distances(adjacency, 0)
|
|
22
|
+
distances_ = np.array([0, 1, 3, 2, 2, 3, 2, 3, 4, 4])
|
|
23
|
+
self.assertTrue(all(distances == distances_))
|
|
24
|
+
distances = get_distances(adjacency, 0, transpose=True)
|
|
25
|
+
self.assertTrue(all(distances == distances_))
|
|
26
|
+
distances = get_distances(adjacency, [0, 5])
|
|
27
|
+
distances_ = np.array([0, 1, 3, 2, 1, 0, 1, 2, 4, 3])
|
|
28
|
+
self.assertTrue(all(distances == distances_))
|
|
29
|
+
|
|
30
|
+
adjacency = test_graph_empty()
|
|
31
|
+
source = [0, 3]
|
|
32
|
+
distances = get_distances(adjacency, source)
|
|
33
|
+
distances_ = -np.ones(len(distances), dtype=int)
|
|
34
|
+
distances_[source] = 0
|
|
35
|
+
self.assertTrue(all(distances == distances_))
|
|
36
|
+
|
|
37
|
+
adjacency = test_digraph()
|
|
38
|
+
distances = get_distances(adjacency, [0])
|
|
39
|
+
distances_ = np.array([0, 1, 3, 2, 2, 3, -1, -1, -1, -1])
|
|
40
|
+
self.assertTrue(all(distances == distances_))
|
|
41
|
+
distances = get_distances(adjacency, [0], transpose=True)
|
|
42
|
+
self.assertTrue(sum(distances < 0) == 9)
|
|
43
|
+
distances = get_distances(adjacency, [0, 5], transpose=True)
|
|
44
|
+
distances_ = np.array([0, 2, -1, -1, 1, 0, 1, -1, -1, -1])
|
|
45
|
+
self.assertTrue(all(distances == distances_))
|
|
46
|
+
|
|
47
|
+
biadjacency = test_bigraph()
|
|
48
|
+
distances_row, distances_col = get_distances(biadjacency, [0])
|
|
49
|
+
distances_row_, distances_col_ = np.array([0, -1, 2, -1, -1, -1]), np.array([3, 1, -1, -1, -1, -1, -1, -1])
|
|
50
|
+
self.assertTrue(all(distances_row == distances_row_))
|
|
51
|
+
self.assertTrue(all(distances_col == distances_col_))
|
|
52
|
+
|
|
53
|
+
adjacency = test_graph()
|
|
54
|
+
distances_row, distances_col = get_distances(adjacency, source_col=[0])
|
|
55
|
+
self.assertTrue(all(distances_row % 2))
|
|
56
|
+
self.assertTrue(all((distances_col + 1) % 2))
|
|
57
|
+
|
|
58
|
+
biadjacency = test_bigraph()
|
|
59
|
+
distances_row, distances_col = get_distances(biadjacency, source=0, source_col=[1, 2])
|
|
60
|
+
distances_row_, distances_col_ = np.array([0, 1, 1, -1, -1, -1]), np.array([2, 0, 0, 2, -1, -1, -1, -1])
|
|
61
|
+
self.assertTrue(all(distances_row == distances_row_))
|
|
62
|
+
self.assertTrue(all(distances_col == distances_col_))
|
|
@@ -7,19 +7,34 @@ import unittest
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
|
|
9
9
|
from sknetwork.data import cyclic_digraph
|
|
10
|
-
from sknetwork.
|
|
10
|
+
from sknetwork.data.test_graphs import *
|
|
11
|
+
from sknetwork.path import breadth_first_search
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class TestSearch(unittest.TestCase):
|
|
14
15
|
|
|
15
|
-
def setUp(self) -> None:
|
|
16
|
-
"""Load graph for tests."""
|
|
17
|
-
self.adjacency = cyclic_digraph(3).astype(bool)
|
|
18
|
-
|
|
19
16
|
def test_bfs(self):
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
17
|
+
adjacency = cyclic_digraph(3)
|
|
18
|
+
search = breadth_first_search(adjacency, 0)
|
|
19
|
+
search_ = np.arange(3)
|
|
20
|
+
self.assertTrue(all(search == search_))
|
|
21
|
+
|
|
22
|
+
adjacency = test_graph_empty()
|
|
23
|
+
search = breadth_first_search(adjacency, 0)
|
|
24
|
+
search_ = np.array([0])
|
|
25
|
+
self.assertTrue(all(search == search_))
|
|
26
|
+
|
|
27
|
+
adjacency = test_graph()
|
|
28
|
+
search = breadth_first_search(adjacency, 3)
|
|
29
|
+
search_ = np.array([3, 1, 2, 0, 4, 6, 8, 5, 7, 9])
|
|
30
|
+
self.assertTrue(all(search == search_))
|
|
31
|
+
|
|
32
|
+
adjacency = test_disconnected_graph()
|
|
33
|
+
search = breadth_first_search(adjacency, 2)
|
|
34
|
+
search_ = np.array([2, 3])
|
|
35
|
+
self.assertTrue(all(search == search_))
|
|
36
|
+
|
|
37
|
+
adjacency = test_digraph()
|
|
38
|
+
search = breadth_first_search(adjacency, 1)
|
|
39
|
+
search_ = {1, 3, 4, 2, 5}
|
|
40
|
+
self.assertTrue(set(list(search)) == search_)
|
|
@@ -3,43 +3,38 @@
|
|
|
3
3
|
""""tests for shortest_path.py"""
|
|
4
4
|
import unittest
|
|
5
5
|
|
|
6
|
-
from sknetwork.data import
|
|
7
|
-
from sknetwork.path.shortest_path import
|
|
8
|
-
import numpy as np
|
|
9
|
-
from scipy.sparse import csr_matrix
|
|
6
|
+
from sknetwork.data.test_graphs import *
|
|
7
|
+
from sknetwork.path.shortest_path import get_shortest_path
|
|
10
8
|
|
|
11
9
|
|
|
12
10
|
class TestShortestPath(unittest.TestCase):
|
|
13
11
|
|
|
14
|
-
def
|
|
15
|
-
adjacency =
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
self.
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
self.
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
adj = csr_matrix(adj)
|
|
44
|
-
with self.assertRaises(ValueError):
|
|
45
|
-
get_distances(adj, method='BF', n_jobs=1)
|
|
12
|
+
def test_path(self):
|
|
13
|
+
adjacency = test_graph_empty()
|
|
14
|
+
path = get_shortest_path(adjacency, 0)
|
|
15
|
+
self.assertEqual(path.nnz, 0)
|
|
16
|
+
|
|
17
|
+
adjacency = test_graph()
|
|
18
|
+
path = get_shortest_path(adjacency, 0)
|
|
19
|
+
self.assertEqual(path.nnz, 10)
|
|
20
|
+
path = get_shortest_path(adjacency, [0, 4, 6])
|
|
21
|
+
self.assertEqual(path.nnz, 10)
|
|
22
|
+
path = get_shortest_path(adjacency, np.arange(10))
|
|
23
|
+
self.assertEqual(path.nnz, 0)
|
|
24
|
+
path = get_shortest_path(adjacency, [0, 5])
|
|
25
|
+
self.assertEqual(path.nnz, 9)
|
|
26
|
+
|
|
27
|
+
adjacency = test_disconnected_graph()
|
|
28
|
+
path = get_shortest_path(adjacency, 4)
|
|
29
|
+
self.assertEqual(path.nnz, 5)
|
|
30
|
+
|
|
31
|
+
adjacency = test_digraph()
|
|
32
|
+
path = get_shortest_path(adjacency, 0)
|
|
33
|
+
self.assertEqual(path.nnz, 5)
|
|
34
|
+
|
|
35
|
+
biadjacency = test_bigraph()
|
|
36
|
+
path = get_shortest_path(biadjacency, 0)
|
|
37
|
+
self.assertEqual(path.nnz, 3)
|
|
38
|
+
self.assertTrue(path.shape[0] == np.sum(biadjacency.shape))
|
|
39
|
+
path = get_shortest_path(biadjacency, source_col=np.arange(biadjacency.shape[1]))
|
|
40
|
+
self.assertEqual(path.nnz, biadjacency.nnz)
|
sknetwork/ranking/__init__.py
CHANGED
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
from sknetwork.ranking.base import BaseRanking
|
|
3
3
|
from sknetwork.ranking.betweenness import Betweenness
|
|
4
4
|
from sknetwork.ranking.closeness import Closeness
|
|
5
|
-
from sknetwork.ranking.harmonic import Harmonic
|
|
6
5
|
from sknetwork.ranking.hits import HITS
|
|
7
6
|
from sknetwork.ranking.katz import Katz
|
|
8
7
|
from sknetwork.ranking.pagerank import PageRank
|
sknetwork/ranking/base.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
2
|
# -*- coding: utf-8 -*-
|
|
3
3
|
"""
|
|
4
|
-
Created
|
|
4
|
+
Created in November 2019
|
|
5
5
|
@author: Nathan de Lara <nathan.delara@polytechnique.org>
|
|
6
6
|
"""
|
|
7
7
|
from abc import ABC
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
|
|
11
|
-
from sknetwork.
|
|
11
|
+
from sknetwork.base import Algorithm
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class BaseRanking(Algorithm, ABC):
|
|
@@ -26,20 +26,25 @@ class BaseRanking(Algorithm, ABC):
|
|
|
26
26
|
def __init__(self):
|
|
27
27
|
self.scores_ = None
|
|
28
28
|
|
|
29
|
-
def
|
|
30
|
-
"""
|
|
29
|
+
def predict(self, columns: bool = False) -> np.ndarray:
|
|
30
|
+
"""Return the scores predicted by the algorithm.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
columns : bool
|
|
35
|
+
If ``True``, return the prediction for columns.
|
|
31
36
|
|
|
32
37
|
Returns
|
|
33
38
|
-------
|
|
34
39
|
scores : np.ndarray
|
|
35
40
|
Scores.
|
|
36
41
|
"""
|
|
37
|
-
|
|
42
|
+
if columns:
|
|
43
|
+
return self.scores_col_
|
|
38
44
|
return self.scores_
|
|
39
45
|
|
|
40
|
-
def
|
|
41
|
-
"""Fit algorithm to data and return the scores.
|
|
42
|
-
Same parameters as the ``fit`` method.
|
|
46
|
+
def fit_predict(self, *args, **kwargs) -> np.ndarray:
|
|
47
|
+
"""Fit algorithm to data and return the scores. Same parameters as the ``fit`` method.
|
|
43
48
|
|
|
44
49
|
Returns
|
|
45
50
|
-------
|
|
Binary file
|