sclab 0.2.5__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sclab might be problematic. Click here for more details.

Files changed (50) hide show
  1. sclab/__init__.py +1 -1
  2. sclab/dataset/_dataset.py +1 -1
  3. sclab/examples/processor_steps/__init__.py +2 -0
  4. sclab/examples/processor_steps/_doublet_detection.py +68 -0
  5. sclab/examples/processor_steps/_integration.py +37 -4
  6. sclab/examples/processor_steps/_neighbors.py +24 -4
  7. sclab/examples/processor_steps/_pca.py +5 -5
  8. sclab/examples/processor_steps/_preprocess.py +14 -1
  9. sclab/examples/processor_steps/_qc.py +22 -6
  10. sclab/gui/__init__.py +0 -0
  11. sclab/gui/components/__init__.py +5 -0
  12. sclab/gui/components/_guided_pseudotime.py +482 -0
  13. sclab/methods/__init__.py +25 -1
  14. sclab/preprocess/__init__.py +18 -0
  15. sclab/preprocess/_cca.py +154 -0
  16. sclab/preprocess/_cca_integrate.py +77 -0
  17. sclab/preprocess/_filter_obs.py +42 -0
  18. sclab/preprocess/_harmony.py +421 -0
  19. sclab/preprocess/_harmony_integrate.py +50 -0
  20. sclab/preprocess/_normalize_weighted.py +61 -0
  21. sclab/preprocess/_subset.py +208 -0
  22. sclab/preprocess/_transfer_metadata.py +137 -0
  23. sclab/preprocess/_transform.py +82 -0
  24. sclab/preprocess/_utils.py +96 -0
  25. sclab/tools/__init__.py +0 -0
  26. sclab/tools/cellflow/__init__.py +0 -0
  27. sclab/tools/cellflow/density_dynamics/__init__.py +0 -0
  28. sclab/tools/cellflow/density_dynamics/_density_dynamics.py +349 -0
  29. sclab/tools/cellflow/pseudotime/__init__.py +0 -0
  30. sclab/tools/cellflow/pseudotime/_pseudotime.py +332 -0
  31. sclab/tools/cellflow/pseudotime/timeseries.py +226 -0
  32. sclab/tools/cellflow/utils/__init__.py +0 -0
  33. sclab/tools/cellflow/utils/density_nd.py +136 -0
  34. sclab/tools/cellflow/utils/interpolate.py +334 -0
  35. sclab/tools/cellflow/utils/smoothen.py +124 -0
  36. sclab/tools/cellflow/utils/times.py +55 -0
  37. sclab/tools/differential_expression/__init__.py +5 -0
  38. sclab/tools/differential_expression/_pseudobulk_edger.py +304 -0
  39. sclab/tools/differential_expression/_pseudobulk_helpers.py +277 -0
  40. sclab/tools/doublet_detection/__init__.py +5 -0
  41. sclab/tools/doublet_detection/_scrublet.py +64 -0
  42. sclab/tools/labeling/__init__.py +6 -0
  43. sclab/tools/labeling/sctype.py +233 -0
  44. sclab/utils/__init__.py +5 -0
  45. sclab/utils/_write_excel.py +510 -0
  46. {sclab-0.2.5.dist-info → sclab-0.3.0.dist-info}/METADATA +6 -2
  47. sclab-0.3.0.dist-info/RECORD +81 -0
  48. sclab-0.2.5.dist-info/RECORD +0 -45
  49. {sclab-0.2.5.dist-info → sclab-0.3.0.dist-info}/WHEEL +0 -0
  50. {sclab-0.2.5.dist-info → sclab-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,277 @@
1
+ import random
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from anndata import AnnData
6
+ from numpy import ndarray
7
+ from scipy.sparse import csr_matrix, issparse
8
+
9
+
10
+ # code inspired from
11
+ # https://www.sc-best-practices.org/conditions/differential_gene_expression.html
12
+ def aggregate_and_filter(
13
+ adata: AnnData,
14
+ group_key: str = "batch",
15
+ cell_identity_key: str | None = None,
16
+ layer: str | None = None,
17
+ replicas_per_group: int = 3,
18
+ min_cells_per_group: int = 30,
19
+ bootstrap_sampling: bool = False,
20
+ use_cells: dict[str, list[str]] | None = None,
21
+ ) -> AnnData:
22
+ """
23
+ Aggregate and filter cells in an AnnData object into cell populations.
24
+
25
+ Parameters
26
+ ----------
27
+ adata : AnnData
28
+ AnnData object to aggregate and filter.
29
+ group_key : str, optional
30
+ Key to group cells by. Defaults to 'batch'.
31
+ cell_identity_key : str, optional
32
+ Key to use to identify cell identities. Defaults to None.
33
+ layer : str, optional
34
+ Layer in AnnData object to use for aggregation. Defaults to None.
35
+ replicas_per_group : int, optional
36
+ Number of replicas to create for each group. Defaults to 3.
37
+ min_cells_per_group : int, optional
38
+ Minimum number of cells required for a group to be included. Defaults to 30.
39
+ bootstrap_sampling : bool, optional
40
+ Whether to use bootstrap sampling to create replicas. Defaults to False.
41
+ use_cells : dict[str, list[str]], optional
42
+ If not None, only use the specified cells. Defaults to None.
43
+
44
+ Returns
45
+ -------
46
+ AnnData
47
+ AnnData object with aggregated and filtered cells.
48
+ """
49
+ adata = _prepare_dataset(adata, use_cells)
50
+
51
+ grouping_keys = [group_key]
52
+ if cell_identity_key is not None:
53
+ grouping_keys.append(cell_identity_key)
54
+
55
+ groups_to_drop = _get_groups_to_drop(adata, grouping_keys, min_cells_per_group)
56
+
57
+ _prepare_categorical_column(adata, group_key)
58
+ group_dtype = adata.obs[group_key].dtype
59
+
60
+ if cell_identity_key is not None:
61
+ _prepare_categorical_column(adata, cell_identity_key)
62
+ cell_identity_dtype = adata.obs[cell_identity_key].dtype
63
+
64
+ var_dataframe = _create_var_dataframe(adata, layer, grouping_keys, groups_to_drop)
65
+
66
+ data = {}
67
+ meta = {}
68
+ groups = adata.obs.groupby(grouping_keys, observed=True).groups
69
+ for group, group_idxs in groups.items():
70
+ if not isinstance(group, tuple):
71
+ group = (group,)
72
+
73
+ if not _including(group, groups_to_drop):
74
+ continue
75
+
76
+ sample_id = "_".join(group)
77
+ match group:
78
+ case (gid, cid):
79
+ group_metadata = {group_key: gid, cell_identity_key: cid}
80
+ case (gid,):
81
+ group_metadata = {group_key: gid}
82
+
83
+ adata_group = adata[group_idxs]
84
+ indices = _get_replica_idxs(adata_group, replicas_per_group, bootstrap_sampling)
85
+ for i, rep_idx in enumerate(indices):
86
+ replica_number = i + 1
87
+ replica_size = len(rep_idx)
88
+ replica_sample_id = f"{sample_id}_rep{replica_number}"
89
+
90
+ adata_group_replica = adata_group[rep_idx]
91
+ X = _get_layer(adata_group_replica, layer)
92
+
93
+ data[replica_sample_id] = np.array(X.sum(axis=0)).flatten()
94
+ meta[replica_sample_id] = {
95
+ **group_metadata,
96
+ "replica": str(replica_number),
97
+ "replica_size": replica_size,
98
+ }
99
+
100
+ data = pd.DataFrame(data).T
101
+ meta = pd.DataFrame(meta).T
102
+ meta["replica"] = meta["replica"].astype("category")
103
+ meta[group_key] = meta[group_key].astype(group_dtype)
104
+ if cell_identity_key is not None:
105
+ meta[cell_identity_key] = meta[cell_identity_key].astype(cell_identity_dtype)
106
+
107
+ aggr_adata = AnnData(
108
+ data.values,
109
+ obs=meta,
110
+ var=var_dataframe,
111
+ )
112
+
113
+ _join_dummies(aggr_adata, group_key)
114
+
115
+ return aggr_adata
116
+
117
+
118
+ def _prepare_dataset(
119
+ adata: AnnData,
120
+ use_cells: dict[str, list[str]] | None,
121
+ ) -> AnnData:
122
+ if use_cells is not None:
123
+ for key, value in use_cells.items():
124
+ adata = adata[adata.obs[key].isin(value)]
125
+
126
+ return adata.copy()
127
+
128
+
129
+ def _get_groups_to_drop(
130
+ adata: AnnData,
131
+ grouping_keys: str | list[str],
132
+ min_cells_per_group: int,
133
+ ):
134
+ group_sizes = adata.obs.groupby(grouping_keys, observed=True).size()
135
+ groups_to_drop = group_sizes[group_sizes < min_cells_per_group].index.to_list()
136
+
137
+ if len(groups_to_drop) > 0:
138
+ print("Dropping the following samples:")
139
+
140
+ groups_to_drop = groups_to_drop + [
141
+ (g,) for g in groups_to_drop if not isinstance(g, tuple)
142
+ ]
143
+
144
+ return groups_to_drop
145
+
146
+
147
+ def _prepare_categorical_column(adata: AnnData, column: str) -> None:
148
+ if not isinstance(adata.obs[column].dtype, pd.CategoricalDtype):
149
+ adata.obs[column] = adata.obs[column].astype("category")
150
+
151
+
152
+ def _create_var_dataframe(
153
+ adata: AnnData,
154
+ layer: str,
155
+ grouping_keys: list[str],
156
+ groups_to_drop: list[str],
157
+ ):
158
+ columns = _get_var_dataframe_columns(adata, grouping_keys, groups_to_drop)
159
+ var_dataframe = pd.DataFrame(index=adata.var_names, columns=columns, dtype=float)
160
+
161
+ groups = adata.obs.groupby(grouping_keys, observed=True).groups
162
+ for group, idx in groups.items():
163
+ if not isinstance(group, tuple):
164
+ group = (group,)
165
+
166
+ if not _including(group, groups_to_drop):
167
+ continue
168
+
169
+ sample_id = "_".join(group)
170
+ rest_id = f"not{sample_id}"
171
+
172
+ adata_subset = adata[idx]
173
+ rest_subset = adata[~adata.obs_names.isin(idx)]
174
+
175
+ X = _get_layer(adata_subset, layer, dense=True)
176
+ Y = _get_layer(rest_subset, layer, dense=True)
177
+
178
+ var_dataframe[f"pct_expr_{sample_id}"] = (X > 0).mean(axis=0)
179
+ var_dataframe[f"pct_expr_{rest_id}"] = (Y > 0).mean(axis=0)
180
+ var_dataframe[f"num_expr_{sample_id}"] = (X > 0).sum(axis=0)
181
+ var_dataframe[f"num_expr_{rest_id}"] = (Y > 0).sum(axis=0)
182
+ var_dataframe[f"tot_expr_{sample_id}"] = X.sum(axis=0)
183
+ var_dataframe[f"tot_expr_{rest_id}"] = Y.sum(axis=0)
184
+
185
+ return var_dataframe
186
+
187
+
188
+ def _get_var_dataframe_columns(
189
+ adata: AnnData, grouping_keys: list[str], groups_to_drop: list[str]
190
+ ) -> list[str]:
191
+ columns = []
192
+
193
+ groups = adata.obs.groupby(grouping_keys, observed=True).groups
194
+ for group, _ in groups.items():
195
+ if not isinstance(group, tuple):
196
+ group = (group,)
197
+
198
+ if not _including(group, groups_to_drop):
199
+ continue
200
+
201
+ sample_id = "_".join(group)
202
+ rest_id = f"not{sample_id}"
203
+
204
+ columns.extend(
205
+ [
206
+ f"pct_expr_{sample_id}",
207
+ f"pct_expr_{rest_id}",
208
+ f"num_expr_{sample_id}",
209
+ f"num_expr_{rest_id}",
210
+ f"tot_expr_{sample_id}",
211
+ f"tot_expr_{rest_id}",
212
+ ]
213
+ )
214
+
215
+ return columns
216
+
217
+
218
+ def _including(group: tuple | str, groups_to_drop: list[str]) -> bool:
219
+ match group:
220
+ case (gid, cid):
221
+ if isinstance(cid, float) and np.isnan(cid):
222
+ return False
223
+
224
+ case (gid,) | gid:
225
+ ...
226
+
227
+ if gid in groups_to_drop:
228
+ return False
229
+
230
+ return True
231
+
232
+
233
+ def _get_replica_idxs(
234
+ adata_group: AnnData,
235
+ replicas_per_group: int,
236
+ bootstrap_sampling: bool,
237
+ ):
238
+ group_size = adata_group.n_obs
239
+ indices = list(adata_group.obs_names)
240
+ if bootstrap_sampling:
241
+ indices = np.array(
242
+ [
243
+ np.random.choice(indices, size=group_size, replace=True)
244
+ for _ in range(replicas_per_group)
245
+ ]
246
+ )
247
+
248
+ else:
249
+ random.shuffle(indices)
250
+ indices = np.array_split(np.array(indices), replicas_per_group)
251
+
252
+ return indices
253
+
254
+
255
+ def _get_layer(adata: AnnData, layer: str | None, dense: bool = False):
256
+ X: ndarray | csr_matrix
257
+
258
+ if layer is None or layer == "X":
259
+ X = adata.X
260
+ else:
261
+ X = adata.layers[layer]
262
+
263
+ if dense:
264
+ if issparse(X):
265
+ X = np.asarray(X.todense())
266
+ else:
267
+ X = np.asarray(X)
268
+
269
+ return X
270
+
271
+
272
+ def _join_dummies(aggr_adata: AnnData, group_key: str) -> None:
273
+ dummies = pd.get_dummies(aggr_adata.obs[group_key], prefix=group_key).astype(str)
274
+ dummies = dummies.astype(str).apply(lambda s: s.map({"True": "", "False": "not"}))
275
+ dummies = dummies + aggr_adata.obs[group_key].cat.categories
276
+
277
+ aggr_adata.obs = aggr_adata.obs.join(dummies)
@@ -0,0 +1,5 @@
1
+ from ._scrublet import scrublet
2
+
3
+ __all__ = [
4
+ "scrublet",
5
+ ]
@@ -0,0 +1,64 @@
1
+ from importlib.util import find_spec
2
+ from typing import Any
3
+
4
+ import pandas as pd
5
+ from anndata import AnnData
6
+ from numpy import ndarray
7
+
8
+
9
+ def scrublet(
10
+ adata: AnnData,
11
+ layer: str = "X",
12
+ key_added: str = "scrublet",
13
+ total_counts: ndarray | None = None,
14
+ sim_doublet_ratio: float = 2.0,
15
+ n_neighbors: int = None,
16
+ expected_doublet_rate: float = 0.1,
17
+ stdev_doublet_rate: float = 0.02,
18
+ random_state: int = 0,
19
+ scrub_doublets_kwargs: dict[str, Any] = dict(
20
+ synthetic_doublet_umi_subsampling=1.0,
21
+ use_approx_neighbors=True,
22
+ distance_metric="euclidean",
23
+ get_doublet_neighbor_parents=False,
24
+ min_counts=3,
25
+ min_cells=3,
26
+ min_gene_variability_pctl=85,
27
+ log_transform=False,
28
+ mean_center=True,
29
+ normalize_variance=True,
30
+ n_prin_comps=30,
31
+ svd_solver="arpack",
32
+ verbose=True,
33
+ ),
34
+ ):
35
+ if find_spec("scrublet") is None:
36
+ raise ImportError(
37
+ "scrublet is not installed. Install with:\npip install scrublet"
38
+ )
39
+ from scrublet import Scrublet # noqa: E402
40
+
41
+ if layer == "X":
42
+ X = adata.X
43
+ else:
44
+ X = adata.layers[layer]
45
+
46
+ scrub = Scrublet(
47
+ counts_matrix=X,
48
+ total_counts=total_counts,
49
+ sim_doublet_ratio=sim_doublet_ratio,
50
+ n_neighbors=n_neighbors,
51
+ expected_doublet_rate=expected_doublet_rate,
52
+ stdev_doublet_rate=stdev_doublet_rate,
53
+ random_state=random_state,
54
+ )
55
+
56
+ _scores, labels = scrub.scrub_doublets(**scrub_doublets_kwargs)
57
+ if labels is not None:
58
+ _labels = list(map(lambda v: "doublet" if v else "singlet", labels))
59
+ _labels = pd.Categorical(_labels, ["singlet", "doublet"])
60
+ adata.obs[f"{key_added}_label"] = _labels
61
+ else:
62
+ adata.obs[f"{key_added}_label"] = "singlet"
63
+
64
+ adata.obs[f"{key_added}_score"] = _scores
@@ -0,0 +1,6 @@
1
+ from . import sctype
2
+
3
+
4
+ __all__ = [
5
+ "sctype",
6
+ ]
@@ -0,0 +1,233 @@
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from anndata import AnnData
7
+ from numpy.typing import NDArray
8
+ from scipy import stats
9
+ from scipy.sparse import csc_matrix, csr_matrix, issparse
10
+
11
+ from ...preprocess import pool_neighbors
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def _get_classification_scores_matrix(
17
+ adata: AnnData,
18
+ markers: pd.DataFrame,
19
+ marker_class_key: str,
20
+ neighbors_key: Optional[str] = None,
21
+ weighted_pooling: bool = False,
22
+ directed_pooling: bool = True,
23
+ layer: Optional[str] = None,
24
+ penalize_non_specific: bool = True,
25
+ ):
26
+ # Ianevski, A., Giri, A.K. & Aittokallio, T.
27
+ # Fully-automated and ultra-fast cell-type identification using specific
28
+ # marker combinations from single-cell transcriptomic data.
29
+ # Nat Commun 13, 1246 (2022).
30
+ # https://doi.org/10.1038/s41467-022-28803-w
31
+
32
+ if layer is not None:
33
+ X = adata.layers[layer]
34
+
35
+ else:
36
+ X = adata.X
37
+
38
+ min_val: np.number = X.min()
39
+ M = X > min_val
40
+ n_cells = np.asarray(M.sum(axis=0)).squeeze()
41
+ mask = n_cells > 5
42
+ print(f"using {mask.sum()} genes")
43
+
44
+ markers = markers.loc[markers["names"].isin(adata.var_names[mask])].copy()
45
+ classes = markers[marker_class_key].cat.categories
46
+
47
+ x = markers[[marker_class_key, "names"]].groupby("names").count()[marker_class_key]
48
+ if penalize_non_specific:
49
+ S = 1.0 - (x - x.min()) / (x.max() - x.min())
50
+ S = S[S > 0]
51
+ else:
52
+ S = x * 0.0 + 1.0
53
+
54
+ X: NDArray | csr_matrix | csc_matrix
55
+ if neighbors_key is not None:
56
+ X = pool_neighbors(
57
+ adata[:, S.index],
58
+ layer=layer,
59
+ neighbors_key=neighbors_key,
60
+ weighted=weighted_pooling,
61
+ directed=directed_pooling,
62
+ copy=True,
63
+ )
64
+
65
+ elif layer is not None:
66
+ X = adata[:, S.index].layers[layer].copy()
67
+
68
+ else:
69
+ X = adata[:, S.index].X.copy()
70
+
71
+ if issparse(X):
72
+ X = np.asarray(X.todense("C"))
73
+
74
+ Z: NDArray
75
+ Z = stats.zscore(X, axis=0)
76
+ Xp = Z * S.values
77
+
78
+ Xc = np.zeros((adata.shape[0], len(classes)))
79
+ for c, cell_class in enumerate(classes):
80
+ if cell_class == "Unknown":
81
+ continue
82
+ up_genes = markers.loc[
83
+ (markers[marker_class_key] == cell_class) & (markers["logfoldchanges"] > 0),
84
+ "names",
85
+ ]
86
+ dw_genes = markers.loc[
87
+ (markers[marker_class_key] == cell_class) & (markers["logfoldchanges"] < 0),
88
+ "names",
89
+ ]
90
+ x_up = Xp[:, S.index.isin(up_genes)]
91
+ x_dw = Xp[:, S.index.isin(dw_genes)]
92
+ if len(up_genes) > 0:
93
+ Xc[:, c] += x_up.sum(axis=1) / np.sqrt(len(up_genes))
94
+ if len(dw_genes) > 0:
95
+ Xc[:, c] -= x_dw.sum(axis=1) / np.sqrt(len(dw_genes))
96
+
97
+ return Xc
98
+
99
+
100
+ def classify_cells(
101
+ adata: AnnData,
102
+ markers: pd.DataFrame,
103
+ marker_class_key: Optional[str] = None,
104
+ cluster_key: Optional[str] = None,
105
+ layer: Optional[str] = None,
106
+ key_added: Optional[str] = None,
107
+ threshold: float = 0.25,
108
+ penalize_non_specific: bool = True,
109
+ neighbors_key: Optional[str] = None,
110
+ save_scores: bool = False,
111
+ ):
112
+ """
113
+ Classify cells based on a set of marker genes.
114
+
115
+ Ianevski, A., Giri, A.K. & Aittokallio, T.
116
+ Fully-automated and ultra-fast cell-type identification using specific
117
+ marker combinations from single-cell transcriptomic data.
118
+ Nat Commun 13, 1246 (2022).
119
+ https://doi.org/10.1038/s41467-022-28803-w
120
+
121
+ Parameters
122
+ ----------
123
+ adata
124
+ AnnData object.
125
+ markers
126
+ Marker genes.
127
+ marker_class_key
128
+ Column in `markers` that contains the cell type information.
129
+ cluster_key
130
+ Column in `adata.obs` that contains the cluster information. If
131
+ not provided, the classification will be performed on a cell by cell
132
+ basis, pooling across neighbor cells. This pooling can be avoided by
133
+ setting `force_pooling` to `False`.
134
+ layer
135
+ Layer to use for classification. Defaults to `X`.
136
+ key_added
137
+ Key under which to add the classification information.
138
+ threshold
139
+ Confidence threshold for classification. Defaults to `0.25`.
140
+ penalize_non_specific
141
+ Whether to penalize non-specific markers. Defaults to `True`.
142
+ neighbors_key
143
+ If provided, counts will be pooled across neighbor cells using the
144
+ distances in `adata.uns[neighbors_key]["distances"]`. Defaults to `None`.
145
+ save_scores
146
+ Whether to save the classification scores. Defaults to `False`
147
+ """
148
+ # cite("10.1038/s41467-022-28803-w", __package__)
149
+
150
+ if marker_class_key is not None:
151
+ marker_class = markers[marker_class_key]
152
+ if not marker_class.dtype.name.startswith("category"):
153
+ markers[marker_class_key] = marker_class.astype("category")
154
+ else:
155
+ col_mask = markers.dtypes == "category"
156
+ assert col_mask.sum() == 1, (
157
+ "markers_df must have exactly one column of type 'category'"
158
+ )
159
+ marker_class_key = markers.loc[:, col_mask].squeeze().name
160
+
161
+ classes = markers[marker_class_key].cat.categories
162
+ dtype = markers[marker_class_key].dtype
163
+
164
+ # if doing cell by cell classification, we should pool counts to use cell
165
+ # neighborhood information. This allows to estimate the confidence of the
166
+ # classification. We specify pooling by providing a neighbors_key.
167
+ posXc = _get_classification_scores_matrix(
168
+ adata,
169
+ markers.query("logfoldchanges > 0"),
170
+ marker_class_key,
171
+ neighbors_key,
172
+ weighted_pooling=True,
173
+ directed_pooling=True,
174
+ layer=layer,
175
+ penalize_non_specific=penalize_non_specific,
176
+ )
177
+ negXc = _get_classification_scores_matrix(
178
+ adata,
179
+ markers.query("logfoldchanges < 0"),
180
+ marker_class_key,
181
+ neighbors_key,
182
+ weighted_pooling=True,
183
+ directed_pooling=True,
184
+ layer=layer,
185
+ penalize_non_specific=penalize_non_specific,
186
+ )
187
+ Xc = posXc + negXc
188
+
189
+ if cluster_key is not None:
190
+ mappings = {}
191
+ mappings_nona = {}
192
+ for c in adata.obs[cluster_key].cat.categories:
193
+ cluster_scores_matrix = Xc[adata.obs[cluster_key] == c]
194
+ n_cells_in_cluster = cluster_scores_matrix.shape[0]
195
+
196
+ scores = cluster_scores_matrix.sum(axis=0)
197
+ confidence = scores.max() / n_cells_in_cluster
198
+ if confidence >= threshold:
199
+ mappings[c] = classes[np.argmax(scores)]
200
+ else:
201
+ mappings[c] = pd.NA
202
+ logger.warning(
203
+ f"Cluster {str(c):>5} classified as Unknown with confidence score {confidence: 8.2f}"
204
+ )
205
+ mappings_nona[c] = classes[np.argmax(scores)]
206
+ classifications = adata.obs[cluster_key].map(mappings).astype(dtype)
207
+ classifications_nona = adata.obs[cluster_key].map(mappings_nona).astype(dtype)
208
+ else:
209
+ if neighbors_key is not None:
210
+ n_neigs = adata.uns[neighbors_key]["params"]["n_neighbors"]
211
+ else:
212
+ n_neigs = 1
213
+ index = adata.obs_names
214
+ classifications = classes.values[Xc.argmax(axis=1)]
215
+ classifications = pd.Series(classifications, index=index).astype(dtype)
216
+ classifications_nona = classifications.copy()
217
+ classifications.loc[Xc.max(axis=1) < threshold * n_neigs] = pd.NA
218
+
219
+ N = len(classifications)
220
+ n_unknowns = pd.isna(classifications).sum()
221
+ n_estimated = N - n_unknowns
222
+
223
+ logger.info(f"Estimated types for {n_estimated} cells ({n_estimated / N:.2%})")
224
+
225
+ if key_added is None:
226
+ key_added = marker_class_key
227
+
228
+ adata.obs[key_added] = classifications
229
+ adata.obs[key_added + "_noNA"] = classifications_nona
230
+
231
+ if save_scores:
232
+ adata.obs[key_added + "_score"] = Xc.max(axis=1)
233
+ adata.obsm[key_added + "_scores"] = Xc
@@ -0,0 +1,5 @@
1
+ from ._write_excel import write_excel
2
+
3
+ __all__ = [
4
+ "write_excel",
5
+ ]