sclab 0.1.7__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. sclab/__init__.py +3 -1
  2. sclab/_io.py +83 -12
  3. sclab/_methods_registry.py +65 -0
  4. sclab/_sclab.py +241 -21
  5. sclab/dataset/_dataset.py +4 -6
  6. sclab/dataset/processor/_processor.py +41 -19
  7. sclab/dataset/processor/_results_panel.py +94 -0
  8. sclab/dataset/processor/step/_processor_step_base.py +12 -6
  9. sclab/examples/processor_steps/__init__.py +8 -0
  10. sclab/examples/processor_steps/_cluster.py +2 -2
  11. sclab/examples/processor_steps/_differential_expression.py +329 -0
  12. sclab/examples/processor_steps/_doublet_detection.py +68 -0
  13. sclab/examples/processor_steps/_gene_expression.py +125 -0
  14. sclab/examples/processor_steps/_integration.py +116 -0
  15. sclab/examples/processor_steps/_neighbors.py +26 -6
  16. sclab/examples/processor_steps/_pca.py +13 -8
  17. sclab/examples/processor_steps/_preprocess.py +52 -25
  18. sclab/examples/processor_steps/_qc.py +24 -8
  19. sclab/examples/processor_steps/_umap.py +2 -2
  20. sclab/gui/__init__.py +0 -0
  21. sclab/gui/components/__init__.py +7 -0
  22. sclab/gui/components/_guided_pseudotime.py +482 -0
  23. sclab/gui/components/_transfer_metadata.py +186 -0
  24. sclab/methods/__init__.py +50 -0
  25. sclab/preprocess/__init__.py +26 -0
  26. sclab/preprocess/_cca.py +176 -0
  27. sclab/preprocess/_cca_integrate.py +109 -0
  28. sclab/preprocess/_filter_obs.py +42 -0
  29. sclab/preprocess/_harmony.py +421 -0
  30. sclab/preprocess/_harmony_integrate.py +53 -0
  31. sclab/preprocess/_normalize_weighted.py +65 -0
  32. sclab/preprocess/_pca.py +51 -0
  33. sclab/preprocess/_preprocess.py +155 -0
  34. sclab/preprocess/_qc.py +38 -0
  35. sclab/preprocess/_rpca.py +116 -0
  36. sclab/preprocess/_subset.py +208 -0
  37. sclab/preprocess/_transfer_metadata.py +196 -0
  38. sclab/preprocess/_transform.py +82 -0
  39. sclab/preprocess/_utils.py +96 -0
  40. sclab/scanpy/__init__.py +0 -0
  41. sclab/scanpy/_compat.py +92 -0
  42. sclab/scanpy/_settings.py +526 -0
  43. sclab/scanpy/logging.py +290 -0
  44. sclab/scanpy/plotting/__init__.py +0 -0
  45. sclab/scanpy/plotting/_rcmod.py +73 -0
  46. sclab/scanpy/plotting/palettes.py +221 -0
  47. sclab/scanpy/readwrite.py +1108 -0
  48. sclab/tools/__init__.py +0 -0
  49. sclab/tools/cellflow/__init__.py +0 -0
  50. sclab/tools/cellflow/density_dynamics/__init__.py +0 -0
  51. sclab/tools/cellflow/density_dynamics/_density_dynamics.py +349 -0
  52. sclab/tools/cellflow/pseudotime/__init__.py +0 -0
  53. sclab/tools/cellflow/pseudotime/_pseudotime.py +336 -0
  54. sclab/tools/cellflow/pseudotime/timeseries.py +226 -0
  55. sclab/tools/cellflow/utils/__init__.py +0 -0
  56. sclab/tools/cellflow/utils/density_nd.py +215 -0
  57. sclab/tools/cellflow/utils/interpolate.py +334 -0
  58. sclab/tools/cellflow/utils/periodic_genes.py +106 -0
  59. sclab/tools/cellflow/utils/smoothen.py +124 -0
  60. sclab/tools/cellflow/utils/times.py +55 -0
  61. sclab/tools/differential_expression/__init__.py +7 -0
  62. sclab/tools/differential_expression/_pseudobulk_edger.py +309 -0
  63. sclab/tools/differential_expression/_pseudobulk_helpers.py +290 -0
  64. sclab/tools/differential_expression/_pseudobulk_limma.py +257 -0
  65. sclab/tools/doublet_detection/__init__.py +5 -0
  66. sclab/tools/doublet_detection/_scrublet.py +64 -0
  67. sclab/tools/embedding/__init__.py +0 -0
  68. sclab/tools/imputation/__init__.py +0 -0
  69. sclab/tools/imputation/_alra.py +135 -0
  70. sclab/tools/labeling/__init__.py +6 -0
  71. sclab/tools/labeling/sctype.py +233 -0
  72. sclab/tools/utils/__init__.py +5 -0
  73. sclab/tools/utils/_aggregate_and_filter.py +290 -0
  74. sclab/utils/__init__.py +5 -0
  75. sclab/utils/_write_excel.py +510 -0
  76. {sclab-0.1.7.dist-info → sclab-0.3.4.dist-info}/METADATA +29 -12
  77. sclab-0.3.4.dist-info/RECORD +93 -0
  78. {sclab-0.1.7.dist-info → sclab-0.3.4.dist-info}/WHEEL +1 -1
  79. sclab-0.3.4.dist-info/licenses/LICENSE +29 -0
  80. sclab-0.1.7.dist-info/RECORD +0 -30
@@ -0,0 +1,421 @@
1
+ # harmonypy - A data alignment algorithm.
2
+ # Copyright (C) 2018 Ilya Korsunsky
3
+ # 2019 Kamil Slowikowski <kslowikowski@gmail.com>
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+
18
+ from functools import partial
19
+ import pandas as pd
20
+ import numpy as np
21
+ from sklearn.cluster import KMeans
22
+ import logging
23
+
24
+ # create logger
25
+ logger = logging.getLogger("harmonypy")
26
+ logger.setLevel(logging.DEBUG)
27
+ ch = logging.StreamHandler()
28
+ ch.setLevel(logging.DEBUG)
29
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
30
+ ch.setFormatter(formatter)
31
+ logger.addHandler(ch)
32
+
33
+ # from IPython.core.debugger import set_trace
34
+
35
+
36
+ def run_harmony(
37
+ data_mat: np.ndarray,
38
+ meta_data: pd.DataFrame,
39
+ vars_use,
40
+ theta=None,
41
+ lamb=None,
42
+ sigma=0.1,
43
+ nclust=None,
44
+ tau=0,
45
+ block_size=0.05,
46
+ max_iter_harmony=10,
47
+ max_iter_kmeans=20,
48
+ epsilon_cluster=1e-5,
49
+ epsilon_harmony=1e-4,
50
+ plot_convergence=False,
51
+ verbose=True,
52
+ reference_values=None,
53
+ cluster_prior=None,
54
+ random_state=0,
55
+ cluster_fn="kmeans",
56
+ ):
57
+ """Run Harmony."""
58
+
59
+ # theta = None
60
+ # lamb = None
61
+ # sigma = 0.1
62
+ # nclust = None
63
+ # tau = 0
64
+ # block_size = 0.05
65
+ # epsilon_cluster = 1e-5
66
+ # epsilon_harmony = 1e-4
67
+ # plot_convergence = False
68
+ # verbose = True
69
+ # reference_values = None
70
+ # cluster_prior = None
71
+ # random_state = 0
72
+ # cluster_fn = 'kmeans'. Also accepts a callable object with data, num_clusters parameters
73
+
74
+ N = meta_data.shape[0]
75
+ if data_mat.shape[1] != N:
76
+ data_mat = data_mat.T
77
+
78
+ assert data_mat.shape[1] == N, (
79
+ "data_mat and meta_data do not have the same number of cells"
80
+ )
81
+
82
+ if nclust is None:
83
+ nclust = np.min([np.round(N / 30.0), 100]).astype(int)
84
+
85
+ if type(sigma) is float and nclust > 1:
86
+ sigma = np.repeat(sigma, nclust)
87
+
88
+ if isinstance(vars_use, str):
89
+ vars_use = [vars_use]
90
+
91
+ phi = pd.get_dummies(meta_data[vars_use]).to_numpy().T
92
+ phi_n = meta_data[vars_use].describe().loc["unique"].to_numpy().astype(int)
93
+
94
+ if theta is None:
95
+ theta = np.repeat([1] * len(phi_n), phi_n)
96
+ elif isinstance(theta, float) or isinstance(theta, int):
97
+ theta = np.repeat([theta] * len(phi_n), phi_n)
98
+ elif len(theta) == len(phi_n):
99
+ theta = np.repeat([theta], phi_n)
100
+
101
+ assert len(theta) == np.sum(phi_n), "each batch variable must have a theta"
102
+
103
+ if lamb is None:
104
+ lamb = np.repeat([1] * len(phi_n), phi_n)
105
+ elif isinstance(lamb, float) or isinstance(lamb, int):
106
+ lamb = np.repeat([lamb] * len(phi_n), phi_n)
107
+ elif len(lamb) == len(phi_n):
108
+ lamb = np.repeat([lamb], phi_n)
109
+
110
+ assert len(lamb) == np.sum(phi_n), "each batch variable must have a lambda"
111
+
112
+ # Number of items in each category.
113
+ N_b = phi.sum(axis=1)
114
+ # Proportion of items in each category.
115
+ Pr_b = N_b / N
116
+
117
+ if tau > 0:
118
+ theta = theta * (1 - np.exp(-((N_b / (nclust * tau)) ** 2)))
119
+
120
+ lamb_mat = np.diag(np.insert(lamb, 0, 0))
121
+
122
+ phi_moe = np.vstack((np.repeat(1, N), phi))
123
+
124
+ np.random.seed(random_state)
125
+
126
+ ho = Harmony(
127
+ data_mat,
128
+ phi,
129
+ phi_moe,
130
+ Pr_b,
131
+ sigma,
132
+ theta,
133
+ max_iter_harmony,
134
+ max_iter_kmeans,
135
+ epsilon_cluster,
136
+ epsilon_harmony,
137
+ nclust,
138
+ block_size,
139
+ lamb_mat,
140
+ verbose,
141
+ random_state,
142
+ cluster_fn,
143
+ reference_values,
144
+ )
145
+
146
+ return ho
147
+
148
+
149
+ class Harmony(object):
150
+ def __init__(
151
+ self,
152
+ Z,
153
+ Phi,
154
+ Phi_moe,
155
+ Pr_b,
156
+ sigma,
157
+ theta,
158
+ max_iter_harmony,
159
+ max_iter_kmeans,
160
+ epsilon_kmeans,
161
+ epsilon_harmony,
162
+ K,
163
+ block_size,
164
+ lamb,
165
+ verbose,
166
+ random_state=None,
167
+ cluster_fn="kmeans",
168
+ frozen_values=None,
169
+ ):
170
+ self.Z_corr = np.array(Z)
171
+ self.Z_orig = np.array(Z)
172
+
173
+ self.Z_cos = self.Z_orig / self.Z_orig.max(axis=0)
174
+ self.Z_cos = self.Z_cos / np.linalg.norm(self.Z_cos, ord=2, axis=0)
175
+
176
+ self.Phi = Phi
177
+ self.Phi_moe = Phi_moe
178
+ self.N = self.Z_corr.shape[1]
179
+ self.Pr_b = Pr_b
180
+ self.B = self.Phi.shape[0] # number of batch variables
181
+ self.d = self.Z_corr.shape[0]
182
+ self.window_size = 3
183
+ self.epsilon_kmeans = epsilon_kmeans
184
+ self.epsilon_harmony = epsilon_harmony
185
+ self.reference_values = frozen_values
186
+
187
+ self.lamb = lamb
188
+ self.sigma = sigma
189
+ self.sigma_prior = sigma
190
+ self.block_size = block_size
191
+ self.K = K # number of clusters
192
+ self.max_iter_harmony = max_iter_harmony
193
+ self.max_iter_kmeans = max_iter_kmeans
194
+ self.verbose = verbose
195
+ self.theta = theta
196
+
197
+ self.objective_harmony = []
198
+ self.objective_kmeans = []
199
+ self.objective_kmeans_dist = []
200
+ self.objective_kmeans_entropy = []
201
+ self.objective_kmeans_cross = []
202
+ self.kmeans_rounds = []
203
+
204
+ self.allocate_buffers()
205
+ if cluster_fn == "kmeans":
206
+ cluster_fn = partial(Harmony._cluster_kmeans, random_state=random_state)
207
+ self.init_cluster(cluster_fn)
208
+ self.harmonize(self.max_iter_harmony, self.verbose)
209
+
210
+ def result(self):
211
+ return self.Z_corr
212
+
213
+ def allocate_buffers(self):
214
+ self._scale_dist = np.zeros((self.K, self.N))
215
+ self.dist_mat = np.zeros((self.K, self.N))
216
+ self.O = np.zeros((self.K, self.B))
217
+ self.E = np.zeros((self.K, self.B))
218
+ self.W = np.zeros((self.B + 1, self.d))
219
+ self.Phi_Rk = np.zeros((self.B + 1, self.N))
220
+
221
+ @staticmethod
222
+ def _cluster_kmeans(data, K, random_state):
223
+ # Start with cluster centroids
224
+ logger.info("Computing initial centroids with sklearn.KMeans...")
225
+ model = KMeans(
226
+ n_clusters=K,
227
+ init="k-means++",
228
+ n_init=10,
229
+ max_iter=25,
230
+ random_state=random_state,
231
+ )
232
+ model.fit(data)
233
+ km_centroids, km_labels = model.cluster_centers_, model.labels_
234
+ logger.info("sklearn.KMeans initialization complete.")
235
+ return km_centroids
236
+
237
+ def init_cluster(self, cluster_fn):
238
+ self.Y = cluster_fn(self.Z_cos.T, self.K).T
239
+ # (1) Normalize
240
+ self.Y = self.Y / np.linalg.norm(self.Y, ord=2, axis=0)
241
+ # (2) Assign cluster probabilities
242
+ self.dist_mat = 2 * (1 - np.dot(self.Y.T, self.Z_cos))
243
+ self.R = -self.dist_mat
244
+ self.R = self.R / self.sigma[:, None]
245
+ self.R -= np.max(self.R, axis=0)
246
+ self.R = np.exp(self.R)
247
+ self.R = self.R / np.sum(self.R, axis=0)
248
+ # (3) Batch diversity statistics
249
+ self.E = np.outer(np.sum(self.R, axis=1), self.Pr_b)
250
+ self.O = np.inner(self.R, self.Phi)
251
+ self.compute_objective()
252
+ # Save results
253
+ self.objective_harmony.append(self.objective_kmeans[-1])
254
+
255
+ def compute_objective(self):
256
+ kmeans_error = np.sum(self.R * self.dist_mat)
257
+ # Entropy
258
+ _entropy = np.sum(safe_entropy(self.R) * self.sigma[:, np.newaxis])
259
+ # Cross Entropy
260
+ x = self.R * self.sigma[:, np.newaxis]
261
+ y = np.tile(self.theta[:, np.newaxis], self.K).T
262
+ z = np.log((self.O + 1) / (self.E + 1))
263
+ w = np.dot(y * z, self.Phi)
264
+ _cross_entropy = np.sum(x * w)
265
+ # Save results
266
+ # print(f"{kmeans_error=}, {_entropy=}, {_cross_entropy=}")
267
+ self.objective_kmeans.append(kmeans_error + _entropy + _cross_entropy)
268
+ self.objective_kmeans_dist.append(kmeans_error)
269
+ self.objective_kmeans_entropy.append(_entropy)
270
+ self.objective_kmeans_cross.append(_cross_entropy)
271
+
272
+ def harmonize(self, iter_harmony=10, verbose=True):
273
+ converged = False
274
+ for i in range(1, iter_harmony + 1):
275
+ if verbose:
276
+ # logger.info("Iteration {} of {}".format(i, iter_harmony))
277
+ pass
278
+ # STEP 1: Clustering
279
+ self.cluster()
280
+ # STEP 2: Regress out covariates
281
+ # self.moe_correct_ridge()
282
+ self.Z_cos, self.Z_corr, self.W, self.Phi_Rk = moe_correct_ridge(
283
+ self.Z_orig,
284
+ self.Z_cos,
285
+ self.Z_corr,
286
+ self.R,
287
+ self.W,
288
+ self.K,
289
+ self.Phi_Rk,
290
+ self.Phi_moe,
291
+ self.lamb,
292
+ self.reference_values,
293
+ )
294
+ # STEP 3: Check for convergence
295
+ converged = self.check_convergence(1)
296
+ if converged:
297
+ if verbose:
298
+ logger.info(
299
+ "Converged after {} iteration{}".format(i, "s" if i > 1 else "")
300
+ )
301
+ break
302
+ if verbose and not converged:
303
+ logger.info("Stopped before convergence")
304
+ return 0
305
+
306
+ def cluster(self):
307
+ # Z_cos has changed
308
+ # R is assumed to not have changed
309
+ # Update Y to match new integrated data
310
+ self.dist_mat = 2 * (1 - np.dot(self.Y.T, self.Z_cos))
311
+ for i in range(self.max_iter_kmeans):
312
+ # print("kmeans {}".format(i))
313
+ # STEP 1: Update Y
314
+ self.Y = np.dot(self.Z_cos, self.R.T)
315
+ self.Y = self.Y / np.linalg.norm(self.Y, ord=2, axis=0)
316
+ # STEP 2: Update dist_mat
317
+ self.dist_mat = 2 * (1 - np.dot(self.Y.T, self.Z_cos))
318
+ # STEP 3: Update R
319
+ self.update_R()
320
+ # STEP 4: Check for convergence
321
+ self.compute_objective()
322
+ if i > self.window_size:
323
+ converged = self.check_convergence(0)
324
+ if converged:
325
+ break
326
+ self.kmeans_rounds.append(i)
327
+ self.objective_harmony.append(self.objective_kmeans[-1])
328
+ return 0
329
+
330
+ def update_R(self):
331
+ self._scale_dist = -self.dist_mat
332
+ self._scale_dist = self._scale_dist / self.sigma[:, None]
333
+ self._scale_dist -= np.max(self._scale_dist, axis=0)
334
+ self._scale_dist = np.exp(self._scale_dist)
335
+ # Update cells in blocks
336
+ update_order = np.arange(self.N)
337
+ np.random.shuffle(update_order)
338
+ n_blocks = np.ceil(1 / self.block_size).astype(int)
339
+ blocks = np.array_split(update_order, n_blocks)
340
+ for b in blocks:
341
+ # STEP 1: Remove cells
342
+ self.E -= np.outer(np.sum(self.R[:, b], axis=1), self.Pr_b)
343
+ self.O -= np.dot(self.R[:, b], self.Phi[:, b].T)
344
+ # STEP 2: Recompute R for removed cells
345
+ self.R[:, b] = self._scale_dist[:, b]
346
+ self.R[:, b] = np.multiply(
347
+ self.R[:, b],
348
+ np.dot(
349
+ np.power((self.E + 1) / (self.O + 1), self.theta), self.Phi[:, b]
350
+ ),
351
+ )
352
+ self.R[:, b] = self.R[:, b] / np.linalg.norm(self.R[:, b], ord=1, axis=0)
353
+ # STEP 3: Put cells back
354
+ self.E += np.outer(np.sum(self.R[:, b], axis=1), self.Pr_b)
355
+ self.O += np.dot(self.R[:, b], self.Phi[:, b].T)
356
+ return 0
357
+
358
+ def check_convergence(self, i_type):
359
+ obj_old = 0.0
360
+ obj_new = 0.0
361
+ # Clustering, compute new window mean
362
+ if i_type == 0:
363
+ okl = len(self.objective_kmeans)
364
+ for i in range(self.window_size):
365
+ obj_old += self.objective_kmeans[okl - 2 - i]
366
+ obj_new += self.objective_kmeans[okl - 1 - i]
367
+ if (score := (abs(obj_old - obj_new) / abs(obj_old))) < self.epsilon_kmeans:
368
+ return True
369
+ # logger.info("Score: {} >= {}".format(score, self.epsilon_kmeans))
370
+ return False
371
+ # Harmony
372
+ if i_type == 1:
373
+ obj_old = self.objective_harmony[-2]
374
+ obj_new = self.objective_harmony[-1]
375
+ if (
376
+ score := (abs(obj_old - obj_new) / abs(obj_old))
377
+ ) < self.epsilon_harmony:
378
+ # logger.info("Score: {} >= {}".format(score, self.epsilon_harmony))
379
+ return True
380
+ # logger.info("Score: {} >= {}".format(score, self.epsilon_harmony))
381
+ return False
382
+ return True
383
+
384
+
385
+ def safe_entropy(x: np.array):
386
+ y = np.multiply(x, np.log(x))
387
+ y[~np.isfinite(y)] = 0.0
388
+ return y
389
+
390
+
391
+ def moe_correct_ridge(
392
+ Z_orig, Z_cos, Z_corr, R, W, K, Phi_Rk, Phi_moe, lamb, frozen_values=None
393
+ ):
394
+ """
395
+ Z_orig, Z_cos, Z_corr: DxN
396
+ R: KxN
397
+ W: (B+1)xD
398
+ Phi_moe: (B+1)xN
399
+ lamb: (B+1)x(B+1) diag matrix
400
+ """
401
+ Z_corr = Z_orig.copy()
402
+
403
+ if frozen_values is not None:
404
+ update_mask = ~frozen_values
405
+ else:
406
+ update_mask = np.ones(Z_orig.shape[1], dtype=bool)
407
+
408
+ for i in range(K):
409
+ # standard design
410
+ Phi_Rk = Phi_moe * R[i, :]
411
+
412
+ # ridge regression to get W
413
+ x = Phi_Rk @ Phi_moe.T + lamb
414
+ W = np.linalg.inv(x) @ Phi_Rk @ Z_orig.T
415
+ W[0, :] = 0 # don’t remove intercept
416
+
417
+ # apply correction
418
+ Z_corr[:, update_mask] -= (W.T @ Phi_Rk)[:, update_mask]
419
+
420
+ Z_cos = Z_corr / np.linalg.norm(Z_corr, ord=2, axis=0)
421
+ return Z_cos, Z_corr, W, Phi_Rk
@@ -0,0 +1,53 @@
1
+ """Use harmony to integrate cells from different experiments.
2
+
3
+ Note: code adapted from scanpy to use a custom version of harmonypy
4
+
5
+ Harmony:
6
+ Korsunsky, I., Millard, N., Fan, J. et al. Fast, sensitive and accurate integration of single-cell data with Harmony.
7
+ Nat Methods 16, 1289-1296 (2019). https://doi.org/10.1038/s41592-019-0619-0
8
+
9
+ Scanpy:
10
+ Wolf, F., Angerer, P. & Theis, F. SCANPY: large-scale single-cell gene expression data analysis.
11
+ Genome Biol 19, 15 (2018). https://doi.org/10.1186/s13059-017-1382-0
12
+
13
+ Scverse:
14
+ Virshup, I., Bredikhin, D., Heumos, L. et al. The scverse project provides a computational ecosystem for single-cell omics data analysis.
15
+ Nat Biotechnol 41, 604-606 (2023). https://doi.org/10.1038/s41587-023-01733-8
16
+ """
17
+
18
+ from collections.abc import Sequence
19
+
20
+ import numpy as np
21
+ from anndata import AnnData
22
+
23
+ from ._harmony import run_harmony
24
+
25
+
26
+ def harmony_integrate(
27
+ adata: AnnData,
28
+ key: str | Sequence[str],
29
+ *,
30
+ basis: str = "X_pca",
31
+ adjusted_basis: str | None = None,
32
+ reference_batch: str | list[str] | None = None,
33
+ **kwargs,
34
+ ):
35
+ """Use harmonypy :cite:p:`Korsunsky2019` to integrate different experiments."""
36
+
37
+ if adjusted_basis is None:
38
+ adjusted_basis = f"{basis}_harmony"
39
+
40
+ if isinstance(reference_batch, str):
41
+ reference_batch = [reference_batch]
42
+
43
+ if reference_batch is not None:
44
+ reference_values = np.zeros(adata.n_obs, dtype=bool)
45
+ for batch in reference_batch:
46
+ reference_values |= adata.obs[key].values == batch
47
+ kwargs["reference_values"] = reference_values
48
+
49
+ X = adata.obsm[basis].astype(np.float64)
50
+
51
+ harmony_out = run_harmony(X, adata.obs, key, **kwargs)
52
+
53
+ adata.obsm[adjusted_basis] = harmony_out.Z_corr.T
@@ -0,0 +1,65 @@
1
+ import warnings
2
+
3
+ import numpy as np
4
+ from anndata import AnnData, ImplicitModificationWarning
5
+ from scipy.sparse import csr_matrix, issparse
6
+
7
+
8
+ def normalize_weighted(
9
+ adata: AnnData,
10
+ target_scale: float | None = None,
11
+ batch_key: str | None = None,
12
+ q: float = 0.99,
13
+ ) -> None:
14
+ if batch_key is not None:
15
+ for _, idx in adata.obs.groupby(batch_key, observed=True).groups.items():
16
+ with warnings.catch_warnings():
17
+ warnings.filterwarnings(
18
+ "ignore",
19
+ category=ImplicitModificationWarning,
20
+ message="Modifying `X` on a view results in data being overridden",
21
+ )
22
+ normalize_weighted(adata[idx], target_scale, None)
23
+
24
+ return
25
+
26
+ target_scale = None
27
+
28
+ X: csr_matrix
29
+ Y: csr_matrix
30
+ Z: csr_matrix
31
+
32
+ X = adata.X
33
+ assert issparse(X)
34
+
35
+ with warnings.catch_warnings():
36
+ warnings.filterwarnings(
37
+ "ignore", category=RuntimeWarning, message="divide by zero"
38
+ )
39
+ Y = X.multiply(1 / X.sum(axis=0))
40
+ Y = Y.tocsr()
41
+ Y.eliminate_zeros()
42
+ Y.data = -Y.data * np.log(Y.data)
43
+ entropy = Y.sum(axis=0)
44
+ entropy[:, entropy.A1 < np.quantile(entropy.A1, q)] *= 0.0
45
+
46
+ Z = X.multiply(entropy)
47
+ Z = Z.tocsr()
48
+ Z.eliminate_zeros()
49
+
50
+ with warnings.catch_warnings():
51
+ warnings.filterwarnings(
52
+ "ignore", category=RuntimeWarning, message="divide by zero"
53
+ )
54
+ scale = Z.sum(axis=1)
55
+ Z = X.multiply(1 / scale)
56
+ Z = Z.tocsr()
57
+
58
+ if target_scale is None:
59
+ target_scale = np.median(scale.A1[scale.A1 > 0])
60
+
61
+ Z = Z * target_scale
62
+
63
+ adata.X = Z
64
+
65
+ return
@@ -0,0 +1,51 @@
1
+ from anndata import AnnData
2
+
3
+
4
+ def pca(
5
+ adata: AnnData,
6
+ layer: str | None = None,
7
+ n_comps: int = 30,
8
+ mask_var: str | None = None,
9
+ batch_key: str | None = None,
10
+ reference_batch: str | None = None,
11
+ zero_center: bool = False,
12
+ ):
13
+ import scanpy as sc
14
+
15
+ pca_kwargs = dict(
16
+ n_comps=n_comps,
17
+ layer=layer,
18
+ mask_var=mask_var,
19
+ svd_solver="arpack",
20
+ )
21
+
22
+ if reference_batch:
23
+ obs_mask = adata.obs[batch_key] == reference_batch
24
+ adata_ref = adata[obs_mask].copy()
25
+ if mask_var == "highly_variable":
26
+ sc.pp.highly_variable_genes(
27
+ adata_ref, layer=f"{layer if layer else 'X'}_log1p", flavor="seurat"
28
+ )
29
+ hvg_seurat = adata_ref.var["highly_variable"]
30
+ sc.pp.highly_variable_genes(
31
+ adata_ref,
32
+ layer=layer,
33
+ flavor="seurat_v3_paper",
34
+ n_top_genes=hvg_seurat.sum(),
35
+ )
36
+ hvg_seurat_v3 = adata_ref.var["highly_variable"]
37
+ adata_ref.var["highly_variable"] = hvg_seurat | hvg_seurat_v3
38
+
39
+ sc.pp.pca(adata_ref, **pca_kwargs)
40
+ uns_pca = adata_ref.uns["pca"]
41
+ uns_pca["reference_batch"] = reference_batch
42
+ PCs = adata_ref.varm["PCs"]
43
+ adata.obsm["X_pca"] = adata.X.dot(PCs)
44
+ adata.uns["pca"] = uns_pca
45
+ adata.varm["PCs"] = PCs
46
+ else:
47
+ sc.pp.pca(adata, **pca_kwargs)
48
+ adata.obsm["X_pca"] = adata.X.dot(adata.varm["PCs"])
49
+
50
+ if zero_center:
51
+ adata.obsm["X_pca"] -= adata.obsm["X_pca"].mean(axis=0, keepdims=True)