sclab 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,257 @@
1
+ import pandas as pd
2
+ from anndata import AnnData
3
+
4
+ from ._pseudobulk_helpers import aggregate_and_filter
5
+
6
+
7
+ def pseudobulk_limma(
8
+ adata_: AnnData,
9
+ group_key: str,
10
+ condition_group: str | list[str] | None = None,
11
+ reference_group: str | None = None,
12
+ cell_identity_key: str | None = None,
13
+ batch_key: str | None = None,
14
+ layer: str | None = None,
15
+ replicas_per_group: int = 5,
16
+ min_cells_per_group: int = 30,
17
+ bootstrap_sampling: bool = False,
18
+ use_cells: dict[str, list[str]] | None = None,
19
+ aggregate: bool = True,
20
+ verbosity: int = 0,
21
+ ) -> dict[str, pd.DataFrame]:
22
+ _try_imports()
23
+ import anndata2ri # noqa: F401
24
+ import rpy2.robjects as robjects
25
+ from rpy2.rinterface_lib.embedded import RRuntimeError # noqa: F401
26
+ from rpy2.robjects import pandas2ri # noqa: F401
27
+ from rpy2.robjects.conversion import localconverter # noqa: F401
28
+
29
+ R = robjects.r
30
+
31
+ if aggregate:
32
+ aggr_adata = aggregate_and_filter(
33
+ adata_,
34
+ group_key,
35
+ cell_identity_key,
36
+ layer,
37
+ replicas_per_group,
38
+ min_cells_per_group,
39
+ bootstrap_sampling,
40
+ use_cells,
41
+ )
42
+ else:
43
+ aggr_adata = adata_.copy()
44
+
45
+ with localconverter(anndata2ri.converter):
46
+ R.assign("aggr_adata", aggr_adata)
47
+
48
+ # defines the R function for fitting the model with limma
49
+ R(_fit_model_r_script)
50
+
51
+ if condition_group is None:
52
+ condition_group_list = aggr_adata.obs[group_key].unique()
53
+ elif isinstance(condition_group, str):
54
+ condition_group_list = [condition_group]
55
+ else:
56
+ condition_group_list = condition_group
57
+
58
+ if cell_identity_key is not None:
59
+ cids = aggr_adata.obs[cell_identity_key].unique()
60
+ else:
61
+ cids = [""]
62
+
63
+ tt_dict = {}
64
+ for condition_group in condition_group_list:
65
+ if reference_group is not None and condition_group == reference_group:
66
+ continue
67
+
68
+ if verbosity > 0:
69
+ print(f"Fitting model for {condition_group}...")
70
+
71
+ if reference_group is not None:
72
+ gk = group_key
73
+ else:
74
+ gk = f"{group_key}_{condition_group}"
75
+
76
+ try:
77
+ R(f"""
78
+ outs <- fit_limma_model(aggr_adata, "{gk}", "{cell_identity_key}", verbosity = {verbosity})
79
+ fit <- outs$fit
80
+ v <- outs$v
81
+ """)
82
+
83
+ except RRuntimeError as e:
84
+ print("Error fitting model for", condition_group)
85
+ print("Error:", e)
86
+ print("Skipping...", flush=True)
87
+ continue
88
+
89
+ if reference_group is None:
90
+ new_contrasts_tuples = [
91
+ (
92
+ condition_group, # common prefix
93
+ "", # condition group
94
+ "not", # reference group
95
+ cid, # cell identity
96
+ )
97
+ for cid in cids
98
+ ]
99
+
100
+ else:
101
+ new_contrasts_tuples = [
102
+ (
103
+ "", # common prefix
104
+ condition_group, # condition group
105
+ reference_group, # reference group
106
+ cid, # cell identity
107
+ )
108
+ for cid in cids
109
+ ]
110
+
111
+ new_contrasts = [
112
+ f"group{cnd}{prefix}_{cid}".strip("_")
113
+ + "-"
114
+ + f"group{ref}{prefix}_{cid}".strip("_")
115
+ for prefix, cnd, ref, cid in new_contrasts_tuples
116
+ ]
117
+
118
+ for contrast, contrast_tuple in zip(new_contrasts, new_contrasts_tuples):
119
+ prefix, cnd, ref, cid = contrast_tuple
120
+
121
+ if ref == "not":
122
+ cnd, ref = "", "rest"
123
+
124
+ contrast_key = f"{prefix}{cnd}_vs_{ref}"
125
+ if cid:
126
+ contrast_key = f"{cell_identity_key}:{cid}|{contrast_key}"
127
+
128
+ if verbosity > 0:
129
+ print(f"Computing contrast: {contrast_key}... ({contrast})")
130
+
131
+ R(f"myContrast <- makeContrasts('{contrast}', levels = v$design)")
132
+ R("fit2 <- contrasts.fit(fit, myContrast)")
133
+ R("fit2 <- eBayes(fit2)")
134
+ R("tt <- topTable(fit2, n = Inf)")
135
+ tt: pd.DataFrame = pandas2ri.rpy2py(R("tt"))
136
+ tt.index.name = "gene_ids"
137
+
138
+ genes = tt.index
139
+ cnd, ref = [c[5:] for c in contrast.split("-")]
140
+ tt["pct_expr_cnd"] = aggr_adata.var[f"pct_expr_{cnd}"].loc[genes]
141
+ tt["pct_expr_ref"] = aggr_adata.var[f"pct_expr_{ref}"].loc[genes]
142
+ tt["num_expr_cnd"] = aggr_adata.var[f"num_expr_{cnd}"].loc[genes]
143
+ tt["num_expr_ref"] = aggr_adata.var[f"num_expr_{ref}"].loc[genes]
144
+ tt["tot_expr_cnd"] = aggr_adata.var[f"tot_expr_{cnd}"].loc[genes]
145
+ tt["tot_expr_ref"] = aggr_adata.var[f"tot_expr_{ref}"].loc[genes]
146
+ tt["mean_cnd"] = tt["tot_expr_cnd"] / tt["num_expr_cnd"]
147
+ tt["mean_ref"] = tt["tot_expr_ref"] / tt["num_expr_ref"]
148
+ tt_dict[contrast_key] = tt
149
+
150
+ return tt_dict
151
+
152
+
153
+ _fit_model_r_script = """
154
+ suppressPackageStartupMessages({
155
+ library(edgeR)
156
+ library(limma)
157
+ library(MAST)
158
+ })
159
+
160
+ fit_limma_model <- function(adata_, group_key, cell_identity_key = "None", batch_key = "None", verbosity = 0){
161
+
162
+ if (verbosity > 0){
163
+ cat("Group key:", group_key, "\n")
164
+ cat("Cell identity key:", cell_identity_key, "\n")
165
+ }
166
+
167
+ # create a vector that is concatentation of condition and cell type that we will later use with contrasts
168
+ if (cell_identity_key == "None"){
169
+ group <- colData(adata_)[[group_key]]
170
+ } else {
171
+ group <- paste0(colData(adata_)[[group_key]], "_", colData(adata_)[[cell_identity_key]])
172
+ }
173
+
174
+ if (verbosity > 1){
175
+ cat("Group(s):", group, "\n")
176
+ }
177
+
178
+ group <- factor(group)
179
+ replica <- factor(colData(adata_)$replica)
180
+
181
+ # create a design matrix
182
+ if (batch_key == "None"){
183
+ design <- model.matrix(~ 0 + group + replica)
184
+ } else {
185
+ batch <- factor(colData(adata_)[[batch_key]])
186
+ design <- model.matrix(~ 0 + group + replica + batch)
187
+ }
188
+ colnames(design) <- make.names(colnames(design))
189
+
190
+ # create an edgeR object with counts and grouping factor
191
+ y <- DGEList(assay(adata_, "X"), group = group)
192
+
193
+ # filter out genes with low counts
194
+ if (verbosity > 1){
195
+ cat("Dimensions before subsetting:", dim(y), "\n")
196
+ }
197
+
198
+ keep <- filterByExpr(y, design = design)
199
+ y <- y[keep, , keep.lib.sizes=FALSE]
200
+ if (verbosity > 1){
201
+ cat("Dimensions after subsetting:", dim(y), "\n")
202
+ }
203
+
204
+ # normalize
205
+ y <- calcNormFactors(y)
206
+
207
+ # Apply voom transformation to prepare for linear modeling
208
+ v <- voom(y, design = design)
209
+
210
+ # fit the linear model
211
+ fit <- lmFit(v, design)
212
+ ne <- limma::nonEstimable(design)
213
+ if (!is.null(ne) && verbosity > 0) cat("Non-estimable:", ne, "\n")
214
+ fit <- eBayes(fit)
215
+
216
+ return(list("fit"=fit, "design"=design, "v"=v))
217
+ }
218
+ """
219
+
220
+
221
+ def _try_imports():
222
+ try:
223
+ import rpy2.robjects as robjects
224
+ from rpy2.robjects.packages import PackageNotInstalledError, importr
225
+
226
+ robjects.r("options(warn=-1)")
227
+ import anndata2ri # noqa: F401
228
+ from rpy2.rinterface_lib.embedded import RRuntimeError # noqa: F401
229
+ from rpy2.robjects import numpy2ri, pandas2ri # noqa: F401
230
+ from rpy2.robjects.conversion import localconverter # noqa: F401
231
+
232
+ importr("edgeR")
233
+ importr("limma")
234
+ importr("MAST")
235
+ importr("SingleCellExperiment")
236
+
237
+ except ModuleNotFoundError:
238
+ message = (
239
+ "pseudobulk_limma requires rpy2 and anndata2ri to be installed.\n"
240
+ "please install with one of the following:\n"
241
+ "$ pip install rpy2 anndata2ri\n"
242
+ "or\n"
243
+ "$ conda install -c conda-forge rpy2 anndata2ri\n"
244
+ )
245
+ print(message)
246
+ raise ModuleNotFoundError(message)
247
+
248
+ except PackageNotInstalledError:
249
+ message = (
250
+ "pseudobulk_limma requires the following R packages to be installed: limma, edgeR, MAST, and SingleCellExperiment.\n"
251
+ "> \n"
252
+ "> if (!require('BiocManager', quietly = TRUE)) install.packages('BiocManager');\n"
253
+ "> BiocManager::install(c('limma', 'edgeR', 'MAST', 'SingleCellExperiment'));\n"
254
+ "> \n"
255
+ )
256
+ print(message)
257
+ raise ImportError(message)
File without changes
File without changes
@@ -0,0 +1,135 @@
1
+ import gc
2
+
3
+ import numpy as np
4
+ from numpy import float32
5
+ from numpy.typing import NDArray
6
+ from scipy.sparse import csc_matrix, csr_matrix, issparse
7
+
8
+
9
+ def _alra_on_ndarray(
10
+ data: NDArray | csr_matrix,
11
+ ) -> tuple[NDArray[float32], NDArray[float32]]:
12
+ """
13
+ Run ALRA on the given data.
14
+
15
+ Parameters
16
+ ----------
17
+ data : NDArray | csr_matrix
18
+ Input data to impute.
19
+
20
+ Returns
21
+ -------
22
+ data_aprx : NDArray
23
+ Approximated data.
24
+ data_alra : NDArray
25
+ Imputed data.
26
+ """
27
+ import rpy2.robjects as robjects
28
+ import rpy2.robjects.numpy2ri
29
+ from rpy2.robjects.packages import importr
30
+
31
+ rpy2.robjects.numpy2ri.activate()
32
+ R = robjects.r
33
+ alra = importr("ALRA")
34
+
35
+ if issparse(data):
36
+ data = np.ascontiguousarray(data.todense("C"), dtype=np.float32)
37
+
38
+ # convert to R object
39
+ r_X = R.matrix(data, nrow=data.shape[0], ncol=data.shape[1])
40
+ # run ALRA
41
+ r_res = alra.alra(r_X, 0, 10, 0.001)
42
+ # retrieve imputed data
43
+ r_K = r_res[0] # rank k
44
+ r_T = r_res[1] # rank k thresholded
45
+ r_S = r_res[2] # rank k thresholded scaled
46
+ # convert back to numpy array
47
+ data_aprx = np.array(r_K, dtype=float32)
48
+ data_thrs = np.array(r_T, dtype=float32)
49
+ data_alra = np.array(r_S, dtype=float32)
50
+
51
+ # clean up
52
+ del (r_X, r_res, r_K, r_T, r_S)
53
+ R("gc()")
54
+ gc.collect()
55
+
56
+ return data_aprx, data_thrs, data_alra
57
+
58
+
59
+ def _fix_alra_scale(
60
+ input_data: NDArray | csr_matrix | csc_matrix,
61
+ thrs_data: NDArray,
62
+ target_data: NDArray,
63
+ ) -> NDArray:
64
+ # Convert sparse -> dense
65
+ if issparse(input_data):
66
+ input_data = input_data.toarray("C")
67
+ input_data = input_data.astype(np.float32)
68
+ input_data = np.ascontiguousarray(input_data, dtype=np.float32)
69
+
70
+ n_cells, n_genes = input_data.shape
71
+
72
+ # per-gene nonzero means/sds (match R: sample sd ddof=1)
73
+ input_means = np.full(n_genes, fill_value=np.nan)
74
+ input_stds = np.full(n_genes, fill_value=np.nan)
75
+ thrs_means = np.full(n_genes, fill_value=np.nan)
76
+ thrs_stds = np.full(n_genes, fill_value=np.nan)
77
+ v: NDArray
78
+
79
+ for i, e in enumerate(input_data.T):
80
+ v = e[e > 0]
81
+
82
+ if v.size == 0:
83
+ continue
84
+ input_means[i] = v.mean()
85
+
86
+ if v.size == 1:
87
+ continue
88
+ input_stds[i] = v.std(ddof=1)
89
+
90
+ for i, e in enumerate(thrs_data.T):
91
+ v = e[e > 0]
92
+
93
+ if v.size == 0:
94
+ continue
95
+ thrs_means[i] = v.mean()
96
+
97
+ if v.size == 1:
98
+ continue
99
+ thrs_stds[i] = v.std(ddof=1)
100
+
101
+ # columns to scale (mirror R's toscale)
102
+ toscale = (
103
+ ~np.isnan(thrs_stds)
104
+ & ~np.isnan(input_stds)
105
+ & ~((thrs_stds == 0) & (input_stds == 0))
106
+ & ~(thrs_stds == 0)
107
+ )
108
+
109
+ # affine params
110
+ a = np.full(n_genes, fill_value=1.0)
111
+ b = np.full(n_genes, fill_value=0.0)
112
+ a[toscale] = input_stds[toscale] / thrs_stds[toscale]
113
+ b[toscale] = input_means[toscale] - a[toscale] * thrs_means[toscale]
114
+
115
+ # apply to target matrix (only columns in toscale)
116
+ out = target_data.copy()
117
+ out[:, toscale] = out[:, toscale] * a[toscale] + b[toscale]
118
+
119
+ # keep zeros as zeros
120
+ out[thrs_data == 0] = 0
121
+
122
+ # clip negatives to zero
123
+ out[out < 0] = 0
124
+
125
+ # restore originally observed positives that became zero
126
+ mask = (input_data > 0) & (out == 0)
127
+ out[mask] = input_data[mask]
128
+
129
+ return out
130
+
131
+
132
+ __all__ = [
133
+ "_alra_on_ndarray",
134
+ "_fix_alra_scale",
135
+ ]
@@ -0,0 +1,5 @@
1
+ from ._aggregate_and_filter import aggregate_and_filter
2
+
3
+ __all__ = [
4
+ "aggregate_and_filter",
5
+ ]
@@ -0,0 +1,290 @@
1
+ import random
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from anndata import AnnData
6
+ from numpy import ndarray
7
+ from scipy.sparse import csr_matrix, issparse
8
+
9
+
10
+ # code inspired from
11
+ # https://www.sc-best-practices.org/conditions/differential_gene_expression.html
12
+ def aggregate_and_filter(
13
+ adata: AnnData,
14
+ group_key: str = "batch",
15
+ cell_identity_key: str | None = None,
16
+ layer: str | None = None,
17
+ replicas_per_group: int = 3,
18
+ min_cells_per_group: int = 30,
19
+ bootstrap_sampling: bool = False,
20
+ use_cells: dict[str, list[str]] | None = None,
21
+ make_stats: bool = True,
22
+ make_dummies: bool = True,
23
+ ) -> AnnData:
24
+ """
25
+ Aggregate and filter cells in an AnnData object into cell populations.
26
+
27
+ Parameters
28
+ ----------
29
+ adata : AnnData
30
+ AnnData object to aggregate and filter.
31
+ group_key : str, optional
32
+ Key to group cells by. Defaults to 'batch'.
33
+ cell_identity_key : str, optional
34
+ Key to use to identify cell identities. Defaults to None.
35
+ layer : str, optional
36
+ Layer in AnnData object to use for aggregation. Defaults to None.
37
+ replicas_per_group : int, optional
38
+ Number of replicas to create for each group. Defaults to 3.
39
+ min_cells_per_group : int, optional
40
+ Minimum number of cells required for a group to be included. Defaults to 30.
41
+ bootstrap_sampling : bool, optional
42
+ Whether to use bootstrap sampling to create replicas. Defaults to False.
43
+ use_cells : dict[str, list[str]], optional
44
+ If not None, only use the specified cells. Defaults to None.
45
+ make_stats : bool, optional
46
+ Whether to create expression statistics for each group. Defaults to True.
47
+ make_dummies : bool, optional
48
+ Whether to make categorical columns into dummies. Defaults to True.
49
+
50
+ Returns
51
+ -------
52
+ AnnData
53
+ AnnData object with aggregated and filtered cells.
54
+ """
55
+ adata = _prepare_dataset(adata, use_cells)
56
+
57
+ grouping_keys = [group_key]
58
+ if cell_identity_key is not None:
59
+ grouping_keys.append(cell_identity_key)
60
+
61
+ groups_to_drop = _get_groups_to_drop(adata, grouping_keys, min_cells_per_group)
62
+
63
+ _prepare_categorical_column(adata, group_key)
64
+ group_dtype = adata.obs[group_key].dtype
65
+
66
+ if cell_identity_key is not None:
67
+ _prepare_categorical_column(adata, cell_identity_key)
68
+ cell_identity_dtype = adata.obs[cell_identity_key].dtype
69
+
70
+ if make_stats:
71
+ var_dataframe = _create_var_dataframe(
72
+ adata, layer, grouping_keys, groups_to_drop
73
+ )
74
+ else:
75
+ var_dataframe = pd.DataFrame(index=adata.var_names)
76
+
77
+ data = {}
78
+ meta = {}
79
+ groups = adata.obs.groupby(grouping_keys, observed=True).groups
80
+ for group, group_idxs in groups.items():
81
+ if not isinstance(group, tuple):
82
+ group = (group,)
83
+
84
+ if not _including(group, groups_to_drop):
85
+ continue
86
+
87
+ sample_id = "_".join(group)
88
+ match group:
89
+ case (gid, cid):
90
+ group_metadata = {group_key: gid, cell_identity_key: cid}
91
+ case (gid,):
92
+ group_metadata = {group_key: gid}
93
+
94
+ adata_group = adata[group_idxs]
95
+ indices = _get_replica_idxs(adata_group, replicas_per_group, bootstrap_sampling)
96
+ for i, rep_idx in enumerate(indices):
97
+ replica_number = i + 1
98
+ replica_size = len(rep_idx)
99
+ replica_sample_id = f"{sample_id}_rep{replica_number}"
100
+
101
+ adata_group_replica = adata_group[rep_idx]
102
+ X = _get_layer(adata_group_replica, layer)
103
+
104
+ data[replica_sample_id] = np.array(X.sum(axis=0)).flatten()
105
+ meta[replica_sample_id] = {
106
+ **group_metadata,
107
+ "replica": str(replica_number),
108
+ "replica_size": replica_size,
109
+ }
110
+
111
+ data = pd.DataFrame(data).T
112
+ meta = pd.DataFrame(meta).T
113
+ meta["replica"] = meta["replica"].astype("category")
114
+ meta["replica_size"] = meta["replica_size"].astype(int)
115
+ meta[group_key] = meta[group_key].astype(group_dtype)
116
+ if cell_identity_key is not None:
117
+ meta[cell_identity_key] = meta[cell_identity_key].astype(cell_identity_dtype)
118
+
119
+ aggr_adata = AnnData(
120
+ data.values,
121
+ obs=meta,
122
+ var=var_dataframe,
123
+ )
124
+
125
+ if make_dummies:
126
+ _join_dummies(aggr_adata, group_key)
127
+
128
+ return aggr_adata
129
+
130
+
131
+ def _prepare_dataset(
132
+ adata: AnnData,
133
+ use_cells: dict[str, list[str]] | None,
134
+ ) -> AnnData:
135
+ if use_cells is not None:
136
+ for key, value in use_cells.items():
137
+ adata = adata[adata.obs[key].isin(value)]
138
+
139
+ return adata.copy()
140
+
141
+
142
+ def _get_groups_to_drop(
143
+ adata: AnnData,
144
+ grouping_keys: str | list[str],
145
+ min_cells_per_group: int,
146
+ ):
147
+ group_sizes = adata.obs.groupby(grouping_keys, observed=True).size()
148
+ groups_to_drop = group_sizes[group_sizes < min_cells_per_group].index.to_list()
149
+
150
+ if len(groups_to_drop) > 0:
151
+ print("Dropping the following samples:")
152
+
153
+ groups_to_drop = groups_to_drop + [
154
+ (g,) for g in groups_to_drop if not isinstance(g, tuple)
155
+ ]
156
+
157
+ return groups_to_drop
158
+
159
+
160
+ def _prepare_categorical_column(adata: AnnData, column: str) -> None:
161
+ if not isinstance(adata.obs[column].dtype, pd.CategoricalDtype):
162
+ adata.obs[column] = adata.obs[column].astype("category")
163
+
164
+
165
+ def _create_var_dataframe(
166
+ adata: AnnData,
167
+ layer: str,
168
+ grouping_keys: list[str],
169
+ groups_to_drop: list[str],
170
+ ):
171
+ columns = _get_var_dataframe_columns(adata, grouping_keys, groups_to_drop)
172
+ var_dataframe = pd.DataFrame(index=adata.var_names, columns=columns, dtype=float)
173
+
174
+ groups = adata.obs.groupby(grouping_keys, observed=True).groups
175
+ for group, idx in groups.items():
176
+ if not isinstance(group, tuple):
177
+ group = (group,)
178
+
179
+ if not _including(group, groups_to_drop):
180
+ continue
181
+
182
+ sample_id = "_".join(group)
183
+ rest_id = f"not{sample_id}"
184
+
185
+ adata_subset = adata[idx]
186
+ rest_subset = adata[~adata.obs_names.isin(idx)]
187
+
188
+ X = _get_layer(adata_subset, layer, dense=True)
189
+ Y = _get_layer(rest_subset, layer, dense=True)
190
+
191
+ var_dataframe[f"pct_expr_{sample_id}"] = (X > 0).mean(axis=0)
192
+ var_dataframe[f"pct_expr_{rest_id}"] = (Y > 0).mean(axis=0)
193
+ var_dataframe[f"num_expr_{sample_id}"] = (X > 0).sum(axis=0)
194
+ var_dataframe[f"num_expr_{rest_id}"] = (Y > 0).sum(axis=0)
195
+ var_dataframe[f"tot_expr_{sample_id}"] = X.sum(axis=0)
196
+ var_dataframe[f"tot_expr_{rest_id}"] = Y.sum(axis=0)
197
+
198
+ return var_dataframe
199
+
200
+
201
+ def _get_var_dataframe_columns(
202
+ adata: AnnData, grouping_keys: list[str], groups_to_drop: list[str]
203
+ ) -> list[str]:
204
+ columns = []
205
+
206
+ groups = adata.obs.groupby(grouping_keys, observed=True).groups
207
+ for group, _ in groups.items():
208
+ if not isinstance(group, tuple):
209
+ group = (group,)
210
+
211
+ if not _including(group, groups_to_drop):
212
+ continue
213
+
214
+ sample_id = "_".join(group)
215
+ rest_id = f"not{sample_id}"
216
+
217
+ columns.extend(
218
+ [
219
+ f"pct_expr_{sample_id}",
220
+ f"pct_expr_{rest_id}",
221
+ f"num_expr_{sample_id}",
222
+ f"num_expr_{rest_id}",
223
+ f"tot_expr_{sample_id}",
224
+ f"tot_expr_{rest_id}",
225
+ ]
226
+ )
227
+
228
+ return columns
229
+
230
+
231
+ def _including(group: tuple | str, groups_to_drop: list[str]) -> bool:
232
+ match group:
233
+ case (gid, cid):
234
+ if isinstance(cid, float) and np.isnan(cid):
235
+ return False
236
+
237
+ case (gid,) | gid:
238
+ ...
239
+
240
+ if gid in groups_to_drop:
241
+ return False
242
+
243
+ return True
244
+
245
+
246
+ def _get_replica_idxs(
247
+ adata_group: AnnData,
248
+ replicas_per_group: int,
249
+ bootstrap_sampling: bool,
250
+ ):
251
+ group_size = adata_group.n_obs
252
+ indices = list(adata_group.obs_names)
253
+ if bootstrap_sampling:
254
+ indices = np.array(
255
+ [
256
+ np.random.choice(indices, size=group_size, replace=True)
257
+ for _ in range(replicas_per_group)
258
+ ]
259
+ )
260
+
261
+ else:
262
+ random.shuffle(indices)
263
+ indices = np.array_split(np.array(indices), replicas_per_group)
264
+
265
+ return indices
266
+
267
+
268
+ def _get_layer(adata: AnnData, layer: str | None, dense: bool = False):
269
+ X: ndarray | csr_matrix
270
+
271
+ if layer is None or layer == "X":
272
+ X = adata.X
273
+ else:
274
+ X = adata.layers[layer]
275
+
276
+ if dense:
277
+ if issparse(X):
278
+ X = np.asarray(X.todense())
279
+ else:
280
+ X = np.asarray(X)
281
+
282
+ return X
283
+
284
+
285
+ def _join_dummies(aggr_adata: AnnData, group_key: str) -> None:
286
+ dummies = pd.get_dummies(aggr_adata.obs[group_key], prefix=group_key).astype(str)
287
+ dummies = dummies.astype(str).apply(lambda s: s.map({"True": "", "False": "not"}))
288
+ dummies = dummies + aggr_adata.obs[group_key].cat.categories
289
+
290
+ aggr_adata.obs = aggr_adata.obs.join(dummies)