sclab 0.2.5__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sclab might be problematic. Click here for more details.
- sclab/__init__.py +1 -1
- sclab/dataset/_dataset.py +1 -1
- sclab/examples/processor_steps/__init__.py +2 -0
- sclab/examples/processor_steps/_doublet_detection.py +68 -0
- sclab/examples/processor_steps/_integration.py +37 -4
- sclab/examples/processor_steps/_neighbors.py +24 -4
- sclab/examples/processor_steps/_pca.py +5 -5
- sclab/examples/processor_steps/_preprocess.py +14 -1
- sclab/examples/processor_steps/_qc.py +22 -6
- sclab/gui/__init__.py +0 -0
- sclab/gui/components/__init__.py +5 -0
- sclab/gui/components/_guided_pseudotime.py +482 -0
- sclab/methods/__init__.py +25 -1
- sclab/preprocess/__init__.py +18 -0
- sclab/preprocess/_cca.py +154 -0
- sclab/preprocess/_cca_integrate.py +77 -0
- sclab/preprocess/_filter_obs.py +42 -0
- sclab/preprocess/_harmony.py +421 -0
- sclab/preprocess/_harmony_integrate.py +50 -0
- sclab/preprocess/_normalize_weighted.py +61 -0
- sclab/preprocess/_subset.py +208 -0
- sclab/preprocess/_transfer_metadata.py +137 -0
- sclab/preprocess/_transform.py +82 -0
- sclab/preprocess/_utils.py +96 -0
- sclab/tools/__init__.py +0 -0
- sclab/tools/cellflow/__init__.py +0 -0
- sclab/tools/cellflow/density_dynamics/__init__.py +0 -0
- sclab/tools/cellflow/density_dynamics/_density_dynamics.py +349 -0
- sclab/tools/cellflow/pseudotime/__init__.py +0 -0
- sclab/tools/cellflow/pseudotime/_pseudotime.py +332 -0
- sclab/tools/cellflow/pseudotime/timeseries.py +226 -0
- sclab/tools/cellflow/utils/__init__.py +0 -0
- sclab/tools/cellflow/utils/density_nd.py +136 -0
- sclab/tools/cellflow/utils/interpolate.py +334 -0
- sclab/tools/cellflow/utils/smoothen.py +124 -0
- sclab/tools/cellflow/utils/times.py +55 -0
- sclab/tools/differential_expression/__init__.py +5 -0
- sclab/tools/differential_expression/_pseudobulk_edger.py +304 -0
- sclab/tools/differential_expression/_pseudobulk_helpers.py +277 -0
- sclab/tools/doublet_detection/__init__.py +5 -0
- sclab/tools/doublet_detection/_scrublet.py +64 -0
- sclab/tools/labeling/__init__.py +6 -0
- sclab/tools/labeling/sctype.py +233 -0
- sclab/utils/__init__.py +5 -0
- sclab/utils/_write_excel.py +510 -0
- {sclab-0.2.5.dist-info → sclab-0.3.0.dist-info}/METADATA +6 -2
- sclab-0.3.0.dist-info/RECORD +81 -0
- sclab-0.2.5.dist-info/RECORD +0 -45
- {sclab-0.2.5.dist-info → sclab-0.3.0.dist-info}/WHEEL +0 -0
- {sclab-0.2.5.dist-info → sclab-0.3.0.dist-info}/licenses/LICENSE +0 -0
sclab/preprocess/_cca.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy import matrix
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
from scipy.linalg import svd
|
|
8
|
+
from scipy.sparse import csc_matrix, csr_matrix, issparse
|
|
9
|
+
from scipy.sparse.linalg import svds
|
|
10
|
+
from sklearn.utils.extmath import randomized_svd
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def cca(
|
|
16
|
+
X: NDArray | csr_matrix | csc_matrix,
|
|
17
|
+
Y: NDArray | csr_matrix | csc_matrix,
|
|
18
|
+
n_components=None,
|
|
19
|
+
svd_solver: Literal["full", "partial", "randomized"] = "partial",
|
|
20
|
+
normalize: bool = False,
|
|
21
|
+
random_state=42,
|
|
22
|
+
):
|
|
23
|
+
"""
|
|
24
|
+
CCA-style integration for two single-cell matrices with unequal numbers of cells.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
X, Y : array-like, shape (n_cells, n_features)
|
|
29
|
+
feature-by-cell matrices with same column space (variable genes/pcs) in the same order.
|
|
30
|
+
n_components : int or None
|
|
31
|
+
Dimensionality of the canonical space (default = all that the smaller
|
|
32
|
+
dataset allows).
|
|
33
|
+
svd_solver : {'full', 'partial', 'randomized'}
|
|
34
|
+
'randomized' uses Halko et al. algorithm (`sklearn.utils.extmath.randomized_svd`)
|
|
35
|
+
and is strongly recommended when only the leading few components are needed.
|
|
36
|
+
random_state : int or None
|
|
37
|
+
Passed through to the randomized SVD for reproducibility.
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
U : (n_cells(X), k) ndarray
|
|
42
|
+
V : (n_cells(Y), k) ndarray
|
|
43
|
+
Cell-level canonical variables.
|
|
44
|
+
"""
|
|
45
|
+
n1, p1 = X.shape
|
|
46
|
+
n2, p2 = Y.shape
|
|
47
|
+
if p1 != p2:
|
|
48
|
+
raise ValueError("The two matrices must have the same number of features.")
|
|
49
|
+
|
|
50
|
+
k = n_components or min(n1, n2)
|
|
51
|
+
|
|
52
|
+
if issparse(X):
|
|
53
|
+
C = _cross_covariance_sparse(X, Y)
|
|
54
|
+
else:
|
|
55
|
+
C = _cross_covariance_dense(X, Y)
|
|
56
|
+
|
|
57
|
+
logger.info(f"Cross-covariance computed. Shape: {C.shape}")
|
|
58
|
+
|
|
59
|
+
Uc, s, Vct = _svd_decomposition(C, k, svd_solver, random_state)
|
|
60
|
+
|
|
61
|
+
# canonical variables
|
|
62
|
+
# Left and right singular vectors are cell embeddings
|
|
63
|
+
U = Uc # (n1 x k)
|
|
64
|
+
V = Vct.T # (n2 x k)
|
|
65
|
+
|
|
66
|
+
if normalize:
|
|
67
|
+
logger.info("Normalizing canonical variables...")
|
|
68
|
+
U = U / np.linalg.norm(U, axis=1, keepdims=True)
|
|
69
|
+
V = V / np.linalg.norm(V, axis=1, keepdims=True)
|
|
70
|
+
|
|
71
|
+
logger.info("Done.")
|
|
72
|
+
|
|
73
|
+
return U, s, V
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _svd_decomposition(
|
|
77
|
+
C: NDArray,
|
|
78
|
+
k: int,
|
|
79
|
+
svd_solver: Literal["full", "partial", "randomized"],
|
|
80
|
+
random_state: int | None,
|
|
81
|
+
) -> tuple[NDArray, NDArray, NDArray]:
|
|
82
|
+
if svd_solver == "full":
|
|
83
|
+
logger.info("SVD decomposition with full SVD...")
|
|
84
|
+
Uc, s, Vct = svd(C, full_matrices=False)
|
|
85
|
+
Uc, s, Vct = Uc[:, :k], s[:k], Vct[:k, :]
|
|
86
|
+
|
|
87
|
+
elif svd_solver == "partial":
|
|
88
|
+
logger.info("SVD decomposition with partial SVD...")
|
|
89
|
+
Uc, s, Vct = svds(C, k=k)
|
|
90
|
+
|
|
91
|
+
elif svd_solver == "randomized":
|
|
92
|
+
logger.info("SVD decomposition with randomized SVD...")
|
|
93
|
+
Uc, s, Vct = randomized_svd(C, n_components=k, random_state=random_state)
|
|
94
|
+
|
|
95
|
+
else:
|
|
96
|
+
raise ValueError("svd_solver must be 'full' or 'partial'.")
|
|
97
|
+
|
|
98
|
+
order = np.argsort(-s)
|
|
99
|
+
s = s[order]
|
|
100
|
+
Uc = Uc[:, order]
|
|
101
|
+
Vct = Vct[order, :]
|
|
102
|
+
|
|
103
|
+
return Uc, s, Vct
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _cross_covariance_sparse(X: csr_matrix, Y: csr_matrix) -> NDArray:
|
|
107
|
+
_, p1 = X.shape
|
|
108
|
+
_, p2 = Y.shape
|
|
109
|
+
if p1 != p2:
|
|
110
|
+
raise ValueError("The two matrices must have the same number of features.")
|
|
111
|
+
|
|
112
|
+
p = p1
|
|
113
|
+
|
|
114
|
+
# TODO: incorporate sparse scaling
|
|
115
|
+
|
|
116
|
+
logger.info("Computing cross-covariance on sparse matrices...")
|
|
117
|
+
|
|
118
|
+
mux: matrix = X.mean(axis=0)
|
|
119
|
+
muy: matrix = Y.mean(axis=0)
|
|
120
|
+
|
|
121
|
+
XYt: csr_matrix = X.dot(Y.T)
|
|
122
|
+
Xmuyt: matrix = X.dot(muy.T)
|
|
123
|
+
muxYt: matrix = Y.dot(mux.T).T
|
|
124
|
+
muxmuyt: float = (mux @ muy.T)[0, 0]
|
|
125
|
+
|
|
126
|
+
C = (XYt - Xmuyt - muxYt + muxmuyt) / (p - 1)
|
|
127
|
+
|
|
128
|
+
return np.asarray(C)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _cross_covariance_dense(X: NDArray, Y: NDArray) -> NDArray:
|
|
132
|
+
_, p1 = X.shape
|
|
133
|
+
_, p2 = Y.shape
|
|
134
|
+
if p1 != p2:
|
|
135
|
+
raise ValueError("The two matrices must have the same number of features.")
|
|
136
|
+
|
|
137
|
+
p = p1
|
|
138
|
+
|
|
139
|
+
logger.info("Computing cross-covariance on dense matrices...")
|
|
140
|
+
X = _dense_scale(X)
|
|
141
|
+
Y = _dense_scale(Y)
|
|
142
|
+
|
|
143
|
+
X = X - X.mean(axis=0, keepdims=True)
|
|
144
|
+
Y = Y - Y.mean(axis=0, keepdims=True)
|
|
145
|
+
|
|
146
|
+
C: NDArray = (X @ Y.T) / (p - 1)
|
|
147
|
+
|
|
148
|
+
return C
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _dense_scale(A: NDArray) -> NDArray:
|
|
152
|
+
A = np.asarray(A)
|
|
153
|
+
eps = np.finfo(A.dtype).eps
|
|
154
|
+
return A / (A.std(axis=0, ddof=1, keepdims=True) + eps)
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from anndata import AnnData
|
|
3
|
+
|
|
4
|
+
from ._cca import cca
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def cca_integrate_pair(
|
|
8
|
+
adata: AnnData,
|
|
9
|
+
key: str,
|
|
10
|
+
group1: str,
|
|
11
|
+
group2: str,
|
|
12
|
+
*,
|
|
13
|
+
basis: str | None = None,
|
|
14
|
+
adjusted_basis: str | None = None,
|
|
15
|
+
mask_var: str | None = None,
|
|
16
|
+
n_components: int = 50,
|
|
17
|
+
svd_solver: str = "partial",
|
|
18
|
+
normalize: bool = False,
|
|
19
|
+
random_state: int | None = None,
|
|
20
|
+
):
|
|
21
|
+
if basis is None:
|
|
22
|
+
basis = "X"
|
|
23
|
+
|
|
24
|
+
if adjusted_basis is None:
|
|
25
|
+
adjusted_basis = basis + "_cca"
|
|
26
|
+
|
|
27
|
+
if mask_var is not None:
|
|
28
|
+
mask = adata.var[mask_var].values
|
|
29
|
+
else:
|
|
30
|
+
mask = np.ones(adata.n_vars, dtype=bool)
|
|
31
|
+
|
|
32
|
+
Xs = {}
|
|
33
|
+
groups = adata.obs.groupby(key, observed=True).groups
|
|
34
|
+
for gr, idx in groups.items():
|
|
35
|
+
Xs[gr] = _get_basis(adata[idx, mask], basis)
|
|
36
|
+
|
|
37
|
+
Ys = {}
|
|
38
|
+
Ys[group1], sigma, Ys[group2] = cca(
|
|
39
|
+
Xs[group1],
|
|
40
|
+
Xs[group2],
|
|
41
|
+
n_components=n_components,
|
|
42
|
+
svd_solver=svd_solver,
|
|
43
|
+
normalize=normalize,
|
|
44
|
+
random_state=random_state,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
if (
|
|
48
|
+
adjusted_basis not in adata.obsm
|
|
49
|
+
or adata.obsm[adjusted_basis].shape[1] != n_components
|
|
50
|
+
):
|
|
51
|
+
adata.obsm[adjusted_basis] = np.full((adata.n_obs, n_components), np.nan)
|
|
52
|
+
|
|
53
|
+
if adjusted_basis not in adata.uns:
|
|
54
|
+
adata.uns[adjusted_basis] = {}
|
|
55
|
+
|
|
56
|
+
uns = adata.uns[adjusted_basis]
|
|
57
|
+
uns[f"{group1}-{group2}"] = {"sigma": sigma}
|
|
58
|
+
for gr, obs_names in groups.items():
|
|
59
|
+
idx = adata.obs_names.get_indexer(obs_names)
|
|
60
|
+
adata.obsm[adjusted_basis][idx] = Ys[gr]
|
|
61
|
+
uns[gr] = Ys[gr]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _get_basis(adata: AnnData, basis: str):
|
|
65
|
+
if basis == "X":
|
|
66
|
+
X = adata.X
|
|
67
|
+
|
|
68
|
+
elif basis in adata.layers:
|
|
69
|
+
X = adata.layers[basis]
|
|
70
|
+
|
|
71
|
+
elif basis in adata.obsm:
|
|
72
|
+
X = adata.obsm[basis]
|
|
73
|
+
|
|
74
|
+
else:
|
|
75
|
+
raise ValueError(f"Unknown basis {basis}")
|
|
76
|
+
|
|
77
|
+
return X
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from anndata import AnnData
|
|
3
|
+
from scipy.stats import rankdata
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def filter_obs(
|
|
7
|
+
adata: AnnData,
|
|
8
|
+
*,
|
|
9
|
+
layer: str | None = None,
|
|
10
|
+
min_counts: int | None = None,
|
|
11
|
+
min_genes: int | None = None,
|
|
12
|
+
max_counts: int | None = None,
|
|
13
|
+
max_cells: int | None = None,
|
|
14
|
+
) -> None:
|
|
15
|
+
if layer is not None:
|
|
16
|
+
X = adata.layers[layer]
|
|
17
|
+
else:
|
|
18
|
+
X = adata.X
|
|
19
|
+
|
|
20
|
+
remove_mask = np.zeros(X.shape[0], dtype=bool)
|
|
21
|
+
|
|
22
|
+
if min_genes is not None:
|
|
23
|
+
M = X > 0
|
|
24
|
+
rowsums = np.asarray(M.sum(axis=1)).squeeze()
|
|
25
|
+
remove_mask[rowsums < min_genes] = True
|
|
26
|
+
|
|
27
|
+
if min_counts is not None or max_counts is not None or max_cells is not None:
|
|
28
|
+
rowsums = np.asarray(X.sum(axis=1)).squeeze()
|
|
29
|
+
|
|
30
|
+
if min_counts is not None:
|
|
31
|
+
remove_mask[rowsums < min_counts] = True
|
|
32
|
+
|
|
33
|
+
if max_counts is not None:
|
|
34
|
+
remove_mask[rowsums > max_counts] = True
|
|
35
|
+
|
|
36
|
+
if max_cells is not None:
|
|
37
|
+
ranks = rankdata(-rowsums, method="min")
|
|
38
|
+
remove_mask[ranks > max_cells] = True
|
|
39
|
+
|
|
40
|
+
if remove_mask.any():
|
|
41
|
+
obs_idx = adata.obs_names[~remove_mask]
|
|
42
|
+
adata._inplace_subset_obs(obs_idx)
|
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
# harmonypy - A data alignment algorithm.
|
|
2
|
+
# Copyright (C) 2018 Ilya Korsunsky
|
|
3
|
+
# 2019 Kamil Slowikowski <kslowikowski@gmail.com>
|
|
4
|
+
#
|
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
8
|
+
# (at your option) any later version.
|
|
9
|
+
#
|
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
13
|
+
# GNU General Public License for more details.
|
|
14
|
+
#
|
|
15
|
+
# You should have received a copy of the GNU General Public License
|
|
16
|
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
17
|
+
|
|
18
|
+
from functools import partial
|
|
19
|
+
import pandas as pd
|
|
20
|
+
import numpy as np
|
|
21
|
+
from sklearn.cluster import KMeans
|
|
22
|
+
import logging
|
|
23
|
+
|
|
24
|
+
# create logger
|
|
25
|
+
logger = logging.getLogger("harmonypy")
|
|
26
|
+
logger.setLevel(logging.DEBUG)
|
|
27
|
+
ch = logging.StreamHandler()
|
|
28
|
+
ch.setLevel(logging.DEBUG)
|
|
29
|
+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
30
|
+
ch.setFormatter(formatter)
|
|
31
|
+
logger.addHandler(ch)
|
|
32
|
+
|
|
33
|
+
# from IPython.core.debugger import set_trace
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def run_harmony(
|
|
37
|
+
data_mat: np.ndarray,
|
|
38
|
+
meta_data: pd.DataFrame,
|
|
39
|
+
vars_use,
|
|
40
|
+
theta=None,
|
|
41
|
+
lamb=None,
|
|
42
|
+
sigma=0.1,
|
|
43
|
+
nclust=None,
|
|
44
|
+
tau=0,
|
|
45
|
+
block_size=0.05,
|
|
46
|
+
max_iter_harmony=10,
|
|
47
|
+
max_iter_kmeans=20,
|
|
48
|
+
epsilon_cluster=1e-5,
|
|
49
|
+
epsilon_harmony=1e-4,
|
|
50
|
+
plot_convergence=False,
|
|
51
|
+
verbose=True,
|
|
52
|
+
reference_values=None,
|
|
53
|
+
cluster_prior=None,
|
|
54
|
+
random_state=0,
|
|
55
|
+
cluster_fn="kmeans",
|
|
56
|
+
):
|
|
57
|
+
"""Run Harmony."""
|
|
58
|
+
|
|
59
|
+
# theta = None
|
|
60
|
+
# lamb = None
|
|
61
|
+
# sigma = 0.1
|
|
62
|
+
# nclust = None
|
|
63
|
+
# tau = 0
|
|
64
|
+
# block_size = 0.05
|
|
65
|
+
# epsilon_cluster = 1e-5
|
|
66
|
+
# epsilon_harmony = 1e-4
|
|
67
|
+
# plot_convergence = False
|
|
68
|
+
# verbose = True
|
|
69
|
+
# reference_values = None
|
|
70
|
+
# cluster_prior = None
|
|
71
|
+
# random_state = 0
|
|
72
|
+
# cluster_fn = 'kmeans'. Also accepts a callable object with data, num_clusters parameters
|
|
73
|
+
|
|
74
|
+
N = meta_data.shape[0]
|
|
75
|
+
if data_mat.shape[1] != N:
|
|
76
|
+
data_mat = data_mat.T
|
|
77
|
+
|
|
78
|
+
assert data_mat.shape[1] == N, (
|
|
79
|
+
"data_mat and meta_data do not have the same number of cells"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if nclust is None:
|
|
83
|
+
nclust = np.min([np.round(N / 30.0), 100]).astype(int)
|
|
84
|
+
|
|
85
|
+
if type(sigma) is float and nclust > 1:
|
|
86
|
+
sigma = np.repeat(sigma, nclust)
|
|
87
|
+
|
|
88
|
+
if isinstance(vars_use, str):
|
|
89
|
+
vars_use = [vars_use]
|
|
90
|
+
|
|
91
|
+
phi = pd.get_dummies(meta_data[vars_use]).to_numpy().T
|
|
92
|
+
phi_n = meta_data[vars_use].describe().loc["unique"].to_numpy().astype(int)
|
|
93
|
+
|
|
94
|
+
if theta is None:
|
|
95
|
+
theta = np.repeat([1] * len(phi_n), phi_n)
|
|
96
|
+
elif isinstance(theta, float) or isinstance(theta, int):
|
|
97
|
+
theta = np.repeat([theta] * len(phi_n), phi_n)
|
|
98
|
+
elif len(theta) == len(phi_n):
|
|
99
|
+
theta = np.repeat([theta], phi_n)
|
|
100
|
+
|
|
101
|
+
assert len(theta) == np.sum(phi_n), "each batch variable must have a theta"
|
|
102
|
+
|
|
103
|
+
if lamb is None:
|
|
104
|
+
lamb = np.repeat([1] * len(phi_n), phi_n)
|
|
105
|
+
elif isinstance(lamb, float) or isinstance(lamb, int):
|
|
106
|
+
lamb = np.repeat([lamb] * len(phi_n), phi_n)
|
|
107
|
+
elif len(lamb) == len(phi_n):
|
|
108
|
+
lamb = np.repeat([lamb], phi_n)
|
|
109
|
+
|
|
110
|
+
assert len(lamb) == np.sum(phi_n), "each batch variable must have a lambda"
|
|
111
|
+
|
|
112
|
+
# Number of items in each category.
|
|
113
|
+
N_b = phi.sum(axis=1)
|
|
114
|
+
# Proportion of items in each category.
|
|
115
|
+
Pr_b = N_b / N
|
|
116
|
+
|
|
117
|
+
if tau > 0:
|
|
118
|
+
theta = theta * (1 - np.exp(-((N_b / (nclust * tau)) ** 2)))
|
|
119
|
+
|
|
120
|
+
lamb_mat = np.diag(np.insert(lamb, 0, 0))
|
|
121
|
+
|
|
122
|
+
phi_moe = np.vstack((np.repeat(1, N), phi))
|
|
123
|
+
|
|
124
|
+
np.random.seed(random_state)
|
|
125
|
+
|
|
126
|
+
ho = Harmony(
|
|
127
|
+
data_mat,
|
|
128
|
+
phi,
|
|
129
|
+
phi_moe,
|
|
130
|
+
Pr_b,
|
|
131
|
+
sigma,
|
|
132
|
+
theta,
|
|
133
|
+
max_iter_harmony,
|
|
134
|
+
max_iter_kmeans,
|
|
135
|
+
epsilon_cluster,
|
|
136
|
+
epsilon_harmony,
|
|
137
|
+
nclust,
|
|
138
|
+
block_size,
|
|
139
|
+
lamb_mat,
|
|
140
|
+
verbose,
|
|
141
|
+
random_state,
|
|
142
|
+
cluster_fn,
|
|
143
|
+
reference_values,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return ho
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class Harmony(object):
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
Z,
|
|
153
|
+
Phi,
|
|
154
|
+
Phi_moe,
|
|
155
|
+
Pr_b,
|
|
156
|
+
sigma,
|
|
157
|
+
theta,
|
|
158
|
+
max_iter_harmony,
|
|
159
|
+
max_iter_kmeans,
|
|
160
|
+
epsilon_kmeans,
|
|
161
|
+
epsilon_harmony,
|
|
162
|
+
K,
|
|
163
|
+
block_size,
|
|
164
|
+
lamb,
|
|
165
|
+
verbose,
|
|
166
|
+
random_state=None,
|
|
167
|
+
cluster_fn="kmeans",
|
|
168
|
+
frozen_values=None,
|
|
169
|
+
):
|
|
170
|
+
self.Z_corr = np.array(Z)
|
|
171
|
+
self.Z_orig = np.array(Z)
|
|
172
|
+
|
|
173
|
+
self.Z_cos = self.Z_orig / self.Z_orig.max(axis=0)
|
|
174
|
+
self.Z_cos = self.Z_cos / np.linalg.norm(self.Z_cos, ord=2, axis=0)
|
|
175
|
+
|
|
176
|
+
self.Phi = Phi
|
|
177
|
+
self.Phi_moe = Phi_moe
|
|
178
|
+
self.N = self.Z_corr.shape[1]
|
|
179
|
+
self.Pr_b = Pr_b
|
|
180
|
+
self.B = self.Phi.shape[0] # number of batch variables
|
|
181
|
+
self.d = self.Z_corr.shape[0]
|
|
182
|
+
self.window_size = 3
|
|
183
|
+
self.epsilon_kmeans = epsilon_kmeans
|
|
184
|
+
self.epsilon_harmony = epsilon_harmony
|
|
185
|
+
self.reference_values = frozen_values
|
|
186
|
+
|
|
187
|
+
self.lamb = lamb
|
|
188
|
+
self.sigma = sigma
|
|
189
|
+
self.sigma_prior = sigma
|
|
190
|
+
self.block_size = block_size
|
|
191
|
+
self.K = K # number of clusters
|
|
192
|
+
self.max_iter_harmony = max_iter_harmony
|
|
193
|
+
self.max_iter_kmeans = max_iter_kmeans
|
|
194
|
+
self.verbose = verbose
|
|
195
|
+
self.theta = theta
|
|
196
|
+
|
|
197
|
+
self.objective_harmony = []
|
|
198
|
+
self.objective_kmeans = []
|
|
199
|
+
self.objective_kmeans_dist = []
|
|
200
|
+
self.objective_kmeans_entropy = []
|
|
201
|
+
self.objective_kmeans_cross = []
|
|
202
|
+
self.kmeans_rounds = []
|
|
203
|
+
|
|
204
|
+
self.allocate_buffers()
|
|
205
|
+
if cluster_fn == "kmeans":
|
|
206
|
+
cluster_fn = partial(Harmony._cluster_kmeans, random_state=random_state)
|
|
207
|
+
self.init_cluster(cluster_fn)
|
|
208
|
+
self.harmonize(self.max_iter_harmony, self.verbose)
|
|
209
|
+
|
|
210
|
+
def result(self):
|
|
211
|
+
return self.Z_corr
|
|
212
|
+
|
|
213
|
+
def allocate_buffers(self):
|
|
214
|
+
self._scale_dist = np.zeros((self.K, self.N))
|
|
215
|
+
self.dist_mat = np.zeros((self.K, self.N))
|
|
216
|
+
self.O = np.zeros((self.K, self.B))
|
|
217
|
+
self.E = np.zeros((self.K, self.B))
|
|
218
|
+
self.W = np.zeros((self.B + 1, self.d))
|
|
219
|
+
self.Phi_Rk = np.zeros((self.B + 1, self.N))
|
|
220
|
+
|
|
221
|
+
@staticmethod
|
|
222
|
+
def _cluster_kmeans(data, K, random_state):
|
|
223
|
+
# Start with cluster centroids
|
|
224
|
+
logger.info("Computing initial centroids with sklearn.KMeans...")
|
|
225
|
+
model = KMeans(
|
|
226
|
+
n_clusters=K,
|
|
227
|
+
init="k-means++",
|
|
228
|
+
n_init=10,
|
|
229
|
+
max_iter=25,
|
|
230
|
+
random_state=random_state,
|
|
231
|
+
)
|
|
232
|
+
model.fit(data)
|
|
233
|
+
km_centroids, km_labels = model.cluster_centers_, model.labels_
|
|
234
|
+
logger.info("sklearn.KMeans initialization complete.")
|
|
235
|
+
return km_centroids
|
|
236
|
+
|
|
237
|
+
def init_cluster(self, cluster_fn):
|
|
238
|
+
self.Y = cluster_fn(self.Z_cos.T, self.K).T
|
|
239
|
+
# (1) Normalize
|
|
240
|
+
self.Y = self.Y / np.linalg.norm(self.Y, ord=2, axis=0)
|
|
241
|
+
# (2) Assign cluster probabilities
|
|
242
|
+
self.dist_mat = 2 * (1 - np.dot(self.Y.T, self.Z_cos))
|
|
243
|
+
self.R = -self.dist_mat
|
|
244
|
+
self.R = self.R / self.sigma[:, None]
|
|
245
|
+
self.R -= np.max(self.R, axis=0)
|
|
246
|
+
self.R = np.exp(self.R)
|
|
247
|
+
self.R = self.R / np.sum(self.R, axis=0)
|
|
248
|
+
# (3) Batch diversity statistics
|
|
249
|
+
self.E = np.outer(np.sum(self.R, axis=1), self.Pr_b)
|
|
250
|
+
self.O = np.inner(self.R, self.Phi)
|
|
251
|
+
self.compute_objective()
|
|
252
|
+
# Save results
|
|
253
|
+
self.objective_harmony.append(self.objective_kmeans[-1])
|
|
254
|
+
|
|
255
|
+
def compute_objective(self):
|
|
256
|
+
kmeans_error = np.sum(self.R * self.dist_mat)
|
|
257
|
+
# Entropy
|
|
258
|
+
_entropy = np.sum(safe_entropy(self.R) * self.sigma[:, np.newaxis])
|
|
259
|
+
# Cross Entropy
|
|
260
|
+
x = self.R * self.sigma[:, np.newaxis]
|
|
261
|
+
y = np.tile(self.theta[:, np.newaxis], self.K).T
|
|
262
|
+
z = np.log((self.O + 1) / (self.E + 1))
|
|
263
|
+
w = np.dot(y * z, self.Phi)
|
|
264
|
+
_cross_entropy = np.sum(x * w)
|
|
265
|
+
# Save results
|
|
266
|
+
# print(f"{kmeans_error=}, {_entropy=}, {_cross_entropy=}")
|
|
267
|
+
self.objective_kmeans.append(kmeans_error + _entropy + _cross_entropy)
|
|
268
|
+
self.objective_kmeans_dist.append(kmeans_error)
|
|
269
|
+
self.objective_kmeans_entropy.append(_entropy)
|
|
270
|
+
self.objective_kmeans_cross.append(_cross_entropy)
|
|
271
|
+
|
|
272
|
+
def harmonize(self, iter_harmony=10, verbose=True):
|
|
273
|
+
converged = False
|
|
274
|
+
for i in range(1, iter_harmony + 1):
|
|
275
|
+
if verbose:
|
|
276
|
+
# logger.info("Iteration {} of {}".format(i, iter_harmony))
|
|
277
|
+
pass
|
|
278
|
+
# STEP 1: Clustering
|
|
279
|
+
self.cluster()
|
|
280
|
+
# STEP 2: Regress out covariates
|
|
281
|
+
# self.moe_correct_ridge()
|
|
282
|
+
self.Z_cos, self.Z_corr, self.W, self.Phi_Rk = moe_correct_ridge(
|
|
283
|
+
self.Z_orig,
|
|
284
|
+
self.Z_cos,
|
|
285
|
+
self.Z_corr,
|
|
286
|
+
self.R,
|
|
287
|
+
self.W,
|
|
288
|
+
self.K,
|
|
289
|
+
self.Phi_Rk,
|
|
290
|
+
self.Phi_moe,
|
|
291
|
+
self.lamb,
|
|
292
|
+
self.reference_values,
|
|
293
|
+
)
|
|
294
|
+
# STEP 3: Check for convergence
|
|
295
|
+
converged = self.check_convergence(1)
|
|
296
|
+
if converged:
|
|
297
|
+
if verbose:
|
|
298
|
+
logger.info(
|
|
299
|
+
"Converged after {} iteration{}".format(i, "s" if i > 1 else "")
|
|
300
|
+
)
|
|
301
|
+
break
|
|
302
|
+
if verbose and not converged:
|
|
303
|
+
logger.info("Stopped before convergence")
|
|
304
|
+
return 0
|
|
305
|
+
|
|
306
|
+
def cluster(self):
|
|
307
|
+
# Z_cos has changed
|
|
308
|
+
# R is assumed to not have changed
|
|
309
|
+
# Update Y to match new integrated data
|
|
310
|
+
self.dist_mat = 2 * (1 - np.dot(self.Y.T, self.Z_cos))
|
|
311
|
+
for i in range(self.max_iter_kmeans):
|
|
312
|
+
# print("kmeans {}".format(i))
|
|
313
|
+
# STEP 1: Update Y
|
|
314
|
+
self.Y = np.dot(self.Z_cos, self.R.T)
|
|
315
|
+
self.Y = self.Y / np.linalg.norm(self.Y, ord=2, axis=0)
|
|
316
|
+
# STEP 2: Update dist_mat
|
|
317
|
+
self.dist_mat = 2 * (1 - np.dot(self.Y.T, self.Z_cos))
|
|
318
|
+
# STEP 3: Update R
|
|
319
|
+
self.update_R()
|
|
320
|
+
# STEP 4: Check for convergence
|
|
321
|
+
self.compute_objective()
|
|
322
|
+
if i > self.window_size:
|
|
323
|
+
converged = self.check_convergence(0)
|
|
324
|
+
if converged:
|
|
325
|
+
break
|
|
326
|
+
self.kmeans_rounds.append(i)
|
|
327
|
+
self.objective_harmony.append(self.objective_kmeans[-1])
|
|
328
|
+
return 0
|
|
329
|
+
|
|
330
|
+
def update_R(self):
|
|
331
|
+
self._scale_dist = -self.dist_mat
|
|
332
|
+
self._scale_dist = self._scale_dist / self.sigma[:, None]
|
|
333
|
+
self._scale_dist -= np.max(self._scale_dist, axis=0)
|
|
334
|
+
self._scale_dist = np.exp(self._scale_dist)
|
|
335
|
+
# Update cells in blocks
|
|
336
|
+
update_order = np.arange(self.N)
|
|
337
|
+
np.random.shuffle(update_order)
|
|
338
|
+
n_blocks = np.ceil(1 / self.block_size).astype(int)
|
|
339
|
+
blocks = np.array_split(update_order, n_blocks)
|
|
340
|
+
for b in blocks:
|
|
341
|
+
# STEP 1: Remove cells
|
|
342
|
+
self.E -= np.outer(np.sum(self.R[:, b], axis=1), self.Pr_b)
|
|
343
|
+
self.O -= np.dot(self.R[:, b], self.Phi[:, b].T)
|
|
344
|
+
# STEP 2: Recompute R for removed cells
|
|
345
|
+
self.R[:, b] = self._scale_dist[:, b]
|
|
346
|
+
self.R[:, b] = np.multiply(
|
|
347
|
+
self.R[:, b],
|
|
348
|
+
np.dot(
|
|
349
|
+
np.power((self.E + 1) / (self.O + 1), self.theta), self.Phi[:, b]
|
|
350
|
+
),
|
|
351
|
+
)
|
|
352
|
+
self.R[:, b] = self.R[:, b] / np.linalg.norm(self.R[:, b], ord=1, axis=0)
|
|
353
|
+
# STEP 3: Put cells back
|
|
354
|
+
self.E += np.outer(np.sum(self.R[:, b], axis=1), self.Pr_b)
|
|
355
|
+
self.O += np.dot(self.R[:, b], self.Phi[:, b].T)
|
|
356
|
+
return 0
|
|
357
|
+
|
|
358
|
+
def check_convergence(self, i_type):
|
|
359
|
+
obj_old = 0.0
|
|
360
|
+
obj_new = 0.0
|
|
361
|
+
# Clustering, compute new window mean
|
|
362
|
+
if i_type == 0:
|
|
363
|
+
okl = len(self.objective_kmeans)
|
|
364
|
+
for i in range(self.window_size):
|
|
365
|
+
obj_old += self.objective_kmeans[okl - 2 - i]
|
|
366
|
+
obj_new += self.objective_kmeans[okl - 1 - i]
|
|
367
|
+
if (score := (abs(obj_old - obj_new) / abs(obj_old))) < self.epsilon_kmeans:
|
|
368
|
+
return True
|
|
369
|
+
# logger.info("Score: {} >= {}".format(score, self.epsilon_kmeans))
|
|
370
|
+
return False
|
|
371
|
+
# Harmony
|
|
372
|
+
if i_type == 1:
|
|
373
|
+
obj_old = self.objective_harmony[-2]
|
|
374
|
+
obj_new = self.objective_harmony[-1]
|
|
375
|
+
if (
|
|
376
|
+
score := (abs(obj_old - obj_new) / abs(obj_old))
|
|
377
|
+
) < self.epsilon_harmony:
|
|
378
|
+
# logger.info("Score: {} >= {}".format(score, self.epsilon_harmony))
|
|
379
|
+
return True
|
|
380
|
+
# logger.info("Score: {} >= {}".format(score, self.epsilon_harmony))
|
|
381
|
+
return False
|
|
382
|
+
return True
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def safe_entropy(x: np.array):
|
|
386
|
+
y = np.multiply(x, np.log(x))
|
|
387
|
+
y[~np.isfinite(y)] = 0.0
|
|
388
|
+
return y
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def moe_correct_ridge(
|
|
392
|
+
Z_orig, Z_cos, Z_corr, R, W, K, Phi_Rk, Phi_moe, lamb, frozen_values=None
|
|
393
|
+
):
|
|
394
|
+
"""
|
|
395
|
+
Z_orig, Z_cos, Z_corr: DxN
|
|
396
|
+
R: KxN
|
|
397
|
+
W: (B+1)xD
|
|
398
|
+
Phi_moe: (B+1)xN
|
|
399
|
+
lamb: (B+1)x(B+1) diag matrix
|
|
400
|
+
"""
|
|
401
|
+
Z_corr = Z_orig.copy()
|
|
402
|
+
|
|
403
|
+
if frozen_values is not None:
|
|
404
|
+
update_mask = ~frozen_values
|
|
405
|
+
else:
|
|
406
|
+
update_mask = np.ones(Z_orig.shape[1], dtype=bool)
|
|
407
|
+
|
|
408
|
+
for i in range(K):
|
|
409
|
+
# standard design
|
|
410
|
+
Phi_Rk = Phi_moe * R[i, :]
|
|
411
|
+
|
|
412
|
+
# ridge regression to get W
|
|
413
|
+
x = Phi_Rk @ Phi_moe.T + lamb
|
|
414
|
+
W = np.linalg.inv(x) @ Phi_Rk @ Z_orig.T
|
|
415
|
+
W[0, :] = 0 # don’t remove intercept
|
|
416
|
+
|
|
417
|
+
# apply correction
|
|
418
|
+
Z_corr[:, update_mask] -= (W.T @ Phi_Rk)[:, update_mask]
|
|
419
|
+
|
|
420
|
+
Z_cos = Z_corr / np.linalg.norm(Z_corr, ord=2, axis=0)
|
|
421
|
+
return Z_cos, Z_corr, W, Phi_Rk
|