scikit-network 0.33.3__cp312-cp312-macosx_10_13_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-network might be problematic. Click here for more details.

Files changed (228) hide show
  1. scikit_network-0.33.3.dist-info/METADATA +122 -0
  2. scikit_network-0.33.3.dist-info/RECORD +228 -0
  3. scikit_network-0.33.3.dist-info/WHEEL +6 -0
  4. scikit_network-0.33.3.dist-info/licenses/AUTHORS.rst +43 -0
  5. scikit_network-0.33.3.dist-info/licenses/LICENSE +34 -0
  6. scikit_network-0.33.3.dist-info/top_level.txt +1 -0
  7. sknetwork/__init__.py +21 -0
  8. sknetwork/base.py +67 -0
  9. sknetwork/classification/__init__.py +8 -0
  10. sknetwork/classification/base.py +142 -0
  11. sknetwork/classification/base_rank.py +133 -0
  12. sknetwork/classification/diffusion.py +134 -0
  13. sknetwork/classification/knn.py +139 -0
  14. sknetwork/classification/metrics.py +205 -0
  15. sknetwork/classification/pagerank.py +66 -0
  16. sknetwork/classification/propagation.py +152 -0
  17. sknetwork/classification/tests/__init__.py +1 -0
  18. sknetwork/classification/tests/test_API.py +30 -0
  19. sknetwork/classification/tests/test_diffusion.py +77 -0
  20. sknetwork/classification/tests/test_knn.py +23 -0
  21. sknetwork/classification/tests/test_metrics.py +53 -0
  22. sknetwork/classification/tests/test_pagerank.py +20 -0
  23. sknetwork/classification/tests/test_propagation.py +24 -0
  24. sknetwork/classification/vote.cpp +27581 -0
  25. sknetwork/classification/vote.cpython-312-darwin.so +0 -0
  26. sknetwork/classification/vote.pyx +56 -0
  27. sknetwork/clustering/__init__.py +8 -0
  28. sknetwork/clustering/base.py +172 -0
  29. sknetwork/clustering/kcenters.py +253 -0
  30. sknetwork/clustering/leiden.py +242 -0
  31. sknetwork/clustering/leiden_core.cpp +31572 -0
  32. sknetwork/clustering/leiden_core.cpython-312-darwin.so +0 -0
  33. sknetwork/clustering/leiden_core.pyx +124 -0
  34. sknetwork/clustering/louvain.py +286 -0
  35. sknetwork/clustering/louvain_core.cpp +31217 -0
  36. sknetwork/clustering/louvain_core.cpython-312-darwin.so +0 -0
  37. sknetwork/clustering/louvain_core.pyx +124 -0
  38. sknetwork/clustering/metrics.py +91 -0
  39. sknetwork/clustering/postprocess.py +66 -0
  40. sknetwork/clustering/propagation_clustering.py +104 -0
  41. sknetwork/clustering/tests/__init__.py +1 -0
  42. sknetwork/clustering/tests/test_API.py +38 -0
  43. sknetwork/clustering/tests/test_kcenters.py +60 -0
  44. sknetwork/clustering/tests/test_leiden.py +34 -0
  45. sknetwork/clustering/tests/test_louvain.py +135 -0
  46. sknetwork/clustering/tests/test_metrics.py +50 -0
  47. sknetwork/clustering/tests/test_postprocess.py +39 -0
  48. sknetwork/data/__init__.py +6 -0
  49. sknetwork/data/base.py +33 -0
  50. sknetwork/data/load.py +406 -0
  51. sknetwork/data/models.py +459 -0
  52. sknetwork/data/parse.py +644 -0
  53. sknetwork/data/test_graphs.py +84 -0
  54. sknetwork/data/tests/__init__.py +1 -0
  55. sknetwork/data/tests/test_API.py +30 -0
  56. sknetwork/data/tests/test_base.py +14 -0
  57. sknetwork/data/tests/test_load.py +95 -0
  58. sknetwork/data/tests/test_models.py +52 -0
  59. sknetwork/data/tests/test_parse.py +250 -0
  60. sknetwork/data/tests/test_test_graphs.py +29 -0
  61. sknetwork/data/tests/test_toy_graphs.py +68 -0
  62. sknetwork/data/timeout.py +38 -0
  63. sknetwork/data/toy_graphs.py +611 -0
  64. sknetwork/embedding/__init__.py +8 -0
  65. sknetwork/embedding/base.py +94 -0
  66. sknetwork/embedding/force_atlas.py +198 -0
  67. sknetwork/embedding/louvain_embedding.py +148 -0
  68. sknetwork/embedding/random_projection.py +135 -0
  69. sknetwork/embedding/spectral.py +141 -0
  70. sknetwork/embedding/spring.py +198 -0
  71. sknetwork/embedding/svd.py +359 -0
  72. sknetwork/embedding/tests/__init__.py +1 -0
  73. sknetwork/embedding/tests/test_API.py +49 -0
  74. sknetwork/embedding/tests/test_force_atlas.py +35 -0
  75. sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
  76. sknetwork/embedding/tests/test_random_projection.py +28 -0
  77. sknetwork/embedding/tests/test_spectral.py +81 -0
  78. sknetwork/embedding/tests/test_spring.py +50 -0
  79. sknetwork/embedding/tests/test_svd.py +43 -0
  80. sknetwork/gnn/__init__.py +10 -0
  81. sknetwork/gnn/activation.py +117 -0
  82. sknetwork/gnn/base.py +181 -0
  83. sknetwork/gnn/base_activation.py +90 -0
  84. sknetwork/gnn/base_layer.py +109 -0
  85. sknetwork/gnn/gnn_classifier.py +305 -0
  86. sknetwork/gnn/layer.py +153 -0
  87. sknetwork/gnn/loss.py +180 -0
  88. sknetwork/gnn/neighbor_sampler.py +65 -0
  89. sknetwork/gnn/optimizer.py +164 -0
  90. sknetwork/gnn/tests/__init__.py +1 -0
  91. sknetwork/gnn/tests/test_activation.py +56 -0
  92. sknetwork/gnn/tests/test_base.py +75 -0
  93. sknetwork/gnn/tests/test_base_layer.py +37 -0
  94. sknetwork/gnn/tests/test_gnn_classifier.py +130 -0
  95. sknetwork/gnn/tests/test_layers.py +80 -0
  96. sknetwork/gnn/tests/test_loss.py +33 -0
  97. sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
  98. sknetwork/gnn/tests/test_optimizer.py +43 -0
  99. sknetwork/gnn/tests/test_utils.py +41 -0
  100. sknetwork/gnn/utils.py +127 -0
  101. sknetwork/hierarchy/__init__.py +6 -0
  102. sknetwork/hierarchy/base.py +96 -0
  103. sknetwork/hierarchy/louvain_hierarchy.py +272 -0
  104. sknetwork/hierarchy/metrics.py +234 -0
  105. sknetwork/hierarchy/paris.cpp +37865 -0
  106. sknetwork/hierarchy/paris.cpython-312-darwin.so +0 -0
  107. sknetwork/hierarchy/paris.pyx +316 -0
  108. sknetwork/hierarchy/postprocess.py +350 -0
  109. sknetwork/hierarchy/tests/__init__.py +1 -0
  110. sknetwork/hierarchy/tests/test_API.py +24 -0
  111. sknetwork/hierarchy/tests/test_algos.py +34 -0
  112. sknetwork/hierarchy/tests/test_metrics.py +62 -0
  113. sknetwork/hierarchy/tests/test_postprocess.py +57 -0
  114. sknetwork/linalg/__init__.py +9 -0
  115. sknetwork/linalg/basics.py +37 -0
  116. sknetwork/linalg/diteration.cpp +27397 -0
  117. sknetwork/linalg/diteration.cpython-312-darwin.so +0 -0
  118. sknetwork/linalg/diteration.pyx +47 -0
  119. sknetwork/linalg/eig_solver.py +93 -0
  120. sknetwork/linalg/laplacian.py +15 -0
  121. sknetwork/linalg/normalizer.py +86 -0
  122. sknetwork/linalg/operators.py +225 -0
  123. sknetwork/linalg/polynome.py +76 -0
  124. sknetwork/linalg/ppr_solver.py +170 -0
  125. sknetwork/linalg/push.cpp +31069 -0
  126. sknetwork/linalg/push.cpython-312-darwin.so +0 -0
  127. sknetwork/linalg/push.pyx +71 -0
  128. sknetwork/linalg/sparse_lowrank.py +142 -0
  129. sknetwork/linalg/svd_solver.py +91 -0
  130. sknetwork/linalg/tests/__init__.py +1 -0
  131. sknetwork/linalg/tests/test_eig.py +44 -0
  132. sknetwork/linalg/tests/test_laplacian.py +18 -0
  133. sknetwork/linalg/tests/test_normalization.py +34 -0
  134. sknetwork/linalg/tests/test_operators.py +66 -0
  135. sknetwork/linalg/tests/test_polynome.py +38 -0
  136. sknetwork/linalg/tests/test_ppr.py +50 -0
  137. sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
  138. sknetwork/linalg/tests/test_svd.py +38 -0
  139. sknetwork/linkpred/__init__.py +2 -0
  140. sknetwork/linkpred/base.py +46 -0
  141. sknetwork/linkpred/nn.py +126 -0
  142. sknetwork/linkpred/tests/__init__.py +1 -0
  143. sknetwork/linkpred/tests/test_nn.py +27 -0
  144. sknetwork/log.py +19 -0
  145. sknetwork/path/__init__.py +5 -0
  146. sknetwork/path/dag.py +54 -0
  147. sknetwork/path/distances.py +98 -0
  148. sknetwork/path/search.py +31 -0
  149. sknetwork/path/shortest_path.py +61 -0
  150. sknetwork/path/tests/__init__.py +1 -0
  151. sknetwork/path/tests/test_dag.py +37 -0
  152. sknetwork/path/tests/test_distances.py +62 -0
  153. sknetwork/path/tests/test_search.py +40 -0
  154. sknetwork/path/tests/test_shortest_path.py +40 -0
  155. sknetwork/ranking/__init__.py +8 -0
  156. sknetwork/ranking/base.py +61 -0
  157. sknetwork/ranking/betweenness.cpp +9704 -0
  158. sknetwork/ranking/betweenness.cpython-312-darwin.so +0 -0
  159. sknetwork/ranking/betweenness.pyx +97 -0
  160. sknetwork/ranking/closeness.py +92 -0
  161. sknetwork/ranking/hits.py +94 -0
  162. sknetwork/ranking/katz.py +83 -0
  163. sknetwork/ranking/pagerank.py +110 -0
  164. sknetwork/ranking/postprocess.py +37 -0
  165. sknetwork/ranking/tests/__init__.py +1 -0
  166. sknetwork/ranking/tests/test_API.py +32 -0
  167. sknetwork/ranking/tests/test_betweenness.py +38 -0
  168. sknetwork/ranking/tests/test_closeness.py +30 -0
  169. sknetwork/ranking/tests/test_hits.py +20 -0
  170. sknetwork/ranking/tests/test_pagerank.py +62 -0
  171. sknetwork/ranking/tests/test_postprocess.py +26 -0
  172. sknetwork/regression/__init__.py +4 -0
  173. sknetwork/regression/base.py +61 -0
  174. sknetwork/regression/diffusion.py +210 -0
  175. sknetwork/regression/tests/__init__.py +1 -0
  176. sknetwork/regression/tests/test_API.py +32 -0
  177. sknetwork/regression/tests/test_diffusion.py +56 -0
  178. sknetwork/sknetwork.py +3 -0
  179. sknetwork/test_base.py +35 -0
  180. sknetwork/test_log.py +15 -0
  181. sknetwork/topology/__init__.py +8 -0
  182. sknetwork/topology/cliques.cpp +32562 -0
  183. sknetwork/topology/cliques.cpython-312-darwin.so +0 -0
  184. sknetwork/topology/cliques.pyx +149 -0
  185. sknetwork/topology/core.cpp +30648 -0
  186. sknetwork/topology/core.cpython-312-darwin.so +0 -0
  187. sknetwork/topology/core.pyx +90 -0
  188. sknetwork/topology/cycles.py +243 -0
  189. sknetwork/topology/minheap.cpp +27329 -0
  190. sknetwork/topology/minheap.cpython-312-darwin.so +0 -0
  191. sknetwork/topology/minheap.pxd +20 -0
  192. sknetwork/topology/minheap.pyx +109 -0
  193. sknetwork/topology/structure.py +194 -0
  194. sknetwork/topology/tests/__init__.py +1 -0
  195. sknetwork/topology/tests/test_cliques.py +28 -0
  196. sknetwork/topology/tests/test_core.py +19 -0
  197. sknetwork/topology/tests/test_cycles.py +65 -0
  198. sknetwork/topology/tests/test_structure.py +85 -0
  199. sknetwork/topology/tests/test_triangles.py +38 -0
  200. sknetwork/topology/tests/test_wl.py +72 -0
  201. sknetwork/topology/triangles.cpp +8891 -0
  202. sknetwork/topology/triangles.cpython-312-darwin.so +0 -0
  203. sknetwork/topology/triangles.pyx +151 -0
  204. sknetwork/topology/weisfeiler_lehman.py +133 -0
  205. sknetwork/topology/weisfeiler_lehman_core.cpp +27632 -0
  206. sknetwork/topology/weisfeiler_lehman_core.cpython-312-darwin.so +0 -0
  207. sknetwork/topology/weisfeiler_lehman_core.pyx +114 -0
  208. sknetwork/utils/__init__.py +7 -0
  209. sknetwork/utils/check.py +355 -0
  210. sknetwork/utils/format.py +221 -0
  211. sknetwork/utils/membership.py +82 -0
  212. sknetwork/utils/neighbors.py +115 -0
  213. sknetwork/utils/tests/__init__.py +1 -0
  214. sknetwork/utils/tests/test_check.py +190 -0
  215. sknetwork/utils/tests/test_format.py +63 -0
  216. sknetwork/utils/tests/test_membership.py +24 -0
  217. sknetwork/utils/tests/test_neighbors.py +41 -0
  218. sknetwork/utils/tests/test_tfidf.py +18 -0
  219. sknetwork/utils/tests/test_values.py +66 -0
  220. sknetwork/utils/tfidf.py +37 -0
  221. sknetwork/utils/values.py +76 -0
  222. sknetwork/visualization/__init__.py +4 -0
  223. sknetwork/visualization/colors.py +34 -0
  224. sknetwork/visualization/dendrograms.py +277 -0
  225. sknetwork/visualization/graphs.py +1039 -0
  226. sknetwork/visualization/tests/__init__.py +1 -0
  227. sknetwork/visualization/tests/test_dendrograms.py +53 -0
  228. sknetwork/visualization/tests/test_graphs.py +176 -0
@@ -0,0 +1,50 @@
1
+ # -*- coding: utf-8 -*-
2
+ # tests for metrics.py
3
+ """"tests for clustering metrics"""
4
+ import unittest
5
+
6
+ import numpy as np
7
+
8
+ from sknetwork.clustering import get_modularity, Louvain
9
+ from sknetwork.data import star_wars, karate_club
10
+ from sknetwork.data.test_graphs import test_graph
11
+
12
+
13
+ class TestClusteringMetrics(unittest.TestCase):
14
+
15
+ def setUp(self):
16
+ """Basic graph for tests"""
17
+ self.adjacency = test_graph()
18
+ n = self.adjacency.shape[0]
19
+ labels = np.zeros(n)
20
+ labels[0] = 1
21
+ self.labels = labels.astype(int)
22
+ self.unique_cluster = np.zeros(n, dtype=int)
23
+
24
+ def test_api(self):
25
+ for metric in [get_modularity]:
26
+ _, fit, div = metric(self.adjacency, self.labels, return_all=True)
27
+ mod = metric(self.adjacency, self.labels, return_all=False)
28
+ self.assertAlmostEqual(fit - div, mod)
29
+ self.assertAlmostEqual(metric(self.adjacency, self.unique_cluster), 0.)
30
+
31
+ with self.assertRaises(ValueError):
32
+ metric(self.adjacency, self.labels[:3])
33
+
34
+ def test_modularity(self):
35
+ adjacency = karate_club()
36
+ labels = Louvain().fit_predict(adjacency)
37
+ self.assertAlmostEqual(get_modularity(adjacency, labels), 0.42, 2)
38
+
39
+ def test_bimodularity(self):
40
+ biadjacency = star_wars()
41
+ labels_row = np.array([0, 0, 1, 1])
42
+ labels_col = np.array([0, 1, 0])
43
+ self.assertAlmostEqual(get_modularity(biadjacency, labels_row, labels_col), 0.12, 2)
44
+
45
+ with self.assertRaises(ValueError):
46
+ get_modularity(biadjacency, labels_row)
47
+ with self.assertRaises(ValueError):
48
+ get_modularity(biadjacency, labels_row[:2], labels_col)
49
+ with self.assertRaises(ValueError):
50
+ get_modularity(biadjacency, labels_row, labels_col[:2])
@@ -0,0 +1,39 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Tests for clustering post-processing"""
4
+ import unittest
5
+
6
+ import numpy as np
7
+
8
+ from sknetwork.data import house, star_wars
9
+ from sknetwork.clustering.postprocess import reindex_labels, aggregate_graph
10
+
11
+
12
+ class TestClusteringPostProcessing(unittest.TestCase):
13
+
14
+ def test_reindex_clusters(self):
15
+ truth = np.array([1, 1, 2, 0, 0, 0])
16
+
17
+ labels = np.array([0, 0, 1, 2, 2, 2])
18
+ output = reindex_labels(labels)
19
+ self.assertTrue(np.array_equal(truth, output))
20
+
21
+ labels = np.array([0, 0, 5, 2, 2, 2])
22
+ output = reindex_labels(labels)
23
+ self.assertTrue(np.array_equal(truth, output))
24
+
25
+ def test_aggregate_graph(self):
26
+ adjacency = house()
27
+ labels = np.array([0, 0, 1, 1, 2])
28
+ aggregate = aggregate_graph(adjacency, labels)
29
+ self.assertEqual(aggregate.shape, (3, 3))
30
+
31
+ biadjacency = star_wars()
32
+ labels = np.array([0, 0, 1, 2])
33
+ labels_row = np.array([0, 1, 3, -1])
34
+ labels_col = np.array([0, 0, 1])
35
+ aggregate = aggregate_graph(biadjacency, labels=labels, labels_col=labels_col)
36
+ self.assertEqual(aggregate.shape, (3, 2))
37
+ self.assertEqual(aggregate.shape, (3, 2))
38
+ aggregate = aggregate_graph(biadjacency, labels_row=labels_row, labels_col=labels_col)
39
+ self.assertEqual(aggregate.shape, (4, 2))
@@ -0,0 +1,6 @@
1
+ """data module"""
2
+ from sknetwork.data.base import *
3
+ from sknetwork.data.load import *
4
+ from sknetwork.data.models import *
5
+ from sknetwork.data.parse import from_edge_list, from_adjacency_list, from_csv, from_graphml
6
+ from sknetwork.data.toy_graphs import *
sknetwork/data/base.py ADDED
@@ -0,0 +1,33 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created in May 2023
5
+ @author: Thomas Bonald <bonald@enst.fr>
6
+ """
7
+
8
+
9
+ class Dataset(dict):
10
+ """Container object for datasets.
11
+ Dictionary-like object that exposes its keys as attributes.
12
+ >>> dataset = Dataset(name='dataset')
13
+ >>> dataset['name']
14
+ 'dataset'
15
+ >>> dataset.name
16
+ 'dataset'
17
+ """
18
+ def __init__(self, **kwargs):
19
+ super().__init__(kwargs)
20
+
21
+ def __setattr__(self, key, value):
22
+ self[key] = value
23
+
24
+ def __getattr__(self, key):
25
+ try:
26
+ return self[key]
27
+ except KeyError:
28
+ raise AttributeError(key)
29
+
30
+
31
+ # alias for Dataset
32
+ Bunch = Dataset
33
+
sknetwork/data/load.py ADDED
@@ -0,0 +1,406 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created in November 2019
5
+ @author: Quentin Lutz <qlutz@enst.fr>
6
+ """
7
+
8
+ import pickle
9
+ import shutil
10
+ import tarfile
11
+ from os import environ, makedirs, remove, listdir
12
+ from os.path import abspath, commonprefix, exists, expanduser, isfile, join
13
+ from pathlib import Path
14
+ from typing import Optional, Union
15
+ from urllib.error import HTTPError, URLError
16
+ from urllib.request import urlretrieve
17
+
18
+ import numpy as np
19
+ from scipy import sparse
20
+
21
+ from sknetwork.data.parse import from_csv, load_labels, load_header, load_metadata
22
+ from sknetwork.data.base import Dataset
23
+ from sknetwork.utils.check import is_square
24
+ from sknetwork.log import Log
25
+
26
+ NETSET_URL = 'https://netset.telecom-paris.fr'
27
+
28
+
29
+ def is_within_directory(directory, target):
30
+ """Utility function."""
31
+ abs_directory = abspath(directory)
32
+ abs_target = abspath(target)
33
+ prefix = commonprefix([abs_directory, abs_target])
34
+ return prefix == abs_directory
35
+
36
+
37
+ def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
38
+ """Safe extraction."""
39
+ for member in tar.getmembers():
40
+ member_path = join(path, member.name)
41
+ if not is_within_directory(path, member_path):
42
+ raise Exception("Attempted path traversal in tar file.")
43
+ tar.extractall(path, members, numeric_owner=numeric_owner)
44
+
45
+
46
+ def get_data_home(data_home: Optional[Union[str, Path]] = None) -> Path:
47
+ """Return a path to a storage folder depending on the dedicated environment variable and user input.
48
+
49
+ Parameters
50
+ ----------
51
+ data_home: str
52
+ The folder to be used for dataset storage
53
+ """
54
+ if data_home is None:
55
+ data_home = environ.get('SCIKIT_NETWORK_DATA', join('~', 'scikit_network_data'))
56
+ data_home = expanduser(data_home)
57
+ if not exists(data_home):
58
+ makedirs(data_home)
59
+ return Path(data_home)
60
+
61
+
62
+ def clear_data_home(data_home: Optional[Union[str, Path]] = None):
63
+ """Clear storage folder.
64
+
65
+ Parameters
66
+ ----------
67
+ data_home: str or :class:`pathlib.Path`
68
+ The folder to be used for dataset storage.
69
+ """
70
+ data_home = get_data_home(data_home)
71
+ shutil.rmtree(data_home)
72
+
73
+
74
+ def clean_data_home(data_home: Optional[Union[str, Path]] = None):
75
+ """Clean storage folder so that it contains folders only.
76
+
77
+ Parameters
78
+ ----------
79
+ data_home: str or :class:`pathlib.Path`
80
+ The folder to be used for dataset storage
81
+ """
82
+ data_home = get_data_home(data_home)
83
+ for file in listdir(data_home):
84
+ if isfile(join(data_home, file)):
85
+ remove(join(data_home, file))
86
+
87
+
88
+ def load_netset(name: Optional[str] = None, data_home: Optional[Union[str, Path]] = None,
89
+ verbose: bool = True) -> Optional[Dataset]:
90
+ """Load a dataset from the `NetSet collection
91
+ <https://netset.telecom-paris.fr/>`_.
92
+
93
+ Parameters
94
+ ----------
95
+ name : str
96
+ Name of the dataset (all low-case). Examples include 'openflights', 'cinema' and 'wikivitals'.
97
+ data_home : str or :class:`pathlib.Path`
98
+ Folder to be used for dataset storage.
99
+ This folder must be empty or contain other folders (datasets); files will be removed.
100
+ verbose : bool
101
+ Enable verbosity.
102
+
103
+ Returns
104
+ -------
105
+ dataset : :class:`Dataset`
106
+ Returned dataset.
107
+ """
108
+ dataset = Dataset()
109
+ dataset_folder = NETSET_URL + '/datasets/'
110
+ folder_npz = NETSET_URL + '/datasets_npz/'
111
+
112
+ logger = Log(verbose)
113
+
114
+ if name is None:
115
+ print("Please specify the dataset (e.g., 'wikivitals').\n" +
116
+ f"Complete list available here: <{dataset_folder}>.")
117
+ return None
118
+ else:
119
+ name = name.lower()
120
+ data_home = get_data_home(data_home)
121
+ data_netset = data_home / 'netset'
122
+ if not data_netset.exists():
123
+ clean_data_home(data_home)
124
+ makedirs(data_netset)
125
+
126
+ # remove previous dataset if not in the netset folder
127
+ direct_path = data_home / name
128
+ if direct_path.exists():
129
+ shutil.rmtree(direct_path)
130
+
131
+ data_path = data_netset / name
132
+ if not data_path.exists():
133
+ name_npz = name + '_npz.tar.gz'
134
+ try:
135
+ logger.print_log('Downloading', name, 'from NetSet...')
136
+ urlretrieve(folder_npz + name_npz, data_netset / name_npz)
137
+ except HTTPError:
138
+ raise ValueError('Invalid dataset: ' + name + '.'
139
+ + "\nAvailable datasets include 'openflights' and 'wikivitals'."
140
+ + f"\nSee <{NETSET_URL}>")
141
+ except ConnectionResetError: # pragma: no cover
142
+ raise RuntimeError("Could not reach Netset.")
143
+ with tarfile.open(data_netset / name_npz, 'r:gz') as tar_ref:
144
+ logger.print_log('Unpacking archive...')
145
+ safe_extract(tar_ref, data_path)
146
+
147
+ files = [file for file in listdir(data_path)]
148
+ logger.print_log('Parsing files...')
149
+ for file in files:
150
+ file_components = file.split('.')
151
+ if len(file_components) == 2:
152
+ file_name, file_extension = tuple(file_components)
153
+ if file_extension == 'npz':
154
+ dataset[file_name] = sparse.load_npz(data_path / file)
155
+ elif file_extension == 'npy':
156
+ dataset[file_name] = np.load(data_path / file, allow_pickle=True)
157
+ elif file_extension == 'p':
158
+ with open(data_path / file, 'rb') as f:
159
+ dataset[file_name] = pickle.load(f)
160
+
161
+ clean_data_home(data_netset)
162
+ logger.print_log('Done.')
163
+ return dataset
164
+
165
+
166
+ def load_konect(name: str, data_home: Optional[Union[str, Path]] = None, auto_numpy_bundle: bool = True,
167
+ verbose: bool = True) -> Dataset:
168
+ """Load a dataset from the `Konect database
169
+ <http://konect.cc/networks/>`_.
170
+
171
+ Parameters
172
+ ----------
173
+ name : str
174
+ Name of the dataset as specified on the Konect website (e.g. for the Zachary Karate club dataset,
175
+ the corresponding name is ``'ucidata-zachary'``).
176
+ data_home : str or :class:`pathlib.Path`
177
+ Folder to be used for dataset storage.
178
+ auto_numpy_bundle : bool
179
+ Whether the dataset should be stored in its default format (False) or using Numpy files for faster
180
+ subsequent access to the dataset (True).
181
+ verbose : bool
182
+ Enable verbosity.
183
+
184
+ Returns
185
+ -------
186
+ dataset : :class:`Dataset`
187
+ Object with the following attributes:
188
+
189
+ * `adjacency` or `biadjacency`: the adjacency/biadjacency matrix for the dataset
190
+ * `meta`: a dictionary containing the metadata as specified by Konect
191
+ * each attribute specified by Konect (ent.* file)
192
+
193
+ Notes
194
+ -----
195
+ An attribute `meta` of the `Dataset` class is used to store information about the dataset if present. In any case,
196
+ `meta` has the attribute `name` which, if not given, is equal to the name of the dataset as passed to this function.
197
+
198
+ References
199
+ ----------
200
+ Kunegis, J. (2013, May).
201
+ `Konect: the Koblenz network collection.
202
+ <https://dl.acm.org/doi/abs/10.1145/2487788.2488173>`_
203
+ In Proceedings of the 22nd International Conference on World Wide Web (pp. 1343-1350).
204
+ """
205
+ logger = Log(verbose)
206
+ if name == '':
207
+ raise ValueError("Please specify the dataset. "
208
+ + "\nExamples include 'actor-movie' and 'ego-facebook'."
209
+ + "\n See 'http://konect.cc/networks/' for the full list.")
210
+ data_home = get_data_home(data_home)
211
+ data_konect = data_home / 'konect'
212
+ if not data_konect.exists():
213
+ clean_data_home(data_home)
214
+ makedirs(data_konect)
215
+
216
+ # remove previous dataset if not in the konect folder
217
+ direct_path = data_home / name
218
+ if direct_path.exists():
219
+ shutil.rmtree(direct_path)
220
+
221
+ data_path = data_konect / name
222
+ name_tar = name + '.tar.bz2'
223
+ if not data_path.exists():
224
+ logger.print_log('Downloading', name, 'from Konect...')
225
+ try:
226
+ urlretrieve('http://konect.cc/files/download.tsv.' + name_tar, data_konect / name_tar)
227
+ with tarfile.open(data_konect / name_tar, 'r:bz2') as tar_ref:
228
+ logger.print_log('Unpacking archive...')
229
+ safe_extract(tar_ref, data_path)
230
+ except (HTTPError, tarfile.ReadError):
231
+ raise ValueError('Invalid dataset ' + name + '.'
232
+ + "\nExamples include 'actor-movie' and 'ego-facebook'."
233
+ + "\n See 'http://konect.cc/networks/' for the full list.")
234
+ except (URLError, ConnectionResetError): # pragma: no cover
235
+ raise RuntimeError("Could not reach Konect.")
236
+ elif exists(data_path / (name + '_bundle')):
237
+ logger.print_log('Loading from local bundle...')
238
+ return load_from_numpy_bundle(name + '_bundle', data_path)
239
+
240
+ dataset = Dataset()
241
+ path = data_konect / name / name
242
+ if not path.exists() or len(listdir(path)) == 0:
243
+ raise Exception("No data downloaded.")
244
+ files = [file for file in listdir(path) if name in file]
245
+ logger.print_log('Parsing files...')
246
+ matrix = [file for file in files if 'out.' in file]
247
+ if matrix:
248
+ file = matrix[0]
249
+ directed, bipartite, weighted = load_header(path / file)
250
+ dataset = from_csv(path / file, directed=directed, bipartite=bipartite, weighted=weighted, reindex=True)
251
+
252
+ metadata = [file for file in files if 'meta.' in file]
253
+ if metadata:
254
+ file = metadata[0]
255
+ dataset.meta = load_metadata(path / file)
256
+
257
+ attributes = [file for file in files if 'ent.' + name in file]
258
+ if attributes:
259
+ for file in attributes:
260
+ attribute_name = file.split('.')[-1]
261
+ dataset[attribute_name] = load_labels(path / file)
262
+
263
+ if hasattr(dataset, 'meta'):
264
+ if hasattr(dataset.meta, 'name'):
265
+ pass
266
+ else:
267
+ dataset.meta.name = name
268
+ else:
269
+ dataset.meta = Dataset()
270
+ dataset.meta.name = name
271
+
272
+ if auto_numpy_bundle:
273
+ save_to_numpy_bundle(dataset, name + '_bundle', data_path)
274
+
275
+ clean_data_home(data_konect)
276
+
277
+ return dataset
278
+
279
+
280
+ def save_to_numpy_bundle(data: Dataset, bundle_name: str, data_home: Optional[Union[str, Path]] = None):
281
+ """Save a dataset in the specified data home to a collection of Numpy and Pickle files for faster subsequent loads.
282
+
283
+ Parameters
284
+ ----------
285
+ data: Dataset
286
+ Data to save.
287
+ bundle_name: str
288
+ Name to be used for the bundle folder.
289
+ data_home: str or :class:`pathlib.Path`
290
+ Folder to be used for dataset storage.
291
+ """
292
+ data_home = get_data_home(data_home)
293
+ data_path = data_home / bundle_name
294
+ makedirs(data_path, exist_ok=True)
295
+ for attribute in data:
296
+ if type(data[attribute]) == sparse.csr_matrix:
297
+ sparse.save_npz(data_path / attribute, data[attribute])
298
+ elif type(data[attribute]) == np.ndarray:
299
+ np.save(data_path / attribute, data[attribute])
300
+ else:
301
+ with open(data_path / (attribute + '.p'), 'wb') as file:
302
+ pickle.dump(data[attribute], file)
303
+
304
+
305
+ def load_from_numpy_bundle(bundle_name: str, data_home: Optional[Union[str, Path]] = None):
306
+ """Load a dataset from a collection of Numpy and Pickle files (inverse function of ``save_to_numpy_bundle``).
307
+
308
+ Parameters
309
+ ----------
310
+ bundle_name: str
311
+ Name of the bundle folder.
312
+ data_home: str or :class:`pathlib.Path`
313
+ Folder used for dataset storage.
314
+
315
+ Returns
316
+ -------
317
+ data: Dataset
318
+ Data.
319
+ """
320
+ data_home = get_data_home(data_home)
321
+ data_path = data_home / bundle_name
322
+ if not data_path.exists():
323
+ raise FileNotFoundError('No bundle at ' + str(data_path))
324
+ else:
325
+ files = listdir(data_path)
326
+ data = Dataset()
327
+ for file in files:
328
+ if len(file.split('.')) == 2:
329
+ file_name, file_extension = file.split('.')
330
+ if file_extension == 'npz':
331
+ data[file_name] = sparse.load_npz(data_path / file)
332
+ elif file_extension == 'npy':
333
+ data[file_name] = np.load(data_path / file, allow_pickle=True)
334
+ elif file_extension == 'p':
335
+ with open(data_path / file, 'rb') as f:
336
+ data[file_name] = pickle.load(f)
337
+ return data
338
+
339
+
340
+ def save(folder: Union[str, Path], data: Union[sparse.csr_matrix, Dataset]):
341
+ """Save a dataset or a CSR matrix in the current directory to a collection of Numpy and Pickle files for faster
342
+ subsequent loads. Supported attribute types include sparse matrices, NumPy arrays, strings and objects Dataset.
343
+
344
+ Parameters
345
+ ----------
346
+ folder : str or :class:`pathlib.Path`
347
+ Name of the bundle folder.
348
+ data : Union[sparse.csr_matrix, Dataset]
349
+ Data to save.
350
+
351
+ Example
352
+ -------
353
+ >>> from sknetwork.data import save
354
+ >>> dataset = Dataset()
355
+ >>> dataset.adjacency = sparse.csr_matrix(np.random.random((3, 3)) < 0.5)
356
+ >>> dataset.names = np.array(['a', 'b', 'c'])
357
+ >>> save('dataset', dataset)
358
+ >>> 'dataset' in listdir('.')
359
+ True
360
+ """
361
+ folder = Path(folder)
362
+ folder = folder.expanduser()
363
+ if folder.exists():
364
+ shutil.rmtree(folder)
365
+ if isinstance(data, sparse.csr_matrix):
366
+ dataset = Dataset()
367
+ if is_square(data):
368
+ dataset.adjacency = data
369
+ else:
370
+ dataset.biadjacency = data
371
+ data = dataset
372
+ if folder.is_absolute():
373
+ save_to_numpy_bundle(data, folder, '/')
374
+ else:
375
+ save_to_numpy_bundle(data, folder, '.')
376
+
377
+
378
+ def load(folder: Union[str, Path]):
379
+ """Load a dataset from a previously created bundle from the current directory (inverse function of ``save``).
380
+
381
+ Parameters
382
+ ----------
383
+ folder: str
384
+ Name of the bundle folder.
385
+
386
+ Returns
387
+ -------
388
+ data: Dataset
389
+ Data.
390
+
391
+ Example
392
+ -------
393
+ >>> from sknetwork.data import save
394
+ >>> dataset = Dataset()
395
+ >>> dataset.adjacency = sparse.csr_matrix(np.random.random((3, 3)) < 0.5)
396
+ >>> dataset.names = np.array(['a', 'b', 'c'])
397
+ >>> save('dataset', dataset)
398
+ >>> dataset = load('dataset')
399
+ >>> print(dataset.names)
400
+ ['a' 'b' 'c']
401
+ """
402
+ folder = Path(folder)
403
+ if folder.is_absolute():
404
+ return load_from_numpy_bundle(folder, '/')
405
+ else:
406
+ return load_from_numpy_bundle(folder, '.')