sclab 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sclab might be problematic. Click here for more details.

Files changed (51) hide show
  1. sclab/__init__.py +1 -1
  2. sclab/_sclab.py +10 -3
  3. sclab/dataset/_dataset.py +1 -1
  4. sclab/examples/processor_steps/__init__.py +2 -0
  5. sclab/examples/processor_steps/_doublet_detection.py +68 -0
  6. sclab/examples/processor_steps/_integration.py +37 -4
  7. sclab/examples/processor_steps/_neighbors.py +24 -4
  8. sclab/examples/processor_steps/_pca.py +5 -5
  9. sclab/examples/processor_steps/_preprocess.py +14 -1
  10. sclab/examples/processor_steps/_qc.py +22 -6
  11. sclab/gui/__init__.py +0 -0
  12. sclab/gui/components/__init__.py +5 -0
  13. sclab/gui/components/_guided_pseudotime.py +482 -0
  14. sclab/methods/__init__.py +25 -1
  15. sclab/preprocess/__init__.py +18 -0
  16. sclab/preprocess/_cca.py +154 -0
  17. sclab/preprocess/_cca_integrate.py +77 -0
  18. sclab/preprocess/_filter_obs.py +42 -0
  19. sclab/preprocess/_harmony.py +421 -0
  20. sclab/preprocess/_harmony_integrate.py +50 -0
  21. sclab/preprocess/_normalize_weighted.py +61 -0
  22. sclab/preprocess/_subset.py +208 -0
  23. sclab/preprocess/_transfer_metadata.py +137 -0
  24. sclab/preprocess/_transform.py +82 -0
  25. sclab/preprocess/_utils.py +96 -0
  26. sclab/tools/__init__.py +0 -0
  27. sclab/tools/cellflow/__init__.py +0 -0
  28. sclab/tools/cellflow/density_dynamics/__init__.py +0 -0
  29. sclab/tools/cellflow/density_dynamics/_density_dynamics.py +349 -0
  30. sclab/tools/cellflow/pseudotime/__init__.py +0 -0
  31. sclab/tools/cellflow/pseudotime/_pseudotime.py +332 -0
  32. sclab/tools/cellflow/pseudotime/timeseries.py +226 -0
  33. sclab/tools/cellflow/utils/__init__.py +0 -0
  34. sclab/tools/cellflow/utils/density_nd.py +136 -0
  35. sclab/tools/cellflow/utils/interpolate.py +334 -0
  36. sclab/tools/cellflow/utils/smoothen.py +124 -0
  37. sclab/tools/cellflow/utils/times.py +55 -0
  38. sclab/tools/differential_expression/__init__.py +5 -0
  39. sclab/tools/differential_expression/_pseudobulk_edger.py +304 -0
  40. sclab/tools/differential_expression/_pseudobulk_helpers.py +277 -0
  41. sclab/tools/doublet_detection/__init__.py +5 -0
  42. sclab/tools/doublet_detection/_scrublet.py +64 -0
  43. sclab/tools/labeling/__init__.py +6 -0
  44. sclab/tools/labeling/sctype.py +233 -0
  45. sclab/utils/__init__.py +5 -0
  46. sclab/utils/_write_excel.py +510 -0
  47. {sclab-0.2.4.dist-info → sclab-0.3.0.dist-info}/METADATA +7 -2
  48. sclab-0.3.0.dist-info/RECORD +81 -0
  49. sclab-0.2.4.dist-info/RECORD +0 -45
  50. {sclab-0.2.4.dist-info → sclab-0.3.0.dist-info}/WHEEL +0 -0
  51. {sclab-0.2.4.dist-info → sclab-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,154 @@
1
+ import logging
2
+ from typing import Literal
3
+
4
+ import numpy as np
5
+ from numpy import matrix
6
+ from numpy.typing import NDArray
7
+ from scipy.linalg import svd
8
+ from scipy.sparse import csc_matrix, csr_matrix, issparse
9
+ from scipy.sparse.linalg import svds
10
+ from sklearn.utils.extmath import randomized_svd
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def cca(
16
+ X: NDArray | csr_matrix | csc_matrix,
17
+ Y: NDArray | csr_matrix | csc_matrix,
18
+ n_components=None,
19
+ svd_solver: Literal["full", "partial", "randomized"] = "partial",
20
+ normalize: bool = False,
21
+ random_state=42,
22
+ ):
23
+ """
24
+ CCA-style integration for two single-cell matrices with unequal numbers of cells.
25
+
26
+ Parameters
27
+ ----------
28
+ X, Y : array-like, shape (n_cells, n_features)
29
+ feature-by-cell matrices with same column space (variable genes/pcs) in the same order.
30
+ n_components : int or None
31
+ Dimensionality of the canonical space (default = all that the smaller
32
+ dataset allows).
33
+ svd_solver : {'full', 'partial', 'randomized'}
34
+ 'randomized' uses Halko et al. algorithm (`sklearn.utils.extmath.randomized_svd`)
35
+ and is strongly recommended when only the leading few components are needed.
36
+ random_state : int or None
37
+ Passed through to the randomized SVD for reproducibility.
38
+
39
+ Returns
40
+ -------
41
+ U : (n_cells(X), k) ndarray
42
+ V : (n_cells(Y), k) ndarray
43
+ Cell-level canonical variables.
44
+ """
45
+ n1, p1 = X.shape
46
+ n2, p2 = Y.shape
47
+ if p1 != p2:
48
+ raise ValueError("The two matrices must have the same number of features.")
49
+
50
+ k = n_components or min(n1, n2)
51
+
52
+ if issparse(X):
53
+ C = _cross_covariance_sparse(X, Y)
54
+ else:
55
+ C = _cross_covariance_dense(X, Y)
56
+
57
+ logger.info(f"Cross-covariance computed. Shape: {C.shape}")
58
+
59
+ Uc, s, Vct = _svd_decomposition(C, k, svd_solver, random_state)
60
+
61
+ # canonical variables
62
+ # Left and right singular vectors are cell embeddings
63
+ U = Uc # (n1 x k)
64
+ V = Vct.T # (n2 x k)
65
+
66
+ if normalize:
67
+ logger.info("Normalizing canonical variables...")
68
+ U = U / np.linalg.norm(U, axis=1, keepdims=True)
69
+ V = V / np.linalg.norm(V, axis=1, keepdims=True)
70
+
71
+ logger.info("Done.")
72
+
73
+ return U, s, V
74
+
75
+
76
+ def _svd_decomposition(
77
+ C: NDArray,
78
+ k: int,
79
+ svd_solver: Literal["full", "partial", "randomized"],
80
+ random_state: int | None,
81
+ ) -> tuple[NDArray, NDArray, NDArray]:
82
+ if svd_solver == "full":
83
+ logger.info("SVD decomposition with full SVD...")
84
+ Uc, s, Vct = svd(C, full_matrices=False)
85
+ Uc, s, Vct = Uc[:, :k], s[:k], Vct[:k, :]
86
+
87
+ elif svd_solver == "partial":
88
+ logger.info("SVD decomposition with partial SVD...")
89
+ Uc, s, Vct = svds(C, k=k)
90
+
91
+ elif svd_solver == "randomized":
92
+ logger.info("SVD decomposition with randomized SVD...")
93
+ Uc, s, Vct = randomized_svd(C, n_components=k, random_state=random_state)
94
+
95
+ else:
96
+ raise ValueError("svd_solver must be 'full' or 'partial'.")
97
+
98
+ order = np.argsort(-s)
99
+ s = s[order]
100
+ Uc = Uc[:, order]
101
+ Vct = Vct[order, :]
102
+
103
+ return Uc, s, Vct
104
+
105
+
106
+ def _cross_covariance_sparse(X: csr_matrix, Y: csr_matrix) -> NDArray:
107
+ _, p1 = X.shape
108
+ _, p2 = Y.shape
109
+ if p1 != p2:
110
+ raise ValueError("The two matrices must have the same number of features.")
111
+
112
+ p = p1
113
+
114
+ # TODO: incorporate sparse scaling
115
+
116
+ logger.info("Computing cross-covariance on sparse matrices...")
117
+
118
+ mux: matrix = X.mean(axis=0)
119
+ muy: matrix = Y.mean(axis=0)
120
+
121
+ XYt: csr_matrix = X.dot(Y.T)
122
+ Xmuyt: matrix = X.dot(muy.T)
123
+ muxYt: matrix = Y.dot(mux.T).T
124
+ muxmuyt: float = (mux @ muy.T)[0, 0]
125
+
126
+ C = (XYt - Xmuyt - muxYt + muxmuyt) / (p - 1)
127
+
128
+ return np.asarray(C)
129
+
130
+
131
+ def _cross_covariance_dense(X: NDArray, Y: NDArray) -> NDArray:
132
+ _, p1 = X.shape
133
+ _, p2 = Y.shape
134
+ if p1 != p2:
135
+ raise ValueError("The two matrices must have the same number of features.")
136
+
137
+ p = p1
138
+
139
+ logger.info("Computing cross-covariance on dense matrices...")
140
+ X = _dense_scale(X)
141
+ Y = _dense_scale(Y)
142
+
143
+ X = X - X.mean(axis=0, keepdims=True)
144
+ Y = Y - Y.mean(axis=0, keepdims=True)
145
+
146
+ C: NDArray = (X @ Y.T) / (p - 1)
147
+
148
+ return C
149
+
150
+
151
+ def _dense_scale(A: NDArray) -> NDArray:
152
+ A = np.asarray(A)
153
+ eps = np.finfo(A.dtype).eps
154
+ return A / (A.std(axis=0, ddof=1, keepdims=True) + eps)
@@ -0,0 +1,77 @@
1
+ import numpy as np
2
+ from anndata import AnnData
3
+
4
+ from ._cca import cca
5
+
6
+
7
+ def cca_integrate_pair(
8
+ adata: AnnData,
9
+ key: str,
10
+ group1: str,
11
+ group2: str,
12
+ *,
13
+ basis: str | None = None,
14
+ adjusted_basis: str | None = None,
15
+ mask_var: str | None = None,
16
+ n_components: int = 50,
17
+ svd_solver: str = "partial",
18
+ normalize: bool = False,
19
+ random_state: int | None = None,
20
+ ):
21
+ if basis is None:
22
+ basis = "X"
23
+
24
+ if adjusted_basis is None:
25
+ adjusted_basis = basis + "_cca"
26
+
27
+ if mask_var is not None:
28
+ mask = adata.var[mask_var].values
29
+ else:
30
+ mask = np.ones(adata.n_vars, dtype=bool)
31
+
32
+ Xs = {}
33
+ groups = adata.obs.groupby(key, observed=True).groups
34
+ for gr, idx in groups.items():
35
+ Xs[gr] = _get_basis(adata[idx, mask], basis)
36
+
37
+ Ys = {}
38
+ Ys[group1], sigma, Ys[group2] = cca(
39
+ Xs[group1],
40
+ Xs[group2],
41
+ n_components=n_components,
42
+ svd_solver=svd_solver,
43
+ normalize=normalize,
44
+ random_state=random_state,
45
+ )
46
+
47
+ if (
48
+ adjusted_basis not in adata.obsm
49
+ or adata.obsm[adjusted_basis].shape[1] != n_components
50
+ ):
51
+ adata.obsm[adjusted_basis] = np.full((adata.n_obs, n_components), np.nan)
52
+
53
+ if adjusted_basis not in adata.uns:
54
+ adata.uns[adjusted_basis] = {}
55
+
56
+ uns = adata.uns[adjusted_basis]
57
+ uns[f"{group1}-{group2}"] = {"sigma": sigma}
58
+ for gr, obs_names in groups.items():
59
+ idx = adata.obs_names.get_indexer(obs_names)
60
+ adata.obsm[adjusted_basis][idx] = Ys[gr]
61
+ uns[gr] = Ys[gr]
62
+
63
+
64
+ def _get_basis(adata: AnnData, basis: str):
65
+ if basis == "X":
66
+ X = adata.X
67
+
68
+ elif basis in adata.layers:
69
+ X = adata.layers[basis]
70
+
71
+ elif basis in adata.obsm:
72
+ X = adata.obsm[basis]
73
+
74
+ else:
75
+ raise ValueError(f"Unknown basis {basis}")
76
+
77
+ return X
@@ -0,0 +1,42 @@
1
+ import numpy as np
2
+ from anndata import AnnData
3
+ from scipy.stats import rankdata
4
+
5
+
6
+ def filter_obs(
7
+ adata: AnnData,
8
+ *,
9
+ layer: str | None = None,
10
+ min_counts: int | None = None,
11
+ min_genes: int | None = None,
12
+ max_counts: int | None = None,
13
+ max_cells: int | None = None,
14
+ ) -> None:
15
+ if layer is not None:
16
+ X = adata.layers[layer]
17
+ else:
18
+ X = adata.X
19
+
20
+ remove_mask = np.zeros(X.shape[0], dtype=bool)
21
+
22
+ if min_genes is not None:
23
+ M = X > 0
24
+ rowsums = np.asarray(M.sum(axis=1)).squeeze()
25
+ remove_mask[rowsums < min_genes] = True
26
+
27
+ if min_counts is not None or max_counts is not None or max_cells is not None:
28
+ rowsums = np.asarray(X.sum(axis=1)).squeeze()
29
+
30
+ if min_counts is not None:
31
+ remove_mask[rowsums < min_counts] = True
32
+
33
+ if max_counts is not None:
34
+ remove_mask[rowsums > max_counts] = True
35
+
36
+ if max_cells is not None:
37
+ ranks = rankdata(-rowsums, method="min")
38
+ remove_mask[ranks > max_cells] = True
39
+
40
+ if remove_mask.any():
41
+ obs_idx = adata.obs_names[~remove_mask]
42
+ adata._inplace_subset_obs(obs_idx)
@@ -0,0 +1,421 @@
1
+ # harmonypy - A data alignment algorithm.
2
+ # Copyright (C) 2018 Ilya Korsunsky
3
+ # 2019 Kamil Slowikowski <kslowikowski@gmail.com>
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+
18
+ from functools import partial
19
+ import pandas as pd
20
+ import numpy as np
21
+ from sklearn.cluster import KMeans
22
+ import logging
23
+
24
+ # create logger
25
+ logger = logging.getLogger("harmonypy")
26
+ logger.setLevel(logging.DEBUG)
27
+ ch = logging.StreamHandler()
28
+ ch.setLevel(logging.DEBUG)
29
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
30
+ ch.setFormatter(formatter)
31
+ logger.addHandler(ch)
32
+
33
+ # from IPython.core.debugger import set_trace
34
+
35
+
36
+ def run_harmony(
37
+ data_mat: np.ndarray,
38
+ meta_data: pd.DataFrame,
39
+ vars_use,
40
+ theta=None,
41
+ lamb=None,
42
+ sigma=0.1,
43
+ nclust=None,
44
+ tau=0,
45
+ block_size=0.05,
46
+ max_iter_harmony=10,
47
+ max_iter_kmeans=20,
48
+ epsilon_cluster=1e-5,
49
+ epsilon_harmony=1e-4,
50
+ plot_convergence=False,
51
+ verbose=True,
52
+ reference_values=None,
53
+ cluster_prior=None,
54
+ random_state=0,
55
+ cluster_fn="kmeans",
56
+ ):
57
+ """Run Harmony."""
58
+
59
+ # theta = None
60
+ # lamb = None
61
+ # sigma = 0.1
62
+ # nclust = None
63
+ # tau = 0
64
+ # block_size = 0.05
65
+ # epsilon_cluster = 1e-5
66
+ # epsilon_harmony = 1e-4
67
+ # plot_convergence = False
68
+ # verbose = True
69
+ # reference_values = None
70
+ # cluster_prior = None
71
+ # random_state = 0
72
+ # cluster_fn = 'kmeans'. Also accepts a callable object with data, num_clusters parameters
73
+
74
+ N = meta_data.shape[0]
75
+ if data_mat.shape[1] != N:
76
+ data_mat = data_mat.T
77
+
78
+ assert data_mat.shape[1] == N, (
79
+ "data_mat and meta_data do not have the same number of cells"
80
+ )
81
+
82
+ if nclust is None:
83
+ nclust = np.min([np.round(N / 30.0), 100]).astype(int)
84
+
85
+ if type(sigma) is float and nclust > 1:
86
+ sigma = np.repeat(sigma, nclust)
87
+
88
+ if isinstance(vars_use, str):
89
+ vars_use = [vars_use]
90
+
91
+ phi = pd.get_dummies(meta_data[vars_use]).to_numpy().T
92
+ phi_n = meta_data[vars_use].describe().loc["unique"].to_numpy().astype(int)
93
+
94
+ if theta is None:
95
+ theta = np.repeat([1] * len(phi_n), phi_n)
96
+ elif isinstance(theta, float) or isinstance(theta, int):
97
+ theta = np.repeat([theta] * len(phi_n), phi_n)
98
+ elif len(theta) == len(phi_n):
99
+ theta = np.repeat([theta], phi_n)
100
+
101
+ assert len(theta) == np.sum(phi_n), "each batch variable must have a theta"
102
+
103
+ if lamb is None:
104
+ lamb = np.repeat([1] * len(phi_n), phi_n)
105
+ elif isinstance(lamb, float) or isinstance(lamb, int):
106
+ lamb = np.repeat([lamb] * len(phi_n), phi_n)
107
+ elif len(lamb) == len(phi_n):
108
+ lamb = np.repeat([lamb], phi_n)
109
+
110
+ assert len(lamb) == np.sum(phi_n), "each batch variable must have a lambda"
111
+
112
+ # Number of items in each category.
113
+ N_b = phi.sum(axis=1)
114
+ # Proportion of items in each category.
115
+ Pr_b = N_b / N
116
+
117
+ if tau > 0:
118
+ theta = theta * (1 - np.exp(-((N_b / (nclust * tau)) ** 2)))
119
+
120
+ lamb_mat = np.diag(np.insert(lamb, 0, 0))
121
+
122
+ phi_moe = np.vstack((np.repeat(1, N), phi))
123
+
124
+ np.random.seed(random_state)
125
+
126
+ ho = Harmony(
127
+ data_mat,
128
+ phi,
129
+ phi_moe,
130
+ Pr_b,
131
+ sigma,
132
+ theta,
133
+ max_iter_harmony,
134
+ max_iter_kmeans,
135
+ epsilon_cluster,
136
+ epsilon_harmony,
137
+ nclust,
138
+ block_size,
139
+ lamb_mat,
140
+ verbose,
141
+ random_state,
142
+ cluster_fn,
143
+ reference_values,
144
+ )
145
+
146
+ return ho
147
+
148
+
149
+ class Harmony(object):
150
+ def __init__(
151
+ self,
152
+ Z,
153
+ Phi,
154
+ Phi_moe,
155
+ Pr_b,
156
+ sigma,
157
+ theta,
158
+ max_iter_harmony,
159
+ max_iter_kmeans,
160
+ epsilon_kmeans,
161
+ epsilon_harmony,
162
+ K,
163
+ block_size,
164
+ lamb,
165
+ verbose,
166
+ random_state=None,
167
+ cluster_fn="kmeans",
168
+ frozen_values=None,
169
+ ):
170
+ self.Z_corr = np.array(Z)
171
+ self.Z_orig = np.array(Z)
172
+
173
+ self.Z_cos = self.Z_orig / self.Z_orig.max(axis=0)
174
+ self.Z_cos = self.Z_cos / np.linalg.norm(self.Z_cos, ord=2, axis=0)
175
+
176
+ self.Phi = Phi
177
+ self.Phi_moe = Phi_moe
178
+ self.N = self.Z_corr.shape[1]
179
+ self.Pr_b = Pr_b
180
+ self.B = self.Phi.shape[0] # number of batch variables
181
+ self.d = self.Z_corr.shape[0]
182
+ self.window_size = 3
183
+ self.epsilon_kmeans = epsilon_kmeans
184
+ self.epsilon_harmony = epsilon_harmony
185
+ self.reference_values = frozen_values
186
+
187
+ self.lamb = lamb
188
+ self.sigma = sigma
189
+ self.sigma_prior = sigma
190
+ self.block_size = block_size
191
+ self.K = K # number of clusters
192
+ self.max_iter_harmony = max_iter_harmony
193
+ self.max_iter_kmeans = max_iter_kmeans
194
+ self.verbose = verbose
195
+ self.theta = theta
196
+
197
+ self.objective_harmony = []
198
+ self.objective_kmeans = []
199
+ self.objective_kmeans_dist = []
200
+ self.objective_kmeans_entropy = []
201
+ self.objective_kmeans_cross = []
202
+ self.kmeans_rounds = []
203
+
204
+ self.allocate_buffers()
205
+ if cluster_fn == "kmeans":
206
+ cluster_fn = partial(Harmony._cluster_kmeans, random_state=random_state)
207
+ self.init_cluster(cluster_fn)
208
+ self.harmonize(self.max_iter_harmony, self.verbose)
209
+
210
+ def result(self):
211
+ return self.Z_corr
212
+
213
+ def allocate_buffers(self):
214
+ self._scale_dist = np.zeros((self.K, self.N))
215
+ self.dist_mat = np.zeros((self.K, self.N))
216
+ self.O = np.zeros((self.K, self.B))
217
+ self.E = np.zeros((self.K, self.B))
218
+ self.W = np.zeros((self.B + 1, self.d))
219
+ self.Phi_Rk = np.zeros((self.B + 1, self.N))
220
+
221
+ @staticmethod
222
+ def _cluster_kmeans(data, K, random_state):
223
+ # Start with cluster centroids
224
+ logger.info("Computing initial centroids with sklearn.KMeans...")
225
+ model = KMeans(
226
+ n_clusters=K,
227
+ init="k-means++",
228
+ n_init=10,
229
+ max_iter=25,
230
+ random_state=random_state,
231
+ )
232
+ model.fit(data)
233
+ km_centroids, km_labels = model.cluster_centers_, model.labels_
234
+ logger.info("sklearn.KMeans initialization complete.")
235
+ return km_centroids
236
+
237
+ def init_cluster(self, cluster_fn):
238
+ self.Y = cluster_fn(self.Z_cos.T, self.K).T
239
+ # (1) Normalize
240
+ self.Y = self.Y / np.linalg.norm(self.Y, ord=2, axis=0)
241
+ # (2) Assign cluster probabilities
242
+ self.dist_mat = 2 * (1 - np.dot(self.Y.T, self.Z_cos))
243
+ self.R = -self.dist_mat
244
+ self.R = self.R / self.sigma[:, None]
245
+ self.R -= np.max(self.R, axis=0)
246
+ self.R = np.exp(self.R)
247
+ self.R = self.R / np.sum(self.R, axis=0)
248
+ # (3) Batch diversity statistics
249
+ self.E = np.outer(np.sum(self.R, axis=1), self.Pr_b)
250
+ self.O = np.inner(self.R, self.Phi)
251
+ self.compute_objective()
252
+ # Save results
253
+ self.objective_harmony.append(self.objective_kmeans[-1])
254
+
255
+ def compute_objective(self):
256
+ kmeans_error = np.sum(self.R * self.dist_mat)
257
+ # Entropy
258
+ _entropy = np.sum(safe_entropy(self.R) * self.sigma[:, np.newaxis])
259
+ # Cross Entropy
260
+ x = self.R * self.sigma[:, np.newaxis]
261
+ y = np.tile(self.theta[:, np.newaxis], self.K).T
262
+ z = np.log((self.O + 1) / (self.E + 1))
263
+ w = np.dot(y * z, self.Phi)
264
+ _cross_entropy = np.sum(x * w)
265
+ # Save results
266
+ # print(f"{kmeans_error=}, {_entropy=}, {_cross_entropy=}")
267
+ self.objective_kmeans.append(kmeans_error + _entropy + _cross_entropy)
268
+ self.objective_kmeans_dist.append(kmeans_error)
269
+ self.objective_kmeans_entropy.append(_entropy)
270
+ self.objective_kmeans_cross.append(_cross_entropy)
271
+
272
+ def harmonize(self, iter_harmony=10, verbose=True):
273
+ converged = False
274
+ for i in range(1, iter_harmony + 1):
275
+ if verbose:
276
+ # logger.info("Iteration {} of {}".format(i, iter_harmony))
277
+ pass
278
+ # STEP 1: Clustering
279
+ self.cluster()
280
+ # STEP 2: Regress out covariates
281
+ # self.moe_correct_ridge()
282
+ self.Z_cos, self.Z_corr, self.W, self.Phi_Rk = moe_correct_ridge(
283
+ self.Z_orig,
284
+ self.Z_cos,
285
+ self.Z_corr,
286
+ self.R,
287
+ self.W,
288
+ self.K,
289
+ self.Phi_Rk,
290
+ self.Phi_moe,
291
+ self.lamb,
292
+ self.reference_values,
293
+ )
294
+ # STEP 3: Check for convergence
295
+ converged = self.check_convergence(1)
296
+ if converged:
297
+ if verbose:
298
+ logger.info(
299
+ "Converged after {} iteration{}".format(i, "s" if i > 1 else "")
300
+ )
301
+ break
302
+ if verbose and not converged:
303
+ logger.info("Stopped before convergence")
304
+ return 0
305
+
306
+ def cluster(self):
307
+ # Z_cos has changed
308
+ # R is assumed to not have changed
309
+ # Update Y to match new integrated data
310
+ self.dist_mat = 2 * (1 - np.dot(self.Y.T, self.Z_cos))
311
+ for i in range(self.max_iter_kmeans):
312
+ # print("kmeans {}".format(i))
313
+ # STEP 1: Update Y
314
+ self.Y = np.dot(self.Z_cos, self.R.T)
315
+ self.Y = self.Y / np.linalg.norm(self.Y, ord=2, axis=0)
316
+ # STEP 2: Update dist_mat
317
+ self.dist_mat = 2 * (1 - np.dot(self.Y.T, self.Z_cos))
318
+ # STEP 3: Update R
319
+ self.update_R()
320
+ # STEP 4: Check for convergence
321
+ self.compute_objective()
322
+ if i > self.window_size:
323
+ converged = self.check_convergence(0)
324
+ if converged:
325
+ break
326
+ self.kmeans_rounds.append(i)
327
+ self.objective_harmony.append(self.objective_kmeans[-1])
328
+ return 0
329
+
330
+ def update_R(self):
331
+ self._scale_dist = -self.dist_mat
332
+ self._scale_dist = self._scale_dist / self.sigma[:, None]
333
+ self._scale_dist -= np.max(self._scale_dist, axis=0)
334
+ self._scale_dist = np.exp(self._scale_dist)
335
+ # Update cells in blocks
336
+ update_order = np.arange(self.N)
337
+ np.random.shuffle(update_order)
338
+ n_blocks = np.ceil(1 / self.block_size).astype(int)
339
+ blocks = np.array_split(update_order, n_blocks)
340
+ for b in blocks:
341
+ # STEP 1: Remove cells
342
+ self.E -= np.outer(np.sum(self.R[:, b], axis=1), self.Pr_b)
343
+ self.O -= np.dot(self.R[:, b], self.Phi[:, b].T)
344
+ # STEP 2: Recompute R for removed cells
345
+ self.R[:, b] = self._scale_dist[:, b]
346
+ self.R[:, b] = np.multiply(
347
+ self.R[:, b],
348
+ np.dot(
349
+ np.power((self.E + 1) / (self.O + 1), self.theta), self.Phi[:, b]
350
+ ),
351
+ )
352
+ self.R[:, b] = self.R[:, b] / np.linalg.norm(self.R[:, b], ord=1, axis=0)
353
+ # STEP 3: Put cells back
354
+ self.E += np.outer(np.sum(self.R[:, b], axis=1), self.Pr_b)
355
+ self.O += np.dot(self.R[:, b], self.Phi[:, b].T)
356
+ return 0
357
+
358
+ def check_convergence(self, i_type):
359
+ obj_old = 0.0
360
+ obj_new = 0.0
361
+ # Clustering, compute new window mean
362
+ if i_type == 0:
363
+ okl = len(self.objective_kmeans)
364
+ for i in range(self.window_size):
365
+ obj_old += self.objective_kmeans[okl - 2 - i]
366
+ obj_new += self.objective_kmeans[okl - 1 - i]
367
+ if (score := (abs(obj_old - obj_new) / abs(obj_old))) < self.epsilon_kmeans:
368
+ return True
369
+ # logger.info("Score: {} >= {}".format(score, self.epsilon_kmeans))
370
+ return False
371
+ # Harmony
372
+ if i_type == 1:
373
+ obj_old = self.objective_harmony[-2]
374
+ obj_new = self.objective_harmony[-1]
375
+ if (
376
+ score := (abs(obj_old - obj_new) / abs(obj_old))
377
+ ) < self.epsilon_harmony:
378
+ # logger.info("Score: {} >= {}".format(score, self.epsilon_harmony))
379
+ return True
380
+ # logger.info("Score: {} >= {}".format(score, self.epsilon_harmony))
381
+ return False
382
+ return True
383
+
384
+
385
+ def safe_entropy(x: np.array):
386
+ y = np.multiply(x, np.log(x))
387
+ y[~np.isfinite(y)] = 0.0
388
+ return y
389
+
390
+
391
+ def moe_correct_ridge(
392
+ Z_orig, Z_cos, Z_corr, R, W, K, Phi_Rk, Phi_moe, lamb, frozen_values=None
393
+ ):
394
+ """
395
+ Z_orig, Z_cos, Z_corr: DxN
396
+ R: KxN
397
+ W: (B+1)xD
398
+ Phi_moe: (B+1)xN
399
+ lamb: (B+1)x(B+1) diag matrix
400
+ """
401
+ Z_corr = Z_orig.copy()
402
+
403
+ if frozen_values is not None:
404
+ update_mask = ~frozen_values
405
+ else:
406
+ update_mask = np.ones(Z_orig.shape[1], dtype=bool)
407
+
408
+ for i in range(K):
409
+ # standard design
410
+ Phi_Rk = Phi_moe * R[i, :]
411
+
412
+ # ridge regression to get W
413
+ x = Phi_Rk @ Phi_moe.T + lamb
414
+ W = np.linalg.inv(x) @ Phi_Rk @ Z_orig.T
415
+ W[0, :] = 0 # don’t remove intercept
416
+
417
+ # apply correction
418
+ Z_corr[:, update_mask] -= (W.T @ Phi_Rk)[:, update_mask]
419
+
420
+ Z_cos = Z_corr / np.linalg.norm(Z_corr, ord=2, axis=0)
421
+ return Z_cos, Z_corr, W, Phi_Rk