scikit-network 0.33.4__cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. scikit_network-0.33.4.dist-info/METADATA +122 -0
  2. scikit_network-0.33.4.dist-info/RECORD +229 -0
  3. scikit_network-0.33.4.dist-info/WHEEL +6 -0
  4. scikit_network-0.33.4.dist-info/licenses/AUTHORS.rst +43 -0
  5. scikit_network-0.33.4.dist-info/licenses/LICENSE +34 -0
  6. scikit_network-0.33.4.dist-info/top_level.txt +1 -0
  7. scikit_network.libs/libgomp-a34b3233.so.1.0.0 +0 -0
  8. sknetwork/__init__.py +21 -0
  9. sknetwork/base.py +67 -0
  10. sknetwork/classification/__init__.py +8 -0
  11. sknetwork/classification/base.py +138 -0
  12. sknetwork/classification/base_rank.py +129 -0
  13. sknetwork/classification/diffusion.py +127 -0
  14. sknetwork/classification/knn.py +131 -0
  15. sknetwork/classification/metrics.py +205 -0
  16. sknetwork/classification/pagerank.py +58 -0
  17. sknetwork/classification/propagation.py +144 -0
  18. sknetwork/classification/tests/__init__.py +1 -0
  19. sknetwork/classification/tests/test_API.py +30 -0
  20. sknetwork/classification/tests/test_diffusion.py +77 -0
  21. sknetwork/classification/tests/test_knn.py +23 -0
  22. sknetwork/classification/tests/test_metrics.py +53 -0
  23. sknetwork/classification/tests/test_pagerank.py +20 -0
  24. sknetwork/classification/tests/test_propagation.py +24 -0
  25. sknetwork/classification/vote.cpp +27593 -0
  26. sknetwork/classification/vote.cpython-312-x86_64-linux-gnu.so +0 -0
  27. sknetwork/classification/vote.pyx +56 -0
  28. sknetwork/clustering/__init__.py +8 -0
  29. sknetwork/clustering/base.py +168 -0
  30. sknetwork/clustering/kcenters.py +251 -0
  31. sknetwork/clustering/leiden.py +238 -0
  32. sknetwork/clustering/leiden_core.cpp +31928 -0
  33. sknetwork/clustering/leiden_core.cpython-312-x86_64-linux-gnu.so +0 -0
  34. sknetwork/clustering/leiden_core.pyx +124 -0
  35. sknetwork/clustering/louvain.py +282 -0
  36. sknetwork/clustering/louvain_core.cpp +31573 -0
  37. sknetwork/clustering/louvain_core.cpython-312-x86_64-linux-gnu.so +0 -0
  38. sknetwork/clustering/louvain_core.pyx +124 -0
  39. sknetwork/clustering/metrics.py +91 -0
  40. sknetwork/clustering/postprocess.py +66 -0
  41. sknetwork/clustering/propagation_clustering.py +100 -0
  42. sknetwork/clustering/tests/__init__.py +1 -0
  43. sknetwork/clustering/tests/test_API.py +38 -0
  44. sknetwork/clustering/tests/test_kcenters.py +60 -0
  45. sknetwork/clustering/tests/test_leiden.py +34 -0
  46. sknetwork/clustering/tests/test_louvain.py +135 -0
  47. sknetwork/clustering/tests/test_metrics.py +50 -0
  48. sknetwork/clustering/tests/test_postprocess.py +39 -0
  49. sknetwork/data/__init__.py +6 -0
  50. sknetwork/data/base.py +33 -0
  51. sknetwork/data/load.py +292 -0
  52. sknetwork/data/models.py +459 -0
  53. sknetwork/data/parse.py +644 -0
  54. sknetwork/data/test_graphs.py +93 -0
  55. sknetwork/data/tests/__init__.py +1 -0
  56. sknetwork/data/tests/test_API.py +30 -0
  57. sknetwork/data/tests/test_base.py +14 -0
  58. sknetwork/data/tests/test_load.py +61 -0
  59. sknetwork/data/tests/test_models.py +52 -0
  60. sknetwork/data/tests/test_parse.py +250 -0
  61. sknetwork/data/tests/test_test_graphs.py +29 -0
  62. sknetwork/data/tests/test_toy_graphs.py +68 -0
  63. sknetwork/data/timeout.py +38 -0
  64. sknetwork/data/toy_graphs.py +611 -0
  65. sknetwork/embedding/__init__.py +8 -0
  66. sknetwork/embedding/base.py +90 -0
  67. sknetwork/embedding/force_atlas.py +198 -0
  68. sknetwork/embedding/louvain_embedding.py +142 -0
  69. sknetwork/embedding/random_projection.py +131 -0
  70. sknetwork/embedding/spectral.py +137 -0
  71. sknetwork/embedding/spring.py +198 -0
  72. sknetwork/embedding/svd.py +351 -0
  73. sknetwork/embedding/tests/__init__.py +1 -0
  74. sknetwork/embedding/tests/test_API.py +49 -0
  75. sknetwork/embedding/tests/test_force_atlas.py +35 -0
  76. sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
  77. sknetwork/embedding/tests/test_random_projection.py +28 -0
  78. sknetwork/embedding/tests/test_spectral.py +81 -0
  79. sknetwork/embedding/tests/test_spring.py +50 -0
  80. sknetwork/embedding/tests/test_svd.py +43 -0
  81. sknetwork/gnn/__init__.py +10 -0
  82. sknetwork/gnn/activation.py +117 -0
  83. sknetwork/gnn/base.py +181 -0
  84. sknetwork/gnn/base_activation.py +90 -0
  85. sknetwork/gnn/base_layer.py +109 -0
  86. sknetwork/gnn/gnn_classifier.py +305 -0
  87. sknetwork/gnn/layer.py +153 -0
  88. sknetwork/gnn/loss.py +180 -0
  89. sknetwork/gnn/neighbor_sampler.py +65 -0
  90. sknetwork/gnn/optimizer.py +164 -0
  91. sknetwork/gnn/tests/__init__.py +1 -0
  92. sknetwork/gnn/tests/test_activation.py +56 -0
  93. sknetwork/gnn/tests/test_base.py +75 -0
  94. sknetwork/gnn/tests/test_base_layer.py +37 -0
  95. sknetwork/gnn/tests/test_gnn_classifier.py +130 -0
  96. sknetwork/gnn/tests/test_layers.py +80 -0
  97. sknetwork/gnn/tests/test_loss.py +33 -0
  98. sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
  99. sknetwork/gnn/tests/test_optimizer.py +43 -0
  100. sknetwork/gnn/tests/test_utils.py +41 -0
  101. sknetwork/gnn/utils.py +127 -0
  102. sknetwork/hierarchy/__init__.py +6 -0
  103. sknetwork/hierarchy/base.py +90 -0
  104. sknetwork/hierarchy/louvain_hierarchy.py +260 -0
  105. sknetwork/hierarchy/metrics.py +234 -0
  106. sknetwork/hierarchy/paris.cpp +37877 -0
  107. sknetwork/hierarchy/paris.cpython-312-x86_64-linux-gnu.so +0 -0
  108. sknetwork/hierarchy/paris.pyx +310 -0
  109. sknetwork/hierarchy/postprocess.py +350 -0
  110. sknetwork/hierarchy/tests/__init__.py +1 -0
  111. sknetwork/hierarchy/tests/test_API.py +24 -0
  112. sknetwork/hierarchy/tests/test_algos.py +34 -0
  113. sknetwork/hierarchy/tests/test_metrics.py +62 -0
  114. sknetwork/hierarchy/tests/test_postprocess.py +57 -0
  115. sknetwork/linalg/__init__.py +9 -0
  116. sknetwork/linalg/basics.py +37 -0
  117. sknetwork/linalg/diteration.cpp +27409 -0
  118. sknetwork/linalg/diteration.cpython-312-x86_64-linux-gnu.so +0 -0
  119. sknetwork/linalg/diteration.pyx +47 -0
  120. sknetwork/linalg/eig_solver.py +93 -0
  121. sknetwork/linalg/laplacian.py +15 -0
  122. sknetwork/linalg/normalizer.py +86 -0
  123. sknetwork/linalg/operators.py +225 -0
  124. sknetwork/linalg/polynome.py +76 -0
  125. sknetwork/linalg/ppr_solver.py +170 -0
  126. sknetwork/linalg/push.cpp +31081 -0
  127. sknetwork/linalg/push.cpython-312-x86_64-linux-gnu.so +0 -0
  128. sknetwork/linalg/push.pyx +71 -0
  129. sknetwork/linalg/sparse_lowrank.py +142 -0
  130. sknetwork/linalg/svd_solver.py +91 -0
  131. sknetwork/linalg/tests/__init__.py +1 -0
  132. sknetwork/linalg/tests/test_eig.py +44 -0
  133. sknetwork/linalg/tests/test_laplacian.py +18 -0
  134. sknetwork/linalg/tests/test_normalization.py +34 -0
  135. sknetwork/linalg/tests/test_operators.py +66 -0
  136. sknetwork/linalg/tests/test_polynome.py +38 -0
  137. sknetwork/linalg/tests/test_ppr.py +50 -0
  138. sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
  139. sknetwork/linalg/tests/test_svd.py +38 -0
  140. sknetwork/linkpred/__init__.py +2 -0
  141. sknetwork/linkpred/base.py +46 -0
  142. sknetwork/linkpred/nn.py +126 -0
  143. sknetwork/linkpred/tests/__init__.py +1 -0
  144. sknetwork/linkpred/tests/test_nn.py +26 -0
  145. sknetwork/log.py +19 -0
  146. sknetwork/path/__init__.py +5 -0
  147. sknetwork/path/dag.py +54 -0
  148. sknetwork/path/distances.py +98 -0
  149. sknetwork/path/search.py +31 -0
  150. sknetwork/path/shortest_path.py +61 -0
  151. sknetwork/path/tests/__init__.py +1 -0
  152. sknetwork/path/tests/test_dag.py +37 -0
  153. sknetwork/path/tests/test_distances.py +62 -0
  154. sknetwork/path/tests/test_search.py +40 -0
  155. sknetwork/path/tests/test_shortest_path.py +40 -0
  156. sknetwork/ranking/__init__.py +8 -0
  157. sknetwork/ranking/base.py +57 -0
  158. sknetwork/ranking/betweenness.cpp +9716 -0
  159. sknetwork/ranking/betweenness.cpython-312-x86_64-linux-gnu.so +0 -0
  160. sknetwork/ranking/betweenness.pyx +97 -0
  161. sknetwork/ranking/closeness.py +92 -0
  162. sknetwork/ranking/hits.py +90 -0
  163. sknetwork/ranking/katz.py +79 -0
  164. sknetwork/ranking/pagerank.py +106 -0
  165. sknetwork/ranking/postprocess.py +37 -0
  166. sknetwork/ranking/tests/__init__.py +1 -0
  167. sknetwork/ranking/tests/test_API.py +32 -0
  168. sknetwork/ranking/tests/test_betweenness.py +38 -0
  169. sknetwork/ranking/tests/test_closeness.py +30 -0
  170. sknetwork/ranking/tests/test_hits.py +20 -0
  171. sknetwork/ranking/tests/test_pagerank.py +62 -0
  172. sknetwork/ranking/tests/test_postprocess.py +26 -0
  173. sknetwork/regression/__init__.py +4 -0
  174. sknetwork/regression/base.py +57 -0
  175. sknetwork/regression/diffusion.py +204 -0
  176. sknetwork/regression/tests/__init__.py +1 -0
  177. sknetwork/regression/tests/test_API.py +32 -0
  178. sknetwork/regression/tests/test_diffusion.py +56 -0
  179. sknetwork/sknetwork.py +3 -0
  180. sknetwork/test_base.py +35 -0
  181. sknetwork/test_log.py +15 -0
  182. sknetwork/topology/__init__.py +8 -0
  183. sknetwork/topology/cliques.cpp +32574 -0
  184. sknetwork/topology/cliques.cpython-312-x86_64-linux-gnu.so +0 -0
  185. sknetwork/topology/cliques.pyx +149 -0
  186. sknetwork/topology/core.cpp +30660 -0
  187. sknetwork/topology/core.cpython-312-x86_64-linux-gnu.so +0 -0
  188. sknetwork/topology/core.pyx +90 -0
  189. sknetwork/topology/cycles.py +243 -0
  190. sknetwork/topology/minheap.cpp +27341 -0
  191. sknetwork/topology/minheap.cpython-312-x86_64-linux-gnu.so +0 -0
  192. sknetwork/topology/minheap.pxd +20 -0
  193. sknetwork/topology/minheap.pyx +109 -0
  194. sknetwork/topology/structure.py +194 -0
  195. sknetwork/topology/tests/__init__.py +1 -0
  196. sknetwork/topology/tests/test_cliques.py +28 -0
  197. sknetwork/topology/tests/test_core.py +19 -0
  198. sknetwork/topology/tests/test_cycles.py +65 -0
  199. sknetwork/topology/tests/test_structure.py +85 -0
  200. sknetwork/topology/tests/test_triangles.py +38 -0
  201. sknetwork/topology/tests/test_wl.py +72 -0
  202. sknetwork/topology/triangles.cpp +8903 -0
  203. sknetwork/topology/triangles.cpython-312-x86_64-linux-gnu.so +0 -0
  204. sknetwork/topology/triangles.pyx +151 -0
  205. sknetwork/topology/weisfeiler_lehman.py +133 -0
  206. sknetwork/topology/weisfeiler_lehman_core.cpp +27644 -0
  207. sknetwork/topology/weisfeiler_lehman_core.cpython-312-x86_64-linux-gnu.so +0 -0
  208. sknetwork/topology/weisfeiler_lehman_core.pyx +114 -0
  209. sknetwork/utils/__init__.py +7 -0
  210. sknetwork/utils/check.py +355 -0
  211. sknetwork/utils/format.py +221 -0
  212. sknetwork/utils/membership.py +82 -0
  213. sknetwork/utils/neighbors.py +115 -0
  214. sknetwork/utils/tests/__init__.py +1 -0
  215. sknetwork/utils/tests/test_check.py +190 -0
  216. sknetwork/utils/tests/test_format.py +63 -0
  217. sknetwork/utils/tests/test_membership.py +24 -0
  218. sknetwork/utils/tests/test_neighbors.py +41 -0
  219. sknetwork/utils/tests/test_tfidf.py +18 -0
  220. sknetwork/utils/tests/test_values.py +66 -0
  221. sknetwork/utils/tfidf.py +37 -0
  222. sknetwork/utils/values.py +76 -0
  223. sknetwork/visualization/__init__.py +4 -0
  224. sknetwork/visualization/colors.py +34 -0
  225. sknetwork/visualization/dendrograms.py +277 -0
  226. sknetwork/visualization/graphs.py +1039 -0
  227. sknetwork/visualization/tests/__init__.py +1 -0
  228. sknetwork/visualization/tests/test_dendrograms.py +53 -0
  229. sknetwork/visualization/tests/test_graphs.py +176 -0
@@ -0,0 +1,34 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created in March 2020
5
+ @author: Quentin Lutz <qlutz@enst.fr>
6
+ @author: Thomas Bonald <tbonald@enst.fr>
7
+ """
8
+
9
+ import unittest
10
+
11
+ from sknetwork.data.test_graphs import *
12
+ from sknetwork.hierarchy import LouvainIteration, LouvainHierarchy, Paris
13
+
14
+
15
+ class TestLouvainHierarchy(unittest.TestCase):
16
+
17
+ def test(self):
18
+ louvain_iteration = LouvainIteration()
19
+ louvain_iteration_ = LouvainIteration(resolution=2, depth=1)
20
+ louvain_hierarchy = LouvainHierarchy()
21
+ louvain_hierarchy_ = LouvainHierarchy(tol_aggregation=0.1)
22
+ paris = Paris()
23
+ paris_ = Paris(weights='uniform', reorder=False)
24
+ for algo in [louvain_iteration, louvain_iteration_, louvain_hierarchy, louvain_hierarchy_, paris, paris_]:
25
+ for input_matrix in [test_graph(), test_digraph(), test_bigraph()]:
26
+ dendrogram = algo.fit_predict(input_matrix)
27
+ self.assertEqual(dendrogram.shape, (input_matrix.shape[0] - 1, 4))
28
+ if algo.bipartite:
29
+ self.assertEqual(algo.dendrogram_full_.shape, (sum(input_matrix.shape) - 1, 4))
30
+ adjacency = test_graph()
31
+ algo = Paris()
32
+ dendrogram = algo.fit_predict(adjacency)
33
+ dendrogram_ = algo.predict()
34
+ self.assertAlmostEqual(np.linalg.norm(dendrogram - dendrogram_), 0)
@@ -0,0 +1,62 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on March 2019
5
+ @author: Thomas Bonald <bonald@enst.fr>
6
+ """
7
+
8
+ import unittest
9
+
10
+ from sknetwork.data.test_graphs import *
11
+ from sknetwork.data import cyclic_graph
12
+ from sknetwork.hierarchy import Paris, LouvainIteration, dasgupta_cost, dasgupta_score, tree_sampling_divergence
13
+
14
+
15
+ # noinspection PyMissingOrEmptyDocstring
16
+ class TestMetrics(unittest.TestCase):
17
+
18
+ def setUp(self):
19
+ self.paris = Paris()
20
+ self.louvain_iteration = LouvainIteration()
21
+
22
+ def test_undirected(self):
23
+ adjacency = cyclic_graph(3)
24
+ dendrogram = self.paris.fit_predict(adjacency)
25
+ self.assertAlmostEqual(dasgupta_cost(adjacency, dendrogram), 2.666, 2)
26
+ self.assertAlmostEqual(dasgupta_score(adjacency, dendrogram), 0.111, 2)
27
+ self.assertAlmostEqual(tree_sampling_divergence(adjacency, dendrogram), 0.0632, 3)
28
+ self.assertAlmostEqual(tree_sampling_divergence(adjacency, dendrogram, normalized=False), 0.0256, 3)
29
+ adjacency = test_graph()
30
+ dendrogram = self.paris.fit_transform(adjacency)
31
+ self.assertAlmostEqual(dasgupta_cost(adjacency, dendrogram), 4.26, 2)
32
+ self.assertAlmostEqual(dasgupta_score(adjacency, dendrogram), 0.573, 2)
33
+ self.assertAlmostEqual(tree_sampling_divergence(adjacency, dendrogram), 0.304, 2)
34
+ dendrogram = self.louvain_iteration.fit_transform(adjacency)
35
+ self.assertAlmostEqual(dasgupta_cost(adjacency, dendrogram), 4.43, 2)
36
+ self.assertAlmostEqual(dasgupta_score(adjacency, dendrogram), 0.555, 2)
37
+ self.assertAlmostEqual(tree_sampling_divergence(adjacency, dendrogram), 0.286, 2)
38
+
39
+ def test_directed(self):
40
+ adjacency = test_digraph()
41
+ dendrogram = self.paris.fit_transform(adjacency)
42
+ self.assertAlmostEqual(dasgupta_score(adjacency, dendrogram), 0.566, 2)
43
+ self.assertAlmostEqual(tree_sampling_divergence(adjacency, dendrogram), 0.318, 2)
44
+ dendrogram = self.louvain_iteration.fit_transform(adjacency)
45
+ self.assertAlmostEqual(dasgupta_score(adjacency, dendrogram), 0.55, 2)
46
+ self.assertAlmostEqual(tree_sampling_divergence(adjacency, dendrogram), 0.313, 2)
47
+
48
+ def test_disconnected(self):
49
+ adjacency = test_disconnected_graph()
50
+ dendrogram = self.paris.fit_transform(adjacency)
51
+ self.assertAlmostEqual(dasgupta_score(adjacency, dendrogram), 0.682, 2)
52
+ self.assertAlmostEqual(tree_sampling_divergence(adjacency, dendrogram), 0.464, 2)
53
+ dendrogram = self.louvain_iteration.fit_transform(adjacency)
54
+ self.assertAlmostEqual(dasgupta_score(adjacency, dendrogram), 0.670, 2)
55
+ self.assertAlmostEqual(tree_sampling_divergence(adjacency, dendrogram), 0.594, 2)
56
+
57
+ def test_options(self):
58
+ adjacency = test_graph()
59
+ dendrogram = self.paris.fit_transform(adjacency)
60
+ self.assertAlmostEqual(dasgupta_score(adjacency, dendrogram, weights='degree'), 0.573, 2)
61
+ self.assertAlmostEqual(tree_sampling_divergence(adjacency, dendrogram, weights='uniform'), 0.271, 2)
62
+ self.assertAlmostEqual(tree_sampling_divergence(adjacency, dendrogram, normalized=False), 0.367, 2)
@@ -0,0 +1,57 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on March 2019
5
+ @author: Thomas Bonald <bonald@enst.fr>
6
+ @author: Quentin Lutz <qlutz@enst.fr>
7
+ """
8
+
9
+ import unittest
10
+
11
+ from sknetwork.data import karate_club
12
+ from sknetwork.hierarchy import Paris, cut_straight, cut_balanced, aggregate_dendrogram
13
+
14
+
15
+ # noinspection PyMissingOrEmptyDocstring
16
+ class TestCuts(unittest.TestCase):
17
+
18
+ def setUp(self):
19
+ paris = Paris()
20
+ self.adjacency = karate_club()
21
+ self.dendrogram = paris.fit_transform(self.adjacency)
22
+
23
+ def test_cuts(self):
24
+ labels = cut_straight(self.dendrogram)
25
+ self.assertEqual(len(set(labels)), 2)
26
+ labels = cut_straight(self.dendrogram, n_clusters=5)
27
+ self.assertEqual(len(set(labels)), 5)
28
+ labels = cut_balanced(self.dendrogram, 2)
29
+ self.assertEqual(len(set(labels)), 21)
30
+ labels, new_dendrogram = cut_balanced(self.dendrogram, max_cluster_size=4, return_dendrogram=True)
31
+ self.assertEqual(len(set(labels)), 12)
32
+ self.assertTupleEqual(new_dendrogram.shape, (11, 4))
33
+ paris = Paris(reorder=False)
34
+ dendrogram = paris.fit_predict(self.adjacency)
35
+ labels = cut_balanced(dendrogram, 4)
36
+ self.assertEqual(len(set(labels)), 12)
37
+
38
+ def test_options(self):
39
+ labels = cut_straight(self.dendrogram, threshold=0.5)
40
+ self.assertEqual(len(set(labels)), 7)
41
+ labels = cut_straight(self.dendrogram, n_clusters=3, threshold=0.5)
42
+ self.assertEqual(len(set(labels)), 3)
43
+ labels = cut_straight(self.dendrogram, sort_clusters=False)
44
+ self.assertEqual(len(set(labels)), 2)
45
+ labels = cut_balanced(self.dendrogram, max_cluster_size=2, sort_clusters=False)
46
+ self.assertEqual(len(set(labels)), 21)
47
+ labels = cut_balanced(self.dendrogram, max_cluster_size=10)
48
+ self.assertEqual(len(set(labels)), 5)
49
+
50
+ def test_aggregation(self):
51
+ aggregated = aggregate_dendrogram(self.dendrogram, n_clusters=3)
52
+ self.assertEqual(len(aggregated), 2)
53
+
54
+ aggregated, counts = aggregate_dendrogram(self.dendrogram, n_clusters=3, return_counts=True)
55
+ self.assertEqual(len(aggregated), 2)
56
+ self.assertEqual(len(counts), 3)
57
+
@@ -0,0 +1,9 @@
1
+ """Module of linear algebra."""
2
+ from sknetwork.linalg.basics import safe_sparse_dot
3
+ from sknetwork.linalg.eig_solver import EigSolver, LanczosEig
4
+ from sknetwork.linalg.laplacian import get_laplacian
5
+ from sknetwork.linalg.normalizer import diagonal_pseudo_inverse, get_norms, normalize
6
+ from sknetwork.linalg.operators import Regularizer, Laplacian, Normalizer, CoNeighbor
7
+ from sknetwork.linalg.polynome import Polynome
8
+ from sknetwork.linalg.sparse_lowrank import SparseLR
9
+ from sknetwork.linalg.svd_solver import SVDSolver, LanczosSVD
@@ -0,0 +1,37 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Apr 2020
5
+ @author: Nathan de Lara <nathan.delara@polytechnique.org>
6
+ """
7
+
8
+ import numpy as np
9
+ from scipy import sparse
10
+ from scipy.sparse.linalg import LinearOperator
11
+
12
+
13
+ def safe_sparse_dot(a, b):
14
+ """Dot product with proper use of the sparse matrix format.
15
+ Use BLAS instead of numpy.dot when possible to avoid unnecessary copies.
16
+
17
+ Parameters
18
+ ----------
19
+ a : array, sparse matrix or LinearOperator
20
+ b : array, sparse matrix or LinearOperator
21
+ Returns
22
+ -------
23
+ dot_product : array or sparse matrix
24
+ sparse if ``a`` or ``b`` is sparse.
25
+ """
26
+ if type(a) == np.ndarray:
27
+ return b.T.dot(a.T).T
28
+ if isinstance(a, LinearOperator) and isinstance(b, LinearOperator):
29
+ raise NotImplementedError
30
+ if hasattr(a, 'right_sparse_dot') and type(b) == sparse.csr_matrix:
31
+ if callable(a.right_sparse_dot):
32
+ return a.right_sparse_dot(b)
33
+ if hasattr(b, 'left_sparse_dot') and type(a) == sparse.csr_matrix:
34
+ if callable(b.left_sparse_dot):
35
+ return b.left_sparse_dot(a)
36
+ else:
37
+ return a.dot(b)