scikit-network 0.28.3__cp39-cp39-macosx_12_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-network might be problematic. Click here for more details.

Files changed (240) hide show
  1. scikit_network-0.28.3.dist-info/AUTHORS.rst +41 -0
  2. scikit_network-0.28.3.dist-info/LICENSE +34 -0
  3. scikit_network-0.28.3.dist-info/METADATA +457 -0
  4. scikit_network-0.28.3.dist-info/RECORD +240 -0
  5. scikit_network-0.28.3.dist-info/WHEEL +5 -0
  6. scikit_network-0.28.3.dist-info/top_level.txt +1 -0
  7. sknetwork/__init__.py +21 -0
  8. sknetwork/classification/__init__.py +8 -0
  9. sknetwork/classification/base.py +84 -0
  10. sknetwork/classification/base_rank.py +143 -0
  11. sknetwork/classification/diffusion.py +134 -0
  12. sknetwork/classification/knn.py +162 -0
  13. sknetwork/classification/metrics.py +205 -0
  14. sknetwork/classification/pagerank.py +66 -0
  15. sknetwork/classification/propagation.py +152 -0
  16. sknetwork/classification/tests/__init__.py +1 -0
  17. sknetwork/classification/tests/test_API.py +35 -0
  18. sknetwork/classification/tests/test_diffusion.py +37 -0
  19. sknetwork/classification/tests/test_knn.py +24 -0
  20. sknetwork/classification/tests/test_metrics.py +53 -0
  21. sknetwork/classification/tests/test_pagerank.py +20 -0
  22. sknetwork/classification/tests/test_propagation.py +24 -0
  23. sknetwork/classification/vote.cpython-39-darwin.so +0 -0
  24. sknetwork/classification/vote.pyx +58 -0
  25. sknetwork/clustering/__init__.py +7 -0
  26. sknetwork/clustering/base.py +102 -0
  27. sknetwork/clustering/kmeans.py +142 -0
  28. sknetwork/clustering/louvain.py +255 -0
  29. sknetwork/clustering/louvain_core.cpython-39-darwin.so +0 -0
  30. sknetwork/clustering/louvain_core.pyx +134 -0
  31. sknetwork/clustering/metrics.py +91 -0
  32. sknetwork/clustering/postprocess.py +66 -0
  33. sknetwork/clustering/propagation_clustering.py +108 -0
  34. sknetwork/clustering/tests/__init__.py +1 -0
  35. sknetwork/clustering/tests/test_API.py +37 -0
  36. sknetwork/clustering/tests/test_kmeans.py +47 -0
  37. sknetwork/clustering/tests/test_louvain.py +104 -0
  38. sknetwork/clustering/tests/test_metrics.py +50 -0
  39. sknetwork/clustering/tests/test_post_processing.py +23 -0
  40. sknetwork/clustering/tests/test_postprocess.py +39 -0
  41. sknetwork/data/__init__.py +5 -0
  42. sknetwork/data/load.py +408 -0
  43. sknetwork/data/models.py +459 -0
  44. sknetwork/data/parse.py +621 -0
  45. sknetwork/data/test_graphs.py +84 -0
  46. sknetwork/data/tests/__init__.py +1 -0
  47. sknetwork/data/tests/test_API.py +30 -0
  48. sknetwork/data/tests/test_load.py +95 -0
  49. sknetwork/data/tests/test_models.py +52 -0
  50. sknetwork/data/tests/test_parse.py +253 -0
  51. sknetwork/data/tests/test_test_graphs.py +30 -0
  52. sknetwork/data/tests/test_toy_graphs.py +68 -0
  53. sknetwork/data/toy_graphs.py +619 -0
  54. sknetwork/embedding/__init__.py +10 -0
  55. sknetwork/embedding/base.py +90 -0
  56. sknetwork/embedding/force_atlas.py +197 -0
  57. sknetwork/embedding/louvain_embedding.py +174 -0
  58. sknetwork/embedding/louvain_hierarchy.py +142 -0
  59. sknetwork/embedding/metrics.py +66 -0
  60. sknetwork/embedding/random_projection.py +133 -0
  61. sknetwork/embedding/spectral.py +214 -0
  62. sknetwork/embedding/spring.py +198 -0
  63. sknetwork/embedding/svd.py +363 -0
  64. sknetwork/embedding/tests/__init__.py +1 -0
  65. sknetwork/embedding/tests/test_API.py +73 -0
  66. sknetwork/embedding/tests/test_force_atlas.py +35 -0
  67. sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
  68. sknetwork/embedding/tests/test_louvain_hierarchy.py +19 -0
  69. sknetwork/embedding/tests/test_metrics.py +29 -0
  70. sknetwork/embedding/tests/test_random_projection.py +28 -0
  71. sknetwork/embedding/tests/test_spectral.py +84 -0
  72. sknetwork/embedding/tests/test_spring.py +50 -0
  73. sknetwork/embedding/tests/test_svd.py +37 -0
  74. sknetwork/flow/__init__.py +3 -0
  75. sknetwork/flow/flow.py +73 -0
  76. sknetwork/flow/tests/__init__.py +1 -0
  77. sknetwork/flow/tests/test_flow.py +17 -0
  78. sknetwork/flow/tests/test_utils.py +69 -0
  79. sknetwork/flow/utils.py +91 -0
  80. sknetwork/gnn/__init__.py +10 -0
  81. sknetwork/gnn/activation.py +117 -0
  82. sknetwork/gnn/base.py +155 -0
  83. sknetwork/gnn/base_activation.py +89 -0
  84. sknetwork/gnn/base_layer.py +109 -0
  85. sknetwork/gnn/gnn_classifier.py +381 -0
  86. sknetwork/gnn/layer.py +153 -0
  87. sknetwork/gnn/layers.py +127 -0
  88. sknetwork/gnn/loss.py +180 -0
  89. sknetwork/gnn/neighbor_sampler.py +65 -0
  90. sknetwork/gnn/optimizer.py +163 -0
  91. sknetwork/gnn/tests/__init__.py +1 -0
  92. sknetwork/gnn/tests/test_activation.py +56 -0
  93. sknetwork/gnn/tests/test_base.py +79 -0
  94. sknetwork/gnn/tests/test_base_layer.py +37 -0
  95. sknetwork/gnn/tests/test_gnn_classifier.py +192 -0
  96. sknetwork/gnn/tests/test_layers.py +80 -0
  97. sknetwork/gnn/tests/test_loss.py +33 -0
  98. sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
  99. sknetwork/gnn/tests/test_optimizer.py +43 -0
  100. sknetwork/gnn/tests/test_utils.py +93 -0
  101. sknetwork/gnn/utils.py +219 -0
  102. sknetwork/hierarchy/__init__.py +7 -0
  103. sknetwork/hierarchy/base.py +69 -0
  104. sknetwork/hierarchy/louvain_hierarchy.py +264 -0
  105. sknetwork/hierarchy/metrics.py +234 -0
  106. sknetwork/hierarchy/paris.cpython-39-darwin.so +0 -0
  107. sknetwork/hierarchy/paris.pyx +317 -0
  108. sknetwork/hierarchy/postprocess.py +350 -0
  109. sknetwork/hierarchy/tests/__init__.py +1 -0
  110. sknetwork/hierarchy/tests/test_API.py +25 -0
  111. sknetwork/hierarchy/tests/test_algos.py +29 -0
  112. sknetwork/hierarchy/tests/test_metrics.py +62 -0
  113. sknetwork/hierarchy/tests/test_postprocess.py +57 -0
  114. sknetwork/hierarchy/tests/test_ward.py +25 -0
  115. sknetwork/hierarchy/ward.py +94 -0
  116. sknetwork/linalg/__init__.py +9 -0
  117. sknetwork/linalg/basics.py +37 -0
  118. sknetwork/linalg/diteration.cpython-39-darwin.so +0 -0
  119. sknetwork/linalg/diteration.pyx +49 -0
  120. sknetwork/linalg/eig_solver.py +93 -0
  121. sknetwork/linalg/laplacian.py +15 -0
  122. sknetwork/linalg/normalization.py +66 -0
  123. sknetwork/linalg/operators.py +225 -0
  124. sknetwork/linalg/polynome.py +76 -0
  125. sknetwork/linalg/ppr_solver.py +170 -0
  126. sknetwork/linalg/push.cpython-39-darwin.so +0 -0
  127. sknetwork/linalg/push.pyx +73 -0
  128. sknetwork/linalg/sparse_lowrank.py +142 -0
  129. sknetwork/linalg/svd_solver.py +91 -0
  130. sknetwork/linalg/tests/__init__.py +1 -0
  131. sknetwork/linalg/tests/test_eig.py +44 -0
  132. sknetwork/linalg/tests/test_laplacian.py +18 -0
  133. sknetwork/linalg/tests/test_normalization.py +38 -0
  134. sknetwork/linalg/tests/test_operators.py +70 -0
  135. sknetwork/linalg/tests/test_polynome.py +38 -0
  136. sknetwork/linalg/tests/test_ppr.py +50 -0
  137. sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
  138. sknetwork/linalg/tests/test_svd.py +38 -0
  139. sknetwork/linkpred/__init__.py +4 -0
  140. sknetwork/linkpred/base.py +80 -0
  141. sknetwork/linkpred/first_order.py +508 -0
  142. sknetwork/linkpred/first_order_core.cpython-39-darwin.so +0 -0
  143. sknetwork/linkpred/first_order_core.pyx +315 -0
  144. sknetwork/linkpred/postprocessing.py +98 -0
  145. sknetwork/linkpred/tests/__init__.py +1 -0
  146. sknetwork/linkpred/tests/test_API.py +49 -0
  147. sknetwork/linkpred/tests/test_postprocessing.py +21 -0
  148. sknetwork/path/__init__.py +4 -0
  149. sknetwork/path/metrics.py +148 -0
  150. sknetwork/path/search.py +65 -0
  151. sknetwork/path/shortest_path.py +186 -0
  152. sknetwork/path/tests/__init__.py +1 -0
  153. sknetwork/path/tests/test_metrics.py +29 -0
  154. sknetwork/path/tests/test_search.py +25 -0
  155. sknetwork/path/tests/test_shortest_path.py +45 -0
  156. sknetwork/ranking/__init__.py +9 -0
  157. sknetwork/ranking/base.py +56 -0
  158. sknetwork/ranking/betweenness.cpython-39-darwin.so +0 -0
  159. sknetwork/ranking/betweenness.pyx +99 -0
  160. sknetwork/ranking/closeness.py +95 -0
  161. sknetwork/ranking/harmonic.py +82 -0
  162. sknetwork/ranking/hits.py +94 -0
  163. sknetwork/ranking/katz.py +81 -0
  164. sknetwork/ranking/pagerank.py +107 -0
  165. sknetwork/ranking/postprocess.py +25 -0
  166. sknetwork/ranking/tests/__init__.py +1 -0
  167. sknetwork/ranking/tests/test_API.py +34 -0
  168. sknetwork/ranking/tests/test_betweenness.py +38 -0
  169. sknetwork/ranking/tests/test_closeness.py +34 -0
  170. sknetwork/ranking/tests/test_hits.py +20 -0
  171. sknetwork/ranking/tests/test_pagerank.py +69 -0
  172. sknetwork/regression/__init__.py +4 -0
  173. sknetwork/regression/base.py +56 -0
  174. sknetwork/regression/diffusion.py +190 -0
  175. sknetwork/regression/tests/__init__.py +1 -0
  176. sknetwork/regression/tests/test_API.py +34 -0
  177. sknetwork/regression/tests/test_diffusion.py +48 -0
  178. sknetwork/sknetwork.py +3 -0
  179. sknetwork/topology/__init__.py +9 -0
  180. sknetwork/topology/dag.py +74 -0
  181. sknetwork/topology/dag_core.cpython-39-darwin.so +0 -0
  182. sknetwork/topology/dag_core.pyx +38 -0
  183. sknetwork/topology/kcliques.cpython-39-darwin.so +0 -0
  184. sknetwork/topology/kcliques.pyx +193 -0
  185. sknetwork/topology/kcore.cpython-39-darwin.so +0 -0
  186. sknetwork/topology/kcore.pyx +120 -0
  187. sknetwork/topology/structure.py +234 -0
  188. sknetwork/topology/tests/__init__.py +1 -0
  189. sknetwork/topology/tests/test_cliques.py +28 -0
  190. sknetwork/topology/tests/test_cores.py +21 -0
  191. sknetwork/topology/tests/test_dag.py +26 -0
  192. sknetwork/topology/tests/test_structure.py +99 -0
  193. sknetwork/topology/tests/test_triangles.py +42 -0
  194. sknetwork/topology/tests/test_wl_coloring.py +49 -0
  195. sknetwork/topology/tests/test_wl_kernel.py +31 -0
  196. sknetwork/topology/triangles.cpython-39-darwin.so +0 -0
  197. sknetwork/topology/triangles.pyx +166 -0
  198. sknetwork/topology/weisfeiler_lehman.py +163 -0
  199. sknetwork/topology/weisfeiler_lehman_core.cpython-39-darwin.so +0 -0
  200. sknetwork/topology/weisfeiler_lehman_core.pyx +116 -0
  201. sknetwork/utils/__init__.py +40 -0
  202. sknetwork/utils/base.py +35 -0
  203. sknetwork/utils/check.py +354 -0
  204. sknetwork/utils/co_neighbor.py +71 -0
  205. sknetwork/utils/format.py +219 -0
  206. sknetwork/utils/kmeans.py +89 -0
  207. sknetwork/utils/knn.py +166 -0
  208. sknetwork/utils/knn1d.cpython-39-darwin.so +0 -0
  209. sknetwork/utils/knn1d.pyx +80 -0
  210. sknetwork/utils/membership.py +82 -0
  211. sknetwork/utils/minheap.cpython-39-darwin.so +0 -0
  212. sknetwork/utils/minheap.pxd +22 -0
  213. sknetwork/utils/minheap.pyx +111 -0
  214. sknetwork/utils/neighbors.py +115 -0
  215. sknetwork/utils/seeds.py +75 -0
  216. sknetwork/utils/simplex.py +140 -0
  217. sknetwork/utils/tests/__init__.py +1 -0
  218. sknetwork/utils/tests/test_base.py +28 -0
  219. sknetwork/utils/tests/test_bunch.py +16 -0
  220. sknetwork/utils/tests/test_check.py +190 -0
  221. sknetwork/utils/tests/test_co_neighbor.py +43 -0
  222. sknetwork/utils/tests/test_format.py +61 -0
  223. sknetwork/utils/tests/test_kmeans.py +21 -0
  224. sknetwork/utils/tests/test_knn.py +32 -0
  225. sknetwork/utils/tests/test_membership.py +24 -0
  226. sknetwork/utils/tests/test_neighbors.py +41 -0
  227. sknetwork/utils/tests/test_projection_simplex.py +33 -0
  228. sknetwork/utils/tests/test_seeds.py +67 -0
  229. sknetwork/utils/tests/test_verbose.py +15 -0
  230. sknetwork/utils/tests/test_ward.py +20 -0
  231. sknetwork/utils/timeout.py +38 -0
  232. sknetwork/utils/verbose.py +37 -0
  233. sknetwork/utils/ward.py +60 -0
  234. sknetwork/visualization/__init__.py +4 -0
  235. sknetwork/visualization/colors.py +34 -0
  236. sknetwork/visualization/dendrograms.py +229 -0
  237. sknetwork/visualization/graphs.py +819 -0
  238. sknetwork/visualization/tests/__init__.py +1 -0
  239. sknetwork/visualization/tests/test_dendrograms.py +53 -0
  240. sknetwork/visualization/tests/test_graphs.py +167 -0
@@ -0,0 +1,33 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """tests for simplex.py"""
4
+
5
+ import unittest
6
+
7
+ import numpy as np
8
+ from scipy import sparse
9
+
10
+ from sknetwork.utils.check import is_proba_array
11
+ from sknetwork.utils.simplex import projection_simplex
12
+
13
+
14
+ class TestProjSimplex(unittest.TestCase):
15
+
16
+ def test_array(self):
17
+ x = np.random.rand(5)
18
+ proj = projection_simplex(x)
19
+ self.assertTrue(is_proba_array(proj))
20
+
21
+ x = np.random.rand(4, 3)
22
+ proj = projection_simplex(x)
23
+ self.assertTrue(is_proba_array(proj))
24
+
25
+ def test_csr(self):
26
+ x = sparse.csr_matrix(np.ones((3, 3)))
27
+ proj1 = projection_simplex(x)
28
+ proj2 = projection_simplex(x.astype(bool))
29
+ self.assertEqual(0, (proj1-proj2).nnz)
30
+
31
+ def test_other(self):
32
+ with self.assertRaises(TypeError):
33
+ projection_simplex('toto')
@@ -0,0 +1,67 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """tests for seeds.py"""
4
+
5
+ import unittest
6
+
7
+ import numpy as np
8
+
9
+ from sknetwork.utils.seeds import get_seeds, stack_seeds, seeds2probs
10
+
11
+
12
+ class TestSeeds(unittest.TestCase):
13
+
14
+ def test_get_seeds(self):
15
+ n = 10
16
+ seeds_array = -np.ones(n)
17
+ seeds_array[:2] = np.arange(2)
18
+ seeds_dict = {0: 0, 1: 1}
19
+ labels_array = get_seeds((n,), seeds_array)
20
+ labels_dict = get_seeds((n,), seeds_dict)
21
+
22
+ self.assertTrue(np.allclose(labels_array, labels_dict))
23
+ with self.assertRaises(ValueError):
24
+ get_seeds((5,), labels_array)
25
+ self.assertRaises(TypeError, get_seeds, 'toto', 3)
26
+ with self.assertWarns(Warning):
27
+ seeds_dict[0] = -1
28
+ get_seeds((n,), seeds_dict)
29
+
30
+ def test_seeds2probs(self):
31
+ n = 4
32
+ seeds_array = np.array([0, 1, -1, 0])
33
+ seeds_dict = {0: 0, 1: 1, 3: 0}
34
+
35
+ probs1 = seeds2probs(n, seeds_array)
36
+ probs2 = seeds2probs(n, seeds_dict)
37
+ self.assertTrue(np.allclose(probs1, probs2))
38
+
39
+ bad_input = np.array([0, 0, -1, 0])
40
+ with self.assertRaises(ValueError):
41
+ seeds2probs(n, bad_input)
42
+
43
+ def test_stack_seeds(self):
44
+ shape = 4, 3
45
+ seeds_row_array = np.array([0, 1, -1, 0])
46
+ seeds_row_dict = {0: 0, 1: 1, 3: 0}
47
+ seeds_col_array = np.array([0, 1, -1])
48
+ seeds_col_dict = {0: 0, 1: 1}
49
+
50
+ seeds1 = stack_seeds(shape, seeds_row_array, seeds_col_array)
51
+ seeds2 = stack_seeds(shape, seeds_row_dict, seeds_col_dict)
52
+ seeds3 = stack_seeds(shape, seeds_row_array, seeds_col_dict)
53
+ seeds4 = stack_seeds(shape, seeds_row_dict, seeds_col_array)
54
+
55
+ self.assertTrue(np.allclose(seeds1, seeds2))
56
+ self.assertTrue(np.allclose(seeds2, seeds3))
57
+ self.assertTrue(np.allclose(seeds3, seeds4))
58
+
59
+ seeds1 = stack_seeds(shape, seeds_row_array, None)
60
+ seeds2 = stack_seeds(shape, seeds_row_dict, None)
61
+
62
+ self.assertTrue(np.allclose(seeds1, seeds2))
63
+
64
+ seeds1 = stack_seeds(shape, None, seeds_col_array)
65
+ seeds2 = stack_seeds(shape, None, seeds_col_dict)
66
+
67
+ self.assertTrue(np.allclose(seeds1, seeds2))
@@ -0,0 +1,15 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """tests for verbose.py"""
4
+
5
+ import unittest
6
+
7
+ from sknetwork.utils.verbose import VerboseMixin
8
+
9
+
10
+ class TestVerbose(unittest.TestCase):
11
+
12
+ def test_prints(self):
13
+ verbose = VerboseMixin(verbose=True)
14
+ verbose.log.print('There are', 4, 'seasons in a year')
15
+ self.assertEqual(str(verbose.log), 'There are 4 seasons in a year\n')
@@ -0,0 +1,20 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on October 2019
5
+ @author: Nathan de Lara <nathan.delara@polytechnique.org>
6
+ """
7
+ import unittest
8
+
9
+ import numpy as np
10
+
11
+ from sknetwork.utils import WardDense
12
+
13
+
14
+ class TestKMeans(unittest.TestCase):
15
+
16
+ def test_kmeans(self):
17
+ x = np.random.randn(10, 3)
18
+ ward = WardDense()
19
+ dendrogram = ward.fit_transform(x)
20
+ self.assertEqual(dendrogram.shape, (x.shape[0] - 1, 4))
@@ -0,0 +1,38 @@
1
+ #!/usr/bin/env python3
2
+ import contextlib
3
+ import signal
4
+ import warnings
5
+
6
+
7
+ class TimeOut(contextlib.ContextDecorator):
8
+ """
9
+ Timeout context manager/decorator.
10
+
11
+ Adapted from https://gist.github.com/TySkby/143190ad1b88c6115597c45f996b030c on 12/10/2020.
12
+
13
+ Examples
14
+ --------
15
+ >>> from time import sleep
16
+ >>> try:
17
+ ... with TimeOut(1):
18
+ ... sleep(10)
19
+ ... except TimeoutError:
20
+ ... print("Function timed out")
21
+ Function timed out
22
+ """
23
+ def __init__(self, seconds: float):
24
+ self.seconds = seconds
25
+
26
+ def _timeout_handler(self, signum, frame):
27
+ raise TimeoutError("Code timed out.")
28
+
29
+ def __enter__(self):
30
+ if hasattr(signal, "SIGALRM"):
31
+ signal.signal(signal.SIGALRM, self._timeout_handler)
32
+ signal.alarm(self.seconds)
33
+ else:
34
+ warnings.warn("SIGALRM is unavailable on Windows. Timeouts are not functional.")
35
+
36
+ def __exit__(self, exc_type, exc_val, exc_tb):
37
+ if hasattr(signal, "SIGALRM"):
38
+ signal.alarm(0)
@@ -0,0 +1,37 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created in December 2019
5
+ @author: Quentin Lutz <qlutz@enst.fr>
6
+ """
7
+
8
+
9
+ class Log:
10
+ """Log class for easier verbosity features"""
11
+ def __init__(self, verbose: bool = False):
12
+ self.verbose = verbose
13
+ self.log = ''
14
+
15
+ def print(self, *args):
16
+ """Fill log with text."""
17
+ if self.verbose:
18
+ print(*args)
19
+ self.log += ' '.join(map(str, args)) + '\n'
20
+
21
+ def __repr__(self):
22
+ return self.log
23
+
24
+
25
+ class VerboseMixin:
26
+ """Mixin class for verbosity"""
27
+ def __init__(self, verbose: bool = False):
28
+ self.log = Log(verbose)
29
+
30
+ def _scipy_solver_info(self, info: int):
31
+ """Fill log with scipy info."""
32
+ if info == 0:
33
+ self.log.print('Successful exit.')
34
+ elif info > 0:
35
+ self.log.print('Convergence to tolerance not achieved.')
36
+ else:
37
+ self.log.print('Illegal input or breakdown.')
@@ -0,0 +1,60 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on October 2019
5
+ @author: Nathan de Lara <nathan.delara@polytechnique.org>
6
+ """
7
+ import numpy as np
8
+ from scipy.cluster.hierarchy import ward
9
+
10
+ from sknetwork.utils.base import Algorithm
11
+
12
+
13
+ class WardDense(Algorithm):
14
+ """Hierarchical clustering by the Ward method based on SciPy.
15
+
16
+ Attributes
17
+ ----------
18
+ dendrogram_ : np.ndarray (n - 1, 4)
19
+ Dendrogram.
20
+
21
+ References
22
+ ----------
23
+ * Ward, J. H., Jr. (1963). Hierarchical grouping to optimize an objective function.
24
+ Journal of the American Statistical Association, 58, 236–244.
25
+
26
+ * Murtagh, F., & Contreras, P. (2012). Algorithms for hierarchical clustering: an overview.
27
+ Wiley Interdisciplinary Reviews: Data Mining and Knowledge Discovery, 2(1), 86-97.
28
+ """
29
+ def __init__(self):
30
+ self.dendrogram_ = None
31
+
32
+ def fit(self, x: np.ndarray) -> 'WardDense':
33
+ """Apply algorithm to a dense matrix.
34
+
35
+ Parameters
36
+ ----------
37
+ x:
38
+ Data to cluster.
39
+
40
+ Returns
41
+ -------
42
+ self: :class:`WardDense`
43
+ """
44
+ self.dendrogram_ = ward(x)
45
+ return self
46
+
47
+ def fit_transform(self, x: np.ndarray) -> np.ndarray:
48
+ """Apply algorithm to a dense matrix and return the dendrogram.
49
+
50
+ Parameters
51
+ ----------
52
+ x:
53
+ Data to cluster.
54
+
55
+ Returns
56
+ -------
57
+ dendrogram: np.ndarray
58
+ """
59
+ self.fit(x)
60
+ return self.dendrogram_
@@ -0,0 +1,4 @@
1
+ """Visualization module."""
2
+
3
+ from sknetwork.visualization.dendrograms import svg_dendrogram
4
+ from sknetwork.visualization.graphs import svg_graph, svg_bigraph
@@ -0,0 +1,34 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on April 2020
5
+ @authors:
6
+ Thomas Bonald <bonald@enst.fr>
7
+ """
8
+
9
+ import numpy as np
10
+
11
+ # standard SVG colors
12
+ STANDARD_COLORS = np.array(['blue', 'red', 'green', 'orange', 'purple', 'yellow', 'fuchsia', 'olive', 'aqua', 'brown'])
13
+
14
+ # 100 RGB colors of coolwarm color map.
15
+ COOLWARM_RGB = np.array([[58, 76, 192], [60, 79, 195], [64, 84, 199], [66, 88, 202], [70, 93, 207], [72, 96, 209],
16
+ [76, 102, 214], [80, 107, 218], [82, 110, 220], [86, 115, 224], [88, 118, 226], [92, 123, 229],
17
+ [96, 128, 232], [99, 131, 234], [103, 136, 237], [105, 139, 239], [109, 144, 241],
18
+ [112, 147, 243], [116, 151, 245], [120, 155, 247], [123, 158, 248], [127, 162, 250],
19
+ [130, 165, 251], [134, 169, 252], [138, 173, 253], [141, 175, 253], [145, 179, 254],
20
+ [148, 181, 254], [152, 185, 254], [155, 187, 254], [159, 190, 254], [163, 193, 254],
21
+ [166, 195, 253], [170, 198, 253], [172, 200, 252], [176, 203, 251], [180, 205, 250],
22
+ [183, 207, 249], [187, 209, 247], [189, 210, 246], [193, 212, 244], [197, 213, 242],
23
+ [199, 214, 240], [202, 216, 238], [205, 217, 236], [208, 218, 233], [210, 218, 231],
24
+ [214, 219, 228], [217, 220, 224], [219, 220, 222], [222, 219, 218], [224, 218, 215],
25
+ [227, 217, 211], [230, 215, 207], [231, 214, 204], [234, 211, 199], [236, 210, 196],
26
+ [237, 207, 192], [239, 206, 188], [241, 203, 184], [242, 200, 179], [243, 198, 176],
27
+ [244, 195, 171], [245, 193, 168], [246, 189, 164], [246, 186, 159], [246, 183, 156],
28
+ [247, 179, 151], [247, 177, 148], [247, 173, 143], [246, 169, 138], [246, 166, 135],
29
+ [245, 161, 130], [245, 158, 127], [244, 154, 123], [243, 150, 120], [242, 145, 115],
30
+ [240, 141, 111], [239, 137, 108], [237, 132, 103], [236, 128, 100], [234, 123, 96],
31
+ [231, 117, 92], [230, 114, 89], [227, 108, 84], [225, 104, 82], [222, 98, 78],
32
+ [220, 94, 75], [217, 88, 71], [214, 82, 67], [211, 77, 64], [207, 70, 61],
33
+ [205, 66, 58], [201, 59, 55], [197, 50, 51], [194, 45, 49], [190, 35, 45],
34
+ [187, 26, 43], [182, 13, 40], [179, 3, 38]])
@@ -0,0 +1,229 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on April 2020
5
+ @author: Thomas Bonald <bonald@enst.fr>
6
+ """
7
+ from typing import Iterable, Optional
8
+
9
+ import numpy as np
10
+
11
+ from sknetwork.hierarchy.postprocess import cut_straight
12
+ from sknetwork.visualization.colors import STANDARD_COLORS
13
+
14
+
15
+ def get_index(dendrogram, reorder=True):
16
+ """Index nodes for pretty dendrogram."""
17
+ n = dendrogram.shape[0] + 1
18
+ tree = {i: [i] for i in range(n)}
19
+ for t in range(n - 1):
20
+ i = int(dendrogram[t, 0])
21
+ j = int(dendrogram[t, 1])
22
+ left: list = tree.pop(i)
23
+ right: list = tree.pop(j)
24
+ if reorder and len(left) < len(right):
25
+ tree[n + t] = right + left
26
+ else:
27
+ tree[n + t] = left + right
28
+ return list(tree.values())[0]
29
+
30
+
31
+ def svg_dendrogram_top(dendrogram, names, width, height, margin, margin_text, scale, line_width, n_clusters,
32
+ color, colors, font_size, reorder, rotate_names):
33
+ """Dendrogram as SVG image with root on top."""
34
+
35
+ # scaling
36
+ height *= scale
37
+ width *= scale
38
+
39
+ # positioning
40
+ labels = cut_straight(dendrogram, n_clusters, return_dendrogram=False)
41
+ index = get_index(dendrogram, reorder)
42
+ n = len(index)
43
+ unit_height = height / dendrogram[-1, 2]
44
+ unit_width = width / n
45
+ height_basis = margin + height
46
+ position = {index[i]: (margin + i * unit_width, height_basis) for i in range(n)}
47
+ label = {i: l for i, l in enumerate(labels)}
48
+ width += 2 * margin
49
+ height += 2 * margin
50
+ if names is not None:
51
+ text_length = np.max(np.array([len(str(name)) for name in names]))
52
+ height += text_length * font_size * .5 + margin_text
53
+
54
+ svg = """<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">""".format(width, height)
55
+
56
+ # text
57
+ if names is not None:
58
+ for i in range(n):
59
+ x, y = position[i]
60
+ x -= margin_text
61
+ y += margin_text
62
+ text = str(names[i]).replace('&', ' ')
63
+ if rotate_names:
64
+ svg += """<text x="{}" y="{}" transform="rotate(60, {}, {})" font-size="{}">{}</text>""" \
65
+ .format(x, y, x, y, font_size, text)
66
+ else:
67
+ y += margin_text
68
+ svg += """<text x="{}" y="{}" font-size="{}">{}</text>""" \
69
+ .format(x, y, font_size, text)
70
+
71
+ # tree
72
+ for t in range(n - 1):
73
+ i = int(dendrogram[t, 0])
74
+ j = int(dendrogram[t, 1])
75
+ x1, y1 = position.pop(i)
76
+ x2, y2 = position.pop(j)
77
+ l1 = label.pop(i)
78
+ l2 = label.pop(j)
79
+ if l1 == l2:
80
+ line_color = colors[l1 % len(colors)]
81
+ else:
82
+ line_color = color
83
+ x = .5 * (x1 + x2)
84
+ y = height_basis - dendrogram[t, 2] * unit_height
85
+ svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
86
+ .format(line_width, line_color, x1, y1, x1, y)
87
+ svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
88
+ .format(line_width, line_color, x2, y2, x2, y)
89
+ svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
90
+ .format(line_width, line_color, x1, y, x2, y)
91
+ position[n + t] = (x, y)
92
+ label[n + t] = l1
93
+
94
+ svg += '</svg>'
95
+ return svg
96
+
97
+
98
+ def svg_dendrogram_left(dendrogram, names, width, height, margin, margin_text, scale, line_width, n_clusters,
99
+ color, colors, font_size, reorder):
100
+ """Dendrogram as SVG image with root on left side."""
101
+
102
+ # scaling
103
+ height *= scale
104
+ width *= scale
105
+
106
+ # positioning
107
+ labels = cut_straight(dendrogram, n_clusters, return_dendrogram=False)
108
+ index = get_index(dendrogram, reorder)
109
+ n = len(index)
110
+ unit_height = height / n
111
+ unit_width = width / dendrogram[-1, 2]
112
+ width_basis = width + margin
113
+ position = {index[i]: (width_basis, margin + i * unit_height) for i in range(n)}
114
+ label = {i: l for i, l in enumerate(labels)}
115
+ width += 2 * margin
116
+ height += 2 * margin
117
+ if names is not None:
118
+ text_length = np.max(np.array([len(str(name)) for name in names]))
119
+ width += text_length * font_size * .5 + margin_text
120
+
121
+ svg = """<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">""".format(width, height)
122
+
123
+ # text
124
+ if names is not None:
125
+ for i in range(n):
126
+ x, y = position[i]
127
+ x += margin_text
128
+ y += unit_height / 3
129
+ text = str(names[i]).replace('&', ' ')
130
+ svg += """<text x="{}" y="{}" font-size="{}">{}</text>""" \
131
+ .format(x, y, font_size, text)
132
+
133
+ # tree
134
+ for t in range(n - 1):
135
+ i = int(dendrogram[t, 0])
136
+ j = int(dendrogram[t, 1])
137
+ x1, y1 = position.pop(i)
138
+ x2, y2 = position.pop(j)
139
+ l1 = label.pop(i)
140
+ l2 = label.pop(j)
141
+ if l1 == l2:
142
+ line_color = colors[l1 % len(colors)]
143
+ else:
144
+ line_color = color
145
+ y = .5 * (y1 + y2)
146
+ x = width_basis - dendrogram[t, 2] * unit_width
147
+ svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
148
+ .format(line_width, line_color, x1, y1, x, y1)
149
+ svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
150
+ .format(line_width, line_color, x2, y2, x, y2)
151
+ svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
152
+ .format(line_width, line_color, x, y1, x, y2)
153
+ position[n + t] = (x, y)
154
+ label[n + t] = l1
155
+
156
+ svg += '</svg>'
157
+
158
+ return svg
159
+
160
+
161
+ def svg_dendrogram(dendrogram: np.ndarray, names: Optional[np.ndarray] = None, rotate: bool = False, width: float = 400,
162
+ height: float = 300, margin: float = 10, margin_text: float = 5, scale: float = 1,
163
+ line_width: float = 2, n_clusters: int = 2, color: str = 'black', colors: Optional[Iterable] = None,
164
+ font_size: int = 12, reorder: bool = False, rotate_names: bool = True,
165
+ filename: Optional[str] = None):
166
+ """Return SVG image of a dendrogram.
167
+
168
+ Parameters
169
+ ----------
170
+ dendrogram :
171
+ Dendrogram to display.
172
+ names :
173
+ Names of leaves.
174
+ rotate :
175
+ If ``True``, rotate the tree so that the root is on the left.
176
+ width :
177
+ Width of the image (margins excluded).
178
+ height :
179
+ Height of the image (margins excluded).
180
+ margin :
181
+ Margin.
182
+ margin_text :
183
+ Margin between leaves and their names, if any.
184
+ scale :
185
+ Scaling factor.
186
+ line_width :
187
+ Line width.
188
+ n_clusters :
189
+ Number of coloured clusters to display.
190
+ color :
191
+ Default SVG color for the dendrogram.
192
+ colors :
193
+ SVG colors of the clusters of the dendrogram (optional).
194
+ font_size :
195
+ Font size.
196
+ reorder :
197
+ If ``True``, reorder leaves so that left subtree has more leaves than right subtree.
198
+ rotate_names :
199
+ If ``True``, rotate names of leaves (only valid if **rotate** is ``False``).
200
+ filename :
201
+ Filename for saving image (optional).
202
+
203
+ Example
204
+ -------
205
+ >>> dendrogram = np.array([[0, 1, 1, 2], [2, 3, 2, 3]])
206
+ >>> from sknetwork.visualization import svg_dendrogram
207
+ >>> image = svg_dendrogram(dendrogram)
208
+ >>> image[1:4]
209
+ 'svg'
210
+ """
211
+ if colors is None:
212
+ colors = STANDARD_COLORS
213
+ elif isinstance(colors, dict):
214
+ colors = np.array(list(colors.values()))
215
+ elif isinstance(colors, list):
216
+ colors = np.array(colors)
217
+
218
+ if rotate:
219
+ svg = svg_dendrogram_left(dendrogram, names, width, height, margin, margin_text, scale, line_width, n_clusters,
220
+ color, colors, font_size, reorder)
221
+ else:
222
+ svg = svg_dendrogram_top(dendrogram, names, width, height, margin, margin_text, scale, line_width, n_clusters,
223
+ color, colors, font_size, reorder, rotate_names)
224
+
225
+ if filename is not None:
226
+ with open(filename + '.svg', 'w') as f:
227
+ f.write(svg)
228
+
229
+ return svg