sclab 0.1.7__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. sclab/__init__.py +3 -1
  2. sclab/_io.py +83 -12
  3. sclab/_methods_registry.py +65 -0
  4. sclab/_sclab.py +241 -21
  5. sclab/dataset/_dataset.py +4 -6
  6. sclab/dataset/processor/_processor.py +41 -19
  7. sclab/dataset/processor/_results_panel.py +94 -0
  8. sclab/dataset/processor/step/_processor_step_base.py +12 -6
  9. sclab/examples/processor_steps/__init__.py +8 -0
  10. sclab/examples/processor_steps/_cluster.py +2 -2
  11. sclab/examples/processor_steps/_differential_expression.py +329 -0
  12. sclab/examples/processor_steps/_doublet_detection.py +68 -0
  13. sclab/examples/processor_steps/_gene_expression.py +125 -0
  14. sclab/examples/processor_steps/_integration.py +116 -0
  15. sclab/examples/processor_steps/_neighbors.py +26 -6
  16. sclab/examples/processor_steps/_pca.py +13 -8
  17. sclab/examples/processor_steps/_preprocess.py +52 -25
  18. sclab/examples/processor_steps/_qc.py +24 -8
  19. sclab/examples/processor_steps/_umap.py +2 -2
  20. sclab/gui/__init__.py +0 -0
  21. sclab/gui/components/__init__.py +7 -0
  22. sclab/gui/components/_guided_pseudotime.py +482 -0
  23. sclab/gui/components/_transfer_metadata.py +186 -0
  24. sclab/methods/__init__.py +50 -0
  25. sclab/preprocess/__init__.py +26 -0
  26. sclab/preprocess/_cca.py +176 -0
  27. sclab/preprocess/_cca_integrate.py +109 -0
  28. sclab/preprocess/_filter_obs.py +42 -0
  29. sclab/preprocess/_harmony.py +421 -0
  30. sclab/preprocess/_harmony_integrate.py +53 -0
  31. sclab/preprocess/_normalize_weighted.py +65 -0
  32. sclab/preprocess/_pca.py +51 -0
  33. sclab/preprocess/_preprocess.py +155 -0
  34. sclab/preprocess/_qc.py +38 -0
  35. sclab/preprocess/_rpca.py +116 -0
  36. sclab/preprocess/_subset.py +208 -0
  37. sclab/preprocess/_transfer_metadata.py +196 -0
  38. sclab/preprocess/_transform.py +82 -0
  39. sclab/preprocess/_utils.py +96 -0
  40. sclab/scanpy/__init__.py +0 -0
  41. sclab/scanpy/_compat.py +92 -0
  42. sclab/scanpy/_settings.py +526 -0
  43. sclab/scanpy/logging.py +290 -0
  44. sclab/scanpy/plotting/__init__.py +0 -0
  45. sclab/scanpy/plotting/_rcmod.py +73 -0
  46. sclab/scanpy/plotting/palettes.py +221 -0
  47. sclab/scanpy/readwrite.py +1108 -0
  48. sclab/tools/__init__.py +0 -0
  49. sclab/tools/cellflow/__init__.py +0 -0
  50. sclab/tools/cellflow/density_dynamics/__init__.py +0 -0
  51. sclab/tools/cellflow/density_dynamics/_density_dynamics.py +349 -0
  52. sclab/tools/cellflow/pseudotime/__init__.py +0 -0
  53. sclab/tools/cellflow/pseudotime/_pseudotime.py +336 -0
  54. sclab/tools/cellflow/pseudotime/timeseries.py +226 -0
  55. sclab/tools/cellflow/utils/__init__.py +0 -0
  56. sclab/tools/cellflow/utils/density_nd.py +215 -0
  57. sclab/tools/cellflow/utils/interpolate.py +334 -0
  58. sclab/tools/cellflow/utils/periodic_genes.py +106 -0
  59. sclab/tools/cellflow/utils/smoothen.py +124 -0
  60. sclab/tools/cellflow/utils/times.py +55 -0
  61. sclab/tools/differential_expression/__init__.py +7 -0
  62. sclab/tools/differential_expression/_pseudobulk_edger.py +309 -0
  63. sclab/tools/differential_expression/_pseudobulk_helpers.py +290 -0
  64. sclab/tools/differential_expression/_pseudobulk_limma.py +257 -0
  65. sclab/tools/doublet_detection/__init__.py +5 -0
  66. sclab/tools/doublet_detection/_scrublet.py +64 -0
  67. sclab/tools/embedding/__init__.py +0 -0
  68. sclab/tools/imputation/__init__.py +0 -0
  69. sclab/tools/imputation/_alra.py +135 -0
  70. sclab/tools/labeling/__init__.py +6 -0
  71. sclab/tools/labeling/sctype.py +233 -0
  72. sclab/tools/utils/__init__.py +5 -0
  73. sclab/tools/utils/_aggregate_and_filter.py +290 -0
  74. sclab/utils/__init__.py +5 -0
  75. sclab/utils/_write_excel.py +510 -0
  76. {sclab-0.1.7.dist-info → sclab-0.3.4.dist-info}/METADATA +29 -12
  77. sclab-0.3.4.dist-info/RECORD +93 -0
  78. {sclab-0.1.7.dist-info → sclab-0.3.4.dist-info}/WHEEL +1 -1
  79. sclab-0.3.4.dist-info/licenses/LICENSE +29 -0
  80. sclab-0.1.7.dist-info/RECORD +0 -30
@@ -0,0 +1,257 @@
1
+ import pandas as pd
2
+ from anndata import AnnData
3
+
4
+ from ._pseudobulk_helpers import aggregate_and_filter
5
+
6
+
7
+ def pseudobulk_limma(
8
+ adata_: AnnData,
9
+ group_key: str,
10
+ condition_group: str | list[str] | None = None,
11
+ reference_group: str | None = None,
12
+ cell_identity_key: str | None = None,
13
+ batch_key: str | None = None,
14
+ layer: str | None = None,
15
+ replicas_per_group: int = 5,
16
+ min_cells_per_group: int = 30,
17
+ bootstrap_sampling: bool = False,
18
+ use_cells: dict[str, list[str]] | None = None,
19
+ aggregate: bool = True,
20
+ verbosity: int = 0,
21
+ ) -> dict[str, pd.DataFrame]:
22
+ _try_imports()
23
+ import anndata2ri # noqa: F401
24
+ import rpy2.robjects as robjects
25
+ from rpy2.rinterface_lib.embedded import RRuntimeError # noqa: F401
26
+ from rpy2.robjects import pandas2ri # noqa: F401
27
+ from rpy2.robjects.conversion import localconverter # noqa: F401
28
+
29
+ R = robjects.r
30
+
31
+ if aggregate:
32
+ aggr_adata = aggregate_and_filter(
33
+ adata_,
34
+ group_key,
35
+ cell_identity_key,
36
+ layer,
37
+ replicas_per_group,
38
+ min_cells_per_group,
39
+ bootstrap_sampling,
40
+ use_cells,
41
+ )
42
+ else:
43
+ aggr_adata = adata_.copy()
44
+
45
+ with localconverter(anndata2ri.converter):
46
+ R.assign("aggr_adata", aggr_adata)
47
+
48
+ # defines the R function for fitting the model with limma
49
+ R(_fit_model_r_script)
50
+
51
+ if condition_group is None:
52
+ condition_group_list = aggr_adata.obs[group_key].unique()
53
+ elif isinstance(condition_group, str):
54
+ condition_group_list = [condition_group]
55
+ else:
56
+ condition_group_list = condition_group
57
+
58
+ if cell_identity_key is not None:
59
+ cids = aggr_adata.obs[cell_identity_key].unique()
60
+ else:
61
+ cids = [""]
62
+
63
+ tt_dict = {}
64
+ for condition_group in condition_group_list:
65
+ if reference_group is not None and condition_group == reference_group:
66
+ continue
67
+
68
+ if verbosity > 0:
69
+ print(f"Fitting model for {condition_group}...")
70
+
71
+ if reference_group is not None:
72
+ gk = group_key
73
+ else:
74
+ gk = f"{group_key}_{condition_group}"
75
+
76
+ try:
77
+ R(f"""
78
+ outs <- fit_limma_model(aggr_adata, "{gk}", "{cell_identity_key}", verbosity = {verbosity})
79
+ fit <- outs$fit
80
+ v <- outs$v
81
+ """)
82
+
83
+ except RRuntimeError as e:
84
+ print("Error fitting model for", condition_group)
85
+ print("Error:", e)
86
+ print("Skipping...", flush=True)
87
+ continue
88
+
89
+ if reference_group is None:
90
+ new_contrasts_tuples = [
91
+ (
92
+ condition_group, # common prefix
93
+ "", # condition group
94
+ "not", # reference group
95
+ cid, # cell identity
96
+ )
97
+ for cid in cids
98
+ ]
99
+
100
+ else:
101
+ new_contrasts_tuples = [
102
+ (
103
+ "", # common prefix
104
+ condition_group, # condition group
105
+ reference_group, # reference group
106
+ cid, # cell identity
107
+ )
108
+ for cid in cids
109
+ ]
110
+
111
+ new_contrasts = [
112
+ f"group{cnd}{prefix}_{cid}".strip("_")
113
+ + "-"
114
+ + f"group{ref}{prefix}_{cid}".strip("_")
115
+ for prefix, cnd, ref, cid in new_contrasts_tuples
116
+ ]
117
+
118
+ for contrast, contrast_tuple in zip(new_contrasts, new_contrasts_tuples):
119
+ prefix, cnd, ref, cid = contrast_tuple
120
+
121
+ if ref == "not":
122
+ cnd, ref = "", "rest"
123
+
124
+ contrast_key = f"{prefix}{cnd}_vs_{ref}"
125
+ if cid:
126
+ contrast_key = f"{cell_identity_key}:{cid}|{contrast_key}"
127
+
128
+ if verbosity > 0:
129
+ print(f"Computing contrast: {contrast_key}... ({contrast})")
130
+
131
+ R(f"myContrast <- makeContrasts('{contrast}', levels = v$design)")
132
+ R("fit2 <- contrasts.fit(fit, myContrast)")
133
+ R("fit2 <- eBayes(fit2)")
134
+ R("tt <- topTable(fit2, n = Inf)")
135
+ tt: pd.DataFrame = pandas2ri.rpy2py(R("tt"))
136
+ tt.index.name = "gene_ids"
137
+
138
+ genes = tt.index
139
+ cnd, ref = [c[5:] for c in contrast.split("-")]
140
+ tt["pct_expr_cnd"] = aggr_adata.var[f"pct_expr_{cnd}"].loc[genes]
141
+ tt["pct_expr_ref"] = aggr_adata.var[f"pct_expr_{ref}"].loc[genes]
142
+ tt["num_expr_cnd"] = aggr_adata.var[f"num_expr_{cnd}"].loc[genes]
143
+ tt["num_expr_ref"] = aggr_adata.var[f"num_expr_{ref}"].loc[genes]
144
+ tt["tot_expr_cnd"] = aggr_adata.var[f"tot_expr_{cnd}"].loc[genes]
145
+ tt["tot_expr_ref"] = aggr_adata.var[f"tot_expr_{ref}"].loc[genes]
146
+ tt["mean_cnd"] = tt["tot_expr_cnd"] / tt["num_expr_cnd"]
147
+ tt["mean_ref"] = tt["tot_expr_ref"] / tt["num_expr_ref"]
148
+ tt_dict[contrast_key] = tt
149
+
150
+ return tt_dict
151
+
152
+
153
+ _fit_model_r_script = """
154
+ suppressPackageStartupMessages({
155
+ library(edgeR)
156
+ library(limma)
157
+ library(MAST)
158
+ })
159
+
160
+ fit_limma_model <- function(adata_, group_key, cell_identity_key = "None", batch_key = "None", verbosity = 0){
161
+
162
+ if (verbosity > 0){
163
+ cat("Group key:", group_key, "\n")
164
+ cat("Cell identity key:", cell_identity_key, "\n")
165
+ }
166
+
167
+ # create a vector that is concatentation of condition and cell type that we will later use with contrasts
168
+ if (cell_identity_key == "None"){
169
+ group <- colData(adata_)[[group_key]]
170
+ } else {
171
+ group <- paste0(colData(adata_)[[group_key]], "_", colData(adata_)[[cell_identity_key]])
172
+ }
173
+
174
+ if (verbosity > 1){
175
+ cat("Group(s):", group, "\n")
176
+ }
177
+
178
+ group <- factor(group)
179
+ replica <- factor(colData(adata_)$replica)
180
+
181
+ # create a design matrix
182
+ if (batch_key == "None"){
183
+ design <- model.matrix(~ 0 + group + replica)
184
+ } else {
185
+ batch <- factor(colData(adata_)[[batch_key]])
186
+ design <- model.matrix(~ 0 + group + replica + batch)
187
+ }
188
+ colnames(design) <- make.names(colnames(design))
189
+
190
+ # create an edgeR object with counts and grouping factor
191
+ y <- DGEList(assay(adata_, "X"), group = group)
192
+
193
+ # filter out genes with low counts
194
+ if (verbosity > 1){
195
+ cat("Dimensions before subsetting:", dim(y), "\n")
196
+ }
197
+
198
+ keep <- filterByExpr(y, design = design)
199
+ y <- y[keep, , keep.lib.sizes=FALSE]
200
+ if (verbosity > 1){
201
+ cat("Dimensions after subsetting:", dim(y), "\n")
202
+ }
203
+
204
+ # normalize
205
+ y <- calcNormFactors(y)
206
+
207
+ # Apply voom transformation to prepare for linear modeling
208
+ v <- voom(y, design = design)
209
+
210
+ # fit the linear model
211
+ fit <- lmFit(v, design)
212
+ ne <- limma::nonEstimable(design)
213
+ if (!is.null(ne) && verbosity > 0) cat("Non-estimable:", ne, "\n")
214
+ fit <- eBayes(fit)
215
+
216
+ return(list("fit"=fit, "design"=design, "v"=v))
217
+ }
218
+ """
219
+
220
+
221
+ def _try_imports():
222
+ try:
223
+ import rpy2.robjects as robjects
224
+ from rpy2.robjects.packages import PackageNotInstalledError, importr
225
+
226
+ robjects.r("options(warn=-1)")
227
+ import anndata2ri # noqa: F401
228
+ from rpy2.rinterface_lib.embedded import RRuntimeError # noqa: F401
229
+ from rpy2.robjects import numpy2ri, pandas2ri # noqa: F401
230
+ from rpy2.robjects.conversion import localconverter # noqa: F401
231
+
232
+ importr("edgeR")
233
+ importr("limma")
234
+ importr("MAST")
235
+ importr("SingleCellExperiment")
236
+
237
+ except ModuleNotFoundError:
238
+ message = (
239
+ "pseudobulk_limma requires rpy2 and anndata2ri to be installed.\n"
240
+ "please install with one of the following:\n"
241
+ "$ pip install rpy2 anndata2ri\n"
242
+ "or\n"
243
+ "$ conda install -c conda-forge rpy2 anndata2ri\n"
244
+ )
245
+ print(message)
246
+ raise ModuleNotFoundError(message)
247
+
248
+ except PackageNotInstalledError:
249
+ message = (
250
+ "pseudobulk_limma requires the following R packages to be installed: limma, edgeR, MAST, and SingleCellExperiment.\n"
251
+ "> \n"
252
+ "> if (!require('BiocManager', quietly = TRUE)) install.packages('BiocManager');\n"
253
+ "> BiocManager::install(c('limma', 'edgeR', 'MAST', 'SingleCellExperiment'));\n"
254
+ "> \n"
255
+ )
256
+ print(message)
257
+ raise ImportError(message)
@@ -0,0 +1,5 @@
1
+ from ._scrublet import scrublet
2
+
3
+ __all__ = [
4
+ "scrublet",
5
+ ]
@@ -0,0 +1,64 @@
1
+ from importlib.util import find_spec
2
+ from typing import Any
3
+
4
+ import pandas as pd
5
+ from anndata import AnnData
6
+ from numpy import ndarray
7
+
8
+
9
+ def scrublet(
10
+ adata: AnnData,
11
+ layer: str = "X",
12
+ key_added: str = "scrublet",
13
+ total_counts: ndarray | None = None,
14
+ sim_doublet_ratio: float = 2.0,
15
+ n_neighbors: int = None,
16
+ expected_doublet_rate: float = 0.1,
17
+ stdev_doublet_rate: float = 0.02,
18
+ random_state: int = 0,
19
+ scrub_doublets_kwargs: dict[str, Any] = dict(
20
+ synthetic_doublet_umi_subsampling=1.0,
21
+ use_approx_neighbors=True,
22
+ distance_metric="euclidean",
23
+ get_doublet_neighbor_parents=False,
24
+ min_counts=3,
25
+ min_cells=3,
26
+ min_gene_variability_pctl=85,
27
+ log_transform=False,
28
+ mean_center=True,
29
+ normalize_variance=True,
30
+ n_prin_comps=30,
31
+ svd_solver="arpack",
32
+ verbose=True,
33
+ ),
34
+ ):
35
+ if find_spec("scrublet") is None:
36
+ raise ImportError(
37
+ "scrublet is not installed. Install with:\npip install scrublet"
38
+ )
39
+ from scrublet import Scrublet # noqa: E402
40
+
41
+ if layer == "X":
42
+ X = adata.X
43
+ else:
44
+ X = adata.layers[layer]
45
+
46
+ scrub = Scrublet(
47
+ counts_matrix=X,
48
+ total_counts=total_counts,
49
+ sim_doublet_ratio=sim_doublet_ratio,
50
+ n_neighbors=n_neighbors,
51
+ expected_doublet_rate=expected_doublet_rate,
52
+ stdev_doublet_rate=stdev_doublet_rate,
53
+ random_state=random_state,
54
+ )
55
+
56
+ _scores, labels = scrub.scrub_doublets(**scrub_doublets_kwargs)
57
+ if labels is not None:
58
+ _labels = list(map(lambda v: "doublet" if v else "singlet", labels))
59
+ _labels = pd.Categorical(_labels, ["singlet", "doublet"])
60
+ adata.obs[f"{key_added}_label"] = _labels
61
+ else:
62
+ adata.obs[f"{key_added}_label"] = "singlet"
63
+
64
+ adata.obs[f"{key_added}_score"] = _scores
File without changes
File without changes
@@ -0,0 +1,135 @@
1
+ import gc
2
+
3
+ import numpy as np
4
+ from numpy import float32
5
+ from numpy.typing import NDArray
6
+ from scipy.sparse import csc_matrix, csr_matrix, issparse
7
+
8
+
9
+ def _alra_on_ndarray(
10
+ data: NDArray | csr_matrix,
11
+ ) -> tuple[NDArray[float32], NDArray[float32]]:
12
+ """
13
+ Run ALRA on the given data.
14
+
15
+ Parameters
16
+ ----------
17
+ data : NDArray | csr_matrix
18
+ Input data to impute.
19
+
20
+ Returns
21
+ -------
22
+ data_aprx : NDArray
23
+ Approximated data.
24
+ data_alra : NDArray
25
+ Imputed data.
26
+ """
27
+ import rpy2.robjects as robjects
28
+ import rpy2.robjects.numpy2ri
29
+ from rpy2.robjects.packages import importr
30
+
31
+ rpy2.robjects.numpy2ri.activate()
32
+ R = robjects.r
33
+ alra = importr("ALRA")
34
+
35
+ if issparse(data):
36
+ data = np.ascontiguousarray(data.todense("C"), dtype=np.float32)
37
+
38
+ # convert to R object
39
+ r_X = R.matrix(data, nrow=data.shape[0], ncol=data.shape[1])
40
+ # run ALRA
41
+ r_res = alra.alra(r_X, 0, 10, 0.001)
42
+ # retrieve imputed data
43
+ r_K = r_res[0] # rank k
44
+ r_T = r_res[1] # rank k thresholded
45
+ r_S = r_res[2] # rank k thresholded scaled
46
+ # convert back to numpy array
47
+ data_aprx = np.array(r_K, dtype=float32)
48
+ data_thrs = np.array(r_T, dtype=float32)
49
+ data_alra = np.array(r_S, dtype=float32)
50
+
51
+ # clean up
52
+ del (r_X, r_res, r_K, r_T, r_S)
53
+ R("gc()")
54
+ gc.collect()
55
+
56
+ return data_aprx, data_thrs, data_alra
57
+
58
+
59
+ def _fix_alra_scale(
60
+ input_data: NDArray | csr_matrix | csc_matrix,
61
+ thrs_data: NDArray,
62
+ target_data: NDArray,
63
+ ) -> NDArray:
64
+ # Convert sparse -> dense
65
+ if issparse(input_data):
66
+ input_data = input_data.toarray("C")
67
+ input_data = input_data.astype(np.float32)
68
+ input_data = np.ascontiguousarray(input_data, dtype=np.float32)
69
+
70
+ n_cells, n_genes = input_data.shape
71
+
72
+ # per-gene nonzero means/sds (match R: sample sd ddof=1)
73
+ input_means = np.full(n_genes, fill_value=np.nan)
74
+ input_stds = np.full(n_genes, fill_value=np.nan)
75
+ thrs_means = np.full(n_genes, fill_value=np.nan)
76
+ thrs_stds = np.full(n_genes, fill_value=np.nan)
77
+ v: NDArray
78
+
79
+ for i, e in enumerate(input_data.T):
80
+ v = e[e > 0]
81
+
82
+ if v.size == 0:
83
+ continue
84
+ input_means[i] = v.mean()
85
+
86
+ if v.size == 1:
87
+ continue
88
+ input_stds[i] = v.std(ddof=1)
89
+
90
+ for i, e in enumerate(thrs_data.T):
91
+ v = e[e > 0]
92
+
93
+ if v.size == 0:
94
+ continue
95
+ thrs_means[i] = v.mean()
96
+
97
+ if v.size == 1:
98
+ continue
99
+ thrs_stds[i] = v.std(ddof=1)
100
+
101
+ # columns to scale (mirror R's toscale)
102
+ toscale = (
103
+ ~np.isnan(thrs_stds)
104
+ & ~np.isnan(input_stds)
105
+ & ~((thrs_stds == 0) & (input_stds == 0))
106
+ & ~(thrs_stds == 0)
107
+ )
108
+
109
+ # affine params
110
+ a = np.full(n_genes, fill_value=1.0)
111
+ b = np.full(n_genes, fill_value=0.0)
112
+ a[toscale] = input_stds[toscale] / thrs_stds[toscale]
113
+ b[toscale] = input_means[toscale] - a[toscale] * thrs_means[toscale]
114
+
115
+ # apply to target matrix (only columns in toscale)
116
+ out = target_data.copy()
117
+ out[:, toscale] = out[:, toscale] * a[toscale] + b[toscale]
118
+
119
+ # keep zeros as zeros
120
+ out[thrs_data == 0] = 0
121
+
122
+ # clip negatives to zero
123
+ out[out < 0] = 0
124
+
125
+ # restore originally observed positives that became zero
126
+ mask = (input_data > 0) & (out == 0)
127
+ out[mask] = input_data[mask]
128
+
129
+ return out
130
+
131
+
132
+ __all__ = [
133
+ "_alra_on_ndarray",
134
+ "_fix_alra_scale",
135
+ ]
@@ -0,0 +1,6 @@
1
+ from . import sctype
2
+
3
+
4
+ __all__ = [
5
+ "sctype",
6
+ ]
@@ -0,0 +1,233 @@
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from anndata import AnnData
7
+ from numpy.typing import NDArray
8
+ from scipy import stats
9
+ from scipy.sparse import csc_matrix, csr_matrix, issparse
10
+
11
+ from ...preprocess import pool_neighbors
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def _get_classification_scores_matrix(
17
+ adata: AnnData,
18
+ markers: pd.DataFrame,
19
+ marker_class_key: str,
20
+ neighbors_key: Optional[str] = None,
21
+ weighted_pooling: bool = False,
22
+ directed_pooling: bool = True,
23
+ layer: Optional[str] = None,
24
+ penalize_non_specific: bool = True,
25
+ ):
26
+ # Ianevski, A., Giri, A.K. & Aittokallio, T.
27
+ # Fully-automated and ultra-fast cell-type identification using specific
28
+ # marker combinations from single-cell transcriptomic data.
29
+ # Nat Commun 13, 1246 (2022).
30
+ # https://doi.org/10.1038/s41467-022-28803-w
31
+
32
+ if layer is not None:
33
+ X = adata.layers[layer]
34
+
35
+ else:
36
+ X = adata.X
37
+
38
+ min_val: np.number = X.min()
39
+ M = X > min_val
40
+ n_cells = np.asarray(M.sum(axis=0)).squeeze()
41
+ mask = n_cells > 5
42
+ print(f"using {mask.sum()} genes")
43
+
44
+ markers = markers.loc[markers["names"].isin(adata.var_names[mask])].copy()
45
+ classes = markers[marker_class_key].cat.categories
46
+
47
+ x = markers[[marker_class_key, "names"]].groupby("names").count()[marker_class_key]
48
+ if penalize_non_specific:
49
+ S = 1.0 - (x - x.min()) / (x.max() - x.min())
50
+ S = S[S > 0]
51
+ else:
52
+ S = x * 0.0 + 1.0
53
+
54
+ X: NDArray | csr_matrix | csc_matrix
55
+ if neighbors_key is not None:
56
+ X = pool_neighbors(
57
+ adata[:, S.index],
58
+ layer=layer,
59
+ neighbors_key=neighbors_key,
60
+ weighted=weighted_pooling,
61
+ directed=directed_pooling,
62
+ copy=True,
63
+ )
64
+
65
+ elif layer is not None:
66
+ X = adata[:, S.index].layers[layer].copy()
67
+
68
+ else:
69
+ X = adata[:, S.index].X.copy()
70
+
71
+ if issparse(X):
72
+ X = np.asarray(X.todense("C"))
73
+
74
+ Z: NDArray
75
+ Z = stats.zscore(X, axis=0)
76
+ Xp = Z * S.values
77
+
78
+ Xc = np.zeros((adata.shape[0], len(classes)))
79
+ for c, cell_class in enumerate(classes):
80
+ if cell_class == "Unknown":
81
+ continue
82
+ up_genes = markers.loc[
83
+ (markers[marker_class_key] == cell_class) & (markers["logfoldchanges"] > 0),
84
+ "names",
85
+ ]
86
+ dw_genes = markers.loc[
87
+ (markers[marker_class_key] == cell_class) & (markers["logfoldchanges"] < 0),
88
+ "names",
89
+ ]
90
+ x_up = Xp[:, S.index.isin(up_genes)]
91
+ x_dw = Xp[:, S.index.isin(dw_genes)]
92
+ if len(up_genes) > 0:
93
+ Xc[:, c] += x_up.sum(axis=1) / np.sqrt(len(up_genes))
94
+ if len(dw_genes) > 0:
95
+ Xc[:, c] -= x_dw.sum(axis=1) / np.sqrt(len(dw_genes))
96
+
97
+ return Xc
98
+
99
+
100
+ def classify_cells(
101
+ adata: AnnData,
102
+ markers: pd.DataFrame,
103
+ marker_class_key: Optional[str] = None,
104
+ cluster_key: Optional[str] = None,
105
+ layer: Optional[str] = None,
106
+ key_added: Optional[str] = None,
107
+ threshold: float = 0.25,
108
+ penalize_non_specific: bool = True,
109
+ neighbors_key: Optional[str] = None,
110
+ save_scores: bool = False,
111
+ ):
112
+ """
113
+ Classify cells based on a set of marker genes.
114
+
115
+ Ianevski, A., Giri, A.K. & Aittokallio, T.
116
+ Fully-automated and ultra-fast cell-type identification using specific
117
+ marker combinations from single-cell transcriptomic data.
118
+ Nat Commun 13, 1246 (2022).
119
+ https://doi.org/10.1038/s41467-022-28803-w
120
+
121
+ Parameters
122
+ ----------
123
+ adata
124
+ AnnData object.
125
+ markers
126
+ Marker genes.
127
+ marker_class_key
128
+ Column in `markers` that contains the cell type information.
129
+ cluster_key
130
+ Column in `adata.obs` that contains the cluster information. If
131
+ not provided, the classification will be performed on a cell by cell
132
+ basis, pooling across neighbor cells. This pooling can be avoided by
133
+ setting `force_pooling` to `False`.
134
+ layer
135
+ Layer to use for classification. Defaults to `X`.
136
+ key_added
137
+ Key under which to add the classification information.
138
+ threshold
139
+ Confidence threshold for classification. Defaults to `0.25`.
140
+ penalize_non_specific
141
+ Whether to penalize non-specific markers. Defaults to `True`.
142
+ neighbors_key
143
+ If provided, counts will be pooled across neighbor cells using the
144
+ distances in `adata.uns[neighbors_key]["distances"]`. Defaults to `None`.
145
+ save_scores
146
+ Whether to save the classification scores. Defaults to `False`
147
+ """
148
+ # cite("10.1038/s41467-022-28803-w", __package__)
149
+
150
+ if marker_class_key is not None:
151
+ marker_class = markers[marker_class_key]
152
+ if not marker_class.dtype.name.startswith("category"):
153
+ markers[marker_class_key] = marker_class.astype("category")
154
+ else:
155
+ col_mask = markers.dtypes == "category"
156
+ assert col_mask.sum() == 1, (
157
+ "markers_df must have exactly one column of type 'category'"
158
+ )
159
+ marker_class_key = markers.loc[:, col_mask].squeeze().name
160
+
161
+ classes = markers[marker_class_key].cat.categories
162
+ dtype = markers[marker_class_key].dtype
163
+
164
+ # if doing cell by cell classification, we should pool counts to use cell
165
+ # neighborhood information. This allows to estimate the confidence of the
166
+ # classification. We specify pooling by providing a neighbors_key.
167
+ posXc = _get_classification_scores_matrix(
168
+ adata,
169
+ markers.query("logfoldchanges > 0"),
170
+ marker_class_key,
171
+ neighbors_key,
172
+ weighted_pooling=True,
173
+ directed_pooling=True,
174
+ layer=layer,
175
+ penalize_non_specific=penalize_non_specific,
176
+ )
177
+ negXc = _get_classification_scores_matrix(
178
+ adata,
179
+ markers.query("logfoldchanges < 0"),
180
+ marker_class_key,
181
+ neighbors_key,
182
+ weighted_pooling=True,
183
+ directed_pooling=True,
184
+ layer=layer,
185
+ penalize_non_specific=penalize_non_specific,
186
+ )
187
+ Xc = posXc + negXc
188
+
189
+ if cluster_key is not None:
190
+ mappings = {}
191
+ mappings_nona = {}
192
+ for c in adata.obs[cluster_key].cat.categories:
193
+ cluster_scores_matrix = Xc[adata.obs[cluster_key] == c]
194
+ n_cells_in_cluster = cluster_scores_matrix.shape[0]
195
+
196
+ scores = cluster_scores_matrix.sum(axis=0)
197
+ confidence = scores.max() / n_cells_in_cluster
198
+ if confidence >= threshold:
199
+ mappings[c] = classes[np.argmax(scores)]
200
+ else:
201
+ mappings[c] = pd.NA
202
+ logger.warning(
203
+ f"Cluster {str(c):>5} classified as Unknown with confidence score {confidence: 8.2f}"
204
+ )
205
+ mappings_nona[c] = classes[np.argmax(scores)]
206
+ classifications = adata.obs[cluster_key].map(mappings).astype(dtype)
207
+ classifications_nona = adata.obs[cluster_key].map(mappings_nona).astype(dtype)
208
+ else:
209
+ if neighbors_key is not None:
210
+ n_neigs = adata.uns[neighbors_key]["params"]["n_neighbors"]
211
+ else:
212
+ n_neigs = 1
213
+ index = adata.obs_names
214
+ classifications = classes.values[Xc.argmax(axis=1)]
215
+ classifications = pd.Series(classifications, index=index).astype(dtype)
216
+ classifications_nona = classifications.copy()
217
+ classifications.loc[Xc.max(axis=1) < threshold * n_neigs] = pd.NA
218
+
219
+ N = len(classifications)
220
+ n_unknowns = pd.isna(classifications).sum()
221
+ n_estimated = N - n_unknowns
222
+
223
+ logger.info(f"Estimated types for {n_estimated} cells ({n_estimated / N:.2%})")
224
+
225
+ if key_added is None:
226
+ key_added = marker_class_key
227
+
228
+ adata.obs[key_added] = classifications
229
+ adata.obs[key_added + "_noNA"] = classifications_nona
230
+
231
+ if save_scores:
232
+ adata.obs[key_added + "_score"] = Xc.max(axis=1)
233
+ adata.obsm[key_added + "_scores"] = Xc