scikit-network 0.33.4__cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. scikit_network-0.33.4.dist-info/METADATA +122 -0
  2. scikit_network-0.33.4.dist-info/RECORD +229 -0
  3. scikit_network-0.33.4.dist-info/WHEEL +6 -0
  4. scikit_network-0.33.4.dist-info/licenses/AUTHORS.rst +43 -0
  5. scikit_network-0.33.4.dist-info/licenses/LICENSE +34 -0
  6. scikit_network-0.33.4.dist-info/top_level.txt +1 -0
  7. scikit_network.libs/libgomp-a34b3233.so.1.0.0 +0 -0
  8. sknetwork/__init__.py +21 -0
  9. sknetwork/base.py +67 -0
  10. sknetwork/classification/__init__.py +8 -0
  11. sknetwork/classification/base.py +138 -0
  12. sknetwork/classification/base_rank.py +129 -0
  13. sknetwork/classification/diffusion.py +127 -0
  14. sknetwork/classification/knn.py +131 -0
  15. sknetwork/classification/metrics.py +205 -0
  16. sknetwork/classification/pagerank.py +58 -0
  17. sknetwork/classification/propagation.py +144 -0
  18. sknetwork/classification/tests/__init__.py +1 -0
  19. sknetwork/classification/tests/test_API.py +30 -0
  20. sknetwork/classification/tests/test_diffusion.py +77 -0
  21. sknetwork/classification/tests/test_knn.py +23 -0
  22. sknetwork/classification/tests/test_metrics.py +53 -0
  23. sknetwork/classification/tests/test_pagerank.py +20 -0
  24. sknetwork/classification/tests/test_propagation.py +24 -0
  25. sknetwork/classification/vote.cpp +27593 -0
  26. sknetwork/classification/vote.cpython-312-x86_64-linux-gnu.so +0 -0
  27. sknetwork/classification/vote.pyx +56 -0
  28. sknetwork/clustering/__init__.py +8 -0
  29. sknetwork/clustering/base.py +168 -0
  30. sknetwork/clustering/kcenters.py +251 -0
  31. sknetwork/clustering/leiden.py +238 -0
  32. sknetwork/clustering/leiden_core.cpp +31928 -0
  33. sknetwork/clustering/leiden_core.cpython-312-x86_64-linux-gnu.so +0 -0
  34. sknetwork/clustering/leiden_core.pyx +124 -0
  35. sknetwork/clustering/louvain.py +282 -0
  36. sknetwork/clustering/louvain_core.cpp +31573 -0
  37. sknetwork/clustering/louvain_core.cpython-312-x86_64-linux-gnu.so +0 -0
  38. sknetwork/clustering/louvain_core.pyx +124 -0
  39. sknetwork/clustering/metrics.py +91 -0
  40. sknetwork/clustering/postprocess.py +66 -0
  41. sknetwork/clustering/propagation_clustering.py +100 -0
  42. sknetwork/clustering/tests/__init__.py +1 -0
  43. sknetwork/clustering/tests/test_API.py +38 -0
  44. sknetwork/clustering/tests/test_kcenters.py +60 -0
  45. sknetwork/clustering/tests/test_leiden.py +34 -0
  46. sknetwork/clustering/tests/test_louvain.py +135 -0
  47. sknetwork/clustering/tests/test_metrics.py +50 -0
  48. sknetwork/clustering/tests/test_postprocess.py +39 -0
  49. sknetwork/data/__init__.py +6 -0
  50. sknetwork/data/base.py +33 -0
  51. sknetwork/data/load.py +292 -0
  52. sknetwork/data/models.py +459 -0
  53. sknetwork/data/parse.py +644 -0
  54. sknetwork/data/test_graphs.py +93 -0
  55. sknetwork/data/tests/__init__.py +1 -0
  56. sknetwork/data/tests/test_API.py +30 -0
  57. sknetwork/data/tests/test_base.py +14 -0
  58. sknetwork/data/tests/test_load.py +61 -0
  59. sknetwork/data/tests/test_models.py +52 -0
  60. sknetwork/data/tests/test_parse.py +250 -0
  61. sknetwork/data/tests/test_test_graphs.py +29 -0
  62. sknetwork/data/tests/test_toy_graphs.py +68 -0
  63. sknetwork/data/timeout.py +38 -0
  64. sknetwork/data/toy_graphs.py +611 -0
  65. sknetwork/embedding/__init__.py +8 -0
  66. sknetwork/embedding/base.py +90 -0
  67. sknetwork/embedding/force_atlas.py +198 -0
  68. sknetwork/embedding/louvain_embedding.py +142 -0
  69. sknetwork/embedding/random_projection.py +131 -0
  70. sknetwork/embedding/spectral.py +137 -0
  71. sknetwork/embedding/spring.py +198 -0
  72. sknetwork/embedding/svd.py +351 -0
  73. sknetwork/embedding/tests/__init__.py +1 -0
  74. sknetwork/embedding/tests/test_API.py +49 -0
  75. sknetwork/embedding/tests/test_force_atlas.py +35 -0
  76. sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
  77. sknetwork/embedding/tests/test_random_projection.py +28 -0
  78. sknetwork/embedding/tests/test_spectral.py +81 -0
  79. sknetwork/embedding/tests/test_spring.py +50 -0
  80. sknetwork/embedding/tests/test_svd.py +43 -0
  81. sknetwork/gnn/__init__.py +10 -0
  82. sknetwork/gnn/activation.py +117 -0
  83. sknetwork/gnn/base.py +181 -0
  84. sknetwork/gnn/base_activation.py +90 -0
  85. sknetwork/gnn/base_layer.py +109 -0
  86. sknetwork/gnn/gnn_classifier.py +305 -0
  87. sknetwork/gnn/layer.py +153 -0
  88. sknetwork/gnn/loss.py +180 -0
  89. sknetwork/gnn/neighbor_sampler.py +65 -0
  90. sknetwork/gnn/optimizer.py +164 -0
  91. sknetwork/gnn/tests/__init__.py +1 -0
  92. sknetwork/gnn/tests/test_activation.py +56 -0
  93. sknetwork/gnn/tests/test_base.py +75 -0
  94. sknetwork/gnn/tests/test_base_layer.py +37 -0
  95. sknetwork/gnn/tests/test_gnn_classifier.py +130 -0
  96. sknetwork/gnn/tests/test_layers.py +80 -0
  97. sknetwork/gnn/tests/test_loss.py +33 -0
  98. sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
  99. sknetwork/gnn/tests/test_optimizer.py +43 -0
  100. sknetwork/gnn/tests/test_utils.py +41 -0
  101. sknetwork/gnn/utils.py +127 -0
  102. sknetwork/hierarchy/__init__.py +6 -0
  103. sknetwork/hierarchy/base.py +90 -0
  104. sknetwork/hierarchy/louvain_hierarchy.py +260 -0
  105. sknetwork/hierarchy/metrics.py +234 -0
  106. sknetwork/hierarchy/paris.cpp +37877 -0
  107. sknetwork/hierarchy/paris.cpython-312-x86_64-linux-gnu.so +0 -0
  108. sknetwork/hierarchy/paris.pyx +310 -0
  109. sknetwork/hierarchy/postprocess.py +350 -0
  110. sknetwork/hierarchy/tests/__init__.py +1 -0
  111. sknetwork/hierarchy/tests/test_API.py +24 -0
  112. sknetwork/hierarchy/tests/test_algos.py +34 -0
  113. sknetwork/hierarchy/tests/test_metrics.py +62 -0
  114. sknetwork/hierarchy/tests/test_postprocess.py +57 -0
  115. sknetwork/linalg/__init__.py +9 -0
  116. sknetwork/linalg/basics.py +37 -0
  117. sknetwork/linalg/diteration.cpp +27409 -0
  118. sknetwork/linalg/diteration.cpython-312-x86_64-linux-gnu.so +0 -0
  119. sknetwork/linalg/diteration.pyx +47 -0
  120. sknetwork/linalg/eig_solver.py +93 -0
  121. sknetwork/linalg/laplacian.py +15 -0
  122. sknetwork/linalg/normalizer.py +86 -0
  123. sknetwork/linalg/operators.py +225 -0
  124. sknetwork/linalg/polynome.py +76 -0
  125. sknetwork/linalg/ppr_solver.py +170 -0
  126. sknetwork/linalg/push.cpp +31081 -0
  127. sknetwork/linalg/push.cpython-312-x86_64-linux-gnu.so +0 -0
  128. sknetwork/linalg/push.pyx +71 -0
  129. sknetwork/linalg/sparse_lowrank.py +142 -0
  130. sknetwork/linalg/svd_solver.py +91 -0
  131. sknetwork/linalg/tests/__init__.py +1 -0
  132. sknetwork/linalg/tests/test_eig.py +44 -0
  133. sknetwork/linalg/tests/test_laplacian.py +18 -0
  134. sknetwork/linalg/tests/test_normalization.py +34 -0
  135. sknetwork/linalg/tests/test_operators.py +66 -0
  136. sknetwork/linalg/tests/test_polynome.py +38 -0
  137. sknetwork/linalg/tests/test_ppr.py +50 -0
  138. sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
  139. sknetwork/linalg/tests/test_svd.py +38 -0
  140. sknetwork/linkpred/__init__.py +2 -0
  141. sknetwork/linkpred/base.py +46 -0
  142. sknetwork/linkpred/nn.py +126 -0
  143. sknetwork/linkpred/tests/__init__.py +1 -0
  144. sknetwork/linkpred/tests/test_nn.py +26 -0
  145. sknetwork/log.py +19 -0
  146. sknetwork/path/__init__.py +5 -0
  147. sknetwork/path/dag.py +54 -0
  148. sknetwork/path/distances.py +98 -0
  149. sknetwork/path/search.py +31 -0
  150. sknetwork/path/shortest_path.py +61 -0
  151. sknetwork/path/tests/__init__.py +1 -0
  152. sknetwork/path/tests/test_dag.py +37 -0
  153. sknetwork/path/tests/test_distances.py +62 -0
  154. sknetwork/path/tests/test_search.py +40 -0
  155. sknetwork/path/tests/test_shortest_path.py +40 -0
  156. sknetwork/ranking/__init__.py +8 -0
  157. sknetwork/ranking/base.py +57 -0
  158. sknetwork/ranking/betweenness.cpp +9716 -0
  159. sknetwork/ranking/betweenness.cpython-312-x86_64-linux-gnu.so +0 -0
  160. sknetwork/ranking/betweenness.pyx +97 -0
  161. sknetwork/ranking/closeness.py +92 -0
  162. sknetwork/ranking/hits.py +90 -0
  163. sknetwork/ranking/katz.py +79 -0
  164. sknetwork/ranking/pagerank.py +106 -0
  165. sknetwork/ranking/postprocess.py +37 -0
  166. sknetwork/ranking/tests/__init__.py +1 -0
  167. sknetwork/ranking/tests/test_API.py +32 -0
  168. sknetwork/ranking/tests/test_betweenness.py +38 -0
  169. sknetwork/ranking/tests/test_closeness.py +30 -0
  170. sknetwork/ranking/tests/test_hits.py +20 -0
  171. sknetwork/ranking/tests/test_pagerank.py +62 -0
  172. sknetwork/ranking/tests/test_postprocess.py +26 -0
  173. sknetwork/regression/__init__.py +4 -0
  174. sknetwork/regression/base.py +57 -0
  175. sknetwork/regression/diffusion.py +204 -0
  176. sknetwork/regression/tests/__init__.py +1 -0
  177. sknetwork/regression/tests/test_API.py +32 -0
  178. sknetwork/regression/tests/test_diffusion.py +56 -0
  179. sknetwork/sknetwork.py +3 -0
  180. sknetwork/test_base.py +35 -0
  181. sknetwork/test_log.py +15 -0
  182. sknetwork/topology/__init__.py +8 -0
  183. sknetwork/topology/cliques.cpp +32574 -0
  184. sknetwork/topology/cliques.cpython-312-x86_64-linux-gnu.so +0 -0
  185. sknetwork/topology/cliques.pyx +149 -0
  186. sknetwork/topology/core.cpp +30660 -0
  187. sknetwork/topology/core.cpython-312-x86_64-linux-gnu.so +0 -0
  188. sknetwork/topology/core.pyx +90 -0
  189. sknetwork/topology/cycles.py +243 -0
  190. sknetwork/topology/minheap.cpp +27341 -0
  191. sknetwork/topology/minheap.cpython-312-x86_64-linux-gnu.so +0 -0
  192. sknetwork/topology/minheap.pxd +20 -0
  193. sknetwork/topology/minheap.pyx +109 -0
  194. sknetwork/topology/structure.py +194 -0
  195. sknetwork/topology/tests/__init__.py +1 -0
  196. sknetwork/topology/tests/test_cliques.py +28 -0
  197. sknetwork/topology/tests/test_core.py +19 -0
  198. sknetwork/topology/tests/test_cycles.py +65 -0
  199. sknetwork/topology/tests/test_structure.py +85 -0
  200. sknetwork/topology/tests/test_triangles.py +38 -0
  201. sknetwork/topology/tests/test_wl.py +72 -0
  202. sknetwork/topology/triangles.cpp +8903 -0
  203. sknetwork/topology/triangles.cpython-312-x86_64-linux-gnu.so +0 -0
  204. sknetwork/topology/triangles.pyx +151 -0
  205. sknetwork/topology/weisfeiler_lehman.py +133 -0
  206. sknetwork/topology/weisfeiler_lehman_core.cpp +27644 -0
  207. sknetwork/topology/weisfeiler_lehman_core.cpython-312-x86_64-linux-gnu.so +0 -0
  208. sknetwork/topology/weisfeiler_lehman_core.pyx +114 -0
  209. sknetwork/utils/__init__.py +7 -0
  210. sknetwork/utils/check.py +355 -0
  211. sknetwork/utils/format.py +221 -0
  212. sknetwork/utils/membership.py +82 -0
  213. sknetwork/utils/neighbors.py +115 -0
  214. sknetwork/utils/tests/__init__.py +1 -0
  215. sknetwork/utils/tests/test_check.py +190 -0
  216. sknetwork/utils/tests/test_format.py +63 -0
  217. sknetwork/utils/tests/test_membership.py +24 -0
  218. sknetwork/utils/tests/test_neighbors.py +41 -0
  219. sknetwork/utils/tests/test_tfidf.py +18 -0
  220. sknetwork/utils/tests/test_values.py +66 -0
  221. sknetwork/utils/tfidf.py +37 -0
  222. sknetwork/utils/values.py +76 -0
  223. sknetwork/visualization/__init__.py +4 -0
  224. sknetwork/visualization/colors.py +34 -0
  225. sknetwork/visualization/dendrograms.py +277 -0
  226. sknetwork/visualization/graphs.py +1039 -0
  227. sknetwork/visualization/tests/__init__.py +1 -0
  228. sknetwork/visualization/tests/test_dendrograms.py +53 -0
  229. sknetwork/visualization/tests/test_graphs.py +176 -0
@@ -0,0 +1,50 @@
1
+ # -*- coding: utf-8 -*-
2
+ # tests for metrics.py
3
+ """"tests for clustering metrics"""
4
+ import unittest
5
+
6
+ import numpy as np
7
+
8
+ from sknetwork.clustering import get_modularity, Louvain
9
+ from sknetwork.data import star_wars, karate_club
10
+ from sknetwork.data.test_graphs import test_graph
11
+
12
+
13
+ class TestClusteringMetrics(unittest.TestCase):
14
+
15
+ def setUp(self):
16
+ """Basic graph for tests"""
17
+ self.adjacency = test_graph()
18
+ n = self.adjacency.shape[0]
19
+ labels = np.zeros(n)
20
+ labels[0] = 1
21
+ self.labels = labels.astype(int)
22
+ self.unique_cluster = np.zeros(n, dtype=int)
23
+
24
+ def test_api(self):
25
+ for metric in [get_modularity]:
26
+ _, fit, div = metric(self.adjacency, self.labels, return_all=True)
27
+ mod = metric(self.adjacency, self.labels, return_all=False)
28
+ self.assertAlmostEqual(fit - div, mod)
29
+ self.assertAlmostEqual(metric(self.adjacency, self.unique_cluster), 0.)
30
+
31
+ with self.assertRaises(ValueError):
32
+ metric(self.adjacency, self.labels[:3])
33
+
34
+ def test_modularity(self):
35
+ adjacency = karate_club()
36
+ labels = Louvain().fit_predict(adjacency)
37
+ self.assertAlmostEqual(get_modularity(adjacency, labels), 0.42, 2)
38
+
39
+ def test_bimodularity(self):
40
+ biadjacency = star_wars()
41
+ labels_row = np.array([0, 0, 1, 1])
42
+ labels_col = np.array([0, 1, 0])
43
+ self.assertAlmostEqual(get_modularity(biadjacency, labels_row, labels_col), 0.12, 2)
44
+
45
+ with self.assertRaises(ValueError):
46
+ get_modularity(biadjacency, labels_row)
47
+ with self.assertRaises(ValueError):
48
+ get_modularity(biadjacency, labels_row[:2], labels_col)
49
+ with self.assertRaises(ValueError):
50
+ get_modularity(biadjacency, labels_row, labels_col[:2])
@@ -0,0 +1,39 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Tests for clustering post-processing"""
4
+ import unittest
5
+
6
+ import numpy as np
7
+
8
+ from sknetwork.data import house, star_wars
9
+ from sknetwork.clustering.postprocess import reindex_labels, aggregate_graph
10
+
11
+
12
+ class TestClusteringPostProcessing(unittest.TestCase):
13
+
14
+ def test_reindex_clusters(self):
15
+ truth = np.array([1, 1, 2, 0, 0, 0])
16
+
17
+ labels = np.array([0, 0, 1, 2, 2, 2])
18
+ output = reindex_labels(labels)
19
+ self.assertTrue(np.array_equal(truth, output))
20
+
21
+ labels = np.array([0, 0, 5, 2, 2, 2])
22
+ output = reindex_labels(labels)
23
+ self.assertTrue(np.array_equal(truth, output))
24
+
25
+ def test_aggregate_graph(self):
26
+ adjacency = house()
27
+ labels = np.array([0, 0, 1, 1, 2])
28
+ aggregate = aggregate_graph(adjacency, labels)
29
+ self.assertEqual(aggregate.shape, (3, 3))
30
+
31
+ biadjacency = star_wars()
32
+ labels = np.array([0, 0, 1, 2])
33
+ labels_row = np.array([0, 1, 3, -1])
34
+ labels_col = np.array([0, 0, 1])
35
+ aggregate = aggregate_graph(biadjacency, labels=labels, labels_col=labels_col)
36
+ self.assertEqual(aggregate.shape, (3, 2))
37
+ self.assertEqual(aggregate.shape, (3, 2))
38
+ aggregate = aggregate_graph(biadjacency, labels_row=labels_row, labels_col=labels_col)
39
+ self.assertEqual(aggregate.shape, (4, 2))
@@ -0,0 +1,6 @@
1
+ """data module"""
2
+ from sknetwork.data.base import *
3
+ from sknetwork.data.load import *
4
+ from sknetwork.data.models import *
5
+ from sknetwork.data.parse import from_edge_list, from_adjacency_list, from_csv, from_graphml
6
+ from sknetwork.data.toy_graphs import *
sknetwork/data/base.py ADDED
@@ -0,0 +1,33 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created in May 2023
5
+ @author: Thomas Bonald <bonald@enst.fr>
6
+ """
7
+
8
+
9
+ class Dataset(dict):
10
+ """Container object for datasets.
11
+ Dictionary-like object that exposes its keys as attributes.
12
+ >>> dataset = Dataset(name='dataset')
13
+ >>> dataset['name']
14
+ 'dataset'
15
+ >>> dataset.name
16
+ 'dataset'
17
+ """
18
+ def __init__(self, **kwargs):
19
+ super().__init__(kwargs)
20
+
21
+ def __setattr__(self, key, value):
22
+ self[key] = value
23
+
24
+ def __getattr__(self, key):
25
+ try:
26
+ return self[key]
27
+ except KeyError:
28
+ raise AttributeError(key)
29
+
30
+
31
+ # alias for Dataset
32
+ Bunch = Dataset
33
+
sknetwork/data/load.py ADDED
@@ -0,0 +1,292 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created in November 2019
5
+ @author: Quentin Lutz <qlutz@enst.fr>
6
+ """
7
+
8
+ import pickle
9
+ import shutil
10
+ import tarfile
11
+ from os import environ, makedirs, remove, listdir
12
+ from os.path import abspath, commonprefix, exists, expanduser, isfile, join
13
+ from pathlib import Path
14
+ from typing import Optional, Union
15
+ from urllib.error import HTTPError, URLError
16
+ from urllib.request import urlretrieve
17
+
18
+ import numpy as np
19
+ from scipy import sparse
20
+
21
+ from sknetwork.data.parse import from_csv, load_labels, load_header, load_metadata
22
+ from sknetwork.data.base import Dataset
23
+ from sknetwork.utils.check import is_square
24
+ from sknetwork.log import Log
25
+
26
+ NETSET_URL = 'https://netset.telecom-paris.fr'
27
+
28
+
29
+ def is_within_directory(directory, target):
30
+ """Utility function."""
31
+ abs_directory = abspath(directory)
32
+ abs_target = abspath(target)
33
+ prefix = commonprefix([abs_directory, abs_target])
34
+ return prefix == abs_directory
35
+
36
+
37
+ def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
38
+ """Safe extraction."""
39
+ for member in tar.getmembers():
40
+ member_path = join(path, member.name)
41
+ if not is_within_directory(path, member_path):
42
+ raise Exception("Attempted path traversal in tar file.")
43
+ tar.extractall(path, members, numeric_owner=numeric_owner)
44
+
45
+
46
+ def get_data_home(data_home: Optional[Union[str, Path]] = None) -> Path:
47
+ """Return a path to a storage folder depending on the dedicated environment variable and user input.
48
+
49
+ Parameters
50
+ ----------
51
+ data_home: str
52
+ The folder to be used for dataset storage
53
+ """
54
+ if data_home is None:
55
+ data_home = environ.get('SCIKIT_NETWORK_DATA', join('~', 'scikit_network_data'))
56
+ data_home = expanduser(data_home)
57
+ if not exists(data_home):
58
+ makedirs(data_home)
59
+ return Path(data_home)
60
+
61
+
62
+ def clear_data_home(data_home: Optional[Union[str, Path]] = None):
63
+ """Clear storage folder.
64
+
65
+ Parameters
66
+ ----------
67
+ data_home: str or :class:`pathlib.Path`
68
+ The folder to be used for dataset storage.
69
+ """
70
+ data_home = get_data_home(data_home)
71
+ shutil.rmtree(data_home)
72
+
73
+
74
+ def clean_data_home(data_home: Optional[Union[str, Path]] = None):
75
+ """Clean storage folder so that it contains folders only.
76
+
77
+ Parameters
78
+ ----------
79
+ data_home: str or :class:`pathlib.Path`
80
+ The folder to be used for dataset storage
81
+ """
82
+ data_home = get_data_home(data_home)
83
+ for file in listdir(data_home):
84
+ if isfile(join(data_home, file)):
85
+ remove(join(data_home, file))
86
+
87
+
88
+ def load_netset(name: Optional[str] = None, data_home: Optional[Union[str, Path]] = None,
89
+ verbose: bool = True) -> Optional[Dataset]:
90
+ """Load a dataset from the `NetSet collection
91
+ <https://netset.telecom-paris.fr/>`_.
92
+
93
+ Parameters
94
+ ----------
95
+ name : str
96
+ Name of the dataset (all low-case). Examples include 'openflights', 'cinema' and 'wikivitals'.
97
+ data_home : str or :class:`pathlib.Path`
98
+ Folder to be used for dataset storage.
99
+ This folder must be empty or contain other folders (datasets); files will be removed.
100
+ verbose : bool
101
+ Enable verbosity.
102
+
103
+ Returns
104
+ -------
105
+ dataset : :class:`Dataset`
106
+ Returned dataset.
107
+ """
108
+ dataset = Dataset()
109
+ dataset_folder = NETSET_URL + '/datasets/'
110
+ folder_npz = NETSET_URL + '/datasets_npz/'
111
+
112
+ logger = Log(verbose)
113
+
114
+ if name is None:
115
+ print("Please specify the dataset (e.g., 'wikivitals').\n" +
116
+ f"Complete list available here: <{dataset_folder}>.")
117
+ return None
118
+ else:
119
+ name = name.lower()
120
+ data_home = get_data_home(data_home)
121
+ data_netset = data_home / 'netset'
122
+ if not data_netset.exists():
123
+ clean_data_home(data_home)
124
+ makedirs(data_netset)
125
+
126
+ # remove previous dataset if not in the netset folder
127
+ direct_path = data_home / name
128
+ if direct_path.exists():
129
+ shutil.rmtree(direct_path)
130
+
131
+ data_path = data_netset / name
132
+ if not data_path.exists():
133
+ name_npz = name + '_npz.tar.gz'
134
+ try:
135
+ logger.print_log('Downloading', name, 'from NetSet...')
136
+ urlretrieve(folder_npz + name_npz, data_netset / name_npz)
137
+ except HTTPError:
138
+ raise ValueError('Invalid dataset: ' + name + '.'
139
+ + "\nAvailable datasets include 'openflights' and 'wikivitals'."
140
+ + f"\nSee <{NETSET_URL}>")
141
+ except ConnectionResetError: # pragma: no cover
142
+ raise RuntimeError("Could not reach Netset.")
143
+ with tarfile.open(data_netset / name_npz, 'r:gz') as tar_ref:
144
+ logger.print_log('Unpacking archive...')
145
+ safe_extract(tar_ref, data_path)
146
+
147
+ files = [file for file in listdir(data_path)]
148
+ logger.print_log('Parsing files...')
149
+ for file in files:
150
+ file_components = file.split('.')
151
+ if len(file_components) == 2:
152
+ file_name, file_extension = tuple(file_components)
153
+ if file_extension == 'npz':
154
+ dataset[file_name] = sparse.load_npz(data_path / file)
155
+ elif file_extension == 'npy':
156
+ dataset[file_name] = np.load(data_path / file, allow_pickle=True)
157
+ elif file_extension == 'p':
158
+ with open(data_path / file, 'rb') as f:
159
+ dataset[file_name] = pickle.load(f)
160
+
161
+ clean_data_home(data_netset)
162
+ logger.print_log('Done.')
163
+ return dataset
164
+
165
+
166
+ def save_to_numpy_bundle(data: Dataset, bundle_name: str, data_home: Optional[Union[str, Path]] = None):
167
+ """Save a dataset in the specified data home to a collection of Numpy and Pickle files for faster subsequent loads.
168
+
169
+ Parameters
170
+ ----------
171
+ data: Dataset
172
+ Data to save.
173
+ bundle_name: str
174
+ Name to be used for the bundle folder.
175
+ data_home: str or :class:`pathlib.Path`
176
+ Folder to be used for dataset storage.
177
+ """
178
+ data_home = get_data_home(data_home)
179
+ data_path = data_home / bundle_name
180
+ makedirs(data_path, exist_ok=True)
181
+ for attribute in data:
182
+ if type(data[attribute]) == sparse.csr_matrix:
183
+ sparse.save_npz(data_path / attribute, data[attribute])
184
+ elif type(data[attribute]) == np.ndarray:
185
+ np.save(data_path / attribute, data[attribute])
186
+ else:
187
+ with open(data_path / (attribute + '.p'), 'wb') as file:
188
+ pickle.dump(data[attribute], file)
189
+
190
+
191
+ def load_from_numpy_bundle(bundle_name: str, data_home: Optional[Union[str, Path]] = None):
192
+ """Load a dataset from a collection of Numpy and Pickle files (inverse function of ``save_to_numpy_bundle``).
193
+
194
+ Parameters
195
+ ----------
196
+ bundle_name: str
197
+ Name of the bundle folder.
198
+ data_home: str or :class:`pathlib.Path`
199
+ Folder used for dataset storage.
200
+
201
+ Returns
202
+ -------
203
+ data: Dataset
204
+ Data.
205
+ """
206
+ data_home = get_data_home(data_home)
207
+ data_path = data_home / bundle_name
208
+ if not data_path.exists():
209
+ raise FileNotFoundError('No bundle at ' + str(data_path))
210
+ else:
211
+ files = listdir(data_path)
212
+ data = Dataset()
213
+ for file in files:
214
+ if len(file.split('.')) == 2:
215
+ file_name, file_extension = file.split('.')
216
+ if file_extension == 'npz':
217
+ data[file_name] = sparse.load_npz(data_path / file)
218
+ elif file_extension == 'npy':
219
+ data[file_name] = np.load(data_path / file, allow_pickle=True)
220
+ elif file_extension == 'p':
221
+ with open(data_path / file, 'rb') as f:
222
+ data[file_name] = pickle.load(f)
223
+ return data
224
+
225
+
226
+ def save(folder: Union[str, Path], data: Union[sparse.csr_matrix, Dataset]):
227
+ """Save a dataset or a CSR matrix in the current directory to a collection of Numpy and Pickle files for faster
228
+ subsequent loads. Supported attribute types include sparse matrices, NumPy arrays, strings and objects Dataset.
229
+
230
+ Parameters
231
+ ----------
232
+ folder : str or :class:`pathlib.Path`
233
+ Name of the bundle folder.
234
+ data : Union[sparse.csr_matrix, Dataset]
235
+ Data to save.
236
+
237
+ Example
238
+ -------
239
+ >>> from sknetwork.data import save
240
+ >>> dataset = Dataset()
241
+ >>> dataset.adjacency = sparse.csr_matrix(np.random.random((3, 3)) < 0.5)
242
+ >>> dataset.names = np.array(['a', 'b', 'c'])
243
+ >>> save('dataset', dataset)
244
+ >>> 'dataset' in listdir('.')
245
+ True
246
+ """
247
+ folder = Path(folder)
248
+ folder = folder.expanduser()
249
+ if folder.exists():
250
+ shutil.rmtree(folder)
251
+ if isinstance(data, sparse.csr_matrix):
252
+ dataset = Dataset()
253
+ if is_square(data):
254
+ dataset.adjacency = data
255
+ else:
256
+ dataset.biadjacency = data
257
+ data = dataset
258
+ if folder.is_absolute():
259
+ save_to_numpy_bundle(data, folder, '/')
260
+ else:
261
+ save_to_numpy_bundle(data, folder, '.')
262
+
263
+
264
+ def load(folder: Union[str, Path]):
265
+ """Load a dataset from a previously created bundle from the current directory (inverse function of ``save``).
266
+
267
+ Parameters
268
+ ----------
269
+ folder: str
270
+ Name of the bundle folder.
271
+
272
+ Returns
273
+ -------
274
+ data: Dataset
275
+ Data.
276
+
277
+ Example
278
+ -------
279
+ >>> from sknetwork.data import save
280
+ >>> dataset = Dataset()
281
+ >>> dataset.adjacency = sparse.csr_matrix(np.random.random((3, 3)) < 0.5)
282
+ >>> dataset.names = np.array(['a', 'b', 'c'])
283
+ >>> save('dataset', dataset)
284
+ >>> dataset = load('dataset')
285
+ >>> print(dataset.names)
286
+ ['a' 'b' 'c']
287
+ """
288
+ folder = Path(folder)
289
+ if folder.is_absolute():
290
+ return load_from_numpy_bundle(folder, '/')
291
+ else:
292
+ return load_from_numpy_bundle(folder, '.')