scikit-network 0.33.3__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-network might be problematic. Click here for more details.
- scikit_network-0.33.3.dist-info/METADATA +122 -0
- scikit_network-0.33.3.dist-info/RECORD +228 -0
- scikit_network-0.33.3.dist-info/WHEEL +5 -0
- scikit_network-0.33.3.dist-info/licenses/AUTHORS.rst +43 -0
- scikit_network-0.33.3.dist-info/licenses/LICENSE +34 -0
- scikit_network-0.33.3.dist-info/top_level.txt +1 -0
- sknetwork/__init__.py +21 -0
- sknetwork/base.py +67 -0
- sknetwork/classification/__init__.py +8 -0
- sknetwork/classification/base.py +142 -0
- sknetwork/classification/base_rank.py +133 -0
- sknetwork/classification/diffusion.py +134 -0
- sknetwork/classification/knn.py +139 -0
- sknetwork/classification/metrics.py +205 -0
- sknetwork/classification/pagerank.py +66 -0
- sknetwork/classification/propagation.py +152 -0
- sknetwork/classification/tests/__init__.py +1 -0
- sknetwork/classification/tests/test_API.py +30 -0
- sknetwork/classification/tests/test_diffusion.py +77 -0
- sknetwork/classification/tests/test_knn.py +23 -0
- sknetwork/classification/tests/test_metrics.py +53 -0
- sknetwork/classification/tests/test_pagerank.py +20 -0
- sknetwork/classification/tests/test_propagation.py +24 -0
- sknetwork/classification/vote.cp313-win_amd64.pyd +0 -0
- sknetwork/classification/vote.cpp +27584 -0
- sknetwork/classification/vote.pyx +56 -0
- sknetwork/clustering/__init__.py +8 -0
- sknetwork/clustering/base.py +172 -0
- sknetwork/clustering/kcenters.py +253 -0
- sknetwork/clustering/leiden.py +242 -0
- sknetwork/clustering/leiden_core.cp313-win_amd64.pyd +0 -0
- sknetwork/clustering/leiden_core.cpp +31575 -0
- sknetwork/clustering/leiden_core.pyx +124 -0
- sknetwork/clustering/louvain.py +286 -0
- sknetwork/clustering/louvain_core.cp313-win_amd64.pyd +0 -0
- sknetwork/clustering/louvain_core.cpp +31220 -0
- sknetwork/clustering/louvain_core.pyx +124 -0
- sknetwork/clustering/metrics.py +91 -0
- sknetwork/clustering/postprocess.py +66 -0
- sknetwork/clustering/propagation_clustering.py +104 -0
- sknetwork/clustering/tests/__init__.py +1 -0
- sknetwork/clustering/tests/test_API.py +38 -0
- sknetwork/clustering/tests/test_kcenters.py +60 -0
- sknetwork/clustering/tests/test_leiden.py +34 -0
- sknetwork/clustering/tests/test_louvain.py +135 -0
- sknetwork/clustering/tests/test_metrics.py +50 -0
- sknetwork/clustering/tests/test_postprocess.py +39 -0
- sknetwork/data/__init__.py +6 -0
- sknetwork/data/base.py +33 -0
- sknetwork/data/load.py +406 -0
- sknetwork/data/models.py +459 -0
- sknetwork/data/parse.py +644 -0
- sknetwork/data/test_graphs.py +84 -0
- sknetwork/data/tests/__init__.py +1 -0
- sknetwork/data/tests/test_API.py +30 -0
- sknetwork/data/tests/test_base.py +14 -0
- sknetwork/data/tests/test_load.py +95 -0
- sknetwork/data/tests/test_models.py +52 -0
- sknetwork/data/tests/test_parse.py +250 -0
- sknetwork/data/tests/test_test_graphs.py +29 -0
- sknetwork/data/tests/test_toy_graphs.py +68 -0
- sknetwork/data/timeout.py +38 -0
- sknetwork/data/toy_graphs.py +611 -0
- sknetwork/embedding/__init__.py +8 -0
- sknetwork/embedding/base.py +94 -0
- sknetwork/embedding/force_atlas.py +198 -0
- sknetwork/embedding/louvain_embedding.py +148 -0
- sknetwork/embedding/random_projection.py +135 -0
- sknetwork/embedding/spectral.py +141 -0
- sknetwork/embedding/spring.py +198 -0
- sknetwork/embedding/svd.py +359 -0
- sknetwork/embedding/tests/__init__.py +1 -0
- sknetwork/embedding/tests/test_API.py +49 -0
- sknetwork/embedding/tests/test_force_atlas.py +35 -0
- sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
- sknetwork/embedding/tests/test_random_projection.py +28 -0
- sknetwork/embedding/tests/test_spectral.py +81 -0
- sknetwork/embedding/tests/test_spring.py +50 -0
- sknetwork/embedding/tests/test_svd.py +43 -0
- sknetwork/gnn/__init__.py +10 -0
- sknetwork/gnn/activation.py +117 -0
- sknetwork/gnn/base.py +181 -0
- sknetwork/gnn/base_activation.py +90 -0
- sknetwork/gnn/base_layer.py +109 -0
- sknetwork/gnn/gnn_classifier.py +305 -0
- sknetwork/gnn/layer.py +153 -0
- sknetwork/gnn/loss.py +180 -0
- sknetwork/gnn/neighbor_sampler.py +65 -0
- sknetwork/gnn/optimizer.py +164 -0
- sknetwork/gnn/tests/__init__.py +1 -0
- sknetwork/gnn/tests/test_activation.py +56 -0
- sknetwork/gnn/tests/test_base.py +75 -0
- sknetwork/gnn/tests/test_base_layer.py +37 -0
- sknetwork/gnn/tests/test_gnn_classifier.py +130 -0
- sknetwork/gnn/tests/test_layers.py +80 -0
- sknetwork/gnn/tests/test_loss.py +33 -0
- sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
- sknetwork/gnn/tests/test_optimizer.py +43 -0
- sknetwork/gnn/tests/test_utils.py +41 -0
- sknetwork/gnn/utils.py +127 -0
- sknetwork/hierarchy/__init__.py +6 -0
- sknetwork/hierarchy/base.py +96 -0
- sknetwork/hierarchy/louvain_hierarchy.py +272 -0
- sknetwork/hierarchy/metrics.py +234 -0
- sknetwork/hierarchy/paris.cp313-win_amd64.pyd +0 -0
- sknetwork/hierarchy/paris.cpp +37868 -0
- sknetwork/hierarchy/paris.pyx +316 -0
- sknetwork/hierarchy/postprocess.py +350 -0
- sknetwork/hierarchy/tests/__init__.py +1 -0
- sknetwork/hierarchy/tests/test_API.py +24 -0
- sknetwork/hierarchy/tests/test_algos.py +34 -0
- sknetwork/hierarchy/tests/test_metrics.py +62 -0
- sknetwork/hierarchy/tests/test_postprocess.py +57 -0
- sknetwork/linalg/__init__.py +9 -0
- sknetwork/linalg/basics.py +37 -0
- sknetwork/linalg/diteration.cp313-win_amd64.pyd +0 -0
- sknetwork/linalg/diteration.cpp +27400 -0
- sknetwork/linalg/diteration.pyx +47 -0
- sknetwork/linalg/eig_solver.py +93 -0
- sknetwork/linalg/laplacian.py +15 -0
- sknetwork/linalg/normalizer.py +86 -0
- sknetwork/linalg/operators.py +225 -0
- sknetwork/linalg/polynome.py +76 -0
- sknetwork/linalg/ppr_solver.py +170 -0
- sknetwork/linalg/push.cp313-win_amd64.pyd +0 -0
- sknetwork/linalg/push.cpp +31072 -0
- sknetwork/linalg/push.pyx +71 -0
- sknetwork/linalg/sparse_lowrank.py +142 -0
- sknetwork/linalg/svd_solver.py +91 -0
- sknetwork/linalg/tests/__init__.py +1 -0
- sknetwork/linalg/tests/test_eig.py +44 -0
- sknetwork/linalg/tests/test_laplacian.py +18 -0
- sknetwork/linalg/tests/test_normalization.py +34 -0
- sknetwork/linalg/tests/test_operators.py +66 -0
- sknetwork/linalg/tests/test_polynome.py +38 -0
- sknetwork/linalg/tests/test_ppr.py +50 -0
- sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
- sknetwork/linalg/tests/test_svd.py +38 -0
- sknetwork/linkpred/__init__.py +2 -0
- sknetwork/linkpred/base.py +46 -0
- sknetwork/linkpred/nn.py +126 -0
- sknetwork/linkpred/tests/__init__.py +1 -0
- sknetwork/linkpred/tests/test_nn.py +27 -0
- sknetwork/log.py +19 -0
- sknetwork/path/__init__.py +5 -0
- sknetwork/path/dag.py +54 -0
- sknetwork/path/distances.py +98 -0
- sknetwork/path/search.py +31 -0
- sknetwork/path/shortest_path.py +61 -0
- sknetwork/path/tests/__init__.py +1 -0
- sknetwork/path/tests/test_dag.py +37 -0
- sknetwork/path/tests/test_distances.py +62 -0
- sknetwork/path/tests/test_search.py +40 -0
- sknetwork/path/tests/test_shortest_path.py +40 -0
- sknetwork/ranking/__init__.py +8 -0
- sknetwork/ranking/base.py +61 -0
- sknetwork/ranking/betweenness.cp313-win_amd64.pyd +0 -0
- sknetwork/ranking/betweenness.cpp +9707 -0
- sknetwork/ranking/betweenness.pyx +97 -0
- sknetwork/ranking/closeness.py +92 -0
- sknetwork/ranking/hits.py +94 -0
- sknetwork/ranking/katz.py +83 -0
- sknetwork/ranking/pagerank.py +110 -0
- sknetwork/ranking/postprocess.py +37 -0
- sknetwork/ranking/tests/__init__.py +1 -0
- sknetwork/ranking/tests/test_API.py +32 -0
- sknetwork/ranking/tests/test_betweenness.py +38 -0
- sknetwork/ranking/tests/test_closeness.py +30 -0
- sknetwork/ranking/tests/test_hits.py +20 -0
- sknetwork/ranking/tests/test_pagerank.py +62 -0
- sknetwork/ranking/tests/test_postprocess.py +26 -0
- sknetwork/regression/__init__.py +4 -0
- sknetwork/regression/base.py +61 -0
- sknetwork/regression/diffusion.py +210 -0
- sknetwork/regression/tests/__init__.py +1 -0
- sknetwork/regression/tests/test_API.py +32 -0
- sknetwork/regression/tests/test_diffusion.py +56 -0
- sknetwork/sknetwork.py +3 -0
- sknetwork/test_base.py +35 -0
- sknetwork/test_log.py +15 -0
- sknetwork/topology/__init__.py +8 -0
- sknetwork/topology/cliques.cp313-win_amd64.pyd +0 -0
- sknetwork/topology/cliques.cpp +32565 -0
- sknetwork/topology/cliques.pyx +149 -0
- sknetwork/topology/core.cp313-win_amd64.pyd +0 -0
- sknetwork/topology/core.cpp +30651 -0
- sknetwork/topology/core.pyx +90 -0
- sknetwork/topology/cycles.py +243 -0
- sknetwork/topology/minheap.cp313-win_amd64.pyd +0 -0
- sknetwork/topology/minheap.cpp +27332 -0
- sknetwork/topology/minheap.pxd +20 -0
- sknetwork/topology/minheap.pyx +109 -0
- sknetwork/topology/structure.py +194 -0
- sknetwork/topology/tests/__init__.py +1 -0
- sknetwork/topology/tests/test_cliques.py +28 -0
- sknetwork/topology/tests/test_core.py +19 -0
- sknetwork/topology/tests/test_cycles.py +65 -0
- sknetwork/topology/tests/test_structure.py +85 -0
- sknetwork/topology/tests/test_triangles.py +38 -0
- sknetwork/topology/tests/test_wl.py +72 -0
- sknetwork/topology/triangles.cp313-win_amd64.pyd +0 -0
- sknetwork/topology/triangles.cpp +8894 -0
- sknetwork/topology/triangles.pyx +151 -0
- sknetwork/topology/weisfeiler_lehman.py +133 -0
- sknetwork/topology/weisfeiler_lehman_core.cp313-win_amd64.pyd +0 -0
- sknetwork/topology/weisfeiler_lehman_core.cpp +27635 -0
- sknetwork/topology/weisfeiler_lehman_core.pyx +114 -0
- sknetwork/utils/__init__.py +7 -0
- sknetwork/utils/check.py +355 -0
- sknetwork/utils/format.py +221 -0
- sknetwork/utils/membership.py +82 -0
- sknetwork/utils/neighbors.py +115 -0
- sknetwork/utils/tests/__init__.py +1 -0
- sknetwork/utils/tests/test_check.py +190 -0
- sknetwork/utils/tests/test_format.py +63 -0
- sknetwork/utils/tests/test_membership.py +24 -0
- sknetwork/utils/tests/test_neighbors.py +41 -0
- sknetwork/utils/tests/test_tfidf.py +18 -0
- sknetwork/utils/tests/test_values.py +66 -0
- sknetwork/utils/tfidf.py +37 -0
- sknetwork/utils/values.py +76 -0
- sknetwork/visualization/__init__.py +4 -0
- sknetwork/visualization/colors.py +34 -0
- sknetwork/visualization/dendrograms.py +277 -0
- sknetwork/visualization/graphs.py +1039 -0
- sknetwork/visualization/tests/__init__.py +1 -0
- sknetwork/visualization/tests/test_dendrograms.py +53 -0
- sknetwork/visualization/tests/test_graphs.py +176 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created on April 2022
|
|
5
|
+
@author: Simon Delarue <sdelarue@enst.fr>
|
|
6
|
+
"""
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseActivation:
|
|
11
|
+
"""Base class for activation functions.
|
|
12
|
+
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
name : str
|
|
16
|
+
Name of the activation function.
|
|
17
|
+
"""
|
|
18
|
+
def __init__(self, name: str = 'custom'):
|
|
19
|
+
self.name = name
|
|
20
|
+
|
|
21
|
+
@staticmethod
|
|
22
|
+
def output(signal: np.ndarray) -> np.ndarray:
|
|
23
|
+
"""Output of the activation function.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
signal : np.ndarray, shape (n_samples, n_channels)
|
|
28
|
+
Input signal.
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
output : np.ndarray, shape (n_samples, n_channels)
|
|
33
|
+
Output signal.
|
|
34
|
+
"""
|
|
35
|
+
output = signal
|
|
36
|
+
return output
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def gradient(signal: np.ndarray, direction: np.ndarray) -> np.ndarray:
|
|
40
|
+
"""Gradient of the activation function.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
signal : np.ndarray, shape (n_samples, n_channels)
|
|
45
|
+
Input signal.
|
|
46
|
+
direction : np.ndarray, shape (n_samples, n_channels)
|
|
47
|
+
Direction where the gradient is taken.
|
|
48
|
+
|
|
49
|
+
Returns
|
|
50
|
+
-------
|
|
51
|
+
gradient : np.ndarray, shape (n_samples, n_channels)
|
|
52
|
+
Gradient.
|
|
53
|
+
"""
|
|
54
|
+
gradient = direction
|
|
55
|
+
return gradient
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class BaseLoss(BaseActivation):
|
|
59
|
+
"""Base class for loss functions."""
|
|
60
|
+
@staticmethod
|
|
61
|
+
def loss(signal: np.ndarray, labels: np.ndarray) -> float:
|
|
62
|
+
"""Get the loss value.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
signal : np.ndarray, shape (n_samples, n_channels)
|
|
67
|
+
Input signal (before activation).
|
|
68
|
+
labels : np.ndarray, shape (n_samples)
|
|
69
|
+
True labels.
|
|
70
|
+
"""
|
|
71
|
+
return 0
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def loss_gradient(signal: np.ndarray, labels: np.ndarray) -> np.ndarray:
|
|
75
|
+
"""Gradient of the loss function.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
signal : np.ndarray, shape (n_samples, n_channels)
|
|
80
|
+
Input signal.
|
|
81
|
+
labels : np.ndarray, shape (n_samples,)
|
|
82
|
+
True labels.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
gradient : np.ndarray, shape (n_samples, n_channels)
|
|
87
|
+
Gradient.
|
|
88
|
+
"""
|
|
89
|
+
gradient = np.ones_like(signal)
|
|
90
|
+
return gradient
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created in July 2022
|
|
5
|
+
@author: Simon Delarue <sdelarue@enst.fr>
|
|
6
|
+
"""
|
|
7
|
+
from typing import Optional, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from sknetwork.gnn.activation import BaseActivation, get_activation
|
|
12
|
+
from sknetwork.gnn.loss import BaseLoss, get_loss
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BaseLayer:
|
|
16
|
+
"""Base class for GNN layers.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
layer_type : str
|
|
21
|
+
Layer type. Can be either ``'Conv'`` (Convolution) or ``'Sage'`` (GraphSAGE).
|
|
22
|
+
out_channels: int
|
|
23
|
+
Dimension of the output.
|
|
24
|
+
activation: str (default = ``'Relu'``) or custom activation.
|
|
25
|
+
Activation function.
|
|
26
|
+
If a string, can be either ``'Identity'``, ``'Relu'``, ``'Sigmoid'`` or ``'Softmax'``.
|
|
27
|
+
use_bias: bool (default = `True`)
|
|
28
|
+
If ``True``, add a bias vector.
|
|
29
|
+
normalization: str (default = ``'both'``)
|
|
30
|
+
Normalization of the adjacency matrix for message passing.
|
|
31
|
+
Can be either `'left'`` (left normalization by the degrees), ``'right'`` (right normalization by the degrees),
|
|
32
|
+
``'both'`` (symmetric normalization by the square root of degrees, default) or ``None`` (no normalization).
|
|
33
|
+
self_embeddings: bool (default = `True`)
|
|
34
|
+
If ``True``, consider self-embedding in addition to neighbors embedding for each node of the graph.
|
|
35
|
+
sample_size: int (default = 25)
|
|
36
|
+
Size of neighborhood sampled for each node. Used only for ``'SAGEConv'`` layer.
|
|
37
|
+
|
|
38
|
+
Attributes
|
|
39
|
+
----------
|
|
40
|
+
weight : np.ndarray,
|
|
41
|
+
Trainable weight matrix.
|
|
42
|
+
bias : np.ndarray
|
|
43
|
+
Bias vector.
|
|
44
|
+
embedding : np.ndarray
|
|
45
|
+
Embedding of the nodes (before activation).
|
|
46
|
+
output : np.ndarray
|
|
47
|
+
Output of the layer (after activation).
|
|
48
|
+
"""
|
|
49
|
+
def __init__(self, layer_type: str, out_channels: int, activation: Optional[Union[BaseActivation, str]] = 'Relu',
|
|
50
|
+
use_bias: bool = True, normalization: str = 'both', self_embeddings: bool = True,
|
|
51
|
+
sample_size: int = 25, loss: Optional[Union[BaseLoss, str]] = None):
|
|
52
|
+
self.layer_type = layer_type
|
|
53
|
+
self.out_channels = out_channels
|
|
54
|
+
if loss is None:
|
|
55
|
+
self.activation = get_activation(activation)
|
|
56
|
+
else:
|
|
57
|
+
self.activation = get_loss(loss)
|
|
58
|
+
self.use_bias = use_bias
|
|
59
|
+
self.normalization = normalization.lower()
|
|
60
|
+
self.self_embeddings = self_embeddings
|
|
61
|
+
self.sample_size = sample_size
|
|
62
|
+
self.weight = None
|
|
63
|
+
self.bias = None
|
|
64
|
+
self.embedding = None
|
|
65
|
+
self.output = None
|
|
66
|
+
self.weights_initialized = False
|
|
67
|
+
|
|
68
|
+
def _initialize_weights(self, in_channels: int):
|
|
69
|
+
"""Initialize weights and bias.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
in_channels: int
|
|
74
|
+
Number of input channels.
|
|
75
|
+
"""
|
|
76
|
+
# He initialization
|
|
77
|
+
self.weight = np.random.randn(in_channels, self.out_channels) * np.sqrt(2 / self.out_channels)
|
|
78
|
+
if self.use_bias:
|
|
79
|
+
self.bias = np.zeros((1, self.out_channels))
|
|
80
|
+
self.weights_initialized = True
|
|
81
|
+
|
|
82
|
+
def forward(self, *args, **kwargs):
|
|
83
|
+
"""Compute forward pass."""
|
|
84
|
+
raise NotImplementedError
|
|
85
|
+
|
|
86
|
+
def __call__(self, *args, **kwargs):
|
|
87
|
+
return self.forward(*args, **kwargs)
|
|
88
|
+
|
|
89
|
+
def __repr__(self) -> str:
|
|
90
|
+
""" String representation of object
|
|
91
|
+
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
str
|
|
95
|
+
String representation of object
|
|
96
|
+
"""
|
|
97
|
+
print_attr = ['out_channels', 'layer_type', 'activation', 'use_bias', 'normalization', 'self_embeddings']
|
|
98
|
+
if 'sage' in self.layer_type:
|
|
99
|
+
print_attr.append('sample_size')
|
|
100
|
+
attributes_dict = {k: v for k, v in self.__dict__.items() if k in print_attr}
|
|
101
|
+
string = ''
|
|
102
|
+
|
|
103
|
+
for k, v in attributes_dict.items():
|
|
104
|
+
if k == 'activation':
|
|
105
|
+
string += f'{k}: {v.name}, '
|
|
106
|
+
else:
|
|
107
|
+
string += f'{k}: {v}, '
|
|
108
|
+
|
|
109
|
+
return f' {self.__class__.__name__}({string[:-2]})'
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created in April 2022
|
|
5
|
+
@author: Simon Delarue <sdelarue@enst.fr>
|
|
6
|
+
"""
|
|
7
|
+
from typing import Iterable, Optional, Union
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from scipy import sparse
|
|
12
|
+
|
|
13
|
+
from sknetwork.classification.metrics import get_accuracy_score
|
|
14
|
+
from sknetwork.gnn.base import BaseGNN
|
|
15
|
+
from sknetwork.gnn.loss import BaseLoss
|
|
16
|
+
from sknetwork.gnn.layer import get_layer
|
|
17
|
+
from sknetwork.gnn.neighbor_sampler import UniformNeighborSampler
|
|
18
|
+
from sknetwork.gnn.optimizer import BaseOptimizer
|
|
19
|
+
from sknetwork.gnn.utils import check_output, check_early_stopping, check_loss, get_layers
|
|
20
|
+
from sknetwork.utils.check import check_format, check_nonnegative, check_square
|
|
21
|
+
from sknetwork.utils.values import get_values
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class GNNClassifier(BaseGNN):
|
|
25
|
+
"""Graph Neural Network for node classification.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
dims : iterable or int
|
|
30
|
+
Dimension of the output of each layer (in forward direction).
|
|
31
|
+
If an integer, dimension of the output layer (no hidden layer).
|
|
32
|
+
Optional if ``layers`` is specified.
|
|
33
|
+
layer_types : iterable or str
|
|
34
|
+
Layer types (in forward direction).
|
|
35
|
+
If a string, the same type is used at each layer.
|
|
36
|
+
Can be ``'Conv'``, graph convolutional layer (default) or ``'Sage'`` (GraphSage).
|
|
37
|
+
activations : iterable or str
|
|
38
|
+
Activation functions (in forward direction).
|
|
39
|
+
If a string, the same activation function is used at each layer.
|
|
40
|
+
Can be either ``'Identity'``, ``'Relu'``, ``'Sigmoid'`` or ``'Softmax'`` (default = ``'Relu'``).
|
|
41
|
+
use_bias : iterable or bool
|
|
42
|
+
Whether to add a bias term at each layer (in forward direction).
|
|
43
|
+
If ``True``, use a bias term at each layer.
|
|
44
|
+
normalizations : iterable or str
|
|
45
|
+
Normalizations of the adjacency matrix for message passing (in forward direction).
|
|
46
|
+
If a string, the same type of normalization is used at each layer.
|
|
47
|
+
Can be either ``'left'`` (left normalization by the degrees), ``'right'`` (right normalization by the degrees),
|
|
48
|
+
``'both'`` (symmetric normalization by the square root of degrees, default) or ``None`` (no normalization).
|
|
49
|
+
self_embeddings : iterable or str
|
|
50
|
+
Whether to add the embedding to each node for message passing (in forward direction).
|
|
51
|
+
If ``True``, add a self-embedding at each layer.
|
|
52
|
+
sample_sizes : iterable or int
|
|
53
|
+
Sizes of neighborhood sampled for each node (in forward direction).
|
|
54
|
+
If an integer, the same sampling size is used at each layer.
|
|
55
|
+
Used only for ``'Sage'`` layer type.
|
|
56
|
+
loss : str (default = ``'CrossEntropy'``) or BaseLoss
|
|
57
|
+
Name of loss function or custom loss function.
|
|
58
|
+
layers : iterable or None
|
|
59
|
+
Custom layers (in forward directions). If used, previous parameters are ignored.
|
|
60
|
+
optimizer : str or optimizer
|
|
61
|
+
* ``'Adam'``, stochastic gradient-based optimizer (default).
|
|
62
|
+
* ``'GD'``, gradient descent.
|
|
63
|
+
learning_rate : float
|
|
64
|
+
Learning rate.
|
|
65
|
+
early_stopping : bool (default = ``True``)
|
|
66
|
+
Whether to use early stopping to end training.
|
|
67
|
+
If ``True``, training terminates when validation score is not improving for `patience` number of epochs.
|
|
68
|
+
patience : int (default = 10)
|
|
69
|
+
Number of iterations with no improvement to wait before stopping fitting.
|
|
70
|
+
verbose : bool
|
|
71
|
+
Verbose mode.
|
|
72
|
+
|
|
73
|
+
Attributes
|
|
74
|
+
----------
|
|
75
|
+
conv2, ..., conv1: :class:'GCNConv'
|
|
76
|
+
Graph convolutional layers.
|
|
77
|
+
output_ : np.ndarray
|
|
78
|
+
Output of the GNN.
|
|
79
|
+
labels_: np.ndarray
|
|
80
|
+
Predicted node labels.
|
|
81
|
+
history_: dict
|
|
82
|
+
Training history per epoch: {``'embedding'``, ``'loss'``, ``'train_accuracy'``, ``'val_accuracy'``}.
|
|
83
|
+
|
|
84
|
+
Example
|
|
85
|
+
-------
|
|
86
|
+
>>> from sknetwork.gnn.gnn_classifier import GNNClassifier
|
|
87
|
+
>>> from sknetwork.data import karate_club
|
|
88
|
+
>>> from numpy.random import randint
|
|
89
|
+
>>> graph = karate_club(metadata=True)
|
|
90
|
+
>>> adjacency = graph.adjacency
|
|
91
|
+
>>> labels_true = graph.labels
|
|
92
|
+
>>> labels = {i: labels_true[i] for i in [0, 1, 33]}
|
|
93
|
+
>>> features = adjacency.copy()
|
|
94
|
+
>>> gnn = GNNClassifier(dims=1, early_stopping=False)
|
|
95
|
+
>>> labels_pred = gnn.fit_predict(adjacency, features, labels, random_state=42)
|
|
96
|
+
>>> float(round(np.mean(labels_pred == labels_true), 2))
|
|
97
|
+
0.88
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def __init__(self, dims: Optional[Union[int, Iterable]] = None, layer_types: Union[str, Iterable] = 'Conv',
|
|
101
|
+
activations: Union[str, Iterable] = 'ReLu', use_bias: Union[bool, list] = True,
|
|
102
|
+
normalizations: Union[str, Iterable] = 'both', self_embeddings: Union[bool, Iterable] = True,
|
|
103
|
+
sample_sizes: Union[int, list] = 25, loss: Union[BaseLoss, str] = 'CrossEntropy',
|
|
104
|
+
layers: Optional[Iterable] = None, optimizer: Union[BaseOptimizer, str] = 'Adam',
|
|
105
|
+
learning_rate: float = 0.01, early_stopping: bool = True, patience: int = 10, verbose: bool = False):
|
|
106
|
+
super(GNNClassifier, self).__init__(loss, optimizer, learning_rate, verbose)
|
|
107
|
+
if layers is not None:
|
|
108
|
+
layers = [get_layer(layer) for layer in layers]
|
|
109
|
+
else:
|
|
110
|
+
layers = get_layers(dims, layer_types, activations, use_bias, normalizations, self_embeddings, sample_sizes,
|
|
111
|
+
loss)
|
|
112
|
+
self.loss = check_loss(layers[-1])
|
|
113
|
+
self.layers = layers
|
|
114
|
+
self.early_stopping = early_stopping
|
|
115
|
+
self.patience = patience
|
|
116
|
+
self.history_ = defaultdict(list)
|
|
117
|
+
|
|
118
|
+
def forward(self, adjacency: Union[list, sparse.csr_matrix], features: Union[sparse.csr_matrix, np.ndarray]) \
|
|
119
|
+
-> np.ndarray:
|
|
120
|
+
"""Perform a forward pass on the graph and return the output.
|
|
121
|
+
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
adjacency : Union[list, sparse.csr_matrix]
|
|
125
|
+
Adjacency matrix or list of sampled adjacency matrices.
|
|
126
|
+
features : sparse.csr_matrix, np.ndarray
|
|
127
|
+
Features, array of shape (n_nodes, n_features).
|
|
128
|
+
|
|
129
|
+
Returns
|
|
130
|
+
-------
|
|
131
|
+
output : np.ndarray
|
|
132
|
+
Output of the GNN.
|
|
133
|
+
"""
|
|
134
|
+
h = features.copy()
|
|
135
|
+
for i, layer in enumerate(self.layers):
|
|
136
|
+
if isinstance(adjacency, list):
|
|
137
|
+
h = layer(adjacency[i], h)
|
|
138
|
+
else:
|
|
139
|
+
h = layer(adjacency, h)
|
|
140
|
+
return h
|
|
141
|
+
|
|
142
|
+
@staticmethod
|
|
143
|
+
def _compute_predictions(output: np.ndarray) -> np.ndarray:
|
|
144
|
+
"""Compute predictions from the output of the GNN.
|
|
145
|
+
|
|
146
|
+
Parameters
|
|
147
|
+
----------
|
|
148
|
+
output : np.ndarray
|
|
149
|
+
Output of the GNN.
|
|
150
|
+
|
|
151
|
+
Returns
|
|
152
|
+
-------
|
|
153
|
+
labels : np.ndarray
|
|
154
|
+
Predicted labels.
|
|
155
|
+
"""
|
|
156
|
+
if output.shape[1] == 1:
|
|
157
|
+
labels = (output.ravel() > 0.5).astype(int)
|
|
158
|
+
else:
|
|
159
|
+
labels = output.argmax(axis=1)
|
|
160
|
+
return labels
|
|
161
|
+
|
|
162
|
+
def fit(self, adjacency: Union[sparse.csr_matrix, np.ndarray], features: Union[sparse.csr_matrix, np.ndarray],
|
|
163
|
+
labels: np.ndarray, n_epochs: int = 100, validation: float = 0, reinit: bool = False,
|
|
164
|
+
random_state: Optional[int] = None) -> 'GNNClassifier':
|
|
165
|
+
""" Fit model to data and store trained parameters.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
adjacency : sparse.csr_matrix
|
|
170
|
+
Adjacency matrix of the graph.
|
|
171
|
+
features : sparse.csr_matrix, np.ndarray
|
|
172
|
+
Input feature of shape :math:`(n, d)` with :math:`n` the number of nodes in the graph and :math:`d`
|
|
173
|
+
the size of feature space.
|
|
174
|
+
labels : dict, np.ndarray
|
|
175
|
+
Known labels. Negative values ignored.
|
|
176
|
+
n_epochs : int (default = 100)
|
|
177
|
+
Number of epochs (iterations over the whole graph).
|
|
178
|
+
validation : float
|
|
179
|
+
Proportion of the training set used for validation (between 0 and 1).
|
|
180
|
+
reinit: bool (default = ``False``)
|
|
181
|
+
If ``True``, reinit the trainable parameters of the GNN (weights and biases).
|
|
182
|
+
random_state : int
|
|
183
|
+
Random seed, used for reproducible results across multiple runs.
|
|
184
|
+
"""
|
|
185
|
+
if reinit:
|
|
186
|
+
for layer in self.layers:
|
|
187
|
+
layer.weights_initialized = False
|
|
188
|
+
self.history_ = defaultdict(list)
|
|
189
|
+
|
|
190
|
+
if random_state is not None:
|
|
191
|
+
np.random.seed(random_state)
|
|
192
|
+
|
|
193
|
+
check_format(adjacency, allow_empty=True)
|
|
194
|
+
check_format(features, allow_empty=True)
|
|
195
|
+
|
|
196
|
+
labels = get_values(adjacency.shape, labels)
|
|
197
|
+
labels = labels.astype(int)
|
|
198
|
+
if (labels < 0).all():
|
|
199
|
+
raise ValueError('At least one node must have a non-negative label.')
|
|
200
|
+
check_output(self.layers[-1].out_channels, labels)
|
|
201
|
+
|
|
202
|
+
self.train_mask = labels >= 0
|
|
203
|
+
if self.val_mask is None and 0 < validation < 1:
|
|
204
|
+
mask = np.random.random(size=len(labels)) < validation
|
|
205
|
+
self.val_mask = self.train_mask & mask
|
|
206
|
+
self.train_mask &= ~mask
|
|
207
|
+
|
|
208
|
+
early_stopping = check_early_stopping(self.early_stopping, self.val_mask, self.patience)
|
|
209
|
+
|
|
210
|
+
# List of sampled adjacencies (one per layer)
|
|
211
|
+
adjacencies = self._sample_nodes(adjacency)
|
|
212
|
+
|
|
213
|
+
best_val_accuracy = 0
|
|
214
|
+
count = 0
|
|
215
|
+
|
|
216
|
+
for epoch in range(n_epochs):
|
|
217
|
+
|
|
218
|
+
# Forward
|
|
219
|
+
output = self.forward(adjacencies, features)
|
|
220
|
+
|
|
221
|
+
# Compute predictions
|
|
222
|
+
labels_pred = self._compute_predictions(output)
|
|
223
|
+
|
|
224
|
+
# Loss
|
|
225
|
+
loss_value = self.loss.loss(output[self.train_mask], labels[self.train_mask])
|
|
226
|
+
|
|
227
|
+
# Accuracy
|
|
228
|
+
train_accuracy = get_accuracy_score(labels[self.train_mask], labels_pred[self.train_mask])
|
|
229
|
+
if self.val_mask is not None and any(self.val_mask):
|
|
230
|
+
val_accuracy = get_accuracy_score(labels[self.val_mask], labels_pred[self.val_mask])
|
|
231
|
+
else:
|
|
232
|
+
val_accuracy = None
|
|
233
|
+
|
|
234
|
+
# Backpropagation
|
|
235
|
+
self.backward(features, labels, self.train_mask)
|
|
236
|
+
|
|
237
|
+
# Update weights using optimizer
|
|
238
|
+
self.optimizer.step(self)
|
|
239
|
+
|
|
240
|
+
# Save results
|
|
241
|
+
self.history_['loss'].append(loss_value)
|
|
242
|
+
self.history_['train_accuracy'].append(train_accuracy)
|
|
243
|
+
if val_accuracy is not None:
|
|
244
|
+
self.history_['val_accuracy'].append(val_accuracy)
|
|
245
|
+
|
|
246
|
+
if n_epochs > 10 and epoch % int(n_epochs / 10) == 0:
|
|
247
|
+
if val_accuracy is not None:
|
|
248
|
+
self.print_log(
|
|
249
|
+
f'In epoch {epoch:>3}, loss: {loss_value:.3f}, train accuracy: {train_accuracy:.3f}, '
|
|
250
|
+
f'val accuracy: {val_accuracy:.3f}')
|
|
251
|
+
else:
|
|
252
|
+
self.print_log(
|
|
253
|
+
f'In epoch {epoch:>3}, loss: {loss_value:.3f}, train accuracy: {train_accuracy:.3f}')
|
|
254
|
+
elif n_epochs <= 10:
|
|
255
|
+
if val_accuracy is not None:
|
|
256
|
+
self.print_log(
|
|
257
|
+
f'In epoch {epoch:>3}, loss: {loss_value:.3f}, train accuracy: {train_accuracy:.3f}, '
|
|
258
|
+
f'val accuracy: {val_accuracy:.3f}')
|
|
259
|
+
else:
|
|
260
|
+
self.print_log(
|
|
261
|
+
f'In epoch {epoch:>3}, loss: {loss_value:.3f}, train accuracy: {train_accuracy:.3f}')
|
|
262
|
+
|
|
263
|
+
# Early stopping
|
|
264
|
+
if early_stopping:
|
|
265
|
+
if val_accuracy > best_val_accuracy:
|
|
266
|
+
count = 0
|
|
267
|
+
best_val_accuracy = val_accuracy
|
|
268
|
+
else:
|
|
269
|
+
count += 1
|
|
270
|
+
if count >= self.patience:
|
|
271
|
+
self.print_log('Early stopping.')
|
|
272
|
+
break
|
|
273
|
+
|
|
274
|
+
output = self.forward(adjacencies, features)
|
|
275
|
+
labels_pred = self._compute_predictions(output)
|
|
276
|
+
|
|
277
|
+
self.embedding_ = self.layers[-1].embedding
|
|
278
|
+
self.output_ = self.layers[-1].output
|
|
279
|
+
self.labels_ = labels_pred
|
|
280
|
+
|
|
281
|
+
return self
|
|
282
|
+
|
|
283
|
+
def _sample_nodes(self, adjacency: Union[sparse.csr_matrix, np.ndarray]) -> list:
|
|
284
|
+
"""Perform node sampling on adjacency matrix for GraphSAGE layers. For other layers, the
|
|
285
|
+
adjacency matrix remains unchanged.
|
|
286
|
+
|
|
287
|
+
Parameters
|
|
288
|
+
----------
|
|
289
|
+
adjacency : sparse.csr_matrix
|
|
290
|
+
Adjacency matrix of the graph.
|
|
291
|
+
|
|
292
|
+
Returns
|
|
293
|
+
-------
|
|
294
|
+
List of (sampled) adjacency matrices.
|
|
295
|
+
"""
|
|
296
|
+
adjacencies = []
|
|
297
|
+
|
|
298
|
+
for layer in self.layers:
|
|
299
|
+
if layer.layer_type == 'sage':
|
|
300
|
+
sampler = UniformNeighborSampler(sample_size=layer.sample_size)
|
|
301
|
+
adjacencies.append(sampler(adjacency))
|
|
302
|
+
else:
|
|
303
|
+
adjacencies.append(adjacency)
|
|
304
|
+
|
|
305
|
+
return adjacencies
|
sknetwork/gnn/layer.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# coding: utf-8
|
|
3
|
+
"""
|
|
4
|
+
Created in April 2022
|
|
5
|
+
@author: Simon Delarue <sdelarue@enst.fr>
|
|
6
|
+
"""
|
|
7
|
+
from typing import Optional, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from scipy import sparse
|
|
11
|
+
|
|
12
|
+
from sknetwork.gnn.activation import BaseActivation
|
|
13
|
+
from sknetwork.gnn.loss import BaseLoss
|
|
14
|
+
from sknetwork.gnn.base_layer import BaseLayer
|
|
15
|
+
from sknetwork.utils.check import add_self_loops
|
|
16
|
+
from sknetwork.linalg import diagonal_pseudo_inverse
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Convolution(BaseLayer):
|
|
20
|
+
"""Graph convolutional layer.
|
|
21
|
+
|
|
22
|
+
Apply the following function to the embedding :math:`X`:
|
|
23
|
+
|
|
24
|
+
:math:`\\sigma(\\bar AXW + b)`,
|
|
25
|
+
|
|
26
|
+
where :math:`\\bar A` is the normalized adjacency matrix (possibly with inserted self-embeddings),
|
|
27
|
+
:math:`W`, :math:`b` are trainable parameters and :math:`\\sigma` is the activation function.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
layer_type : str
|
|
32
|
+
Layer type. Can be either ``'Conv'``, convolutional operator as in [1] or ``'Sage'``, as in [2].
|
|
33
|
+
out_channels: int
|
|
34
|
+
Dimension of the output.
|
|
35
|
+
activation: str (default = ``'Relu'``) or custom activation.
|
|
36
|
+
Activation function.
|
|
37
|
+
If a string, can be either ``'Identity'``, ``'Relu'``, ``'Sigmoid'`` or ``'Softmax'``.
|
|
38
|
+
use_bias: bool (default = `True`)
|
|
39
|
+
If ``True``, add a bias vector.
|
|
40
|
+
normalization: str (default = ``'both'``)
|
|
41
|
+
Normalization of the adjacency matrix for message passing.
|
|
42
|
+
Can be either `'left'`` (left normalization by the degrees), ``'right'`` (right normalization by the degrees),
|
|
43
|
+
``'both'`` (symmetric normalization by the square root of degrees, default) or ``None`` (no normalization).
|
|
44
|
+
self_embeddings: bool (default = `True`)
|
|
45
|
+
If ``True``, consider self-embedding in addition to neighbors embedding for each node of the graph.
|
|
46
|
+
sample_size: int (default = 25)
|
|
47
|
+
Size of neighborhood sampled for each node. Used only for ``'Sage'`` layer.
|
|
48
|
+
|
|
49
|
+
Attributes
|
|
50
|
+
----------
|
|
51
|
+
weight: np.ndarray,
|
|
52
|
+
Trainable weight matrix.
|
|
53
|
+
bias: np.ndarray
|
|
54
|
+
Bias vector.
|
|
55
|
+
embedding: np.ndarray
|
|
56
|
+
Embedding of the nodes (before activation).
|
|
57
|
+
output: np.ndarray
|
|
58
|
+
Output of the layer (after activation).
|
|
59
|
+
|
|
60
|
+
References
|
|
61
|
+
----------
|
|
62
|
+
[1] Kipf, T., & Welling, M. (2017).
|
|
63
|
+
`Semi-supervised Classification with Graph Convolutional Networks.
|
|
64
|
+
<https://arxiv.org/pdf/1609.02907.pdf>`_
|
|
65
|
+
5th International Conference on Learning Representations.
|
|
66
|
+
|
|
67
|
+
[2] Hamilton, W. Ying, R., & Leskovec, J. (2017)
|
|
68
|
+
`Inductive Representation Learning on Large Graphs.
|
|
69
|
+
<https://arxiv.org/pdf/1706.02216.pdf>`_
|
|
70
|
+
NIPS
|
|
71
|
+
"""
|
|
72
|
+
def __init__(self, layer_type: str, out_channels: int, activation: Optional[Union[BaseActivation, str]] = 'Relu',
|
|
73
|
+
use_bias: bool = True, normalization: str = 'both', self_embeddings: bool = True,
|
|
74
|
+
sample_size: int = None, loss: Optional[Union[BaseLoss, str]] = None):
|
|
75
|
+
super(Convolution, self).__init__(layer_type, out_channels, activation, use_bias, normalization,
|
|
76
|
+
self_embeddings, sample_size, loss)
|
|
77
|
+
|
|
78
|
+
def forward(self, adjacency: Union[sparse.csr_matrix, np.ndarray],
|
|
79
|
+
features: Union[sparse.csr_matrix, np.ndarray]) -> np.ndarray:
|
|
80
|
+
"""Compute graph convolution.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
adjacency
|
|
85
|
+
Adjacency matrix of the graph.
|
|
86
|
+
features : sparse.csr_matrix, np.ndarray
|
|
87
|
+
Input feature of shape :math:`(n, d)` with :math:`n` the number of nodes in the graph and :math:`d`
|
|
88
|
+
the size of feature space.
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
output: np.ndarray
|
|
93
|
+
Output of the layer.
|
|
94
|
+
"""
|
|
95
|
+
if not self.weights_initialized:
|
|
96
|
+
self._initialize_weights(features.shape[1])
|
|
97
|
+
|
|
98
|
+
n_row, n_col = adjacency.shape
|
|
99
|
+
|
|
100
|
+
weights = adjacency.dot(np.ones(n_col))
|
|
101
|
+
if self.normalization == 'left':
|
|
102
|
+
d_inv = diagonal_pseudo_inverse(weights)
|
|
103
|
+
adjacency = d_inv.dot(adjacency)
|
|
104
|
+
elif self.normalization == 'right':
|
|
105
|
+
d_inv = diagonal_pseudo_inverse(weights)
|
|
106
|
+
adjacency = adjacency.dot(d_inv)
|
|
107
|
+
elif self.normalization == 'both':
|
|
108
|
+
d_inv = diagonal_pseudo_inverse(np.sqrt(weights))
|
|
109
|
+
adjacency = d_inv.dot(adjacency).dot(d_inv)
|
|
110
|
+
|
|
111
|
+
if self.self_embeddings:
|
|
112
|
+
adjacency = add_self_loops(adjacency)
|
|
113
|
+
|
|
114
|
+
message = adjacency.dot(features)
|
|
115
|
+
embedding = message.dot(self.weight)
|
|
116
|
+
|
|
117
|
+
if self.use_bias:
|
|
118
|
+
embedding += self.bias
|
|
119
|
+
|
|
120
|
+
output = self.activation.output(embedding)
|
|
121
|
+
|
|
122
|
+
self.embedding = embedding
|
|
123
|
+
self.output = output
|
|
124
|
+
|
|
125
|
+
return output
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def get_layer(layer: Union[BaseLayer, str] = 'conv', **kwargs) -> BaseLayer:
|
|
129
|
+
"""Get layer.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
layer : str or custom layer
|
|
134
|
+
If a string, must be either ``'Conv'`` (Convolution) or ``'Sage'`` (GraphSAGE).
|
|
135
|
+
|
|
136
|
+
Returns
|
|
137
|
+
-------
|
|
138
|
+
Layer object.
|
|
139
|
+
"""
|
|
140
|
+
if issubclass(type(layer), BaseLayer):
|
|
141
|
+
return layer
|
|
142
|
+
elif type(layer) == str:
|
|
143
|
+
layer = layer.lower()
|
|
144
|
+
if 'sage' in layer:
|
|
145
|
+
kwargs['normalization'] = 'left'
|
|
146
|
+
kwargs['self_embeddings'] = True
|
|
147
|
+
return Convolution('sage', **kwargs)
|
|
148
|
+
elif 'conv' in layer:
|
|
149
|
+
return Convolution('conv', **kwargs)
|
|
150
|
+
else:
|
|
151
|
+
raise ValueError("Layer name must be \"Conv\" or \"Sage\".")
|
|
152
|
+
else:
|
|
153
|
+
raise TypeError("Layer must be a string or a \"BaseLayer\" object.")
|