scikit-network 0.28.3__cp39-cp39-macosx_12_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-network might be problematic. Click here for more details.
- scikit_network-0.28.3.dist-info/AUTHORS.rst +41 -0
- scikit_network-0.28.3.dist-info/LICENSE +34 -0
- scikit_network-0.28.3.dist-info/METADATA +457 -0
- scikit_network-0.28.3.dist-info/RECORD +240 -0
- scikit_network-0.28.3.dist-info/WHEEL +5 -0
- scikit_network-0.28.3.dist-info/top_level.txt +1 -0
- sknetwork/__init__.py +21 -0
- sknetwork/classification/__init__.py +8 -0
- sknetwork/classification/base.py +84 -0
- sknetwork/classification/base_rank.py +143 -0
- sknetwork/classification/diffusion.py +134 -0
- sknetwork/classification/knn.py +162 -0
- sknetwork/classification/metrics.py +205 -0
- sknetwork/classification/pagerank.py +66 -0
- sknetwork/classification/propagation.py +152 -0
- sknetwork/classification/tests/__init__.py +1 -0
- sknetwork/classification/tests/test_API.py +35 -0
- sknetwork/classification/tests/test_diffusion.py +37 -0
- sknetwork/classification/tests/test_knn.py +24 -0
- sknetwork/classification/tests/test_metrics.py +53 -0
- sknetwork/classification/tests/test_pagerank.py +20 -0
- sknetwork/classification/tests/test_propagation.py +24 -0
- sknetwork/classification/vote.cpython-39-darwin.so +0 -0
- sknetwork/classification/vote.pyx +58 -0
- sknetwork/clustering/__init__.py +7 -0
- sknetwork/clustering/base.py +102 -0
- sknetwork/clustering/kmeans.py +142 -0
- sknetwork/clustering/louvain.py +255 -0
- sknetwork/clustering/louvain_core.cpython-39-darwin.so +0 -0
- sknetwork/clustering/louvain_core.pyx +134 -0
- sknetwork/clustering/metrics.py +91 -0
- sknetwork/clustering/postprocess.py +66 -0
- sknetwork/clustering/propagation_clustering.py +108 -0
- sknetwork/clustering/tests/__init__.py +1 -0
- sknetwork/clustering/tests/test_API.py +37 -0
- sknetwork/clustering/tests/test_kmeans.py +47 -0
- sknetwork/clustering/tests/test_louvain.py +104 -0
- sknetwork/clustering/tests/test_metrics.py +50 -0
- sknetwork/clustering/tests/test_post_processing.py +23 -0
- sknetwork/clustering/tests/test_postprocess.py +39 -0
- sknetwork/data/__init__.py +5 -0
- sknetwork/data/load.py +408 -0
- sknetwork/data/models.py +459 -0
- sknetwork/data/parse.py +621 -0
- sknetwork/data/test_graphs.py +84 -0
- sknetwork/data/tests/__init__.py +1 -0
- sknetwork/data/tests/test_API.py +30 -0
- sknetwork/data/tests/test_load.py +95 -0
- sknetwork/data/tests/test_models.py +52 -0
- sknetwork/data/tests/test_parse.py +253 -0
- sknetwork/data/tests/test_test_graphs.py +30 -0
- sknetwork/data/tests/test_toy_graphs.py +68 -0
- sknetwork/data/toy_graphs.py +619 -0
- sknetwork/embedding/__init__.py +10 -0
- sknetwork/embedding/base.py +90 -0
- sknetwork/embedding/force_atlas.py +197 -0
- sknetwork/embedding/louvain_embedding.py +174 -0
- sknetwork/embedding/louvain_hierarchy.py +142 -0
- sknetwork/embedding/metrics.py +66 -0
- sknetwork/embedding/random_projection.py +133 -0
- sknetwork/embedding/spectral.py +214 -0
- sknetwork/embedding/spring.py +198 -0
- sknetwork/embedding/svd.py +363 -0
- sknetwork/embedding/tests/__init__.py +1 -0
- sknetwork/embedding/tests/test_API.py +73 -0
- sknetwork/embedding/tests/test_force_atlas.py +35 -0
- sknetwork/embedding/tests/test_louvain_embedding.py +33 -0
- sknetwork/embedding/tests/test_louvain_hierarchy.py +19 -0
- sknetwork/embedding/tests/test_metrics.py +29 -0
- sknetwork/embedding/tests/test_random_projection.py +28 -0
- sknetwork/embedding/tests/test_spectral.py +84 -0
- sknetwork/embedding/tests/test_spring.py +50 -0
- sknetwork/embedding/tests/test_svd.py +37 -0
- sknetwork/flow/__init__.py +3 -0
- sknetwork/flow/flow.py +73 -0
- sknetwork/flow/tests/__init__.py +1 -0
- sknetwork/flow/tests/test_flow.py +17 -0
- sknetwork/flow/tests/test_utils.py +69 -0
- sknetwork/flow/utils.py +91 -0
- sknetwork/gnn/__init__.py +10 -0
- sknetwork/gnn/activation.py +117 -0
- sknetwork/gnn/base.py +155 -0
- sknetwork/gnn/base_activation.py +89 -0
- sknetwork/gnn/base_layer.py +109 -0
- sknetwork/gnn/gnn_classifier.py +381 -0
- sknetwork/gnn/layer.py +153 -0
- sknetwork/gnn/layers.py +127 -0
- sknetwork/gnn/loss.py +180 -0
- sknetwork/gnn/neighbor_sampler.py +65 -0
- sknetwork/gnn/optimizer.py +163 -0
- sknetwork/gnn/tests/__init__.py +1 -0
- sknetwork/gnn/tests/test_activation.py +56 -0
- sknetwork/gnn/tests/test_base.py +79 -0
- sknetwork/gnn/tests/test_base_layer.py +37 -0
- sknetwork/gnn/tests/test_gnn_classifier.py +192 -0
- sknetwork/gnn/tests/test_layers.py +80 -0
- sknetwork/gnn/tests/test_loss.py +33 -0
- sknetwork/gnn/tests/test_neigh_sampler.py +23 -0
- sknetwork/gnn/tests/test_optimizer.py +43 -0
- sknetwork/gnn/tests/test_utils.py +93 -0
- sknetwork/gnn/utils.py +219 -0
- sknetwork/hierarchy/__init__.py +7 -0
- sknetwork/hierarchy/base.py +69 -0
- sknetwork/hierarchy/louvain_hierarchy.py +264 -0
- sknetwork/hierarchy/metrics.py +234 -0
- sknetwork/hierarchy/paris.cpython-39-darwin.so +0 -0
- sknetwork/hierarchy/paris.pyx +317 -0
- sknetwork/hierarchy/postprocess.py +350 -0
- sknetwork/hierarchy/tests/__init__.py +1 -0
- sknetwork/hierarchy/tests/test_API.py +25 -0
- sknetwork/hierarchy/tests/test_algos.py +29 -0
- sknetwork/hierarchy/tests/test_metrics.py +62 -0
- sknetwork/hierarchy/tests/test_postprocess.py +57 -0
- sknetwork/hierarchy/tests/test_ward.py +25 -0
- sknetwork/hierarchy/ward.py +94 -0
- sknetwork/linalg/__init__.py +9 -0
- sknetwork/linalg/basics.py +37 -0
- sknetwork/linalg/diteration.cpython-39-darwin.so +0 -0
- sknetwork/linalg/diteration.pyx +49 -0
- sknetwork/linalg/eig_solver.py +93 -0
- sknetwork/linalg/laplacian.py +15 -0
- sknetwork/linalg/normalization.py +66 -0
- sknetwork/linalg/operators.py +225 -0
- sknetwork/linalg/polynome.py +76 -0
- sknetwork/linalg/ppr_solver.py +170 -0
- sknetwork/linalg/push.cpython-39-darwin.so +0 -0
- sknetwork/linalg/push.pyx +73 -0
- sknetwork/linalg/sparse_lowrank.py +142 -0
- sknetwork/linalg/svd_solver.py +91 -0
- sknetwork/linalg/tests/__init__.py +1 -0
- sknetwork/linalg/tests/test_eig.py +44 -0
- sknetwork/linalg/tests/test_laplacian.py +18 -0
- sknetwork/linalg/tests/test_normalization.py +38 -0
- sknetwork/linalg/tests/test_operators.py +70 -0
- sknetwork/linalg/tests/test_polynome.py +38 -0
- sknetwork/linalg/tests/test_ppr.py +50 -0
- sknetwork/linalg/tests/test_sparse_lowrank.py +61 -0
- sknetwork/linalg/tests/test_svd.py +38 -0
- sknetwork/linkpred/__init__.py +4 -0
- sknetwork/linkpred/base.py +80 -0
- sknetwork/linkpred/first_order.py +508 -0
- sknetwork/linkpred/first_order_core.cpython-39-darwin.so +0 -0
- sknetwork/linkpred/first_order_core.pyx +315 -0
- sknetwork/linkpred/postprocessing.py +98 -0
- sknetwork/linkpred/tests/__init__.py +1 -0
- sknetwork/linkpred/tests/test_API.py +49 -0
- sknetwork/linkpred/tests/test_postprocessing.py +21 -0
- sknetwork/path/__init__.py +4 -0
- sknetwork/path/metrics.py +148 -0
- sknetwork/path/search.py +65 -0
- sknetwork/path/shortest_path.py +186 -0
- sknetwork/path/tests/__init__.py +1 -0
- sknetwork/path/tests/test_metrics.py +29 -0
- sknetwork/path/tests/test_search.py +25 -0
- sknetwork/path/tests/test_shortest_path.py +45 -0
- sknetwork/ranking/__init__.py +9 -0
- sknetwork/ranking/base.py +56 -0
- sknetwork/ranking/betweenness.cpython-39-darwin.so +0 -0
- sknetwork/ranking/betweenness.pyx +99 -0
- sknetwork/ranking/closeness.py +95 -0
- sknetwork/ranking/harmonic.py +82 -0
- sknetwork/ranking/hits.py +94 -0
- sknetwork/ranking/katz.py +81 -0
- sknetwork/ranking/pagerank.py +107 -0
- sknetwork/ranking/postprocess.py +25 -0
- sknetwork/ranking/tests/__init__.py +1 -0
- sknetwork/ranking/tests/test_API.py +34 -0
- sknetwork/ranking/tests/test_betweenness.py +38 -0
- sknetwork/ranking/tests/test_closeness.py +34 -0
- sknetwork/ranking/tests/test_hits.py +20 -0
- sknetwork/ranking/tests/test_pagerank.py +69 -0
- sknetwork/regression/__init__.py +4 -0
- sknetwork/regression/base.py +56 -0
- sknetwork/regression/diffusion.py +190 -0
- sknetwork/regression/tests/__init__.py +1 -0
- sknetwork/regression/tests/test_API.py +34 -0
- sknetwork/regression/tests/test_diffusion.py +48 -0
- sknetwork/sknetwork.py +3 -0
- sknetwork/topology/__init__.py +9 -0
- sknetwork/topology/dag.py +74 -0
- sknetwork/topology/dag_core.cpython-39-darwin.so +0 -0
- sknetwork/topology/dag_core.pyx +38 -0
- sknetwork/topology/kcliques.cpython-39-darwin.so +0 -0
- sknetwork/topology/kcliques.pyx +193 -0
- sknetwork/topology/kcore.cpython-39-darwin.so +0 -0
- sknetwork/topology/kcore.pyx +120 -0
- sknetwork/topology/structure.py +234 -0
- sknetwork/topology/tests/__init__.py +1 -0
- sknetwork/topology/tests/test_cliques.py +28 -0
- sknetwork/topology/tests/test_cores.py +21 -0
- sknetwork/topology/tests/test_dag.py +26 -0
- sknetwork/topology/tests/test_structure.py +99 -0
- sknetwork/topology/tests/test_triangles.py +42 -0
- sknetwork/topology/tests/test_wl_coloring.py +49 -0
- sknetwork/topology/tests/test_wl_kernel.py +31 -0
- sknetwork/topology/triangles.cpython-39-darwin.so +0 -0
- sknetwork/topology/triangles.pyx +166 -0
- sknetwork/topology/weisfeiler_lehman.py +163 -0
- sknetwork/topology/weisfeiler_lehman_core.cpython-39-darwin.so +0 -0
- sknetwork/topology/weisfeiler_lehman_core.pyx +116 -0
- sknetwork/utils/__init__.py +40 -0
- sknetwork/utils/base.py +35 -0
- sknetwork/utils/check.py +354 -0
- sknetwork/utils/co_neighbor.py +71 -0
- sknetwork/utils/format.py +219 -0
- sknetwork/utils/kmeans.py +89 -0
- sknetwork/utils/knn.py +166 -0
- sknetwork/utils/knn1d.cpython-39-darwin.so +0 -0
- sknetwork/utils/knn1d.pyx +80 -0
- sknetwork/utils/membership.py +82 -0
- sknetwork/utils/minheap.cpython-39-darwin.so +0 -0
- sknetwork/utils/minheap.pxd +22 -0
- sknetwork/utils/minheap.pyx +111 -0
- sknetwork/utils/neighbors.py +115 -0
- sknetwork/utils/seeds.py +75 -0
- sknetwork/utils/simplex.py +140 -0
- sknetwork/utils/tests/__init__.py +1 -0
- sknetwork/utils/tests/test_base.py +28 -0
- sknetwork/utils/tests/test_bunch.py +16 -0
- sknetwork/utils/tests/test_check.py +190 -0
- sknetwork/utils/tests/test_co_neighbor.py +43 -0
- sknetwork/utils/tests/test_format.py +61 -0
- sknetwork/utils/tests/test_kmeans.py +21 -0
- sknetwork/utils/tests/test_knn.py +32 -0
- sknetwork/utils/tests/test_membership.py +24 -0
- sknetwork/utils/tests/test_neighbors.py +41 -0
- sknetwork/utils/tests/test_projection_simplex.py +33 -0
- sknetwork/utils/tests/test_seeds.py +67 -0
- sknetwork/utils/tests/test_verbose.py +15 -0
- sknetwork/utils/tests/test_ward.py +20 -0
- sknetwork/utils/timeout.py +38 -0
- sknetwork/utils/verbose.py +37 -0
- sknetwork/utils/ward.py +60 -0
- sknetwork/visualization/__init__.py +4 -0
- sknetwork/visualization/colors.py +34 -0
- sknetwork/visualization/dendrograms.py +229 -0
- sknetwork/visualization/graphs.py +819 -0
- sknetwork/visualization/tests/__init__.py +1 -0
- sknetwork/visualization/tests/test_dendrograms.py +53 -0
- sknetwork/visualization/tests/test_graphs.py +167 -0
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created on November 19 2019
|
|
5
|
+
@author: Quentin Lutz <qlutz@enst.fr>
|
|
6
|
+
"""
|
|
7
|
+
from typing import Union, Optional
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from scipy import sparse
|
|
11
|
+
|
|
12
|
+
from sknetwork.path.shortest_path import get_distances
|
|
13
|
+
from sknetwork.ranking.base import BaseRanking
|
|
14
|
+
from sknetwork.utils.check import check_format, check_square
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Harmonic(BaseRanking):
|
|
18
|
+
"""Harmonic centrality of each node in a connected graph, corresponding to the average inverse length of
|
|
19
|
+
the shortest paths from that node to all the other ones.
|
|
20
|
+
|
|
21
|
+
For a directed graph, the harmonic centrality is computed in terms of outgoing paths.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
n_jobs:
|
|
26
|
+
If an integer value is given, denotes the number of workers to use (-1 means the maximum number will be used).
|
|
27
|
+
If ``None``, no parallel computations are made.
|
|
28
|
+
|
|
29
|
+
Attributes
|
|
30
|
+
----------
|
|
31
|
+
scores_ : np.ndarray
|
|
32
|
+
Score of each node.
|
|
33
|
+
|
|
34
|
+
Example
|
|
35
|
+
-------
|
|
36
|
+
>>> from sknetwork.ranking import Harmonic
|
|
37
|
+
>>> from sknetwork.data import house
|
|
38
|
+
>>> harmonic = Harmonic()
|
|
39
|
+
>>> adjacency = house()
|
|
40
|
+
>>> scores = harmonic.fit_predict(adjacency)
|
|
41
|
+
>>> np.round(scores, 2)
|
|
42
|
+
array([3. , 3.5, 3. , 3. , 3.5])
|
|
43
|
+
|
|
44
|
+
References
|
|
45
|
+
----------
|
|
46
|
+
Marchiori, M., & Latora, V. (2000).
|
|
47
|
+
`Harmony in the small-world.
|
|
48
|
+
<https://arxiv.org/pdf/cond-mat/0008357.pdf>`_
|
|
49
|
+
Physica A: Statistical Mechanics and its Applications, 285(3-4), 539-546.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, n_jobs: Optional[int] = None):
|
|
53
|
+
super(Harmonic, self).__init__()
|
|
54
|
+
|
|
55
|
+
self.n_jobs = n_jobs
|
|
56
|
+
|
|
57
|
+
def fit(self, adjacency: Union[sparse.csr_matrix, np.ndarray]) -> 'Harmonic':
|
|
58
|
+
"""Harmonic centrality for connected graphs.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
adjacency :
|
|
63
|
+
Adjacency matrix of the graph.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
self: :class:`Harmonic`
|
|
68
|
+
"""
|
|
69
|
+
adjacency = check_format(adjacency)
|
|
70
|
+
check_square(adjacency)
|
|
71
|
+
n = adjacency.shape[0]
|
|
72
|
+
indices = np.arange(n)
|
|
73
|
+
|
|
74
|
+
dists = get_distances(adjacency, n_jobs=self.n_jobs, sources=indices)
|
|
75
|
+
|
|
76
|
+
np.fill_diagonal(dists, 1)
|
|
77
|
+
inv = (1 / dists)
|
|
78
|
+
np.fill_diagonal(inv, 0)
|
|
79
|
+
|
|
80
|
+
self.scores_ = inv.dot(np.ones(n))
|
|
81
|
+
|
|
82
|
+
return self
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created on Oct 07 2019
|
|
5
|
+
@author: Nathan de Lara <nathan.delara@polytechnique.org>
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from scipy import sparse
|
|
12
|
+
|
|
13
|
+
from sknetwork.linalg import SVDSolver, LanczosSVD
|
|
14
|
+
from sknetwork.ranking.base import BaseRanking
|
|
15
|
+
from sknetwork.utils.check import check_format
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class HITS(BaseRanking):
|
|
19
|
+
"""Hub and authority scores of each node.
|
|
20
|
+
For bipartite graphs, the hub score is computed on rows and the authority score on columns.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
solver : ``'lanczos'`` (default, Lanczos algorithm) or :class:`SVDSolver` (custom solver)
|
|
25
|
+
Which solver to use.
|
|
26
|
+
|
|
27
|
+
Attributes
|
|
28
|
+
----------
|
|
29
|
+
scores_ : np.ndarray
|
|
30
|
+
Hub score of each node.
|
|
31
|
+
scores_row_ : np.ndarray
|
|
32
|
+
Hub score of each row, for bipartite graphs.
|
|
33
|
+
scores_col_ : np.ndarray
|
|
34
|
+
Authority score of each column, for bipartite graphs.
|
|
35
|
+
|
|
36
|
+
Example
|
|
37
|
+
-------
|
|
38
|
+
>>> from sknetwork.ranking import HITS
|
|
39
|
+
>>> from sknetwork.data import star_wars
|
|
40
|
+
>>> hits = HITS()
|
|
41
|
+
>>> biadjacency = star_wars()
|
|
42
|
+
>>> scores = hits.fit_predict(biadjacency)
|
|
43
|
+
>>> np.round(scores, 2)
|
|
44
|
+
array([0.5 , 0.23, 0.69, 0.46])
|
|
45
|
+
|
|
46
|
+
References
|
|
47
|
+
----------
|
|
48
|
+
Kleinberg, J. M. (1999). Authoritative sources in a hyperlinked environment.
|
|
49
|
+
Journal of the ACM, 46(5), 604-632.
|
|
50
|
+
"""
|
|
51
|
+
def __init__(self, solver: Union[str, SVDSolver] = 'lanczos'):
|
|
52
|
+
super(HITS, self).__init__()
|
|
53
|
+
|
|
54
|
+
if type(solver) == str:
|
|
55
|
+
self.solver: SVDSolver = LanczosSVD()
|
|
56
|
+
else:
|
|
57
|
+
self.solver = solver
|
|
58
|
+
|
|
59
|
+
def fit(self, adjacency: Union[sparse.csr_matrix, np.ndarray]) -> 'HITS':
|
|
60
|
+
"""Compute HITS algorithm with a spectral method.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
adjacency :
|
|
65
|
+
Adjacency or biadjacency matrix of the graph.
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
self: :class:`HITS`
|
|
70
|
+
"""
|
|
71
|
+
adjacency = check_format(adjacency)
|
|
72
|
+
|
|
73
|
+
self.solver.fit(adjacency, 1)
|
|
74
|
+
hubs: np.ndarray = self.solver.singular_vectors_left_.reshape(-1)
|
|
75
|
+
authorities: np.ndarray = self.solver.singular_vectors_right_.reshape(-1)
|
|
76
|
+
|
|
77
|
+
h_pos, h_neg = (hubs > 0).sum(), (hubs < 0).sum()
|
|
78
|
+
a_pos, a_neg = (authorities > 0).sum(), (authorities < 0).sum()
|
|
79
|
+
|
|
80
|
+
if h_pos > h_neg:
|
|
81
|
+
hubs = np.clip(hubs, a_min=0., a_max=None)
|
|
82
|
+
else:
|
|
83
|
+
hubs = np.clip(-hubs, a_min=0., a_max=None)
|
|
84
|
+
|
|
85
|
+
if a_pos > a_neg:
|
|
86
|
+
authorities = np.clip(authorities, a_min=0., a_max=None)
|
|
87
|
+
else:
|
|
88
|
+
authorities = np.clip(-authorities, a_min=0., a_max=None)
|
|
89
|
+
|
|
90
|
+
self.scores_row_ = hubs
|
|
91
|
+
self.scores_col_ = authorities
|
|
92
|
+
self.scores_ = hubs
|
|
93
|
+
|
|
94
|
+
return self
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created on May 2020
|
|
5
|
+
@author: Nathan de Lara <nathan.delara@polytechnique.org>
|
|
6
|
+
"""
|
|
7
|
+
from typing import Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from scipy import sparse
|
|
11
|
+
from scipy.sparse.linalg import LinearOperator
|
|
12
|
+
|
|
13
|
+
from sknetwork.linalg.polynome import Polynome
|
|
14
|
+
from sknetwork.ranking.base import BaseRanking
|
|
15
|
+
from sknetwork.utils.check import check_format
|
|
16
|
+
from sknetwork.utils.format import get_adjacency
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Katz(BaseRanking):
|
|
20
|
+
"""Katz centrality, defined by:
|
|
21
|
+
|
|
22
|
+
:math:`\\sum_{k=1}^K\\alpha^k(A^k)^T\\mathbf{1}`.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
damping_factor : float
|
|
27
|
+
Decay parameter for path contributions.
|
|
28
|
+
path_length : int
|
|
29
|
+
Maximum length of the paths to take into account.
|
|
30
|
+
|
|
31
|
+
Attributes
|
|
32
|
+
----------
|
|
33
|
+
scores_ : np.ndarray
|
|
34
|
+
Score of each node.
|
|
35
|
+
scores_row_: np.ndarray
|
|
36
|
+
Scores of rows, for bipartite graphs.
|
|
37
|
+
scores_col_: np.ndarray
|
|
38
|
+
Scores of columns, for bipartite graphs.
|
|
39
|
+
|
|
40
|
+
Examples
|
|
41
|
+
--------
|
|
42
|
+
>>> from sknetwork.data.toy_graphs import house
|
|
43
|
+
>>> adjacency = house()
|
|
44
|
+
>>> katz = Katz()
|
|
45
|
+
>>> scores = katz.fit_predict(adjacency)
|
|
46
|
+
>>> np.round(scores, 2)
|
|
47
|
+
array([6.5 , 8.25, 5.62, 5.62, 8.25])
|
|
48
|
+
|
|
49
|
+
References
|
|
50
|
+
----------
|
|
51
|
+
Katz, L. (1953). `A new status index derived from sociometric analysis
|
|
52
|
+
<https://link.springer.com/content/pdf/10.1007/BF02289026.pdf>`_. Psychometrika, 18(1), 39-43.
|
|
53
|
+
"""
|
|
54
|
+
def __init__(self, damping_factor: float = 0.5, path_length: int = 4):
|
|
55
|
+
super(Katz, self).__init__()
|
|
56
|
+
self.damping_factor = damping_factor
|
|
57
|
+
self.path_length = path_length
|
|
58
|
+
self.bipartite = None
|
|
59
|
+
|
|
60
|
+
def fit(self, input_matrix: Union[sparse.csr_matrix, np.ndarray, LinearOperator]) -> 'Katz':
|
|
61
|
+
"""Katz centrality.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
input_matrix :
|
|
66
|
+
Adjacency matrix or biadjacency matrix of the graph.
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
self: :class:`Katz`
|
|
71
|
+
"""
|
|
72
|
+
input_matrix = check_format(input_matrix)
|
|
73
|
+
adjacency, self.bipartite = get_adjacency(input_matrix)
|
|
74
|
+
n = adjacency.shape[0]
|
|
75
|
+
coefs = self.damping_factor ** np.arange(self.path_length + 1)
|
|
76
|
+
coefs[0] = 0.
|
|
77
|
+
polynome = Polynome(adjacency.T.astype(bool).tocsr(), coefs)
|
|
78
|
+
self.scores_ = polynome.dot(np.ones(n))
|
|
79
|
+
if self.bipartite:
|
|
80
|
+
self._split_vars(input_matrix.shape)
|
|
81
|
+
return self
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created on May 2019
|
|
5
|
+
@author: Nathan de Lara <nathan.delara@polytechnique.org>
|
|
6
|
+
@author: Thomas Bonald <bonald@enst.fr>
|
|
7
|
+
"""
|
|
8
|
+
from typing import Union, Optional
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from scipy import sparse
|
|
12
|
+
from scipy.sparse.linalg import LinearOperator
|
|
13
|
+
|
|
14
|
+
from sknetwork.linalg.ppr_solver import get_pagerank
|
|
15
|
+
from sknetwork.ranking.base import BaseRanking
|
|
16
|
+
from sknetwork.utils.check import check_damping_factor
|
|
17
|
+
from sknetwork.utils.format import get_adjacency_seeds
|
|
18
|
+
from sknetwork.utils.verbose import VerboseMixin
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PageRank(BaseRanking, VerboseMixin):
|
|
22
|
+
"""PageRank of each node, corresponding to its frequency of visit by a random walk.
|
|
23
|
+
|
|
24
|
+
The random walk restarts with some fixed probability. The restart distribution can be personalized by the user.
|
|
25
|
+
This variant is known as Personalized PageRank.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
damping_factor : float
|
|
30
|
+
Probability to continue the random walk.
|
|
31
|
+
solver : str
|
|
32
|
+
* ``'piteration'``, use power iteration for a given number of iterations.
|
|
33
|
+
* ``'diteration'``, use asynchronous parallel diffusion for a given number of iterations.
|
|
34
|
+
* ``'lanczos'``, use eigensolver with a given tolerance.
|
|
35
|
+
* ``'bicgstab'``, use Biconjugate Gradient Stabilized method for a given tolerance.
|
|
36
|
+
* ``'RH'``, use a Ruffini-Horner polynomial evaluation.
|
|
37
|
+
* ``'push'``, use push-based algorithm for a given tolerance
|
|
38
|
+
n_iter : int
|
|
39
|
+
Number of iterations for some solvers.
|
|
40
|
+
tol : float
|
|
41
|
+
Tolerance for the convergence of some solvers.
|
|
42
|
+
|
|
43
|
+
Attributes
|
|
44
|
+
----------
|
|
45
|
+
scores_ : np.ndarray
|
|
46
|
+
PageRank score of each node.
|
|
47
|
+
scores_row_: np.ndarray
|
|
48
|
+
Scores of rows, for bipartite graphs.
|
|
49
|
+
scores_col_: np.ndarray
|
|
50
|
+
Scores of columns, for bipartite graphs.
|
|
51
|
+
|
|
52
|
+
Example
|
|
53
|
+
-------
|
|
54
|
+
>>> from sknetwork.ranking import PageRank
|
|
55
|
+
>>> from sknetwork.data import house
|
|
56
|
+
>>> pagerank = PageRank()
|
|
57
|
+
>>> adjacency = house()
|
|
58
|
+
>>> seeds = {0: 1}
|
|
59
|
+
>>> scores = pagerank.fit_predict(adjacency, seeds)
|
|
60
|
+
>>> np.round(scores, 2)
|
|
61
|
+
array([0.29, 0.24, 0.12, 0.12, 0.24])
|
|
62
|
+
|
|
63
|
+
References
|
|
64
|
+
----------
|
|
65
|
+
Page, L., Brin, S., Motwani, R., & Winograd, T. (1999). The PageRank citation ranking: Bringing order to the web.
|
|
66
|
+
Stanford InfoLab.
|
|
67
|
+
"""
|
|
68
|
+
def __init__(self, damping_factor: float = 0.85, solver: str = 'piteration', n_iter: int = 10, tol: float = 1e-6):
|
|
69
|
+
super(PageRank, self).__init__()
|
|
70
|
+
check_damping_factor(damping_factor)
|
|
71
|
+
self.damping_factor = damping_factor
|
|
72
|
+
self.solver = solver
|
|
73
|
+
self.n_iter = n_iter
|
|
74
|
+
self.tol = tol
|
|
75
|
+
self.bipartite = None
|
|
76
|
+
|
|
77
|
+
def fit(self, input_matrix: Union[sparse.csr_matrix, np.ndarray, LinearOperator],
|
|
78
|
+
seeds: Optional[Union[dict, np.ndarray]] = None, seeds_row: Optional[Union[dict, np.ndarray]] = None,
|
|
79
|
+
seeds_col: Optional[Union[dict, np.ndarray]] = None, force_bipartite: bool = False) -> 'PageRank':
|
|
80
|
+
"""Fit algorithm to data.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
input_matrix :
|
|
85
|
+
Adjacency matrix or biadjacency matrix of the graph.
|
|
86
|
+
seeds :
|
|
87
|
+
Parameter to be used for Personalized PageRank.
|
|
88
|
+
Restart distribution as a vector or a dict (node: weight).
|
|
89
|
+
If ``None``, the uniform distribution is used (no personalization, default).
|
|
90
|
+
seeds_row, seeds_col :
|
|
91
|
+
Parameter to be used for Personalized PageRank on bipartite graphs.
|
|
92
|
+
Restart distribution as vectors or dicts on rows, columns (node: weight).
|
|
93
|
+
If both seeds_row and seeds_col are ``None`` (default), the uniform distribution on rows is used.
|
|
94
|
+
force_bipartite :
|
|
95
|
+
If ``True``, consider the input matrix as the biadjacency matrix of a bipartite graph.
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
self: :class:`PageRank`
|
|
99
|
+
"""
|
|
100
|
+
adjacency, seeds, self.bipartite = get_adjacency_seeds(input_matrix, force_bipartite=force_bipartite,
|
|
101
|
+
seeds=seeds, seeds_row=seeds_row,
|
|
102
|
+
seeds_col=seeds_col, default_value=0, which='probs')
|
|
103
|
+
self.scores_ = get_pagerank(adjacency, seeds, damping_factor=self.damping_factor, n_iter=self.n_iter,
|
|
104
|
+
solver=self.solver, tol=self.tol)
|
|
105
|
+
if self.bipartite:
|
|
106
|
+
self._split_vars(input_matrix.shape)
|
|
107
|
+
return self
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created on May 2019
|
|
5
|
+
@author: Nathan de Lara <nathan.delara@polytechnique.org>
|
|
6
|
+
"""
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def top_k(scores: np.ndarray, k: int = 1):
|
|
11
|
+
"""Return the indices of the k elements of highest values.
|
|
12
|
+
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
scores : np.ndarray
|
|
16
|
+
Array of values.
|
|
17
|
+
k : int
|
|
18
|
+
Number of elements to return.
|
|
19
|
+
|
|
20
|
+
Examples
|
|
21
|
+
--------
|
|
22
|
+
>>> top_k([1, 3, 2], k=2)
|
|
23
|
+
array([1, 2])
|
|
24
|
+
"""
|
|
25
|
+
return np.argpartition(-np.array(scores), k)[:k]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""tests for ranking"""
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""tests for ranking API"""
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
from sknetwork.data.test_graphs import test_bigraph, test_graph, test_digraph
|
|
7
|
+
from sknetwork.ranking import *
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TestPageRank(unittest.TestCase):
|
|
11
|
+
|
|
12
|
+
def test_basic(self):
|
|
13
|
+
methods = [PageRank(), Closeness(), HITS(), Harmonic(), Katz()]
|
|
14
|
+
for adjacency in [test_graph(), test_digraph()]:
|
|
15
|
+
n = adjacency.shape[0]
|
|
16
|
+
for method in methods:
|
|
17
|
+
score = method.fit_predict(adjacency)
|
|
18
|
+
self.assertEqual(score.shape, (n, ))
|
|
19
|
+
self.assertTrue(min(score) >= 0)
|
|
20
|
+
score = method.fit_transform(adjacency)
|
|
21
|
+
self.assertEqual(score.shape, (n,))
|
|
22
|
+
|
|
23
|
+
def test_bipartite(self):
|
|
24
|
+
biadjacency = test_bigraph()
|
|
25
|
+
n_row, n_col = biadjacency.shape
|
|
26
|
+
|
|
27
|
+
methods = [PageRank(), HITS(), Katz()]
|
|
28
|
+
for method in methods:
|
|
29
|
+
method.fit(biadjacency)
|
|
30
|
+
scores_row = method.scores_row_
|
|
31
|
+
scores_col = method.scores_col_
|
|
32
|
+
|
|
33
|
+
self.assertEqual(scores_row.shape, (n_row,))
|
|
34
|
+
self.assertEqual(scores_col.shape, (n_col,))
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""tests for betweenness.py"""
|
|
4
|
+
|
|
5
|
+
import unittest
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from sknetwork.ranking.betweenness import Betweenness
|
|
9
|
+
from sknetwork.data.test_graphs import test_graph, test_graph_disconnect
|
|
10
|
+
from sknetwork.data.toy_graphs import bow_tie, star_wars
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestBetweenness(unittest.TestCase):
|
|
14
|
+
|
|
15
|
+
def test_basic(self):
|
|
16
|
+
adjacency = test_graph()
|
|
17
|
+
betweenness = Betweenness()
|
|
18
|
+
scores = betweenness.fit_predict(adjacency)
|
|
19
|
+
self.assertEqual(len(scores), adjacency.shape[0])
|
|
20
|
+
|
|
21
|
+
def test_bowtie(self):
|
|
22
|
+
adjacency = bow_tie()
|
|
23
|
+
betweenness = Betweenness()
|
|
24
|
+
scores = betweenness.fit_predict(adjacency)
|
|
25
|
+
self.assertEqual(np.sum(scores > 0), 1)
|
|
26
|
+
|
|
27
|
+
def test_disconnected(self):
|
|
28
|
+
adjacency = test_graph_disconnect()
|
|
29
|
+
betweenness = Betweenness()
|
|
30
|
+
with self.assertRaises(ValueError):
|
|
31
|
+
betweenness.fit(adjacency)
|
|
32
|
+
|
|
33
|
+
def test_bipartite(self):
|
|
34
|
+
adjacency = star_wars()
|
|
35
|
+
betweenness = Betweenness()
|
|
36
|
+
|
|
37
|
+
with self.assertRaises(ValueError):
|
|
38
|
+
betweenness.fit_transform(adjacency)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""tests for closeness.py"""
|
|
4
|
+
|
|
5
|
+
import unittest
|
|
6
|
+
|
|
7
|
+
from sknetwork.data.test_graphs import *
|
|
8
|
+
from sknetwork.ranking.closeness import Closeness
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TestDiffusion(unittest.TestCase):
|
|
12
|
+
|
|
13
|
+
def test_params(self):
|
|
14
|
+
with self.assertRaises(ValueError):
|
|
15
|
+
adjacency = test_graph()
|
|
16
|
+
Closeness(method='toto').fit(adjacency)
|
|
17
|
+
|
|
18
|
+
def test_parallel(self):
|
|
19
|
+
adjacency = test_graph()
|
|
20
|
+
n = adjacency.shape[0]
|
|
21
|
+
|
|
22
|
+
closeness = Closeness(method='approximate')
|
|
23
|
+
scores1 = closeness.fit_predict(adjacency)
|
|
24
|
+
closeness = Closeness(method='approximate', n_jobs=-1)
|
|
25
|
+
scores2 = closeness.fit_predict(adjacency)
|
|
26
|
+
|
|
27
|
+
self.assertEqual(scores1.shape, (n,))
|
|
28
|
+
self.assertAlmostEqual(np.linalg.norm(scores1 - scores2), 0)
|
|
29
|
+
|
|
30
|
+
def test_disconnected(self):
|
|
31
|
+
adjacency = test_graph_disconnect()
|
|
32
|
+
closeness = Closeness()
|
|
33
|
+
with self.assertRaises(ValueError):
|
|
34
|
+
closeness.fit(adjacency)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""Tests for his.py"""
|
|
4
|
+
|
|
5
|
+
import unittest
|
|
6
|
+
|
|
7
|
+
from sknetwork.data.test_graphs import test_bigraph
|
|
8
|
+
from sknetwork.ranking import HITS
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TestHITS(unittest.TestCase):
|
|
12
|
+
|
|
13
|
+
def test_keywords(self):
|
|
14
|
+
biadjacency = test_bigraph()
|
|
15
|
+
n_row, n_col = biadjacency.shape
|
|
16
|
+
|
|
17
|
+
hits = HITS()
|
|
18
|
+
hits.fit(biadjacency)
|
|
19
|
+
self.assertEqual(hits.scores_row_.shape, (n_row,))
|
|
20
|
+
self.assertEqual(hits.scores_col_.shape, (n_col,))
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""tests for pagerank.py"""
|
|
4
|
+
|
|
5
|
+
import unittest
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from sknetwork.data.models import cyclic_digraph
|
|
10
|
+
from sknetwork.data.test_graphs import test_bigraph
|
|
11
|
+
from sknetwork.ranking.pagerank import PageRank
|
|
12
|
+
from sknetwork.utils import co_neighbor_graph
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TestPageRank(unittest.TestCase):
|
|
16
|
+
|
|
17
|
+
def setUp(self) -> None:
|
|
18
|
+
"""Cycle graph for tests."""
|
|
19
|
+
self.n = 5
|
|
20
|
+
self.adjacency = cyclic_digraph(self.n)
|
|
21
|
+
self.truth = np.ones(self.n) / self.n
|
|
22
|
+
|
|
23
|
+
def test_params(self):
|
|
24
|
+
with self.assertRaises(ValueError):
|
|
25
|
+
PageRank(damping_factor=1789)
|
|
26
|
+
|
|
27
|
+
def test_solvers(self):
|
|
28
|
+
for solver in ['piteration', 'lanczos', 'bicgstab', 'RH']:
|
|
29
|
+
pagerank = PageRank(solver=solver)
|
|
30
|
+
scores = pagerank.fit_predict(self.adjacency)
|
|
31
|
+
self.assertAlmostEqual(0, np.linalg.norm(scores - self.truth))
|
|
32
|
+
with self.assertRaises(ValueError):
|
|
33
|
+
PageRank(solver='toto').fit_predict(self.adjacency)
|
|
34
|
+
|
|
35
|
+
def test_seeding(self):
|
|
36
|
+
pagerank = PageRank()
|
|
37
|
+
seeds_array = np.zeros(self.n)
|
|
38
|
+
seeds_array[0] = 1.
|
|
39
|
+
seeds_dict = {0: 1}
|
|
40
|
+
|
|
41
|
+
scores1 = pagerank.fit_transform(self.adjacency, seeds_array)
|
|
42
|
+
scores2 = pagerank.fit_transform(self.adjacency, seeds_dict)
|
|
43
|
+
self.assertAlmostEqual(np.linalg.norm(scores1 - scores2), 0.)
|
|
44
|
+
|
|
45
|
+
def test_input(self):
|
|
46
|
+
pagerank = PageRank()
|
|
47
|
+
scores = pagerank.fit_predict(self.adjacency, force_bipartite=True)
|
|
48
|
+
self.assertEqual(len(scores), len(pagerank.scores_col_))
|
|
49
|
+
|
|
50
|
+
def test_damping(self):
|
|
51
|
+
pagerank = PageRank(damping_factor=0.99)
|
|
52
|
+
scores = pagerank.fit_transform(self.adjacency)
|
|
53
|
+
self.assertAlmostEqual(np.linalg.norm(scores - self.truth), 0.)
|
|
54
|
+
|
|
55
|
+
pagerank = PageRank(damping_factor=0.01)
|
|
56
|
+
scores = pagerank.fit_transform(self.adjacency)
|
|
57
|
+
self.assertAlmostEqual(np.linalg.norm(scores - self.truth), 0.)
|
|
58
|
+
|
|
59
|
+
def test_copagerank(self):
|
|
60
|
+
seeds = {0: 1}
|
|
61
|
+
biadjacency = test_bigraph()
|
|
62
|
+
|
|
63
|
+
adjacency = co_neighbor_graph(biadjacency, method='exact', normalized=True)
|
|
64
|
+
pagerank = PageRank(damping_factor=0.85, solver='lanczos')
|
|
65
|
+
pagerank.fit(biadjacency, seeds_row=seeds)
|
|
66
|
+
scores1 = pagerank.scores_row_
|
|
67
|
+
scores1 /= scores1.sum()
|
|
68
|
+
scores2 = PageRank(damping_factor=0.85**2, solver='lanczos').fit_transform(adjacency, seeds)
|
|
69
|
+
self.assertAlmostEqual(np.linalg.norm(scores1 - scores2), 0., places=6)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Created on April 2022
|
|
5
|
+
@author: Thomas Bonald <bonald@enst.fr>
|
|
6
|
+
"""
|
|
7
|
+
from abc import ABC
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from sknetwork.utils.base import Algorithm
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BaseRegressor(Algorithm, ABC):
|
|
15
|
+
"""Base class for regression algorithms.
|
|
16
|
+
|
|
17
|
+
Attributes
|
|
18
|
+
----------
|
|
19
|
+
values_ : np.ndarray
|
|
20
|
+
Value of each node.
|
|
21
|
+
values_row_: np.ndarray
|
|
22
|
+
Values of rows, for bipartite graphs.
|
|
23
|
+
values_col_: np.ndarray
|
|
24
|
+
Values of columns, for bipartite graphs.
|
|
25
|
+
"""
|
|
26
|
+
def __init__(self):
|
|
27
|
+
self.values_ = None
|
|
28
|
+
|
|
29
|
+
def fit_predict(self, *args, **kwargs) -> np.ndarray:
|
|
30
|
+
"""Fit algorithm to data and return the scores. Same parameters as the ``fit`` method.
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
values : np.ndarray
|
|
35
|
+
Values.
|
|
36
|
+
"""
|
|
37
|
+
self.fit(*args, **kwargs)
|
|
38
|
+
return self.values_
|
|
39
|
+
|
|
40
|
+
def fit_transform(self, *args, **kwargs) -> np.ndarray:
|
|
41
|
+
"""Fit algorithm to data and return the scores. Alias for ``fit_transform``.
|
|
42
|
+
Same parameters as the ``fit`` method.
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
values : np.ndarray
|
|
47
|
+
Values.
|
|
48
|
+
"""
|
|
49
|
+
self.fit(*args, **kwargs)
|
|
50
|
+
return self.values_
|
|
51
|
+
|
|
52
|
+
def _split_vars(self, shape):
|
|
53
|
+
n_row = shape[0]
|
|
54
|
+
self.values_row_ = self.values_[:n_row]
|
|
55
|
+
self.values_col_ = self.values_[n_row:]
|
|
56
|
+
self.values_ = self.values_row_
|