scikit-network 0.28.3__cp39-cp39-macosx_12_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-network might be problematic. Click here for more details.

Files changed (240) hide show
  1. scikit_network-0.28.3.dist-info/AUTHORS.rst +41 -0
  2. scikit_network-0.28.3.dist-info/LICENSE +34 -0
  3. scikit_network-0.28.3.dist-info/METADATA +457 -0
  4. scikit_network-0.28.3.dist-info/RECORD +240 -0
  5. scikit_network-0.28.3.dist-info/WHEEL +5 -0
  6. scikit_network-0.28.3.dist-info/top_level.txt +1 -0
  7. sknetwork/__init__.py +21 -0
  8. sknetwork/classification/__init__.py +8 -0
  9. sknetwork/classification/base.py +84 -0
  10. sknetwork/classification/base_rank.py +143 -0
  11. sknetwork/classification/diffusion.py +134 -0
  12. sknetwork/classification/knn.py +162 -0
  13. sknetwork/classification/metrics.py +205 -0
  14. sknetwork/classification/pagerank.py +66 -0
  15. sknetwork/classification/propagation.py +152 -0
  16. sknetwork/classification/tests/__init__.py +1 -0
  17. sknetwork/classification/tests/test_API.py +35 -0
  18. sknetwork/classification/tests/test_diffusion.py +37 -0
  19. sknetwork/classification/tests/test_knn.py +24 -0
  20. sknetwork/classification/tests/test_metrics.py +53 -0
  21. sknetwork/classification/tests/test_pagerank.py +20 -0
  22. sknetwork/classification/tests/test_propagation.py +24 -0
  23. sknetwork/classification/vote.cpython-39-darwin.so +0 -0
  24. sknetwork/classification/vote.pyx +58 -0
  25. sknetwork/clustering/__init__.py +7 -0
  26. sknetwork/clustering/base.py +102 -0
  27. sknetwork/clustering/kmeans.py +142 -0
  28. sknetwork/clustering/louvain.py +255 -0
  29. sknetwork/clustering/louvain_core.cpython-39-darwin.so +0 -0
  30. sknetwork/clustering/louvain_core.pyx +134 -0
  31. sknetwork/clustering/metrics.py +91 -0
  32. sknetwork/clustering/postprocess.py +66 -0
  33. sknetwork/clustering/propagation_clustering.py +108 -0
  34. sknetwork/clustering/tests/__init__.py +1 -0
  35. sknetwork/clustering/tests/test_API.py +37 -0
  36. sknetwork/clustering/tests/test_kmeans.py +47 -0
  37. sknetwork/clustering/tests/test_louvain.py +104 -0
  38. sknetwork/clustering/tests/test_metrics.py +50 -0
  39. sknetwork/clustering/tests/test_post_processing.py +23 -0
  40. sknetwork/clustering/tests/test_postprocess.py +39 -0
  41. sknetwork/data/__init__.py +5 -0
  42. sknetwork/data/load.py +408 -0
  43. sknetwork/data/models.py +459 -0
  44. sknetwork/data/parse.py +621 -0
  45. sknetwork/data/test_graphs.py +84 -0
  46. sknetwork/data/tests/__init__.py +1 -0
  47. sknetwork/data/tests/test_API.py +30 -0
  48. sknetwork/data/tests/test_load.py +95 -0
  49. sknetwork/data/tests/test_models.py +52 -0
  50. sknetwork/data/tests/test_parse.py +253 -0
  51. sknetwork/data/tests/test_test_graphs.py +30 -0
  52. sknetwork/data/tests/test_toy_graphs.py +68 -0
  53. sknetwork/data/toy_graphs.py +619 -0
  54. sknetwork/embedding/__init__.py +10 -0
  55. sknetwork/embedding/base.py +90 -0
  56. sknetwork/embedding/force_atlas.py +197 -0
  57. sknetwork/embedding/louvain_embedding.py +174 -0
  58. sknetwork/embedding/louvain_hierarchy.py +142 -0
  59. sknetwork/embedding/metrics.py +66 -0
  60. sknetwork/embedding/random_projection.py +133 -0
  61. sknetwork/embedding/spectral.py +214 -0
  62. sknetwork/embedding/spring.py +198 -0
  63. sknetwork/embedding/svd.py +363 -0
  64. sknetwork/embedding/tests/__init__.py +1 -0
  65. sknetwork/embedding/tests/test_API.py +73 -0
  66. sknetwork/embedding/tests/test_force_atlas.py +35 -0
  67. sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
  68. sknetwork/embedding/tests/test_louvain_hierarchy.py +19 -0
  69. sknetwork/embedding/tests/test_metrics.py +29 -0
  70. sknetwork/embedding/tests/test_random_projection.py +28 -0
  71. sknetwork/embedding/tests/test_spectral.py +84 -0
  72. sknetwork/embedding/tests/test_spring.py +50 -0
  73. sknetwork/embedding/tests/test_svd.py +37 -0
  74. sknetwork/flow/__init__.py +3 -0
  75. sknetwork/flow/flow.py +73 -0
  76. sknetwork/flow/tests/__init__.py +1 -0
  77. sknetwork/flow/tests/test_flow.py +17 -0
  78. sknetwork/flow/tests/test_utils.py +69 -0
  79. sknetwork/flow/utils.py +91 -0
  80. sknetwork/gnn/__init__.py +10 -0
  81. sknetwork/gnn/activation.py +117 -0
  82. sknetwork/gnn/base.py +155 -0
  83. sknetwork/gnn/base_activation.py +89 -0
  84. sknetwork/gnn/base_layer.py +109 -0
  85. sknetwork/gnn/gnn_classifier.py +381 -0
  86. sknetwork/gnn/layer.py +153 -0
  87. sknetwork/gnn/layers.py +127 -0
  88. sknetwork/gnn/loss.py +180 -0
  89. sknetwork/gnn/neighbor_sampler.py +65 -0
  90. sknetwork/gnn/optimizer.py +163 -0
  91. sknetwork/gnn/tests/__init__.py +1 -0
  92. sknetwork/gnn/tests/test_activation.py +56 -0
  93. sknetwork/gnn/tests/test_base.py +79 -0
  94. sknetwork/gnn/tests/test_base_layer.py +37 -0
  95. sknetwork/gnn/tests/test_gnn_classifier.py +192 -0
  96. sknetwork/gnn/tests/test_layers.py +80 -0
  97. sknetwork/gnn/tests/test_loss.py +33 -0
  98. sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
  99. sknetwork/gnn/tests/test_optimizer.py +43 -0
  100. sknetwork/gnn/tests/test_utils.py +93 -0
  101. sknetwork/gnn/utils.py +219 -0
  102. sknetwork/hierarchy/__init__.py +7 -0
  103. sknetwork/hierarchy/base.py +69 -0
  104. sknetwork/hierarchy/louvain_hierarchy.py +264 -0
  105. sknetwork/hierarchy/metrics.py +234 -0
  106. sknetwork/hierarchy/paris.cpython-39-darwin.so +0 -0
  107. sknetwork/hierarchy/paris.pyx +317 -0
  108. sknetwork/hierarchy/postprocess.py +350 -0
  109. sknetwork/hierarchy/tests/__init__.py +1 -0
  110. sknetwork/hierarchy/tests/test_API.py +25 -0
  111. sknetwork/hierarchy/tests/test_algos.py +29 -0
  112. sknetwork/hierarchy/tests/test_metrics.py +62 -0
  113. sknetwork/hierarchy/tests/test_postprocess.py +57 -0
  114. sknetwork/hierarchy/tests/test_ward.py +25 -0
  115. sknetwork/hierarchy/ward.py +94 -0
  116. sknetwork/linalg/__init__.py +9 -0
  117. sknetwork/linalg/basics.py +37 -0
  118. sknetwork/linalg/diteration.cpython-39-darwin.so +0 -0
  119. sknetwork/linalg/diteration.pyx +49 -0
  120. sknetwork/linalg/eig_solver.py +93 -0
  121. sknetwork/linalg/laplacian.py +15 -0
  122. sknetwork/linalg/normalization.py +66 -0
  123. sknetwork/linalg/operators.py +225 -0
  124. sknetwork/linalg/polynome.py +76 -0
  125. sknetwork/linalg/ppr_solver.py +170 -0
  126. sknetwork/linalg/push.cpython-39-darwin.so +0 -0
  127. sknetwork/linalg/push.pyx +73 -0
  128. sknetwork/linalg/sparse_lowrank.py +142 -0
  129. sknetwork/linalg/svd_solver.py +91 -0
  130. sknetwork/linalg/tests/__init__.py +1 -0
  131. sknetwork/linalg/tests/test_eig.py +44 -0
  132. sknetwork/linalg/tests/test_laplacian.py +18 -0
  133. sknetwork/linalg/tests/test_normalization.py +38 -0
  134. sknetwork/linalg/tests/test_operators.py +70 -0
  135. sknetwork/linalg/tests/test_polynome.py +38 -0
  136. sknetwork/linalg/tests/test_ppr.py +50 -0
  137. sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
  138. sknetwork/linalg/tests/test_svd.py +38 -0
  139. sknetwork/linkpred/__init__.py +4 -0
  140. sknetwork/linkpred/base.py +80 -0
  141. sknetwork/linkpred/first_order.py +508 -0
  142. sknetwork/linkpred/first_order_core.cpython-39-darwin.so +0 -0
  143. sknetwork/linkpred/first_order_core.pyx +315 -0
  144. sknetwork/linkpred/postprocessing.py +98 -0
  145. sknetwork/linkpred/tests/__init__.py +1 -0
  146. sknetwork/linkpred/tests/test_API.py +49 -0
  147. sknetwork/linkpred/tests/test_postprocessing.py +21 -0
  148. sknetwork/path/__init__.py +4 -0
  149. sknetwork/path/metrics.py +148 -0
  150. sknetwork/path/search.py +65 -0
  151. sknetwork/path/shortest_path.py +186 -0
  152. sknetwork/path/tests/__init__.py +1 -0
  153. sknetwork/path/tests/test_metrics.py +29 -0
  154. sknetwork/path/tests/test_search.py +25 -0
  155. sknetwork/path/tests/test_shortest_path.py +45 -0
  156. sknetwork/ranking/__init__.py +9 -0
  157. sknetwork/ranking/base.py +56 -0
  158. sknetwork/ranking/betweenness.cpython-39-darwin.so +0 -0
  159. sknetwork/ranking/betweenness.pyx +99 -0
  160. sknetwork/ranking/closeness.py +95 -0
  161. sknetwork/ranking/harmonic.py +82 -0
  162. sknetwork/ranking/hits.py +94 -0
  163. sknetwork/ranking/katz.py +81 -0
  164. sknetwork/ranking/pagerank.py +107 -0
  165. sknetwork/ranking/postprocess.py +25 -0
  166. sknetwork/ranking/tests/__init__.py +1 -0
  167. sknetwork/ranking/tests/test_API.py +34 -0
  168. sknetwork/ranking/tests/test_betweenness.py +38 -0
  169. sknetwork/ranking/tests/test_closeness.py +34 -0
  170. sknetwork/ranking/tests/test_hits.py +20 -0
  171. sknetwork/ranking/tests/test_pagerank.py +69 -0
  172. sknetwork/regression/__init__.py +4 -0
  173. sknetwork/regression/base.py +56 -0
  174. sknetwork/regression/diffusion.py +190 -0
  175. sknetwork/regression/tests/__init__.py +1 -0
  176. sknetwork/regression/tests/test_API.py +34 -0
  177. sknetwork/regression/tests/test_diffusion.py +48 -0
  178. sknetwork/sknetwork.py +3 -0
  179. sknetwork/topology/__init__.py +9 -0
  180. sknetwork/topology/dag.py +74 -0
  181. sknetwork/topology/dag_core.cpython-39-darwin.so +0 -0
  182. sknetwork/topology/dag_core.pyx +38 -0
  183. sknetwork/topology/kcliques.cpython-39-darwin.so +0 -0
  184. sknetwork/topology/kcliques.pyx +193 -0
  185. sknetwork/topology/kcore.cpython-39-darwin.so +0 -0
  186. sknetwork/topology/kcore.pyx +120 -0
  187. sknetwork/topology/structure.py +234 -0
  188. sknetwork/topology/tests/__init__.py +1 -0
  189. sknetwork/topology/tests/test_cliques.py +28 -0
  190. sknetwork/topology/tests/test_cores.py +21 -0
  191. sknetwork/topology/tests/test_dag.py +26 -0
  192. sknetwork/topology/tests/test_structure.py +99 -0
  193. sknetwork/topology/tests/test_triangles.py +42 -0
  194. sknetwork/topology/tests/test_wl_coloring.py +49 -0
  195. sknetwork/topology/tests/test_wl_kernel.py +31 -0
  196. sknetwork/topology/triangles.cpython-39-darwin.so +0 -0
  197. sknetwork/topology/triangles.pyx +166 -0
  198. sknetwork/topology/weisfeiler_lehman.py +163 -0
  199. sknetwork/topology/weisfeiler_lehman_core.cpython-39-darwin.so +0 -0
  200. sknetwork/topology/weisfeiler_lehman_core.pyx +116 -0
  201. sknetwork/utils/__init__.py +40 -0
  202. sknetwork/utils/base.py +35 -0
  203. sknetwork/utils/check.py +354 -0
  204. sknetwork/utils/co_neighbor.py +71 -0
  205. sknetwork/utils/format.py +219 -0
  206. sknetwork/utils/kmeans.py +89 -0
  207. sknetwork/utils/knn.py +166 -0
  208. sknetwork/utils/knn1d.cpython-39-darwin.so +0 -0
  209. sknetwork/utils/knn1d.pyx +80 -0
  210. sknetwork/utils/membership.py +82 -0
  211. sknetwork/utils/minheap.cpython-39-darwin.so +0 -0
  212. sknetwork/utils/minheap.pxd +22 -0
  213. sknetwork/utils/minheap.pyx +111 -0
  214. sknetwork/utils/neighbors.py +115 -0
  215. sknetwork/utils/seeds.py +75 -0
  216. sknetwork/utils/simplex.py +140 -0
  217. sknetwork/utils/tests/__init__.py +1 -0
  218. sknetwork/utils/tests/test_base.py +28 -0
  219. sknetwork/utils/tests/test_bunch.py +16 -0
  220. sknetwork/utils/tests/test_check.py +190 -0
  221. sknetwork/utils/tests/test_co_neighbor.py +43 -0
  222. sknetwork/utils/tests/test_format.py +61 -0
  223. sknetwork/utils/tests/test_kmeans.py +21 -0
  224. sknetwork/utils/tests/test_knn.py +32 -0
  225. sknetwork/utils/tests/test_membership.py +24 -0
  226. sknetwork/utils/tests/test_neighbors.py +41 -0
  227. sknetwork/utils/tests/test_projection_simplex.py +33 -0
  228. sknetwork/utils/tests/test_seeds.py +67 -0
  229. sknetwork/utils/tests/test_verbose.py +15 -0
  230. sknetwork/utils/tests/test_ward.py +20 -0
  231. sknetwork/utils/timeout.py +38 -0
  232. sknetwork/utils/verbose.py +37 -0
  233. sknetwork/utils/ward.py +60 -0
  234. sknetwork/visualization/__init__.py +4 -0
  235. sknetwork/visualization/colors.py +34 -0
  236. sknetwork/visualization/dendrograms.py +229 -0
  237. sknetwork/visualization/graphs.py +819 -0
  238. sknetwork/visualization/tests/__init__.py +1 -0
  239. sknetwork/visualization/tests/test_dendrograms.py +53 -0
  240. sknetwork/visualization/tests/test_graphs.py +167 -0
@@ -0,0 +1,89 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on April 2022
5
+ @author: Simon Delarue <sdelarue@enst.fr>
6
+ """
7
+ import numpy as np
8
+
9
+
10
+ class BaseActivation:
11
+ """Base class for activation functions.
12
+ Parameters
13
+ ----------
14
+ name : str
15
+ Name of the activation function.
16
+ """
17
+ def __init__(self, name: str = 'custom'):
18
+ self.name = name
19
+
20
+ @staticmethod
21
+ def output(signal: np.ndarray) -> np.ndarray:
22
+ """Output of the activation function.
23
+
24
+ Parameters
25
+ ----------
26
+ signal : np.ndarray, shape (n_samples, n_channels)
27
+ Input signal.
28
+
29
+ Returns
30
+ -------
31
+ output : np.ndarray, shape (n_samples, n_channels)
32
+ Output signal.
33
+ """
34
+ output = signal
35
+ return output
36
+
37
+ @staticmethod
38
+ def gradient(signal: np.ndarray, direction: np.ndarray) -> np.ndarray:
39
+ """Gradient of the activation function.
40
+
41
+ Parameters
42
+ ----------
43
+ signal : np.ndarray, shape (n_samples, n_channels)
44
+ Input signal.
45
+ direction : np.ndarray, shape (n_samples, n_channels)
46
+ Direction where the gradient is taken.
47
+
48
+ Returns
49
+ -------
50
+ gradient : np.ndarray, shape (n_samples, n_channels)
51
+ Gradient.
52
+ """
53
+ gradient = direction
54
+ return gradient
55
+
56
+
57
+ class BaseLoss(BaseActivation):
58
+ """Base class for loss functions."""
59
+ @staticmethod
60
+ def loss(signal: np.ndarray, labels: np.ndarray) -> float:
61
+ """Get the loss value.
62
+
63
+ Parameters
64
+ ----------
65
+ signal : np.ndarray, shape (n_samples, n_channels)
66
+ Input signal (before activation).
67
+ labels : np.ndarray, shape (n_samples)
68
+ True labels.
69
+ """
70
+ return 0
71
+
72
+ @staticmethod
73
+ def loss_gradient(signal: np.ndarray, labels: np.ndarray) -> np.ndarray:
74
+ """Gradient of the loss function.
75
+
76
+ Parameters
77
+ ----------
78
+ signal : np.ndarray, shape (n_samples, n_channels)
79
+ Input signal.
80
+ labels : np.ndarray, shape (n_samples,)
81
+ True labels.
82
+
83
+ Returns
84
+ -------
85
+ gradient : np.ndarray, shape (n_samples, n_channels)
86
+ Gradient.
87
+ """
88
+ gradient = np.ones_like(signal)
89
+ return gradient
@@ -0,0 +1,109 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on July 2022
5
+ @author: Simon Delarue <sdelarue@enst.fr>
6
+ """
7
+ from typing import Optional, Union
8
+
9
+ import numpy as np
10
+
11
+ from sknetwork.gnn.activation import BaseActivation, get_activation
12
+ from sknetwork.gnn.loss import BaseLoss, get_loss
13
+
14
+
15
+ class BaseLayer:
16
+ """Base class for GNN layers.
17
+
18
+ Parameters
19
+ ----------
20
+ layer_type : str
21
+ Layer type. Can be either ``'Conv'`` (Convolution) or ``'Sage'`` (GraphSAGE).
22
+ out_channels: int
23
+ Dimension of the output.
24
+ activation: str (default = ``'Relu'``) or custom activation.
25
+ Activation function.
26
+ If a string, can be either ``'Identity'``, ``'Relu'``, ``'Sigmoid'`` or ``'Softmax'``.
27
+ use_bias: bool (default = `True`)
28
+ If ``True``, add a bias vector.
29
+ normalization: str (default = ``'both'``)
30
+ Normalization of the adjacency matrix for message passing.
31
+ Can be either `'left'`` (left normalization by the degrees), ``'right'`` (right normalization by the degrees),
32
+ ``'both'`` (symmetric normalization by the square root of degrees, default) or ``None`` (no normalization).
33
+ self_embeddings: bool (default = `True`)
34
+ If ``True``, consider self-embedding in addition to neighbors embedding for each node of the graph.
35
+ sample_size: int (default = 25)
36
+ Size of neighborhood sampled for each node. Used only for ``'SAGEConv'`` layer.
37
+
38
+ Attributes
39
+ ----------
40
+ weight : np.ndarray,
41
+ Trainable weight matrix.
42
+ bias : np.ndarray
43
+ Bias vector.
44
+ embedding : np.ndarray
45
+ Embedding of the nodes (before activation).
46
+ output : np.ndarray
47
+ Output of the layer (after activation).
48
+ """
49
+ def __init__(self, layer_type: str, out_channels: int, activation: Optional[Union[BaseActivation, str]] = 'Relu',
50
+ use_bias: bool = True, normalization: str = 'both', self_embeddings: bool = True,
51
+ sample_size: int = 25, loss: Optional[Union[BaseLoss, str]] = None):
52
+ self.layer_type = layer_type
53
+ self.out_channels = out_channels
54
+ if loss is None:
55
+ self.activation = get_activation(activation)
56
+ else:
57
+ self.activation = get_loss(loss)
58
+ self.use_bias = use_bias
59
+ self.normalization = normalization.lower()
60
+ self.self_embeddings = self_embeddings
61
+ self.sample_size = sample_size
62
+ self.weight = None
63
+ self.bias = None
64
+ self.embedding = None
65
+ self.output = None
66
+ self.weights_initialized = False
67
+
68
+ def _initialize_weights(self, in_channels: int):
69
+ """Initialize weights and bias.
70
+
71
+ Parameters
72
+ ----------
73
+ in_channels: int
74
+ Number of input channels.
75
+ """
76
+ # Trainable parameters with He initialization
77
+ self.weight = np.random.randn(in_channels, self.out_channels) * np.sqrt(2 / self.out_channels)
78
+ if self.use_bias:
79
+ self.bias = np.zeros((self.out_channels, 1)).T
80
+ self.weights_initialized = True
81
+
82
+ def forward(self, *args, **kwargs):
83
+ """Compute forward pass."""
84
+ raise NotImplementedError
85
+
86
+ def __call__(self, *args, **kwargs):
87
+ return self.forward(*args, **kwargs)
88
+
89
+ def __repr__(self) -> str:
90
+ """ String representation of object
91
+
92
+ Returns
93
+ -------
94
+ str
95
+ String representation of object
96
+ """
97
+ print_attr = ['out_channels', 'layer_type', 'activation', 'use_bias', 'normalization', 'self_embeddings']
98
+ if 'sage' in self.layer_type:
99
+ print_attr.append('sample_size')
100
+ attributes_dict = {k: v for k, v in self.__dict__.items() if k in print_attr}
101
+ string = ''
102
+
103
+ for k, v in attributes_dict.items():
104
+ if k == 'activation':
105
+ string += f'{k}: {v.name}, '
106
+ else:
107
+ string += f'{k}: {v}, '
108
+
109
+ return f' {self.__class__.__name__}({string[:-2]})'
@@ -0,0 +1,381 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created in April 2022
5
+ @author: Simon Delarue <sdelarue@enst.fr>
6
+ """
7
+ from typing import Optional, Union
8
+ from collections import defaultdict
9
+
10
+ import numpy as np
11
+ from scipy import sparse
12
+
13
+ from sknetwork.classification.metrics import get_accuracy_score
14
+ from sknetwork.gnn.base import BaseGNN
15
+ from sknetwork.gnn.loss import BaseLoss
16
+ from sknetwork.gnn.layer import get_layer
17
+ from sknetwork.gnn.neighbor_sampler import UniformNeighborSampler
18
+ from sknetwork.gnn.optimizer import BaseOptimizer
19
+ from sknetwork.gnn.utils import filter_mask, check_existing_masks, check_output, check_early_stopping, check_loss, \
20
+ get_layers
21
+ from sknetwork.utils.check import check_format, check_nonnegative, check_square
22
+
23
+
24
+ class GNNClassifier(BaseGNN):
25
+ """Graph Neural Network for node classification.
26
+
27
+ Parameters
28
+ ----------
29
+ dims : list or int
30
+ Dimensions of the output of each layer (in forward direction).
31
+ If an integer, dimension of the output layer (no hidden layer).
32
+ Optional if ``layers`` is specified.
33
+ layer_types : list or str
34
+ Layer types (in forward direction).
35
+ If a string, use the same type of layer for all layers.
36
+ Can be ``'Conv'``, graph convolutional layer (default) or ``'Sage'`` (GraphSage).
37
+ activations : list or str
38
+ Activation functions (in forward direction).
39
+ If a string, use the same activation function for all layers.
40
+ Can be either ``'Identity'``, ``'Relu'``, ``'Sigmoid'`` or ``'Softmax'`` (default = ``'Relu'``).
41
+ use_bias : list or bool
42
+ Whether to use a bias term at each layer.
43
+ If ``True``, use a bias term at all layers.
44
+ normalizations : list or str
45
+ Normalization of the adjacency matrix for message passing.
46
+ If a string, use the same normalization for all layers.
47
+ Can be either `'left'`` (left normalization by the degrees), ``'right'`` (right normalization by the degrees),
48
+ ``'both'`` (symmetric normalization by the square root of degrees, default) or ``None`` (no normalization).
49
+ self_embeddings : list or str
50
+ Whether to add a self embeddings to each node of the graph for message passing.
51
+ If ``True``, add self-embeddings at all layers.
52
+ sample_sizes : list or int
53
+ Size of neighborhood sampled for each node. Used only for ``'Sage'`` layer type.
54
+ loss : str (default = ``'CrossEntropy'``) or BaseLoss
55
+ Loss function name or custom loss.
56
+ layers : list or None
57
+ Custom layers. If used, previous parameters are ignored.
58
+ optimizer : str or optimizer
59
+ * ``'Adam'``, stochastic gradient-based optimizer (default).
60
+ * ``'GD'``, gradient descent.
61
+ learning_rate : float
62
+ Learning rate.
63
+ early_stopping : bool (default = ``True``)
64
+ Whether to use early stopping to end training.
65
+ If ``True``, training terminates when validation score is not improving for `patience` number of epochs.
66
+ patience : int (default = 10)
67
+ Number of iterations with no improvement to wait before stopping fitting.
68
+ verbose : bool
69
+ Verbose mode.
70
+
71
+ Attributes
72
+ ----------
73
+ conv2, ..., conv1: :class:'GCNConv'
74
+ Graph convolutional layers.
75
+ output_ : array
76
+ Output of the GNN.
77
+ labels_: np.ndarray
78
+ Predicted node labels.
79
+ history_: dict
80
+ Training history per epoch: {``'embedding'``, ``'loss'``, ``'train_accuracy'``, ``'val_accuracy'``}.
81
+
82
+ Example
83
+ -------
84
+ >>> from sknetwork.gnn.gnn_classifier import GNNClassifier
85
+ >>> from sknetwork.data import karate_club
86
+ >>> from numpy.random import randint
87
+ >>> graph = karate_club(metadata=True)
88
+ >>> adjacency = graph.adjacency
89
+ >>> labels = graph.labels
90
+ >>> features = adjacency.copy()
91
+ >>> gnn = GNNClassifier(dims=1, early_stopping=False)
92
+ >>> labels_pred = gnn.fit_predict(adjacency, features, labels, random_state=42)
93
+ >>> np.round(np.mean(labels_pred == labels), 2)
94
+ 0.91
95
+ """
96
+
97
+ def __init__(self, dims: Optional[Union[int, list]] = None, layer_types: Union[str, list] = 'Conv',
98
+ activations: Union[str, list] = 'ReLu', use_bias: Union[bool, list] = True,
99
+ normalizations: Union[str, list] = 'both', self_embeddings: Union[bool, list] = True,
100
+ sample_sizes: Union[int, list] = 25, loss: Union[BaseLoss, str] = 'CrossEntropy',
101
+ layers: Optional[list] = None, optimizer: Union[BaseOptimizer, str] = 'Adam',
102
+ learning_rate: float = 0.01, early_stopping: bool = True, patience: int = 10, verbose: bool = False):
103
+ super(GNNClassifier, self).__init__(loss, optimizer, learning_rate, verbose)
104
+ if layers is not None:
105
+ layers = [get_layer(layer) for layer in layers]
106
+ else:
107
+ layers = get_layers(dims, layer_types, activations, use_bias, normalizations, self_embeddings, sample_sizes,
108
+ loss)
109
+ self.loss = check_loss(layers[-1])
110
+ self.layers = layers
111
+ self.early_stopping = early_stopping
112
+ self.patience = patience
113
+ self.history_ = defaultdict(list)
114
+
115
+ def forward(self, adjacency: Union[list, sparse.csr_matrix], features: Union[sparse.csr_matrix, np.ndarray]) \
116
+ -> np.ndarray:
117
+ """Perform a forward pass on the graph and return the output.
118
+
119
+ Parameters
120
+ ----------
121
+ adjacency : Union[list, sparse.csr_matrix]
122
+ Adjacency matrix or list of sampled adjacency matrices.
123
+ features : sparse.csr_matrix, np.ndarray
124
+ Features, array of shape (n_nodes, n_features).
125
+
126
+ Returns
127
+ -------
128
+ output : np.ndarray
129
+ Output of the GNN.
130
+ """
131
+ h = features.copy()
132
+ for i, layer in enumerate(self.layers):
133
+ if isinstance(adjacency, list):
134
+ h = layer(adjacency[i], h)
135
+ else:
136
+ h = layer(adjacency, h)
137
+ return h
138
+
139
+ @staticmethod
140
+ def _compute_predictions(output: np.ndarray) -> np.ndarray:
141
+ """Compute predictions from the output of the GNN.
142
+
143
+ Parameters
144
+ ----------
145
+ output : np.ndarray
146
+ Output of the GNN.
147
+
148
+ Returns
149
+ -------
150
+ labels : np.ndarray
151
+ Predicted labels.
152
+ """
153
+ if output.shape[1] == 1:
154
+ labels = (output.ravel() > 0.5).astype(int)
155
+ else:
156
+ labels = output.argmax(axis=1)
157
+ return labels
158
+
159
+ def fit(self, adjacency: Union[sparse.csr_matrix, np.ndarray], features: Union[sparse.csr_matrix, np.ndarray],
160
+ labels: np.ndarray, n_epochs: int = 100, train_mask: Optional[np.ndarray] = None,
161
+ val_mask: Optional[np.ndarray] = None,
162
+ test_mask: Optional[np.ndarray] = None, train_size: Optional[float] = 0.8,
163
+ val_size: Optional[float] = 0.1, test_size: Optional[float] = 0.1, resample: bool = False,
164
+ reinit: bool = False, random_state: Optional[int] = None, history: bool = False) -> 'GNNClassifier':
165
+ """ Fit model to data and store trained parameters.
166
+
167
+ Parameters
168
+ ----------
169
+ adjacency : sparse.csr_matrix
170
+ Adjacency matrix of the graph.
171
+ features : sparse.csr_matrix, np.ndarray
172
+ Input feature of shape :math:`(n, d)` with :math:`n` the number of nodes in the graph and :math:`d`
173
+ the size of feature space.
174
+ labels : np.ndarray
175
+ Label vectors of length :math:`n`, with :math:`n` the number of nodes in `adjacency`. A value of `labels`
176
+ equals `-1` means no label. The associated nodes are not considered in training steps.
177
+ n_epochs : int (default = 100)
178
+ Number of epochs (iterations over the whole graph).
179
+ train_mask, val_mask, test_mask : np.ndarray
180
+ Boolean array indicating whether nodes are in training/validation/test set.
181
+ train_size, test_size : float
182
+ Proportion of the nodes in the training/test set (between 0 and 1).
183
+ Only used if the corresponding masks are ``None``.
184
+ val_size : float
185
+ Proportion of the training set used for validation (between 0 and 1).
186
+ Only used if the corresponding mask is ``None``.
187
+ resample : bool (default = ``False``)
188
+ If ``True``, resample the train/test/validation sets before fitting.
189
+ Otherwise, the train/test/validation sets remain the same after the first fit.
190
+ reinit: bool (default = ``False``)
191
+ If ``True``, reinit the trainable parameters of the GNN (weights and biases).
192
+ random_state : int
193
+ Pass an int for reproducible results across multiple runs.
194
+ history : bool (default = ``False``)
195
+ If ``True``, save training history.
196
+ """
197
+ if reinit:
198
+ for layer in self.layers:
199
+ layer.weights_initialized = False
200
+
201
+ if resample or self.output_ is None:
202
+ exists_mask, self.train_mask, self.val_mask, self.test_mask = \
203
+ check_existing_masks(labels, train_mask, val_mask, test_mask, train_size, val_size, test_size)
204
+ if not exists_mask:
205
+ self._generate_masks(train_size, val_size, test_size, random_state)
206
+
207
+ check_format(adjacency)
208
+ check_format(features)
209
+
210
+ check_output(self.layers[-1].out_channels, labels)
211
+
212
+ early_stopping = check_early_stopping(self.early_stopping, self.val_mask, self.patience)
213
+
214
+ # List of sampled adjacencies (one per layer)
215
+ adjacencies = self._sample_nodes(adjacency)
216
+
217
+ best_val_acc = 0
218
+ trigger_times = 0
219
+
220
+ for epoch in range(n_epochs):
221
+
222
+ # Forward
223
+ output = self.forward(adjacencies, features)
224
+
225
+ # Compute predictions
226
+ labels_pred = self._compute_predictions(output)
227
+
228
+ # Loss
229
+ loss_value = self.loss.loss(output[self.train_mask], labels[self.train_mask])
230
+
231
+ # Accuracy
232
+ train_acc = get_accuracy_score(labels[self.train_mask], labels_pred[self.train_mask])
233
+ if self.val_mask is not None and any(self.val_mask):
234
+ val_acc = get_accuracy_score(labels[self.val_mask], labels_pred[self.val_mask])
235
+ else:
236
+ val_acc = None
237
+
238
+ # Backpropagation
239
+ self.backward(features, labels, self.train_mask)
240
+
241
+ # Update weights using optimizer
242
+ self.optimizer.step(self)
243
+
244
+ # Save results
245
+ if history:
246
+ self.history_['embedding'].append(self.layers[-1].embedding)
247
+ self.history_['loss'].append(loss_value)
248
+ self.history_['train_accuracy'].append(train_acc)
249
+ if val_acc is not None:
250
+ self.history_['val_accuracy'].append(val_acc)
251
+
252
+ if n_epochs > 10 and epoch % int(n_epochs / 10) == 0:
253
+ if val_acc is not None:
254
+ self.log.print(
255
+ f'In epoch {epoch:>3}, loss: {loss_value:.3f}, train accuracy: {train_acc:.3f}, '
256
+ f'val accuracy: {val_acc:.3f}')
257
+ else:
258
+ self.log.print(
259
+ f'In epoch {epoch:>3}, loss: {loss_value:.3f}, train accuracy: {train_acc:.3f}')
260
+ elif n_epochs <= 10:
261
+ if val_acc is not None:
262
+ self.log.print(
263
+ f'In epoch {epoch:>3}, loss: {loss_value:.3f}, train accuracy: {train_acc:.3f}, '
264
+ f'val accuracy: {val_acc:.3f}')
265
+ else:
266
+ self.log.print(
267
+ f'In epoch {epoch:>3}, loss: {loss_value:.3f}, train accuracy: {train_acc:.3f}')
268
+
269
+ # Early stopping
270
+ if early_stopping:
271
+ if val_acc > best_val_acc:
272
+ trigger_times = 0
273
+ best_val_acc = val_acc
274
+ else:
275
+ trigger_times += 1
276
+ if trigger_times >= self.patience:
277
+ self.log.print('Early stopping.')
278
+ break
279
+
280
+ output = self.forward(adjacencies, features)
281
+ labels_pred = self._compute_predictions(output)
282
+
283
+ self.embedding_ = self.layers[-1].embedding
284
+ self.output_ = self.layers[-1].output
285
+ self.labels_ = labels_pred
286
+
287
+ return self
288
+
289
+ def _generate_masks(self, train_size: Optional[float] = None, val_size: Optional[float] = None,
290
+ test_size: Optional[float] = None, random_state: int = None):
291
+ """ Create training, validation and test masks.
292
+
293
+ Parameters
294
+ ----------
295
+ train_size : float
296
+ Proportion of nodes in the training set (between 0 and 1).
297
+ val_size : float
298
+ Proportion of nodes in the validation set (between 0 and 1).
299
+ test_size : float
300
+ Proportion of nodes in the test set (between 0 and 1).
301
+ random_state : int
302
+ Pass an int for reproducible results across multiple runs.
303
+ """
304
+ is_negative_labels = self.test_mask
305
+
306
+ if random_state is not None:
307
+ np.random.seed(random_state)
308
+
309
+ if train_size is None:
310
+ train_size = 1 - test_size
311
+
312
+ self.train_mask = filter_mask(~is_negative_labels, train_size)
313
+ self.val_mask = filter_mask(np.logical_and(~self.train_mask, ~is_negative_labels), val_size)
314
+ self.test_mask = np.logical_and(~self.train_mask, ~self.val_mask)
315
+
316
+ def _sample_nodes(self, adjacency: Union[sparse.csr_matrix, np.ndarray]) -> list:
317
+ """Perform node sampling on adjacency matrix for GraphSAGE layers. For other layers, the
318
+ adjacency matrix remains unchanged.
319
+
320
+ Parameters
321
+ ----------
322
+ adjacency : sparse.csr_matrix
323
+ Adjacency matrix of the graph.
324
+
325
+ Returns
326
+ -------
327
+ List of (sampled) adjacency matrices.
328
+ """
329
+ adjacencies = []
330
+
331
+ for layer in self.layers:
332
+ if layer.layer_type == 'sage':
333
+ sampler = UniformNeighborSampler(sample_size=layer.sample_size)
334
+ adjacencies.append(sampler(adjacency))
335
+ else:
336
+ adjacencies.append(adjacency)
337
+
338
+ return adjacencies
339
+
340
+ def predict(self, adjacency_vectors: Union[sparse.csr_matrix, np.ndarray] = None,
341
+ feature_vectors: Union[sparse.csr_matrix, np.ndarray] = None) -> np.ndarray:
342
+ """Predict labels for new nodes. If called without parameters, labels are returned for all nodes.
343
+
344
+ Parameters
345
+ ----------
346
+ adjacency_vectors : np.ndarray
347
+ Square adjacency matrix. Array of shape (n, n).
348
+ feature_vectors : np.ndarray
349
+ Features row vectors. Array of shape (n, n_feat). The number of features n_feat must match with the one
350
+ used during training.
351
+
352
+ Returns
353
+ -------
354
+ labels : np.ndarray
355
+ Label of each node of the graph.
356
+ """
357
+ self._check_fitted()
358
+
359
+ if adjacency_vectors is None and feature_vectors is None:
360
+ return self.labels_
361
+ elif adjacency_vectors is None:
362
+ adjacency_vectors = sparse.identity(feature_vectors.shape[0], format='csr')
363
+
364
+ check_square(adjacency_vectors)
365
+ check_nonnegative(adjacency_vectors)
366
+ feature_vectors = check_format(feature_vectors)
367
+
368
+ n_row, n_col = adjacency_vectors.shape
369
+ feat_row, feat_col = feature_vectors.shape
370
+
371
+ if n_col != feat_row:
372
+ raise ValueError(f'Dimension mismatch: dim0={n_col} != dim1={feat_row}.')
373
+ elif feat_col != self.layers[0].weight.shape[0]:
374
+ raise ValueError(f'Dimension mismatch: current number of features is {feat_col} whereas GNN has been '
375
+ f'trained with '
376
+ f'{self.layers[0].weight.shape[0]} features.')
377
+
378
+ h = self.forward(adjacency_vectors, feature_vectors)
379
+ labels = self._compute_predictions(h)
380
+
381
+ return labels