scikit-network 0.33.4__cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_network-0.33.4.dist-info/METADATA +122 -0
- scikit_network-0.33.4.dist-info/RECORD +229 -0
- scikit_network-0.33.4.dist-info/WHEEL +6 -0
- scikit_network-0.33.4.dist-info/licenses/AUTHORS.rst +43 -0
- scikit_network-0.33.4.dist-info/licenses/LICENSE +34 -0
- scikit_network-0.33.4.dist-info/top_level.txt +1 -0
- scikit_network.libs/libgomp-a34b3233.so.1.0.0 +0 -0
- sknetwork/__init__.py +21 -0
- sknetwork/base.py +67 -0
- sknetwork/classification/__init__.py +8 -0
- sknetwork/classification/base.py +138 -0
- sknetwork/classification/base_rank.py +129 -0
- sknetwork/classification/diffusion.py +127 -0
- sknetwork/classification/knn.py +131 -0
- sknetwork/classification/metrics.py +205 -0
- sknetwork/classification/pagerank.py +58 -0
- sknetwork/classification/propagation.py +144 -0
- sknetwork/classification/tests/__init__.py +1 -0
- sknetwork/classification/tests/test_API.py +30 -0
- sknetwork/classification/tests/test_diffusion.py +77 -0
- sknetwork/classification/tests/test_knn.py +23 -0
- sknetwork/classification/tests/test_metrics.py +53 -0
- sknetwork/classification/tests/test_pagerank.py +20 -0
- sknetwork/classification/tests/test_propagation.py +24 -0
- sknetwork/classification/vote.cpp +27593 -0
- sknetwork/classification/vote.cpython-312-x86_64-linux-gnu.so +0 -0
- sknetwork/classification/vote.pyx +56 -0
- sknetwork/clustering/__init__.py +8 -0
- sknetwork/clustering/base.py +168 -0
- sknetwork/clustering/kcenters.py +251 -0
- sknetwork/clustering/leiden.py +238 -0
- sknetwork/clustering/leiden_core.cpp +31928 -0
- sknetwork/clustering/leiden_core.cpython-312-x86_64-linux-gnu.so +0 -0
- sknetwork/clustering/leiden_core.pyx +124 -0
- sknetwork/clustering/louvain.py +282 -0
- sknetwork/clustering/louvain_core.cpp +31573 -0
- sknetwork/clustering/louvain_core.cpython-312-x86_64-linux-gnu.so +0 -0
- sknetwork/clustering/louvain_core.pyx +124 -0
- sknetwork/clustering/metrics.py +91 -0
- sknetwork/clustering/postprocess.py +66 -0
- sknetwork/clustering/propagation_clustering.py +100 -0
- sknetwork/clustering/tests/__init__.py +1 -0
- sknetwork/clustering/tests/test_API.py +38 -0
- sknetwork/clustering/tests/test_kcenters.py +60 -0
- sknetwork/clustering/tests/test_leiden.py +34 -0
- sknetwork/clustering/tests/test_louvain.py +135 -0
- sknetwork/clustering/tests/test_metrics.py +50 -0
- sknetwork/clustering/tests/test_postprocess.py +39 -0
- sknetwork/data/__init__.py +6 -0
- sknetwork/data/base.py +33 -0
- sknetwork/data/load.py +292 -0
- sknetwork/data/models.py +459 -0
- sknetwork/data/parse.py +644 -0
- sknetwork/data/test_graphs.py +93 -0
- sknetwork/data/tests/__init__.py +1 -0
- sknetwork/data/tests/test_API.py +30 -0
- sknetwork/data/tests/test_base.py +14 -0
- sknetwork/data/tests/test_load.py +61 -0
- sknetwork/data/tests/test_models.py +52 -0
- sknetwork/data/tests/test_parse.py +250 -0
- sknetwork/data/tests/test_test_graphs.py +29 -0
- sknetwork/data/tests/test_toy_graphs.py +68 -0
- sknetwork/data/timeout.py +38 -0
- sknetwork/data/toy_graphs.py +611 -0
- sknetwork/embedding/__init__.py +8 -0
- sknetwork/embedding/base.py +90 -0
- sknetwork/embedding/force_atlas.py +198 -0
- sknetwork/embedding/louvain_embedding.py +142 -0
- sknetwork/embedding/random_projection.py +131 -0
- sknetwork/embedding/spectral.py +137 -0
- sknetwork/embedding/spring.py +198 -0
- sknetwork/embedding/svd.py +351 -0
- sknetwork/embedding/tests/__init__.py +1 -0
- sknetwork/embedding/tests/test_API.py +49 -0
- sknetwork/embedding/tests/test_force_atlas.py +35 -0
- sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
- sknetwork/embedding/tests/test_random_projection.py +28 -0
- sknetwork/embedding/tests/test_spectral.py +81 -0
- sknetwork/embedding/tests/test_spring.py +50 -0
- sknetwork/embedding/tests/test_svd.py +43 -0
- sknetwork/gnn/__init__.py +10 -0
- sknetwork/gnn/activation.py +117 -0
- sknetwork/gnn/base.py +181 -0
- sknetwork/gnn/base_activation.py +90 -0
- sknetwork/gnn/base_layer.py +109 -0
- sknetwork/gnn/gnn_classifier.py +305 -0
- sknetwork/gnn/layer.py +153 -0
- sknetwork/gnn/loss.py +180 -0
- sknetwork/gnn/neighbor_sampler.py +65 -0
- sknetwork/gnn/optimizer.py +164 -0
- sknetwork/gnn/tests/__init__.py +1 -0
- sknetwork/gnn/tests/test_activation.py +56 -0
- sknetwork/gnn/tests/test_base.py +75 -0
- sknetwork/gnn/tests/test_base_layer.py +37 -0
- sknetwork/gnn/tests/test_gnn_classifier.py +130 -0
- sknetwork/gnn/tests/test_layers.py +80 -0
- sknetwork/gnn/tests/test_loss.py +33 -0
- sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
- sknetwork/gnn/tests/test_optimizer.py +43 -0
- sknetwork/gnn/tests/test_utils.py +41 -0
- sknetwork/gnn/utils.py +127 -0
- sknetwork/hierarchy/__init__.py +6 -0
- sknetwork/hierarchy/base.py +90 -0
- sknetwork/hierarchy/louvain_hierarchy.py +260 -0
- sknetwork/hierarchy/metrics.py +234 -0
- sknetwork/hierarchy/paris.cpp +37877 -0
- sknetwork/hierarchy/paris.cpython-312-x86_64-linux-gnu.so +0 -0
- sknetwork/hierarchy/paris.pyx +310 -0
- sknetwork/hierarchy/postprocess.py +350 -0
- sknetwork/hierarchy/tests/__init__.py +1 -0
- sknetwork/hierarchy/tests/test_API.py +24 -0
- sknetwork/hierarchy/tests/test_algos.py +34 -0
- sknetwork/hierarchy/tests/test_metrics.py +62 -0
- sknetwork/hierarchy/tests/test_postprocess.py +57 -0
- sknetwork/linalg/__init__.py +9 -0
- sknetwork/linalg/basics.py +37 -0
- sknetwork/linalg/diteration.cpp +27409 -0
- sknetwork/linalg/diteration.cpython-312-x86_64-linux-gnu.so +0 -0
- sknetwork/linalg/diteration.pyx +47 -0
- sknetwork/linalg/eig_solver.py +93 -0
- sknetwork/linalg/laplacian.py +15 -0
- sknetwork/linalg/normalizer.py +86 -0
- sknetwork/linalg/operators.py +225 -0
- sknetwork/linalg/polynome.py +76 -0
- sknetwork/linalg/ppr_solver.py +170 -0
- sknetwork/linalg/push.cpp +31081 -0
- sknetwork/linalg/push.cpython-312-x86_64-linux-gnu.so +0 -0
- sknetwork/linalg/push.pyx +71 -0
- sknetwork/linalg/sparse_lowrank.py +142 -0
- sknetwork/linalg/svd_solver.py +91 -0
- sknetwork/linalg/tests/__init__.py +1 -0
- sknetwork/linalg/tests/test_eig.py +44 -0
- sknetwork/linalg/tests/test_laplacian.py +18 -0
- sknetwork/linalg/tests/test_normalization.py +34 -0
- sknetwork/linalg/tests/test_operators.py +66 -0
- sknetwork/linalg/tests/test_polynome.py +38 -0
- sknetwork/linalg/tests/test_ppr.py +50 -0
- sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
- sknetwork/linalg/tests/test_svd.py +38 -0
- sknetwork/linkpred/__init__.py +2 -0
- sknetwork/linkpred/base.py +46 -0
- sknetwork/linkpred/nn.py +126 -0
- sknetwork/linkpred/tests/__init__.py +1 -0
- sknetwork/linkpred/tests/test_nn.py +26 -0
- sknetwork/log.py +19 -0
- sknetwork/path/__init__.py +5 -0
- sknetwork/path/dag.py +54 -0
- sknetwork/path/distances.py +98 -0
- sknetwork/path/search.py +31 -0
- sknetwork/path/shortest_path.py +61 -0
- sknetwork/path/tests/__init__.py +1 -0
- sknetwork/path/tests/test_dag.py +37 -0
- sknetwork/path/tests/test_distances.py +62 -0
- sknetwork/path/tests/test_search.py +40 -0
- sknetwork/path/tests/test_shortest_path.py +40 -0
- sknetwork/ranking/__init__.py +8 -0
- sknetwork/ranking/base.py +57 -0
- sknetwork/ranking/betweenness.cpp +9716 -0
- sknetwork/ranking/betweenness.cpython-312-x86_64-linux-gnu.so +0 -0
- sknetwork/ranking/betweenness.pyx +97 -0
- sknetwork/ranking/closeness.py +92 -0
- sknetwork/ranking/hits.py +90 -0
- sknetwork/ranking/katz.py +79 -0
- sknetwork/ranking/pagerank.py +106 -0
- sknetwork/ranking/postprocess.py +37 -0
- sknetwork/ranking/tests/__init__.py +1 -0
- sknetwork/ranking/tests/test_API.py +32 -0
- sknetwork/ranking/tests/test_betweenness.py +38 -0
- sknetwork/ranking/tests/test_closeness.py +30 -0
- sknetwork/ranking/tests/test_hits.py +20 -0
- sknetwork/ranking/tests/test_pagerank.py +62 -0
- sknetwork/ranking/tests/test_postprocess.py +26 -0
- sknetwork/regression/__init__.py +4 -0
- sknetwork/regression/base.py +57 -0
- sknetwork/regression/diffusion.py +204 -0
- sknetwork/regression/tests/__init__.py +1 -0
- sknetwork/regression/tests/test_API.py +32 -0
- sknetwork/regression/tests/test_diffusion.py +56 -0
- sknetwork/sknetwork.py +3 -0
- sknetwork/test_base.py +35 -0
- sknetwork/test_log.py +15 -0
- sknetwork/topology/__init__.py +8 -0
- sknetwork/topology/cliques.cpp +32574 -0
- sknetwork/topology/cliques.cpython-312-x86_64-linux-gnu.so +0 -0
- sknetwork/topology/cliques.pyx +149 -0
- sknetwork/topology/core.cpp +30660 -0
- sknetwork/topology/core.cpython-312-x86_64-linux-gnu.so +0 -0
- sknetwork/topology/core.pyx +90 -0
- sknetwork/topology/cycles.py +243 -0
- sknetwork/topology/minheap.cpp +27341 -0
- sknetwork/topology/minheap.cpython-312-x86_64-linux-gnu.so +0 -0
- sknetwork/topology/minheap.pxd +20 -0
- sknetwork/topology/minheap.pyx +109 -0
- sknetwork/topology/structure.py +194 -0
- sknetwork/topology/tests/__init__.py +1 -0
- sknetwork/topology/tests/test_cliques.py +28 -0
- sknetwork/topology/tests/test_core.py +19 -0
- sknetwork/topology/tests/test_cycles.py +65 -0
- sknetwork/topology/tests/test_structure.py +85 -0
- sknetwork/topology/tests/test_triangles.py +38 -0
- sknetwork/topology/tests/test_wl.py +72 -0
- sknetwork/topology/triangles.cpp +8903 -0
- sknetwork/topology/triangles.cpython-312-x86_64-linux-gnu.so +0 -0
- sknetwork/topology/triangles.pyx +151 -0
- sknetwork/topology/weisfeiler_lehman.py +133 -0
- sknetwork/topology/weisfeiler_lehman_core.cpp +27644 -0
- sknetwork/topology/weisfeiler_lehman_core.cpython-312-x86_64-linux-gnu.so +0 -0
- sknetwork/topology/weisfeiler_lehman_core.pyx +114 -0
- sknetwork/utils/__init__.py +7 -0
- sknetwork/utils/check.py +355 -0
- sknetwork/utils/format.py +221 -0
- sknetwork/utils/membership.py +82 -0
- sknetwork/utils/neighbors.py +115 -0
- sknetwork/utils/tests/__init__.py +1 -0
- sknetwork/utils/tests/test_check.py +190 -0
- sknetwork/utils/tests/test_format.py +63 -0
- sknetwork/utils/tests/test_membership.py +24 -0
- sknetwork/utils/tests/test_neighbors.py +41 -0
- sknetwork/utils/tests/test_tfidf.py +18 -0
- sknetwork/utils/tests/test_values.py +66 -0
- sknetwork/utils/tfidf.py +37 -0
- sknetwork/utils/values.py +76 -0
- sknetwork/visualization/__init__.py +4 -0
- sknetwork/visualization/colors.py +34 -0
- sknetwork/visualization/dendrograms.py +277 -0
- sknetwork/visualization/graphs.py +1039 -0
- sknetwork/visualization/tests/__init__.py +1 -0
- sknetwork/visualization/tests/test_dendrograms.py +53 -0
- sknetwork/visualization/tests/test_graphs.py +176 -0
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# tests for metrics.py
|
|
3
|
+
""""tests for clustering metrics"""
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from sknetwork.clustering import get_modularity, Louvain
|
|
9
|
+
from sknetwork.data import star_wars, karate_club
|
|
10
|
+
from sknetwork.data.test_graphs import test_graph
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestClusteringMetrics(unittest.TestCase):
|
|
14
|
+
|
|
15
|
+
def setUp(self):
|
|
16
|
+
"""Basic graph for tests"""
|
|
17
|
+
self.adjacency = test_graph()
|
|
18
|
+
n = self.adjacency.shape[0]
|
|
19
|
+
labels = np.zeros(n)
|
|
20
|
+
labels[0] = 1
|
|
21
|
+
self.labels = labels.astype(int)
|
|
22
|
+
self.unique_cluster = np.zeros(n, dtype=int)
|
|
23
|
+
|
|
24
|
+
def test_api(self):
|
|
25
|
+
for metric in [get_modularity]:
|
|
26
|
+
_, fit, div = metric(self.adjacency, self.labels, return_all=True)
|
|
27
|
+
mod = metric(self.adjacency, self.labels, return_all=False)
|
|
28
|
+
self.assertAlmostEqual(fit - div, mod)
|
|
29
|
+
self.assertAlmostEqual(metric(self.adjacency, self.unique_cluster), 0.)
|
|
30
|
+
|
|
31
|
+
with self.assertRaises(ValueError):
|
|
32
|
+
metric(self.adjacency, self.labels[:3])
|
|
33
|
+
|
|
34
|
+
def test_modularity(self):
|
|
35
|
+
adjacency = karate_club()
|
|
36
|
+
labels = Louvain().fit_predict(adjacency)
|
|
37
|
+
self.assertAlmostEqual(get_modularity(adjacency, labels), 0.42, 2)
|
|
38
|
+
|
|
39
|
+
def test_bimodularity(self):
|
|
40
|
+
biadjacency = star_wars()
|
|
41
|
+
labels_row = np.array([0, 0, 1, 1])
|
|
42
|
+
labels_col = np.array([0, 1, 0])
|
|
43
|
+
self.assertAlmostEqual(get_modularity(biadjacency, labels_row, labels_col), 0.12, 2)
|
|
44
|
+
|
|
45
|
+
with self.assertRaises(ValueError):
|
|
46
|
+
get_modularity(biadjacency, labels_row)
|
|
47
|
+
with self.assertRaises(ValueError):
|
|
48
|
+
get_modularity(biadjacency, labels_row[:2], labels_col)
|
|
49
|
+
with self.assertRaises(ValueError):
|
|
50
|
+
get_modularity(biadjacency, labels_row, labels_col[:2])
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""Tests for clustering post-processing"""
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from sknetwork.data import house, star_wars
|
|
9
|
+
from sknetwork.clustering.postprocess import reindex_labels, aggregate_graph
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TestClusteringPostProcessing(unittest.TestCase):
|
|
13
|
+
|
|
14
|
+
def test_reindex_clusters(self):
|
|
15
|
+
truth = np.array([1, 1, 2, 0, 0, 0])
|
|
16
|
+
|
|
17
|
+
labels = np.array([0, 0, 1, 2, 2, 2])
|
|
18
|
+
output = reindex_labels(labels)
|
|
19
|
+
self.assertTrue(np.array_equal(truth, output))
|
|
20
|
+
|
|
21
|
+
labels = np.array([0, 0, 5, 2, 2, 2])
|
|
22
|
+
output = reindex_labels(labels)
|
|
23
|
+
self.assertTrue(np.array_equal(truth, output))
|
|
24
|
+
|
|
25
|
+
def test_aggregate_graph(self):
|
|
26
|
+
adjacency = house()
|
|
27
|
+
labels = np.array([0, 0, 1, 1, 2])
|
|
28
|
+
aggregate = aggregate_graph(adjacency, labels)
|
|
29
|
+
self.assertEqual(aggregate.shape, (3, 3))
|
|
30
|
+
|
|
31
|
+
biadjacency = star_wars()
|
|
32
|
+
labels = np.array([0, 0, 1, 2])
|
|
33
|
+
labels_row = np.array([0, 1, 3, -1])
|
|
34
|
+
labels_col = np.array([0, 0, 1])
|
|
35
|
+
aggregate = aggregate_graph(biadjacency, labels=labels, labels_col=labels_col)
|
|
36
|
+
self.assertEqual(aggregate.shape, (3, 2))
|
|
37
|
+
self.assertEqual(aggregate.shape, (3, 2))
|
|
38
|
+
aggregate = aggregate_graph(biadjacency, labels_row=labels_row, labels_col=labels_col)
|
|
39
|
+
self.assertEqual(aggregate.shape, (4, 2))
|
sknetwork/data/base.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created in May 2023
|
|
5
|
+
@author: Thomas Bonald <bonald@enst.fr>
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Dataset(dict):
|
|
10
|
+
"""Container object for datasets.
|
|
11
|
+
Dictionary-like object that exposes its keys as attributes.
|
|
12
|
+
>>> dataset = Dataset(name='dataset')
|
|
13
|
+
>>> dataset['name']
|
|
14
|
+
'dataset'
|
|
15
|
+
>>> dataset.name
|
|
16
|
+
'dataset'
|
|
17
|
+
"""
|
|
18
|
+
def __init__(self, **kwargs):
|
|
19
|
+
super().__init__(kwargs)
|
|
20
|
+
|
|
21
|
+
def __setattr__(self, key, value):
|
|
22
|
+
self[key] = value
|
|
23
|
+
|
|
24
|
+
def __getattr__(self, key):
|
|
25
|
+
try:
|
|
26
|
+
return self[key]
|
|
27
|
+
except KeyError:
|
|
28
|
+
raise AttributeError(key)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# alias for Dataset
|
|
32
|
+
Bunch = Dataset
|
|
33
|
+
|
sknetwork/data/load.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created in November 2019
|
|
5
|
+
@author: Quentin Lutz <qlutz@enst.fr>
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import pickle
|
|
9
|
+
import shutil
|
|
10
|
+
import tarfile
|
|
11
|
+
from os import environ, makedirs, remove, listdir
|
|
12
|
+
from os.path import abspath, commonprefix, exists, expanduser, isfile, join
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Optional, Union
|
|
15
|
+
from urllib.error import HTTPError, URLError
|
|
16
|
+
from urllib.request import urlretrieve
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
from scipy import sparse
|
|
20
|
+
|
|
21
|
+
from sknetwork.data.parse import from_csv, load_labels, load_header, load_metadata
|
|
22
|
+
from sknetwork.data.base import Dataset
|
|
23
|
+
from sknetwork.utils.check import is_square
|
|
24
|
+
from sknetwork.log import Log
|
|
25
|
+
|
|
26
|
+
NETSET_URL = 'https://netset.telecom-paris.fr'
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def is_within_directory(directory, target):
|
|
30
|
+
"""Utility function."""
|
|
31
|
+
abs_directory = abspath(directory)
|
|
32
|
+
abs_target = abspath(target)
|
|
33
|
+
prefix = commonprefix([abs_directory, abs_target])
|
|
34
|
+
return prefix == abs_directory
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
|
|
38
|
+
"""Safe extraction."""
|
|
39
|
+
for member in tar.getmembers():
|
|
40
|
+
member_path = join(path, member.name)
|
|
41
|
+
if not is_within_directory(path, member_path):
|
|
42
|
+
raise Exception("Attempted path traversal in tar file.")
|
|
43
|
+
tar.extractall(path, members, numeric_owner=numeric_owner)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_data_home(data_home: Optional[Union[str, Path]] = None) -> Path:
|
|
47
|
+
"""Return a path to a storage folder depending on the dedicated environment variable and user input.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
data_home: str
|
|
52
|
+
The folder to be used for dataset storage
|
|
53
|
+
"""
|
|
54
|
+
if data_home is None:
|
|
55
|
+
data_home = environ.get('SCIKIT_NETWORK_DATA', join('~', 'scikit_network_data'))
|
|
56
|
+
data_home = expanduser(data_home)
|
|
57
|
+
if not exists(data_home):
|
|
58
|
+
makedirs(data_home)
|
|
59
|
+
return Path(data_home)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def clear_data_home(data_home: Optional[Union[str, Path]] = None):
|
|
63
|
+
"""Clear storage folder.
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
data_home: str or :class:`pathlib.Path`
|
|
68
|
+
The folder to be used for dataset storage.
|
|
69
|
+
"""
|
|
70
|
+
data_home = get_data_home(data_home)
|
|
71
|
+
shutil.rmtree(data_home)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def clean_data_home(data_home: Optional[Union[str, Path]] = None):
|
|
75
|
+
"""Clean storage folder so that it contains folders only.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
data_home: str or :class:`pathlib.Path`
|
|
80
|
+
The folder to be used for dataset storage
|
|
81
|
+
"""
|
|
82
|
+
data_home = get_data_home(data_home)
|
|
83
|
+
for file in listdir(data_home):
|
|
84
|
+
if isfile(join(data_home, file)):
|
|
85
|
+
remove(join(data_home, file))
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def load_netset(name: Optional[str] = None, data_home: Optional[Union[str, Path]] = None,
|
|
89
|
+
verbose: bool = True) -> Optional[Dataset]:
|
|
90
|
+
"""Load a dataset from the `NetSet collection
|
|
91
|
+
<https://netset.telecom-paris.fr/>`_.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
name : str
|
|
96
|
+
Name of the dataset (all low-case). Examples include 'openflights', 'cinema' and 'wikivitals'.
|
|
97
|
+
data_home : str or :class:`pathlib.Path`
|
|
98
|
+
Folder to be used for dataset storage.
|
|
99
|
+
This folder must be empty or contain other folders (datasets); files will be removed.
|
|
100
|
+
verbose : bool
|
|
101
|
+
Enable verbosity.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
dataset : :class:`Dataset`
|
|
106
|
+
Returned dataset.
|
|
107
|
+
"""
|
|
108
|
+
dataset = Dataset()
|
|
109
|
+
dataset_folder = NETSET_URL + '/datasets/'
|
|
110
|
+
folder_npz = NETSET_URL + '/datasets_npz/'
|
|
111
|
+
|
|
112
|
+
logger = Log(verbose)
|
|
113
|
+
|
|
114
|
+
if name is None:
|
|
115
|
+
print("Please specify the dataset (e.g., 'wikivitals').\n" +
|
|
116
|
+
f"Complete list available here: <{dataset_folder}>.")
|
|
117
|
+
return None
|
|
118
|
+
else:
|
|
119
|
+
name = name.lower()
|
|
120
|
+
data_home = get_data_home(data_home)
|
|
121
|
+
data_netset = data_home / 'netset'
|
|
122
|
+
if not data_netset.exists():
|
|
123
|
+
clean_data_home(data_home)
|
|
124
|
+
makedirs(data_netset)
|
|
125
|
+
|
|
126
|
+
# remove previous dataset if not in the netset folder
|
|
127
|
+
direct_path = data_home / name
|
|
128
|
+
if direct_path.exists():
|
|
129
|
+
shutil.rmtree(direct_path)
|
|
130
|
+
|
|
131
|
+
data_path = data_netset / name
|
|
132
|
+
if not data_path.exists():
|
|
133
|
+
name_npz = name + '_npz.tar.gz'
|
|
134
|
+
try:
|
|
135
|
+
logger.print_log('Downloading', name, 'from NetSet...')
|
|
136
|
+
urlretrieve(folder_npz + name_npz, data_netset / name_npz)
|
|
137
|
+
except HTTPError:
|
|
138
|
+
raise ValueError('Invalid dataset: ' + name + '.'
|
|
139
|
+
+ "\nAvailable datasets include 'openflights' and 'wikivitals'."
|
|
140
|
+
+ f"\nSee <{NETSET_URL}>")
|
|
141
|
+
except ConnectionResetError: # pragma: no cover
|
|
142
|
+
raise RuntimeError("Could not reach Netset.")
|
|
143
|
+
with tarfile.open(data_netset / name_npz, 'r:gz') as tar_ref:
|
|
144
|
+
logger.print_log('Unpacking archive...')
|
|
145
|
+
safe_extract(tar_ref, data_path)
|
|
146
|
+
|
|
147
|
+
files = [file for file in listdir(data_path)]
|
|
148
|
+
logger.print_log('Parsing files...')
|
|
149
|
+
for file in files:
|
|
150
|
+
file_components = file.split('.')
|
|
151
|
+
if len(file_components) == 2:
|
|
152
|
+
file_name, file_extension = tuple(file_components)
|
|
153
|
+
if file_extension == 'npz':
|
|
154
|
+
dataset[file_name] = sparse.load_npz(data_path / file)
|
|
155
|
+
elif file_extension == 'npy':
|
|
156
|
+
dataset[file_name] = np.load(data_path / file, allow_pickle=True)
|
|
157
|
+
elif file_extension == 'p':
|
|
158
|
+
with open(data_path / file, 'rb') as f:
|
|
159
|
+
dataset[file_name] = pickle.load(f)
|
|
160
|
+
|
|
161
|
+
clean_data_home(data_netset)
|
|
162
|
+
logger.print_log('Done.')
|
|
163
|
+
return dataset
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def save_to_numpy_bundle(data: Dataset, bundle_name: str, data_home: Optional[Union[str, Path]] = None):
|
|
167
|
+
"""Save a dataset in the specified data home to a collection of Numpy and Pickle files for faster subsequent loads.
|
|
168
|
+
|
|
169
|
+
Parameters
|
|
170
|
+
----------
|
|
171
|
+
data: Dataset
|
|
172
|
+
Data to save.
|
|
173
|
+
bundle_name: str
|
|
174
|
+
Name to be used for the bundle folder.
|
|
175
|
+
data_home: str or :class:`pathlib.Path`
|
|
176
|
+
Folder to be used for dataset storage.
|
|
177
|
+
"""
|
|
178
|
+
data_home = get_data_home(data_home)
|
|
179
|
+
data_path = data_home / bundle_name
|
|
180
|
+
makedirs(data_path, exist_ok=True)
|
|
181
|
+
for attribute in data:
|
|
182
|
+
if type(data[attribute]) == sparse.csr_matrix:
|
|
183
|
+
sparse.save_npz(data_path / attribute, data[attribute])
|
|
184
|
+
elif type(data[attribute]) == np.ndarray:
|
|
185
|
+
np.save(data_path / attribute, data[attribute])
|
|
186
|
+
else:
|
|
187
|
+
with open(data_path / (attribute + '.p'), 'wb') as file:
|
|
188
|
+
pickle.dump(data[attribute], file)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def load_from_numpy_bundle(bundle_name: str, data_home: Optional[Union[str, Path]] = None):
|
|
192
|
+
"""Load a dataset from a collection of Numpy and Pickle files (inverse function of ``save_to_numpy_bundle``).
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
bundle_name: str
|
|
197
|
+
Name of the bundle folder.
|
|
198
|
+
data_home: str or :class:`pathlib.Path`
|
|
199
|
+
Folder used for dataset storage.
|
|
200
|
+
|
|
201
|
+
Returns
|
|
202
|
+
-------
|
|
203
|
+
data: Dataset
|
|
204
|
+
Data.
|
|
205
|
+
"""
|
|
206
|
+
data_home = get_data_home(data_home)
|
|
207
|
+
data_path = data_home / bundle_name
|
|
208
|
+
if not data_path.exists():
|
|
209
|
+
raise FileNotFoundError('No bundle at ' + str(data_path))
|
|
210
|
+
else:
|
|
211
|
+
files = listdir(data_path)
|
|
212
|
+
data = Dataset()
|
|
213
|
+
for file in files:
|
|
214
|
+
if len(file.split('.')) == 2:
|
|
215
|
+
file_name, file_extension = file.split('.')
|
|
216
|
+
if file_extension == 'npz':
|
|
217
|
+
data[file_name] = sparse.load_npz(data_path / file)
|
|
218
|
+
elif file_extension == 'npy':
|
|
219
|
+
data[file_name] = np.load(data_path / file, allow_pickle=True)
|
|
220
|
+
elif file_extension == 'p':
|
|
221
|
+
with open(data_path / file, 'rb') as f:
|
|
222
|
+
data[file_name] = pickle.load(f)
|
|
223
|
+
return data
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def save(folder: Union[str, Path], data: Union[sparse.csr_matrix, Dataset]):
|
|
227
|
+
"""Save a dataset or a CSR matrix in the current directory to a collection of Numpy and Pickle files for faster
|
|
228
|
+
subsequent loads. Supported attribute types include sparse matrices, NumPy arrays, strings and objects Dataset.
|
|
229
|
+
|
|
230
|
+
Parameters
|
|
231
|
+
----------
|
|
232
|
+
folder : str or :class:`pathlib.Path`
|
|
233
|
+
Name of the bundle folder.
|
|
234
|
+
data : Union[sparse.csr_matrix, Dataset]
|
|
235
|
+
Data to save.
|
|
236
|
+
|
|
237
|
+
Example
|
|
238
|
+
-------
|
|
239
|
+
>>> from sknetwork.data import save
|
|
240
|
+
>>> dataset = Dataset()
|
|
241
|
+
>>> dataset.adjacency = sparse.csr_matrix(np.random.random((3, 3)) < 0.5)
|
|
242
|
+
>>> dataset.names = np.array(['a', 'b', 'c'])
|
|
243
|
+
>>> save('dataset', dataset)
|
|
244
|
+
>>> 'dataset' in listdir('.')
|
|
245
|
+
True
|
|
246
|
+
"""
|
|
247
|
+
folder = Path(folder)
|
|
248
|
+
folder = folder.expanduser()
|
|
249
|
+
if folder.exists():
|
|
250
|
+
shutil.rmtree(folder)
|
|
251
|
+
if isinstance(data, sparse.csr_matrix):
|
|
252
|
+
dataset = Dataset()
|
|
253
|
+
if is_square(data):
|
|
254
|
+
dataset.adjacency = data
|
|
255
|
+
else:
|
|
256
|
+
dataset.biadjacency = data
|
|
257
|
+
data = dataset
|
|
258
|
+
if folder.is_absolute():
|
|
259
|
+
save_to_numpy_bundle(data, folder, '/')
|
|
260
|
+
else:
|
|
261
|
+
save_to_numpy_bundle(data, folder, '.')
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def load(folder: Union[str, Path]):
|
|
265
|
+
"""Load a dataset from a previously created bundle from the current directory (inverse function of ``save``).
|
|
266
|
+
|
|
267
|
+
Parameters
|
|
268
|
+
----------
|
|
269
|
+
folder: str
|
|
270
|
+
Name of the bundle folder.
|
|
271
|
+
|
|
272
|
+
Returns
|
|
273
|
+
-------
|
|
274
|
+
data: Dataset
|
|
275
|
+
Data.
|
|
276
|
+
|
|
277
|
+
Example
|
|
278
|
+
-------
|
|
279
|
+
>>> from sknetwork.data import save
|
|
280
|
+
>>> dataset = Dataset()
|
|
281
|
+
>>> dataset.adjacency = sparse.csr_matrix(np.random.random((3, 3)) < 0.5)
|
|
282
|
+
>>> dataset.names = np.array(['a', 'b', 'c'])
|
|
283
|
+
>>> save('dataset', dataset)
|
|
284
|
+
>>> dataset = load('dataset')
|
|
285
|
+
>>> print(dataset.names)
|
|
286
|
+
['a' 'b' 'c']
|
|
287
|
+
"""
|
|
288
|
+
folder = Path(folder)
|
|
289
|
+
if folder.is_absolute():
|
|
290
|
+
return load_from_numpy_bundle(folder, '/')
|
|
291
|
+
else:
|
|
292
|
+
return load_from_numpy_bundle(folder, '.')
|