scikit-network 0.28.3__cp39-cp39-macosx_12_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-network might be problematic. Click here for more details.
- scikit_network-0.28.3.dist-info/AUTHORS.rst +41 -0
- scikit_network-0.28.3.dist-info/LICENSE +34 -0
- scikit_network-0.28.3.dist-info/METADATA +457 -0
- scikit_network-0.28.3.dist-info/RECORD +240 -0
- scikit_network-0.28.3.dist-info/WHEEL +5 -0
- scikit_network-0.28.3.dist-info/top_level.txt +1 -0
- sknetwork/__init__.py +21 -0
- sknetwork/classification/__init__.py +8 -0
- sknetwork/classification/base.py +84 -0
- sknetwork/classification/base_rank.py +143 -0
- sknetwork/classification/diffusion.py +134 -0
- sknetwork/classification/knn.py +162 -0
- sknetwork/classification/metrics.py +205 -0
- sknetwork/classification/pagerank.py +66 -0
- sknetwork/classification/propagation.py +152 -0
- sknetwork/classification/tests/__init__.py +1 -0
- sknetwork/classification/tests/test_API.py +35 -0
- sknetwork/classification/tests/test_diffusion.py +37 -0
- sknetwork/classification/tests/test_knn.py +24 -0
- sknetwork/classification/tests/test_metrics.py +53 -0
- sknetwork/classification/tests/test_pagerank.py +20 -0
- sknetwork/classification/tests/test_propagation.py +24 -0
- sknetwork/classification/vote.cpython-39-darwin.so +0 -0
- sknetwork/classification/vote.pyx +58 -0
- sknetwork/clustering/__init__.py +7 -0
- sknetwork/clustering/base.py +102 -0
- sknetwork/clustering/kmeans.py +142 -0
- sknetwork/clustering/louvain.py +255 -0
- sknetwork/clustering/louvain_core.cpython-39-darwin.so +0 -0
- sknetwork/clustering/louvain_core.pyx +134 -0
- sknetwork/clustering/metrics.py +91 -0
- sknetwork/clustering/postprocess.py +66 -0
- sknetwork/clustering/propagation_clustering.py +108 -0
- sknetwork/clustering/tests/__init__.py +1 -0
- sknetwork/clustering/tests/test_API.py +37 -0
- sknetwork/clustering/tests/test_kmeans.py +47 -0
- sknetwork/clustering/tests/test_louvain.py +104 -0
- sknetwork/clustering/tests/test_metrics.py +50 -0
- sknetwork/clustering/tests/test_post_processing.py +23 -0
- sknetwork/clustering/tests/test_postprocess.py +39 -0
- sknetwork/data/__init__.py +5 -0
- sknetwork/data/load.py +408 -0
- sknetwork/data/models.py +459 -0
- sknetwork/data/parse.py +621 -0
- sknetwork/data/test_graphs.py +84 -0
- sknetwork/data/tests/__init__.py +1 -0
- sknetwork/data/tests/test_API.py +30 -0
- sknetwork/data/tests/test_load.py +95 -0
- sknetwork/data/tests/test_models.py +52 -0
- sknetwork/data/tests/test_parse.py +253 -0
- sknetwork/data/tests/test_test_graphs.py +30 -0
- sknetwork/data/tests/test_toy_graphs.py +68 -0
- sknetwork/data/toy_graphs.py +619 -0
- sknetwork/embedding/__init__.py +10 -0
- sknetwork/embedding/base.py +90 -0
- sknetwork/embedding/force_atlas.py +197 -0
- sknetwork/embedding/louvain_embedding.py +174 -0
- sknetwork/embedding/louvain_hierarchy.py +142 -0
- sknetwork/embedding/metrics.py +66 -0
- sknetwork/embedding/random_projection.py +133 -0
- sknetwork/embedding/spectral.py +214 -0
- sknetwork/embedding/spring.py +198 -0
- sknetwork/embedding/svd.py +363 -0
- sknetwork/embedding/tests/__init__.py +1 -0
- sknetwork/embedding/tests/test_API.py +73 -0
- sknetwork/embedding/tests/test_force_atlas.py +35 -0
- sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
- sknetwork/embedding/tests/test_louvain_hierarchy.py +19 -0
- sknetwork/embedding/tests/test_metrics.py +29 -0
- sknetwork/embedding/tests/test_random_projection.py +28 -0
- sknetwork/embedding/tests/test_spectral.py +84 -0
- sknetwork/embedding/tests/test_spring.py +50 -0
- sknetwork/embedding/tests/test_svd.py +37 -0
- sknetwork/flow/__init__.py +3 -0
- sknetwork/flow/flow.py +73 -0
- sknetwork/flow/tests/__init__.py +1 -0
- sknetwork/flow/tests/test_flow.py +17 -0
- sknetwork/flow/tests/test_utils.py +69 -0
- sknetwork/flow/utils.py +91 -0
- sknetwork/gnn/__init__.py +10 -0
- sknetwork/gnn/activation.py +117 -0
- sknetwork/gnn/base.py +155 -0
- sknetwork/gnn/base_activation.py +89 -0
- sknetwork/gnn/base_layer.py +109 -0
- sknetwork/gnn/gnn_classifier.py +381 -0
- sknetwork/gnn/layer.py +153 -0
- sknetwork/gnn/layers.py +127 -0
- sknetwork/gnn/loss.py +180 -0
- sknetwork/gnn/neighbor_sampler.py +65 -0
- sknetwork/gnn/optimizer.py +163 -0
- sknetwork/gnn/tests/__init__.py +1 -0
- sknetwork/gnn/tests/test_activation.py +56 -0
- sknetwork/gnn/tests/test_base.py +79 -0
- sknetwork/gnn/tests/test_base_layer.py +37 -0
- sknetwork/gnn/tests/test_gnn_classifier.py +192 -0
- sknetwork/gnn/tests/test_layers.py +80 -0
- sknetwork/gnn/tests/test_loss.py +33 -0
- sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
- sknetwork/gnn/tests/test_optimizer.py +43 -0
- sknetwork/gnn/tests/test_utils.py +93 -0
- sknetwork/gnn/utils.py +219 -0
- sknetwork/hierarchy/__init__.py +7 -0
- sknetwork/hierarchy/base.py +69 -0
- sknetwork/hierarchy/louvain_hierarchy.py +264 -0
- sknetwork/hierarchy/metrics.py +234 -0
- sknetwork/hierarchy/paris.cpython-39-darwin.so +0 -0
- sknetwork/hierarchy/paris.pyx +317 -0
- sknetwork/hierarchy/postprocess.py +350 -0
- sknetwork/hierarchy/tests/__init__.py +1 -0
- sknetwork/hierarchy/tests/test_API.py +25 -0
- sknetwork/hierarchy/tests/test_algos.py +29 -0
- sknetwork/hierarchy/tests/test_metrics.py +62 -0
- sknetwork/hierarchy/tests/test_postprocess.py +57 -0
- sknetwork/hierarchy/tests/test_ward.py +25 -0
- sknetwork/hierarchy/ward.py +94 -0
- sknetwork/linalg/__init__.py +9 -0
- sknetwork/linalg/basics.py +37 -0
- sknetwork/linalg/diteration.cpython-39-darwin.so +0 -0
- sknetwork/linalg/diteration.pyx +49 -0
- sknetwork/linalg/eig_solver.py +93 -0
- sknetwork/linalg/laplacian.py +15 -0
- sknetwork/linalg/normalization.py +66 -0
- sknetwork/linalg/operators.py +225 -0
- sknetwork/linalg/polynome.py +76 -0
- sknetwork/linalg/ppr_solver.py +170 -0
- sknetwork/linalg/push.cpython-39-darwin.so +0 -0
- sknetwork/linalg/push.pyx +73 -0
- sknetwork/linalg/sparse_lowrank.py +142 -0
- sknetwork/linalg/svd_solver.py +91 -0
- sknetwork/linalg/tests/__init__.py +1 -0
- sknetwork/linalg/tests/test_eig.py +44 -0
- sknetwork/linalg/tests/test_laplacian.py +18 -0
- sknetwork/linalg/tests/test_normalization.py +38 -0
- sknetwork/linalg/tests/test_operators.py +70 -0
- sknetwork/linalg/tests/test_polynome.py +38 -0
- sknetwork/linalg/tests/test_ppr.py +50 -0
- sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
- sknetwork/linalg/tests/test_svd.py +38 -0
- sknetwork/linkpred/__init__.py +4 -0
- sknetwork/linkpred/base.py +80 -0
- sknetwork/linkpred/first_order.py +508 -0
- sknetwork/linkpred/first_order_core.cpython-39-darwin.so +0 -0
- sknetwork/linkpred/first_order_core.pyx +315 -0
- sknetwork/linkpred/postprocessing.py +98 -0
- sknetwork/linkpred/tests/__init__.py +1 -0
- sknetwork/linkpred/tests/test_API.py +49 -0
- sknetwork/linkpred/tests/test_postprocessing.py +21 -0
- sknetwork/path/__init__.py +4 -0
- sknetwork/path/metrics.py +148 -0
- sknetwork/path/search.py +65 -0
- sknetwork/path/shortest_path.py +186 -0
- sknetwork/path/tests/__init__.py +1 -0
- sknetwork/path/tests/test_metrics.py +29 -0
- sknetwork/path/tests/test_search.py +25 -0
- sknetwork/path/tests/test_shortest_path.py +45 -0
- sknetwork/ranking/__init__.py +9 -0
- sknetwork/ranking/base.py +56 -0
- sknetwork/ranking/betweenness.cpython-39-darwin.so +0 -0
- sknetwork/ranking/betweenness.pyx +99 -0
- sknetwork/ranking/closeness.py +95 -0
- sknetwork/ranking/harmonic.py +82 -0
- sknetwork/ranking/hits.py +94 -0
- sknetwork/ranking/katz.py +81 -0
- sknetwork/ranking/pagerank.py +107 -0
- sknetwork/ranking/postprocess.py +25 -0
- sknetwork/ranking/tests/__init__.py +1 -0
- sknetwork/ranking/tests/test_API.py +34 -0
- sknetwork/ranking/tests/test_betweenness.py +38 -0
- sknetwork/ranking/tests/test_closeness.py +34 -0
- sknetwork/ranking/tests/test_hits.py +20 -0
- sknetwork/ranking/tests/test_pagerank.py +69 -0
- sknetwork/regression/__init__.py +4 -0
- sknetwork/regression/base.py +56 -0
- sknetwork/regression/diffusion.py +190 -0
- sknetwork/regression/tests/__init__.py +1 -0
- sknetwork/regression/tests/test_API.py +34 -0
- sknetwork/regression/tests/test_diffusion.py +48 -0
- sknetwork/sknetwork.py +3 -0
- sknetwork/topology/__init__.py +9 -0
- sknetwork/topology/dag.py +74 -0
- sknetwork/topology/dag_core.cpython-39-darwin.so +0 -0
- sknetwork/topology/dag_core.pyx +38 -0
- sknetwork/topology/kcliques.cpython-39-darwin.so +0 -0
- sknetwork/topology/kcliques.pyx +193 -0
- sknetwork/topology/kcore.cpython-39-darwin.so +0 -0
- sknetwork/topology/kcore.pyx +120 -0
- sknetwork/topology/structure.py +234 -0
- sknetwork/topology/tests/__init__.py +1 -0
- sknetwork/topology/tests/test_cliques.py +28 -0
- sknetwork/topology/tests/test_cores.py +21 -0
- sknetwork/topology/tests/test_dag.py +26 -0
- sknetwork/topology/tests/test_structure.py +99 -0
- sknetwork/topology/tests/test_triangles.py +42 -0
- sknetwork/topology/tests/test_wl_coloring.py +49 -0
- sknetwork/topology/tests/test_wl_kernel.py +31 -0
- sknetwork/topology/triangles.cpython-39-darwin.so +0 -0
- sknetwork/topology/triangles.pyx +166 -0
- sknetwork/topology/weisfeiler_lehman.py +163 -0
- sknetwork/topology/weisfeiler_lehman_core.cpython-39-darwin.so +0 -0
- sknetwork/topology/weisfeiler_lehman_core.pyx +116 -0
- sknetwork/utils/__init__.py +40 -0
- sknetwork/utils/base.py +35 -0
- sknetwork/utils/check.py +354 -0
- sknetwork/utils/co_neighbor.py +71 -0
- sknetwork/utils/format.py +219 -0
- sknetwork/utils/kmeans.py +89 -0
- sknetwork/utils/knn.py +166 -0
- sknetwork/utils/knn1d.cpython-39-darwin.so +0 -0
- sknetwork/utils/knn1d.pyx +80 -0
- sknetwork/utils/membership.py +82 -0
- sknetwork/utils/minheap.cpython-39-darwin.so +0 -0
- sknetwork/utils/minheap.pxd +22 -0
- sknetwork/utils/minheap.pyx +111 -0
- sknetwork/utils/neighbors.py +115 -0
- sknetwork/utils/seeds.py +75 -0
- sknetwork/utils/simplex.py +140 -0
- sknetwork/utils/tests/__init__.py +1 -0
- sknetwork/utils/tests/test_base.py +28 -0
- sknetwork/utils/tests/test_bunch.py +16 -0
- sknetwork/utils/tests/test_check.py +190 -0
- sknetwork/utils/tests/test_co_neighbor.py +43 -0
- sknetwork/utils/tests/test_format.py +61 -0
- sknetwork/utils/tests/test_kmeans.py +21 -0
- sknetwork/utils/tests/test_knn.py +32 -0
- sknetwork/utils/tests/test_membership.py +24 -0
- sknetwork/utils/tests/test_neighbors.py +41 -0
- sknetwork/utils/tests/test_projection_simplex.py +33 -0
- sknetwork/utils/tests/test_seeds.py +67 -0
- sknetwork/utils/tests/test_verbose.py +15 -0
- sknetwork/utils/tests/test_ward.py +20 -0
- sknetwork/utils/timeout.py +38 -0
- sknetwork/utils/verbose.py +37 -0
- sknetwork/utils/ward.py +60 -0
- sknetwork/visualization/__init__.py +4 -0
- sknetwork/visualization/colors.py +34 -0
- sknetwork/visualization/dendrograms.py +229 -0
- sknetwork/visualization/graphs.py +819 -0
- sknetwork/visualization/tests/__init__.py +1 -0
- sknetwork/visualization/tests/test_dendrograms.py +53 -0
- sknetwork/visualization/tests/test_graphs.py +167 -0
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created on November 2019
|
|
5
|
+
@author: Nathan de Lara <nathan.delara@polytechnique.org>
|
|
6
|
+
@author: Thomas Bonald <tbonald@enst.fr>
|
|
7
|
+
"""
|
|
8
|
+
from typing import Optional, Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from scipy import sparse
|
|
12
|
+
from scipy.spatial import cKDTree
|
|
13
|
+
|
|
14
|
+
from sknetwork.classification.base import BaseClassifier
|
|
15
|
+
from sknetwork.embedding.base import BaseEmbedding
|
|
16
|
+
from sknetwork.embedding.svd import GSVD
|
|
17
|
+
from sknetwork.linalg.normalization import normalize
|
|
18
|
+
from sknetwork.utils.check import check_n_neighbors, check_n_jobs
|
|
19
|
+
from sknetwork.utils.format import get_adjacency_seeds
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class KNN(BaseClassifier):
|
|
23
|
+
"""Node classification by K-nearest neighbors in the embedding space.
|
|
24
|
+
|
|
25
|
+
For bigraphs, classify rows only (see ``BiKNN`` for joint classification of rows and columns).
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
embedding_method :
|
|
30
|
+
Which algorithm to use to project the nodes in vector space. Default is ``GSVD``.
|
|
31
|
+
n_neighbors :
|
|
32
|
+
Number of nearest neighbors to consider.
|
|
33
|
+
factor_distance :
|
|
34
|
+
Power weighting factor :math:`\\alpha` applied to the distance to each neighbor.
|
|
35
|
+
Neighbor at distance :math:``d`` has weight :math:`1 / d^\\alpha`. Default is 2.
|
|
36
|
+
leaf_size :
|
|
37
|
+
Leaf size passed to KDTree.
|
|
38
|
+
p :
|
|
39
|
+
Which Minkowski p-norm to use. Default is 2 (Euclidean distance).
|
|
40
|
+
tol_nn :
|
|
41
|
+
Tolerance in nearest neighbors search; the k-th returned value is guaranteed to be no further
|
|
42
|
+
than ``1 + tol_nn`` times the distance to the actual k-th nearest neighbor.
|
|
43
|
+
n_jobs :
|
|
44
|
+
Number of jobs to schedule for parallel processing. If -1 is given all processors are used.
|
|
45
|
+
|
|
46
|
+
Attributes
|
|
47
|
+
----------
|
|
48
|
+
labels_ : np.ndarray, shape (n_labels,)
|
|
49
|
+
Label of each node.
|
|
50
|
+
membership_ : sparse.csr_matrix, shape (n_row, n_labels)
|
|
51
|
+
Membership matrix.
|
|
52
|
+
labels_row_ : np.ndarray
|
|
53
|
+
Labels of rows, for bipartite graphs.
|
|
54
|
+
labels_col_ : np.ndarray
|
|
55
|
+
Labels of columns, for bipartite graphs.
|
|
56
|
+
membership_row_ : sparse.csr_matrix, shape (n_row, n_labels)
|
|
57
|
+
Membership matrix of rows, for bipartite graphs.
|
|
58
|
+
membership_col_ : sparse.csr_matrix, shape (n_col, n_labels)
|
|
59
|
+
Membership matrix of columns, for bipartite graphs.
|
|
60
|
+
Example
|
|
61
|
+
-------
|
|
62
|
+
>>> from sknetwork.classification import KNN
|
|
63
|
+
>>> from sknetwork.embedding import GSVD
|
|
64
|
+
>>> from sknetwork.data import karate_club
|
|
65
|
+
>>> knn = KNN(GSVD(3), n_neighbors=1)
|
|
66
|
+
>>> graph = karate_club(metadata=True)
|
|
67
|
+
>>> adjacency = graph.adjacency
|
|
68
|
+
>>> labels_true = graph.labels
|
|
69
|
+
>>> seeds = {0: labels_true[0], 33: labels_true[33]}
|
|
70
|
+
>>> labels_pred = knn.fit_predict(adjacency, seeds)
|
|
71
|
+
>>> np.round(np.mean(labels_pred == labels_true), 2)
|
|
72
|
+
0.97
|
|
73
|
+
"""
|
|
74
|
+
def __init__(self, embedding_method: BaseEmbedding = GSVD(10), n_neighbors: int = 5,
|
|
75
|
+
factor_distance: float = 2, leaf_size: int = 16, p: float = 2, tol_nn: float = 0.01,
|
|
76
|
+
n_jobs: Optional[int] = None):
|
|
77
|
+
super(KNN, self).__init__()
|
|
78
|
+
|
|
79
|
+
self.embedding_method = embedding_method
|
|
80
|
+
self.n_neighbors = n_neighbors
|
|
81
|
+
self.factor_distance = factor_distance
|
|
82
|
+
self.leaf_size = leaf_size
|
|
83
|
+
self.p = p
|
|
84
|
+
self.tol_nn = tol_nn
|
|
85
|
+
self.n_jobs = check_n_jobs(n_jobs)
|
|
86
|
+
if self.n_jobs is None:
|
|
87
|
+
self.n_jobs = -1
|
|
88
|
+
self.bipartite = None
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def _instantiate_vars(seeds: Union[np.ndarray, dict]):
|
|
92
|
+
labels = seeds.astype(int)
|
|
93
|
+
index_seed = np.argwhere(labels >= 0).ravel()
|
|
94
|
+
index_remain = np.argwhere(labels < 0).ravel()
|
|
95
|
+
labels_seed = labels[index_seed]
|
|
96
|
+
return index_seed, index_remain, labels_seed
|
|
97
|
+
|
|
98
|
+
def _fit_core(self, n, labels_seed, embedding, index_seed, index_remain):
|
|
99
|
+
n_seeds = len(labels_seed)
|
|
100
|
+
embedding_seed = embedding[index_seed]
|
|
101
|
+
embedding_remain = embedding[index_remain]
|
|
102
|
+
n_neighbors = check_n_neighbors(self.n_neighbors, n_seeds)
|
|
103
|
+
tree = cKDTree(embedding_seed, self.leaf_size)
|
|
104
|
+
distances, neighbors = tree.query(embedding_remain, n_neighbors, self.tol_nn, self.p, workers=self.n_jobs)
|
|
105
|
+
|
|
106
|
+
if n_neighbors == 1:
|
|
107
|
+
distances = distances[:, np.newaxis]
|
|
108
|
+
neighbors = neighbors[:, np.newaxis]
|
|
109
|
+
|
|
110
|
+
labels_neighbor = labels_seed[neighbors]
|
|
111
|
+
index = (np.min(distances, axis=1) == 0)
|
|
112
|
+
weights_neighbor = np.zeros_like(distances).astype(float)
|
|
113
|
+
# take all seeds at distance zero, if any
|
|
114
|
+
weights_neighbor[index] = (distances[index] == 0).astype(float)
|
|
115
|
+
# assign weights with respect to distances for other
|
|
116
|
+
weights_neighbor[~index] = 1 / np.power(distances[~index], self.factor_distance)
|
|
117
|
+
|
|
118
|
+
# form the corresponding matrix
|
|
119
|
+
row = list(np.repeat(index_remain, n_neighbors))
|
|
120
|
+
col = list(labels_neighbor.ravel())
|
|
121
|
+
data = list(weights_neighbor.ravel())
|
|
122
|
+
|
|
123
|
+
row += list(index_seed)
|
|
124
|
+
col += list(labels_seed)
|
|
125
|
+
data += list(np.ones_like(index_seed))
|
|
126
|
+
|
|
127
|
+
membership = normalize(sparse.csr_matrix((data, (row, col)), shape=(n, np.max(labels_seed) + 1)))
|
|
128
|
+
membership_dense = membership.toarray()
|
|
129
|
+
labels = np.argmax(membership_dense, axis=1)
|
|
130
|
+
|
|
131
|
+
return membership, labels
|
|
132
|
+
|
|
133
|
+
def fit(self, input_matrix: Union[sparse.csr_matrix, np.ndarray], seeds: Union[np.ndarray, dict] = None,
|
|
134
|
+
seeds_row: Union[np.ndarray, dict] = None, seeds_col: Union[np.ndarray, dict] = None) -> 'KNN':
|
|
135
|
+
"""Node classification by k-nearest neighbors in the embedding space.
|
|
136
|
+
|
|
137
|
+
Parameters
|
|
138
|
+
----------
|
|
139
|
+
input_matrix :
|
|
140
|
+
Adjacency matrix or biadjacency matrix of the graph.
|
|
141
|
+
seeds :
|
|
142
|
+
Seed nodes. Can be a dict {node: label} or an array where "-1" means no label.
|
|
143
|
+
seeds_row, seeds_col :
|
|
144
|
+
Seeds of rows and columns (for bipartite graphs).
|
|
145
|
+
|
|
146
|
+
Returns
|
|
147
|
+
-------
|
|
148
|
+
self: :class:`KNN`
|
|
149
|
+
"""
|
|
150
|
+
adjacency, seeds, self.bipartite = get_adjacency_seeds(input_matrix, seeds=seeds, seeds_row=seeds_row,
|
|
151
|
+
seeds_col=seeds_col)
|
|
152
|
+
index_seed, index_remain, labels_seed = self._instantiate_vars(seeds)
|
|
153
|
+
embedding = self.embedding_method.fit_transform(adjacency)
|
|
154
|
+
membership, labels = self._fit_core(adjacency.shape[0], labels_seed, embedding, index_seed, index_remain)
|
|
155
|
+
|
|
156
|
+
self.membership_ = membership
|
|
157
|
+
self.labels_ = labels
|
|
158
|
+
|
|
159
|
+
if self.bipartite:
|
|
160
|
+
self._split_vars(input_matrix.shape)
|
|
161
|
+
|
|
162
|
+
return self
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created in July 2020
|
|
5
|
+
@author: Nathan de Lara <nathan.delara@polytechnique.org>
|
|
6
|
+
@author: Thomas Bonald <thomas.bonald@telecom-paris.fr>
|
|
7
|
+
"""
|
|
8
|
+
from typing import Union, Tuple
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from scipy import sparse
|
|
12
|
+
|
|
13
|
+
from sknetwork.utils.check import check_vector_format
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_accuracy_score(labels_true: np.ndarray, labels_pred: np.ndarray) -> float:
|
|
17
|
+
"""Return the proportion of correctly labeled samples.
|
|
18
|
+
Negative labels ignored.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
labels_true : np.ndarray
|
|
23
|
+
True labels.
|
|
24
|
+
labels_pred : np.ndarray
|
|
25
|
+
Predicted labels
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
accuracy : float
|
|
30
|
+
A score between 0 and 1.
|
|
31
|
+
|
|
32
|
+
Examples
|
|
33
|
+
--------
|
|
34
|
+
>>> import numpy as np
|
|
35
|
+
>>> labels_true = np.array([0, 0, 1, 1])
|
|
36
|
+
>>> labels_pred = np.array([0, 0, 0, 1])
|
|
37
|
+
>>> get_accuracy_score(labels_true, labels_pred)
|
|
38
|
+
0.75
|
|
39
|
+
"""
|
|
40
|
+
check_vector_format(labels_true, labels_pred)
|
|
41
|
+
mask = (labels_true >= 0) & (labels_pred >= 0)
|
|
42
|
+
if np.sum(mask):
|
|
43
|
+
return np.mean(labels_true[mask] == labels_pred[mask])
|
|
44
|
+
else:
|
|
45
|
+
raise ValueError('No sample with both true non-negative label and predicted non-negative label.')
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_confusion_matrix(labels_true: np.ndarray, labels_pred: np.ndarray) -> sparse.csr_matrix:
|
|
49
|
+
"""Return the confusion matrix in sparse format (true labels on rows, predicted labels on columns).
|
|
50
|
+
Negative labels ignored.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
labels_true : np.ndarray
|
|
55
|
+
True labels.
|
|
56
|
+
labels_pred : np.ndarray
|
|
57
|
+
Predicted labels
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
confusion matrix : sparse.csr_matrix
|
|
62
|
+
Confusion matrix.
|
|
63
|
+
|
|
64
|
+
Examples
|
|
65
|
+
--------
|
|
66
|
+
>>> import numpy as np
|
|
67
|
+
>>> labels_true = np.array([0, 0, 1, 1])
|
|
68
|
+
>>> labels_pred = np.array([0, 0, 0, 1])
|
|
69
|
+
>>> get_confusion_matrix(labels_true, labels_pred).toarray()
|
|
70
|
+
array([[2, 0],
|
|
71
|
+
[1, 1]])
|
|
72
|
+
"""
|
|
73
|
+
check_vector_format(labels_true, labels_pred)
|
|
74
|
+
mask = (labels_true >= 0) & (labels_pred >= 0)
|
|
75
|
+
if np.sum(mask):
|
|
76
|
+
n_labels = max(max(labels_true), max(labels_pred)) + 1
|
|
77
|
+
row = labels_true[mask]
|
|
78
|
+
col = labels_pred[mask]
|
|
79
|
+
data = np.ones(np.sum(mask), dtype=int)
|
|
80
|
+
return sparse.csr_matrix((data, (row, col)), shape=(n_labels, n_labels))
|
|
81
|
+
else:
|
|
82
|
+
raise ValueError('No sample with both true non-negative label and predicted non-negative label.')
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def get_f1_score(labels_true: np.ndarray, labels_pred: np.ndarray, return_precision_recall: bool = False) \
|
|
86
|
+
-> Union[float, Tuple[float, float, float]]:
|
|
87
|
+
"""Return the f1 score of binary classification.
|
|
88
|
+
Negative labels ignored.
|
|
89
|
+
|
|
90
|
+
Parameters
|
|
91
|
+
----------
|
|
92
|
+
labels_true : np.ndarray
|
|
93
|
+
True labels.
|
|
94
|
+
labels_pred : np.ndarray
|
|
95
|
+
Predicted labels
|
|
96
|
+
return_precision_recall : bool
|
|
97
|
+
If ``True``, also return precision and recall.
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
score, [precision, recall] : np.ndarray
|
|
102
|
+
F1 score (between 0 and 1). Optionally, also return precision and recall.
|
|
103
|
+
Examples
|
|
104
|
+
--------
|
|
105
|
+
>>> import numpy as np
|
|
106
|
+
>>> labels_true = np.array([0, 0, 1, 1])
|
|
107
|
+
>>> labels_pred = np.array([0, 0, 0, 1])
|
|
108
|
+
>>> np.round(get_f1_score(labels_true, labels_pred), 2)
|
|
109
|
+
0.67
|
|
110
|
+
"""
|
|
111
|
+
values = set(labels_true[labels_true >= 0]) | set(labels_pred[labels_pred >= 0])
|
|
112
|
+
if values != {0, 1}:
|
|
113
|
+
raise ValueError('Labels must be binary. '
|
|
114
|
+
'Check get_f1_scores or get_average_f1_score for multi-label classification.')
|
|
115
|
+
if return_precision_recall:
|
|
116
|
+
f1_scores, precisions, recalls = get_f1_scores(labels_true, labels_pred, True)
|
|
117
|
+
return f1_scores[1], precisions[1], recalls[1]
|
|
118
|
+
else:
|
|
119
|
+
f1_scores = get_f1_scores(labels_true, labels_pred, False)
|
|
120
|
+
return f1_scores[1]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def get_f1_scores(labels_true: np.ndarray, labels_pred: np.ndarray, return_precision_recall: bool = False) \
|
|
124
|
+
-> Union[np.ndarray, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
|
|
125
|
+
"""Return the f1 scores of multi-label classification (one per label).
|
|
126
|
+
Negative labels ignored.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
labels_true : np.ndarray
|
|
131
|
+
True labels.
|
|
132
|
+
labels_pred : np.ndarray
|
|
133
|
+
Predicted labels
|
|
134
|
+
return_precision_recall : bool
|
|
135
|
+
If ``True``, also return precisions and recalls.
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
scores, [precisions, recalls] : np.ndarray
|
|
140
|
+
F1 scores (between 0 and 1). Optionally, also return F1 precisions and recalls.
|
|
141
|
+
Examples
|
|
142
|
+
--------
|
|
143
|
+
>>> import numpy as np
|
|
144
|
+
>>> labels_true = np.array([0, 0, 1, 1])
|
|
145
|
+
>>> labels_pred = np.array([0, 0, 0, 1])
|
|
146
|
+
>>> np.round(get_f1_scores(labels_true, labels_pred), 2)
|
|
147
|
+
array([0.8 , 0.67])
|
|
148
|
+
"""
|
|
149
|
+
confusion = get_confusion_matrix(labels_true, labels_pred)
|
|
150
|
+
n_labels = confusion.shape[0]
|
|
151
|
+
counts_correct = confusion.diagonal()
|
|
152
|
+
counts_true = confusion.dot(np.ones(n_labels))
|
|
153
|
+
counts_pred = confusion.T.dot(np.ones(n_labels))
|
|
154
|
+
mask = counts_true > 0
|
|
155
|
+
recalls = np.zeros(n_labels)
|
|
156
|
+
recalls[mask] = counts_correct[mask] / counts_true[mask]
|
|
157
|
+
precisions = np.zeros(n_labels)
|
|
158
|
+
mask = counts_pred > 0
|
|
159
|
+
precisions[mask] = counts_correct[mask] / counts_pred[mask]
|
|
160
|
+
f1_scores = np.zeros(n_labels)
|
|
161
|
+
mask = (counts_true > 0) & (counts_pred > 0)
|
|
162
|
+
f1_scores[mask] = 2 / (1 / precisions[mask] + 1 / recalls[mask])
|
|
163
|
+
if return_precision_recall:
|
|
164
|
+
return f1_scores, precisions, recalls
|
|
165
|
+
else:
|
|
166
|
+
return f1_scores
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def get_average_f1_score(labels_true: np.ndarray, labels_pred: np.ndarray, average: str = 'macro') -> float:
|
|
170
|
+
"""Return the average f1 score of multi-label classification.
|
|
171
|
+
Negative labels ignored.
|
|
172
|
+
|
|
173
|
+
Parameters
|
|
174
|
+
----------
|
|
175
|
+
labels_true : np.ndarray
|
|
176
|
+
True labels.
|
|
177
|
+
labels_pred : np.ndarray
|
|
178
|
+
Predicted labels
|
|
179
|
+
average : str
|
|
180
|
+
Averaging method. Can be either ``'macro'`` (default), ``'micro'`` or ``'weighted'``.
|
|
181
|
+
|
|
182
|
+
Returns
|
|
183
|
+
-------
|
|
184
|
+
score : float
|
|
185
|
+
Average F1 score (between 0 and 1).
|
|
186
|
+
Examples
|
|
187
|
+
--------
|
|
188
|
+
>>> import numpy as np
|
|
189
|
+
>>> labels_true = np.array([0, 0, 1, 1])
|
|
190
|
+
>>> labels_pred = np.array([0, 0, 0, 1])
|
|
191
|
+
>>> np.round(get_average_f1_score(labels_true, labels_pred), 2)
|
|
192
|
+
0.73
|
|
193
|
+
"""
|
|
194
|
+
if average == 'micro':
|
|
195
|
+
# micro averaging = accuracy
|
|
196
|
+
return get_accuracy_score(labels_true, labels_pred)
|
|
197
|
+
else:
|
|
198
|
+
f1_scores = get_f1_scores(labels_true, labels_pred)
|
|
199
|
+
if average == 'macro':
|
|
200
|
+
return np.mean(f1_scores)
|
|
201
|
+
elif average == 'weighted':
|
|
202
|
+
labels_unique, counts = np.unique(labels_true[labels_true >= 0], return_counts=True)
|
|
203
|
+
return np.sum(f1_scores[labels_unique] * counts) / np.sum(counts)
|
|
204
|
+
else:
|
|
205
|
+
raise ValueError('Check the ``average`` parameter.')
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created on March 2020
|
|
5
|
+
@author: Nathan de Lara <nathan.delara@polytechnique.org>
|
|
6
|
+
"""
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from sknetwork.classification.base_rank import RankClassifier
|
|
12
|
+
from sknetwork.ranking.pagerank import PageRank
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PageRankClassifier(RankClassifier):
|
|
16
|
+
"""Node classification by multiple personalized PageRanks.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
damping_factor:
|
|
21
|
+
Probability to continue the random walk.
|
|
22
|
+
solver : :obj:`str`
|
|
23
|
+
Which solver to use: 'piteration', 'diteration', 'bicgstab', 'lanczos'.
|
|
24
|
+
n_iter : int
|
|
25
|
+
Number of iterations for some solvers such as ``'piteration'`` or ``'diteration'``.
|
|
26
|
+
tol : float
|
|
27
|
+
Tolerance for the convergence of some solvers such as ``'bicgstab'`` or ``'lanczos'``.
|
|
28
|
+
|
|
29
|
+
Attributes
|
|
30
|
+
----------
|
|
31
|
+
labels_ : np.ndarray, shape (n_labels,)
|
|
32
|
+
Label of each node.
|
|
33
|
+
membership_ : sparse.csr_matrix, shape (n_row, n_labels)
|
|
34
|
+
Membership matrix.
|
|
35
|
+
labels_row_ : np.ndarray
|
|
36
|
+
Labels of rows, for bipartite graphs.
|
|
37
|
+
labels_col_ : np.ndarray
|
|
38
|
+
Labels of columns, for bipartite graphs.
|
|
39
|
+
membership_row_ : sparse.csr_matrix, shape (n_row, n_labels)
|
|
40
|
+
Membership matrix of rows, for bipartite graphs.
|
|
41
|
+
membership_col_ : sparse.csr_matrix, shape (n_col, n_labels)
|
|
42
|
+
Membership matrix of columns, for bipartite graphs.
|
|
43
|
+
|
|
44
|
+
Example
|
|
45
|
+
-------
|
|
46
|
+
>>> from sknetwork.classification import PageRankClassifier
|
|
47
|
+
>>> from sknetwork.data import karate_club
|
|
48
|
+
>>> pagerank = PageRankClassifier()
|
|
49
|
+
>>> graph = karate_club(metadata=True)
|
|
50
|
+
>>> adjacency = graph.adjacency
|
|
51
|
+
>>> labels_true = graph.labels
|
|
52
|
+
>>> seeds = {0: labels_true[0], 33: labels_true[33]}
|
|
53
|
+
>>> labels_pred = pagerank.fit_predict(adjacency, seeds)
|
|
54
|
+
>>> np.round(np.mean(labels_pred == labels_true), 2)
|
|
55
|
+
0.97
|
|
56
|
+
|
|
57
|
+
References
|
|
58
|
+
----------
|
|
59
|
+
Lin, F., & Cohen, W. W. (2010). `Semi-supervised classification of network data using very few labels.
|
|
60
|
+
<https://lti.cs.cmu.edu/sites/default/files/research/reports/2009/cmulti09017.pdf>`_
|
|
61
|
+
In IEEE International Conference on Advances in Social Networks Analysis and Mining.
|
|
62
|
+
"""
|
|
63
|
+
def __init__(self, damping_factor: float = 0.85, solver: str = 'piteration', n_iter: int = 10, tol: float = 0.,
|
|
64
|
+
n_jobs: Optional[int] = None, verbose: bool = False):
|
|
65
|
+
algorithm = PageRank(damping_factor, solver, n_iter, tol)
|
|
66
|
+
super(PageRankClassifier, self).__init__(algorithm, n_jobs, verbose)
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# coding: utf-8
|
|
3
|
+
"""
|
|
4
|
+
Created on April 2020
|
|
5
|
+
@author: Thomas Bonald <tbonald@enst.fr>
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from scipy import sparse
|
|
12
|
+
|
|
13
|
+
from sknetwork.classification.base import BaseClassifier
|
|
14
|
+
from sknetwork.classification.vote import vote_update
|
|
15
|
+
from sknetwork.linalg.normalization import normalize
|
|
16
|
+
from sknetwork.utils.format import get_adjacency_seeds
|
|
17
|
+
from sknetwork.utils.membership import get_membership
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Propagation(BaseClassifier):
|
|
21
|
+
"""Node classification by label propagation.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
n_iter : float
|
|
26
|
+
Maximum number of iterations (-1 for infinity).
|
|
27
|
+
node_order : str
|
|
28
|
+
* `'random'`: node labels are updated in random order.
|
|
29
|
+
* `'increasing'`: node labels are updated by increasing order of (in-)weight.
|
|
30
|
+
* `'decreasing'`: node labels are updated by decreasing order of (in-)weight.
|
|
31
|
+
* Otherwise, node labels are updated by index order.
|
|
32
|
+
weighted : bool
|
|
33
|
+
If ``True``, the vote of each neighbor is proportional to the edge weight.
|
|
34
|
+
Otherwise, all votes have weight 1.
|
|
35
|
+
|
|
36
|
+
Attributes
|
|
37
|
+
----------
|
|
38
|
+
labels_ : np.ndarray, shape (n_labels,)
|
|
39
|
+
Label of each node.
|
|
40
|
+
membership_ : sparse.csr_matrix, shape (n_row, n_labels)
|
|
41
|
+
Membership matrix.
|
|
42
|
+
labels_row_ : np.ndarray
|
|
43
|
+
Labels of rows, for bipartite graphs.
|
|
44
|
+
labels_col_ : np.ndarray
|
|
45
|
+
Labels of columns, for bipartite graphs.
|
|
46
|
+
membership_row_ : sparse.csr_matrix, shape (n_row, n_labels)
|
|
47
|
+
Membership matrix of rows, for bipartite graphs.
|
|
48
|
+
membership_col_ : sparse.csr_matrix, shape (n_col, n_labels)
|
|
49
|
+
Membership matrix of columns, for bipartite graphs.
|
|
50
|
+
|
|
51
|
+
Example
|
|
52
|
+
-------
|
|
53
|
+
>>> from sknetwork.classification import Propagation
|
|
54
|
+
>>> from sknetwork.data import karate_club
|
|
55
|
+
>>> propagation = Propagation()
|
|
56
|
+
>>> graph = karate_club(metadata=True)
|
|
57
|
+
>>> adjacency = graph.adjacency
|
|
58
|
+
>>> labels_true = graph.labels
|
|
59
|
+
>>> seeds = {0: labels_true[0], 33: labels_true[33]}
|
|
60
|
+
>>> labels_pred = propagation.fit_predict(adjacency, seeds)
|
|
61
|
+
>>> np.round(np.mean(labels_pred == labels_true), 2)
|
|
62
|
+
0.94
|
|
63
|
+
|
|
64
|
+
References
|
|
65
|
+
----------
|
|
66
|
+
Raghavan, U. N., Albert, R., & Kumara, S. (2007).
|
|
67
|
+
`Near linear time algorithm to detect community structures in large-scale networks.
|
|
68
|
+
<https://arxiv.org/pdf/0709.2938.pdf>`_
|
|
69
|
+
Physical review E, 76(3), 036106.
|
|
70
|
+
"""
|
|
71
|
+
def __init__(self, n_iter: float = -1, node_order: str = None, weighted: bool = True):
|
|
72
|
+
super(Propagation, self).__init__()
|
|
73
|
+
|
|
74
|
+
if n_iter < 0:
|
|
75
|
+
self.n_iter = np.inf
|
|
76
|
+
else:
|
|
77
|
+
self.n_iter = n_iter
|
|
78
|
+
self.node_order = node_order
|
|
79
|
+
self.weighted = weighted
|
|
80
|
+
self.bipartite = None
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def _instantiate_vars(seeds: np.ndarray):
|
|
84
|
+
"""Instantiate variables for label propagation."""
|
|
85
|
+
n = len(seeds)
|
|
86
|
+
if len(set(seeds)) == n:
|
|
87
|
+
index_seed = np.arange(n)
|
|
88
|
+
index_remain = np.arange(n)
|
|
89
|
+
labels = seeds
|
|
90
|
+
else:
|
|
91
|
+
index_seed = np.argwhere(seeds >= 0).ravel()
|
|
92
|
+
index_remain = np.argwhere(seeds < 0).ravel()
|
|
93
|
+
labels = seeds[index_seed]
|
|
94
|
+
return index_seed.astype(np.int32), index_remain.astype(np.int32), labels.astype(np.int32)
|
|
95
|
+
|
|
96
|
+
def fit(self, input_matrix: Union[sparse.csr_matrix, np.ndarray], seeds: Union[np.ndarray, dict] = None,
|
|
97
|
+
seeds_row: Union[np.ndarray, dict] = None, seeds_col: Union[np.ndarray, dict] = None) -> 'Propagation':
|
|
98
|
+
"""Node classification by label propagation.
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
input_matrix :
|
|
103
|
+
Adjacency matrix or biadjacency matrix of the graph.
|
|
104
|
+
seeds :
|
|
105
|
+
Seed nodes. Can be a dict {node: label} or an array where "-1" means no label.
|
|
106
|
+
seeds_row, seeds_col :
|
|
107
|
+
Seeds of rows and columns (for bipartite graphs).
|
|
108
|
+
Returns
|
|
109
|
+
-------
|
|
110
|
+
self: :class:`Propagation`
|
|
111
|
+
"""
|
|
112
|
+
adjacency, seeds, self.bipartite = get_adjacency_seeds(input_matrix, seeds=seeds, seeds_row=seeds_row,
|
|
113
|
+
seeds_col=seeds_col, which='labels')
|
|
114
|
+
n = adjacency.shape[0]
|
|
115
|
+
index_seed, index_remain, labels_seed = self._instantiate_vars(seeds)
|
|
116
|
+
|
|
117
|
+
if self.node_order == 'random':
|
|
118
|
+
np.random.shuffle(index_remain)
|
|
119
|
+
elif self.node_order == 'decreasing':
|
|
120
|
+
index = np.argsort(-adjacency.T.dot(np.ones(n))).astype(np.int32)
|
|
121
|
+
index_remain = index[index_remain]
|
|
122
|
+
elif self.node_order == 'increasing':
|
|
123
|
+
index = np.argsort(adjacency.T.dot(np.ones(n))).astype(np.int32)
|
|
124
|
+
index_remain = index[index_remain]
|
|
125
|
+
|
|
126
|
+
labels = -np.ones(n, dtype=np.int32)
|
|
127
|
+
labels[index_seed] = labels_seed
|
|
128
|
+
labels_remain = np.zeros_like(index_remain, dtype=np.int32)
|
|
129
|
+
|
|
130
|
+
indptr = adjacency.indptr.astype(np.int32)
|
|
131
|
+
indices = adjacency.indices.astype(np.int32)
|
|
132
|
+
if self.weighted:
|
|
133
|
+
data = adjacency.data.astype(np.float32)
|
|
134
|
+
else:
|
|
135
|
+
data = np.ones(n, dtype=np.float32)
|
|
136
|
+
|
|
137
|
+
t = 0
|
|
138
|
+
while t < self.n_iter and not np.array_equal(labels_remain, labels[index_remain]):
|
|
139
|
+
t += 1
|
|
140
|
+
labels_remain = labels[index_remain].copy()
|
|
141
|
+
labels = np.asarray(vote_update(indptr, indices, data, labels, index_remain))
|
|
142
|
+
|
|
143
|
+
membership = get_membership(labels)
|
|
144
|
+
membership = normalize(adjacency.dot(membership))
|
|
145
|
+
|
|
146
|
+
self.labels_ = labels
|
|
147
|
+
self.membership_ = membership
|
|
148
|
+
|
|
149
|
+
if self.bipartite:
|
|
150
|
+
self._split_vars(input_matrix.shape)
|
|
151
|
+
|
|
152
|
+
return self
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""tests for classification"""
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""Tests for classification API"""
|
|
4
|
+
|
|
5
|
+
import unittest
|
|
6
|
+
|
|
7
|
+
from sknetwork.classification import *
|
|
8
|
+
from sknetwork.data.test_graphs import *
|
|
9
|
+
from sknetwork.embedding import LouvainEmbedding
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TestClassificationAPI(unittest.TestCase):
|
|
13
|
+
|
|
14
|
+
def test_undirected(self):
|
|
15
|
+
for adjacency in [test_graph(), test_digraph()]:
|
|
16
|
+
n = adjacency.shape[0]
|
|
17
|
+
seeds_array = -np.ones(n)
|
|
18
|
+
seeds_array[:2] = np.arange(2)
|
|
19
|
+
seeds_dict = {0: 0, 1: 1}
|
|
20
|
+
|
|
21
|
+
classifiers = [PageRankClassifier(), DiffusionClassifier(),
|
|
22
|
+
KNN(embedding_method=LouvainEmbedding(), n_neighbors=1), Propagation()]
|
|
23
|
+
|
|
24
|
+
with self.assertRaises(ValueError):
|
|
25
|
+
classifiers[0].score(0)
|
|
26
|
+
|
|
27
|
+
for algo in classifiers:
|
|
28
|
+
labels1 = algo.fit_predict(adjacency, seeds_array)
|
|
29
|
+
labels2 = algo.fit_predict(adjacency, seeds_dict)
|
|
30
|
+
scores = algo.score(0)
|
|
31
|
+
self.assertTrue((labels1 == labels2).all())
|
|
32
|
+
self.assertEqual(labels2.shape, (n,))
|
|
33
|
+
membership = algo.fit_transform(adjacency, seeds_array)
|
|
34
|
+
self.assertTupleEqual(membership.shape, (n, 2))
|
|
35
|
+
self.assertEqual(scores.shape, (n,))
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""Tests for DiffusionClassifier"""
|
|
4
|
+
|
|
5
|
+
import unittest
|
|
6
|
+
|
|
7
|
+
from sknetwork.classification import DiffusionClassifier
|
|
8
|
+
from sknetwork.data.test_graphs import *
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TestDiffusionClassifier(unittest.TestCase):
|
|
12
|
+
|
|
13
|
+
def test_graph(self):
|
|
14
|
+
adjacency = test_graph()
|
|
15
|
+
seeds = {0: 0, 1: 1}
|
|
16
|
+
algo = DiffusionClassifier()
|
|
17
|
+
algo.fit(adjacency, seeds=seeds)
|
|
18
|
+
self.assertTrue(len(algo.labels_) == adjacency.shape[0])
|
|
19
|
+
adjacency = test_digraph()
|
|
20
|
+
algo = DiffusionClassifier(centering=False)
|
|
21
|
+
algo.fit(adjacency, seeds=seeds)
|
|
22
|
+
self.assertTrue(len(algo.labels_) == adjacency.shape[0])
|
|
23
|
+
with self.assertRaises(ValueError):
|
|
24
|
+
DiffusionClassifier(n_iter=0)
|
|
25
|
+
algo = DiffusionClassifier(centering=False, threshold=1)
|
|
26
|
+
algo.fit(adjacency, seeds=seeds)
|
|
27
|
+
self.assertTrue(max(algo.labels_) == -1)
|
|
28
|
+
|
|
29
|
+
def test_bipartite(self):
|
|
30
|
+
biadjacency = test_bigraph()
|
|
31
|
+
n_row, n_col = biadjacency.shape
|
|
32
|
+
seeds_row = {0: 0, 1: 1}
|
|
33
|
+
seeds_col = {5: 1}
|
|
34
|
+
algo = DiffusionClassifier()
|
|
35
|
+
algo.fit(biadjacency, seeds_row=seeds_row, seeds_col=seeds_col)
|
|
36
|
+
self.assertTrue(len(algo.labels_row_) == n_row)
|
|
37
|
+
self.assertTrue(len(algo.labels_col_) == n_col)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""Tests for KNN"""
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
from sknetwork.classification import KNN
|
|
7
|
+
from sknetwork.data.test_graphs import *
|
|
8
|
+
from sknetwork.embedding import LouvainEmbedding
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TestDiffusionClassifier(unittest.TestCase):
|
|
12
|
+
|
|
13
|
+
def test_parallel(self):
|
|
14
|
+
for adjacency in [test_graph(), test_digraph(), test_bigraph()]:
|
|
15
|
+
seeds = {0: 0, 1: 1}
|
|
16
|
+
|
|
17
|
+
algo1 = KNN(n_neighbors=1, n_jobs=None, embedding_method=LouvainEmbedding())
|
|
18
|
+
algo2 = KNN(n_neighbors=1, n_jobs=-1, embedding_method=LouvainEmbedding())
|
|
19
|
+
|
|
20
|
+
labels1 = algo1.fit_predict(adjacency, seeds)
|
|
21
|
+
labels2 = algo2.fit_predict(adjacency, seeds)
|
|
22
|
+
|
|
23
|
+
self.assertTrue(np.allclose(labels1, labels2))
|
|
24
|
+
|