scikit-network 0.28.3__cp39-cp39-macosx_12_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-network might be problematic. Click here for more details.
- scikit_network-0.28.3.dist-info/AUTHORS.rst +41 -0
- scikit_network-0.28.3.dist-info/LICENSE +34 -0
- scikit_network-0.28.3.dist-info/METADATA +457 -0
- scikit_network-0.28.3.dist-info/RECORD +240 -0
- scikit_network-0.28.3.dist-info/WHEEL +5 -0
- scikit_network-0.28.3.dist-info/top_level.txt +1 -0
- sknetwork/__init__.py +21 -0
- sknetwork/classification/__init__.py +8 -0
- sknetwork/classification/base.py +84 -0
- sknetwork/classification/base_rank.py +143 -0
- sknetwork/classification/diffusion.py +134 -0
- sknetwork/classification/knn.py +162 -0
- sknetwork/classification/metrics.py +205 -0
- sknetwork/classification/pagerank.py +66 -0
- sknetwork/classification/propagation.py +152 -0
- sknetwork/classification/tests/__init__.py +1 -0
- sknetwork/classification/tests/test_API.py +35 -0
- sknetwork/classification/tests/test_diffusion.py +37 -0
- sknetwork/classification/tests/test_knn.py +24 -0
- sknetwork/classification/tests/test_metrics.py +53 -0
- sknetwork/classification/tests/test_pagerank.py +20 -0
- sknetwork/classification/tests/test_propagation.py +24 -0
- sknetwork/classification/vote.cpython-39-darwin.so +0 -0
- sknetwork/classification/vote.pyx +58 -0
- sknetwork/clustering/__init__.py +7 -0
- sknetwork/clustering/base.py +102 -0
- sknetwork/clustering/kmeans.py +142 -0
- sknetwork/clustering/louvain.py +255 -0
- sknetwork/clustering/louvain_core.cpython-39-darwin.so +0 -0
- sknetwork/clustering/louvain_core.pyx +134 -0
- sknetwork/clustering/metrics.py +91 -0
- sknetwork/clustering/postprocess.py +66 -0
- sknetwork/clustering/propagation_clustering.py +108 -0
- sknetwork/clustering/tests/__init__.py +1 -0
- sknetwork/clustering/tests/test_API.py +37 -0
- sknetwork/clustering/tests/test_kmeans.py +47 -0
- sknetwork/clustering/tests/test_louvain.py +104 -0
- sknetwork/clustering/tests/test_metrics.py +50 -0
- sknetwork/clustering/tests/test_post_processing.py +23 -0
- sknetwork/clustering/tests/test_postprocess.py +39 -0
- sknetwork/data/__init__.py +5 -0
- sknetwork/data/load.py +408 -0
- sknetwork/data/models.py +459 -0
- sknetwork/data/parse.py +621 -0
- sknetwork/data/test_graphs.py +84 -0
- sknetwork/data/tests/__init__.py +1 -0
- sknetwork/data/tests/test_API.py +30 -0
- sknetwork/data/tests/test_load.py +95 -0
- sknetwork/data/tests/test_models.py +52 -0
- sknetwork/data/tests/test_parse.py +253 -0
- sknetwork/data/tests/test_test_graphs.py +30 -0
- sknetwork/data/tests/test_toy_graphs.py +68 -0
- sknetwork/data/toy_graphs.py +619 -0
- sknetwork/embedding/__init__.py +10 -0
- sknetwork/embedding/base.py +90 -0
- sknetwork/embedding/force_atlas.py +197 -0
- sknetwork/embedding/louvain_embedding.py +174 -0
- sknetwork/embedding/louvain_hierarchy.py +142 -0
- sknetwork/embedding/metrics.py +66 -0
- sknetwork/embedding/random_projection.py +133 -0
- sknetwork/embedding/spectral.py +214 -0
- sknetwork/embedding/spring.py +198 -0
- sknetwork/embedding/svd.py +363 -0
- sknetwork/embedding/tests/__init__.py +1 -0
- sknetwork/embedding/tests/test_API.py +73 -0
- sknetwork/embedding/tests/test_force_atlas.py +35 -0
- sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
- sknetwork/embedding/tests/test_louvain_hierarchy.py +19 -0
- sknetwork/embedding/tests/test_metrics.py +29 -0
- sknetwork/embedding/tests/test_random_projection.py +28 -0
- sknetwork/embedding/tests/test_spectral.py +84 -0
- sknetwork/embedding/tests/test_spring.py +50 -0
- sknetwork/embedding/tests/test_svd.py +37 -0
- sknetwork/flow/__init__.py +3 -0
- sknetwork/flow/flow.py +73 -0
- sknetwork/flow/tests/__init__.py +1 -0
- sknetwork/flow/tests/test_flow.py +17 -0
- sknetwork/flow/tests/test_utils.py +69 -0
- sknetwork/flow/utils.py +91 -0
- sknetwork/gnn/__init__.py +10 -0
- sknetwork/gnn/activation.py +117 -0
- sknetwork/gnn/base.py +155 -0
- sknetwork/gnn/base_activation.py +89 -0
- sknetwork/gnn/base_layer.py +109 -0
- sknetwork/gnn/gnn_classifier.py +381 -0
- sknetwork/gnn/layer.py +153 -0
- sknetwork/gnn/layers.py +127 -0
- sknetwork/gnn/loss.py +180 -0
- sknetwork/gnn/neighbor_sampler.py +65 -0
- sknetwork/gnn/optimizer.py +163 -0
- sknetwork/gnn/tests/__init__.py +1 -0
- sknetwork/gnn/tests/test_activation.py +56 -0
- sknetwork/gnn/tests/test_base.py +79 -0
- sknetwork/gnn/tests/test_base_layer.py +37 -0
- sknetwork/gnn/tests/test_gnn_classifier.py +192 -0
- sknetwork/gnn/tests/test_layers.py +80 -0
- sknetwork/gnn/tests/test_loss.py +33 -0
- sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
- sknetwork/gnn/tests/test_optimizer.py +43 -0
- sknetwork/gnn/tests/test_utils.py +93 -0
- sknetwork/gnn/utils.py +219 -0
- sknetwork/hierarchy/__init__.py +7 -0
- sknetwork/hierarchy/base.py +69 -0
- sknetwork/hierarchy/louvain_hierarchy.py +264 -0
- sknetwork/hierarchy/metrics.py +234 -0
- sknetwork/hierarchy/paris.cpython-39-darwin.so +0 -0
- sknetwork/hierarchy/paris.pyx +317 -0
- sknetwork/hierarchy/postprocess.py +350 -0
- sknetwork/hierarchy/tests/__init__.py +1 -0
- sknetwork/hierarchy/tests/test_API.py +25 -0
- sknetwork/hierarchy/tests/test_algos.py +29 -0
- sknetwork/hierarchy/tests/test_metrics.py +62 -0
- sknetwork/hierarchy/tests/test_postprocess.py +57 -0
- sknetwork/hierarchy/tests/test_ward.py +25 -0
- sknetwork/hierarchy/ward.py +94 -0
- sknetwork/linalg/__init__.py +9 -0
- sknetwork/linalg/basics.py +37 -0
- sknetwork/linalg/diteration.cpython-39-darwin.so +0 -0
- sknetwork/linalg/diteration.pyx +49 -0
- sknetwork/linalg/eig_solver.py +93 -0
- sknetwork/linalg/laplacian.py +15 -0
- sknetwork/linalg/normalization.py +66 -0
- sknetwork/linalg/operators.py +225 -0
- sknetwork/linalg/polynome.py +76 -0
- sknetwork/linalg/ppr_solver.py +170 -0
- sknetwork/linalg/push.cpython-39-darwin.so +0 -0
- sknetwork/linalg/push.pyx +73 -0
- sknetwork/linalg/sparse_lowrank.py +142 -0
- sknetwork/linalg/svd_solver.py +91 -0
- sknetwork/linalg/tests/__init__.py +1 -0
- sknetwork/linalg/tests/test_eig.py +44 -0
- sknetwork/linalg/tests/test_laplacian.py +18 -0
- sknetwork/linalg/tests/test_normalization.py +38 -0
- sknetwork/linalg/tests/test_operators.py +70 -0
- sknetwork/linalg/tests/test_polynome.py +38 -0
- sknetwork/linalg/tests/test_ppr.py +50 -0
- sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
- sknetwork/linalg/tests/test_svd.py +38 -0
- sknetwork/linkpred/__init__.py +4 -0
- sknetwork/linkpred/base.py +80 -0
- sknetwork/linkpred/first_order.py +508 -0
- sknetwork/linkpred/first_order_core.cpython-39-darwin.so +0 -0
- sknetwork/linkpred/first_order_core.pyx +315 -0
- sknetwork/linkpred/postprocessing.py +98 -0
- sknetwork/linkpred/tests/__init__.py +1 -0
- sknetwork/linkpred/tests/test_API.py +49 -0
- sknetwork/linkpred/tests/test_postprocessing.py +21 -0
- sknetwork/path/__init__.py +4 -0
- sknetwork/path/metrics.py +148 -0
- sknetwork/path/search.py +65 -0
- sknetwork/path/shortest_path.py +186 -0
- sknetwork/path/tests/__init__.py +1 -0
- sknetwork/path/tests/test_metrics.py +29 -0
- sknetwork/path/tests/test_search.py +25 -0
- sknetwork/path/tests/test_shortest_path.py +45 -0
- sknetwork/ranking/__init__.py +9 -0
- sknetwork/ranking/base.py +56 -0
- sknetwork/ranking/betweenness.cpython-39-darwin.so +0 -0
- sknetwork/ranking/betweenness.pyx +99 -0
- sknetwork/ranking/closeness.py +95 -0
- sknetwork/ranking/harmonic.py +82 -0
- sknetwork/ranking/hits.py +94 -0
- sknetwork/ranking/katz.py +81 -0
- sknetwork/ranking/pagerank.py +107 -0
- sknetwork/ranking/postprocess.py +25 -0
- sknetwork/ranking/tests/__init__.py +1 -0
- sknetwork/ranking/tests/test_API.py +34 -0
- sknetwork/ranking/tests/test_betweenness.py +38 -0
- sknetwork/ranking/tests/test_closeness.py +34 -0
- sknetwork/ranking/tests/test_hits.py +20 -0
- sknetwork/ranking/tests/test_pagerank.py +69 -0
- sknetwork/regression/__init__.py +4 -0
- sknetwork/regression/base.py +56 -0
- sknetwork/regression/diffusion.py +190 -0
- sknetwork/regression/tests/__init__.py +1 -0
- sknetwork/regression/tests/test_API.py +34 -0
- sknetwork/regression/tests/test_diffusion.py +48 -0
- sknetwork/sknetwork.py +3 -0
- sknetwork/topology/__init__.py +9 -0
- sknetwork/topology/dag.py +74 -0
- sknetwork/topology/dag_core.cpython-39-darwin.so +0 -0
- sknetwork/topology/dag_core.pyx +38 -0
- sknetwork/topology/kcliques.cpython-39-darwin.so +0 -0
- sknetwork/topology/kcliques.pyx +193 -0
- sknetwork/topology/kcore.cpython-39-darwin.so +0 -0
- sknetwork/topology/kcore.pyx +120 -0
- sknetwork/topology/structure.py +234 -0
- sknetwork/topology/tests/__init__.py +1 -0
- sknetwork/topology/tests/test_cliques.py +28 -0
- sknetwork/topology/tests/test_cores.py +21 -0
- sknetwork/topology/tests/test_dag.py +26 -0
- sknetwork/topology/tests/test_structure.py +99 -0
- sknetwork/topology/tests/test_triangles.py +42 -0
- sknetwork/topology/tests/test_wl_coloring.py +49 -0
- sknetwork/topology/tests/test_wl_kernel.py +31 -0
- sknetwork/topology/triangles.cpython-39-darwin.so +0 -0
- sknetwork/topology/triangles.pyx +166 -0
- sknetwork/topology/weisfeiler_lehman.py +163 -0
- sknetwork/topology/weisfeiler_lehman_core.cpython-39-darwin.so +0 -0
- sknetwork/topology/weisfeiler_lehman_core.pyx +116 -0
- sknetwork/utils/__init__.py +40 -0
- sknetwork/utils/base.py +35 -0
- sknetwork/utils/check.py +354 -0
- sknetwork/utils/co_neighbor.py +71 -0
- sknetwork/utils/format.py +219 -0
- sknetwork/utils/kmeans.py +89 -0
- sknetwork/utils/knn.py +166 -0
- sknetwork/utils/knn1d.cpython-39-darwin.so +0 -0
- sknetwork/utils/knn1d.pyx +80 -0
- sknetwork/utils/membership.py +82 -0
- sknetwork/utils/minheap.cpython-39-darwin.so +0 -0
- sknetwork/utils/minheap.pxd +22 -0
- sknetwork/utils/minheap.pyx +111 -0
- sknetwork/utils/neighbors.py +115 -0
- sknetwork/utils/seeds.py +75 -0
- sknetwork/utils/simplex.py +140 -0
- sknetwork/utils/tests/__init__.py +1 -0
- sknetwork/utils/tests/test_base.py +28 -0
- sknetwork/utils/tests/test_bunch.py +16 -0
- sknetwork/utils/tests/test_check.py +190 -0
- sknetwork/utils/tests/test_co_neighbor.py +43 -0
- sknetwork/utils/tests/test_format.py +61 -0
- sknetwork/utils/tests/test_kmeans.py +21 -0
- sknetwork/utils/tests/test_knn.py +32 -0
- sknetwork/utils/tests/test_membership.py +24 -0
- sknetwork/utils/tests/test_neighbors.py +41 -0
- sknetwork/utils/tests/test_projection_simplex.py +33 -0
- sknetwork/utils/tests/test_seeds.py +67 -0
- sknetwork/utils/tests/test_verbose.py +15 -0
- sknetwork/utils/tests/test_ward.py +20 -0
- sknetwork/utils/timeout.py +38 -0
- sknetwork/utils/verbose.py +37 -0
- sknetwork/utils/ward.py +60 -0
- sknetwork/visualization/__init__.py +4 -0
- sknetwork/visualization/colors.py +34 -0
- sknetwork/visualization/dendrograms.py +229 -0
- sknetwork/visualization/graphs.py +819 -0
- sknetwork/visualization/tests/__init__.py +1 -0
- sknetwork/visualization/tests/test_dendrograms.py +53 -0
- sknetwork/visualization/tests/test_graphs.py +167 -0
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""tests for simplex.py"""
|
|
4
|
+
|
|
5
|
+
import unittest
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from scipy import sparse
|
|
9
|
+
|
|
10
|
+
from sknetwork.utils.check import is_proba_array
|
|
11
|
+
from sknetwork.utils.simplex import projection_simplex
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestProjSimplex(unittest.TestCase):
|
|
15
|
+
|
|
16
|
+
def test_array(self):
|
|
17
|
+
x = np.random.rand(5)
|
|
18
|
+
proj = projection_simplex(x)
|
|
19
|
+
self.assertTrue(is_proba_array(proj))
|
|
20
|
+
|
|
21
|
+
x = np.random.rand(4, 3)
|
|
22
|
+
proj = projection_simplex(x)
|
|
23
|
+
self.assertTrue(is_proba_array(proj))
|
|
24
|
+
|
|
25
|
+
def test_csr(self):
|
|
26
|
+
x = sparse.csr_matrix(np.ones((3, 3)))
|
|
27
|
+
proj1 = projection_simplex(x)
|
|
28
|
+
proj2 = projection_simplex(x.astype(bool))
|
|
29
|
+
self.assertEqual(0, (proj1-proj2).nnz)
|
|
30
|
+
|
|
31
|
+
def test_other(self):
|
|
32
|
+
with self.assertRaises(TypeError):
|
|
33
|
+
projection_simplex('toto')
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""tests for seeds.py"""
|
|
4
|
+
|
|
5
|
+
import unittest
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from sknetwork.utils.seeds import get_seeds, stack_seeds, seeds2probs
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TestSeeds(unittest.TestCase):
|
|
13
|
+
|
|
14
|
+
def test_get_seeds(self):
|
|
15
|
+
n = 10
|
|
16
|
+
seeds_array = -np.ones(n)
|
|
17
|
+
seeds_array[:2] = np.arange(2)
|
|
18
|
+
seeds_dict = {0: 0, 1: 1}
|
|
19
|
+
labels_array = get_seeds((n,), seeds_array)
|
|
20
|
+
labels_dict = get_seeds((n,), seeds_dict)
|
|
21
|
+
|
|
22
|
+
self.assertTrue(np.allclose(labels_array, labels_dict))
|
|
23
|
+
with self.assertRaises(ValueError):
|
|
24
|
+
get_seeds((5,), labels_array)
|
|
25
|
+
self.assertRaises(TypeError, get_seeds, 'toto', 3)
|
|
26
|
+
with self.assertWarns(Warning):
|
|
27
|
+
seeds_dict[0] = -1
|
|
28
|
+
get_seeds((n,), seeds_dict)
|
|
29
|
+
|
|
30
|
+
def test_seeds2probs(self):
|
|
31
|
+
n = 4
|
|
32
|
+
seeds_array = np.array([0, 1, -1, 0])
|
|
33
|
+
seeds_dict = {0: 0, 1: 1, 3: 0}
|
|
34
|
+
|
|
35
|
+
probs1 = seeds2probs(n, seeds_array)
|
|
36
|
+
probs2 = seeds2probs(n, seeds_dict)
|
|
37
|
+
self.assertTrue(np.allclose(probs1, probs2))
|
|
38
|
+
|
|
39
|
+
bad_input = np.array([0, 0, -1, 0])
|
|
40
|
+
with self.assertRaises(ValueError):
|
|
41
|
+
seeds2probs(n, bad_input)
|
|
42
|
+
|
|
43
|
+
def test_stack_seeds(self):
|
|
44
|
+
shape = 4, 3
|
|
45
|
+
seeds_row_array = np.array([0, 1, -1, 0])
|
|
46
|
+
seeds_row_dict = {0: 0, 1: 1, 3: 0}
|
|
47
|
+
seeds_col_array = np.array([0, 1, -1])
|
|
48
|
+
seeds_col_dict = {0: 0, 1: 1}
|
|
49
|
+
|
|
50
|
+
seeds1 = stack_seeds(shape, seeds_row_array, seeds_col_array)
|
|
51
|
+
seeds2 = stack_seeds(shape, seeds_row_dict, seeds_col_dict)
|
|
52
|
+
seeds3 = stack_seeds(shape, seeds_row_array, seeds_col_dict)
|
|
53
|
+
seeds4 = stack_seeds(shape, seeds_row_dict, seeds_col_array)
|
|
54
|
+
|
|
55
|
+
self.assertTrue(np.allclose(seeds1, seeds2))
|
|
56
|
+
self.assertTrue(np.allclose(seeds2, seeds3))
|
|
57
|
+
self.assertTrue(np.allclose(seeds3, seeds4))
|
|
58
|
+
|
|
59
|
+
seeds1 = stack_seeds(shape, seeds_row_array, None)
|
|
60
|
+
seeds2 = stack_seeds(shape, seeds_row_dict, None)
|
|
61
|
+
|
|
62
|
+
self.assertTrue(np.allclose(seeds1, seeds2))
|
|
63
|
+
|
|
64
|
+
seeds1 = stack_seeds(shape, None, seeds_col_array)
|
|
65
|
+
seeds2 = stack_seeds(shape, None, seeds_col_dict)
|
|
66
|
+
|
|
67
|
+
self.assertTrue(np.allclose(seeds1, seeds2))
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""tests for verbose.py"""
|
|
4
|
+
|
|
5
|
+
import unittest
|
|
6
|
+
|
|
7
|
+
from sknetwork.utils.verbose import VerboseMixin
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TestVerbose(unittest.TestCase):
|
|
11
|
+
|
|
12
|
+
def test_prints(self):
|
|
13
|
+
verbose = VerboseMixin(verbose=True)
|
|
14
|
+
verbose.log.print('There are', 4, 'seasons in a year')
|
|
15
|
+
self.assertEqual(str(verbose.log), 'There are 4 seasons in a year\n')
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created on October 2019
|
|
5
|
+
@author: Nathan de Lara <nathan.delara@polytechnique.org>
|
|
6
|
+
"""
|
|
7
|
+
import unittest
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from sknetwork.utils import WardDense
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestKMeans(unittest.TestCase):
|
|
15
|
+
|
|
16
|
+
def test_kmeans(self):
|
|
17
|
+
x = np.random.randn(10, 3)
|
|
18
|
+
ward = WardDense()
|
|
19
|
+
dendrogram = ward.fit_transform(x)
|
|
20
|
+
self.assertEqual(dendrogram.shape, (x.shape[0] - 1, 4))
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
import contextlib
|
|
3
|
+
import signal
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TimeOut(contextlib.ContextDecorator):
|
|
8
|
+
"""
|
|
9
|
+
Timeout context manager/decorator.
|
|
10
|
+
|
|
11
|
+
Adapted from https://gist.github.com/TySkby/143190ad1b88c6115597c45f996b030c on 12/10/2020.
|
|
12
|
+
|
|
13
|
+
Examples
|
|
14
|
+
--------
|
|
15
|
+
>>> from time import sleep
|
|
16
|
+
>>> try:
|
|
17
|
+
... with TimeOut(1):
|
|
18
|
+
... sleep(10)
|
|
19
|
+
... except TimeoutError:
|
|
20
|
+
... print("Function timed out")
|
|
21
|
+
Function timed out
|
|
22
|
+
"""
|
|
23
|
+
def __init__(self, seconds: float):
|
|
24
|
+
self.seconds = seconds
|
|
25
|
+
|
|
26
|
+
def _timeout_handler(self, signum, frame):
|
|
27
|
+
raise TimeoutError("Code timed out.")
|
|
28
|
+
|
|
29
|
+
def __enter__(self):
|
|
30
|
+
if hasattr(signal, "SIGALRM"):
|
|
31
|
+
signal.signal(signal.SIGALRM, self._timeout_handler)
|
|
32
|
+
signal.alarm(self.seconds)
|
|
33
|
+
else:
|
|
34
|
+
warnings.warn("SIGALRM is unavailable on Windows. Timeouts are not functional.")
|
|
35
|
+
|
|
36
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
37
|
+
if hasattr(signal, "SIGALRM"):
|
|
38
|
+
signal.alarm(0)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created in December 2019
|
|
5
|
+
@author: Quentin Lutz <qlutz@enst.fr>
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Log:
|
|
10
|
+
"""Log class for easier verbosity features"""
|
|
11
|
+
def __init__(self, verbose: bool = False):
|
|
12
|
+
self.verbose = verbose
|
|
13
|
+
self.log = ''
|
|
14
|
+
|
|
15
|
+
def print(self, *args):
|
|
16
|
+
"""Fill log with text."""
|
|
17
|
+
if self.verbose:
|
|
18
|
+
print(*args)
|
|
19
|
+
self.log += ' '.join(map(str, args)) + '\n'
|
|
20
|
+
|
|
21
|
+
def __repr__(self):
|
|
22
|
+
return self.log
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class VerboseMixin:
|
|
26
|
+
"""Mixin class for verbosity"""
|
|
27
|
+
def __init__(self, verbose: bool = False):
|
|
28
|
+
self.log = Log(verbose)
|
|
29
|
+
|
|
30
|
+
def _scipy_solver_info(self, info: int):
|
|
31
|
+
"""Fill log with scipy info."""
|
|
32
|
+
if info == 0:
|
|
33
|
+
self.log.print('Successful exit.')
|
|
34
|
+
elif info > 0:
|
|
35
|
+
self.log.print('Convergence to tolerance not achieved.')
|
|
36
|
+
else:
|
|
37
|
+
self.log.print('Illegal input or breakdown.')
|
sknetwork/utils/ward.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created on October 2019
|
|
5
|
+
@author: Nathan de Lara <nathan.delara@polytechnique.org>
|
|
6
|
+
"""
|
|
7
|
+
import numpy as np
|
|
8
|
+
from scipy.cluster.hierarchy import ward
|
|
9
|
+
|
|
10
|
+
from sknetwork.utils.base import Algorithm
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class WardDense(Algorithm):
|
|
14
|
+
"""Hierarchical clustering by the Ward method based on SciPy.
|
|
15
|
+
|
|
16
|
+
Attributes
|
|
17
|
+
----------
|
|
18
|
+
dendrogram_ : np.ndarray (n - 1, 4)
|
|
19
|
+
Dendrogram.
|
|
20
|
+
|
|
21
|
+
References
|
|
22
|
+
----------
|
|
23
|
+
* Ward, J. H., Jr. (1963). Hierarchical grouping to optimize an objective function.
|
|
24
|
+
Journal of the American Statistical Association, 58, 236–244.
|
|
25
|
+
|
|
26
|
+
* Murtagh, F., & Contreras, P. (2012). Algorithms for hierarchical clustering: an overview.
|
|
27
|
+
Wiley Interdisciplinary Reviews: Data Mining and Knowledge Discovery, 2(1), 86-97.
|
|
28
|
+
"""
|
|
29
|
+
def __init__(self):
|
|
30
|
+
self.dendrogram_ = None
|
|
31
|
+
|
|
32
|
+
def fit(self, x: np.ndarray) -> 'WardDense':
|
|
33
|
+
"""Apply algorithm to a dense matrix.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
x:
|
|
38
|
+
Data to cluster.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
self: :class:`WardDense`
|
|
43
|
+
"""
|
|
44
|
+
self.dendrogram_ = ward(x)
|
|
45
|
+
return self
|
|
46
|
+
|
|
47
|
+
def fit_transform(self, x: np.ndarray) -> np.ndarray:
|
|
48
|
+
"""Apply algorithm to a dense matrix and return the dendrogram.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
x:
|
|
53
|
+
Data to cluster.
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
dendrogram: np.ndarray
|
|
58
|
+
"""
|
|
59
|
+
self.fit(x)
|
|
60
|
+
return self.dendrogram_
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created on April 2020
|
|
5
|
+
@authors:
|
|
6
|
+
Thomas Bonald <bonald@enst.fr>
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
# standard SVG colors
|
|
12
|
+
STANDARD_COLORS = np.array(['blue', 'red', 'green', 'orange', 'purple', 'yellow', 'fuchsia', 'olive', 'aqua', 'brown'])
|
|
13
|
+
|
|
14
|
+
# 100 RGB colors of coolwarm color map.
|
|
15
|
+
COOLWARM_RGB = np.array([[58, 76, 192], [60, 79, 195], [64, 84, 199], [66, 88, 202], [70, 93, 207], [72, 96, 209],
|
|
16
|
+
[76, 102, 214], [80, 107, 218], [82, 110, 220], [86, 115, 224], [88, 118, 226], [92, 123, 229],
|
|
17
|
+
[96, 128, 232], [99, 131, 234], [103, 136, 237], [105, 139, 239], [109, 144, 241],
|
|
18
|
+
[112, 147, 243], [116, 151, 245], [120, 155, 247], [123, 158, 248], [127, 162, 250],
|
|
19
|
+
[130, 165, 251], [134, 169, 252], [138, 173, 253], [141, 175, 253], [145, 179, 254],
|
|
20
|
+
[148, 181, 254], [152, 185, 254], [155, 187, 254], [159, 190, 254], [163, 193, 254],
|
|
21
|
+
[166, 195, 253], [170, 198, 253], [172, 200, 252], [176, 203, 251], [180, 205, 250],
|
|
22
|
+
[183, 207, 249], [187, 209, 247], [189, 210, 246], [193, 212, 244], [197, 213, 242],
|
|
23
|
+
[199, 214, 240], [202, 216, 238], [205, 217, 236], [208, 218, 233], [210, 218, 231],
|
|
24
|
+
[214, 219, 228], [217, 220, 224], [219, 220, 222], [222, 219, 218], [224, 218, 215],
|
|
25
|
+
[227, 217, 211], [230, 215, 207], [231, 214, 204], [234, 211, 199], [236, 210, 196],
|
|
26
|
+
[237, 207, 192], [239, 206, 188], [241, 203, 184], [242, 200, 179], [243, 198, 176],
|
|
27
|
+
[244, 195, 171], [245, 193, 168], [246, 189, 164], [246, 186, 159], [246, 183, 156],
|
|
28
|
+
[247, 179, 151], [247, 177, 148], [247, 173, 143], [246, 169, 138], [246, 166, 135],
|
|
29
|
+
[245, 161, 130], [245, 158, 127], [244, 154, 123], [243, 150, 120], [242, 145, 115],
|
|
30
|
+
[240, 141, 111], [239, 137, 108], [237, 132, 103], [236, 128, 100], [234, 123, 96],
|
|
31
|
+
[231, 117, 92], [230, 114, 89], [227, 108, 84], [225, 104, 82], [222, 98, 78],
|
|
32
|
+
[220, 94, 75], [217, 88, 71], [214, 82, 67], [211, 77, 64], [207, 70, 61],
|
|
33
|
+
[205, 66, 58], [201, 59, 55], [197, 50, 51], [194, 45, 49], [190, 35, 45],
|
|
34
|
+
[187, 26, 43], [182, 13, 40], [179, 3, 38]])
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created on April 2020
|
|
5
|
+
@author: Thomas Bonald <bonald@enst.fr>
|
|
6
|
+
"""
|
|
7
|
+
from typing import Iterable, Optional
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from sknetwork.hierarchy.postprocess import cut_straight
|
|
12
|
+
from sknetwork.visualization.colors import STANDARD_COLORS
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_index(dendrogram, reorder=True):
|
|
16
|
+
"""Index nodes for pretty dendrogram."""
|
|
17
|
+
n = dendrogram.shape[0] + 1
|
|
18
|
+
tree = {i: [i] for i in range(n)}
|
|
19
|
+
for t in range(n - 1):
|
|
20
|
+
i = int(dendrogram[t, 0])
|
|
21
|
+
j = int(dendrogram[t, 1])
|
|
22
|
+
left: list = tree.pop(i)
|
|
23
|
+
right: list = tree.pop(j)
|
|
24
|
+
if reorder and len(left) < len(right):
|
|
25
|
+
tree[n + t] = right + left
|
|
26
|
+
else:
|
|
27
|
+
tree[n + t] = left + right
|
|
28
|
+
return list(tree.values())[0]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def svg_dendrogram_top(dendrogram, names, width, height, margin, margin_text, scale, line_width, n_clusters,
|
|
32
|
+
color, colors, font_size, reorder, rotate_names):
|
|
33
|
+
"""Dendrogram as SVG image with root on top."""
|
|
34
|
+
|
|
35
|
+
# scaling
|
|
36
|
+
height *= scale
|
|
37
|
+
width *= scale
|
|
38
|
+
|
|
39
|
+
# positioning
|
|
40
|
+
labels = cut_straight(dendrogram, n_clusters, return_dendrogram=False)
|
|
41
|
+
index = get_index(dendrogram, reorder)
|
|
42
|
+
n = len(index)
|
|
43
|
+
unit_height = height / dendrogram[-1, 2]
|
|
44
|
+
unit_width = width / n
|
|
45
|
+
height_basis = margin + height
|
|
46
|
+
position = {index[i]: (margin + i * unit_width, height_basis) for i in range(n)}
|
|
47
|
+
label = {i: l for i, l in enumerate(labels)}
|
|
48
|
+
width += 2 * margin
|
|
49
|
+
height += 2 * margin
|
|
50
|
+
if names is not None:
|
|
51
|
+
text_length = np.max(np.array([len(str(name)) for name in names]))
|
|
52
|
+
height += text_length * font_size * .5 + margin_text
|
|
53
|
+
|
|
54
|
+
svg = """<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">""".format(width, height)
|
|
55
|
+
|
|
56
|
+
# text
|
|
57
|
+
if names is not None:
|
|
58
|
+
for i in range(n):
|
|
59
|
+
x, y = position[i]
|
|
60
|
+
x -= margin_text
|
|
61
|
+
y += margin_text
|
|
62
|
+
text = str(names[i]).replace('&', ' ')
|
|
63
|
+
if rotate_names:
|
|
64
|
+
svg += """<text x="{}" y="{}" transform="rotate(60, {}, {})" font-size="{}">{}</text>""" \
|
|
65
|
+
.format(x, y, x, y, font_size, text)
|
|
66
|
+
else:
|
|
67
|
+
y += margin_text
|
|
68
|
+
svg += """<text x="{}" y="{}" font-size="{}">{}</text>""" \
|
|
69
|
+
.format(x, y, font_size, text)
|
|
70
|
+
|
|
71
|
+
# tree
|
|
72
|
+
for t in range(n - 1):
|
|
73
|
+
i = int(dendrogram[t, 0])
|
|
74
|
+
j = int(dendrogram[t, 1])
|
|
75
|
+
x1, y1 = position.pop(i)
|
|
76
|
+
x2, y2 = position.pop(j)
|
|
77
|
+
l1 = label.pop(i)
|
|
78
|
+
l2 = label.pop(j)
|
|
79
|
+
if l1 == l2:
|
|
80
|
+
line_color = colors[l1 % len(colors)]
|
|
81
|
+
else:
|
|
82
|
+
line_color = color
|
|
83
|
+
x = .5 * (x1 + x2)
|
|
84
|
+
y = height_basis - dendrogram[t, 2] * unit_height
|
|
85
|
+
svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
|
|
86
|
+
.format(line_width, line_color, x1, y1, x1, y)
|
|
87
|
+
svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
|
|
88
|
+
.format(line_width, line_color, x2, y2, x2, y)
|
|
89
|
+
svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
|
|
90
|
+
.format(line_width, line_color, x1, y, x2, y)
|
|
91
|
+
position[n + t] = (x, y)
|
|
92
|
+
label[n + t] = l1
|
|
93
|
+
|
|
94
|
+
svg += '</svg>'
|
|
95
|
+
return svg
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def svg_dendrogram_left(dendrogram, names, width, height, margin, margin_text, scale, line_width, n_clusters,
|
|
99
|
+
color, colors, font_size, reorder):
|
|
100
|
+
"""Dendrogram as SVG image with root on left side."""
|
|
101
|
+
|
|
102
|
+
# scaling
|
|
103
|
+
height *= scale
|
|
104
|
+
width *= scale
|
|
105
|
+
|
|
106
|
+
# positioning
|
|
107
|
+
labels = cut_straight(dendrogram, n_clusters, return_dendrogram=False)
|
|
108
|
+
index = get_index(dendrogram, reorder)
|
|
109
|
+
n = len(index)
|
|
110
|
+
unit_height = height / n
|
|
111
|
+
unit_width = width / dendrogram[-1, 2]
|
|
112
|
+
width_basis = width + margin
|
|
113
|
+
position = {index[i]: (width_basis, margin + i * unit_height) for i in range(n)}
|
|
114
|
+
label = {i: l for i, l in enumerate(labels)}
|
|
115
|
+
width += 2 * margin
|
|
116
|
+
height += 2 * margin
|
|
117
|
+
if names is not None:
|
|
118
|
+
text_length = np.max(np.array([len(str(name)) for name in names]))
|
|
119
|
+
width += text_length * font_size * .5 + margin_text
|
|
120
|
+
|
|
121
|
+
svg = """<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">""".format(width, height)
|
|
122
|
+
|
|
123
|
+
# text
|
|
124
|
+
if names is not None:
|
|
125
|
+
for i in range(n):
|
|
126
|
+
x, y = position[i]
|
|
127
|
+
x += margin_text
|
|
128
|
+
y += unit_height / 3
|
|
129
|
+
text = str(names[i]).replace('&', ' ')
|
|
130
|
+
svg += """<text x="{}" y="{}" font-size="{}">{}</text>""" \
|
|
131
|
+
.format(x, y, font_size, text)
|
|
132
|
+
|
|
133
|
+
# tree
|
|
134
|
+
for t in range(n - 1):
|
|
135
|
+
i = int(dendrogram[t, 0])
|
|
136
|
+
j = int(dendrogram[t, 1])
|
|
137
|
+
x1, y1 = position.pop(i)
|
|
138
|
+
x2, y2 = position.pop(j)
|
|
139
|
+
l1 = label.pop(i)
|
|
140
|
+
l2 = label.pop(j)
|
|
141
|
+
if l1 == l2:
|
|
142
|
+
line_color = colors[l1 % len(colors)]
|
|
143
|
+
else:
|
|
144
|
+
line_color = color
|
|
145
|
+
y = .5 * (y1 + y2)
|
|
146
|
+
x = width_basis - dendrogram[t, 2] * unit_width
|
|
147
|
+
svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
|
|
148
|
+
.format(line_width, line_color, x1, y1, x, y1)
|
|
149
|
+
svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
|
|
150
|
+
.format(line_width, line_color, x2, y2, x, y2)
|
|
151
|
+
svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\
|
|
152
|
+
.format(line_width, line_color, x, y1, x, y2)
|
|
153
|
+
position[n + t] = (x, y)
|
|
154
|
+
label[n + t] = l1
|
|
155
|
+
|
|
156
|
+
svg += '</svg>'
|
|
157
|
+
|
|
158
|
+
return svg
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def svg_dendrogram(dendrogram: np.ndarray, names: Optional[np.ndarray] = None, rotate: bool = False, width: float = 400,
|
|
162
|
+
height: float = 300, margin: float = 10, margin_text: float = 5, scale: float = 1,
|
|
163
|
+
line_width: float = 2, n_clusters: int = 2, color: str = 'black', colors: Optional[Iterable] = None,
|
|
164
|
+
font_size: int = 12, reorder: bool = False, rotate_names: bool = True,
|
|
165
|
+
filename: Optional[str] = None):
|
|
166
|
+
"""Return SVG image of a dendrogram.
|
|
167
|
+
|
|
168
|
+
Parameters
|
|
169
|
+
----------
|
|
170
|
+
dendrogram :
|
|
171
|
+
Dendrogram to display.
|
|
172
|
+
names :
|
|
173
|
+
Names of leaves.
|
|
174
|
+
rotate :
|
|
175
|
+
If ``True``, rotate the tree so that the root is on the left.
|
|
176
|
+
width :
|
|
177
|
+
Width of the image (margins excluded).
|
|
178
|
+
height :
|
|
179
|
+
Height of the image (margins excluded).
|
|
180
|
+
margin :
|
|
181
|
+
Margin.
|
|
182
|
+
margin_text :
|
|
183
|
+
Margin between leaves and their names, if any.
|
|
184
|
+
scale :
|
|
185
|
+
Scaling factor.
|
|
186
|
+
line_width :
|
|
187
|
+
Line width.
|
|
188
|
+
n_clusters :
|
|
189
|
+
Number of coloured clusters to display.
|
|
190
|
+
color :
|
|
191
|
+
Default SVG color for the dendrogram.
|
|
192
|
+
colors :
|
|
193
|
+
SVG colors of the clusters of the dendrogram (optional).
|
|
194
|
+
font_size :
|
|
195
|
+
Font size.
|
|
196
|
+
reorder :
|
|
197
|
+
If ``True``, reorder leaves so that left subtree has more leaves than right subtree.
|
|
198
|
+
rotate_names :
|
|
199
|
+
If ``True``, rotate names of leaves (only valid if **rotate** is ``False``).
|
|
200
|
+
filename :
|
|
201
|
+
Filename for saving image (optional).
|
|
202
|
+
|
|
203
|
+
Example
|
|
204
|
+
-------
|
|
205
|
+
>>> dendrogram = np.array([[0, 1, 1, 2], [2, 3, 2, 3]])
|
|
206
|
+
>>> from sknetwork.visualization import svg_dendrogram
|
|
207
|
+
>>> image = svg_dendrogram(dendrogram)
|
|
208
|
+
>>> image[1:4]
|
|
209
|
+
'svg'
|
|
210
|
+
"""
|
|
211
|
+
if colors is None:
|
|
212
|
+
colors = STANDARD_COLORS
|
|
213
|
+
elif isinstance(colors, dict):
|
|
214
|
+
colors = np.array(list(colors.values()))
|
|
215
|
+
elif isinstance(colors, list):
|
|
216
|
+
colors = np.array(colors)
|
|
217
|
+
|
|
218
|
+
if rotate:
|
|
219
|
+
svg = svg_dendrogram_left(dendrogram, names, width, height, margin, margin_text, scale, line_width, n_clusters,
|
|
220
|
+
color, colors, font_size, reorder)
|
|
221
|
+
else:
|
|
222
|
+
svg = svg_dendrogram_top(dendrogram, names, width, height, margin, margin_text, scale, line_width, n_clusters,
|
|
223
|
+
color, colors, font_size, reorder, rotate_names)
|
|
224
|
+
|
|
225
|
+
if filename is not None:
|
|
226
|
+
with open(filename + '.svg', 'w') as f:
|
|
227
|
+
f.write(svg)
|
|
228
|
+
|
|
229
|
+
return svg
|