scikit-network 0.28.3__cp39-cp39-macosx_12_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-network might be problematic. Click here for more details.
- scikit_network-0.28.3.dist-info/AUTHORS.rst +41 -0
- scikit_network-0.28.3.dist-info/LICENSE +34 -0
- scikit_network-0.28.3.dist-info/METADATA +457 -0
- scikit_network-0.28.3.dist-info/RECORD +240 -0
- scikit_network-0.28.3.dist-info/WHEEL +5 -0
- scikit_network-0.28.3.dist-info/top_level.txt +1 -0
- sknetwork/__init__.py +21 -0
- sknetwork/classification/__init__.py +8 -0
- sknetwork/classification/base.py +84 -0
- sknetwork/classification/base_rank.py +143 -0
- sknetwork/classification/diffusion.py +134 -0
- sknetwork/classification/knn.py +162 -0
- sknetwork/classification/metrics.py +205 -0
- sknetwork/classification/pagerank.py +66 -0
- sknetwork/classification/propagation.py +152 -0
- sknetwork/classification/tests/__init__.py +1 -0
- sknetwork/classification/tests/test_API.py +35 -0
- sknetwork/classification/tests/test_diffusion.py +37 -0
- sknetwork/classification/tests/test_knn.py +24 -0
- sknetwork/classification/tests/test_metrics.py +53 -0
- sknetwork/classification/tests/test_pagerank.py +20 -0
- sknetwork/classification/tests/test_propagation.py +24 -0
- sknetwork/classification/vote.cpython-39-darwin.so +0 -0
- sknetwork/classification/vote.pyx +58 -0
- sknetwork/clustering/__init__.py +7 -0
- sknetwork/clustering/base.py +102 -0
- sknetwork/clustering/kmeans.py +142 -0
- sknetwork/clustering/louvain.py +255 -0
- sknetwork/clustering/louvain_core.cpython-39-darwin.so +0 -0
- sknetwork/clustering/louvain_core.pyx +134 -0
- sknetwork/clustering/metrics.py +91 -0
- sknetwork/clustering/postprocess.py +66 -0
- sknetwork/clustering/propagation_clustering.py +108 -0
- sknetwork/clustering/tests/__init__.py +1 -0
- sknetwork/clustering/tests/test_API.py +37 -0
- sknetwork/clustering/tests/test_kmeans.py +47 -0
- sknetwork/clustering/tests/test_louvain.py +104 -0
- sknetwork/clustering/tests/test_metrics.py +50 -0
- sknetwork/clustering/tests/test_post_processing.py +23 -0
- sknetwork/clustering/tests/test_postprocess.py +39 -0
- sknetwork/data/__init__.py +5 -0
- sknetwork/data/load.py +408 -0
- sknetwork/data/models.py +459 -0
- sknetwork/data/parse.py +621 -0
- sknetwork/data/test_graphs.py +84 -0
- sknetwork/data/tests/__init__.py +1 -0
- sknetwork/data/tests/test_API.py +30 -0
- sknetwork/data/tests/test_load.py +95 -0
- sknetwork/data/tests/test_models.py +52 -0
- sknetwork/data/tests/test_parse.py +253 -0
- sknetwork/data/tests/test_test_graphs.py +30 -0
- sknetwork/data/tests/test_toy_graphs.py +68 -0
- sknetwork/data/toy_graphs.py +619 -0
- sknetwork/embedding/__init__.py +10 -0
- sknetwork/embedding/base.py +90 -0
- sknetwork/embedding/force_atlas.py +197 -0
- sknetwork/embedding/louvain_embedding.py +174 -0
- sknetwork/embedding/louvain_hierarchy.py +142 -0
- sknetwork/embedding/metrics.py +66 -0
- sknetwork/embedding/random_projection.py +133 -0
- sknetwork/embedding/spectral.py +214 -0
- sknetwork/embedding/spring.py +198 -0
- sknetwork/embedding/svd.py +363 -0
- sknetwork/embedding/tests/__init__.py +1 -0
- sknetwork/embedding/tests/test_API.py +73 -0
- sknetwork/embedding/tests/test_force_atlas.py +35 -0
- sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
- sknetwork/embedding/tests/test_louvain_hierarchy.py +19 -0
- sknetwork/embedding/tests/test_metrics.py +29 -0
- sknetwork/embedding/tests/test_random_projection.py +28 -0
- sknetwork/embedding/tests/test_spectral.py +84 -0
- sknetwork/embedding/tests/test_spring.py +50 -0
- sknetwork/embedding/tests/test_svd.py +37 -0
- sknetwork/flow/__init__.py +3 -0
- sknetwork/flow/flow.py +73 -0
- sknetwork/flow/tests/__init__.py +1 -0
- sknetwork/flow/tests/test_flow.py +17 -0
- sknetwork/flow/tests/test_utils.py +69 -0
- sknetwork/flow/utils.py +91 -0
- sknetwork/gnn/__init__.py +10 -0
- sknetwork/gnn/activation.py +117 -0
- sknetwork/gnn/base.py +155 -0
- sknetwork/gnn/base_activation.py +89 -0
- sknetwork/gnn/base_layer.py +109 -0
- sknetwork/gnn/gnn_classifier.py +381 -0
- sknetwork/gnn/layer.py +153 -0
- sknetwork/gnn/layers.py +127 -0
- sknetwork/gnn/loss.py +180 -0
- sknetwork/gnn/neighbor_sampler.py +65 -0
- sknetwork/gnn/optimizer.py +163 -0
- sknetwork/gnn/tests/__init__.py +1 -0
- sknetwork/gnn/tests/test_activation.py +56 -0
- sknetwork/gnn/tests/test_base.py +79 -0
- sknetwork/gnn/tests/test_base_layer.py +37 -0
- sknetwork/gnn/tests/test_gnn_classifier.py +192 -0
- sknetwork/gnn/tests/test_layers.py +80 -0
- sknetwork/gnn/tests/test_loss.py +33 -0
- sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
- sknetwork/gnn/tests/test_optimizer.py +43 -0
- sknetwork/gnn/tests/test_utils.py +93 -0
- sknetwork/gnn/utils.py +219 -0
- sknetwork/hierarchy/__init__.py +7 -0
- sknetwork/hierarchy/base.py +69 -0
- sknetwork/hierarchy/louvain_hierarchy.py +264 -0
- sknetwork/hierarchy/metrics.py +234 -0
- sknetwork/hierarchy/paris.cpython-39-darwin.so +0 -0
- sknetwork/hierarchy/paris.pyx +317 -0
- sknetwork/hierarchy/postprocess.py +350 -0
- sknetwork/hierarchy/tests/__init__.py +1 -0
- sknetwork/hierarchy/tests/test_API.py +25 -0
- sknetwork/hierarchy/tests/test_algos.py +29 -0
- sknetwork/hierarchy/tests/test_metrics.py +62 -0
- sknetwork/hierarchy/tests/test_postprocess.py +57 -0
- sknetwork/hierarchy/tests/test_ward.py +25 -0
- sknetwork/hierarchy/ward.py +94 -0
- sknetwork/linalg/__init__.py +9 -0
- sknetwork/linalg/basics.py +37 -0
- sknetwork/linalg/diteration.cpython-39-darwin.so +0 -0
- sknetwork/linalg/diteration.pyx +49 -0
- sknetwork/linalg/eig_solver.py +93 -0
- sknetwork/linalg/laplacian.py +15 -0
- sknetwork/linalg/normalization.py +66 -0
- sknetwork/linalg/operators.py +225 -0
- sknetwork/linalg/polynome.py +76 -0
- sknetwork/linalg/ppr_solver.py +170 -0
- sknetwork/linalg/push.cpython-39-darwin.so +0 -0
- sknetwork/linalg/push.pyx +73 -0
- sknetwork/linalg/sparse_lowrank.py +142 -0
- sknetwork/linalg/svd_solver.py +91 -0
- sknetwork/linalg/tests/__init__.py +1 -0
- sknetwork/linalg/tests/test_eig.py +44 -0
- sknetwork/linalg/tests/test_laplacian.py +18 -0
- sknetwork/linalg/tests/test_normalization.py +38 -0
- sknetwork/linalg/tests/test_operators.py +70 -0
- sknetwork/linalg/tests/test_polynome.py +38 -0
- sknetwork/linalg/tests/test_ppr.py +50 -0
- sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
- sknetwork/linalg/tests/test_svd.py +38 -0
- sknetwork/linkpred/__init__.py +4 -0
- sknetwork/linkpred/base.py +80 -0
- sknetwork/linkpred/first_order.py +508 -0
- sknetwork/linkpred/first_order_core.cpython-39-darwin.so +0 -0
- sknetwork/linkpred/first_order_core.pyx +315 -0
- sknetwork/linkpred/postprocessing.py +98 -0
- sknetwork/linkpred/tests/__init__.py +1 -0
- sknetwork/linkpred/tests/test_API.py +49 -0
- sknetwork/linkpred/tests/test_postprocessing.py +21 -0
- sknetwork/path/__init__.py +4 -0
- sknetwork/path/metrics.py +148 -0
- sknetwork/path/search.py +65 -0
- sknetwork/path/shortest_path.py +186 -0
- sknetwork/path/tests/__init__.py +1 -0
- sknetwork/path/tests/test_metrics.py +29 -0
- sknetwork/path/tests/test_search.py +25 -0
- sknetwork/path/tests/test_shortest_path.py +45 -0
- sknetwork/ranking/__init__.py +9 -0
- sknetwork/ranking/base.py +56 -0
- sknetwork/ranking/betweenness.cpython-39-darwin.so +0 -0
- sknetwork/ranking/betweenness.pyx +99 -0
- sknetwork/ranking/closeness.py +95 -0
- sknetwork/ranking/harmonic.py +82 -0
- sknetwork/ranking/hits.py +94 -0
- sknetwork/ranking/katz.py +81 -0
- sknetwork/ranking/pagerank.py +107 -0
- sknetwork/ranking/postprocess.py +25 -0
- sknetwork/ranking/tests/__init__.py +1 -0
- sknetwork/ranking/tests/test_API.py +34 -0
- sknetwork/ranking/tests/test_betweenness.py +38 -0
- sknetwork/ranking/tests/test_closeness.py +34 -0
- sknetwork/ranking/tests/test_hits.py +20 -0
- sknetwork/ranking/tests/test_pagerank.py +69 -0
- sknetwork/regression/__init__.py +4 -0
- sknetwork/regression/base.py +56 -0
- sknetwork/regression/diffusion.py +190 -0
- sknetwork/regression/tests/__init__.py +1 -0
- sknetwork/regression/tests/test_API.py +34 -0
- sknetwork/regression/tests/test_diffusion.py +48 -0
- sknetwork/sknetwork.py +3 -0
- sknetwork/topology/__init__.py +9 -0
- sknetwork/topology/dag.py +74 -0
- sknetwork/topology/dag_core.cpython-39-darwin.so +0 -0
- sknetwork/topology/dag_core.pyx +38 -0
- sknetwork/topology/kcliques.cpython-39-darwin.so +0 -0
- sknetwork/topology/kcliques.pyx +193 -0
- sknetwork/topology/kcore.cpython-39-darwin.so +0 -0
- sknetwork/topology/kcore.pyx +120 -0
- sknetwork/topology/structure.py +234 -0
- sknetwork/topology/tests/__init__.py +1 -0
- sknetwork/topology/tests/test_cliques.py +28 -0
- sknetwork/topology/tests/test_cores.py +21 -0
- sknetwork/topology/tests/test_dag.py +26 -0
- sknetwork/topology/tests/test_structure.py +99 -0
- sknetwork/topology/tests/test_triangles.py +42 -0
- sknetwork/topology/tests/test_wl_coloring.py +49 -0
- sknetwork/topology/tests/test_wl_kernel.py +31 -0
- sknetwork/topology/triangles.cpython-39-darwin.so +0 -0
- sknetwork/topology/triangles.pyx +166 -0
- sknetwork/topology/weisfeiler_lehman.py +163 -0
- sknetwork/topology/weisfeiler_lehman_core.cpython-39-darwin.so +0 -0
- sknetwork/topology/weisfeiler_lehman_core.pyx +116 -0
- sknetwork/utils/__init__.py +40 -0
- sknetwork/utils/base.py +35 -0
- sknetwork/utils/check.py +354 -0
- sknetwork/utils/co_neighbor.py +71 -0
- sknetwork/utils/format.py +219 -0
- sknetwork/utils/kmeans.py +89 -0
- sknetwork/utils/knn.py +166 -0
- sknetwork/utils/knn1d.cpython-39-darwin.so +0 -0
- sknetwork/utils/knn1d.pyx +80 -0
- sknetwork/utils/membership.py +82 -0
- sknetwork/utils/minheap.cpython-39-darwin.so +0 -0
- sknetwork/utils/minheap.pxd +22 -0
- sknetwork/utils/minheap.pyx +111 -0
- sknetwork/utils/neighbors.py +115 -0
- sknetwork/utils/seeds.py +75 -0
- sknetwork/utils/simplex.py +140 -0
- sknetwork/utils/tests/__init__.py +1 -0
- sknetwork/utils/tests/test_base.py +28 -0
- sknetwork/utils/tests/test_bunch.py +16 -0
- sknetwork/utils/tests/test_check.py +190 -0
- sknetwork/utils/tests/test_co_neighbor.py +43 -0
- sknetwork/utils/tests/test_format.py +61 -0
- sknetwork/utils/tests/test_kmeans.py +21 -0
- sknetwork/utils/tests/test_knn.py +32 -0
- sknetwork/utils/tests/test_membership.py +24 -0
- sknetwork/utils/tests/test_neighbors.py +41 -0
- sknetwork/utils/tests/test_projection_simplex.py +33 -0
- sknetwork/utils/tests/test_seeds.py +67 -0
- sknetwork/utils/tests/test_verbose.py +15 -0
- sknetwork/utils/tests/test_ward.py +20 -0
- sknetwork/utils/timeout.py +38 -0
- sknetwork/utils/verbose.py +37 -0
- sknetwork/utils/ward.py +60 -0
- sknetwork/visualization/__init__.py +4 -0
- sknetwork/visualization/colors.py +34 -0
- sknetwork/visualization/dendrograms.py +229 -0
- sknetwork/visualization/graphs.py +819 -0
- sknetwork/visualization/tests/__init__.py +1 -0
- sknetwork/visualization/tests/test_dendrograms.py +53 -0
- sknetwork/visualization/tests/test_graphs.py +167 -0
sknetwork/gnn/layer.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# coding: utf-8
|
|
3
|
+
"""
|
|
4
|
+
Created on Thu Apr 21 2022
|
|
5
|
+
@author: Simon Delarue <sdelarue@enst.fr>
|
|
6
|
+
"""
|
|
7
|
+
from typing import Optional, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from scipy import sparse
|
|
11
|
+
|
|
12
|
+
from sknetwork.gnn.activation import BaseActivation
|
|
13
|
+
from sknetwork.gnn.loss import BaseLoss
|
|
14
|
+
from sknetwork.gnn.base_layer import BaseLayer
|
|
15
|
+
from sknetwork.utils.check import add_self_loops
|
|
16
|
+
from sknetwork.linalg import diag_pinv
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Convolution(BaseLayer):
|
|
20
|
+
"""Graph convolutional layer.
|
|
21
|
+
|
|
22
|
+
Apply the following function to the embedding :math:`X`:
|
|
23
|
+
|
|
24
|
+
:math:`\\sigma(\\bar AXW + b)`,
|
|
25
|
+
|
|
26
|
+
where :math:`\\bar A` is the normalized adjacency matrix (possibly with inserted self-embeddings),
|
|
27
|
+
:math:`W`, :math:`b` are trainable parameters and :math:`\\sigma` is the activation function.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
layer_type : str
|
|
32
|
+
Layer type. Can be either ``'Conv'``, convolutional operator as in [1] or ``'Sage'``, as in [2].
|
|
33
|
+
out_channels: int
|
|
34
|
+
Dimension of the output.
|
|
35
|
+
activation: str (default = ``'Relu'``) or custom activation.
|
|
36
|
+
Activation function.
|
|
37
|
+
If a string, can be either ``'Identity'``, ``'Relu'``, ``'Sigmoid'`` or ``'Softmax'``.
|
|
38
|
+
use_bias: bool (default = `True`)
|
|
39
|
+
If ``True``, add a bias vector.
|
|
40
|
+
normalization: str (default = ``'both'``)
|
|
41
|
+
Normalization of the adjacency matrix for message passing.
|
|
42
|
+
Can be either `'left'`` (left normalization by the degrees), ``'right'`` (right normalization by the degrees),
|
|
43
|
+
``'both'`` (symmetric normalization by the square root of degrees, default) or ``None`` (no normalization).
|
|
44
|
+
self_embeddings: bool (default = `True`)
|
|
45
|
+
If ``True``, consider self-embedding in addition to neighbors embedding for each node of the graph.
|
|
46
|
+
sample_size: int (default = 25)
|
|
47
|
+
Size of neighborhood sampled for each node. Used only for ``'Sage'`` layer.
|
|
48
|
+
|
|
49
|
+
Attributes
|
|
50
|
+
----------
|
|
51
|
+
weight: np.ndarray,
|
|
52
|
+
Trainable weight matrix.
|
|
53
|
+
bias: np.ndarray
|
|
54
|
+
Bias vector.
|
|
55
|
+
embedding: np.ndarray
|
|
56
|
+
Embedding of the nodes (before activation).
|
|
57
|
+
output: np.ndarray
|
|
58
|
+
Output of the layer (after activation).
|
|
59
|
+
|
|
60
|
+
References
|
|
61
|
+
----------
|
|
62
|
+
[1] Kipf, T., & Welling, M. (2017).
|
|
63
|
+
`Semi-supervised Classification with Graph Convolutional Networks.
|
|
64
|
+
<https://arxiv.org/pdf/1609.02907.pdf>`_
|
|
65
|
+
5th International Conference on Learning Representations.
|
|
66
|
+
|
|
67
|
+
[2] Hamilton, W. Ying, R., & Leskovec, J. (2017)
|
|
68
|
+
`Inductive Representation Learning on Large Graphs.
|
|
69
|
+
<https://arxiv.org/pdf/1706.02216.pdf>`_
|
|
70
|
+
NIPS
|
|
71
|
+
"""
|
|
72
|
+
def __init__(self, layer_type: str, out_channels: int, activation: Optional[Union[BaseActivation, str]] = 'Relu',
|
|
73
|
+
use_bias: bool = True, normalization: str = 'both', self_embeddings: bool = True,
|
|
74
|
+
sample_size: int = None, loss: Optional[Union[BaseLoss, str]] = None):
|
|
75
|
+
super(Convolution, self).__init__(layer_type, out_channels, activation, use_bias, normalization,
|
|
76
|
+
self_embeddings, sample_size, loss)
|
|
77
|
+
|
|
78
|
+
def forward(self, adjacency: Union[sparse.csr_matrix, np.ndarray],
|
|
79
|
+
features: Union[sparse.csr_matrix, np.ndarray]) -> np.ndarray:
|
|
80
|
+
"""Compute graph convolution.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
adjacency
|
|
85
|
+
Adjacency matrix of the graph.
|
|
86
|
+
features : sparse.csr_matrix, np.ndarray
|
|
87
|
+
Input feature of shape :math:`(n, d)` with :math:`n` the number of nodes in the graph and :math:`d`
|
|
88
|
+
the size of feature space.
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
output: np.ndarray
|
|
93
|
+
Output of the layer.
|
|
94
|
+
"""
|
|
95
|
+
if not self.weights_initialized:
|
|
96
|
+
self._initialize_weights(features.shape[1])
|
|
97
|
+
|
|
98
|
+
n_row, n_col = adjacency.shape
|
|
99
|
+
|
|
100
|
+
weights = adjacency.dot(np.ones(n_col))
|
|
101
|
+
if self.normalization == 'left':
|
|
102
|
+
d_inv = diag_pinv(weights)
|
|
103
|
+
adjacency = d_inv.dot(adjacency)
|
|
104
|
+
elif self.normalization == 'right':
|
|
105
|
+
d_inv = diag_pinv(weights)
|
|
106
|
+
adjacency = adjacency.dot(d_inv)
|
|
107
|
+
elif self.normalization == 'both':
|
|
108
|
+
d_inv = diag_pinv(np.sqrt(weights))
|
|
109
|
+
adjacency = d_inv.dot(adjacency).dot(d_inv)
|
|
110
|
+
|
|
111
|
+
if self.self_embeddings:
|
|
112
|
+
adjacency = add_self_loops(adjacency)
|
|
113
|
+
|
|
114
|
+
message = adjacency.dot(features)
|
|
115
|
+
embedding = message.dot(self.weight)
|
|
116
|
+
|
|
117
|
+
if self.use_bias:
|
|
118
|
+
embedding += self.bias
|
|
119
|
+
|
|
120
|
+
output = self.activation.output(embedding)
|
|
121
|
+
|
|
122
|
+
self.embedding = embedding
|
|
123
|
+
self.output = output
|
|
124
|
+
|
|
125
|
+
return output
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def get_layer(layer: Union[BaseLayer, str] = 'conv', **kwargs) -> BaseLayer:
|
|
129
|
+
"""Get layer.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
layer : str or custom layer
|
|
134
|
+
If a string, must be either ``'Conv'`` (Convolution) or ``'Sage'`` (GraphSAGE).
|
|
135
|
+
|
|
136
|
+
Returns
|
|
137
|
+
-------
|
|
138
|
+
Layer object.
|
|
139
|
+
"""
|
|
140
|
+
if issubclass(type(layer), BaseLayer):
|
|
141
|
+
return layer
|
|
142
|
+
elif type(layer) == str:
|
|
143
|
+
layer = layer.lower()
|
|
144
|
+
if layer in ['conv', 'gcnconv', 'graphconv']:
|
|
145
|
+
return Convolution('conv', **kwargs)
|
|
146
|
+
elif layer in ['sage', 'sageconv', 'graphsage']:
|
|
147
|
+
kwargs['normalization'] = 'left'
|
|
148
|
+
kwargs['self_embeddings'] = True
|
|
149
|
+
return Convolution('sage', **kwargs)
|
|
150
|
+
else:
|
|
151
|
+
raise ValueError("Layer name must be \"Conv\" or \"SAGEConv\".")
|
|
152
|
+
else:
|
|
153
|
+
raise TypeError("Layer must be a string or a \"BaseLayer\" object.")
|
sknetwork/gnn/layers.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# coding: utf-8
|
|
3
|
+
"""
|
|
4
|
+
Created on Thu Apr 21 2022
|
|
5
|
+
@author: Simon Delarue <sdelarue@enst.fr>
|
|
6
|
+
"""
|
|
7
|
+
from typing import Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from scipy import sparse
|
|
11
|
+
|
|
12
|
+
from sknetwork.gnn.activation import get_activation_function
|
|
13
|
+
from sknetwork.gnn.base_layer import BaseLayer
|
|
14
|
+
from sknetwork.gnn.utils import has_self_loops, add_self_loops
|
|
15
|
+
from sknetwork.linalg import diag_pinv
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GCNConv(BaseLayer):
|
|
19
|
+
"""Graph convolutional operator.
|
|
20
|
+
|
|
21
|
+
:math:`H^{\derivative}=\sigma(D^{-1/2}\hat{A}D^{-1/2}HW + b)`,
|
|
22
|
+
|
|
23
|
+
where :math:`\hat{A} = A + I` denotes the adjacency matrix with inserted self-loops and
|
|
24
|
+
:math:`D` its diagonal degree matrix. :math:`W` and :math:`b` are trainable parameters and
|
|
25
|
+
:math:`\\sigma` is the activation function.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
out_channels: int
|
|
30
|
+
Size of each output sample.
|
|
31
|
+
activation: str (default = ``'Relu'``)
|
|
32
|
+
Activation function.
|
|
33
|
+
Can be either:
|
|
34
|
+
|
|
35
|
+
* ``'relu'``, the rectified linear unit function, returns f(x) = max(0, x)
|
|
36
|
+
* ``'sigmoid'``, the logistic sigmoid function, returns f(x) = 1 / (1 + exp(-x)).
|
|
37
|
+
* ``'softmax'``, the softmax function, returns f(x) = exp(x) / sum(exp(x))
|
|
38
|
+
use_bias: bool (default = `True`)
|
|
39
|
+
If ``True``, add a bias vector.
|
|
40
|
+
normalization: str (default = ``'Both'``)
|
|
41
|
+
Normalization of the adjacency matrix for message passing.
|
|
42
|
+
Can be either:
|
|
43
|
+
|
|
44
|
+
* ``'left'``, left normalization by the vector of degrees
|
|
45
|
+
* ``'right'``, right normalization by the vector of degrees
|
|
46
|
+
* ``'both'``, symmetric normalization by the square root of degrees
|
|
47
|
+
self_loops: bool (default = `True`)
|
|
48
|
+
If ``True``, add self-loops to each node in the graph.
|
|
49
|
+
|
|
50
|
+
Attributes
|
|
51
|
+
----------
|
|
52
|
+
weight: np.ndarray,
|
|
53
|
+
Trainable weight matrix.
|
|
54
|
+
bias: np.ndarray
|
|
55
|
+
Bias vector.
|
|
56
|
+
update: np.ndarray
|
|
57
|
+
Embedding of the nodes before the activation function.
|
|
58
|
+
embedding: np.ndarray
|
|
59
|
+
Embedding of the nodes after convolution layer.
|
|
60
|
+
|
|
61
|
+
References
|
|
62
|
+
----------
|
|
63
|
+
Kipf, T., & Welling, M. (2017).
|
|
64
|
+
`Semi-supervised Classification with Graph Convolutional Networks.
|
|
65
|
+
<https://arxiv.org/pdf/1609.02907.pdf>`_
|
|
66
|
+
5th International Conference on Learning Representations.
|
|
67
|
+
|
|
68
|
+
He, K. & Zhang, X. & Ren, S. & Sun, J. (2015).
|
|
69
|
+
`Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.
|
|
70
|
+
<https://arxiv.org/pdf/1502.01852.pdf>`_
|
|
71
|
+
Proceedings of the IEEE International Conference on Computer Vision (ICCV).
|
|
72
|
+
"""
|
|
73
|
+
def __init__(self, out_channels: int, activation: str = 'Relu', use_bias: bool = True,
|
|
74
|
+
normalization: str = 'Both', self_loops: bool = True):
|
|
75
|
+
super(GCNConv, self).__init__(out_channels, activation, use_bias, normalization, self_loops)
|
|
76
|
+
|
|
77
|
+
def forward(self, adjacency: Union[sparse.csr_matrix, np.ndarray],
|
|
78
|
+
features: Union[sparse.csr_matrix, np.ndarray]) -> np.ndarray:
|
|
79
|
+
"""Compute graph convolution.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
adjacency
|
|
84
|
+
Adjacency matrix of the graph.
|
|
85
|
+
features : sparse.csr_matrix, np.ndarray
|
|
86
|
+
Input feature of shape :math:`(n, d)` with :math:`n` the number of nodes in the graph and :math:`d`
|
|
87
|
+
the size of feature space.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
embedding: p.ndarray
|
|
92
|
+
Node embedding.
|
|
93
|
+
"""
|
|
94
|
+
if not self.weights_initialized:
|
|
95
|
+
self._initialize_weights(features.shape[1])
|
|
96
|
+
|
|
97
|
+
n_row, n_col = adjacency.shape
|
|
98
|
+
|
|
99
|
+
if self.self_loops:
|
|
100
|
+
if not has_self_loops(adjacency):
|
|
101
|
+
adjacency = add_self_loops(adjacency)
|
|
102
|
+
|
|
103
|
+
weights = adjacency.dot(np.ones(n_col))
|
|
104
|
+
if self.normalization == 'left':
|
|
105
|
+
d_inv = diag_pinv(weights)
|
|
106
|
+
adjacency = d_inv.dot(adjacency)
|
|
107
|
+
elif self.normalization == 'right':
|
|
108
|
+
d_inv = diag_pinv(weights)
|
|
109
|
+
adjacency = adjacency.dot(d_inv)
|
|
110
|
+
elif self.normalization == 'both':
|
|
111
|
+
d_inv = diag_pinv(np.sqrt(weights))
|
|
112
|
+
adjacency = d_inv.dot(adjacency).dot(d_inv)
|
|
113
|
+
|
|
114
|
+
msg = adjacency.dot(features)
|
|
115
|
+
update = msg.dot(self.weight)
|
|
116
|
+
|
|
117
|
+
if self.use_bias:
|
|
118
|
+
update += self.bias
|
|
119
|
+
|
|
120
|
+
activation_function = get_activation_function(self.activation)
|
|
121
|
+
embedding = activation_function(update)
|
|
122
|
+
|
|
123
|
+
# Keep track of results for backprop
|
|
124
|
+
self.embedding = embedding
|
|
125
|
+
self.update = update
|
|
126
|
+
|
|
127
|
+
return embedding
|
sknetwork/gnn/loss.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# coding: utf-8
|
|
3
|
+
"""
|
|
4
|
+
Created in April 2022
|
|
5
|
+
@author: Simon Delarue <sdelarue@enst.fr>
|
|
6
|
+
@author: Thomas Bonald <bonald@enst.fr>
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Union
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from sknetwork.gnn.base_activation import BaseLoss
|
|
14
|
+
from sknetwork.gnn.activation import Sigmoid, Softmax
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CrossEntropy(BaseLoss, Softmax):
|
|
18
|
+
"""Cross entropy loss with softmax activation.
|
|
19
|
+
|
|
20
|
+
For a single sample with value :math:`x` and true label :math:`y`, the cross-entropy loss
|
|
21
|
+
is:
|
|
22
|
+
|
|
23
|
+
:math:`-\\sum_i 1_{\\{y=i\\}} \\log (p_i)`
|
|
24
|
+
|
|
25
|
+
with
|
|
26
|
+
|
|
27
|
+
:math:`p_i = e^{x_i} / \\sum_j e^{x_j}`.
|
|
28
|
+
|
|
29
|
+
For :math:`n` samples, return the average loss.
|
|
30
|
+
"""
|
|
31
|
+
def __init__(self):
|
|
32
|
+
super(CrossEntropy, self).__init__()
|
|
33
|
+
self.name = 'Cross entropy'
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def loss(signal: np.ndarray, labels: np.ndarray) -> float:
|
|
37
|
+
"""Get loss value.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
signal : np.ndarray, shape (n_samples, n_channels)
|
|
42
|
+
Input signal (before activation).
|
|
43
|
+
The number of channels must be at least 2.
|
|
44
|
+
labels : np.ndarray, shape (n_samples)
|
|
45
|
+
True labels.
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
value : float
|
|
50
|
+
Loss value.
|
|
51
|
+
"""
|
|
52
|
+
n = len(labels)
|
|
53
|
+
probs = Softmax.output(signal)
|
|
54
|
+
|
|
55
|
+
# for numerical stability
|
|
56
|
+
eps = 1e-15
|
|
57
|
+
probs = np.clip(probs, eps, 1 - eps)
|
|
58
|
+
|
|
59
|
+
value = -np.log(probs[np.arange(n), labels]).sum()
|
|
60
|
+
|
|
61
|
+
return value / n
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def loss_gradient(signal: np.ndarray, labels: np.ndarray) -> np.ndarray:
|
|
65
|
+
"""Get the gradient of the loss function (including activation).
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
signal : np.ndarray, shape (n_samples, n_channels)
|
|
70
|
+
Input signal (before activation).
|
|
71
|
+
labels : np.ndarray, shape (n_samples)
|
|
72
|
+
True labels.
|
|
73
|
+
Returns
|
|
74
|
+
-------
|
|
75
|
+
gradient: float
|
|
76
|
+
Gradient of the loss function.
|
|
77
|
+
"""
|
|
78
|
+
probs = Softmax.output(signal)
|
|
79
|
+
one_hot_encoding = np.zeros_like(probs)
|
|
80
|
+
one_hot_encoding[np.arange(len(labels)), labels] = 1
|
|
81
|
+
gradient = probs - one_hot_encoding
|
|
82
|
+
|
|
83
|
+
return gradient
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class BinaryCrossEntropy(BaseLoss, Sigmoid):
|
|
87
|
+
"""Binary cross entropy loss with sigmoid activation.
|
|
88
|
+
|
|
89
|
+
For a single sample with true label :math:`y` and predicted probability :math:`p`, the binary cross-entropy loss
|
|
90
|
+
is:
|
|
91
|
+
|
|
92
|
+
:math:`-y \\log (p) - (1-y) \\log (1 - p).`
|
|
93
|
+
|
|
94
|
+
For :math:`n` samples, return the average loss.
|
|
95
|
+
"""
|
|
96
|
+
def __init__(self):
|
|
97
|
+
super(BinaryCrossEntropy, self).__init__()
|
|
98
|
+
self.name = 'Binary cross entropy'
|
|
99
|
+
|
|
100
|
+
@staticmethod
|
|
101
|
+
def loss(signal: np.ndarray, labels: np.ndarray) -> float:
|
|
102
|
+
"""Get loss value.
|
|
103
|
+
|
|
104
|
+
Parameters
|
|
105
|
+
----------
|
|
106
|
+
signal : np.ndarray, shape (n_samples, n_channels)
|
|
107
|
+
Input signal (before activation).
|
|
108
|
+
The number of channels must be at least 2.
|
|
109
|
+
labels : np.ndarray, shape (n_samples)
|
|
110
|
+
True labels.
|
|
111
|
+
|
|
112
|
+
Returns
|
|
113
|
+
-------
|
|
114
|
+
value : float
|
|
115
|
+
Loss value.
|
|
116
|
+
"""
|
|
117
|
+
probs = Sigmoid.output(signal)
|
|
118
|
+
n = len(labels)
|
|
119
|
+
|
|
120
|
+
# for numerical stability
|
|
121
|
+
eps = 1e-15
|
|
122
|
+
probs = np.clip(probs, eps, 1 - eps)
|
|
123
|
+
|
|
124
|
+
if probs.shape[1] == 1:
|
|
125
|
+
# binary labels
|
|
126
|
+
value = -np.log(probs[labels > 0]).sum()
|
|
127
|
+
value -= np.log((1 - probs)[labels == 0]).sum()
|
|
128
|
+
else:
|
|
129
|
+
# general case
|
|
130
|
+
value = -np.log(1 - probs)
|
|
131
|
+
value[np.arange(n), labels] = -np.log(probs[np.arange(n), labels])
|
|
132
|
+
value = value.sum()
|
|
133
|
+
|
|
134
|
+
return value / n
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def loss_gradient(signal: np.ndarray, labels: np.ndarray) -> np.ndarray:
|
|
138
|
+
"""Get the gradient of the loss function (including activation).
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
signal : np.ndarray, shape (n_samples, n_channels)
|
|
143
|
+
Input signal (before activation).
|
|
144
|
+
labels : np.ndarray, shape (n_samples)
|
|
145
|
+
True labels.
|
|
146
|
+
Returns
|
|
147
|
+
-------
|
|
148
|
+
gradient: float
|
|
149
|
+
Gradient of the loss function.
|
|
150
|
+
"""
|
|
151
|
+
probs = Sigmoid.output(signal)
|
|
152
|
+
gradient = (probs.T - labels).T
|
|
153
|
+
|
|
154
|
+
return gradient
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def get_loss(loss: Union[BaseLoss, str] = 'CrossEntropyLoss') -> BaseLoss:
|
|
158
|
+
"""Instantiate loss function according to parameters.
|
|
159
|
+
|
|
160
|
+
Parameters
|
|
161
|
+
----------
|
|
162
|
+
loss : str or loss function.
|
|
163
|
+
Which loss function to use. Can be ``'CrossEntropy'`` or ``'BinaryCrossEntropy'`` or custom loss.
|
|
164
|
+
|
|
165
|
+
Returns
|
|
166
|
+
-------
|
|
167
|
+
Loss function object.
|
|
168
|
+
"""
|
|
169
|
+
if issubclass(type(loss), BaseLoss):
|
|
170
|
+
return loss
|
|
171
|
+
elif type(loss) == str:
|
|
172
|
+
loss = loss.lower().replace(' ', '')
|
|
173
|
+
if loss in ['crossentropy', 'ce']:
|
|
174
|
+
return CrossEntropy()
|
|
175
|
+
elif loss in ['binarycrossentropy', 'bce']:
|
|
176
|
+
return BinaryCrossEntropy()
|
|
177
|
+
else:
|
|
178
|
+
raise ValueError("Loss must be either \"CrossEntropy\" or \"BinaryCrossEntropy\".")
|
|
179
|
+
else:
|
|
180
|
+
raise TypeError("Loss must be either an \"BaseLoss\" object or a string.")
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# coding: utf-8
|
|
3
|
+
"""
|
|
4
|
+
@author: Simon Delarue <sdelarue@enst.fr>
|
|
5
|
+
"""
|
|
6
|
+
from typing import Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from scipy import sparse
|
|
10
|
+
|
|
11
|
+
from sknetwork.utils import get_degrees
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class UniformNeighborSampler:
|
|
15
|
+
"""Neighbor node sampler.
|
|
16
|
+
|
|
17
|
+
Uniformly sample nodes over neighborhood.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
sample_size : int
|
|
22
|
+
Size of neighborhood sampled for each node.
|
|
23
|
+
"""
|
|
24
|
+
def __init__(self, sample_size: int):
|
|
25
|
+
self.sample_size = sample_size
|
|
26
|
+
|
|
27
|
+
def _sample_indexes(self, size: int) -> np.ndarray:
|
|
28
|
+
"""Randomly chose indexes without replacement.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
size : int
|
|
33
|
+
Highest index available. This index is used if lower than a threshold.
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
Array of sampled indexes.
|
|
38
|
+
"""
|
|
39
|
+
return np.random.choice(size, size=min(size, self.sample_size), replace=False)
|
|
40
|
+
|
|
41
|
+
def __call__(self, adjacency: Union[sparse.csr_matrix, np.ndarray]) -> sparse.csr_matrix:
|
|
42
|
+
"""Apply node sampling on each node and return filtered adjacency matrix.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
adjacency
|
|
47
|
+
Adjacency matrix of the graph.
|
|
48
|
+
|
|
49
|
+
Returns
|
|
50
|
+
-------
|
|
51
|
+
Filtered adjacency matrix using node sampling.
|
|
52
|
+
"""
|
|
53
|
+
n_row, _ = adjacency.shape
|
|
54
|
+
sampled_adjacency = adjacency.copy()
|
|
55
|
+
|
|
56
|
+
degrees = get_degrees(adjacency)
|
|
57
|
+
neighbor_samples = list(map(self._sample_indexes, degrees))
|
|
58
|
+
|
|
59
|
+
for i, neighbors in enumerate(neighbor_samples):
|
|
60
|
+
sampled_adjacency.data[sampled_adjacency.indptr[i]:sampled_adjacency.indptr[i + 1]] = np.zeros(degrees[i])
|
|
61
|
+
sampled_adjacency.data[sampled_adjacency.indptr[i]:sampled_adjacency.indptr[i + 1]][neighbors] = 1
|
|
62
|
+
|
|
63
|
+
sampled_adjacency.eliminate_zeros()
|
|
64
|
+
|
|
65
|
+
return sampled_adjacency
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Created on Thu Apr 21 2022
|
|
5
|
+
@author: Simon Delarue <sdelarue@enst.fr>
|
|
6
|
+
"""
|
|
7
|
+
from typing import Union, TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from sknetwork.gnn.base import BaseGNN
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BaseOptimizer:
|
|
16
|
+
"""Base class for optimizers.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
learning_rate: float (default = 0.01)
|
|
21
|
+
Learning rate for updating weights.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, learning_rate):
|
|
25
|
+
self.learning_rate = learning_rate
|
|
26
|
+
|
|
27
|
+
def step(self, gnn: BaseGNN):
|
|
28
|
+
"""Update model parameters according to gradient values.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
gnn: BaseGNNClassifier
|
|
33
|
+
Model containing parameters to update.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class GD(BaseOptimizer):
|
|
38
|
+
"""Gradient Descent optimizer.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
learning_rate: float (default = 0.01)
|
|
43
|
+
Learning rate for updating weights.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, learning_rate: float = 0.01):
|
|
47
|
+
super(GD, self).__init__(learning_rate)
|
|
48
|
+
|
|
49
|
+
def step(self, gnn: BaseGNN):
|
|
50
|
+
"""Update model parameters according to gradient values.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
gnn: BaseGNNClassifier
|
|
55
|
+
Model containing parameters to update.
|
|
56
|
+
"""
|
|
57
|
+
for idx, layer in enumerate(gnn.layers):
|
|
58
|
+
layer.weight = layer.weight - self.learning_rate * gnn.derivative_weight[idx]
|
|
59
|
+
layer.bias = layer.bias - self.learning_rate * gnn.derivative_bias[idx]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ADAM(BaseOptimizer):
|
|
63
|
+
"""Adam optimizer.
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
learning_rate: float (default = 0.01)
|
|
68
|
+
Learning rate for updating weights.
|
|
69
|
+
beta1, beta2: float
|
|
70
|
+
Coefficients used for computing running averages of gradients.
|
|
71
|
+
eps: float (default = 1e-8)
|
|
72
|
+
Term added to the denominator to improve stability.
|
|
73
|
+
|
|
74
|
+
References
|
|
75
|
+
----------
|
|
76
|
+
Kingma, D. P., & Ba, J. (2014).
|
|
77
|
+
`Adam: A method for stochastic optimization.
|
|
78
|
+
<https://arxiv.org/pdf/1412.6980.pdf>`_
|
|
79
|
+
3rd International Conference for Learning Representation.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self, learning_rate: float = 0.01, beta1: float = 0.9, beta2: float = 0.999,
|
|
83
|
+
eps: float = 1e-8):
|
|
84
|
+
super(ADAM, self).__init__(learning_rate)
|
|
85
|
+
self.beta1 = beta1
|
|
86
|
+
self.beta2 = beta2
|
|
87
|
+
self.eps = eps
|
|
88
|
+
self.m_derivative_weight, self.v_derivative_weight = [], []
|
|
89
|
+
self.m_derivative_bias, self.v_derivative_bias = [], []
|
|
90
|
+
self.t = 0
|
|
91
|
+
|
|
92
|
+
def step(self, gnn: BaseGNN):
|
|
93
|
+
"""Update model parameters according to gradient values and parameters.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
gnn: `BaseGNNClassifier`
|
|
98
|
+
Model containing parameters to update.
|
|
99
|
+
"""
|
|
100
|
+
if self.t == 0:
|
|
101
|
+
self.m_derivative_weight, self.v_derivative_weight = \
|
|
102
|
+
[np.zeros(x.shape) for x in gnn.derivative_weight], [np.zeros(x.shape) for x in gnn.derivative_weight]
|
|
103
|
+
self.m_derivative_bias, self.v_derivative_bias = \
|
|
104
|
+
[np.zeros(x.shape) for x in gnn.derivative_bias], [np.zeros(x.shape) for x in gnn.derivative_bias]
|
|
105
|
+
|
|
106
|
+
for idx, layer in enumerate(gnn.layers):
|
|
107
|
+
self.t += 1
|
|
108
|
+
|
|
109
|
+
# Moving averages
|
|
110
|
+
self.m_derivative_weight[idx] = \
|
|
111
|
+
self.beta1 * self.m_derivative_weight[idx] + (1 - self.beta1) * gnn.derivative_weight[idx]
|
|
112
|
+
self.m_derivative_bias[idx] = \
|
|
113
|
+
self.beta1 * self.m_derivative_bias[idx] + (1 - self.beta1) * gnn.derivative_bias[idx]
|
|
114
|
+
|
|
115
|
+
self.v_derivative_weight[idx] = \
|
|
116
|
+
self.beta2 * self.v_derivative_weight[idx] + (1 - self.beta2) * (gnn.derivative_weight[idx] ** 2)
|
|
117
|
+
self.v_derivative_bias[idx] = \
|
|
118
|
+
self.beta2 * self.v_derivative_bias[idx] + (1 - self.beta2) * (gnn.derivative_bias[idx] ** 2)
|
|
119
|
+
|
|
120
|
+
# Correcting moving averages
|
|
121
|
+
denom_1 = (1 - self.beta1 ** self.t)
|
|
122
|
+
denom_2 = (1 - self.beta2 ** self.t)
|
|
123
|
+
|
|
124
|
+
m_derivative_weight_corr = self.m_derivative_weight[idx] / denom_1
|
|
125
|
+
m_derivative_bias_corr = self.m_derivative_bias[idx] / denom_1
|
|
126
|
+
v_derivative_weight_corr = self.v_derivative_weight[idx] / denom_2
|
|
127
|
+
v_derivative_bias_corr = self.v_derivative_bias[idx] / denom_2
|
|
128
|
+
|
|
129
|
+
# Parameters update
|
|
130
|
+
layer.weight = \
|
|
131
|
+
layer.weight - (self.learning_rate * m_derivative_weight_corr) / (np.sqrt(v_derivative_weight_corr)
|
|
132
|
+
+ self.eps)
|
|
133
|
+
layer.bias = \
|
|
134
|
+
layer.bias - (self.learning_rate * m_derivative_bias_corr) / (np.sqrt(v_derivative_bias_corr)
|
|
135
|
+
+ self.eps)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def get_optimizer(optimizer: Union[BaseOptimizer, str] = 'Adam', learning_rate: float = 0.01) -> BaseOptimizer:
|
|
139
|
+
"""Instantiate optimizer according to parameters.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
optimizer : str or optimizer
|
|
144
|
+
Which optimizer to use. Can be ``'Adam'`` or ``'GD'`` or custom optimizer.
|
|
145
|
+
learning_rate: float
|
|
146
|
+
Learning rate.
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
Optimizer object
|
|
151
|
+
"""
|
|
152
|
+
if issubclass(type(optimizer), BaseOptimizer):
|
|
153
|
+
return optimizer
|
|
154
|
+
elif type(optimizer) == str:
|
|
155
|
+
optimizer = optimizer.lower()
|
|
156
|
+
if optimizer == 'adam':
|
|
157
|
+
return ADAM(learning_rate=learning_rate)
|
|
158
|
+
elif optimizer in ['gd', 'gradient']:
|
|
159
|
+
return GD(learning_rate=learning_rate)
|
|
160
|
+
else:
|
|
161
|
+
raise ValueError("Optimizer must be either \"Adam\" or \"GD\" (Gradient Descent).")
|
|
162
|
+
else:
|
|
163
|
+
raise TypeError("Optimizer must be either an \"BaseOptimizer\" object or a string.")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""tests for gnn"""
|