pertpy 0.9.5__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +5 -1
- pertpy/_doc.py +2 -5
- pertpy/_types.py +6 -0
- pertpy/data/_dataloader.py +68 -24
- pertpy/data/_datasets.py +9 -9
- pertpy/metadata/__init__.py +2 -1
- pertpy/metadata/_cell_line.py +136 -30
- pertpy/metadata/_look_up.py +13 -19
- pertpy/metadata/_moa.py +1 -1
- pertpy/preprocessing/_guide_rna.py +221 -39
- pertpy/preprocessing/_guide_rna_mixture.py +177 -0
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +138 -142
- pertpy/tools/_cinemaot.py +75 -117
- pertpy/tools/_coda/_base_coda.py +150 -174
- pertpy/tools/_coda/_sccoda.py +66 -69
- pertpy/tools/_coda/_tasccoda.py +71 -79
- pertpy/tools/_dialogue.py +60 -56
- pertpy/tools/_differential_gene_expression/_base.py +25 -43
- pertpy/tools/_differential_gene_expression/_checks.py +4 -6
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
- pertpy/tools/_differential_gene_expression/_edger.py +6 -10
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
- pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
- pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
- pertpy/tools/_distances/_distance_tests.py +1 -2
- pertpy/tools/_distances/_distances.py +86 -92
- pertpy/tools/_enrichment.py +8 -25
- pertpy/tools/_milo.py +23 -27
- pertpy/tools/_mixscape.py +261 -175
- pertpy/tools/_perturbation_space/_clustering.py +4 -4
- pertpy/tools/_perturbation_space/_comparison.py +4 -4
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
- pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
- pertpy/tools/_perturbation_space/_simple.py +13 -17
- pertpy/tools/_scgen/_scgen.py +17 -20
- pertpy/tools/_scgen/_scgenvae.py +2 -2
- pertpy/tools/_scgen/_utils.py +3 -1
- {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/METADATA +37 -21
- pertpy-0.11.0.dist-info/RECORD +58 -0
- {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
- pertpy/tools/_kernel_pca.py +0 -50
- pertpy-0.9.5.dist-info/RECORD +0 -57
- {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
pertpy/tools/_mixscape.py
CHANGED
@@ -9,15 +9,14 @@ import numpy as np
|
|
9
9
|
import pandas as pd
|
10
10
|
import scanpy as sc
|
11
11
|
import seaborn as sns
|
12
|
+
from fast_array_utils.stats import mean, mean_var
|
12
13
|
from scanpy import get
|
13
|
-
from scanpy._settings import settings
|
14
14
|
from scanpy._utils import _check_use_raw, sanitize_anndata
|
15
15
|
from scanpy.plotting import _utils
|
16
16
|
from scanpy.tools._utils import _choose_representation
|
17
|
-
from scipy.sparse import csr_matrix,
|
17
|
+
from scipy.sparse import csr_matrix, spmatrix
|
18
18
|
from sklearn.mixture import GaussianMixture
|
19
19
|
|
20
|
-
import pertpy as pt
|
21
20
|
from pertpy._doc import _doc_params, doc_common_plot_args
|
22
21
|
|
23
22
|
if TYPE_CHECKING:
|
@@ -41,9 +40,12 @@ class Mixscape:
|
|
41
40
|
adata: AnnData,
|
42
41
|
pert_key: str,
|
43
42
|
control: str,
|
43
|
+
*,
|
44
|
+
ref_selection_mode: Literal["nn", "split_by"] = "nn",
|
44
45
|
split_by: str | None = None,
|
45
46
|
n_neighbors: int = 20,
|
46
47
|
use_rep: str | None = None,
|
48
|
+
n_dims: int | None = 15,
|
47
49
|
n_pcs: int | None = None,
|
48
50
|
batch_size: int | None = None,
|
49
51
|
copy: bool = False,
|
@@ -51,14 +53,18 @@ class Mixscape:
|
|
51
53
|
):
|
52
54
|
"""Calculate perturbation signature.
|
53
55
|
|
54
|
-
|
55
|
-
|
56
|
-
|
56
|
+
The perturbation signature is calculated by subtracting the mRNA expression profile of each cell from the averaged
|
57
|
+
mRNA expression profile of the control cells (selected according to `ref_selection_mode`).
|
58
|
+
The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the
|
59
|
+
perturbation signature is calculated on unscaled data by default, and we therefore recommend to do the same.
|
57
60
|
|
58
61
|
Args:
|
59
62
|
adata: The annotated data object.
|
60
63
|
pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
|
61
|
-
control:
|
64
|
+
control: Name of the control category from the `pert_key` column.
|
65
|
+
ref_selection_mode: Method to select reference cells for the perturbation signature calculation. If `nn`,
|
66
|
+
the `n_neighbors` cells from the control pool with the most similar mRNA expression profiles are selected. If `split_by`,
|
67
|
+
the control cells from the same split in `split_by` (e.g. indicating biological replicates) are used to calculate the perturbation signature.
|
62
68
|
split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
|
63
69
|
the perturbation signature for every replicate separately.
|
64
70
|
n_neighbors: Number of neighbors from the control to use for the perturbation signature.
|
@@ -66,7 +72,10 @@ class Mixscape:
|
|
66
72
|
If `None`, the representation is chosen automatically:
|
67
73
|
For `.n_vars` < 50, `.X` is used, otherwise 'X_pca' is used.
|
68
74
|
If 'X_pca' is not present, it’s computed with default parameters.
|
69
|
-
|
75
|
+
n_dims: Number of dimensions to use from the representation to calculate the perturbation signature.
|
76
|
+
If `None`, use all dimensions.
|
77
|
+
n_pcs: If PCA representation is used, the number of principal components to compute.
|
78
|
+
If `n_pcs==0` use `.X` if `use_rep is None`.
|
70
79
|
batch_size: Size of batch to calculate the perturbation signature.
|
71
80
|
If 'None', the perturbation signature is calcuated in the full mode, requiring more memory.
|
72
81
|
The batched mode is very inefficient for sparse data.
|
@@ -83,8 +92,13 @@ class Mixscape:
|
|
83
92
|
>>> import pertpy as pt
|
84
93
|
>>> mdata = pt.dt.papalexi_2021()
|
85
94
|
>>> ms_pt = pt.tl.Mixscape()
|
86
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
95
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
87
96
|
"""
|
97
|
+
if ref_selection_mode not in ["nn", "split_by"]:
|
98
|
+
raise ValueError("ref_selection_mode must be either 'nn' or 'split_by'.")
|
99
|
+
if ref_selection_mode == "split_by" and split_by is None:
|
100
|
+
raise ValueError("split_by must be provided if ref_selection_mode is 'split_by'.")
|
101
|
+
|
88
102
|
if copy:
|
89
103
|
adata = adata.copy()
|
90
104
|
|
@@ -92,59 +106,73 @@ class Mixscape:
|
|
92
106
|
|
93
107
|
control_mask = adata.obs[pert_key] == control
|
94
108
|
|
95
|
-
if
|
96
|
-
|
109
|
+
if ref_selection_mode == "split_by":
|
110
|
+
for split in adata.obs[split_by].unique():
|
111
|
+
split_mask = adata.obs[split_by] == split
|
112
|
+
control_mask_group = control_mask & split_mask
|
113
|
+
control_mean_expr = mean(adata.X[control_mask_group], axis=0)
|
114
|
+
adata.layers["X_pert"][split_mask] = (
|
115
|
+
np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0)
|
116
|
+
- adata.layers["X_pert"][split_mask]
|
117
|
+
)
|
97
118
|
else:
|
98
|
-
|
99
|
-
|
119
|
+
if split_by is None:
|
120
|
+
split_masks = [np.full(adata.n_obs, True, dtype=bool)]
|
121
|
+
else:
|
122
|
+
split_obs = adata.obs[split_by]
|
123
|
+
split_masks = [split_obs == cat for cat in split_obs.unique()]
|
100
124
|
|
101
|
-
|
125
|
+
representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
|
126
|
+
if n_dims is not None and n_dims < representation.shape[1]:
|
127
|
+
representation = representation[:, :n_dims]
|
102
128
|
|
103
|
-
|
104
|
-
control_mask_split = control_mask & split_mask
|
129
|
+
from pynndescent import NNDescent
|
105
130
|
|
106
|
-
|
107
|
-
|
131
|
+
for split_mask in split_masks:
|
132
|
+
control_mask_split = control_mask & split_mask
|
108
133
|
|
109
|
-
|
134
|
+
R_split = representation[split_mask]
|
135
|
+
R_control = representation[np.asarray(control_mask_split)]
|
110
136
|
|
111
|
-
|
112
|
-
|
113
|
-
|
137
|
+
eps = kwargs.pop("epsilon", 0.1)
|
138
|
+
nn_index = NNDescent(R_control, **kwargs)
|
139
|
+
indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
|
114
140
|
|
115
|
-
|
141
|
+
X_control = np.expm1(adata.X[np.asarray(control_mask_split)])
|
116
142
|
|
117
|
-
|
118
|
-
|
143
|
+
n_split = split_mask.sum()
|
144
|
+
n_control = X_control.shape[0]
|
119
145
|
|
120
|
-
|
121
|
-
|
122
|
-
|
146
|
+
if batch_size is None:
|
147
|
+
col_indices = np.ravel(indices)
|
148
|
+
row_indices = np.repeat(np.arange(n_split), n_neighbors)
|
123
149
|
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
150
|
+
neigh_matrix = csr_matrix(
|
151
|
+
(np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)),
|
152
|
+
shape=(n_split, n_control),
|
153
|
+
)
|
154
|
+
neigh_matrix /= n_neighbors
|
155
|
+
adata.layers["X_pert"][np.asarray(split_mask)] = (
|
156
|
+
sc.pp.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][np.asarray(split_mask)]
|
157
|
+
)
|
158
|
+
else:
|
159
|
+
split_indices = np.where(split_mask)[0]
|
160
|
+
for i in range(0, n_split, batch_size):
|
161
|
+
size = min(i + batch_size, n_split)
|
162
|
+
select = slice(i, size)
|
136
163
|
|
137
|
-
|
138
|
-
|
164
|
+
batch = np.ravel(indices[select])
|
165
|
+
split_batch = split_indices[select]
|
139
166
|
|
140
|
-
|
167
|
+
size = size - i
|
141
168
|
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1)
|
169
|
+
means_batch = X_control[batch]
|
170
|
+
batch_reshaped = means_batch.reshape(size, n_neighbors, -1)
|
171
|
+
means_batch, _ = mean_var(batch_reshaped, axis=1)
|
146
172
|
|
147
|
-
|
173
|
+
adata.layers["X_pert"][split_batch] = (
|
174
|
+
np.log1p(means_batch) - adata.layers["X_pert"][split_batch]
|
175
|
+
)
|
148
176
|
|
149
177
|
if copy:
|
150
178
|
return adata
|
@@ -154,34 +182,44 @@ class Mixscape:
|
|
154
182
|
adata: AnnData,
|
155
183
|
labels: str,
|
156
184
|
control: str,
|
185
|
+
*,
|
157
186
|
new_class_name: str | None = "mixscape_class",
|
158
|
-
min_de_genes: int | None = 5,
|
159
187
|
layer: str | None = None,
|
188
|
+
min_de_genes: int | None = 5,
|
160
189
|
logfc_threshold: float | None = 0.25,
|
190
|
+
de_layer: str | None = None,
|
191
|
+
test_method: str | None = "wilcoxon",
|
161
192
|
iter_num: int | None = 10,
|
193
|
+
scale: bool | None = True,
|
162
194
|
split_by: str | None = None,
|
163
195
|
pval_cutoff: float | None = 5e-2,
|
164
196
|
perturbation_type: str | None = "KO",
|
197
|
+
random_state: int | None = 0,
|
165
198
|
copy: bool | None = False,
|
199
|
+
**gmmkwargs,
|
166
200
|
):
|
167
201
|
"""Identify perturbed and non-perturbed gRNA expressing cells that accounts for multiple treatments/conditions/chemical perturbations.
|
168
202
|
|
169
|
-
The implementation resembles https://satijalab.org/seurat/reference/runmixscape
|
203
|
+
The implementation resembles https://satijalab.org/seurat/reference/runmixscape.
|
170
204
|
|
171
205
|
Args:
|
172
206
|
adata: The annotated data object.
|
173
207
|
labels: The column of `.obs` with target gene labels.
|
174
|
-
control: Control category from the `
|
208
|
+
control: Control category from the `labels` column.
|
175
209
|
new_class_name: Name of mixscape classification to be stored in `.obs`.
|
176
|
-
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
177
210
|
layer: Key from adata.layers whose value will be used to perform tests on. Default is using `.layers["X_pert"]`.
|
211
|
+
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
178
212
|
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells (default: 0.25).
|
213
|
+
de_layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
|
214
|
+
test_method: Method to use for differential expression testing.
|
179
215
|
iter_num: Number of normalmixEM iterations to run if convergence does not occur.
|
180
|
-
|
181
|
-
|
216
|
+
scale: Scale the data specified in `layer` before running the GaussianMixture model on it.
|
217
|
+
split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific.
|
182
218
|
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
183
219
|
perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
220
|
+
random_state: Random seed for the GaussianMixture model.
|
184
221
|
copy: Determines whether a copy of the `adata` is returned.
|
222
|
+
**gmmkwargs: Passed to custom implementation of scikit-learn Gaussian Mixture Model.
|
185
223
|
|
186
224
|
Returns:
|
187
225
|
If `copy=True`, returns the copy of `adata` with the classification result in `.obs`.
|
@@ -203,8 +241,8 @@ class Mixscape:
|
|
203
241
|
>>> import pertpy as pt
|
204
242
|
>>> mdata = pt.dt.papalexi_2021()
|
205
243
|
>>> ms_pt = pt.tl.Mixscape()
|
206
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
207
|
-
>>> ms_pt.mixscape(
|
244
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
245
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
208
246
|
"""
|
209
247
|
if copy:
|
210
248
|
adata = adata.copy()
|
@@ -218,7 +256,16 @@ class Mixscape:
|
|
218
256
|
split_masks = [split_obs == category for category in categories]
|
219
257
|
|
220
258
|
perturbation_markers = self._get_perturbation_markers(
|
221
|
-
adata,
|
259
|
+
adata=adata,
|
260
|
+
split_masks=split_masks,
|
261
|
+
categories=categories,
|
262
|
+
labels=labels,
|
263
|
+
control=control,
|
264
|
+
layer=de_layer,
|
265
|
+
pval_cutoff=pval_cutoff,
|
266
|
+
min_de_genes=min_de_genes,
|
267
|
+
logfc_threshold=logfc_threshold,
|
268
|
+
test_method=test_method,
|
222
269
|
)
|
223
270
|
|
224
271
|
adata_comp = adata
|
@@ -231,6 +278,7 @@ class Mixscape:
|
|
231
278
|
raise KeyError(
|
232
279
|
"No 'X_pert' found in .layers! Please run perturbation_signature first to calculate perturbation signature!"
|
233
280
|
) from None
|
281
|
+
|
234
282
|
# initialize return variables
|
235
283
|
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
|
236
284
|
adata.obs[new_class_name] = adata.obs[labels].astype(str)
|
@@ -241,10 +289,12 @@ class Mixscape:
|
|
241
289
|
dtype=np.object_,
|
242
290
|
)
|
243
291
|
gv_list: dict[str, dict] = {}
|
292
|
+
|
293
|
+
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0.0
|
244
294
|
for split, split_mask in enumerate(split_masks):
|
245
295
|
category = categories[split]
|
246
|
-
|
247
|
-
for gene in
|
296
|
+
gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
|
297
|
+
for gene in gene_targets:
|
248
298
|
post_prob = 0
|
249
299
|
orig_guide_cells = (adata.obs[labels] == gene) & split_mask
|
250
300
|
orig_guide_cells_index = list(orig_guide_cells.index[orig_guide_cells])
|
@@ -253,63 +303,79 @@ class Mixscape:
|
|
253
303
|
|
254
304
|
if len(perturbation_markers[(category, gene)]) == 0:
|
255
305
|
adata.obs.loc[orig_guide_cells, new_class_name] = f"{gene} NP"
|
306
|
+
|
256
307
|
else:
|
257
308
|
de_genes = perturbation_markers[(category, gene)]
|
258
|
-
de_genes_indices =
|
309
|
+
de_genes_indices = np.where(np.isin(adata.var_names, list(de_genes)))[0]
|
310
|
+
|
259
311
|
dat = X[np.asarray(all_cells)][:, de_genes_indices]
|
312
|
+
if scale:
|
313
|
+
dat = sc.pp.scale(dat)
|
314
|
+
|
260
315
|
converged = False
|
261
316
|
n_iter = 0
|
262
|
-
old_classes = adata.obs[
|
317
|
+
old_classes = adata.obs[new_class_name][all_cells]
|
318
|
+
|
319
|
+
nt_cells_dat_idx = all_cells[all_cells].index.get_indexer(nt_cells[nt_cells].index)
|
320
|
+
nt_cells_mean = np.mean(dat[nt_cells_dat_idx], axis=0)
|
321
|
+
|
263
322
|
while not converged and n_iter < iter_num:
|
264
323
|
# Get all cells in current split&Gene
|
265
|
-
guide_cells = (adata.obs[
|
324
|
+
guide_cells = (adata.obs[new_class_name] == gene) & split_mask
|
325
|
+
|
266
326
|
# get average value for each gene over all selected cells
|
267
327
|
# all cells in current split&Gene minus all NT cells in current split
|
268
328
|
# Each row is for each cell, each column is for each gene, get mean for each column
|
269
|
-
|
270
|
-
|
271
|
-
|
329
|
+
guide_cells_dat_idx = all_cells[all_cells].index.get_indexer(guide_cells[guide_cells].index)
|
330
|
+
guide_cells_mean = np.mean(dat[guide_cells_dat_idx], axis=0)
|
331
|
+
vec = guide_cells_mean - nt_cells_mean
|
332
|
+
|
272
333
|
# project cells onto the perturbation vector
|
273
334
|
if isinstance(dat, spmatrix):
|
274
|
-
pvec =
|
335
|
+
pvec = dat.dot(vec) / np.dot(vec, vec)
|
275
336
|
else:
|
276
|
-
pvec = np.
|
337
|
+
pvec = np.dot(dat, vec) / np.dot(vec, vec)
|
277
338
|
pvec = pd.Series(np.asarray(pvec).flatten(), index=list(all_cells.index[all_cells]))
|
339
|
+
|
278
340
|
if n_iter == 0:
|
279
341
|
gv = pd.DataFrame(columns=["pvec", labels])
|
280
342
|
gv["pvec"] = pvec
|
281
343
|
gv[labels] = control
|
282
344
|
gv.loc[guide_cells, labels] = gene
|
283
|
-
if gene not in gv_list
|
345
|
+
if gene not in gv_list:
|
284
346
|
gv_list[gene] = {}
|
285
347
|
gv_list[gene][category] = gv
|
286
348
|
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
precisions_init = np.array([nt_norm[1], guide_norm[1]])
|
291
|
-
mm = GaussianMixture(
|
349
|
+
means_init = np.array([[pvec[nt_cells].mean()], [pvec[guide_cells].mean()]])
|
350
|
+
std_init = np.array([pvec[nt_cells].std(), pvec[guide_cells].std()])
|
351
|
+
mm = MixscapeGaussianMixture(
|
292
352
|
n_components=2,
|
293
353
|
covariance_type="spherical",
|
294
354
|
means_init=means_init,
|
295
|
-
precisions_init=
|
355
|
+
precisions_init=1 / (std_init**2),
|
356
|
+
random_state=random_state,
|
357
|
+
max_iter=100,
|
358
|
+
fixed_means=[pvec[nt_cells].mean(), None],
|
359
|
+
fixed_covariances=[pvec[nt_cells].std() ** 2, None],
|
360
|
+
**gmmkwargs,
|
296
361
|
).fit(np.asarray(pvec).reshape(-1, 1))
|
297
362
|
probabilities = mm.predict_proba(np.array(pvec[orig_guide_cells_index]).reshape(-1, 1))
|
298
363
|
lik_ratio = probabilities[:, 0] / probabilities[:, 1]
|
299
364
|
post_prob = 1 / (1 + lik_ratio)
|
365
|
+
|
300
366
|
# based on the posterior probability, assign cells to the two classes
|
301
|
-
|
302
|
-
|
303
|
-
] = gene
|
304
|
-
|
305
|
-
[orig_guide_cells_index[cell] for cell in np.where(post_prob <= 0.5)[0]], new_class_name
|
306
|
-
] = f"{gene} NP"
|
367
|
+
ko_mask = post_prob > 0.5
|
368
|
+
adata.obs.loc[np.array(orig_guide_cells_index)[ko_mask], new_class_name] = gene
|
369
|
+
adata.obs.loc[np.array(orig_guide_cells_index)[~ko_mask], new_class_name] = f"{gene} NP"
|
370
|
+
|
307
371
|
if sum(adata.obs[new_class_name][split_mask] == gene) < min_de_genes:
|
308
372
|
adata.obs.loc[guide_cells, new_class_name] = "NP"
|
309
373
|
converged = True
|
310
|
-
|
374
|
+
current_classes = adata.obs[new_class_name][all_cells]
|
375
|
+
if (current_classes == old_classes).all():
|
311
376
|
converged = True
|
312
|
-
old_classes =
|
377
|
+
old_classes = current_classes
|
378
|
+
|
313
379
|
n_iter += 1
|
314
380
|
|
315
381
|
adata.obs.loc[(adata.obs[new_class_name] == gene) & split_mask, new_class_name] = (
|
@@ -317,9 +383,7 @@ class Mixscape:
|
|
317
383
|
)
|
318
384
|
|
319
385
|
adata.obs[f"{new_class_name}_global"] = [a.split(" ")[-1] for a in adata.obs[new_class_name]]
|
320
|
-
adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] =
|
321
|
-
post_prob
|
322
|
-
).astype("int64")
|
386
|
+
adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = post_prob
|
323
387
|
adata.uns["mixscape"] = gv_list
|
324
388
|
|
325
389
|
if copy:
|
@@ -330,11 +394,13 @@ class Mixscape:
|
|
330
394
|
adata: AnnData,
|
331
395
|
labels: str,
|
332
396
|
control: str,
|
397
|
+
*,
|
333
398
|
mixscape_class_global: str | None = "mixscape_class_global",
|
334
399
|
layer: str | None = None,
|
335
400
|
n_comps: int | None = 10,
|
336
401
|
min_de_genes: int | None = 5,
|
337
402
|
logfc_threshold: float | None = 0.25,
|
403
|
+
test_method: str | None = "wilcoxon",
|
338
404
|
split_by: str | None = None,
|
339
405
|
pval_cutoff: float | None = 5e-2,
|
340
406
|
perturbation_type: str | None = "KO",
|
@@ -347,12 +413,12 @@ class Mixscape:
|
|
347
413
|
labels: The column of `.obs` with target gene labels.
|
348
414
|
control: Control category from the `pert_key` column.
|
349
415
|
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
|
350
|
-
layer:
|
351
|
-
control: Control category from the `pert_key` column.
|
416
|
+
layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
|
352
417
|
n_comps: Number of principal components to use.
|
353
418
|
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
354
419
|
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
|
355
|
-
|
420
|
+
test_method: Method to use for differential expression testing.
|
421
|
+
split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific.
|
356
422
|
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
357
423
|
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
358
424
|
copy: Determines whether a copy of the `adata` is returned.
|
@@ -370,9 +436,9 @@ class Mixscape:
|
|
370
436
|
>>> import pertpy as pt
|
371
437
|
>>> mdata = pt.dt.papalexi_2021()
|
372
438
|
>>> ms_pt = pt.tl.Mixscape()
|
373
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
374
|
-
>>> ms_pt.mixscape(
|
375
|
-
>>> ms_pt.lda(
|
439
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
440
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
441
|
+
>>> ms_pt.lda(mdata["rna"], "gene_target", "NT")
|
376
442
|
"""
|
377
443
|
if copy:
|
378
444
|
adata = adata.copy()
|
@@ -388,9 +454,8 @@ class Mixscape:
|
|
388
454
|
categories = split_obs.unique()
|
389
455
|
split_masks = [split_obs == category for category in categories]
|
390
456
|
|
391
|
-
mixscape_identifier = pt.tl.Mixscape()
|
392
457
|
# determine gene sets across all splits/groups through differential gene expression
|
393
|
-
perturbation_markers =
|
458
|
+
perturbation_markers = self._get_perturbation_markers(
|
394
459
|
adata=adata,
|
395
460
|
split_masks=split_masks,
|
396
461
|
categories=categories,
|
@@ -400,10 +465,12 @@ class Mixscape:
|
|
400
465
|
pval_cutoff=pval_cutoff,
|
401
466
|
min_de_genes=min_de_genes,
|
402
467
|
logfc_threshold=logfc_threshold,
|
468
|
+
test_method=test_method,
|
403
469
|
)
|
404
470
|
adata_subset = adata[
|
405
471
|
(adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
|
406
|
-
]
|
472
|
+
]
|
473
|
+
X = adata_subset.X - adata_subset.X.mean(0)
|
407
474
|
projected_pcs: dict[str, np.ndarray] = {}
|
408
475
|
# performs PCA on each mixscape class separately and projects each subspace onto all cells in the data.
|
409
476
|
for _, (key, value) in enumerate(perturbation_markers.items()):
|
@@ -415,16 +482,10 @@ class Mixscape:
|
|
415
482
|
].copy()
|
416
483
|
sc.pp.scale(gene_subset)
|
417
484
|
sc.tl.pca(gene_subset, n_comps=n_comps)
|
418
|
-
|
419
|
-
|
420
|
-
sc.tl.ingest(adata=adata_subset, adata_ref=gene_subset, embedding_method="pca")
|
421
|
-
projected_pcs[key[1]] = adata_subset.obsm["X_pca"]
|
485
|
+
# project cells into PCA space of gene_subset
|
486
|
+
projected_pcs[key[1]] = np.asarray(np.dot(X, gene_subset.varm["PCs"]))
|
422
487
|
# concatenate all pcs into a single matrix.
|
423
|
-
|
424
|
-
if index == 0:
|
425
|
-
projected_pcs_array = value
|
426
|
-
else:
|
427
|
-
projected_pcs_array = np.concatenate((projected_pcs_array, value), axis=1)
|
488
|
+
projected_pcs_array = np.concatenate(list(projected_pcs.values()), axis=1)
|
428
489
|
|
429
490
|
clf = LinearDiscriminantAnalysis(n_components=len(np.unique(adata_subset.obs[labels])) - 1)
|
430
491
|
clf.fit(projected_pcs_array, adata_subset.obs[labels])
|
@@ -445,12 +506,21 @@ class Mixscape:
|
|
445
506
|
pval_cutoff: float,
|
446
507
|
min_de_genes: float,
|
447
508
|
logfc_threshold: float,
|
509
|
+
test_method: str,
|
448
510
|
) -> dict[tuple, np.ndarray]:
|
449
|
-
"""Determine gene sets across all splits/groups through differential gene expression
|
511
|
+
"""Determine gene sets across all splits/groups through differential gene expression.
|
450
512
|
|
451
513
|
Args:
|
452
514
|
adata: :class:`~anndata.AnnData` object
|
453
|
-
|
515
|
+
split_masks: List of boolean masks for each split/group.
|
516
|
+
categories: List of split/group names.
|
517
|
+
labels: The column of `.obs` with target gene labels.
|
518
|
+
control: Control category from the `labels` column.
|
519
|
+
layer: Key from adata.layers whose value will be used to compare gene expression.
|
520
|
+
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
521
|
+
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
522
|
+
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
|
523
|
+
test_method: Method to use for differential expression testing.
|
454
524
|
|
455
525
|
Returns:
|
456
526
|
Set of column indices.
|
@@ -459,21 +529,23 @@ class Mixscape:
|
|
459
529
|
for split, split_mask in enumerate(split_masks):
|
460
530
|
category = categories[split]
|
461
531
|
# get gene sets for each split
|
462
|
-
|
532
|
+
gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
|
463
533
|
adata_split = adata[split_mask].copy()
|
464
534
|
# find top DE genes between cells with targeting and non-targeting gRNAs
|
465
535
|
sc.tl.rank_genes_groups(
|
466
536
|
adata_split,
|
467
537
|
layer=layer,
|
468
538
|
groupby=labels,
|
469
|
-
groups=
|
539
|
+
groups=gene_targets,
|
470
540
|
reference=control,
|
471
|
-
method=
|
541
|
+
method=test_method,
|
472
542
|
use_raw=False,
|
473
543
|
)
|
474
|
-
# get DE genes for each gene
|
475
|
-
for gene in
|
476
|
-
logfc_threshold_mask =
|
544
|
+
# get DE genes for each target gene
|
545
|
+
for gene in gene_targets:
|
546
|
+
logfc_threshold_mask = (
|
547
|
+
np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
|
548
|
+
)
|
477
549
|
de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask]
|
478
550
|
pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask]
|
479
551
|
de_genes = de_genes[pvals_adj < pval_cutoff]
|
@@ -483,33 +555,8 @@ class Mixscape:
|
|
483
555
|
|
484
556
|
return perturbation_markers
|
485
557
|
|
486
|
-
def _get_column_indices(self, adata, col_names):
|
487
|
-
if isinstance(col_names, str): # pragma: no cover
|
488
|
-
col_names = [col_names]
|
489
|
-
|
490
|
-
indices = []
|
491
|
-
for idx, col in enumerate(adata.var_names):
|
492
|
-
if col in col_names:
|
493
|
-
indices.append(idx)
|
494
|
-
|
495
|
-
return indices
|
496
|
-
|
497
|
-
def _define_normal_mixscape(self, X: np.ndarray | sparse.spmatrix | pd.DataFrame | None) -> list[float]:
|
498
|
-
"""Calculates the mean and standard deviation of a matrix.
|
499
|
-
|
500
|
-
Args:
|
501
|
-
X: The matrix to calculate the properties for.
|
502
|
-
|
503
|
-
Returns:
|
504
|
-
Mean and standard deviation of the matrix.
|
505
|
-
"""
|
506
|
-
mu = X.mean()
|
507
|
-
sd = X.std()
|
508
|
-
|
509
|
-
return [mu, sd]
|
510
|
-
|
511
558
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
512
|
-
def plot_barplot( # pragma: no cover
|
559
|
+
def plot_barplot( # pragma: no cover # noqa: D417
|
513
560
|
self,
|
514
561
|
adata: AnnData,
|
515
562
|
guide_rna_column: str,
|
@@ -522,7 +569,6 @@ class Mixscape:
|
|
522
569
|
legend_text_size: int = 8,
|
523
570
|
legend_bbox_to_anchor: tuple[float, float] = None,
|
524
571
|
figsize: tuple[float, float] = (25, 25),
|
525
|
-
show: bool = True,
|
526
572
|
return_fig: bool = False,
|
527
573
|
) -> Figure | None:
|
528
574
|
"""Barplot to visualize perturbation scores calculated by the `mixscape` function.
|
@@ -548,8 +594,8 @@ class Mixscape:
|
|
548
594
|
>>> import pertpy as pt
|
549
595
|
>>> mdata = pt.dt.papalexi_2021()
|
550
596
|
>>> ms_pt = pt.tl.Mixscape()
|
551
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
552
|
-
>>> ms_pt.mixscape(
|
597
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
598
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
553
599
|
>>> ms_pt.plot_barplot(mdata["rna"], guide_rna_column="NT")
|
554
600
|
|
555
601
|
Preview:
|
@@ -611,14 +657,13 @@ class Mixscape:
|
|
611
657
|
fig.subplots_adjust(hspace=0.5, wspace=0.5)
|
612
658
|
plt.tight_layout()
|
613
659
|
|
614
|
-
if show:
|
615
|
-
plt.show()
|
616
660
|
if return_fig:
|
617
661
|
return fig
|
662
|
+
plt.show()
|
618
663
|
return None
|
619
664
|
|
620
665
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
621
|
-
def plot_heatmap( # pragma: no cover
|
666
|
+
def plot_heatmap( # pragma: no cover # noqa: D417
|
622
667
|
self,
|
623
668
|
adata: AnnData,
|
624
669
|
labels: str,
|
@@ -630,7 +675,6 @@ class Mixscape:
|
|
630
675
|
subsample_number: int | None = 900,
|
631
676
|
vmin: float | None = -2,
|
632
677
|
vmax: float | None = 2,
|
633
|
-
show: bool = True,
|
634
678
|
return_fig: bool = False,
|
635
679
|
**kwds,
|
636
680
|
) -> Figure | None:
|
@@ -656,8 +700,8 @@ class Mixscape:
|
|
656
700
|
>>> import pertpy as pt
|
657
701
|
>>> mdata = pt.dt.papalexi_2021()
|
658
702
|
>>> ms_pt = pt.tl.Mixscape()
|
659
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
660
|
-
>>> ms_pt.mixscape(
|
703
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
704
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
661
705
|
>>> ms_pt.plot_heatmap(
|
662
706
|
... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT"
|
663
707
|
... )
|
@@ -683,14 +727,13 @@ class Mixscape:
|
|
683
727
|
**kwds,
|
684
728
|
)
|
685
729
|
|
686
|
-
if show:
|
687
|
-
plt.show()
|
688
730
|
if return_fig:
|
689
731
|
return fig
|
732
|
+
plt.show()
|
690
733
|
return None
|
691
734
|
|
692
735
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
693
|
-
def plot_perturbscore( # pragma: no cover
|
736
|
+
def plot_perturbscore( # pragma: no cover # noqa: D417
|
694
737
|
self,
|
695
738
|
adata: AnnData,
|
696
739
|
labels: str,
|
@@ -702,7 +745,6 @@ class Mixscape:
|
|
702
745
|
split_by: str = None,
|
703
746
|
before_mixscape: bool = False,
|
704
747
|
perturbation_type: str = "KO",
|
705
|
-
show: bool = True,
|
706
748
|
return_fig: bool = False,
|
707
749
|
) -> Figure | None:
|
708
750
|
"""Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
|
@@ -734,8 +776,8 @@ class Mixscape:
|
|
734
776
|
>>> import pertpy as pt
|
735
777
|
>>> mdata = pt.dt.papalexi_2021()
|
736
778
|
>>> ms_pt = pt.tl.Mixscape()
|
737
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
738
|
-
>>> ms_pt.mixscape(
|
779
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
780
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
739
781
|
>>> ms_pt.plot_perturbscore(adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange")
|
740
782
|
|
741
783
|
Preview:
|
@@ -744,7 +786,7 @@ class Mixscape:
|
|
744
786
|
if "mixscape" not in adata.uns:
|
745
787
|
raise ValueError("Please run the `mixscape` function first.")
|
746
788
|
perturbation_score = None
|
747
|
-
for key in adata.uns["mixscape"][target_gene]
|
789
|
+
for key in adata.uns["mixscape"][target_gene]:
|
748
790
|
perturbation_score_temp = adata.uns["mixscape"][target_gene][key]
|
749
791
|
perturbation_score_temp["name"] = key
|
750
792
|
if perturbation_score is None:
|
@@ -851,14 +893,13 @@ class Mixscape:
|
|
851
893
|
plt.legend(title="mixscape class", title_fontsize=14, fontsize=12)
|
852
894
|
sns.despine()
|
853
895
|
|
854
|
-
if show:
|
855
|
-
plt.show()
|
856
896
|
if return_fig:
|
857
897
|
return plt.gcf()
|
898
|
+
plt.show()
|
858
899
|
return None
|
859
900
|
|
860
901
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
861
|
-
def plot_violin( # pragma: no cover
|
902
|
+
def plot_violin( # pragma: no cover # noqa: D417
|
862
903
|
self,
|
863
904
|
adata: AnnData,
|
864
905
|
target_gene_idents: str | list[str],
|
@@ -879,7 +920,6 @@ class Mixscape:
|
|
879
920
|
ylabel: str | Sequence[str] | None = None,
|
880
921
|
rotation: float | None = None,
|
881
922
|
ax: Axes | None = None,
|
882
|
-
show: bool = True,
|
883
923
|
return_fig: bool = False,
|
884
924
|
**kwargs,
|
885
925
|
) -> Axes | Figure | None:
|
@@ -910,8 +950,8 @@ class Mixscape:
|
|
910
950
|
>>> import pertpy as pt
|
911
951
|
>>> mdata = pt.dt.papalexi_2021()
|
912
952
|
>>> ms_pt = pt.tl.Mixscape()
|
913
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
914
|
-
>>> ms_pt.mixscape(
|
953
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
954
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
915
955
|
>>> ms_pt.plot_violin(
|
916
956
|
... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class"
|
917
957
|
... )
|
@@ -939,7 +979,7 @@ class Mixscape:
|
|
939
979
|
if len(ylabel) != 1:
|
940
980
|
raise ValueError(f"Expected number of y-labels to be `1`, found `{len(ylabel)}`.")
|
941
981
|
elif len(ylabel) != len(keys):
|
942
|
-
raise ValueError(f"Expected number of y-labels to be `{len(keys)}`,
|
982
|
+
raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, found `{len(ylabel)}`.")
|
943
983
|
|
944
984
|
if groupby is not None:
|
945
985
|
if hue is not None:
|
@@ -992,7 +1032,7 @@ class Mixscape:
|
|
992
1032
|
g.set(yscale="log")
|
993
1033
|
g.set_titles(col_template="{col_name}").set_xlabels("")
|
994
1034
|
if rotation is not None:
|
995
|
-
for ax in g.axes[0]:
|
1035
|
+
for ax in g.axes[0]: # noqa: PLR1704
|
996
1036
|
ax.tick_params(axis="x", labelrotation=rotation)
|
997
1037
|
else:
|
998
1038
|
# set by default the violin plot cut=0 to limit the extend
|
@@ -1010,7 +1050,7 @@ class Mixscape:
|
|
1010
1050
|
else:
|
1011
1051
|
axs = [ax]
|
1012
1052
|
for ax, y, ylab in zip(axs, ys, ylabel, strict=False):
|
1013
|
-
ax = sns.violinplot(
|
1053
|
+
ax = sns.violinplot( # noqa: PLW2901
|
1014
1054
|
x=x,
|
1015
1055
|
y=y,
|
1016
1056
|
data=obs_tidy,
|
@@ -1024,7 +1064,7 @@ class Mixscape:
|
|
1024
1064
|
# Get the handles and labels.
|
1025
1065
|
handles, labels = ax.get_legend_handles_labels()
|
1026
1066
|
if stripplot:
|
1027
|
-
ax = sns.stripplot(
|
1067
|
+
ax = sns.stripplot( # noqa: PLW2901
|
1028
1068
|
x=x,
|
1029
1069
|
y=y,
|
1030
1070
|
data=obs_tidy,
|
@@ -1047,12 +1087,9 @@ class Mixscape:
|
|
1047
1087
|
if rotation is not None:
|
1048
1088
|
ax.tick_params(axis="x", labelrotation=rotation)
|
1049
1089
|
|
1050
|
-
show = settings.autoshow if show is None else show
|
1051
1090
|
if hue is not None and stripplot is True:
|
1052
1091
|
plt.legend(handles, labels)
|
1053
1092
|
|
1054
|
-
if show:
|
1055
|
-
plt.show()
|
1056
1093
|
if return_fig:
|
1057
1094
|
if multi_panel and groupby is None and len(ys) == 1:
|
1058
1095
|
return g
|
@@ -1060,10 +1097,11 @@ class Mixscape:
|
|
1060
1097
|
return axs[0]
|
1061
1098
|
else:
|
1062
1099
|
return axs
|
1100
|
+
plt.show()
|
1063
1101
|
return None
|
1064
1102
|
|
1065
1103
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
1066
|
-
def plot_lda( # pragma: no cover
|
1104
|
+
def plot_lda( # pragma: no cover # noqa: D417
|
1067
1105
|
self,
|
1068
1106
|
adata: AnnData,
|
1069
1107
|
control: str,
|
@@ -1076,20 +1114,22 @@ class Mixscape:
|
|
1076
1114
|
color_map: Colormap | str | None = None,
|
1077
1115
|
palette: str | Sequence[str] | None = None,
|
1078
1116
|
ax: Axes | None = None,
|
1079
|
-
show: bool = True,
|
1080
1117
|
return_fig: bool = False,
|
1081
1118
|
**kwds,
|
1082
1119
|
) -> Figure | None:
|
1083
1120
|
"""Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
|
1084
1121
|
|
1085
1122
|
Args:
|
1086
|
-
adata: The annotated data
|
1123
|
+
adata: The annotated data objectplot_heatmap.
|
1087
1124
|
control: Control category from the `pert_key` column.
|
1088
1125
|
mixscape_class: The column of `.obs` with the mixscape classification result.
|
1089
1126
|
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
|
1090
1127
|
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
1091
|
-
lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
|
1092
1128
|
n_components: The number of dimensions of the embedding.
|
1129
|
+
lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
|
1130
|
+
color_map: Matplotlib color map.
|
1131
|
+
palette: Matplotlib palette.
|
1132
|
+
ax: Matplotlib axes.
|
1093
1133
|
{common_plot_args}
|
1094
1134
|
**kwds: Additional arguments to `scanpy.pl.umap`.
|
1095
1135
|
|
@@ -1097,9 +1137,9 @@ class Mixscape:
|
|
1097
1137
|
>>> import pertpy as pt
|
1098
1138
|
>>> mdata = pt.dt.papalexi_2021()
|
1099
1139
|
>>> ms_pt = pt.tl.Mixscape()
|
1100
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
1101
|
-
>>> ms_pt.mixscape(
|
1102
|
-
>>> ms_pt.lda(
|
1140
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
1141
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
1142
|
+
>>> ms_pt.lda(mdata["rna"], "gene_target", "NT", split_by="replicate")
|
1103
1143
|
>>> ms_pt.plot_lda(adata=mdata["rna"], control="NT")
|
1104
1144
|
|
1105
1145
|
Preview:
|
@@ -1129,8 +1169,54 @@ class Mixscape:
|
|
1129
1169
|
**kwds,
|
1130
1170
|
)
|
1131
1171
|
|
1132
|
-
if show:
|
1133
|
-
plt.show()
|
1134
1172
|
if return_fig:
|
1135
1173
|
return fig
|
1174
|
+
plt.show()
|
1136
1175
|
return None
|
1176
|
+
|
1177
|
+
|
1178
|
+
class MixscapeGaussianMixture(GaussianMixture):
|
1179
|
+
def __init__(
|
1180
|
+
self,
|
1181
|
+
n_components: int,
|
1182
|
+
fixed_means: Sequence[float] | None = None,
|
1183
|
+
fixed_covariances: Sequence[float] | None = None,
|
1184
|
+
**kwargs,
|
1185
|
+
):
|
1186
|
+
"""Custom Gaussian Mixture Model where means and covariances can be fixed for specific components.
|
1187
|
+
|
1188
|
+
Args:
|
1189
|
+
n_components: Number of Gaussian components
|
1190
|
+
fixed_means: Means to fix (use None for those that should be estimated)
|
1191
|
+
fixed_covariances: Covariances to fix (use None for those that should be estimated)
|
1192
|
+
kwargs: Additional arguments passed to scikit-learn's GaussianMixture
|
1193
|
+
"""
|
1194
|
+
super().__init__(n_components=n_components, **kwargs)
|
1195
|
+
self.fixed_means = fixed_means
|
1196
|
+
self.fixed_covariances = fixed_covariances
|
1197
|
+
|
1198
|
+
self.fixed_mean_indices = []
|
1199
|
+
self.fixed_mean_values = []
|
1200
|
+
if fixed_means is not None:
|
1201
|
+
self.fixed_mean_indices = [i for i, m in enumerate(fixed_means) if m is not None]
|
1202
|
+
if self.fixed_mean_indices:
|
1203
|
+
self.fixed_mean_values = np.array([fixed_means[i] for i in self.fixed_mean_indices])
|
1204
|
+
|
1205
|
+
self.fixed_cov_indices = []
|
1206
|
+
self.fixed_cov_values = []
|
1207
|
+
if fixed_covariances is not None:
|
1208
|
+
self.fixed_cov_indices = [i for i, c in enumerate(fixed_covariances) if c is not None]
|
1209
|
+
if self.fixed_cov_indices:
|
1210
|
+
self.fixed_cov_values = np.array([fixed_covariances[i] for i in self.fixed_cov_indices])
|
1211
|
+
|
1212
|
+
def _m_step(self, X: np.ndarray, log_resp: np.ndarray):
|
1213
|
+
"""Modified M-step to respect fixed means and covariances."""
|
1214
|
+
super()._m_step(X, log_resp)
|
1215
|
+
|
1216
|
+
if self.fixed_mean_indices:
|
1217
|
+
self.means_[self.fixed_mean_indices] = self.fixed_mean_values
|
1218
|
+
|
1219
|
+
if self.fixed_cov_indices:
|
1220
|
+
self.covariances_[self.fixed_cov_indices] = self.fixed_cov_values
|
1221
|
+
|
1222
|
+
return self
|