sclab 0.2.5__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sclab might be problematic. Click here for more details.
- sclab/__init__.py +1 -1
- sclab/dataset/_dataset.py +1 -1
- sclab/examples/processor_steps/__init__.py +2 -0
- sclab/examples/processor_steps/_doublet_detection.py +68 -0
- sclab/examples/processor_steps/_integration.py +37 -4
- sclab/examples/processor_steps/_neighbors.py +24 -4
- sclab/examples/processor_steps/_pca.py +5 -5
- sclab/examples/processor_steps/_preprocess.py +14 -1
- sclab/examples/processor_steps/_qc.py +22 -6
- sclab/gui/__init__.py +0 -0
- sclab/gui/components/__init__.py +5 -0
- sclab/gui/components/_guided_pseudotime.py +482 -0
- sclab/methods/__init__.py +25 -1
- sclab/preprocess/__init__.py +18 -0
- sclab/preprocess/_cca.py +154 -0
- sclab/preprocess/_cca_integrate.py +77 -0
- sclab/preprocess/_filter_obs.py +42 -0
- sclab/preprocess/_harmony.py +421 -0
- sclab/preprocess/_harmony_integrate.py +50 -0
- sclab/preprocess/_normalize_weighted.py +61 -0
- sclab/preprocess/_subset.py +208 -0
- sclab/preprocess/_transfer_metadata.py +137 -0
- sclab/preprocess/_transform.py +82 -0
- sclab/preprocess/_utils.py +96 -0
- sclab/tools/__init__.py +0 -0
- sclab/tools/cellflow/__init__.py +0 -0
- sclab/tools/cellflow/density_dynamics/__init__.py +0 -0
- sclab/tools/cellflow/density_dynamics/_density_dynamics.py +349 -0
- sclab/tools/cellflow/pseudotime/__init__.py +0 -0
- sclab/tools/cellflow/pseudotime/_pseudotime.py +332 -0
- sclab/tools/cellflow/pseudotime/timeseries.py +226 -0
- sclab/tools/cellflow/utils/__init__.py +0 -0
- sclab/tools/cellflow/utils/density_nd.py +136 -0
- sclab/tools/cellflow/utils/interpolate.py +334 -0
- sclab/tools/cellflow/utils/smoothen.py +124 -0
- sclab/tools/cellflow/utils/times.py +55 -0
- sclab/tools/differential_expression/__init__.py +5 -0
- sclab/tools/differential_expression/_pseudobulk_edger.py +304 -0
- sclab/tools/differential_expression/_pseudobulk_helpers.py +277 -0
- sclab/tools/doublet_detection/__init__.py +5 -0
- sclab/tools/doublet_detection/_scrublet.py +64 -0
- sclab/tools/labeling/__init__.py +6 -0
- sclab/tools/labeling/sctype.py +233 -0
- sclab/utils/__init__.py +5 -0
- sclab/utils/_write_excel.py +510 -0
- {sclab-0.2.5.dist-info → sclab-0.3.0.dist-info}/METADATA +6 -2
- sclab-0.3.0.dist-info/RECORD +81 -0
- sclab-0.2.5.dist-info/RECORD +0 -45
- {sclab-0.2.5.dist-info → sclab-0.3.0.dist-info}/WHEEL +0 -0
- {sclab-0.2.5.dist-info → sclab-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
import random
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from anndata import AnnData
|
|
6
|
+
from numpy import ndarray
|
|
7
|
+
from scipy.sparse import csr_matrix, issparse
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# code inspired from
|
|
11
|
+
# https://www.sc-best-practices.org/conditions/differential_gene_expression.html
|
|
12
|
+
def aggregate_and_filter(
|
|
13
|
+
adata: AnnData,
|
|
14
|
+
group_key: str = "batch",
|
|
15
|
+
cell_identity_key: str | None = None,
|
|
16
|
+
layer: str | None = None,
|
|
17
|
+
replicas_per_group: int = 3,
|
|
18
|
+
min_cells_per_group: int = 30,
|
|
19
|
+
bootstrap_sampling: bool = False,
|
|
20
|
+
use_cells: dict[str, list[str]] | None = None,
|
|
21
|
+
) -> AnnData:
|
|
22
|
+
"""
|
|
23
|
+
Aggregate and filter cells in an AnnData object into cell populations.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
adata : AnnData
|
|
28
|
+
AnnData object to aggregate and filter.
|
|
29
|
+
group_key : str, optional
|
|
30
|
+
Key to group cells by. Defaults to 'batch'.
|
|
31
|
+
cell_identity_key : str, optional
|
|
32
|
+
Key to use to identify cell identities. Defaults to None.
|
|
33
|
+
layer : str, optional
|
|
34
|
+
Layer in AnnData object to use for aggregation. Defaults to None.
|
|
35
|
+
replicas_per_group : int, optional
|
|
36
|
+
Number of replicas to create for each group. Defaults to 3.
|
|
37
|
+
min_cells_per_group : int, optional
|
|
38
|
+
Minimum number of cells required for a group to be included. Defaults to 30.
|
|
39
|
+
bootstrap_sampling : bool, optional
|
|
40
|
+
Whether to use bootstrap sampling to create replicas. Defaults to False.
|
|
41
|
+
use_cells : dict[str, list[str]], optional
|
|
42
|
+
If not None, only use the specified cells. Defaults to None.
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
AnnData
|
|
47
|
+
AnnData object with aggregated and filtered cells.
|
|
48
|
+
"""
|
|
49
|
+
adata = _prepare_dataset(adata, use_cells)
|
|
50
|
+
|
|
51
|
+
grouping_keys = [group_key]
|
|
52
|
+
if cell_identity_key is not None:
|
|
53
|
+
grouping_keys.append(cell_identity_key)
|
|
54
|
+
|
|
55
|
+
groups_to_drop = _get_groups_to_drop(adata, grouping_keys, min_cells_per_group)
|
|
56
|
+
|
|
57
|
+
_prepare_categorical_column(adata, group_key)
|
|
58
|
+
group_dtype = adata.obs[group_key].dtype
|
|
59
|
+
|
|
60
|
+
if cell_identity_key is not None:
|
|
61
|
+
_prepare_categorical_column(adata, cell_identity_key)
|
|
62
|
+
cell_identity_dtype = adata.obs[cell_identity_key].dtype
|
|
63
|
+
|
|
64
|
+
var_dataframe = _create_var_dataframe(adata, layer, grouping_keys, groups_to_drop)
|
|
65
|
+
|
|
66
|
+
data = {}
|
|
67
|
+
meta = {}
|
|
68
|
+
groups = adata.obs.groupby(grouping_keys, observed=True).groups
|
|
69
|
+
for group, group_idxs in groups.items():
|
|
70
|
+
if not isinstance(group, tuple):
|
|
71
|
+
group = (group,)
|
|
72
|
+
|
|
73
|
+
if not _including(group, groups_to_drop):
|
|
74
|
+
continue
|
|
75
|
+
|
|
76
|
+
sample_id = "_".join(group)
|
|
77
|
+
match group:
|
|
78
|
+
case (gid, cid):
|
|
79
|
+
group_metadata = {group_key: gid, cell_identity_key: cid}
|
|
80
|
+
case (gid,):
|
|
81
|
+
group_metadata = {group_key: gid}
|
|
82
|
+
|
|
83
|
+
adata_group = adata[group_idxs]
|
|
84
|
+
indices = _get_replica_idxs(adata_group, replicas_per_group, bootstrap_sampling)
|
|
85
|
+
for i, rep_idx in enumerate(indices):
|
|
86
|
+
replica_number = i + 1
|
|
87
|
+
replica_size = len(rep_idx)
|
|
88
|
+
replica_sample_id = f"{sample_id}_rep{replica_number}"
|
|
89
|
+
|
|
90
|
+
adata_group_replica = adata_group[rep_idx]
|
|
91
|
+
X = _get_layer(adata_group_replica, layer)
|
|
92
|
+
|
|
93
|
+
data[replica_sample_id] = np.array(X.sum(axis=0)).flatten()
|
|
94
|
+
meta[replica_sample_id] = {
|
|
95
|
+
**group_metadata,
|
|
96
|
+
"replica": str(replica_number),
|
|
97
|
+
"replica_size": replica_size,
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
data = pd.DataFrame(data).T
|
|
101
|
+
meta = pd.DataFrame(meta).T
|
|
102
|
+
meta["replica"] = meta["replica"].astype("category")
|
|
103
|
+
meta[group_key] = meta[group_key].astype(group_dtype)
|
|
104
|
+
if cell_identity_key is not None:
|
|
105
|
+
meta[cell_identity_key] = meta[cell_identity_key].astype(cell_identity_dtype)
|
|
106
|
+
|
|
107
|
+
aggr_adata = AnnData(
|
|
108
|
+
data.values,
|
|
109
|
+
obs=meta,
|
|
110
|
+
var=var_dataframe,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
_join_dummies(aggr_adata, group_key)
|
|
114
|
+
|
|
115
|
+
return aggr_adata
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _prepare_dataset(
|
|
119
|
+
adata: AnnData,
|
|
120
|
+
use_cells: dict[str, list[str]] | None,
|
|
121
|
+
) -> AnnData:
|
|
122
|
+
if use_cells is not None:
|
|
123
|
+
for key, value in use_cells.items():
|
|
124
|
+
adata = adata[adata.obs[key].isin(value)]
|
|
125
|
+
|
|
126
|
+
return adata.copy()
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _get_groups_to_drop(
|
|
130
|
+
adata: AnnData,
|
|
131
|
+
grouping_keys: str | list[str],
|
|
132
|
+
min_cells_per_group: int,
|
|
133
|
+
):
|
|
134
|
+
group_sizes = adata.obs.groupby(grouping_keys, observed=True).size()
|
|
135
|
+
groups_to_drop = group_sizes[group_sizes < min_cells_per_group].index.to_list()
|
|
136
|
+
|
|
137
|
+
if len(groups_to_drop) > 0:
|
|
138
|
+
print("Dropping the following samples:")
|
|
139
|
+
|
|
140
|
+
groups_to_drop = groups_to_drop + [
|
|
141
|
+
(g,) for g in groups_to_drop if not isinstance(g, tuple)
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
return groups_to_drop
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _prepare_categorical_column(adata: AnnData, column: str) -> None:
|
|
148
|
+
if not isinstance(adata.obs[column].dtype, pd.CategoricalDtype):
|
|
149
|
+
adata.obs[column] = adata.obs[column].astype("category")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _create_var_dataframe(
|
|
153
|
+
adata: AnnData,
|
|
154
|
+
layer: str,
|
|
155
|
+
grouping_keys: list[str],
|
|
156
|
+
groups_to_drop: list[str],
|
|
157
|
+
):
|
|
158
|
+
columns = _get_var_dataframe_columns(adata, grouping_keys, groups_to_drop)
|
|
159
|
+
var_dataframe = pd.DataFrame(index=adata.var_names, columns=columns, dtype=float)
|
|
160
|
+
|
|
161
|
+
groups = adata.obs.groupby(grouping_keys, observed=True).groups
|
|
162
|
+
for group, idx in groups.items():
|
|
163
|
+
if not isinstance(group, tuple):
|
|
164
|
+
group = (group,)
|
|
165
|
+
|
|
166
|
+
if not _including(group, groups_to_drop):
|
|
167
|
+
continue
|
|
168
|
+
|
|
169
|
+
sample_id = "_".join(group)
|
|
170
|
+
rest_id = f"not{sample_id}"
|
|
171
|
+
|
|
172
|
+
adata_subset = adata[idx]
|
|
173
|
+
rest_subset = adata[~adata.obs_names.isin(idx)]
|
|
174
|
+
|
|
175
|
+
X = _get_layer(adata_subset, layer, dense=True)
|
|
176
|
+
Y = _get_layer(rest_subset, layer, dense=True)
|
|
177
|
+
|
|
178
|
+
var_dataframe[f"pct_expr_{sample_id}"] = (X > 0).mean(axis=0)
|
|
179
|
+
var_dataframe[f"pct_expr_{rest_id}"] = (Y > 0).mean(axis=0)
|
|
180
|
+
var_dataframe[f"num_expr_{sample_id}"] = (X > 0).sum(axis=0)
|
|
181
|
+
var_dataframe[f"num_expr_{rest_id}"] = (Y > 0).sum(axis=0)
|
|
182
|
+
var_dataframe[f"tot_expr_{sample_id}"] = X.sum(axis=0)
|
|
183
|
+
var_dataframe[f"tot_expr_{rest_id}"] = Y.sum(axis=0)
|
|
184
|
+
|
|
185
|
+
return var_dataframe
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _get_var_dataframe_columns(
|
|
189
|
+
adata: AnnData, grouping_keys: list[str], groups_to_drop: list[str]
|
|
190
|
+
) -> list[str]:
|
|
191
|
+
columns = []
|
|
192
|
+
|
|
193
|
+
groups = adata.obs.groupby(grouping_keys, observed=True).groups
|
|
194
|
+
for group, _ in groups.items():
|
|
195
|
+
if not isinstance(group, tuple):
|
|
196
|
+
group = (group,)
|
|
197
|
+
|
|
198
|
+
if not _including(group, groups_to_drop):
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
sample_id = "_".join(group)
|
|
202
|
+
rest_id = f"not{sample_id}"
|
|
203
|
+
|
|
204
|
+
columns.extend(
|
|
205
|
+
[
|
|
206
|
+
f"pct_expr_{sample_id}",
|
|
207
|
+
f"pct_expr_{rest_id}",
|
|
208
|
+
f"num_expr_{sample_id}",
|
|
209
|
+
f"num_expr_{rest_id}",
|
|
210
|
+
f"tot_expr_{sample_id}",
|
|
211
|
+
f"tot_expr_{rest_id}",
|
|
212
|
+
]
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
return columns
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _including(group: tuple | str, groups_to_drop: list[str]) -> bool:
|
|
219
|
+
match group:
|
|
220
|
+
case (gid, cid):
|
|
221
|
+
if isinstance(cid, float) and np.isnan(cid):
|
|
222
|
+
return False
|
|
223
|
+
|
|
224
|
+
case (gid,) | gid:
|
|
225
|
+
...
|
|
226
|
+
|
|
227
|
+
if gid in groups_to_drop:
|
|
228
|
+
return False
|
|
229
|
+
|
|
230
|
+
return True
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _get_replica_idxs(
|
|
234
|
+
adata_group: AnnData,
|
|
235
|
+
replicas_per_group: int,
|
|
236
|
+
bootstrap_sampling: bool,
|
|
237
|
+
):
|
|
238
|
+
group_size = adata_group.n_obs
|
|
239
|
+
indices = list(adata_group.obs_names)
|
|
240
|
+
if bootstrap_sampling:
|
|
241
|
+
indices = np.array(
|
|
242
|
+
[
|
|
243
|
+
np.random.choice(indices, size=group_size, replace=True)
|
|
244
|
+
for _ in range(replicas_per_group)
|
|
245
|
+
]
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
else:
|
|
249
|
+
random.shuffle(indices)
|
|
250
|
+
indices = np.array_split(np.array(indices), replicas_per_group)
|
|
251
|
+
|
|
252
|
+
return indices
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _get_layer(adata: AnnData, layer: str | None, dense: bool = False):
|
|
256
|
+
X: ndarray | csr_matrix
|
|
257
|
+
|
|
258
|
+
if layer is None or layer == "X":
|
|
259
|
+
X = adata.X
|
|
260
|
+
else:
|
|
261
|
+
X = adata.layers[layer]
|
|
262
|
+
|
|
263
|
+
if dense:
|
|
264
|
+
if issparse(X):
|
|
265
|
+
X = np.asarray(X.todense())
|
|
266
|
+
else:
|
|
267
|
+
X = np.asarray(X)
|
|
268
|
+
|
|
269
|
+
return X
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def _join_dummies(aggr_adata: AnnData, group_key: str) -> None:
|
|
273
|
+
dummies = pd.get_dummies(aggr_adata.obs[group_key], prefix=group_key).astype(str)
|
|
274
|
+
dummies = dummies.astype(str).apply(lambda s: s.map({"True": "", "False": "not"}))
|
|
275
|
+
dummies = dummies + aggr_adata.obs[group_key].cat.categories
|
|
276
|
+
|
|
277
|
+
aggr_adata.obs = aggr_adata.obs.join(dummies)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from importlib.util import find_spec
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from anndata import AnnData
|
|
6
|
+
from numpy import ndarray
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def scrublet(
|
|
10
|
+
adata: AnnData,
|
|
11
|
+
layer: str = "X",
|
|
12
|
+
key_added: str = "scrublet",
|
|
13
|
+
total_counts: ndarray | None = None,
|
|
14
|
+
sim_doublet_ratio: float = 2.0,
|
|
15
|
+
n_neighbors: int = None,
|
|
16
|
+
expected_doublet_rate: float = 0.1,
|
|
17
|
+
stdev_doublet_rate: float = 0.02,
|
|
18
|
+
random_state: int = 0,
|
|
19
|
+
scrub_doublets_kwargs: dict[str, Any] = dict(
|
|
20
|
+
synthetic_doublet_umi_subsampling=1.0,
|
|
21
|
+
use_approx_neighbors=True,
|
|
22
|
+
distance_metric="euclidean",
|
|
23
|
+
get_doublet_neighbor_parents=False,
|
|
24
|
+
min_counts=3,
|
|
25
|
+
min_cells=3,
|
|
26
|
+
min_gene_variability_pctl=85,
|
|
27
|
+
log_transform=False,
|
|
28
|
+
mean_center=True,
|
|
29
|
+
normalize_variance=True,
|
|
30
|
+
n_prin_comps=30,
|
|
31
|
+
svd_solver="arpack",
|
|
32
|
+
verbose=True,
|
|
33
|
+
),
|
|
34
|
+
):
|
|
35
|
+
if find_spec("scrublet") is None:
|
|
36
|
+
raise ImportError(
|
|
37
|
+
"scrublet is not installed. Install with:\npip install scrublet"
|
|
38
|
+
)
|
|
39
|
+
from scrublet import Scrublet # noqa: E402
|
|
40
|
+
|
|
41
|
+
if layer == "X":
|
|
42
|
+
X = adata.X
|
|
43
|
+
else:
|
|
44
|
+
X = adata.layers[layer]
|
|
45
|
+
|
|
46
|
+
scrub = Scrublet(
|
|
47
|
+
counts_matrix=X,
|
|
48
|
+
total_counts=total_counts,
|
|
49
|
+
sim_doublet_ratio=sim_doublet_ratio,
|
|
50
|
+
n_neighbors=n_neighbors,
|
|
51
|
+
expected_doublet_rate=expected_doublet_rate,
|
|
52
|
+
stdev_doublet_rate=stdev_doublet_rate,
|
|
53
|
+
random_state=random_state,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
_scores, labels = scrub.scrub_doublets(**scrub_doublets_kwargs)
|
|
57
|
+
if labels is not None:
|
|
58
|
+
_labels = list(map(lambda v: "doublet" if v else "singlet", labels))
|
|
59
|
+
_labels = pd.Categorical(_labels, ["singlet", "doublet"])
|
|
60
|
+
adata.obs[f"{key_added}_label"] = _labels
|
|
61
|
+
else:
|
|
62
|
+
adata.obs[f"{key_added}_label"] = "singlet"
|
|
63
|
+
|
|
64
|
+
adata.obs[f"{key_added}_score"] = _scores
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from anndata import AnnData
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
from scipy import stats
|
|
9
|
+
from scipy.sparse import csc_matrix, csr_matrix, issparse
|
|
10
|
+
|
|
11
|
+
from ...preprocess import pool_neighbors
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _get_classification_scores_matrix(
|
|
17
|
+
adata: AnnData,
|
|
18
|
+
markers: pd.DataFrame,
|
|
19
|
+
marker_class_key: str,
|
|
20
|
+
neighbors_key: Optional[str] = None,
|
|
21
|
+
weighted_pooling: bool = False,
|
|
22
|
+
directed_pooling: bool = True,
|
|
23
|
+
layer: Optional[str] = None,
|
|
24
|
+
penalize_non_specific: bool = True,
|
|
25
|
+
):
|
|
26
|
+
# Ianevski, A., Giri, A.K. & Aittokallio, T.
|
|
27
|
+
# Fully-automated and ultra-fast cell-type identification using specific
|
|
28
|
+
# marker combinations from single-cell transcriptomic data.
|
|
29
|
+
# Nat Commun 13, 1246 (2022).
|
|
30
|
+
# https://doi.org/10.1038/s41467-022-28803-w
|
|
31
|
+
|
|
32
|
+
if layer is not None:
|
|
33
|
+
X = adata.layers[layer]
|
|
34
|
+
|
|
35
|
+
else:
|
|
36
|
+
X = adata.X
|
|
37
|
+
|
|
38
|
+
min_val: np.number = X.min()
|
|
39
|
+
M = X > min_val
|
|
40
|
+
n_cells = np.asarray(M.sum(axis=0)).squeeze()
|
|
41
|
+
mask = n_cells > 5
|
|
42
|
+
print(f"using {mask.sum()} genes")
|
|
43
|
+
|
|
44
|
+
markers = markers.loc[markers["names"].isin(adata.var_names[mask])].copy()
|
|
45
|
+
classes = markers[marker_class_key].cat.categories
|
|
46
|
+
|
|
47
|
+
x = markers[[marker_class_key, "names"]].groupby("names").count()[marker_class_key]
|
|
48
|
+
if penalize_non_specific:
|
|
49
|
+
S = 1.0 - (x - x.min()) / (x.max() - x.min())
|
|
50
|
+
S = S[S > 0]
|
|
51
|
+
else:
|
|
52
|
+
S = x * 0.0 + 1.0
|
|
53
|
+
|
|
54
|
+
X: NDArray | csr_matrix | csc_matrix
|
|
55
|
+
if neighbors_key is not None:
|
|
56
|
+
X = pool_neighbors(
|
|
57
|
+
adata[:, S.index],
|
|
58
|
+
layer=layer,
|
|
59
|
+
neighbors_key=neighbors_key,
|
|
60
|
+
weighted=weighted_pooling,
|
|
61
|
+
directed=directed_pooling,
|
|
62
|
+
copy=True,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
elif layer is not None:
|
|
66
|
+
X = adata[:, S.index].layers[layer].copy()
|
|
67
|
+
|
|
68
|
+
else:
|
|
69
|
+
X = adata[:, S.index].X.copy()
|
|
70
|
+
|
|
71
|
+
if issparse(X):
|
|
72
|
+
X = np.asarray(X.todense("C"))
|
|
73
|
+
|
|
74
|
+
Z: NDArray
|
|
75
|
+
Z = stats.zscore(X, axis=0)
|
|
76
|
+
Xp = Z * S.values
|
|
77
|
+
|
|
78
|
+
Xc = np.zeros((adata.shape[0], len(classes)))
|
|
79
|
+
for c, cell_class in enumerate(classes):
|
|
80
|
+
if cell_class == "Unknown":
|
|
81
|
+
continue
|
|
82
|
+
up_genes = markers.loc[
|
|
83
|
+
(markers[marker_class_key] == cell_class) & (markers["logfoldchanges"] > 0),
|
|
84
|
+
"names",
|
|
85
|
+
]
|
|
86
|
+
dw_genes = markers.loc[
|
|
87
|
+
(markers[marker_class_key] == cell_class) & (markers["logfoldchanges"] < 0),
|
|
88
|
+
"names",
|
|
89
|
+
]
|
|
90
|
+
x_up = Xp[:, S.index.isin(up_genes)]
|
|
91
|
+
x_dw = Xp[:, S.index.isin(dw_genes)]
|
|
92
|
+
if len(up_genes) > 0:
|
|
93
|
+
Xc[:, c] += x_up.sum(axis=1) / np.sqrt(len(up_genes))
|
|
94
|
+
if len(dw_genes) > 0:
|
|
95
|
+
Xc[:, c] -= x_dw.sum(axis=1) / np.sqrt(len(dw_genes))
|
|
96
|
+
|
|
97
|
+
return Xc
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def classify_cells(
|
|
101
|
+
adata: AnnData,
|
|
102
|
+
markers: pd.DataFrame,
|
|
103
|
+
marker_class_key: Optional[str] = None,
|
|
104
|
+
cluster_key: Optional[str] = None,
|
|
105
|
+
layer: Optional[str] = None,
|
|
106
|
+
key_added: Optional[str] = None,
|
|
107
|
+
threshold: float = 0.25,
|
|
108
|
+
penalize_non_specific: bool = True,
|
|
109
|
+
neighbors_key: Optional[str] = None,
|
|
110
|
+
save_scores: bool = False,
|
|
111
|
+
):
|
|
112
|
+
"""
|
|
113
|
+
Classify cells based on a set of marker genes.
|
|
114
|
+
|
|
115
|
+
Ianevski, A., Giri, A.K. & Aittokallio, T.
|
|
116
|
+
Fully-automated and ultra-fast cell-type identification using specific
|
|
117
|
+
marker combinations from single-cell transcriptomic data.
|
|
118
|
+
Nat Commun 13, 1246 (2022).
|
|
119
|
+
https://doi.org/10.1038/s41467-022-28803-w
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
adata
|
|
124
|
+
AnnData object.
|
|
125
|
+
markers
|
|
126
|
+
Marker genes.
|
|
127
|
+
marker_class_key
|
|
128
|
+
Column in `markers` that contains the cell type information.
|
|
129
|
+
cluster_key
|
|
130
|
+
Column in `adata.obs` that contains the cluster information. If
|
|
131
|
+
not provided, the classification will be performed on a cell by cell
|
|
132
|
+
basis, pooling across neighbor cells. This pooling can be avoided by
|
|
133
|
+
setting `force_pooling` to `False`.
|
|
134
|
+
layer
|
|
135
|
+
Layer to use for classification. Defaults to `X`.
|
|
136
|
+
key_added
|
|
137
|
+
Key under which to add the classification information.
|
|
138
|
+
threshold
|
|
139
|
+
Confidence threshold for classification. Defaults to `0.25`.
|
|
140
|
+
penalize_non_specific
|
|
141
|
+
Whether to penalize non-specific markers. Defaults to `True`.
|
|
142
|
+
neighbors_key
|
|
143
|
+
If provided, counts will be pooled across neighbor cells using the
|
|
144
|
+
distances in `adata.uns[neighbors_key]["distances"]`. Defaults to `None`.
|
|
145
|
+
save_scores
|
|
146
|
+
Whether to save the classification scores. Defaults to `False`
|
|
147
|
+
"""
|
|
148
|
+
# cite("10.1038/s41467-022-28803-w", __package__)
|
|
149
|
+
|
|
150
|
+
if marker_class_key is not None:
|
|
151
|
+
marker_class = markers[marker_class_key]
|
|
152
|
+
if not marker_class.dtype.name.startswith("category"):
|
|
153
|
+
markers[marker_class_key] = marker_class.astype("category")
|
|
154
|
+
else:
|
|
155
|
+
col_mask = markers.dtypes == "category"
|
|
156
|
+
assert col_mask.sum() == 1, (
|
|
157
|
+
"markers_df must have exactly one column of type 'category'"
|
|
158
|
+
)
|
|
159
|
+
marker_class_key = markers.loc[:, col_mask].squeeze().name
|
|
160
|
+
|
|
161
|
+
classes = markers[marker_class_key].cat.categories
|
|
162
|
+
dtype = markers[marker_class_key].dtype
|
|
163
|
+
|
|
164
|
+
# if doing cell by cell classification, we should pool counts to use cell
|
|
165
|
+
# neighborhood information. This allows to estimate the confidence of the
|
|
166
|
+
# classification. We specify pooling by providing a neighbors_key.
|
|
167
|
+
posXc = _get_classification_scores_matrix(
|
|
168
|
+
adata,
|
|
169
|
+
markers.query("logfoldchanges > 0"),
|
|
170
|
+
marker_class_key,
|
|
171
|
+
neighbors_key,
|
|
172
|
+
weighted_pooling=True,
|
|
173
|
+
directed_pooling=True,
|
|
174
|
+
layer=layer,
|
|
175
|
+
penalize_non_specific=penalize_non_specific,
|
|
176
|
+
)
|
|
177
|
+
negXc = _get_classification_scores_matrix(
|
|
178
|
+
adata,
|
|
179
|
+
markers.query("logfoldchanges < 0"),
|
|
180
|
+
marker_class_key,
|
|
181
|
+
neighbors_key,
|
|
182
|
+
weighted_pooling=True,
|
|
183
|
+
directed_pooling=True,
|
|
184
|
+
layer=layer,
|
|
185
|
+
penalize_non_specific=penalize_non_specific,
|
|
186
|
+
)
|
|
187
|
+
Xc = posXc + negXc
|
|
188
|
+
|
|
189
|
+
if cluster_key is not None:
|
|
190
|
+
mappings = {}
|
|
191
|
+
mappings_nona = {}
|
|
192
|
+
for c in adata.obs[cluster_key].cat.categories:
|
|
193
|
+
cluster_scores_matrix = Xc[adata.obs[cluster_key] == c]
|
|
194
|
+
n_cells_in_cluster = cluster_scores_matrix.shape[0]
|
|
195
|
+
|
|
196
|
+
scores = cluster_scores_matrix.sum(axis=0)
|
|
197
|
+
confidence = scores.max() / n_cells_in_cluster
|
|
198
|
+
if confidence >= threshold:
|
|
199
|
+
mappings[c] = classes[np.argmax(scores)]
|
|
200
|
+
else:
|
|
201
|
+
mappings[c] = pd.NA
|
|
202
|
+
logger.warning(
|
|
203
|
+
f"Cluster {str(c):>5} classified as Unknown with confidence score {confidence: 8.2f}"
|
|
204
|
+
)
|
|
205
|
+
mappings_nona[c] = classes[np.argmax(scores)]
|
|
206
|
+
classifications = adata.obs[cluster_key].map(mappings).astype(dtype)
|
|
207
|
+
classifications_nona = adata.obs[cluster_key].map(mappings_nona).astype(dtype)
|
|
208
|
+
else:
|
|
209
|
+
if neighbors_key is not None:
|
|
210
|
+
n_neigs = adata.uns[neighbors_key]["params"]["n_neighbors"]
|
|
211
|
+
else:
|
|
212
|
+
n_neigs = 1
|
|
213
|
+
index = adata.obs_names
|
|
214
|
+
classifications = classes.values[Xc.argmax(axis=1)]
|
|
215
|
+
classifications = pd.Series(classifications, index=index).astype(dtype)
|
|
216
|
+
classifications_nona = classifications.copy()
|
|
217
|
+
classifications.loc[Xc.max(axis=1) < threshold * n_neigs] = pd.NA
|
|
218
|
+
|
|
219
|
+
N = len(classifications)
|
|
220
|
+
n_unknowns = pd.isna(classifications).sum()
|
|
221
|
+
n_estimated = N - n_unknowns
|
|
222
|
+
|
|
223
|
+
logger.info(f"Estimated types for {n_estimated} cells ({n_estimated / N:.2%})")
|
|
224
|
+
|
|
225
|
+
if key_added is None:
|
|
226
|
+
key_added = marker_class_key
|
|
227
|
+
|
|
228
|
+
adata.obs[key_added] = classifications
|
|
229
|
+
adata.obs[key_added + "_noNA"] = classifications_nona
|
|
230
|
+
|
|
231
|
+
if save_scores:
|
|
232
|
+
adata.obs[key_added + "_score"] = Xc.max(axis=1)
|
|
233
|
+
adata.obsm[key_added + "_scores"] = Xc
|