scikit-network 0.28.3__cp39-cp39-macosx_12_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-network might be problematic. Click here for more details.

Files changed (240) hide show
  1. scikit_network-0.28.3.dist-info/AUTHORS.rst +41 -0
  2. scikit_network-0.28.3.dist-info/LICENSE +34 -0
  3. scikit_network-0.28.3.dist-info/METADATA +457 -0
  4. scikit_network-0.28.3.dist-info/RECORD +240 -0
  5. scikit_network-0.28.3.dist-info/WHEEL +5 -0
  6. scikit_network-0.28.3.dist-info/top_level.txt +1 -0
  7. sknetwork/__init__.py +21 -0
  8. sknetwork/classification/__init__.py +8 -0
  9. sknetwork/classification/base.py +84 -0
  10. sknetwork/classification/base_rank.py +143 -0
  11. sknetwork/classification/diffusion.py +134 -0
  12. sknetwork/classification/knn.py +162 -0
  13. sknetwork/classification/metrics.py +205 -0
  14. sknetwork/classification/pagerank.py +66 -0
  15. sknetwork/classification/propagation.py +152 -0
  16. sknetwork/classification/tests/__init__.py +1 -0
  17. sknetwork/classification/tests/test_API.py +35 -0
  18. sknetwork/classification/tests/test_diffusion.py +37 -0
  19. sknetwork/classification/tests/test_knn.py +24 -0
  20. sknetwork/classification/tests/test_metrics.py +53 -0
  21. sknetwork/classification/tests/test_pagerank.py +20 -0
  22. sknetwork/classification/tests/test_propagation.py +24 -0
  23. sknetwork/classification/vote.cpython-39-darwin.so +0 -0
  24. sknetwork/classification/vote.pyx +58 -0
  25. sknetwork/clustering/__init__.py +7 -0
  26. sknetwork/clustering/base.py +102 -0
  27. sknetwork/clustering/kmeans.py +142 -0
  28. sknetwork/clustering/louvain.py +255 -0
  29. sknetwork/clustering/louvain_core.cpython-39-darwin.so +0 -0
  30. sknetwork/clustering/louvain_core.pyx +134 -0
  31. sknetwork/clustering/metrics.py +91 -0
  32. sknetwork/clustering/postprocess.py +66 -0
  33. sknetwork/clustering/propagation_clustering.py +108 -0
  34. sknetwork/clustering/tests/__init__.py +1 -0
  35. sknetwork/clustering/tests/test_API.py +37 -0
  36. sknetwork/clustering/tests/test_kmeans.py +47 -0
  37. sknetwork/clustering/tests/test_louvain.py +104 -0
  38. sknetwork/clustering/tests/test_metrics.py +50 -0
  39. sknetwork/clustering/tests/test_post_processing.py +23 -0
  40. sknetwork/clustering/tests/test_postprocess.py +39 -0
  41. sknetwork/data/__init__.py +5 -0
  42. sknetwork/data/load.py +408 -0
  43. sknetwork/data/models.py +459 -0
  44. sknetwork/data/parse.py +621 -0
  45. sknetwork/data/test_graphs.py +84 -0
  46. sknetwork/data/tests/__init__.py +1 -0
  47. sknetwork/data/tests/test_API.py +30 -0
  48. sknetwork/data/tests/test_load.py +95 -0
  49. sknetwork/data/tests/test_models.py +52 -0
  50. sknetwork/data/tests/test_parse.py +253 -0
  51. sknetwork/data/tests/test_test_graphs.py +30 -0
  52. sknetwork/data/tests/test_toy_graphs.py +68 -0
  53. sknetwork/data/toy_graphs.py +619 -0
  54. sknetwork/embedding/__init__.py +10 -0
  55. sknetwork/embedding/base.py +90 -0
  56. sknetwork/embedding/force_atlas.py +197 -0
  57. sknetwork/embedding/louvain_embedding.py +174 -0
  58. sknetwork/embedding/louvain_hierarchy.py +142 -0
  59. sknetwork/embedding/metrics.py +66 -0
  60. sknetwork/embedding/random_projection.py +133 -0
  61. sknetwork/embedding/spectral.py +214 -0
  62. sknetwork/embedding/spring.py +198 -0
  63. sknetwork/embedding/svd.py +363 -0
  64. sknetwork/embedding/tests/__init__.py +1 -0
  65. sknetwork/embedding/tests/test_API.py +73 -0
  66. sknetwork/embedding/tests/test_force_atlas.py +35 -0
  67. sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
  68. sknetwork/embedding/tests/test_louvain_hierarchy.py +19 -0
  69. sknetwork/embedding/tests/test_metrics.py +29 -0
  70. sknetwork/embedding/tests/test_random_projection.py +28 -0
  71. sknetwork/embedding/tests/test_spectral.py +84 -0
  72. sknetwork/embedding/tests/test_spring.py +50 -0
  73. sknetwork/embedding/tests/test_svd.py +37 -0
  74. sknetwork/flow/__init__.py +3 -0
  75. sknetwork/flow/flow.py +73 -0
  76. sknetwork/flow/tests/__init__.py +1 -0
  77. sknetwork/flow/tests/test_flow.py +17 -0
  78. sknetwork/flow/tests/test_utils.py +69 -0
  79. sknetwork/flow/utils.py +91 -0
  80. sknetwork/gnn/__init__.py +10 -0
  81. sknetwork/gnn/activation.py +117 -0
  82. sknetwork/gnn/base.py +155 -0
  83. sknetwork/gnn/base_activation.py +89 -0
  84. sknetwork/gnn/base_layer.py +109 -0
  85. sknetwork/gnn/gnn_classifier.py +381 -0
  86. sknetwork/gnn/layer.py +153 -0
  87. sknetwork/gnn/layers.py +127 -0
  88. sknetwork/gnn/loss.py +180 -0
  89. sknetwork/gnn/neighbor_sampler.py +65 -0
  90. sknetwork/gnn/optimizer.py +163 -0
  91. sknetwork/gnn/tests/__init__.py +1 -0
  92. sknetwork/gnn/tests/test_activation.py +56 -0
  93. sknetwork/gnn/tests/test_base.py +79 -0
  94. sknetwork/gnn/tests/test_base_layer.py +37 -0
  95. sknetwork/gnn/tests/test_gnn_classifier.py +192 -0
  96. sknetwork/gnn/tests/test_layers.py +80 -0
  97. sknetwork/gnn/tests/test_loss.py +33 -0
  98. sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
  99. sknetwork/gnn/tests/test_optimizer.py +43 -0
  100. sknetwork/gnn/tests/test_utils.py +93 -0
  101. sknetwork/gnn/utils.py +219 -0
  102. sknetwork/hierarchy/__init__.py +7 -0
  103. sknetwork/hierarchy/base.py +69 -0
  104. sknetwork/hierarchy/louvain_hierarchy.py +264 -0
  105. sknetwork/hierarchy/metrics.py +234 -0
  106. sknetwork/hierarchy/paris.cpython-39-darwin.so +0 -0
  107. sknetwork/hierarchy/paris.pyx +317 -0
  108. sknetwork/hierarchy/postprocess.py +350 -0
  109. sknetwork/hierarchy/tests/__init__.py +1 -0
  110. sknetwork/hierarchy/tests/test_API.py +25 -0
  111. sknetwork/hierarchy/tests/test_algos.py +29 -0
  112. sknetwork/hierarchy/tests/test_metrics.py +62 -0
  113. sknetwork/hierarchy/tests/test_postprocess.py +57 -0
  114. sknetwork/hierarchy/tests/test_ward.py +25 -0
  115. sknetwork/hierarchy/ward.py +94 -0
  116. sknetwork/linalg/__init__.py +9 -0
  117. sknetwork/linalg/basics.py +37 -0
  118. sknetwork/linalg/diteration.cpython-39-darwin.so +0 -0
  119. sknetwork/linalg/diteration.pyx +49 -0
  120. sknetwork/linalg/eig_solver.py +93 -0
  121. sknetwork/linalg/laplacian.py +15 -0
  122. sknetwork/linalg/normalization.py +66 -0
  123. sknetwork/linalg/operators.py +225 -0
  124. sknetwork/linalg/polynome.py +76 -0
  125. sknetwork/linalg/ppr_solver.py +170 -0
  126. sknetwork/linalg/push.cpython-39-darwin.so +0 -0
  127. sknetwork/linalg/push.pyx +73 -0
  128. sknetwork/linalg/sparse_lowrank.py +142 -0
  129. sknetwork/linalg/svd_solver.py +91 -0
  130. sknetwork/linalg/tests/__init__.py +1 -0
  131. sknetwork/linalg/tests/test_eig.py +44 -0
  132. sknetwork/linalg/tests/test_laplacian.py +18 -0
  133. sknetwork/linalg/tests/test_normalization.py +38 -0
  134. sknetwork/linalg/tests/test_operators.py +70 -0
  135. sknetwork/linalg/tests/test_polynome.py +38 -0
  136. sknetwork/linalg/tests/test_ppr.py +50 -0
  137. sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
  138. sknetwork/linalg/tests/test_svd.py +38 -0
  139. sknetwork/linkpred/__init__.py +4 -0
  140. sknetwork/linkpred/base.py +80 -0
  141. sknetwork/linkpred/first_order.py +508 -0
  142. sknetwork/linkpred/first_order_core.cpython-39-darwin.so +0 -0
  143. sknetwork/linkpred/first_order_core.pyx +315 -0
  144. sknetwork/linkpred/postprocessing.py +98 -0
  145. sknetwork/linkpred/tests/__init__.py +1 -0
  146. sknetwork/linkpred/tests/test_API.py +49 -0
  147. sknetwork/linkpred/tests/test_postprocessing.py +21 -0
  148. sknetwork/path/__init__.py +4 -0
  149. sknetwork/path/metrics.py +148 -0
  150. sknetwork/path/search.py +65 -0
  151. sknetwork/path/shortest_path.py +186 -0
  152. sknetwork/path/tests/__init__.py +1 -0
  153. sknetwork/path/tests/test_metrics.py +29 -0
  154. sknetwork/path/tests/test_search.py +25 -0
  155. sknetwork/path/tests/test_shortest_path.py +45 -0
  156. sknetwork/ranking/__init__.py +9 -0
  157. sknetwork/ranking/base.py +56 -0
  158. sknetwork/ranking/betweenness.cpython-39-darwin.so +0 -0
  159. sknetwork/ranking/betweenness.pyx +99 -0
  160. sknetwork/ranking/closeness.py +95 -0
  161. sknetwork/ranking/harmonic.py +82 -0
  162. sknetwork/ranking/hits.py +94 -0
  163. sknetwork/ranking/katz.py +81 -0
  164. sknetwork/ranking/pagerank.py +107 -0
  165. sknetwork/ranking/postprocess.py +25 -0
  166. sknetwork/ranking/tests/__init__.py +1 -0
  167. sknetwork/ranking/tests/test_API.py +34 -0
  168. sknetwork/ranking/tests/test_betweenness.py +38 -0
  169. sknetwork/ranking/tests/test_closeness.py +34 -0
  170. sknetwork/ranking/tests/test_hits.py +20 -0
  171. sknetwork/ranking/tests/test_pagerank.py +69 -0
  172. sknetwork/regression/__init__.py +4 -0
  173. sknetwork/regression/base.py +56 -0
  174. sknetwork/regression/diffusion.py +190 -0
  175. sknetwork/regression/tests/__init__.py +1 -0
  176. sknetwork/regression/tests/test_API.py +34 -0
  177. sknetwork/regression/tests/test_diffusion.py +48 -0
  178. sknetwork/sknetwork.py +3 -0
  179. sknetwork/topology/__init__.py +9 -0
  180. sknetwork/topology/dag.py +74 -0
  181. sknetwork/topology/dag_core.cpython-39-darwin.so +0 -0
  182. sknetwork/topology/dag_core.pyx +38 -0
  183. sknetwork/topology/kcliques.cpython-39-darwin.so +0 -0
  184. sknetwork/topology/kcliques.pyx +193 -0
  185. sknetwork/topology/kcore.cpython-39-darwin.so +0 -0
  186. sknetwork/topology/kcore.pyx +120 -0
  187. sknetwork/topology/structure.py +234 -0
  188. sknetwork/topology/tests/__init__.py +1 -0
  189. sknetwork/topology/tests/test_cliques.py +28 -0
  190. sknetwork/topology/tests/test_cores.py +21 -0
  191. sknetwork/topology/tests/test_dag.py +26 -0
  192. sknetwork/topology/tests/test_structure.py +99 -0
  193. sknetwork/topology/tests/test_triangles.py +42 -0
  194. sknetwork/topology/tests/test_wl_coloring.py +49 -0
  195. sknetwork/topology/tests/test_wl_kernel.py +31 -0
  196. sknetwork/topology/triangles.cpython-39-darwin.so +0 -0
  197. sknetwork/topology/triangles.pyx +166 -0
  198. sknetwork/topology/weisfeiler_lehman.py +163 -0
  199. sknetwork/topology/weisfeiler_lehman_core.cpython-39-darwin.so +0 -0
  200. sknetwork/topology/weisfeiler_lehman_core.pyx +116 -0
  201. sknetwork/utils/__init__.py +40 -0
  202. sknetwork/utils/base.py +35 -0
  203. sknetwork/utils/check.py +354 -0
  204. sknetwork/utils/co_neighbor.py +71 -0
  205. sknetwork/utils/format.py +219 -0
  206. sknetwork/utils/kmeans.py +89 -0
  207. sknetwork/utils/knn.py +166 -0
  208. sknetwork/utils/knn1d.cpython-39-darwin.so +0 -0
  209. sknetwork/utils/knn1d.pyx +80 -0
  210. sknetwork/utils/membership.py +82 -0
  211. sknetwork/utils/minheap.cpython-39-darwin.so +0 -0
  212. sknetwork/utils/minheap.pxd +22 -0
  213. sknetwork/utils/minheap.pyx +111 -0
  214. sknetwork/utils/neighbors.py +115 -0
  215. sknetwork/utils/seeds.py +75 -0
  216. sknetwork/utils/simplex.py +140 -0
  217. sknetwork/utils/tests/__init__.py +1 -0
  218. sknetwork/utils/tests/test_base.py +28 -0
  219. sknetwork/utils/tests/test_bunch.py +16 -0
  220. sknetwork/utils/tests/test_check.py +190 -0
  221. sknetwork/utils/tests/test_co_neighbor.py +43 -0
  222. sknetwork/utils/tests/test_format.py +61 -0
  223. sknetwork/utils/tests/test_kmeans.py +21 -0
  224. sknetwork/utils/tests/test_knn.py +32 -0
  225. sknetwork/utils/tests/test_membership.py +24 -0
  226. sknetwork/utils/tests/test_neighbors.py +41 -0
  227. sknetwork/utils/tests/test_projection_simplex.py +33 -0
  228. sknetwork/utils/tests/test_seeds.py +67 -0
  229. sknetwork/utils/tests/test_verbose.py +15 -0
  230. sknetwork/utils/tests/test_ward.py +20 -0
  231. sknetwork/utils/timeout.py +38 -0
  232. sknetwork/utils/verbose.py +37 -0
  233. sknetwork/utils/ward.py +60 -0
  234. sknetwork/visualization/__init__.py +4 -0
  235. sknetwork/visualization/colors.py +34 -0
  236. sknetwork/visualization/dendrograms.py +229 -0
  237. sknetwork/visualization/graphs.py +819 -0
  238. sknetwork/visualization/tests/__init__.py +1 -0
  239. sknetwork/visualization/tests/test_dendrograms.py +53 -0
  240. sknetwork/visualization/tests/test_graphs.py +167 -0
@@ -0,0 +1,56 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """tests for activation"""
4
+
5
+ import unittest
6
+
7
+ from sknetwork.gnn.activation import *
8
+
9
+
10
+ class TestActivation(unittest.TestCase):
11
+
12
+ def test_get_activation(self):
13
+ self.assertTrue(isinstance(get_activation('Identity'), BaseActivation))
14
+ self.assertTrue(isinstance(get_activation('Relu'), ReLu))
15
+ self.assertTrue(isinstance(get_activation('Sigmoid'), Sigmoid))
16
+ self.assertTrue(isinstance(get_activation('Softmax'), Softmax))
17
+ with self.assertRaises(ValueError):
18
+ get_activation('foo')
19
+
20
+ base_act = BaseActivation()
21
+ self.assertTrue(base_act == get_activation(base_act))
22
+ with self.assertRaises(TypeError):
23
+ get_activation(0)
24
+
25
+ def test_activation_identity(self):
26
+ activation = get_activation('Identity')
27
+ signal = np.arange(5)
28
+ self.assertTrue((activation.output(signal) == signal).all())
29
+ direction = np.arange(5)
30
+ self.assertTrue((activation.gradient(signal, direction) == direction).all())
31
+
32
+ def test_activation_relu(self):
33
+ activation = get_activation('ReLu')
34
+ signal = np.linspace(-2, 2, 5)
35
+ self.assertTrue((activation.output(signal) == [0., 0., 0., 1., 2.]).all())
36
+ direction = np.arange(5)
37
+ self.assertTrue((activation.gradient(signal, direction) == direction * (signal > 0)).all())
38
+
39
+ def test_activation_sigmoid(self):
40
+ activation = get_activation('Sigmoid')
41
+ signal = np.array([-np.inf, -1.5, 0, 1.5, np.inf])
42
+ self.assertTrue(np.allclose(activation.output(signal), np.array([0., 0.18242552, 0.5, 0.81757448, 1.])))
43
+ signal = np.array([[-1000, 1000, 1000]])
44
+ direction = np.arange(3)
45
+ self.assertTrue(np.allclose(activation.output(signal), np.array([[0., 1., 1.]])))
46
+ self.assertTrue(np.allclose(activation.gradient(signal, direction), np.array([[0., 0., 0.]])))
47
+
48
+ def test_activation_softmax(self):
49
+ activation = get_activation('Softmax')
50
+ signal = np.array([[-1, 0, 3, 5]])
51
+ output = activation.output(signal)
52
+ self.assertTrue(np.allclose(output, np.array([[0.0021657, 0.00588697, 0.11824302, 0.87370431]])))
53
+ signal = np.array([[-1000, 1000, 1000]])
54
+ direction = np.arange(3)
55
+ self.assertTrue(np.allclose(activation.output(signal), np.array([[0., 0.5, 0.5]])))
56
+ self.assertTrue(np.allclose(activation.gradient(signal, direction), np.array([[0., -0.25, 0.25]])))
@@ -0,0 +1,79 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """tests for base gnn"""
4
+
5
+ import unittest
6
+
7
+ import numpy as np
8
+
9
+ from sknetwork.data.test_graphs import test_graph
10
+ from sknetwork.gnn.base import BaseGNN
11
+ from sknetwork.gnn.gnn_classifier import GNNClassifier
12
+ from sknetwork.gnn.layer import Convolution
13
+ from sknetwork.gnn.optimizer import ADAM
14
+
15
+
16
+ class TestBaseGNN(unittest.TestCase):
17
+
18
+ def setUp(self) -> None:
19
+ """Test graph for tests."""
20
+ self.adjacency = test_graph()
21
+ self.n = self.adjacency.shape[0]
22
+ self.features = self.adjacency
23
+ self.labels = np.array([0]*5 + [1]*5)
24
+
25
+ def test_base_gnn_fit(self):
26
+ gnn = BaseGNN()
27
+ with self.assertRaises(NotImplementedError):
28
+ gnn.fit(self.adjacency, self.features, self.labels, test_size=0.2)
29
+
30
+ def test_gnn_fit_transform(self):
31
+ gnn = GNNClassifier(dims=2, layer_types='Conv', activations='Relu', optimizer='GD', verbose=False)
32
+ embedding = gnn.fit_transform(self.adjacency, self.features, labels=self.labels, n_epochs=1, val_size=0.2)
33
+ self.assertTrue(len(embedding) == self.n)
34
+ self.assertTrue(embedding.shape == (self.n, 2))
35
+
36
+ def test_gnn_custom_optimizer(self):
37
+ gnn = GNNClassifier(dims=2, layer_types='Conv', activations='Relu', optimizer=ADAM(beta1=0.5), verbose=False)
38
+ embedding = gnn.fit_transform(self.adjacency, self.features, labels=self.labels, n_epochs=1, val_size=0.2)
39
+ self.assertTrue(len(embedding) == self.n)
40
+ self.assertTrue(embedding.shape == (self.n, 2))
41
+
42
+ def test_gnn_custom_layers(self):
43
+ gnn = GNNClassifier(layers=[Convolution('Conv', 2, loss='CrossEntropy')], optimizer=ADAM(beta1=0.5), verbose=False)
44
+ embedding = gnn.fit_transform(self.adjacency, self.features, labels=self.labels, n_epochs=1, val_size=0.2)
45
+ self.assertTrue(len(embedding) == self.n)
46
+ self.assertTrue(embedding.shape == (self.n, 2))
47
+
48
+ gnn = GNNClassifier(layers=[Convolution('SAGEConv', 2, sample_size=5, loss='CrossEntropy')], optimizer=ADAM(beta1=0.5),
49
+ verbose=False)
50
+ embedding = gnn.fit_transform(self.adjacency, self.features, labels=self.labels, n_epochs=1, val_size=0.2)
51
+ self.assertTrue(len(embedding) == self.n)
52
+ self.assertTrue(embedding.shape == (self.n, 2))
53
+
54
+ def test_gnn_custom(self):
55
+ gnn = GNNClassifier(dims=[20, 8, 2], layer_types='conv',
56
+ activations=['Relu', 'Sigmoid', 'Softmax'], optimizer='Adam', verbose=False)
57
+ self.assertTrue(isinstance(gnn, GNNClassifier))
58
+ self.assertTrue(gnn.layers[-1].activation.name == 'Cross entropy')
59
+ y_pred = gnn.fit_predict(self.adjacency, self.features, labels=self.labels, n_epochs=1, val_size=0.2)
60
+ self.assertTrue(len(y_pred) == self.n)
61
+
62
+ def test_check_fitted(self):
63
+ gnn = BaseGNN()
64
+ with self.assertRaises(ValueError):
65
+ gnn._check_fitted()
66
+ gnn = GNNClassifier(dims=2, layer_types='conv', activations='Relu', optimizer='GD', verbose=False)
67
+ gnn.fit_transform(self.adjacency, self.features, labels=self.labels, n_epochs=1, val_size=0.2)
68
+ fit_gnn = gnn._check_fitted()
69
+ self.assertTrue(isinstance(fit_gnn, GNNClassifier))
70
+ self.assertTrue(fit_gnn.embedding_ is not None)
71
+
72
+ def test_base_gnn_repr(self):
73
+ gnn = GNNClassifier(dims=[8, 2], layer_types='conv', activations=['Relu', 'Softmax'], optimizer='Adam')
74
+ self.assertTrue(gnn.__repr__().startswith("GNNClassifier"))
75
+
76
+ def test_gnn_predict(self):
77
+ gnn = BaseGNN()
78
+ with self.assertRaises(NotImplementedError):
79
+ gnn.predict()
@@ -0,0 +1,37 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """tests for base layer gnn"""
4
+
5
+ import unittest
6
+
7
+ import numpy as np
8
+
9
+ from sknetwork.data.test_graphs import test_graph
10
+ from sknetwork.gnn.base_layer import BaseLayer
11
+
12
+
13
+ class TestBaseLayer(unittest.TestCase):
14
+
15
+ def setUp(self) -> None:
16
+ """Test graph for tests."""
17
+ self.adjacency = test_graph()
18
+ self.n = self.adjacency.shape[0]
19
+ self.features = self.adjacency
20
+ self.labels = np.array([0]*5 + [1]*5)
21
+ self.base_layer = BaseLayer('Conv', len(self.labels))
22
+
23
+ def test_base_layer_init(self):
24
+ with self.assertRaises(NotImplementedError):
25
+ self.base_layer.forward(self.adjacency, self.features)
26
+
27
+ def test_base_layer_initialize_weights(self):
28
+ self.base_layer._initialize_weights(10)
29
+ self.assertTrue(self.base_layer.weight.shape == (10, len(self.labels)))
30
+ self.assertTrue(all(self.base_layer.bias[0] == np.zeros((len(self.labels), 1)).T[0]))
31
+ self.assertTrue(self.base_layer.weights_initialized)
32
+
33
+ def test_base_layer_repr(self):
34
+ self.assertTrue(self.base_layer.__repr__().startswith(" BaseLayer(layer_type: Conv, out_channels: 10"))
35
+ sagelayer = BaseLayer(layer_type='sageconv', out_channels=len(self.labels))
36
+ self.assertTrue('sample_size' in sagelayer.__repr__())
37
+ self.assertTrue('sageconv' in sagelayer.__repr__())
@@ -0,0 +1,192 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """tests for gnn classifier"""
4
+
5
+ import unittest
6
+
7
+ import numpy as np
8
+ from scipy import sparse
9
+
10
+ from sknetwork.data.test_graphs import test_graph
11
+ from sknetwork.gnn.gnn_classifier import GNNClassifier
12
+
13
+
14
+ class TestGNNClassifier(unittest.TestCase):
15
+
16
+ def setUp(self) -> None:
17
+ """Test graph for tests."""
18
+ self.adjacency = test_graph()
19
+ self.n = self.adjacency.shape[0]
20
+ self.features = self.adjacency
21
+ self.labels = np.array([0]*5 + [1]*5)
22
+
23
+ def test_gnn_classifier_sparse_feat(self):
24
+ gnn = GNNClassifier([3, 2], 'Conv', 'Softmax')
25
+ self.assertTrue(gnn.layers[0].activation.name == 'Softmax')
26
+ self.assertTrue(gnn.layers[1].activation.name == 'Cross entropy')
27
+ labels_pred = gnn.fit_predict(self.adjacency, self.features, self.labels, val_size=0.2)
28
+ embedding = gnn.embedding_
29
+ self.assertTrue(len(labels_pred) == self.n)
30
+ self.assertTrue(embedding.shape == (self.n, 2))
31
+
32
+ def test_gnn_classifier_dense_feat(self):
33
+ # features not in nparray
34
+ features = self.adjacency.todense()
35
+ gnn = GNNClassifier(2)
36
+ with self.assertRaises(TypeError):
37
+ gnn.fit_predict(self.adjacency, features, self.labels, val_size=0.2)
38
+
39
+ # features in numpy array
40
+ features = np.array(self.adjacency.todense())
41
+ gnn = GNNClassifier(2, 'Conv')
42
+ y_pred = gnn.fit_predict(self.adjacency, features, self.labels, val_size=0.2)
43
+ embedding = gnn.embedding_
44
+ self.assertTrue(len(y_pred) == self.n)
45
+ self.assertTrue(embedding.shape == (self.n, 2))
46
+
47
+ def test_gnn_classifier_optimizer(self):
48
+ optimizers = ['GD', 'Adam']
49
+ for optimizer in optimizers:
50
+ gnn = GNNClassifier(2, 'Conv', optimizer=optimizer)
51
+ y_pred = gnn.fit_predict(self.adjacency, self.features, self.labels, val_size=0.2)
52
+ embedding = gnn.embedding_
53
+ self.assertTrue(len(y_pred) == self.n)
54
+ self.assertTrue(embedding.shape == (self.n, 2))
55
+
56
+ def test_gnn_classifier_binary(self):
57
+ gnn = GNNClassifier([5, 1], 'Conv', 'Softmax')
58
+ self.assertTrue(gnn.layers[1].activation.name == 'Binary cross entropy')
59
+ labels_pred = gnn.fit_predict(self.adjacency, self.features, self.labels)
60
+ self.assertTrue(len(labels_pred) == self.n)
61
+
62
+ def test_gnn_classifier_norm(self):
63
+ n_labels = len(set(self.labels))
64
+ gnn = GNNClassifier([5, n_labels], 'Conv', normalizations=['left', 'both'])
65
+ labels_pred = gnn.fit_predict(self.adjacency, self.features, self.labels)
66
+ self.assertTrue(len(labels_pred) == self.n)
67
+
68
+ def test_gnn_classifier_1label(self):
69
+ gnn = GNNClassifier(1, 'Conv', 'Relu')
70
+ labels_pred = gnn.fit_predict(self.adjacency, self.features, self.labels, val_size=0.2)
71
+ self.assertTrue(len(labels_pred) == self.n)
72
+
73
+ def test_gnn_classifier_masks(self):
74
+ gnn = GNNClassifier(2, 'Conv', 'Softmax', early_stopping=False)
75
+ train_mask = np.array([True, True, True, True, True, True, False, False, False, False])
76
+ _ = gnn.fit_predict(self.adjacency, self.features, self.labels, train_mask=train_mask, n_epochs=5)
77
+ self.assertTrue(sum(gnn.train_mask) + sum(gnn.val_mask) + sum(gnn.test_mask) == self.adjacency.shape[0])
78
+
79
+ train_mask = np.array([True, True, True, True, True, True, False, False, False, False])
80
+ _ = gnn.fit_predict(self.adjacency, self.features, self.labels, train_mask=train_mask, resample=True,
81
+ n_epochs=10)
82
+ self.assertTrue(sum(gnn.train_mask) + sum(gnn.val_mask) + sum(gnn.test_mask) == self.adjacency.shape[0])
83
+
84
+ val_mask = np.array([False, False, False, False, False, False, True, False, False, False])
85
+ _ = gnn.fit_predict(self.adjacency, self.features, self.labels, train_mask=train_mask, val_mask=val_mask)
86
+ self.assertTrue(sum(gnn.train_mask) + sum(gnn.val_mask) + sum(gnn.test_mask) == self.adjacency.shape[0])
87
+
88
+ test_mask = np.array([False, False, False, False, False, False, False, True, True, True])
89
+ _ = gnn.fit_predict(self.adjacency, self.features, self.labels, train_mask=train_mask, val_mask=val_mask,
90
+ test_mask=test_mask)
91
+ self.assertTrue(sum(gnn.train_mask) + sum(gnn.val_mask) + sum(gnn.test_mask) == self.adjacency.shape[0])
92
+
93
+ def test_gnn_classifier_val_size(self):
94
+ gnn = GNNClassifier(2)
95
+ with self.assertRaises(ValueError):
96
+ gnn.fit_predict(self.adjacency, self.features, self.labels, train_size=None, val_size=None, test_size=None)
97
+ with self.assertRaises(ValueError):
98
+ gnn.fit_predict(self.adjacency, self.features, self.labels, val_size=-1)
99
+ with self.assertRaises(ValueError):
100
+ gnn.fit_predict(self.adjacency, self.features, self.labels, val_size=1.5)
101
+
102
+ _ = gnn.fit_predict(self.adjacency, self.features, self.labels, train_size=0.6)
103
+ self.assertTrue(sum(gnn.train_mask) + sum(gnn.val_mask) + sum(gnn.test_mask) == self.adjacency.shape[0])
104
+
105
+ _ = gnn.fit_predict(self.adjacency, self.features, self.labels, train_size=0.7, val_size=0.1, test_size=0.2)
106
+ self.assertTrue(sum(gnn.train_mask) + sum(gnn.val_mask) + sum(gnn.test_mask) == self.adjacency.shape[0])
107
+
108
+ labels = self.labels.copy()
109
+ labels[:2] = -1 # missing labels
110
+ _ = gnn.fit_predict(self.adjacency, self.features, labels, train_size=0.8, val_size=0.1)
111
+ self.assertTrue(sum(gnn.test_mask) != 0)
112
+ self.assertTrue(sum(gnn.train_mask) + sum(gnn.val_mask) + sum(gnn.test_mask) == self.adjacency.shape[0])
113
+
114
+ def test_gnn_classifier_dim_output(self):
115
+ gnn = GNNClassifier(2)
116
+ labels = np.arange(len(self.labels))
117
+ with self.assertRaises(ValueError):
118
+ gnn.fit(self.adjacency, self.features, labels)
119
+
120
+ def test_gnn_classifier_random_state(self):
121
+ gnn = GNNClassifier(2)
122
+ labels_pred = gnn.fit_predict(self.adjacency, self.features, self.labels, val_size=0.2, random_state=42)
123
+ embedding = gnn.embedding_
124
+ self.assertTrue(len(labels_pred) == self.adjacency.shape[0])
125
+ self.assertTrue(embedding.shape == (self.adjacency.shape[0], 2))
126
+
127
+ def test_gnn_classifier_verbose(self):
128
+ gnn = GNNClassifier(2, verbose=True)
129
+ self.assertTrue(isinstance(gnn, GNNClassifier))
130
+
131
+ def test_gnn_classifier_early_stopping(self):
132
+ gnn = GNNClassifier(2, patience=2)
133
+ _ = gnn.fit_predict(self.adjacency, self.features, self.labels, n_epochs=100, history=True)
134
+ self.assertTrue(len(gnn.history_['val_accuracy']) < 100)
135
+
136
+ gnn = GNNClassifier(2, early_stopping=False)
137
+ train_mask = np.array([True, True, True, True, True, True, False, False, False, False])
138
+ val_mask = np.array([False, False, False, False, False, False, True, True, False, False])
139
+ _ = gnn.fit_predict(self.adjacency, self.features, self.labels, train_mask=train_mask, val_mask=val_mask,
140
+ n_epochs=100, history=True)
141
+ self.assertTrue(len(gnn.history_['val_accuracy']) == 100)
142
+
143
+ def test_gnn_classifier_reinit(self):
144
+ gnn = GNNClassifier([4, 2])
145
+ n_layers = len(gnn.layers)
146
+ gnn.fit(self.adjacency, self.features, self.labels, reinit=False)
147
+ weights = [layer.weight for layer in gnn.layers]
148
+ biases = [layer.bias for layer in gnn.layers]
149
+ gnn.fit(self.adjacency, self.features, self.labels, n_epochs=1, reinit=True)
150
+ self.assertTrue(all([(weights[i] != gnn.layers[i].weight).all() for i in range(n_layers)]))
151
+ self.assertTrue(all([(biases[i] != gnn.layers[i].bias).all() for i in range(n_layers)]))
152
+
153
+ def test_gnn_classifier_sageconv(self):
154
+ gnn = GNNClassifier([4, 2], ['SAGEConv', 'SAGEConv'], sample_sizes=[5, 3])
155
+ _ = gnn.fit_predict(self.adjacency, self.features, self.labels, n_epochs=100)
156
+ self.assertTrue(gnn.layers[0].sample_size == 5 and gnn.layers[0].normalization == 'left')
157
+ self.assertTrue(gnn.layers[1].sample_size == 3 and gnn.layers[1].normalization == 'left')
158
+
159
+ def test_gnn_classifier_predict(self):
160
+ gnn = GNNClassifier([4, 2])
161
+ labels_pred = gnn.fit_predict(self.adjacency, self.features, self.labels, val_size=0.2, random_state=42)
162
+ preds = gnn.predict()
163
+ self.assertTrue(all(labels_pred == gnn.labels_))
164
+ self.assertTrue(all(labels_pred == preds))
165
+
166
+ # Predict same nodes
167
+ predictions = gnn.predict(self.adjacency, self.features)
168
+ self.assertTrue(all(predictions == gnn.labels_))
169
+
170
+ # Incorrect shapes
171
+ new_n = sparse.csr_matrix(np.random.randint(2, size=self.features.shape[1]))
172
+ new_feat = sparse.csr_matrix(np.random.randint(3, size=self.features.shape[1]))
173
+ with self.assertRaises(ValueError):
174
+ gnn.predict(new_n, self.features)
175
+ with self.assertRaises(ValueError):
176
+ gnn.predict(self.adjacency, new_feat)
177
+
178
+ new_feat = sparse.csr_matrix(np.random.rand(self.adjacency.shape[0], self.features.shape[1] - 1))
179
+ with self.assertRaises(ValueError):
180
+ gnn.predict(self.adjacency, new_feat)
181
+
182
+ # Predict new graph
183
+ n = 4
184
+ n_feat = self.features.shape[1]
185
+ adjacency = sparse.csr_matrix(np.random.randint(2, size=(n, n)))
186
+ features = sparse.csr_matrix(np.random.randint(2, size=(n, n_feat)))
187
+ preds = gnn.predict(adjacency, features)
188
+ self.assertTrue(len(preds) == n)
189
+
190
+ # No adj matrix
191
+ preds = gnn.predict(None, features)
192
+ self.assertTrue(len(preds) == features.shape[0])
@@ -0,0 +1,80 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """tests for layers"""
4
+
5
+ import unittest
6
+
7
+ import numpy as np
8
+
9
+ from sknetwork.data.test_graphs import test_graph
10
+ from sknetwork.gnn.layer import Convolution, get_layer
11
+
12
+
13
+ class TestLayer(unittest.TestCase):
14
+
15
+ def setUp(self) -> None:
16
+ """Test graph for tests."""
17
+ self.adjacency = test_graph()
18
+ self.features = self.adjacency
19
+ self.labels = np.array([0] * 5 + [1] * 5)
20
+
21
+ def test_graph_conv_shapes(self):
22
+ conv1 = Convolution('Conv', 4)
23
+ conv1._initialize_weights(self.features.shape[1])
24
+ conv2 = Convolution('Conv', 2)
25
+ conv2._initialize_weights(4)
26
+
27
+ self.assertTrue(conv1.weight.shape == (self.features.shape[1], 4))
28
+ self.assertTrue(conv1.bias.shape == (1, 4))
29
+ self.assertTrue(conv1.weights_initialized)
30
+ self.assertTrue(conv2.weight.shape == (4, 2))
31
+ self.assertTrue(conv2.bias.shape == (1, 2))
32
+ self.assertTrue(conv2.weights_initialized)
33
+
34
+ h = conv1.forward(self.adjacency, self.features)
35
+ self.assertTrue(h.shape == (self.adjacency.shape[0], 4))
36
+ emb = conv2.forward(self.adjacency, h)
37
+ self.assertTrue(emb.shape == (self.adjacency.shape[0], 2))
38
+
39
+ def test_graph_conv_bias_use(self):
40
+ conv1 = Convolution('Conv', 4, use_bias=False)
41
+ conv2 = Convolution('Conv', 2, use_bias=False)
42
+ h = conv1.forward(self.adjacency, self.features)
43
+ emb = conv2.forward(self.adjacency, h)
44
+ self.assertTrue(emb.shape == (self.adjacency.shape[0], 2))
45
+
46
+ def test_graph_conv_self_embeddings(self):
47
+ conv1 = Convolution('Conv', 4, self_embeddings=False)
48
+ conv2 = Convolution('Conv', 2, self_embeddings=False)
49
+ h = conv1.forward(self.adjacency, self.features)
50
+ emb = conv2.forward(self.adjacency, h)
51
+ self.assertTrue(emb.shape == (self.adjacency.shape[0], 2))
52
+
53
+ def test_graph_conv_norm(self):
54
+ conv1 = Convolution('Conv', 4, normalization='left')
55
+ conv2 = Convolution('Conv', 2, normalization='right')
56
+ h = conv1.forward(self.adjacency, self.features)
57
+ emb = conv2.forward(self.adjacency, h)
58
+ self.assertTrue(emb.shape == (self.adjacency.shape[0], 2))
59
+
60
+ def test_graph_conv_activation(self):
61
+ activations = ['Relu', 'Sigmoid']
62
+ for a in activations:
63
+ conv1 = Convolution('Conv', 4, activation=a)
64
+ conv2 = Convolution('Conv', 2)
65
+ h = conv1.forward(self.adjacency, self.features)
66
+ emb = conv2.forward(self.adjacency, h)
67
+ self.assertTrue(emb.shape == (self.adjacency.shape[0], 2))
68
+
69
+ def test_graph_sage(self):
70
+ conv1 = Convolution('Sageconv', 4, normalization='left', self_embeddings=True)
71
+ conv2 = Convolution('Sageconv', 2, normalization='right', self_embeddings=True)
72
+ h = conv1.forward(self.adjacency, self.features)
73
+ emb = conv2.forward(self.adjacency, h)
74
+ self.assertTrue(emb.shape == (self.adjacency.shape[0], 2))
75
+
76
+ def test_get_layer(self):
77
+ with self.assertRaises(ValueError):
78
+ get_layer('toto')
79
+ with self.assertRaises(TypeError):
80
+ get_layer(Convolution)
@@ -0,0 +1,33 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """tests for loss"""
4
+
5
+ import unittest
6
+
7
+ from sknetwork.gnn.loss import *
8
+
9
+
10
+ class TestLoss(unittest.TestCase):
11
+
12
+ def test_get_loss(self):
13
+ self.assertTrue(isinstance(get_loss('CrossEntropy'), CrossEntropy))
14
+ self.assertTrue(isinstance(get_loss('BinaryCrossEntropy'), BinaryCrossEntropy))
15
+ with self.assertRaises(ValueError):
16
+ get_loss('foo')
17
+
18
+ base_loss = BaseLoss()
19
+ self.assertTrue(base_loss == get_loss(base_loss))
20
+ with self.assertRaises(TypeError):
21
+ get_loss(0)
22
+
23
+ def test_ce_loss(self):
24
+ cross_entropy = CrossEntropy()
25
+ signal = np.array([[0, 5]])
26
+ labels = np.array([1])
27
+ self.assertAlmostEqual(cross_entropy.loss(signal, labels), 0.00671534848911828)
28
+
29
+ def test_bce_loss(self):
30
+ binary_cross_entropy = BinaryCrossEntropy()
31
+ signal = np.array([[0, 5]])
32
+ labels = np.array([1])
33
+ self.assertAlmostEqual(binary_cross_entropy.loss(signal, labels), 0.6998625290490632)
@@ -0,0 +1,23 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """tests for neighbor sampler"""
4
+
5
+ import unittest
6
+
7
+ from sknetwork.data.test_graphs import test_graph
8
+ from sknetwork.gnn.neighbor_sampler import *
9
+ from sknetwork.utils import get_degrees
10
+
11
+
12
+ class TestNeighSampler(unittest.TestCase):
13
+
14
+ def setUp(self) -> None:
15
+ """Test graph for tests."""
16
+ self.adjacency = test_graph()
17
+ self.n = self.adjacency.shape[0]
18
+
19
+ def test_uni_node_sampler(self):
20
+ uni_sampler = UniformNeighborSampler(sample_size=2)
21
+ sampled_adj = uni_sampler(self.adjacency)
22
+ self.assertTrue(sampled_adj.shape == self.adjacency.shape)
23
+ self.assertTrue(all(get_degrees(sampled_adj) <= 2))
@@ -0,0 +1,43 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """tests for optimizer"""
4
+
5
+ import unittest
6
+
7
+ import numpy as np
8
+
9
+ from sknetwork.data.test_graphs import test_graph
10
+ from sknetwork.gnn.gnn_classifier import GNNClassifier
11
+ from sknetwork.gnn.optimizer import get_optimizer
12
+
13
+
14
+ class TestOptimizer(unittest.TestCase):
15
+
16
+ def setUp(self) -> None:
17
+ self.adjacency = test_graph()
18
+ self.features = self.adjacency
19
+ self.labels = np.array([0] * 5 + [1] * 5)
20
+
21
+ def test_get_optimizer(self):
22
+ with self.assertRaises(ValueError):
23
+ get_optimizer('foo')
24
+ with self.assertRaises(TypeError):
25
+ get_optimizer(GNNClassifier())
26
+
27
+ def test_optimizer(self):
28
+ for optimizer in ['Adam', 'GD']:
29
+ gnn = GNNClassifier([4, 2], 'Conv', ['Relu', 'Softmax'], optimizer=optimizer)
30
+ _ = gnn.fit_predict(self.adjacency, self.features, self.labels, n_epochs=1, val_size=0.2)
31
+ conv0_weight, conv1_weight = gnn.layers[0].weight.copy(), gnn.layers[1].weight.copy()
32
+ conv0_b, conv1_b = gnn.layers[0].bias.copy(), gnn.layers[1].bias.copy()
33
+ gnn.optimizer.step(gnn)
34
+ # Test weight matrix
35
+ self.assertTrue(gnn.layers[0].weight.shape == conv0_weight.shape)
36
+ self.assertTrue(gnn.layers[1].weight.shape == conv1_weight.shape)
37
+ self.assertTrue((gnn.layers[0].weight != conv0_weight).any())
38
+ self.assertTrue((gnn.layers[1].weight != conv1_weight).any())
39
+ # Test bias vector
40
+ self.assertTrue(gnn.layers[0].bias.shape == conv0_b.shape)
41
+ self.assertTrue(gnn.layers[1].bias.shape == conv1_b.shape)
42
+ self.assertTrue((gnn.layers[0].bias != conv0_b).any())
43
+ self.assertTrue((gnn.layers[1].bias != conv1_b).any())
@@ -0,0 +1,93 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """tests for gnn utils"""
4
+
5
+ import unittest
6
+
7
+ from sknetwork.gnn.utils import *
8
+
9
+
10
+ class TestUtils(unittest.TestCase):
11
+
12
+ def test_check_norm(self):
13
+ with self.assertRaises(ValueError):
14
+ check_normalizations('foo')
15
+ with self.assertRaises(ValueError):
16
+ check_normalizations(['foo', 'bar'])
17
+
18
+ def test_mask_similarity(self):
19
+ m1 = np.array([True, True, False])
20
+ m2 = np.array([True, False, False])
21
+ with self.assertWarns(Warning):
22
+ check_mask_similarity(m1, m2)
23
+
24
+ def test_early_stopping(self):
25
+ self.assertTrue(check_early_stopping(True, np.array([True, False]), 2))
26
+ self.assertFalse(check_early_stopping(True, None, 2))
27
+ self.assertFalse(check_early_stopping(True, np.array([True, False, True]), None))
28
+ self.assertFalse(check_early_stopping(True, np.array([False, False, False]), 5))
29
+
30
+ def test_check_existing_masks(self):
31
+ # Check invalid entries
32
+ labels = np.array([1, 1, 0])
33
+ with self.assertRaises(ValueError):
34
+ check_existing_masks(labels)
35
+ # Check valid size parameters
36
+ self.assertFalse(check_existing_masks(labels, None, None, None, 0.1, 0.2, 0.7)[0])
37
+
38
+ # Check valid pre-computed masks
39
+ mask_exists, train_mask, val_mask, test_mask = check_existing_masks(labels, np.array([False, True, False]))
40
+ self.assertTrue(mask_exists)
41
+ self.assertTrue(all(train_mask == np.array([False, True, False])))
42
+ self.assertTrue(all(val_mask == np.array([False, False, False])))
43
+ self.assertTrue(all(test_mask == np.array([True, False, True])))
44
+
45
+ mask_exists, train_mask, val_mask, test_mask = check_existing_masks(labels, np.array([True, False, False]),
46
+ np.array([False, True, False]))
47
+ self.assertTrue(mask_exists)
48
+ self.assertTrue(all(train_mask == np.array([True, False, False])))
49
+ self.assertTrue(all(val_mask == np.array([False, True, False])))
50
+ self.assertTrue(all(test_mask == np.array([False, False, True])))
51
+
52
+ mask_exists, train_mask, val_mask, test_mask = check_existing_masks(labels, np.array([True, True, False]),
53
+ None, np.array([False, False, True]))
54
+ self.assertTrue(mask_exists)
55
+ self.assertTrue(all(train_mask == np.array([True, True, False])))
56
+ self.assertTrue(all(val_mask == np.array([False, False, False])))
57
+ self.assertTrue(all(test_mask == np.array([False, False, True])))
58
+
59
+ mask_exists, train_mask, val_mask, test_mask = check_existing_masks(labels, np.array([True, False, False]),
60
+ np.array([False, False, True]),
61
+ np.array([False, True, False]))
62
+ self.assertTrue(mask_exists)
63
+ self.assertTrue(all(train_mask == np.array([True, False, False])))
64
+ self.assertTrue(all(val_mask == np.array([False, False, True])))
65
+ self.assertTrue(all(test_mask == np.array([False, True, False])))
66
+
67
+ # Check negative labels
68
+ labels = np.array([1, -1, 0])
69
+ mask_exists, train_mask, val_mask, test_mask = check_existing_masks(labels, np.array([True, True, False]))
70
+ self.assertTrue(mask_exists)
71
+ self.assertTrue(all(train_mask == np.array([True, False, False])))
72
+ print(val_mask)
73
+ self.assertTrue(all(val_mask == np.array([False, False, False])))
74
+ self.assertTrue(all(test_mask == np.array([False, True, True])))
75
+
76
+ def test_get_layers(self):
77
+ with self.assertRaises(ValueError):
78
+ get_layers([4, 2], 'Conv', activations=['Relu', 'Sigmoid', 'Relu'], use_bias=True, normalizations='Both',
79
+ self_embeddings=True, sample_sizes=5, loss=None)
80
+ # Type compatibility
81
+ layers = get_layers([4], 'Conv', activations=['Relu'], use_bias=[True], normalizations=['Both'],
82
+ self_embeddings=[True], sample_sizes=[5], loss='Cross entropy')
83
+ self.assertTrue(len(np.ravel(layers)) == 1)
84
+ # Broadcasting parameters
85
+ layers = get_layers([4, 2], ['Conv', 'Conv'], activations='Relu', use_bias=True, normalizations='Both',
86
+ self_embeddings=True, sample_sizes=5, loss='Cross entropy')
87
+ self.assertTrue(len(layers) == 2)
88
+
89
+ def test_check_loss(self):
90
+ layer = get_layers([4], 'Conv', activations=['Relu'], use_bias=[True], normalizations=['Both'],
91
+ self_embeddings=[True], sample_sizes=[5], loss=None)
92
+ with self.assertRaises(ValueError):
93
+ check_loss(layer[0])