scikit-network 0.30.0__cp38-cp38-win_amd64.whl → 0.32.1__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-network might be problematic. Click here for more details.
- {scikit_network-0.30.0.dist-info → scikit_network-0.32.1.dist-info}/AUTHORS.rst +3 -0
- {scikit_network-0.30.0.dist-info → scikit_network-0.32.1.dist-info}/METADATA +31 -3
- scikit_network-0.32.1.dist-info/RECORD +228 -0
- {scikit_network-0.30.0.dist-info → scikit_network-0.32.1.dist-info}/WHEEL +1 -1
- sknetwork/__init__.py +1 -1
- sknetwork/base.py +67 -0
- sknetwork/classification/base.py +24 -24
- sknetwork/classification/base_rank.py +17 -25
- sknetwork/classification/diffusion.py +35 -35
- sknetwork/classification/knn.py +24 -21
- sknetwork/classification/metrics.py +1 -1
- sknetwork/classification/pagerank.py +10 -10
- sknetwork/classification/propagation.py +23 -20
- sknetwork/classification/tests/test_diffusion.py +13 -3
- sknetwork/classification/vote.cp38-win_amd64.pyd +0 -0
- sknetwork/classification/vote.cpp +14482 -10351
- sknetwork/classification/vote.pyx +1 -3
- sknetwork/clustering/__init__.py +3 -1
- sknetwork/clustering/base.py +36 -40
- sknetwork/clustering/kcenters.py +253 -0
- sknetwork/clustering/leiden.py +241 -0
- sknetwork/clustering/leiden_core.cp38-win_amd64.pyd +0 -0
- sknetwork/clustering/leiden_core.cpp +31564 -0
- sknetwork/clustering/leiden_core.pyx +124 -0
- sknetwork/clustering/louvain.py +133 -102
- sknetwork/clustering/louvain_core.cp38-win_amd64.pyd +0 -0
- sknetwork/clustering/louvain_core.cpp +22457 -18792
- sknetwork/clustering/louvain_core.pyx +86 -96
- sknetwork/clustering/postprocess.py +2 -2
- sknetwork/clustering/propagation_clustering.py +15 -19
- sknetwork/clustering/tests/test_API.py +8 -4
- sknetwork/clustering/tests/test_kcenters.py +92 -0
- sknetwork/clustering/tests/test_leiden.py +34 -0
- sknetwork/clustering/tests/test_louvain.py +3 -4
- sknetwork/data/__init__.py +2 -1
- sknetwork/data/base.py +28 -0
- sknetwork/data/load.py +38 -37
- sknetwork/data/models.py +18 -18
- sknetwork/data/parse.py +54 -33
- sknetwork/data/test_graphs.py +2 -2
- sknetwork/data/tests/test_API.py +1 -1
- sknetwork/data/tests/test_base.py +14 -0
- sknetwork/data/tests/test_load.py +1 -1
- sknetwork/data/tests/test_parse.py +9 -12
- sknetwork/data/tests/test_test_graphs.py +1 -2
- sknetwork/data/toy_graphs.py +18 -18
- sknetwork/embedding/__init__.py +0 -1
- sknetwork/embedding/base.py +21 -20
- sknetwork/embedding/force_atlas.py +3 -2
- sknetwork/embedding/louvain_embedding.py +2 -2
- sknetwork/embedding/random_projection.py +5 -3
- sknetwork/embedding/spectral.py +0 -73
- sknetwork/embedding/tests/test_API.py +4 -28
- sknetwork/embedding/tests/test_louvain_embedding.py +4 -9
- sknetwork/embedding/tests/test_random_projection.py +2 -2
- sknetwork/embedding/tests/test_spectral.py +5 -8
- sknetwork/embedding/tests/test_svd.py +1 -1
- sknetwork/gnn/base.py +4 -4
- sknetwork/gnn/base_layer.py +3 -3
- sknetwork/gnn/gnn_classifier.py +45 -89
- sknetwork/gnn/layer.py +1 -1
- sknetwork/gnn/loss.py +1 -1
- sknetwork/gnn/optimizer.py +4 -3
- sknetwork/gnn/tests/test_base_layer.py +4 -4
- sknetwork/gnn/tests/test_gnn_classifier.py +12 -35
- sknetwork/gnn/utils.py +8 -8
- sknetwork/hierarchy/base.py +29 -2
- sknetwork/hierarchy/louvain_hierarchy.py +45 -41
- sknetwork/hierarchy/paris.cp38-win_amd64.pyd +0 -0
- sknetwork/hierarchy/paris.cpp +27371 -22844
- sknetwork/hierarchy/paris.pyx +7 -9
- sknetwork/hierarchy/postprocess.py +16 -16
- sknetwork/hierarchy/tests/test_API.py +1 -1
- sknetwork/hierarchy/tests/test_algos.py +5 -0
- sknetwork/hierarchy/tests/test_metrics.py +1 -1
- sknetwork/linalg/__init__.py +1 -1
- sknetwork/linalg/diteration.cp38-win_amd64.pyd +0 -0
- sknetwork/linalg/diteration.cpp +13474 -9454
- sknetwork/linalg/diteration.pyx +0 -2
- sknetwork/linalg/eig_solver.py +1 -1
- sknetwork/linalg/{normalization.py → normalizer.py} +18 -15
- sknetwork/linalg/operators.py +1 -1
- sknetwork/linalg/ppr_solver.py +1 -1
- sknetwork/linalg/push.cp38-win_amd64.pyd +0 -0
- sknetwork/linalg/push.cpp +23003 -18807
- sknetwork/linalg/push.pyx +0 -2
- sknetwork/linalg/svd_solver.py +1 -1
- sknetwork/linalg/tests/test_normalization.py +3 -7
- sknetwork/linalg/tests/test_operators.py +4 -8
- sknetwork/linalg/tests/test_ppr.py +1 -1
- sknetwork/linkpred/base.py +13 -2
- sknetwork/linkpred/nn.py +6 -6
- sknetwork/log.py +19 -0
- sknetwork/path/__init__.py +4 -3
- sknetwork/path/dag.py +54 -0
- sknetwork/path/distances.py +98 -0
- sknetwork/path/search.py +13 -47
- sknetwork/path/shortest_path.py +37 -162
- sknetwork/path/tests/test_dag.py +37 -0
- sknetwork/path/tests/test_distances.py +62 -0
- sknetwork/path/tests/test_search.py +26 -11
- sknetwork/path/tests/test_shortest_path.py +31 -36
- sknetwork/ranking/__init__.py +0 -1
- sknetwork/ranking/base.py +13 -8
- sknetwork/ranking/betweenness.cp38-win_amd64.pyd +0 -0
- sknetwork/ranking/betweenness.cpp +5709 -3017
- sknetwork/ranking/betweenness.pyx +0 -2
- sknetwork/ranking/closeness.py +7 -10
- sknetwork/ranking/pagerank.py +14 -14
- sknetwork/ranking/postprocess.py +12 -3
- sknetwork/ranking/tests/test_API.py +2 -4
- sknetwork/ranking/tests/test_betweenness.py +3 -3
- sknetwork/ranking/tests/test_closeness.py +3 -7
- sknetwork/ranking/tests/test_pagerank.py +11 -5
- sknetwork/ranking/tests/test_postprocess.py +5 -0
- sknetwork/regression/base.py +19 -2
- sknetwork/regression/diffusion.py +24 -10
- sknetwork/regression/tests/test_diffusion.py +8 -0
- sknetwork/test_base.py +35 -0
- sknetwork/test_log.py +15 -0
- sknetwork/topology/__init__.py +7 -8
- sknetwork/topology/cliques.cp38-win_amd64.pyd +0 -0
- sknetwork/topology/{kcliques.cpp → cliques.cpp} +23423 -20277
- sknetwork/topology/cliques.pyx +149 -0
- sknetwork/topology/core.cp38-win_amd64.pyd +0 -0
- sknetwork/topology/{kcore.cpp → core.cpp} +21637 -18762
- sknetwork/topology/core.pyx +90 -0
- sknetwork/topology/cycles.py +243 -0
- sknetwork/topology/minheap.cp38-win_amd64.pyd +0 -0
- sknetwork/{utils → topology}/minheap.cpp +19452 -15368
- sknetwork/{utils → topology}/minheap.pxd +1 -3
- sknetwork/{utils → topology}/minheap.pyx +1 -3
- sknetwork/topology/structure.py +3 -43
- sknetwork/topology/tests/test_cliques.py +11 -11
- sknetwork/topology/tests/test_core.py +19 -0
- sknetwork/topology/tests/test_cycles.py +65 -0
- sknetwork/topology/tests/test_structure.py +2 -16
- sknetwork/topology/tests/test_triangles.py +11 -15
- sknetwork/topology/tests/test_wl.py +72 -0
- sknetwork/topology/triangles.cp38-win_amd64.pyd +0 -0
- sknetwork/topology/triangles.cpp +5056 -2696
- sknetwork/topology/triangles.pyx +74 -89
- sknetwork/topology/weisfeiler_lehman.py +56 -86
- sknetwork/topology/weisfeiler_lehman_core.cp38-win_amd64.pyd +0 -0
- sknetwork/topology/weisfeiler_lehman_core.cpp +14727 -10622
- sknetwork/topology/weisfeiler_lehman_core.pyx +0 -2
- sknetwork/utils/__init__.py +1 -31
- sknetwork/utils/check.py +2 -2
- sknetwork/utils/format.py +5 -3
- sknetwork/utils/membership.py +2 -2
- sknetwork/utils/tests/test_check.py +3 -3
- sknetwork/utils/tests/test_format.py +3 -1
- sknetwork/utils/values.py +1 -1
- sknetwork/visualization/__init__.py +2 -2
- sknetwork/visualization/dendrograms.py +55 -7
- sknetwork/visualization/graphs.py +292 -72
- sknetwork/visualization/tests/test_dendrograms.py +9 -9
- sknetwork/visualization/tests/test_graphs.py +71 -62
- scikit_network-0.30.0.dist-info/RECORD +0 -227
- sknetwork/embedding/louvain_hierarchy.py +0 -142
- sknetwork/embedding/tests/test_louvain_hierarchy.py +0 -19
- sknetwork/path/metrics.py +0 -148
- sknetwork/path/tests/test_metrics.py +0 -29
- sknetwork/ranking/harmonic.py +0 -82
- sknetwork/topology/dag.py +0 -74
- sknetwork/topology/dag_core.cp38-win_amd64.pyd +0 -0
- sknetwork/topology/dag_core.cpp +0 -23350
- sknetwork/topology/dag_core.pyx +0 -38
- sknetwork/topology/kcliques.cp38-win_amd64.pyd +0 -0
- sknetwork/topology/kcliques.pyx +0 -193
- sknetwork/topology/kcore.cp38-win_amd64.pyd +0 -0
- sknetwork/topology/kcore.pyx +0 -120
- sknetwork/topology/tests/test_cores.py +0 -21
- sknetwork/topology/tests/test_dag.py +0 -26
- sknetwork/topology/tests/test_wl_coloring.py +0 -49
- sknetwork/topology/tests/test_wl_kernel.py +0 -31
- sknetwork/utils/base.py +0 -35
- sknetwork/utils/minheap.cp38-win_amd64.pyd +0 -0
- sknetwork/utils/simplex.py +0 -140
- sknetwork/utils/tests/test_base.py +0 -28
- sknetwork/utils/tests/test_bunch.py +0 -16
- sknetwork/utils/tests/test_projection_simplex.py +0 -33
- sknetwork/utils/tests/test_verbose.py +0 -15
- sknetwork/utils/verbose.py +0 -37
- {scikit_network-0.30.0.dist-info → scikit_network-0.32.1.dist-info}/LICENSE +0 -0
- {scikit_network-0.30.0.dist-info → scikit_network-0.32.1.dist-info}/top_level.txt +0 -0
- /sknetwork/{utils → data}/timeout.py +0 -0
|
@@ -20,6 +20,10 @@ class TestParser(unittest.TestCase):
|
|
|
20
20
|
self.assertTrue((adjacency.indices == [2, 3, 0, 1, 5, 4]).all())
|
|
21
21
|
self.assertTrue((adjacency.indptr == [0, 1, 2, 3, 4, 5, 6]).all())
|
|
22
22
|
self.assertTrue((adjacency.data == [1, 1, 1, 1, 1, 1]).all())
|
|
23
|
+
adjacency = parse.from_csv(self.stub_data_1, shape=(7, 7))
|
|
24
|
+
self.assertTrue((adjacency.shape == (7, 7)))
|
|
25
|
+
biadjacency = parse.from_csv(self.stub_data_1, bipartite=True, shape=(7, 9))
|
|
26
|
+
self.assertTrue((biadjacency.shape == (7, 9)))
|
|
23
27
|
remove(self.stub_data_1)
|
|
24
28
|
|
|
25
29
|
def test_labeled_weighted(self):
|
|
@@ -33,13 +37,14 @@ class TestParser(unittest.TestCase):
|
|
|
33
37
|
self.assertTrue((adjacency.indptr == [0, 1, 2, 3, 4, 5, 6]).all())
|
|
34
38
|
self.assertTrue((adjacency.data == [1, 6, 5, 6, 1, 5]).all())
|
|
35
39
|
self.assertTrue((names == [' b', ' d', ' e', 'a', 'c', 'f']).all())
|
|
40
|
+
|
|
36
41
|
remove(self.stub_data_2)
|
|
37
42
|
|
|
38
43
|
def test_auto_reindex(self):
|
|
39
44
|
self.stub_data_4 = 'stub_4.txt'
|
|
40
45
|
with open(self.stub_data_4, "w") as text_file:
|
|
41
46
|
text_file.write('%stub\n14 31\n42 50\n0 12')
|
|
42
|
-
graph = parse.from_csv(self.stub_data_4)
|
|
47
|
+
graph = parse.from_csv(self.stub_data_4, reindex=True)
|
|
43
48
|
adjacency = graph.adjacency
|
|
44
49
|
names = graph.names
|
|
45
50
|
self.assertTrue((adjacency.data == [1, 1, 1, 1, 1, 1]).all())
|
|
@@ -164,23 +169,15 @@ class TestParser(unittest.TestCase):
|
|
|
164
169
|
self.stub_data_9 = 'stub_9.txt'
|
|
165
170
|
with open(self.stub_data_9, "w") as text_file:
|
|
166
171
|
text_file.write('#stub\n1 3\n4 5\n0 3')
|
|
167
|
-
graph = parse.from_csv(self.stub_data_9, bipartite=True)
|
|
172
|
+
graph = parse.from_csv(self.stub_data_9, bipartite=True, reindex=True)
|
|
168
173
|
biadjacency = graph.biadjacency
|
|
169
174
|
self.assertTrue((biadjacency.indices == [0, 0, 1]).all())
|
|
170
175
|
self.assertTrue((biadjacency.indptr == [0, 1, 2, 3]).all())
|
|
171
176
|
self.assertTrue((biadjacency.data == [1, 1, 1]).all())
|
|
177
|
+
biadjacency = parse.from_csv(self.stub_data_9, bipartite=True)
|
|
178
|
+
self.assertTrue(biadjacency.shape == (5, 6))
|
|
172
179
|
remove(self.stub_data_9)
|
|
173
180
|
|
|
174
|
-
def test_csv_adjacency_bipartite(self):
|
|
175
|
-
self.stub_data_10 = 'stub_10.txt'
|
|
176
|
-
with open(self.stub_data_10, "w") as text_file:
|
|
177
|
-
text_file.write('%stub\n3\n3\n0')
|
|
178
|
-
graph = parse.from_csv(self.stub_data_10, bipartite=True)
|
|
179
|
-
biadjacency = graph.biadjacency
|
|
180
|
-
self.assertTupleEqual(biadjacency.shape, (3, 2))
|
|
181
|
-
self.assertTrue((biadjacency.data == [1, 1, 1]).all())
|
|
182
|
-
remove(self.stub_data_10)
|
|
183
|
-
|
|
184
181
|
def test_edge_list(self):
|
|
185
182
|
edge_list_1 = [('Alice', 'Bob'), ('Carol', 'Alice')]
|
|
186
183
|
graph = parse.from_edge_list(edge_list_1)
|
|
@@ -5,7 +5,6 @@
|
|
|
5
5
|
@author: Nathan de Lara <nathan.delara@polytechnique.org>
|
|
6
6
|
@author: Thomas Bonald <tbonald@enst.fr>
|
|
7
7
|
"""
|
|
8
|
-
|
|
9
8
|
import unittest
|
|
10
9
|
|
|
11
10
|
from sknetwork.data.test_graphs import *
|
|
@@ -16,7 +15,7 @@ class TestTestGraphs(unittest.TestCase):
|
|
|
16
15
|
def test_undirected(self):
|
|
17
16
|
adjacency = test_graph()
|
|
18
17
|
self.assertEqual(adjacency.shape, (10, 10))
|
|
19
|
-
adjacency =
|
|
18
|
+
adjacency = test_disconnected_graph()
|
|
20
19
|
self.assertEqual(adjacency.shape, (10, 10))
|
|
21
20
|
|
|
22
21
|
def test_directed(self):
|
sknetwork/data/toy_graphs.py
CHANGED
|
@@ -11,7 +11,7 @@ from typing import Union
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
from scipy import sparse
|
|
13
13
|
|
|
14
|
-
from sknetwork.
|
|
14
|
+
from sknetwork.data.base import Bunch
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
def house(metadata: bool = False) -> Union[sparse.csr_matrix, Bunch]:
|
|
@@ -23,11 +23,11 @@ def house(metadata: bool = False) -> Union[sparse.csr_matrix, Bunch]:
|
|
|
23
23
|
Parameters
|
|
24
24
|
----------
|
|
25
25
|
metadata :
|
|
26
|
-
If ``True``, return a `
|
|
26
|
+
If ``True``, return a `Dataset` object with metadata.
|
|
27
27
|
|
|
28
28
|
Returns
|
|
29
29
|
-------
|
|
30
|
-
adjacency or graph : Union[sparse.csr_matrix,
|
|
30
|
+
adjacency or graph : Union[sparse.csr_matrix, Dataset]
|
|
31
31
|
Adjacency matrix or graph with metadata (positions).
|
|
32
32
|
|
|
33
33
|
Example
|
|
@@ -64,11 +64,11 @@ def bow_tie(metadata: bool = False) -> Union[sparse.csr_matrix, Bunch]:
|
|
|
64
64
|
Parameters
|
|
65
65
|
----------
|
|
66
66
|
metadata :
|
|
67
|
-
If ``True``, return a `
|
|
67
|
+
If ``True``, return a `Dataset` object with metadata.
|
|
68
68
|
|
|
69
69
|
Returns
|
|
70
70
|
-------
|
|
71
|
-
adjacency or graph : Union[sparse.csr_matrix,
|
|
71
|
+
adjacency or graph : Union[sparse.csr_matrix, Dataset]
|
|
72
72
|
Adjacency matrix or graph with metadata (positions).
|
|
73
73
|
|
|
74
74
|
Example
|
|
@@ -105,11 +105,11 @@ def karate_club(metadata: bool = False) -> Union[sparse.csr_matrix, Bunch]:
|
|
|
105
105
|
Parameters
|
|
106
106
|
----------
|
|
107
107
|
metadata :
|
|
108
|
-
If ``True``, return a `
|
|
108
|
+
If ``True``, return a `Dataset` object with metadata.
|
|
109
109
|
|
|
110
110
|
Returns
|
|
111
111
|
-------
|
|
112
|
-
adjacency or graph : Union[sparse.csr_matrix,
|
|
112
|
+
adjacency or graph : Union[sparse.csr_matrix, Dataset]
|
|
113
113
|
Adjacency matrix or graph with metadata (labels, positions).
|
|
114
114
|
|
|
115
115
|
Example
|
|
@@ -170,11 +170,11 @@ def miserables(metadata: bool = False) -> Union[sparse.csr_matrix, Bunch]:
|
|
|
170
170
|
Parameters
|
|
171
171
|
----------
|
|
172
172
|
metadata :
|
|
173
|
-
If ``True``, return a `
|
|
173
|
+
If ``True``, return a `Dataset` object with metadata.
|
|
174
174
|
|
|
175
175
|
Returns
|
|
176
176
|
-------
|
|
177
|
-
adjacency or graph : Union[sparse.csr_matrix,
|
|
177
|
+
adjacency or graph : Union[sparse.csr_matrix, Dataset]
|
|
178
178
|
Adjacency matrix or graph with metadata (names, positions).
|
|
179
179
|
|
|
180
180
|
Example
|
|
@@ -277,11 +277,11 @@ def painters(metadata: bool = False) -> Union[sparse.csr_matrix, Bunch]:
|
|
|
277
277
|
Parameters
|
|
278
278
|
----------
|
|
279
279
|
metadata :
|
|
280
|
-
If ``True``, return a `
|
|
280
|
+
If ``True``, return a `Dataset` object with metadata.
|
|
281
281
|
|
|
282
282
|
Returns
|
|
283
283
|
-------
|
|
284
|
-
adjacency or graph : Union[sparse.csr_matrix,
|
|
284
|
+
adjacency or graph : Union[sparse.csr_matrix, Dataset]
|
|
285
285
|
Adjacency matrix or graph with metadata (names, positions).
|
|
286
286
|
|
|
287
287
|
Example
|
|
@@ -330,7 +330,7 @@ def hourglass(metadata: bool = False) -> Union[sparse.csr_matrix, Bunch]:
|
|
|
330
330
|
|
|
331
331
|
Returns
|
|
332
332
|
-------
|
|
333
|
-
biadjacency or graph : Union[sparse.csr_matrix,
|
|
333
|
+
biadjacency or graph : Union[sparse.csr_matrix, Dataset]
|
|
334
334
|
Biadjacency matrix or graph.
|
|
335
335
|
|
|
336
336
|
Example
|
|
@@ -359,11 +359,11 @@ def star_wars(metadata: bool = False) -> Union[sparse.csr_matrix, Bunch]:
|
|
|
359
359
|
Parameters
|
|
360
360
|
----------
|
|
361
361
|
metadata :
|
|
362
|
-
If ``True``, return a `
|
|
362
|
+
If ``True``, return a `Dataset` object with metadata.
|
|
363
363
|
|
|
364
364
|
Returns
|
|
365
365
|
-------
|
|
366
|
-
biadjacency or graph : Union[sparse.csr_matrix,
|
|
366
|
+
biadjacency or graph : Union[sparse.csr_matrix, Dataset]
|
|
367
367
|
Biadjacency matrix or graph with metadata (names).
|
|
368
368
|
|
|
369
369
|
Example
|
|
@@ -403,11 +403,11 @@ def movie_actor(metadata: bool = False) -> Union[sparse.csr_matrix, Bunch]:
|
|
|
403
403
|
Parameters
|
|
404
404
|
----------
|
|
405
405
|
metadata :
|
|
406
|
-
If ``True``, return a `
|
|
406
|
+
If ``True``, return a `Dataset` object with metadata.
|
|
407
407
|
|
|
408
408
|
Returns
|
|
409
409
|
-------
|
|
410
|
-
biadjacency or graph : Union[sparse.csr_matrix,
|
|
410
|
+
biadjacency or graph : Union[sparse.csr_matrix, Dataset]
|
|
411
411
|
Biadjacency matrix or graph with metadata (names).
|
|
412
412
|
|
|
413
413
|
Example
|
|
@@ -465,11 +465,11 @@ def art_philo_science(metadata: bool = False) -> Union[sparse.csr_matrix, Bunch]
|
|
|
465
465
|
Parameters
|
|
466
466
|
----------
|
|
467
467
|
metadata :
|
|
468
|
-
If ``True``, return a `
|
|
468
|
+
If ``True``, return a `Dataset` object with metadata.
|
|
469
469
|
|
|
470
470
|
Returns
|
|
471
471
|
-------
|
|
472
|
-
adjacency or graph : Union[sparse.csr_matrix,
|
|
472
|
+
adjacency or graph : Union[sparse.csr_matrix, Dataset]
|
|
473
473
|
Adjacency matrix or graph with metadata (names, positions, labels, names_labels,
|
|
474
474
|
biadjacency, names_col).
|
|
475
475
|
|
sknetwork/embedding/__init__.py
CHANGED
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
from sknetwork.embedding.base import BaseEmbedding
|
|
3
3
|
from sknetwork.embedding.force_atlas import ForceAtlas
|
|
4
4
|
from sknetwork.embedding.louvain_embedding import LouvainEmbedding
|
|
5
|
-
from sknetwork.embedding.louvain_hierarchy import LouvainNE
|
|
6
5
|
from sknetwork.embedding.random_projection import RandomProjection
|
|
7
6
|
from sknetwork.embedding.spectral import Spectral
|
|
8
7
|
from sknetwork.embedding.spring import Spring
|
sknetwork/embedding/base.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
2
|
# -*- coding: utf-8 -*-
|
|
3
3
|
"""
|
|
4
|
-
Created
|
|
4
|
+
Created in November 2019
|
|
5
5
|
@author: Nathan de Lara <nathan.delara@polytechnique.org>
|
|
6
6
|
"""
|
|
7
7
|
from abc import ABC
|
|
8
|
-
from typing import Union
|
|
8
|
+
from typing import Optional, Union
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
from scipy import sparse
|
|
12
12
|
|
|
13
13
|
from sknetwork.topology.structure import is_connected
|
|
14
|
-
from sknetwork.
|
|
14
|
+
from sknetwork.base import Algorithm
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
class BaseEmbedding(Algorithm, ABC):
|
|
@@ -26,10 +26,19 @@ class BaseEmbedding(Algorithm, ABC):
|
|
|
26
26
|
embedding_col_ : array, shape = (n_col, n_components)
|
|
27
27
|
Embedding of the columns, for bipartite graphs.
|
|
28
28
|
"""
|
|
29
|
-
|
|
30
29
|
def __init__(self):
|
|
31
30
|
self._init_vars()
|
|
32
31
|
|
|
32
|
+
def transform(self) -> np.ndarray:
|
|
33
|
+
"""Return the embedding.
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
embedding : np.ndarray
|
|
38
|
+
Embedding.
|
|
39
|
+
"""
|
|
40
|
+
return self.embedding_
|
|
41
|
+
|
|
33
42
|
def fit_transform(self, *args, **kwargs) -> np.ndarray:
|
|
34
43
|
"""Fit to data and return the embedding. Same parameters as the ``fit`` method.
|
|
35
44
|
|
|
@@ -41,30 +50,22 @@ class BaseEmbedding(Algorithm, ABC):
|
|
|
41
50
|
self.fit(*args, **kwargs)
|
|
42
51
|
return self.embedding_
|
|
43
52
|
|
|
44
|
-
def predict(self,
|
|
45
|
-
"""
|
|
46
|
-
|
|
47
|
-
Each new node is defined by its adjacency row vector.
|
|
53
|
+
def predict(self, columns: bool = False) -> np.ndarray:
|
|
54
|
+
"""Return the embedding of nodes.
|
|
48
55
|
|
|
49
56
|
Parameters
|
|
50
57
|
----------
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
Array of shape (n_col,) (single vector) or (n_vectors, n_col)
|
|
58
|
+
columns : bool
|
|
59
|
+
If ``True``, return the prediction for columns.
|
|
54
60
|
|
|
55
61
|
Returns
|
|
56
62
|
-------
|
|
57
|
-
|
|
63
|
+
embedding_ : np.ndarray
|
|
58
64
|
Embedding of the nodes.
|
|
59
65
|
"""
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
if self.embedding_ is None:
|
|
64
|
-
raise ValueError("This embedding instance is not fitted yet."
|
|
65
|
-
" Call 'fit' with appropriate arguments before using this method.")
|
|
66
|
-
else:
|
|
67
|
-
return self
|
|
66
|
+
if columns:
|
|
67
|
+
return self.embedding_col_
|
|
68
|
+
return self.embedding_
|
|
68
69
|
|
|
69
70
|
@staticmethod
|
|
70
71
|
def _get_regularization(regularization: float, adjacency: sparse.csr_matrix) -> float:
|
|
@@ -77,6 +77,7 @@ class ForceAtlas(BaseEmbedding):
|
|
|
77
77
|
self.tolerance = tolerance
|
|
78
78
|
self.speed = speed
|
|
79
79
|
self.speed_max = speed_max
|
|
80
|
+
self.embedding_ = None
|
|
80
81
|
|
|
81
82
|
def fit(self, adjacency: Union[sparse.csr_matrix, np.ndarray], pos_init: Optional[np.ndarray] = None,
|
|
82
83
|
n_iter: Optional[int] = None) -> 'ForceAtlas':
|
|
@@ -155,7 +156,7 @@ class ForceAtlas(BaseEmbedding):
|
|
|
155
156
|
if tree is None:
|
|
156
157
|
neighbors = np.arange(n)
|
|
157
158
|
else:
|
|
158
|
-
neighbors = tree.query_ball_point(position[i], self.approx_radius)
|
|
159
|
+
neighbors = tree.query_ball_point(position[i], self.approx_radius, p=2)
|
|
159
160
|
|
|
160
161
|
grad: np.ndarray = (position[i] - position[neighbors]) # shape (n_neigh, n_components)
|
|
161
162
|
distance: np.ndarray = np.linalg.norm(grad, axis=1) # shape (n_neigh,)
|
|
@@ -191,7 +192,7 @@ class ForceAtlas(BaseEmbedding):
|
|
|
191
192
|
|
|
192
193
|
position += delta # calculating displacement and final position of points after iteration
|
|
193
194
|
if (swing_vector < 1).all():
|
|
194
|
-
break # if the swing of all nodes is zero, then convergence is reached
|
|
195
|
+
break # if the swing of all nodes is zero, then convergence is reached.
|
|
195
196
|
|
|
196
197
|
self.embedding_ = position
|
|
197
198
|
return self
|
|
@@ -12,7 +12,7 @@ from scipy import sparse
|
|
|
12
12
|
|
|
13
13
|
from sknetwork.clustering.louvain import Louvain
|
|
14
14
|
from sknetwork.embedding.base import BaseEmbedding
|
|
15
|
-
from sknetwork.linalg.
|
|
15
|
+
from sknetwork.linalg.normalizer import normalize
|
|
16
16
|
from sknetwork.utils.check import check_random_state, check_adjacency_vector, check_nonnegative, is_square
|
|
17
17
|
from sknetwork.utils.membership import get_membership
|
|
18
18
|
|
|
@@ -121,7 +121,7 @@ class LouvainEmbedding(BaseEmbedding):
|
|
|
121
121
|
louvain = Louvain(resolution=self.resolution, modularity=self.modularity,
|
|
122
122
|
tol_optimization=self.tol_optimization, tol_aggregation=self.tol_aggregation,
|
|
123
123
|
n_aggregations=self.n_aggregations, shuffle_nodes=self.shuffle_nodes, sort_clusters=False,
|
|
124
|
-
|
|
124
|
+
return_probs=True, return_aggregate=True, random_state=self.random_state)
|
|
125
125
|
louvain.fit(input_matrix, force_bipartite=force_bipartite)
|
|
126
126
|
|
|
127
127
|
# isolated nodes
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
2
|
# coding: utf-8
|
|
3
3
|
"""
|
|
4
|
-
Created
|
|
4
|
+
Created in January 2021
|
|
5
5
|
@author: Thomas Bonald <bonald@enst.fr>
|
|
6
6
|
"""
|
|
7
|
+
from abc import ABC
|
|
7
8
|
from typing import Union
|
|
8
9
|
|
|
9
10
|
import numpy as np
|
|
@@ -15,7 +16,7 @@ from sknetwork.utils.check import check_format, check_random_state
|
|
|
15
16
|
from sknetwork.utils.format import get_adjacency
|
|
16
17
|
|
|
17
18
|
|
|
18
|
-
class RandomProjection(BaseEmbedding):
|
|
19
|
+
class RandomProjection(BaseEmbedding, ABC):
|
|
19
20
|
"""Embedding of graphs based the random projection of the adjacency matrix:
|
|
20
21
|
|
|
21
22
|
:math:`(I + \\alpha A +... + (\\alpha A)^K)G`
|
|
@@ -71,6 +72,7 @@ class RandomProjection(BaseEmbedding):
|
|
|
71
72
|
regularization: float = -1, normalized: bool = True, random_state: int = None):
|
|
72
73
|
super(RandomProjection, self).__init__()
|
|
73
74
|
|
|
75
|
+
self.embedding_ = None
|
|
74
76
|
self.n_components = n_components
|
|
75
77
|
self.alpha = alpha
|
|
76
78
|
self.n_iter = n_iter
|
|
@@ -87,7 +89,7 @@ class RandomProjection(BaseEmbedding):
|
|
|
87
89
|
|
|
88
90
|
Parameters
|
|
89
91
|
----------
|
|
90
|
-
input_matrix :
|
|
92
|
+
input_matrix : sparse.csr_matrix, np.ndarray
|
|
91
93
|
Adjacency matrix or biadjacency matrix of the graph.
|
|
92
94
|
force_bipartite : bool (default = ``False``)
|
|
93
95
|
If ``True``, force the input matrix to be considered as a biadjacency matrix.
|
sknetwork/embedding/spectral.py
CHANGED
|
@@ -139,76 +139,3 @@ class Spectral(BaseEmbedding):
|
|
|
139
139
|
self._split_vars(input_matrix.shape)
|
|
140
140
|
|
|
141
141
|
return self
|
|
142
|
-
|
|
143
|
-
def predict(self, adjacency_vectors: Union[sparse.csr_matrix, np.ndarray]) -> np.ndarray:
|
|
144
|
-
"""Predict the embedding of new nodes, when possible (otherwise return 0).
|
|
145
|
-
|
|
146
|
-
Each new node is defined by its adjacency row vector.
|
|
147
|
-
|
|
148
|
-
Parameters
|
|
149
|
-
----------
|
|
150
|
-
adjacency_vectors :
|
|
151
|
-
Adjacency vectors of nodes.
|
|
152
|
-
Array of shape (n_col,) (single vector) or (n_vectors, n_col)
|
|
153
|
-
|
|
154
|
-
Returns
|
|
155
|
-
-------
|
|
156
|
-
embedding_vectors : np.ndarray
|
|
157
|
-
Embedding of the nodes.
|
|
158
|
-
|
|
159
|
-
Example
|
|
160
|
-
-------
|
|
161
|
-
>>> from sknetwork.embedding import Spectral
|
|
162
|
-
>>> from sknetwork.data import karate_club
|
|
163
|
-
>>> spectral = Spectral(n_components=3)
|
|
164
|
-
>>> adjacency = karate_club()
|
|
165
|
-
>>> adjacency_vector = np.arange(34) < 5
|
|
166
|
-
>>> _ = spectral.fit(adjacency)
|
|
167
|
-
>>> len(spectral.predict(adjacency_vector))
|
|
168
|
-
3
|
|
169
|
-
"""
|
|
170
|
-
self._check_fitted()
|
|
171
|
-
|
|
172
|
-
# input
|
|
173
|
-
if self.bipartite:
|
|
174
|
-
n = len(self.embedding_col_)
|
|
175
|
-
else:
|
|
176
|
-
n = len(self.embedding_)
|
|
177
|
-
adjacency_vectors = check_adjacency_vector(adjacency_vectors, n)
|
|
178
|
-
check_nonnegative(adjacency_vectors)
|
|
179
|
-
|
|
180
|
-
if self.bipartite:
|
|
181
|
-
shape = (adjacency_vectors.shape[0], self.embedding_row_.shape[0])
|
|
182
|
-
adjacency_vectors = sparse.csr_matrix(adjacency_vectors)
|
|
183
|
-
adjacency_vectors = sparse.hstack([sparse.csr_matrix(shape), adjacency_vectors], format='csr')
|
|
184
|
-
eigenvectors = self.eigenvectors_
|
|
185
|
-
eigenvalues = self.eigenvalues_
|
|
186
|
-
|
|
187
|
-
# regularization
|
|
188
|
-
if self.regularized:
|
|
189
|
-
regularization = np.abs(self.regularization)
|
|
190
|
-
else:
|
|
191
|
-
regularization = 0
|
|
192
|
-
normalizer = Normalizer(adjacency_vectors, regularization)
|
|
193
|
-
|
|
194
|
-
# prediction
|
|
195
|
-
embedding_vectors = normalizer.dot(eigenvectors)
|
|
196
|
-
normalized_laplacian = self.decomposition == 'rw'
|
|
197
|
-
if normalized_laplacian:
|
|
198
|
-
norm_vect = eigenvalues.copy()
|
|
199
|
-
norm_vect[norm_vect == 0] = 1
|
|
200
|
-
embedding_vectors /= norm_vect
|
|
201
|
-
else:
|
|
202
|
-
norm_matrix = sparse.csr_matrix(1 - np.outer(normalizer.norm_diag.data, eigenvalues))
|
|
203
|
-
norm_matrix.data = 1 / norm_matrix.data
|
|
204
|
-
embedding_vectors *= norm_matrix.toarray()
|
|
205
|
-
|
|
206
|
-
# normalization
|
|
207
|
-
if self.normalized:
|
|
208
|
-
embedding_vectors = normalize(embedding_vectors, p=2)
|
|
209
|
-
|
|
210
|
-
# shape
|
|
211
|
-
if len(embedding_vectors) == 1:
|
|
212
|
-
embedding_vectors = embedding_vectors.ravel()
|
|
213
|
-
|
|
214
|
-
return embedding_vectors
|
|
@@ -13,7 +13,6 @@ class TestEmbeddings(unittest.TestCase):
|
|
|
13
13
|
def setUp(self):
|
|
14
14
|
"""Algorithms by input types."""
|
|
15
15
|
self.methods = [Spectral(), GSVD(), SVD()]
|
|
16
|
-
self.bimethods = [GSVD(), SVD()]
|
|
17
16
|
|
|
18
17
|
def test_undirected(self):
|
|
19
18
|
adjacency = test_graph()
|
|
@@ -22,44 +21,21 @@ class TestEmbeddings(unittest.TestCase):
|
|
|
22
21
|
method = Spring()
|
|
23
22
|
embedding = method.fit_transform(adjacency)
|
|
24
23
|
self.assertEqual(embedding.shape, (n, 2))
|
|
25
|
-
pred1 = method.predict(adjacency[0])
|
|
26
|
-
pred2 = method.predict(adjacency[0].toarray())
|
|
27
|
-
self.assertEqual(pred1.shape, (2,))
|
|
28
|
-
self.assertAlmostEqual(np.linalg.norm(pred1 - pred2), 0)
|
|
29
24
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
self.assertTupleEqual(pred1.shape, (n, 2))
|
|
33
|
-
self.assertAlmostEqual(np.linalg.norm(pred1 - pred2), 0)
|
|
34
|
-
|
|
35
|
-
def test_bimethods(self):
|
|
25
|
+
embedding = method.transform()
|
|
26
|
+
self.assertEqual(embedding.shape, (n, 2))
|
|
36
27
|
|
|
28
|
+
def test_bipartite(self):
|
|
37
29
|
for adjacency in [test_digraph(), test_bigraph()]:
|
|
38
30
|
n_row, n_col = adjacency.shape
|
|
39
31
|
|
|
40
|
-
for method in self.
|
|
32
|
+
for method in self.methods:
|
|
41
33
|
method.fit(adjacency)
|
|
42
34
|
|
|
43
35
|
self.assertEqual(method.embedding_.shape, (n_row, 2))
|
|
44
36
|
self.assertEqual(method.embedding_row_.shape, (n_row, 2))
|
|
45
37
|
self.assertEqual(method.embedding_col_.shape, (n_col, 2))
|
|
46
38
|
|
|
47
|
-
ref = method.embedding_[0]
|
|
48
|
-
pred1 = method.predict(adjacency[0])
|
|
49
|
-
pred2 = method.predict(adjacency[0].toarray())
|
|
50
|
-
|
|
51
|
-
self.assertEqual(pred1.shape, (2,))
|
|
52
|
-
self.assertAlmostEqual(np.linalg.norm(pred1 - pred2), 0)
|
|
53
|
-
self.assertAlmostEqual(np.linalg.norm(pred1 - ref), 0)
|
|
54
|
-
|
|
55
|
-
ref = method.embedding_
|
|
56
|
-
pred1 = method.predict(adjacency)
|
|
57
|
-
pred2 = method.predict(adjacency.toarray())
|
|
58
|
-
|
|
59
|
-
self.assertTupleEqual(pred1.shape, (n_row, 2))
|
|
60
|
-
self.assertAlmostEqual(np.linalg.norm(pred1 - pred2), 0)
|
|
61
|
-
self.assertAlmostEqual(np.linalg.norm(pred1 - ref), 0)
|
|
62
|
-
|
|
63
39
|
def test_disconnected(self):
|
|
64
40
|
n = 10
|
|
65
41
|
adjacency = np.eye(n)
|
|
@@ -12,22 +12,17 @@ from sknetwork.embedding import LouvainEmbedding
|
|
|
12
12
|
class TestLouvainEmbedding(unittest.TestCase):
|
|
13
13
|
|
|
14
14
|
def test_predict(self):
|
|
15
|
+
adjacency = test_graph()
|
|
15
16
|
louvain = LouvainEmbedding()
|
|
16
17
|
louvain.fit(test_graph())
|
|
17
18
|
self.assertEqual(louvain.embedding_.shape[0], 10)
|
|
18
|
-
louvain.fit(
|
|
19
|
+
louvain.fit(adjacency, force_bipartite=True)
|
|
19
20
|
self.assertEqual(louvain.embedding_.shape[0], 10)
|
|
20
21
|
|
|
21
22
|
for method in ['remove', 'merge', 'keep']:
|
|
22
23
|
louvain = LouvainEmbedding(isolated_nodes=method)
|
|
23
|
-
louvain.
|
|
24
|
-
|
|
25
|
-
self.assertEqual(embedding_vector.shape[0], 1)
|
|
24
|
+
embedding = louvain.fit_transform(adjacency)
|
|
25
|
+
self.assertEqual(embedding.shape[0], adjacency.shape[0])
|
|
26
26
|
|
|
27
|
-
for method in ['remove', 'merge', 'keep']:
|
|
28
|
-
bilouvain = LouvainEmbedding(isolated_nodes=method)
|
|
29
|
-
bilouvain.fit(test_bigraph())
|
|
30
|
-
embedding_vector = bilouvain.predict(np.array([1, 0, 0, 0, 1, 1, 0, 1]))
|
|
31
|
-
self.assertEqual(embedding_vector.shape[0], 1)
|
|
32
27
|
|
|
33
28
|
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
"""Tests for random projection"""
|
|
4
4
|
import unittest
|
|
5
5
|
|
|
6
|
-
from sknetwork.data.test_graphs import test_graph, test_bigraph, test_digraph,
|
|
6
|
+
from sknetwork.data.test_graphs import test_graph, test_bigraph, test_digraph, test_disconnected_graph
|
|
7
7
|
from sknetwork.embedding import RandomProjection
|
|
8
8
|
|
|
9
9
|
|
|
@@ -19,7 +19,7 @@ class TestEmbeddings(unittest.TestCase):
|
|
|
19
19
|
adjacency = test_digraph()
|
|
20
20
|
embedding = algo.fit_transform(adjacency)
|
|
21
21
|
self.assertEqual(embedding.shape[1], 2)
|
|
22
|
-
adjacency =
|
|
22
|
+
adjacency = test_disconnected_graph()
|
|
23
23
|
embedding = algo.fit_transform(adjacency)
|
|
24
24
|
self.assertEqual(embedding.shape[1], 2)
|
|
25
25
|
biadjacency = test_bigraph()
|
|
@@ -13,7 +13,7 @@ from sknetwork.utils.format import bipartite2undirected
|
|
|
13
13
|
class TestEmbeddings(unittest.TestCase):
|
|
14
14
|
|
|
15
15
|
def test_undirected(self):
|
|
16
|
-
for adjacency in [test_graph(),
|
|
16
|
+
for adjacency in [test_graph(), test_disconnected_graph()]:
|
|
17
17
|
n = adjacency.shape[0]
|
|
18
18
|
# random walk
|
|
19
19
|
spectral = Spectral(3, normalized=False)
|
|
@@ -22,27 +22,24 @@ class TestEmbeddings(unittest.TestCase):
|
|
|
22
22
|
if not is_weakly_connected(adjacency):
|
|
23
23
|
weights += 1
|
|
24
24
|
self.assertAlmostEqual(np.linalg.norm(embedding.T.dot(weights)), 0)
|
|
25
|
-
self.assertAlmostEqual(np.linalg.norm(embedding[1:4] - spectral.predict(adjacency[1:4])), 0)
|
|
26
25
|
# Laplacian
|
|
27
26
|
spectral = Spectral(3, decomposition='laplacian', normalized=False)
|
|
28
27
|
embedding = spectral.fit_transform(adjacency)
|
|
29
28
|
self.assertAlmostEqual(np.linalg.norm(embedding.sum(axis=0)), 0)
|
|
30
|
-
self.assertAlmostEqual(np.linalg.norm(embedding[1:4] - spectral.predict(adjacency[1:4])), 0)
|
|
31
29
|
|
|
32
30
|
def test_directed(self):
|
|
33
31
|
for adjacency in [test_digraph(), test_digraph().astype(bool)]:
|
|
34
32
|
# random walk
|
|
35
33
|
spectral = Spectral(3, normalized=False)
|
|
36
34
|
embedding = spectral.fit_transform(adjacency)
|
|
37
|
-
self.assertAlmostEqual(
|
|
35
|
+
self.assertAlmostEqual(embedding.shape[0], adjacency.shape[0])
|
|
38
36
|
# Laplacian
|
|
39
37
|
spectral = Spectral(3, decomposition='laplacian', normalized=False)
|
|
40
|
-
|
|
38
|
+
spectral.fit(adjacency)
|
|
41
39
|
self.assertAlmostEqual(np.linalg.norm(spectral.eigenvectors_.sum(axis=0)), 0)
|
|
42
|
-
self.assertAlmostEqual(np.linalg.norm(embedding[6:8] - spectral.predict(adjacency[6:8])), 0)
|
|
43
40
|
|
|
44
41
|
def test_regularization(self):
|
|
45
|
-
for adjacency in [test_graph(),
|
|
42
|
+
for adjacency in [test_graph(), test_disconnected_graph()]:
|
|
46
43
|
n = adjacency.shape[0]
|
|
47
44
|
# random walk
|
|
48
45
|
regularization = 0.1
|
|
@@ -78,7 +75,7 @@ class TestEmbeddings(unittest.TestCase):
|
|
|
78
75
|
self.assertAlmostEqual(np.linalg.norm(embedding_full.sum(axis=0)), 0)
|
|
79
76
|
|
|
80
77
|
def test_normalization(self):
|
|
81
|
-
for adjacency in [test_graph(),
|
|
78
|
+
for adjacency in [test_graph(), test_disconnected_graph()]:
|
|
82
79
|
spectral = Spectral(3)
|
|
83
80
|
embedding = spectral.fit_transform(adjacency)
|
|
84
81
|
self.assertAlmostEqual(np.linalg.norm(np.linalg.norm(embedding, axis=1) - np.ones(adjacency.shape[0])), 0)
|
|
@@ -26,7 +26,7 @@ class TestSVD(unittest.TestCase):
|
|
|
26
26
|
|
|
27
27
|
gsvd = GSVD(n_components=1, regularization=0.1, solver='lanczos')
|
|
28
28
|
gsvd.fit(biadjacency)
|
|
29
|
-
|
|
29
|
+
self.assertEqual(gsvd.embedding_row_.shape, (n_row, 1))
|
|
30
30
|
|
|
31
31
|
pca = PCA(n_components=min_dim, solver='lanczos')
|
|
32
32
|
pca.fit(biadjacency)
|
sknetwork/gnn/base.py
CHANGED
|
@@ -15,11 +15,11 @@ from scipy import sparse
|
|
|
15
15
|
|
|
16
16
|
from sknetwork.gnn.loss import BaseLoss, get_loss
|
|
17
17
|
from sknetwork.gnn.optimizer import BaseOptimizer, get_optimizer
|
|
18
|
-
from sknetwork.
|
|
19
|
-
from sknetwork.
|
|
18
|
+
from sknetwork.base import Algorithm
|
|
19
|
+
from sknetwork.log import Log
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
class BaseGNN(
|
|
22
|
+
class BaseGNN(ABC, Algorithm, Log):
|
|
23
23
|
"""Base class for GNNs.
|
|
24
24
|
|
|
25
25
|
Parameters
|
|
@@ -47,7 +47,7 @@ class BaseGNN(Algorithm, ABC, VerboseMixin):
|
|
|
47
47
|
"""
|
|
48
48
|
def __init__(self, loss: Union[BaseLoss, str] = 'CrossEntropy', optimizer: Union[BaseOptimizer, str] = 'Adam',
|
|
49
49
|
learning_rate: float = 0.01, verbose: bool = False):
|
|
50
|
-
|
|
50
|
+
Log.__init__(self, verbose)
|
|
51
51
|
self.optimizer = get_optimizer(optimizer, learning_rate)
|
|
52
52
|
self.loss = get_loss(loss)
|
|
53
53
|
self.layers = []
|
sknetwork/gnn/base_layer.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
2
|
# -*- coding: utf-8 -*-
|
|
3
3
|
"""
|
|
4
|
-
Created
|
|
4
|
+
Created in July 2022
|
|
5
5
|
@author: Simon Delarue <sdelarue@enst.fr>
|
|
6
6
|
"""
|
|
7
7
|
from typing import Optional, Union
|
|
@@ -73,10 +73,10 @@ class BaseLayer:
|
|
|
73
73
|
in_channels: int
|
|
74
74
|
Number of input channels.
|
|
75
75
|
"""
|
|
76
|
-
#
|
|
76
|
+
# He initialization
|
|
77
77
|
self.weight = np.random.randn(in_channels, self.out_channels) * np.sqrt(2 / self.out_channels)
|
|
78
78
|
if self.use_bias:
|
|
79
|
-
self.bias = np.zeros((self.out_channels
|
|
79
|
+
self.bias = np.zeros((1, self.out_channels))
|
|
80
80
|
self.weights_initialized = True
|
|
81
81
|
|
|
82
82
|
def forward(self, *args, **kwargs):
|