scikit-network 0.33.3__cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-network might be problematic. Click here for more details.
- scikit_network-0.33.3.dist-info/METADATA +122 -0
- scikit_network-0.33.3.dist-info/RECORD +229 -0
- scikit_network-0.33.3.dist-info/WHEEL +6 -0
- scikit_network-0.33.3.dist-info/licenses/AUTHORS.rst +43 -0
- scikit_network-0.33.3.dist-info/licenses/LICENSE +34 -0
- scikit_network-0.33.3.dist-info/top_level.txt +1 -0
- scikit_network.libs/libgomp-d22c30c5.so.1.0.0 +0 -0
- sknetwork/__init__.py +21 -0
- sknetwork/base.py +67 -0
- sknetwork/classification/__init__.py +8 -0
- sknetwork/classification/base.py +142 -0
- sknetwork/classification/base_rank.py +133 -0
- sknetwork/classification/diffusion.py +134 -0
- sknetwork/classification/knn.py +139 -0
- sknetwork/classification/metrics.py +205 -0
- sknetwork/classification/pagerank.py +66 -0
- sknetwork/classification/propagation.py +152 -0
- sknetwork/classification/tests/__init__.py +1 -0
- sknetwork/classification/tests/test_API.py +30 -0
- sknetwork/classification/tests/test_diffusion.py +77 -0
- sknetwork/classification/tests/test_knn.py +23 -0
- sknetwork/classification/tests/test_metrics.py +53 -0
- sknetwork/classification/tests/test_pagerank.py +20 -0
- sknetwork/classification/tests/test_propagation.py +24 -0
- sknetwork/classification/vote.cpp +27587 -0
- sknetwork/classification/vote.cpython-313-aarch64-linux-gnu.so +0 -0
- sknetwork/classification/vote.pyx +56 -0
- sknetwork/clustering/__init__.py +8 -0
- sknetwork/clustering/base.py +172 -0
- sknetwork/clustering/kcenters.py +253 -0
- sknetwork/clustering/leiden.py +242 -0
- sknetwork/clustering/leiden_core.cpp +31578 -0
- sknetwork/clustering/leiden_core.cpython-313-aarch64-linux-gnu.so +0 -0
- sknetwork/clustering/leiden_core.pyx +124 -0
- sknetwork/clustering/louvain.py +286 -0
- sknetwork/clustering/louvain_core.cpp +31223 -0
- sknetwork/clustering/louvain_core.cpython-313-aarch64-linux-gnu.so +0 -0
- sknetwork/clustering/louvain_core.pyx +124 -0
- sknetwork/clustering/metrics.py +91 -0
- sknetwork/clustering/postprocess.py +66 -0
- sknetwork/clustering/propagation_clustering.py +104 -0
- sknetwork/clustering/tests/__init__.py +1 -0
- sknetwork/clustering/tests/test_API.py +38 -0
- sknetwork/clustering/tests/test_kcenters.py +60 -0
- sknetwork/clustering/tests/test_leiden.py +34 -0
- sknetwork/clustering/tests/test_louvain.py +135 -0
- sknetwork/clustering/tests/test_metrics.py +50 -0
- sknetwork/clustering/tests/test_postprocess.py +39 -0
- sknetwork/data/__init__.py +6 -0
- sknetwork/data/base.py +33 -0
- sknetwork/data/load.py +406 -0
- sknetwork/data/models.py +459 -0
- sknetwork/data/parse.py +644 -0
- sknetwork/data/test_graphs.py +84 -0
- sknetwork/data/tests/__init__.py +1 -0
- sknetwork/data/tests/test_API.py +30 -0
- sknetwork/data/tests/test_base.py +14 -0
- sknetwork/data/tests/test_load.py +95 -0
- sknetwork/data/tests/test_models.py +52 -0
- sknetwork/data/tests/test_parse.py +250 -0
- sknetwork/data/tests/test_test_graphs.py +29 -0
- sknetwork/data/tests/test_toy_graphs.py +68 -0
- sknetwork/data/timeout.py +38 -0
- sknetwork/data/toy_graphs.py +611 -0
- sknetwork/embedding/__init__.py +8 -0
- sknetwork/embedding/base.py +94 -0
- sknetwork/embedding/force_atlas.py +198 -0
- sknetwork/embedding/louvain_embedding.py +148 -0
- sknetwork/embedding/random_projection.py +135 -0
- sknetwork/embedding/spectral.py +141 -0
- sknetwork/embedding/spring.py +198 -0
- sknetwork/embedding/svd.py +359 -0
- sknetwork/embedding/tests/__init__.py +1 -0
- sknetwork/embedding/tests/test_API.py +49 -0
- sknetwork/embedding/tests/test_force_atlas.py +35 -0
- sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
- sknetwork/embedding/tests/test_random_projection.py +28 -0
- sknetwork/embedding/tests/test_spectral.py +81 -0
- sknetwork/embedding/tests/test_spring.py +50 -0
- sknetwork/embedding/tests/test_svd.py +43 -0
- sknetwork/gnn/__init__.py +10 -0
- sknetwork/gnn/activation.py +117 -0
- sknetwork/gnn/base.py +181 -0
- sknetwork/gnn/base_activation.py +90 -0
- sknetwork/gnn/base_layer.py +109 -0
- sknetwork/gnn/gnn_classifier.py +305 -0
- sknetwork/gnn/layer.py +153 -0
- sknetwork/gnn/loss.py +180 -0
- sknetwork/gnn/neighbor_sampler.py +65 -0
- sknetwork/gnn/optimizer.py +164 -0
- sknetwork/gnn/tests/__init__.py +1 -0
- sknetwork/gnn/tests/test_activation.py +56 -0
- sknetwork/gnn/tests/test_base.py +75 -0
- sknetwork/gnn/tests/test_base_layer.py +37 -0
- sknetwork/gnn/tests/test_gnn_classifier.py +130 -0
- sknetwork/gnn/tests/test_layers.py +80 -0
- sknetwork/gnn/tests/test_loss.py +33 -0
- sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
- sknetwork/gnn/tests/test_optimizer.py +43 -0
- sknetwork/gnn/tests/test_utils.py +41 -0
- sknetwork/gnn/utils.py +127 -0
- sknetwork/hierarchy/__init__.py +6 -0
- sknetwork/hierarchy/base.py +96 -0
- sknetwork/hierarchy/louvain_hierarchy.py +272 -0
- sknetwork/hierarchy/metrics.py +234 -0
- sknetwork/hierarchy/paris.cpp +37871 -0
- sknetwork/hierarchy/paris.cpython-313-aarch64-linux-gnu.so +0 -0
- sknetwork/hierarchy/paris.pyx +316 -0
- sknetwork/hierarchy/postprocess.py +350 -0
- sknetwork/hierarchy/tests/__init__.py +1 -0
- sknetwork/hierarchy/tests/test_API.py +24 -0
- sknetwork/hierarchy/tests/test_algos.py +34 -0
- sknetwork/hierarchy/tests/test_metrics.py +62 -0
- sknetwork/hierarchy/tests/test_postprocess.py +57 -0
- sknetwork/linalg/__init__.py +9 -0
- sknetwork/linalg/basics.py +37 -0
- sknetwork/linalg/diteration.cpp +27403 -0
- sknetwork/linalg/diteration.cpython-313-aarch64-linux-gnu.so +0 -0
- sknetwork/linalg/diteration.pyx +47 -0
- sknetwork/linalg/eig_solver.py +93 -0
- sknetwork/linalg/laplacian.py +15 -0
- sknetwork/linalg/normalizer.py +86 -0
- sknetwork/linalg/operators.py +225 -0
- sknetwork/linalg/polynome.py +76 -0
- sknetwork/linalg/ppr_solver.py +170 -0
- sknetwork/linalg/push.cpp +31075 -0
- sknetwork/linalg/push.cpython-313-aarch64-linux-gnu.so +0 -0
- sknetwork/linalg/push.pyx +71 -0
- sknetwork/linalg/sparse_lowrank.py +142 -0
- sknetwork/linalg/svd_solver.py +91 -0
- sknetwork/linalg/tests/__init__.py +1 -0
- sknetwork/linalg/tests/test_eig.py +44 -0
- sknetwork/linalg/tests/test_laplacian.py +18 -0
- sknetwork/linalg/tests/test_normalization.py +34 -0
- sknetwork/linalg/tests/test_operators.py +66 -0
- sknetwork/linalg/tests/test_polynome.py +38 -0
- sknetwork/linalg/tests/test_ppr.py +50 -0
- sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
- sknetwork/linalg/tests/test_svd.py +38 -0
- sknetwork/linkpred/__init__.py +2 -0
- sknetwork/linkpred/base.py +46 -0
- sknetwork/linkpred/nn.py +126 -0
- sknetwork/linkpred/tests/__init__.py +1 -0
- sknetwork/linkpred/tests/test_nn.py +27 -0
- sknetwork/log.py +19 -0
- sknetwork/path/__init__.py +5 -0
- sknetwork/path/dag.py +54 -0
- sknetwork/path/distances.py +98 -0
- sknetwork/path/search.py +31 -0
- sknetwork/path/shortest_path.py +61 -0
- sknetwork/path/tests/__init__.py +1 -0
- sknetwork/path/tests/test_dag.py +37 -0
- sknetwork/path/tests/test_distances.py +62 -0
- sknetwork/path/tests/test_search.py +40 -0
- sknetwork/path/tests/test_shortest_path.py +40 -0
- sknetwork/ranking/__init__.py +8 -0
- sknetwork/ranking/base.py +61 -0
- sknetwork/ranking/betweenness.cpp +9710 -0
- sknetwork/ranking/betweenness.cpython-313-aarch64-linux-gnu.so +0 -0
- sknetwork/ranking/betweenness.pyx +97 -0
- sknetwork/ranking/closeness.py +92 -0
- sknetwork/ranking/hits.py +94 -0
- sknetwork/ranking/katz.py +83 -0
- sknetwork/ranking/pagerank.py +110 -0
- sknetwork/ranking/postprocess.py +37 -0
- sknetwork/ranking/tests/__init__.py +1 -0
- sknetwork/ranking/tests/test_API.py +32 -0
- sknetwork/ranking/tests/test_betweenness.py +38 -0
- sknetwork/ranking/tests/test_closeness.py +30 -0
- sknetwork/ranking/tests/test_hits.py +20 -0
- sknetwork/ranking/tests/test_pagerank.py +62 -0
- sknetwork/ranking/tests/test_postprocess.py +26 -0
- sknetwork/regression/__init__.py +4 -0
- sknetwork/regression/base.py +61 -0
- sknetwork/regression/diffusion.py +210 -0
- sknetwork/regression/tests/__init__.py +1 -0
- sknetwork/regression/tests/test_API.py +32 -0
- sknetwork/regression/tests/test_diffusion.py +56 -0
- sknetwork/sknetwork.py +3 -0
- sknetwork/test_base.py +35 -0
- sknetwork/test_log.py +15 -0
- sknetwork/topology/__init__.py +8 -0
- sknetwork/topology/cliques.cpp +32568 -0
- sknetwork/topology/cliques.cpython-313-aarch64-linux-gnu.so +0 -0
- sknetwork/topology/cliques.pyx +149 -0
- sknetwork/topology/core.cpp +30654 -0
- sknetwork/topology/core.cpython-313-aarch64-linux-gnu.so +0 -0
- sknetwork/topology/core.pyx +90 -0
- sknetwork/topology/cycles.py +243 -0
- sknetwork/topology/minheap.cpp +27335 -0
- sknetwork/topology/minheap.cpython-313-aarch64-linux-gnu.so +0 -0
- sknetwork/topology/minheap.pxd +20 -0
- sknetwork/topology/minheap.pyx +109 -0
- sknetwork/topology/structure.py +194 -0
- sknetwork/topology/tests/__init__.py +1 -0
- sknetwork/topology/tests/test_cliques.py +28 -0
- sknetwork/topology/tests/test_core.py +19 -0
- sknetwork/topology/tests/test_cycles.py +65 -0
- sknetwork/topology/tests/test_structure.py +85 -0
- sknetwork/topology/tests/test_triangles.py +38 -0
- sknetwork/topology/tests/test_wl.py +72 -0
- sknetwork/topology/triangles.cpp +8897 -0
- sknetwork/topology/triangles.cpython-313-aarch64-linux-gnu.so +0 -0
- sknetwork/topology/triangles.pyx +151 -0
- sknetwork/topology/weisfeiler_lehman.py +133 -0
- sknetwork/topology/weisfeiler_lehman_core.cpp +27638 -0
- sknetwork/topology/weisfeiler_lehman_core.cpython-313-aarch64-linux-gnu.so +0 -0
- sknetwork/topology/weisfeiler_lehman_core.pyx +114 -0
- sknetwork/utils/__init__.py +7 -0
- sknetwork/utils/check.py +355 -0
- sknetwork/utils/format.py +221 -0
- sknetwork/utils/membership.py +82 -0
- sknetwork/utils/neighbors.py +115 -0
- sknetwork/utils/tests/__init__.py +1 -0
- sknetwork/utils/tests/test_check.py +190 -0
- sknetwork/utils/tests/test_format.py +63 -0
- sknetwork/utils/tests/test_membership.py +24 -0
- sknetwork/utils/tests/test_neighbors.py +41 -0
- sknetwork/utils/tests/test_tfidf.py +18 -0
- sknetwork/utils/tests/test_values.py +66 -0
- sknetwork/utils/tfidf.py +37 -0
- sknetwork/utils/values.py +76 -0
- sknetwork/visualization/__init__.py +4 -0
- sknetwork/visualization/colors.py +34 -0
- sknetwork/visualization/dendrograms.py +277 -0
- sknetwork/visualization/graphs.py +1039 -0
- sknetwork/visualization/tests/__init__.py +1 -0
- sknetwork/visualization/tests/test_dendrograms.py +53 -0
- sknetwork/visualization/tests/test_graphs.py +176 -0
|
Binary file
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
# distutils: language=c++
|
|
2
|
+
# cython: language_level=3
|
|
3
|
+
"""
|
|
4
|
+
Created in June 2020
|
|
5
|
+
@author: Julien Simonnet <julien.simonnet@etu.upmc.fr>
|
|
6
|
+
@author: Yohann Robert <yohann.robert@etu.upmc.fr>
|
|
7
|
+
@author: Nathan de Lara <nathan.delara@polytechnique.org>
|
|
8
|
+
@author: Thomas Bonald <bonald@enst.fr>
|
|
9
|
+
"""
|
|
10
|
+
from libcpp.vector cimport vector
|
|
11
|
+
from scipy import sparse
|
|
12
|
+
from cython.parallel import prange
|
|
13
|
+
|
|
14
|
+
from sknetwork.path.dag import get_dag
|
|
15
|
+
from sknetwork.utils.check import check_square
|
|
16
|
+
from sknetwork.utils.format import directed2undirected
|
|
17
|
+
from sknetwork.utils.neighbors import get_degrees
|
|
18
|
+
|
|
19
|
+
cimport cython
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@cython.boundscheck(False)
|
|
23
|
+
@cython.wraparound(False)
|
|
24
|
+
cdef long count_local_triangles_from_dag(int node, vector[int] indptr, vector[int] indices) nogil:
|
|
25
|
+
"""Count the number of triangles from a given node in a directed acyclic graph.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
node :
|
|
30
|
+
Node.
|
|
31
|
+
indptr :
|
|
32
|
+
CSR format index pointer array of the adjacency matrix of the graph.
|
|
33
|
+
indices :
|
|
34
|
+
CSR format index array of the adjacency matrix of the graph.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
n_triangles :
|
|
39
|
+
Number of triangles.
|
|
40
|
+
"""
|
|
41
|
+
cdef int i, j, k
|
|
42
|
+
cdef int neighbor
|
|
43
|
+
cdef long n_triangles = 0
|
|
44
|
+
|
|
45
|
+
for k in range(indptr[node], indptr[node + 1]):
|
|
46
|
+
neighbor = indices[k]
|
|
47
|
+
i = indptr[node]
|
|
48
|
+
j = indptr[neighbor]
|
|
49
|
+
|
|
50
|
+
while (i < indptr[node + 1]) and (j < indptr[neighbor + 1]):
|
|
51
|
+
if indices[i] == indices[j]:
|
|
52
|
+
i += 1
|
|
53
|
+
j += 1
|
|
54
|
+
n_triangles += 1
|
|
55
|
+
else :
|
|
56
|
+
if indices[i] < indices[j]:
|
|
57
|
+
i += 1
|
|
58
|
+
else :
|
|
59
|
+
j += 1
|
|
60
|
+
|
|
61
|
+
return n_triangles
|
|
62
|
+
|
|
63
|
+
@cython.boundscheck(False)
|
|
64
|
+
@cython.wraparound(False)
|
|
65
|
+
cdef long count_triangles_from_dag(vector[int] indptr, vector[int] indices, bint parallelize):
|
|
66
|
+
"""Count the number of triangles in a directed acyclic graph.
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
indptr :
|
|
71
|
+
CSR format index pointer array of the adjacency matrix of the graph.
|
|
72
|
+
indices :
|
|
73
|
+
CSR format index array of the adjacency matrix of the graph.
|
|
74
|
+
parallelize :
|
|
75
|
+
If ``True``, use a parallel range to count triangles.
|
|
76
|
+
|
|
77
|
+
Returns
|
|
78
|
+
-------
|
|
79
|
+
n_triangles :
|
|
80
|
+
Number of triangles in the graph
|
|
81
|
+
"""
|
|
82
|
+
cdef int n_nodes = indptr.size() - 1
|
|
83
|
+
cdef int node
|
|
84
|
+
cdef long n_triangles = 0
|
|
85
|
+
|
|
86
|
+
if parallelize:
|
|
87
|
+
for node in prange(n_nodes, nogil=True):
|
|
88
|
+
n_triangles += count_local_triangles_from_dag(node, indptr, indices)
|
|
89
|
+
else:
|
|
90
|
+
for node in range(n_nodes):
|
|
91
|
+
n_triangles += count_local_triangles_from_dag(node, indptr, indices)
|
|
92
|
+
|
|
93
|
+
return n_triangles
|
|
94
|
+
|
|
95
|
+
def count_triangles(adjacency: sparse.csr_matrix, parallelize: bool = False) -> int:
|
|
96
|
+
"""Count the number of triangles in a graph. The graph is considered undirected.
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
adjacency :
|
|
101
|
+
Adjacency matrix of the graph.
|
|
102
|
+
parallelize :
|
|
103
|
+
If ``True``, use a parallel range while listing the triangles.
|
|
104
|
+
|
|
105
|
+
Returns
|
|
106
|
+
-------
|
|
107
|
+
n_triangles : int
|
|
108
|
+
Number of triangles.
|
|
109
|
+
|
|
110
|
+
Example
|
|
111
|
+
-------
|
|
112
|
+
>>> from sknetwork.data import karate_club
|
|
113
|
+
>>> adjacency = karate_club()
|
|
114
|
+
>>> count_triangles(adjacency)
|
|
115
|
+
45
|
|
116
|
+
"""
|
|
117
|
+
check_square(adjacency)
|
|
118
|
+
dag = get_dag(directed2undirected(adjacency))
|
|
119
|
+
indptr = dag.indptr
|
|
120
|
+
indices = dag.indices
|
|
121
|
+
n_triangles = count_triangles_from_dag(indptr, indices, parallelize)
|
|
122
|
+
return n_triangles
|
|
123
|
+
|
|
124
|
+
def get_clustering_coefficient(adjacency: sparse.csr_matrix, parallelize: bool = False) -> float:
|
|
125
|
+
"""Get the clustering coefficient of a graph.
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
adjacency :
|
|
130
|
+
Adjacency matrix of the graph.
|
|
131
|
+
parallelize :
|
|
132
|
+
If ``True``, use a parallel range while listing the triangles.
|
|
133
|
+
|
|
134
|
+
Returns
|
|
135
|
+
-------
|
|
136
|
+
coefficient : float
|
|
137
|
+
Clustering coefficient.
|
|
138
|
+
|
|
139
|
+
Example
|
|
140
|
+
-------
|
|
141
|
+
>>> from sknetwork.data import karate_club
|
|
142
|
+
>>> adjacency = karate_club()
|
|
143
|
+
>>> np.round(get_clustering_coefficient(adjacency), 2)
|
|
144
|
+
0.26
|
|
145
|
+
"""
|
|
146
|
+
n_triangles = count_triangles(adjacency, parallelize)
|
|
147
|
+
degrees = get_degrees(directed2undirected(adjacency))
|
|
148
|
+
degrees = degrees[degrees > 1]
|
|
149
|
+
n_edge_pairs = (degrees * (degrees - 1)).sum() / 2
|
|
150
|
+
coefficient = 3 * n_triangles / n_edge_pairs
|
|
151
|
+
return coefficient
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created in July 2020
|
|
5
|
+
@author: Pierre Pebereau <pierre.pebereau@telecom-paris.fr>
|
|
6
|
+
@author: Alexis Barreaux <alexis.barreaux@telecom-paris.fr>
|
|
7
|
+
"""
|
|
8
|
+
from typing import Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from scipy import sparse
|
|
12
|
+
|
|
13
|
+
from sknetwork.topology.weisfeiler_lehman_core import weisfeiler_lehman_coloring
|
|
14
|
+
from sknetwork.utils.check import check_format, check_square
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def color_weisfeiler_lehman(adjacency: Union[sparse.csr_matrix, np.ndarray], max_iter: int = -1) -> np.ndarray:
|
|
18
|
+
"""Color nodes using Weisfeiler-Lehman algorithm.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
adjacency : sparse.csr_matrix
|
|
23
|
+
Adjacency matrix of the graph
|
|
24
|
+
max_iter : int
|
|
25
|
+
Maximum number of iterations. Negative value means no limit (until convergence).
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
labels : np.ndarray
|
|
30
|
+
Label of each node.
|
|
31
|
+
|
|
32
|
+
Example
|
|
33
|
+
-------
|
|
34
|
+
>>> from sknetwork.data import house
|
|
35
|
+
>>> adjacency = house()
|
|
36
|
+
>>> labels = color_weisfeiler_lehman(adjacency)
|
|
37
|
+
>>> print(labels)
|
|
38
|
+
[0 2 1 1 2]
|
|
39
|
+
|
|
40
|
+
References
|
|
41
|
+
----------
|
|
42
|
+
* Douglas, B. L. (2011).
|
|
43
|
+
`The Weisfeiler-Lehman Method and Graph Isomorphism Testing.
|
|
44
|
+
<https://arxiv.org/pdf/1101.5211.pdf>`_
|
|
45
|
+
|
|
46
|
+
* Shervashidze, N., Schweitzer, P., van Leeuwen, E. J., Melhorn, K., Borgwardt, K. M. (2011)
|
|
47
|
+
`Weisfeiler-Lehman graph kernels.
|
|
48
|
+
<https://www.jmlr.org/papers/volume12/shervashidze11a/shervashidze11a.pdf>`_
|
|
49
|
+
Journal of Machine Learning Research 12, 2011.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
adjacency = check_format(adjacency, allow_empty=True)
|
|
53
|
+
check_square(adjacency)
|
|
54
|
+
n_nodes = adjacency.shape[0]
|
|
55
|
+
if max_iter < 0 or max_iter > n_nodes:
|
|
56
|
+
max_iter = n_nodes
|
|
57
|
+
|
|
58
|
+
labels = np.zeros(n_nodes, dtype=np.int32)
|
|
59
|
+
powers = (-np.pi / 3.15) ** np.arange(n_nodes, dtype=np.double)
|
|
60
|
+
indptr = adjacency.indptr
|
|
61
|
+
indices = adjacency.indices
|
|
62
|
+
labels, _ = weisfeiler_lehman_coloring(indptr, indices, labels, powers, max_iter)
|
|
63
|
+
return np.array(labels)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def are_isomorphic(adjacency1: sparse.csr_matrix, adjacency2: sparse.csr_matrix, max_iter: int = -1) -> bool:
|
|
67
|
+
"""Weisfeiler-Lehman isomorphism test. If the test is False, the graphs cannot be isomorphic.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
-----------
|
|
71
|
+
adjacency1 :
|
|
72
|
+
First adjacency matrix.
|
|
73
|
+
adjacency2 :
|
|
74
|
+
Second adjacency matrix.
|
|
75
|
+
max_iter : int
|
|
76
|
+
Maximum number of iterations. Negative value means no limit (until convergence).
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
test_result : bool
|
|
81
|
+
|
|
82
|
+
Example
|
|
83
|
+
-------
|
|
84
|
+
>>> from sknetwork.data import house, bow_tie
|
|
85
|
+
>>> are_isomorphic(house(), bow_tie())
|
|
86
|
+
False
|
|
87
|
+
|
|
88
|
+
References
|
|
89
|
+
----------
|
|
90
|
+
* Douglas, B. L. (2011).
|
|
91
|
+
`The Weisfeiler-Lehman Method and Graph Isomorphism Testing.
|
|
92
|
+
<https://arxiv.org/pdf/1101.5211.pdf>`_
|
|
93
|
+
|
|
94
|
+
* Shervashidze, N., Schweitzer, P., van Leeuwen, E. J., Melhorn, K., Borgwardt, K. M. (2011)
|
|
95
|
+
`Weisfeiler-Lehman graph kernels.
|
|
96
|
+
<https://www.jmlr.org/papers/volume12/shervashidze11a/shervashidze11a.pdf>`_
|
|
97
|
+
Journal of Machine Learning Research 12, 2011.
|
|
98
|
+
"""
|
|
99
|
+
adjacency1 = check_format(adjacency1)
|
|
100
|
+
check_square(adjacency1)
|
|
101
|
+
adjacency2 = check_format(adjacency2)
|
|
102
|
+
check_square(adjacency2)
|
|
103
|
+
|
|
104
|
+
if (adjacency1.shape != adjacency2.shape) or (adjacency1.nnz != adjacency2.nnz):
|
|
105
|
+
return False
|
|
106
|
+
|
|
107
|
+
n_nodes = adjacency1.shape[0]
|
|
108
|
+
|
|
109
|
+
if max_iter < 0 or max_iter > n_nodes:
|
|
110
|
+
max_iter = n_nodes
|
|
111
|
+
|
|
112
|
+
indptr1 = adjacency1.indptr
|
|
113
|
+
indptr2 = adjacency2.indptr
|
|
114
|
+
indices1 = adjacency1.indices
|
|
115
|
+
indices2 = adjacency2.indices
|
|
116
|
+
|
|
117
|
+
labels1 = np.zeros(n_nodes, dtype=np.int32)
|
|
118
|
+
labels2 = np.zeros(n_nodes, dtype=np.int32)
|
|
119
|
+
|
|
120
|
+
powers = (-np.pi / 3.15) ** np.arange(n_nodes, dtype=np.double)
|
|
121
|
+
|
|
122
|
+
iteration = 0
|
|
123
|
+
has_changed1, has_changed2 = True, True
|
|
124
|
+
while iteration < max_iter and (has_changed1 or has_changed2):
|
|
125
|
+
labels1, has_changed1 = weisfeiler_lehman_coloring(indptr1, indices1, labels1, powers, max_iter=1)
|
|
126
|
+
labels2, has_changed2 = weisfeiler_lehman_coloring(indptr2, indices2, labels2, powers, max_iter=1)
|
|
127
|
+
_, counts1 = np.unique(np.array(labels1), return_counts=True)
|
|
128
|
+
_, counts2 = np.unique(np.array(labels2), return_counts=True)
|
|
129
|
+
if (counts1 != counts2).any():
|
|
130
|
+
return False
|
|
131
|
+
iteration += 1
|
|
132
|
+
|
|
133
|
+
return True
|