scikit-network 0.30.0__cp38-cp38-win_amd64.whl → 0.32.1__cp38-cp38-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-network might be problematic. Click here for more details.

Files changed (187) hide show
  1. {scikit_network-0.30.0.dist-info → scikit_network-0.32.1.dist-info}/AUTHORS.rst +3 -0
  2. {scikit_network-0.30.0.dist-info → scikit_network-0.32.1.dist-info}/METADATA +31 -3
  3. scikit_network-0.32.1.dist-info/RECORD +228 -0
  4. {scikit_network-0.30.0.dist-info → scikit_network-0.32.1.dist-info}/WHEEL +1 -1
  5. sknetwork/__init__.py +1 -1
  6. sknetwork/base.py +67 -0
  7. sknetwork/classification/base.py +24 -24
  8. sknetwork/classification/base_rank.py +17 -25
  9. sknetwork/classification/diffusion.py +35 -35
  10. sknetwork/classification/knn.py +24 -21
  11. sknetwork/classification/metrics.py +1 -1
  12. sknetwork/classification/pagerank.py +10 -10
  13. sknetwork/classification/propagation.py +23 -20
  14. sknetwork/classification/tests/test_diffusion.py +13 -3
  15. sknetwork/classification/vote.cp38-win_amd64.pyd +0 -0
  16. sknetwork/classification/vote.cpp +14482 -10351
  17. sknetwork/classification/vote.pyx +1 -3
  18. sknetwork/clustering/__init__.py +3 -1
  19. sknetwork/clustering/base.py +36 -40
  20. sknetwork/clustering/kcenters.py +253 -0
  21. sknetwork/clustering/leiden.py +241 -0
  22. sknetwork/clustering/leiden_core.cp38-win_amd64.pyd +0 -0
  23. sknetwork/clustering/leiden_core.cpp +31564 -0
  24. sknetwork/clustering/leiden_core.pyx +124 -0
  25. sknetwork/clustering/louvain.py +133 -102
  26. sknetwork/clustering/louvain_core.cp38-win_amd64.pyd +0 -0
  27. sknetwork/clustering/louvain_core.cpp +22457 -18792
  28. sknetwork/clustering/louvain_core.pyx +86 -96
  29. sknetwork/clustering/postprocess.py +2 -2
  30. sknetwork/clustering/propagation_clustering.py +15 -19
  31. sknetwork/clustering/tests/test_API.py +8 -4
  32. sknetwork/clustering/tests/test_kcenters.py +92 -0
  33. sknetwork/clustering/tests/test_leiden.py +34 -0
  34. sknetwork/clustering/tests/test_louvain.py +3 -4
  35. sknetwork/data/__init__.py +2 -1
  36. sknetwork/data/base.py +28 -0
  37. sknetwork/data/load.py +38 -37
  38. sknetwork/data/models.py +18 -18
  39. sknetwork/data/parse.py +54 -33
  40. sknetwork/data/test_graphs.py +2 -2
  41. sknetwork/data/tests/test_API.py +1 -1
  42. sknetwork/data/tests/test_base.py +14 -0
  43. sknetwork/data/tests/test_load.py +1 -1
  44. sknetwork/data/tests/test_parse.py +9 -12
  45. sknetwork/data/tests/test_test_graphs.py +1 -2
  46. sknetwork/data/toy_graphs.py +18 -18
  47. sknetwork/embedding/__init__.py +0 -1
  48. sknetwork/embedding/base.py +21 -20
  49. sknetwork/embedding/force_atlas.py +3 -2
  50. sknetwork/embedding/louvain_embedding.py +2 -2
  51. sknetwork/embedding/random_projection.py +5 -3
  52. sknetwork/embedding/spectral.py +0 -73
  53. sknetwork/embedding/tests/test_API.py +4 -28
  54. sknetwork/embedding/tests/test_louvain_embedding.py +4 -9
  55. sknetwork/embedding/tests/test_random_projection.py +2 -2
  56. sknetwork/embedding/tests/test_spectral.py +5 -8
  57. sknetwork/embedding/tests/test_svd.py +1 -1
  58. sknetwork/gnn/base.py +4 -4
  59. sknetwork/gnn/base_layer.py +3 -3
  60. sknetwork/gnn/gnn_classifier.py +45 -89
  61. sknetwork/gnn/layer.py +1 -1
  62. sknetwork/gnn/loss.py +1 -1
  63. sknetwork/gnn/optimizer.py +4 -3
  64. sknetwork/gnn/tests/test_base_layer.py +4 -4
  65. sknetwork/gnn/tests/test_gnn_classifier.py +12 -35
  66. sknetwork/gnn/utils.py +8 -8
  67. sknetwork/hierarchy/base.py +29 -2
  68. sknetwork/hierarchy/louvain_hierarchy.py +45 -41
  69. sknetwork/hierarchy/paris.cp38-win_amd64.pyd +0 -0
  70. sknetwork/hierarchy/paris.cpp +27371 -22844
  71. sknetwork/hierarchy/paris.pyx +7 -9
  72. sknetwork/hierarchy/postprocess.py +16 -16
  73. sknetwork/hierarchy/tests/test_API.py +1 -1
  74. sknetwork/hierarchy/tests/test_algos.py +5 -0
  75. sknetwork/hierarchy/tests/test_metrics.py +1 -1
  76. sknetwork/linalg/__init__.py +1 -1
  77. sknetwork/linalg/diteration.cp38-win_amd64.pyd +0 -0
  78. sknetwork/linalg/diteration.cpp +13474 -9454
  79. sknetwork/linalg/diteration.pyx +0 -2
  80. sknetwork/linalg/eig_solver.py +1 -1
  81. sknetwork/linalg/{normalization.py → normalizer.py} +18 -15
  82. sknetwork/linalg/operators.py +1 -1
  83. sknetwork/linalg/ppr_solver.py +1 -1
  84. sknetwork/linalg/push.cp38-win_amd64.pyd +0 -0
  85. sknetwork/linalg/push.cpp +23003 -18807
  86. sknetwork/linalg/push.pyx +0 -2
  87. sknetwork/linalg/svd_solver.py +1 -1
  88. sknetwork/linalg/tests/test_normalization.py +3 -7
  89. sknetwork/linalg/tests/test_operators.py +4 -8
  90. sknetwork/linalg/tests/test_ppr.py +1 -1
  91. sknetwork/linkpred/base.py +13 -2
  92. sknetwork/linkpred/nn.py +6 -6
  93. sknetwork/log.py +19 -0
  94. sknetwork/path/__init__.py +4 -3
  95. sknetwork/path/dag.py +54 -0
  96. sknetwork/path/distances.py +98 -0
  97. sknetwork/path/search.py +13 -47
  98. sknetwork/path/shortest_path.py +37 -162
  99. sknetwork/path/tests/test_dag.py +37 -0
  100. sknetwork/path/tests/test_distances.py +62 -0
  101. sknetwork/path/tests/test_search.py +26 -11
  102. sknetwork/path/tests/test_shortest_path.py +31 -36
  103. sknetwork/ranking/__init__.py +0 -1
  104. sknetwork/ranking/base.py +13 -8
  105. sknetwork/ranking/betweenness.cp38-win_amd64.pyd +0 -0
  106. sknetwork/ranking/betweenness.cpp +5709 -3017
  107. sknetwork/ranking/betweenness.pyx +0 -2
  108. sknetwork/ranking/closeness.py +7 -10
  109. sknetwork/ranking/pagerank.py +14 -14
  110. sknetwork/ranking/postprocess.py +12 -3
  111. sknetwork/ranking/tests/test_API.py +2 -4
  112. sknetwork/ranking/tests/test_betweenness.py +3 -3
  113. sknetwork/ranking/tests/test_closeness.py +3 -7
  114. sknetwork/ranking/tests/test_pagerank.py +11 -5
  115. sknetwork/ranking/tests/test_postprocess.py +5 -0
  116. sknetwork/regression/base.py +19 -2
  117. sknetwork/regression/diffusion.py +24 -10
  118. sknetwork/regression/tests/test_diffusion.py +8 -0
  119. sknetwork/test_base.py +35 -0
  120. sknetwork/test_log.py +15 -0
  121. sknetwork/topology/__init__.py +7 -8
  122. sknetwork/topology/cliques.cp38-win_amd64.pyd +0 -0
  123. sknetwork/topology/{kcliques.cpp → cliques.cpp} +23423 -20277
  124. sknetwork/topology/cliques.pyx +149 -0
  125. sknetwork/topology/core.cp38-win_amd64.pyd +0 -0
  126. sknetwork/topology/{kcore.cpp → core.cpp} +21637 -18762
  127. sknetwork/topology/core.pyx +90 -0
  128. sknetwork/topology/cycles.py +243 -0
  129. sknetwork/topology/minheap.cp38-win_amd64.pyd +0 -0
  130. sknetwork/{utils → topology}/minheap.cpp +19452 -15368
  131. sknetwork/{utils → topology}/minheap.pxd +1 -3
  132. sknetwork/{utils → topology}/minheap.pyx +1 -3
  133. sknetwork/topology/structure.py +3 -43
  134. sknetwork/topology/tests/test_cliques.py +11 -11
  135. sknetwork/topology/tests/test_core.py +19 -0
  136. sknetwork/topology/tests/test_cycles.py +65 -0
  137. sknetwork/topology/tests/test_structure.py +2 -16
  138. sknetwork/topology/tests/test_triangles.py +11 -15
  139. sknetwork/topology/tests/test_wl.py +72 -0
  140. sknetwork/topology/triangles.cp38-win_amd64.pyd +0 -0
  141. sknetwork/topology/triangles.cpp +5056 -2696
  142. sknetwork/topology/triangles.pyx +74 -89
  143. sknetwork/topology/weisfeiler_lehman.py +56 -86
  144. sknetwork/topology/weisfeiler_lehman_core.cp38-win_amd64.pyd +0 -0
  145. sknetwork/topology/weisfeiler_lehman_core.cpp +14727 -10622
  146. sknetwork/topology/weisfeiler_lehman_core.pyx +0 -2
  147. sknetwork/utils/__init__.py +1 -31
  148. sknetwork/utils/check.py +2 -2
  149. sknetwork/utils/format.py +5 -3
  150. sknetwork/utils/membership.py +2 -2
  151. sknetwork/utils/tests/test_check.py +3 -3
  152. sknetwork/utils/tests/test_format.py +3 -1
  153. sknetwork/utils/values.py +1 -1
  154. sknetwork/visualization/__init__.py +2 -2
  155. sknetwork/visualization/dendrograms.py +55 -7
  156. sknetwork/visualization/graphs.py +292 -72
  157. sknetwork/visualization/tests/test_dendrograms.py +9 -9
  158. sknetwork/visualization/tests/test_graphs.py +71 -62
  159. scikit_network-0.30.0.dist-info/RECORD +0 -227
  160. sknetwork/embedding/louvain_hierarchy.py +0 -142
  161. sknetwork/embedding/tests/test_louvain_hierarchy.py +0 -19
  162. sknetwork/path/metrics.py +0 -148
  163. sknetwork/path/tests/test_metrics.py +0 -29
  164. sknetwork/ranking/harmonic.py +0 -82
  165. sknetwork/topology/dag.py +0 -74
  166. sknetwork/topology/dag_core.cp38-win_amd64.pyd +0 -0
  167. sknetwork/topology/dag_core.cpp +0 -23350
  168. sknetwork/topology/dag_core.pyx +0 -38
  169. sknetwork/topology/kcliques.cp38-win_amd64.pyd +0 -0
  170. sknetwork/topology/kcliques.pyx +0 -193
  171. sknetwork/topology/kcore.cp38-win_amd64.pyd +0 -0
  172. sknetwork/topology/kcore.pyx +0 -120
  173. sknetwork/topology/tests/test_cores.py +0 -21
  174. sknetwork/topology/tests/test_dag.py +0 -26
  175. sknetwork/topology/tests/test_wl_coloring.py +0 -49
  176. sknetwork/topology/tests/test_wl_kernel.py +0 -31
  177. sknetwork/utils/base.py +0 -35
  178. sknetwork/utils/minheap.cp38-win_amd64.pyd +0 -0
  179. sknetwork/utils/simplex.py +0 -140
  180. sknetwork/utils/tests/test_base.py +0 -28
  181. sknetwork/utils/tests/test_bunch.py +0 -16
  182. sknetwork/utils/tests/test_projection_simplex.py +0 -33
  183. sknetwork/utils/tests/test_verbose.py +0 -15
  184. sknetwork/utils/verbose.py +0 -37
  185. {scikit_network-0.30.0.dist-info → scikit_network-0.32.1.dist-info}/LICENSE +0 -0
  186. {scikit_network-0.30.0.dist-info → scikit_network-0.32.1.dist-info}/top_level.txt +0 -0
  187. /sknetwork/{utils → data}/timeout.py +0 -0
@@ -1,54 +1,53 @@
1
- # distutils: language = c++
1
+ # distutils: language=c++
2
2
  # cython: language_level=3
3
- # cython: linetrace=True
4
- # distutils: define_macros=CYTHON_TRACE_NOGIL=1
5
3
  """
6
- Created on Jun 3, 2020
4
+ Created in June 2020
7
5
  @author: Julien Simonnet <julien.simonnet@etu.upmc.fr>
8
6
  @author: Yohann Robert <yohann.robert@etu.upmc.fr>
9
7
  @author: Nathan de Lara <nathan.delara@polytechnique.org>
8
+ @author: Thomas Bonald <bonald@enst.fr>
10
9
  """
11
10
  from libcpp.vector cimport vector
12
11
  from scipy import sparse
13
- from scipy.special import comb
14
12
  from cython.parallel import prange
15
13
 
16
- from sknetwork.topology.dag import DAG
17
- from sknetwork.utils.base import Algorithm
14
+ from sknetwork.path.dag import get_dag
15
+ from sknetwork.utils.check import check_square
16
+ from sknetwork.utils.format import directed2undirected
17
+ from sknetwork.utils.neighbors import get_degrees
18
18
 
19
19
  cimport cython
20
20
 
21
21
 
22
22
  @cython.boundscheck(False)
23
23
  @cython.wraparound(False)
24
- cdef long count_local_triangles(int source, vector[int] indptr, vector[int] indices) nogil:
25
- """Counts the number of nodes in the intersection of a node and its neighbors in a DAG.
24
+ cdef long count_local_triangles_from_dag(int node, vector[int] indptr, vector[int] indices) nogil:
25
+ """Count the number of triangles from a given node in a directed acyclic graph.
26
26
 
27
27
  Parameters
28
28
  ----------
29
- source :
30
- Index of the node to study.
29
+ node :
30
+ Node.
31
31
  indptr :
32
- CSR format index pointer array of the normalized adjacency matrix of a DAG.
32
+ CSR format index pointer array of the adjacency matrix of the graph.
33
33
  indices :
34
- CSR format index array of the normalized adjacency matrix of a DAG.
34
+ CSR format index array of the adjacency matrix of the graph.
35
35
 
36
36
  Returns
37
37
  -------
38
38
  n_triangles :
39
- Number of nodes in the intersection
39
+ Number of triangles.
40
40
  """
41
41
  cdef int i, j, k
42
- cdef int v
43
- cdef long n_triangles = 0 # number of nodes in the intersection
42
+ cdef int neighbor
43
+ cdef long n_triangles = 0
44
44
 
45
- for k in range(indptr[source], indptr[source+1]):
46
- v = indices[k]
47
- i = indptr[source]
48
- j = indptr[v]
45
+ for k in range(indptr[node], indptr[node + 1]):
46
+ neighbor = indices[k]
47
+ i = indptr[node]
48
+ j = indptr[neighbor]
49
49
 
50
- # calculates the intersection of the neighbors of u and v
51
- while (i < indptr[source+1]) and (j < indptr[v+1]):
50
+ while (i < indptr[node + 1]) and (j < indptr[neighbor + 1]):
52
51
  if indices[i] == indices[j]:
53
52
  i += 1
54
53
  j += 1
@@ -61,18 +60,17 @@ cdef long count_local_triangles(int source, vector[int] indptr, vector[int] indi
61
60
 
62
61
  return n_triangles
63
62
 
64
-
65
63
  @cython.boundscheck(False)
66
64
  @cython.wraparound(False)
67
- cdef long fit_core(vector[int] indptr, vector[int] indices, bint parallelize):
68
- """Counts the number of triangles directly without exporting the graph.
65
+ cdef long count_triangles_from_dag(vector[int] indptr, vector[int] indices, bint parallelize):
66
+ """Count the number of triangles in a directed acyclic graph.
69
67
 
70
68
  Parameters
71
69
  ----------
72
70
  indptr :
73
- CSR format index pointer array of the normalized adjacency matrix of a DAG.
71
+ CSR format index pointer array of the adjacency matrix of the graph.
74
72
  indices :
75
- CSR format index array of the normalized adjacency matrix of a DAG.
73
+ CSR format index array of the adjacency matrix of the graph.
76
74
  parallelize :
77
75
  If ``True``, use a parallel range to count triangles.
78
76
 
@@ -81,86 +79,73 @@ cdef long fit_core(vector[int] indptr, vector[int] indices, bint parallelize):
81
79
  n_triangles :
82
80
  Number of triangles in the graph
83
81
  """
84
- cdef int n = indptr.size() - 1
85
- cdef int u
82
+ cdef int n_nodes = indptr.size() - 1
83
+ cdef int node
86
84
  cdef long n_triangles = 0
87
85
 
88
86
  if parallelize:
89
- for u in prange(n, nogil=True):
90
- n_triangles += count_local_triangles(u, indptr, indices)
87
+ for node in prange(n_nodes, nogil=True):
88
+ n_triangles += count_local_triangles_from_dag(node, indptr, indices)
91
89
  else:
92
- for u in range(n):
93
- n_triangles += count_local_triangles(u, indptr, indices)
90
+ for node in range(n_nodes):
91
+ n_triangles += count_local_triangles_from_dag(node, indptr, indices)
94
92
 
95
93
  return n_triangles
96
94
 
97
-
98
- class Triangles(Algorithm):
99
- """Count the number of triangles in a graph, and evaluate the clustering coefficient.
100
-
101
- * Graphs
95
+ def count_triangles(adjacency: sparse.csr_matrix, parallelize: bool = False) -> int:
96
+ """Count the number of triangles in a graph. The graph is considered undirected.
102
97
 
103
98
  Parameters
104
99
  ----------
100
+ adjacency :
101
+ Adjacency matrix of the graph.
105
102
  parallelize :
106
103
  If ``True``, use a parallel range while listing the triangles.
107
104
 
108
- Attributes
109
- ----------
110
- n_triangles_ : int
105
+ Returns
106
+ -------
107
+ n_triangles : int
111
108
  Number of triangles.
112
- clustering_coef_ : float
113
- Global clustering coefficient of the graph.
114
109
 
115
110
  Example
116
111
  -------
117
112
  >>> from sknetwork.data import karate_club
118
- >>> triangles = Triangles()
119
113
  >>> adjacency = karate_club()
120
- >>> triangles.fit_transform(adjacency)
114
+ >>> count_triangles(adjacency)
121
115
  45
122
116
  """
123
- def __init__(self, parallelize : bool = False):
124
- super(Triangles, self).__init__()
125
- self.parallelize = parallelize
126
- self.n_triangles_ = None
127
- self.clustering_coef_ = None
128
-
129
- def fit(self, adjacency: sparse.csr_matrix) -> 'Triangles':
130
- """Count triangles.
131
-
132
- Parameters
133
- ----------
134
- adjacency :
135
- Adjacency matrix of the graph.
136
-
137
- Returns
138
- -------
139
- self: :class:`Triangles`
140
- """
141
- degrees = adjacency.indptr[1:] - adjacency.indptr[:-1]
142
- edge_pairs = comb(degrees, 2).sum()
143
-
144
- dag = DAG(ordering='degree')
145
- dag.fit(adjacency)
146
- indptr = dag.indptr_
147
- indices = dag.indices_
148
-
149
- self.n_triangles_ = fit_core(indptr, indices, self.parallelize)
150
- if edge_pairs > 0:
151
- self.clustering_coef_ = 3 * self.n_triangles_ / edge_pairs
152
- else:
153
- self.clustering_coef_ = 0.
154
-
155
- return self
156
-
157
- def fit_transform(self, adjacency: sparse.csr_matrix) -> int:
158
- """ Fit algorithm to the data and return the number of triangles. Same parameters as the ``fit`` method.
159
-
160
- Returns
161
- -------
162
- n_triangles_ : int
163
- Number of triangles.
164
- """
165
- self.fit(adjacency)
166
- return self.n_triangles_
117
+ check_square(adjacency)
118
+ dag = get_dag(directed2undirected(adjacency))
119
+ indptr = dag.indptr
120
+ indices = dag.indices
121
+ n_triangles = count_triangles_from_dag(indptr, indices, parallelize)
122
+ return n_triangles
123
+
124
+ def get_clustering_coefficient(adjacency: sparse.csr_matrix, parallelize: bool = False) -> float:
125
+ """Get the clustering coefficient of a graph.
126
+
127
+ Parameters
128
+ ----------
129
+ adjacency :
130
+ Adjacency matrix of the graph.
131
+ parallelize :
132
+ If ``True``, use a parallel range while listing the triangles.
133
+
134
+ Returns
135
+ -------
136
+ coefficient : float
137
+ Clustering coefficient.
138
+
139
+ Example
140
+ -------
141
+ >>> from sknetwork.data import karate_club
142
+ >>> adjacency = karate_club()
143
+ >>> np.round(get_clustering_coefficient(adjacency), 2)
144
+ 0.26
145
+ """
146
+ n_triangles = count_triangles(adjacency, parallelize)
147
+ degrees = get_degrees(directed2undirected(adjacency))
148
+ degrees = degrees[degrees > 1]
149
+ n_edge_pairs = (degrees * (degrees - 1)).sum() / 2
150
+ coefficient = 3 * n_triangles / n_edge_pairs
151
+ return coefficient
@@ -1,7 +1,7 @@
1
1
  #!/usr/bin/env python3
2
2
  # -*- coding: utf-8 -*-
3
3
  """
4
- Created on July 2, 2020
4
+ Created in July 2020
5
5
  @author: Pierre Pebereau <pierre.pebereau@telecom-paris.fr>
6
6
  @author: Alexis Barreaux <alexis.barreaux@telecom-paris.fr>
7
7
  """
@@ -9,33 +9,33 @@ from typing import Union
9
9
 
10
10
  import numpy as np
11
11
  from scipy import sparse
12
- from sknetwork.topology.weisfeiler_lehman_core import weisfeiler_lehman_coloring
13
12
 
14
- from sknetwork.utils.base import Algorithm
13
+ from sknetwork.topology.weisfeiler_lehman_core import weisfeiler_lehman_coloring
14
+ from sknetwork.utils.check import check_format, check_square
15
15
 
16
16
 
17
- class WeisfeilerLehman(Algorithm):
18
- """Weisfeiler-Lehman algorithm for coloring/labeling graphs in order to check similarity.
17
+ def color_weisfeiler_lehman(adjacency: Union[sparse.csr_matrix, np.ndarray], max_iter: int = -1) -> np.ndarray:
18
+ """Color nodes using Weisfeiler-Lehman algorithm.
19
19
 
20
20
  Parameters
21
21
  ----------
22
+ adjacency : sparse.csr_matrix
23
+ Adjacency matrix of the graph
22
24
  max_iter : int
23
- Maximum number of iterations. Negative value means until convergence.
25
+ Maximum number of iterations. Negative value means no limit (until convergence).
24
26
 
25
- Attributes
26
- ----------
27
- labels_ : np.ndarray
27
+ Returns
28
+ -------
29
+ labels : np.ndarray
28
30
  Label of each node.
29
31
 
30
32
  Example
31
33
  -------
32
- >>> from sknetwork.topology import WeisfeilerLehman
33
34
  >>> from sknetwork.data import house
34
- >>> weisfeiler_lehman = WeisfeilerLehman()
35
35
  >>> adjacency = house()
36
- >>> labels = weisfeiler_lehman.fit_transform(adjacency)
37
- >>> labels
38
- array([0, 2, 1, 1, 2], dtype=int32)
36
+ >>> labels = color_weisfeiler_lehman(adjacency)
37
+ >>> print(labels)
38
+ [0 2 1 1 2]
39
39
 
40
40
  References
41
41
  ----------
@@ -45,57 +45,26 @@ class WeisfeilerLehman(Algorithm):
45
45
 
46
46
  * Shervashidze, N., Schweitzer, P., van Leeuwen, E. J., Melhorn, K., Borgwardt, K. M. (2011)
47
47
  `Weisfeiler-Lehman graph kernels.
48
- <http://www.jmlr.org/papers/volume12/shervashidze11a/shervashidze11a.pdf>`_
48
+ <https://www.jmlr.org/papers/volume12/shervashidze11a/shervashidze11a.pdf>`_
49
49
  Journal of Machine Learning Research 12, 2011.
50
50
  """
51
- def __init__(self, max_iter: int = -1):
52
- super(WeisfeilerLehman, self).__init__()
53
- self.max_iter = max_iter
54
- self.labels_ = None
55
-
56
- def fit(self, adjacency: Union[sparse.csr_matrix, np.ndarray]) -> 'WeisfeilerLehman':
57
- """Fit algorithm to the data.
58
-
59
- Parameters
60
- ----------
61
- adjacency : Union[sparse.csr_matrix, np.ndarray]
62
- Adjacency matrix of the graph.
63
-
64
- Returns
65
- -------
66
- self: :class:`WeisfeilerLehman`
67
- """
68
- n: int = adjacency.shape[0]
69
- if self.max_iter < 0 or self.max_iter > n:
70
- max_iter = np.int32(n)
71
- else:
72
- max_iter = np.int32(self.max_iter)
73
-
74
- labels = np.zeros(n, dtype=np.int32)
75
- powers = (-np.pi / 3.15) ** np.arange(n, dtype=np.double)
76
- indptr = adjacency.indptr.astype(np.int32)
77
- indices = adjacency.indices.astype(np.int32)
78
-
79
- labels, _ = weisfeiler_lehman_coloring(indptr, indices, labels, powers, max_iter)
80
- self.labels_ = np.asarray(labels).astype(np.int32)
81
- return self
82
-
83
- def fit_transform(self, adjacency: Union[sparse.csr_matrix, np.ndarray]) -> np.ndarray:
84
- """Fit algorithm to the data and return the labels. Same parameters as the ``fit`` method.
85
-
86
- Returns
87
- -------
88
- labels : np.ndarray
89
- Labels.
90
- """
91
- self.fit(adjacency)
92
- return self.labels_
93
-
94
-
95
- def are_isomorphic(adjacency1: sparse.csr_matrix,
96
- adjacency2: sparse.csr_matrix, max_iter: int = -1) -> bool:
97
- """Weisfeiler-Lehman isomorphism test. If the test is False, the graphs cannot be isomorphic,
98
- otherwise, they might be.
51
+
52
+ adjacency = check_format(adjacency, allow_empty=True)
53
+ check_square(adjacency)
54
+ n_nodes = adjacency.shape[0]
55
+ if max_iter < 0 or max_iter > n_nodes:
56
+ max_iter = n_nodes
57
+
58
+ labels = np.zeros(n_nodes, dtype=np.int32)
59
+ powers = (-np.pi / 3.15) ** np.arange(n_nodes, dtype=np.double)
60
+ indptr = adjacency.indptr
61
+ indices = adjacency.indices
62
+ labels, _ = weisfeiler_lehman_coloring(indptr, indices, labels, powers, max_iter)
63
+ return np.array(labels)
64
+
65
+
66
+ def are_isomorphic(adjacency1: sparse.csr_matrix, adjacency2: sparse.csr_matrix, max_iter: int = -1) -> bool:
67
+ """Weisfeiler-Lehman isomorphism test. If the test is False, the graphs cannot be isomorphic.
99
68
 
100
69
  Parameters
101
70
  -----------
@@ -104,7 +73,7 @@ def are_isomorphic(adjacency1: sparse.csr_matrix,
104
73
  adjacency2 :
105
74
  Second adjacency matrix.
106
75
  max_iter : int
107
- Maximum number of coloring iterations. Negative value means until convergence.
76
+ Maximum number of iterations. Negative value means no limit (until convergence).
108
77
 
109
78
  Returns
110
79
  -------
@@ -112,7 +81,6 @@ def are_isomorphic(adjacency1: sparse.csr_matrix,
112
81
 
113
82
  Example
114
83
  -------
115
- >>> from sknetwork.topology import are_isomorphic
116
84
  >>> from sknetwork.data import house, bow_tie
117
85
  >>> are_isomorphic(house(), bow_tie())
118
86
  False
@@ -125,39 +93,41 @@ def are_isomorphic(adjacency1: sparse.csr_matrix,
125
93
 
126
94
  * Shervashidze, N., Schweitzer, P., van Leeuwen, E. J., Melhorn, K., Borgwardt, K. M. (2011)
127
95
  `Weisfeiler-Lehman graph kernels.
128
- <http://www.jmlr.org/papers/volume12/shervashidze11a/shervashidze11a.pdf>`_
96
+ <https://www.jmlr.org/papers/volume12/shervashidze11a/shervashidze11a.pdf>`_
129
97
  Journal of Machine Learning Research 12, 2011.
130
98
  """
99
+ adjacency1 = check_format(adjacency1)
100
+ check_square(adjacency1)
101
+ adjacency2 = check_format(adjacency2)
102
+ check_square(adjacency2)
103
+
131
104
  if (adjacency1.shape != adjacency2.shape) or (adjacency1.nnz != adjacency2.nnz):
132
105
  return False
133
106
 
134
- n = adjacency1.shape[0]
107
+ n_nodes = adjacency1.shape[0]
135
108
 
136
- if max_iter < 0 or max_iter > n:
137
- max_iter = n
109
+ if max_iter < 0 or max_iter > n_nodes:
110
+ max_iter = n_nodes
138
111
 
139
- indptr1 = adjacency1.indptr.astype(np.int32)
140
- indptr2 = adjacency2.indptr.astype(np.int32)
141
- indices1 = adjacency1.indices.astype(np.int32)
142
- indices2 = adjacency2.indices.astype(np.int32)
112
+ indptr1 = adjacency1.indptr
113
+ indptr2 = adjacency2.indptr
114
+ indices1 = adjacency1.indices
115
+ indices2 = adjacency2.indices
143
116
 
144
- labels_1 = np.zeros(n, dtype=np.int32)
145
- labels_2 = np.zeros(n, dtype=np.int32)
117
+ labels1 = np.zeros(n_nodes, dtype=np.int32)
118
+ labels2 = np.zeros(n_nodes, dtype=np.int32)
146
119
 
147
- powers = (- np.pi / 3.15) ** np.arange(n, dtype=np.double)
120
+ powers = (-np.pi / 3.15) ** np.arange(n_nodes, dtype=np.double)
148
121
 
149
122
  iteration = 0
150
- has_changed_1, has_changed_2 = True, True
151
- while iteration < max_iter and (has_changed_1 or has_changed_2):
152
- labels_1, has_changed_1 = weisfeiler_lehman_coloring(indptr1, indices1, labels_1, powers, max_iter=1)
153
- labels_2, has_changed_2 = weisfeiler_lehman_coloring(indptr2, indices2, labels_2, powers, max_iter=1)
154
-
155
- colors_1, counts_1 = np.unique(np.asarray(labels_1), return_counts=True)
156
- colors_2, counts_2 = np.unique(np.asarray(labels_2), return_counts=True)
157
-
158
- if (colors_1.shape != colors_2.shape) or (counts_1 != counts_2).any():
123
+ has_changed1, has_changed2 = True, True
124
+ while iteration < max_iter and (has_changed1 or has_changed2):
125
+ labels1, has_changed1 = weisfeiler_lehman_coloring(indptr1, indices1, labels1, powers, max_iter=1)
126
+ labels2, has_changed2 = weisfeiler_lehman_coloring(indptr2, indices2, labels2, powers, max_iter=1)
127
+ _, counts1 = np.unique(np.array(labels1), return_counts=True)
128
+ _, counts2 = np.unique(np.array(labels2), return_counts=True)
129
+ if (counts1 != counts2).any():
159
130
  return False
160
-
161
131
  iteration += 1
162
132
 
163
133
  return True