pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +19 -0
- pertpy/data/_datasets.py +1 -1
- pertpy/metadata/_cell_line.py +18 -8
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +114 -13
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +64 -86
- pertpy/tools/_cinemaot.py +21 -17
- pertpy/tools/_coda/_base_coda.py +90 -117
- pertpy/tools/_dialogue.py +32 -40
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +486 -112
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +71 -56
- pertpy/tools/_enrichment.py +16 -8
- pertpy/tools/_milo.py +54 -50
- pertpy/tools/_mixscape.py +307 -208
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +48 -0
- pertpy/tools/_scgen/_scgen.py +35 -27
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/METADATA +6 -6
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/RECORD +29 -28
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_mixscape.py
CHANGED
@@ -18,6 +18,7 @@ from scipy.sparse import csr_matrix, issparse, spmatrix
|
|
18
18
|
from sklearn.mixture import GaussianMixture
|
19
19
|
|
20
20
|
import pertpy as pt
|
21
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
21
22
|
|
22
23
|
if TYPE_CHECKING:
|
23
24
|
from collections.abc import Sequence
|
@@ -25,6 +26,7 @@ if TYPE_CHECKING:
|
|
25
26
|
from anndata import AnnData
|
26
27
|
from matplotlib.axes import Axes
|
27
28
|
from matplotlib.colors import Colormap
|
29
|
+
from matplotlib.pyplot import Figure
|
28
30
|
from scipy import sparse
|
29
31
|
|
30
32
|
|
@@ -39,9 +41,12 @@ class Mixscape:
|
|
39
41
|
adata: AnnData,
|
40
42
|
pert_key: str,
|
41
43
|
control: str,
|
44
|
+
*,
|
45
|
+
ref_selection_mode: Literal["nn", "split_by"] = "nn",
|
42
46
|
split_by: str | None = None,
|
43
47
|
n_neighbors: int = 20,
|
44
48
|
use_rep: str | None = None,
|
49
|
+
n_dims: int | None = 15,
|
45
50
|
n_pcs: int | None = None,
|
46
51
|
batch_size: int | None = None,
|
47
52
|
copy: bool = False,
|
@@ -49,14 +54,18 @@ class Mixscape:
|
|
49
54
|
):
|
50
55
|
"""Calculate perturbation signature.
|
51
56
|
|
52
|
-
|
53
|
-
|
54
|
-
|
57
|
+
The perturbation signature is calculated by subtracting the mRNA expression profile of each cell from the averaged
|
58
|
+
mRNA expression profile of the control cells (selected according to `ref_selection_mode`).
|
59
|
+
The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the
|
60
|
+
perturbation signature is calculated on unscaled data by default, and we therefore recommend to do the same.
|
55
61
|
|
56
62
|
Args:
|
57
63
|
adata: The annotated data object.
|
58
64
|
pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
|
59
|
-
control:
|
65
|
+
control: Name of the control category from the `pert_key` column.
|
66
|
+
ref_selection_mode: Method to select reference cells for the perturbation signature calculation. If `nn`,
|
67
|
+
the `n_neighbors` cells from the control pool with the most similar mRNA expression profiles are selected. If `split_by`,
|
68
|
+
the control cells from the same split in `split_by` (e.g. indicating biological replicates) are used to calculate the perturbation signature.
|
60
69
|
split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
|
61
70
|
the perturbation signature for every replicate separately.
|
62
71
|
n_neighbors: Number of neighbors from the control to use for the perturbation signature.
|
@@ -64,7 +73,10 @@ class Mixscape:
|
|
64
73
|
If `None`, the representation is chosen automatically:
|
65
74
|
For `.n_vars` < 50, `.X` is used, otherwise 'X_pca' is used.
|
66
75
|
If 'X_pca' is not present, it’s computed with default parameters.
|
67
|
-
|
76
|
+
n_dims: Number of dimensions to use from the representation to calculate the perturbation signature.
|
77
|
+
If `None`, use all dimensions.
|
78
|
+
n_pcs: If PCA representation is used, the number of principal components to compute.
|
79
|
+
If `n_pcs==0` use `.X` if `use_rep is None`.
|
68
80
|
batch_size: Size of batch to calculate the perturbation signature.
|
69
81
|
If 'None', the perturbation signature is calcuated in the full mode, requiring more memory.
|
70
82
|
The batched mode is very inefficient for sparse data.
|
@@ -81,8 +93,13 @@ class Mixscape:
|
|
81
93
|
>>> import pertpy as pt
|
82
94
|
>>> mdata = pt.dt.papalexi_2021()
|
83
95
|
>>> ms_pt = pt.tl.Mixscape()
|
84
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
96
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
85
97
|
"""
|
98
|
+
if ref_selection_mode not in ["nn", "split_by"]:
|
99
|
+
raise ValueError("ref_selection_mode must be either 'nn' or 'split_by'.")
|
100
|
+
if ref_selection_mode == "split_by" and split_by is None:
|
101
|
+
raise ValueError("split_by must be provided if ref_selection_mode is 'split_by'.")
|
102
|
+
|
86
103
|
if copy:
|
87
104
|
adata = adata.copy()
|
88
105
|
|
@@ -90,59 +107,75 @@ class Mixscape:
|
|
90
107
|
|
91
108
|
control_mask = adata.obs[pert_key] == control
|
92
109
|
|
93
|
-
if
|
94
|
-
|
110
|
+
if ref_selection_mode == "split_by":
|
111
|
+
for split in adata.obs[split_by].unique():
|
112
|
+
split_mask = adata.obs[split_by] == split
|
113
|
+
control_mask_group = control_mask & split_mask
|
114
|
+
control_mean_expr = adata.X[control_mask_group].mean(0)
|
115
|
+
adata.layers["X_pert"][split_mask] = (
|
116
|
+
np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0)
|
117
|
+
- adata.layers["X_pert"][split_mask]
|
118
|
+
)
|
95
119
|
else:
|
96
|
-
|
97
|
-
|
120
|
+
if split_by is None:
|
121
|
+
split_masks = [np.full(adata.n_obs, True, dtype=bool)]
|
122
|
+
else:
|
123
|
+
split_obs = adata.obs[split_by]
|
124
|
+
split_masks = [split_obs == cat for cat in split_obs.unique()]
|
98
125
|
|
99
|
-
|
126
|
+
representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
|
127
|
+
if n_dims is not None and n_dims < representation.shape[1]:
|
128
|
+
representation = representation[:, :n_dims]
|
100
129
|
|
101
|
-
|
102
|
-
|
130
|
+
for split_mask in split_masks:
|
131
|
+
control_mask_split = control_mask & split_mask
|
103
132
|
|
104
|
-
|
105
|
-
|
133
|
+
R_split = representation[split_mask]
|
134
|
+
R_control = representation[np.asarray(control_mask_split)]
|
106
135
|
|
107
|
-
|
136
|
+
from pynndescent import NNDescent
|
108
137
|
|
109
|
-
|
110
|
-
|
111
|
-
|
138
|
+
eps = kwargs.pop("epsilon", 0.1)
|
139
|
+
nn_index = NNDescent(R_control, **kwargs)
|
140
|
+
indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
|
112
141
|
|
113
|
-
|
142
|
+
X_control = np.expm1(adata.X[np.asarray(control_mask_split)])
|
114
143
|
|
115
|
-
|
116
|
-
|
144
|
+
n_split = split_mask.sum()
|
145
|
+
n_control = X_control.shape[0]
|
117
146
|
|
118
|
-
|
119
|
-
|
120
|
-
|
147
|
+
if batch_size is None:
|
148
|
+
col_indices = np.ravel(indices)
|
149
|
+
row_indices = np.repeat(np.arange(n_split), n_neighbors)
|
121
150
|
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
151
|
+
neigh_matrix = csr_matrix(
|
152
|
+
(np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)),
|
153
|
+
shape=(n_split, n_control),
|
154
|
+
)
|
155
|
+
neigh_matrix /= n_neighbors
|
156
|
+
adata.layers["X_pert"][split_mask] = (
|
157
|
+
np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask]
|
158
|
+
)
|
159
|
+
else:
|
160
|
+
is_sparse = issparse(X_control)
|
161
|
+
split_indices = np.where(split_mask)[0]
|
162
|
+
for i in range(0, n_split, batch_size):
|
163
|
+
size = min(i + batch_size, n_split)
|
164
|
+
select = slice(i, size)
|
134
165
|
|
135
|
-
|
136
|
-
|
166
|
+
batch = np.ravel(indices[select])
|
167
|
+
split_batch = split_indices[select]
|
137
168
|
|
138
|
-
|
169
|
+
size = size - i
|
139
170
|
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
171
|
+
# sparse is very slow
|
172
|
+
means_batch = X_control[batch]
|
173
|
+
means_batch = means_batch.toarray() if is_sparse else means_batch
|
174
|
+
means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1)
|
144
175
|
|
145
|
-
|
176
|
+
adata.layers["X_pert"][split_batch] = (
|
177
|
+
np.log1p(means_batch) - adata.layers["X_pert"][split_batch]
|
178
|
+
)
|
146
179
|
|
147
180
|
if copy:
|
148
181
|
return adata
|
@@ -152,33 +185,41 @@ class Mixscape:
|
|
152
185
|
adata: AnnData,
|
153
186
|
labels: str,
|
154
187
|
control: str,
|
188
|
+
*,
|
155
189
|
new_class_name: str | None = "mixscape_class",
|
156
|
-
min_de_genes: int | None = 5,
|
157
190
|
layer: str | None = None,
|
191
|
+
min_de_genes: int | None = 5,
|
158
192
|
logfc_threshold: float | None = 0.25,
|
193
|
+
de_layer: str | None = None,
|
194
|
+
test_method: str | None = "wilcoxon",
|
159
195
|
iter_num: int | None = 10,
|
196
|
+
scale: bool | None = True,
|
160
197
|
split_by: str | None = None,
|
161
198
|
pval_cutoff: float | None = 5e-2,
|
162
199
|
perturbation_type: str | None = "KO",
|
200
|
+
random_state: int | None = 0,
|
163
201
|
copy: bool | None = False,
|
164
202
|
):
|
165
203
|
"""Identify perturbed and non-perturbed gRNA expressing cells that accounts for multiple treatments/conditions/chemical perturbations.
|
166
204
|
|
167
|
-
The implementation resembles https://satijalab.org/seurat/reference/runmixscape
|
205
|
+
The implementation resembles https://satijalab.org/seurat/reference/runmixscape.
|
168
206
|
|
169
207
|
Args:
|
170
208
|
adata: The annotated data object.
|
171
209
|
labels: The column of `.obs` with target gene labels.
|
172
|
-
control: Control category from the `
|
210
|
+
control: Control category from the `labels` column.
|
173
211
|
new_class_name: Name of mixscape classification to be stored in `.obs`.
|
174
|
-
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
175
212
|
layer: Key from adata.layers whose value will be used to perform tests on. Default is using `.layers["X_pert"]`.
|
213
|
+
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
176
214
|
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells (default: 0.25).
|
215
|
+
de_layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
|
216
|
+
test_method: Method to use for differential expression testing.
|
177
217
|
iter_num: Number of normalmixEM iterations to run if convergence does not occur.
|
178
|
-
|
179
|
-
|
218
|
+
scale: Scale the data specified in `layer` before running the GaussianMixture model on it.
|
219
|
+
split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific.
|
180
220
|
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
181
221
|
perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
222
|
+
random_state: Random seed for the GaussianMixture model.
|
182
223
|
copy: Determines whether a copy of the `adata` is returned.
|
183
224
|
|
184
225
|
Returns:
|
@@ -201,8 +242,8 @@ class Mixscape:
|
|
201
242
|
>>> import pertpy as pt
|
202
243
|
>>> mdata = pt.dt.papalexi_2021()
|
203
244
|
>>> ms_pt = pt.tl.Mixscape()
|
204
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
205
|
-
>>> ms_pt.mixscape(
|
245
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
246
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
206
247
|
"""
|
207
248
|
if copy:
|
208
249
|
adata = adata.copy()
|
@@ -216,7 +257,16 @@ class Mixscape:
|
|
216
257
|
split_masks = [split_obs == category for category in categories]
|
217
258
|
|
218
259
|
perturbation_markers = self._get_perturbation_markers(
|
219
|
-
adata,
|
260
|
+
adata=adata,
|
261
|
+
split_masks=split_masks,
|
262
|
+
categories=categories,
|
263
|
+
labels=labels,
|
264
|
+
control=control,
|
265
|
+
layer=de_layer,
|
266
|
+
pval_cutoff=pval_cutoff,
|
267
|
+
min_de_genes=min_de_genes,
|
268
|
+
logfc_threshold=logfc_threshold,
|
269
|
+
test_method=test_method,
|
220
270
|
)
|
221
271
|
|
222
272
|
adata_comp = adata
|
@@ -229,6 +279,7 @@ class Mixscape:
|
|
229
279
|
raise KeyError(
|
230
280
|
"No 'X_pert' found in .layers! Please run perturbation_signature first to calculate perturbation signature!"
|
231
281
|
) from None
|
282
|
+
|
232
283
|
# initialize return variables
|
233
284
|
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
|
234
285
|
adata.obs[new_class_name] = adata.obs[labels].astype(str)
|
@@ -239,10 +290,12 @@ class Mixscape:
|
|
239
290
|
dtype=np.object_,
|
240
291
|
)
|
241
292
|
gv_list: dict[str, dict] = {}
|
293
|
+
|
294
|
+
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0.0
|
242
295
|
for split, split_mask in enumerate(split_masks):
|
243
296
|
category = categories[split]
|
244
|
-
|
245
|
-
for gene in
|
297
|
+
gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
|
298
|
+
for gene in gene_targets:
|
246
299
|
post_prob = 0
|
247
300
|
orig_guide_cells = (adata.obs[labels] == gene) & split_mask
|
248
301
|
orig_guide_cells_index = list(orig_guide_cells.index[orig_guide_cells])
|
@@ -251,28 +304,38 @@ class Mixscape:
|
|
251
304
|
|
252
305
|
if len(perturbation_markers[(category, gene)]) == 0:
|
253
306
|
adata.obs.loc[orig_guide_cells, new_class_name] = f"{gene} NP"
|
307
|
+
|
254
308
|
else:
|
255
309
|
de_genes = perturbation_markers[(category, gene)]
|
256
310
|
de_genes_indices = self._get_column_indices(adata, list(de_genes))
|
257
|
-
|
311
|
+
|
312
|
+
dat = X[np.asarray(all_cells)][:, de_genes_indices]
|
313
|
+
dat_cells = all_cells[all_cells].index
|
314
|
+
if scale:
|
315
|
+
dat = sc.pp.scale(dat)
|
316
|
+
|
258
317
|
converged = False
|
259
318
|
n_iter = 0
|
260
|
-
old_classes = adata.obs[
|
319
|
+
old_classes = adata.obs[new_class_name][all_cells]
|
320
|
+
|
261
321
|
while not converged and n_iter < iter_num:
|
262
322
|
# Get all cells in current split&Gene
|
263
|
-
guide_cells = (adata.obs[
|
323
|
+
guide_cells = (adata.obs[new_class_name] == gene) & split_mask
|
324
|
+
|
264
325
|
# get average value for each gene over all selected cells
|
265
326
|
# all cells in current split&Gene minus all NT cells in current split
|
266
327
|
# Each row is for each cell, each column is for each gene, get mean for each column
|
267
|
-
|
268
|
-
|
269
|
-
)
|
328
|
+
guide_cells_dat_idx = all_cells[all_cells].index.get_indexer(guide_cells[guide_cells].index)
|
329
|
+
nt_cells_dat_idx = all_cells[all_cells].index.get_indexer(nt_cells[nt_cells].index)
|
330
|
+
vec = np.mean(dat[guide_cells_dat_idx], axis=0) - np.mean(dat[nt_cells_dat_idx], axis=0)
|
331
|
+
|
270
332
|
# project cells onto the perturbation vector
|
271
333
|
if isinstance(dat, spmatrix):
|
272
|
-
pvec = np.
|
334
|
+
pvec = np.dot(dat.toarray(), vec) / np.dot(vec, vec)
|
273
335
|
else:
|
274
|
-
pvec = np.
|
336
|
+
pvec = np.dot(dat, vec) / np.dot(vec, vec)
|
275
337
|
pvec = pd.Series(np.asarray(pvec).flatten(), index=list(all_cells.index[all_cells]))
|
338
|
+
|
276
339
|
if n_iter == 0:
|
277
340
|
gv = pd.DataFrame(columns=["pvec", labels])
|
278
341
|
gv["pvec"] = pvec
|
@@ -282,19 +345,22 @@ class Mixscape:
|
|
282
345
|
gv_list[gene] = {}
|
283
346
|
gv_list[gene][category] = gv
|
284
347
|
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
precisions_init = np.array([nt_norm[1], guide_norm[1]])
|
289
|
-
mm = GaussianMixture(
|
348
|
+
means_init = np.array([[pvec[nt_cells].mean()], [pvec[guide_cells].mean()]])
|
349
|
+
std_init = np.array([pvec[nt_cells].std(), pvec[guide_cells].std()])
|
350
|
+
mm = MixscapeGaussianMixture(
|
290
351
|
n_components=2,
|
291
352
|
covariance_type="spherical",
|
292
353
|
means_init=means_init,
|
293
|
-
precisions_init=
|
354
|
+
precisions_init=1 / (std_init ** 2),
|
355
|
+
random_state=random_state,
|
356
|
+
max_iter=5000,
|
357
|
+
fixed_means=[pvec[nt_cells].mean(), None],
|
358
|
+
fixed_covariances=[pvec[nt_cells].std() ** 2, None],
|
294
359
|
).fit(np.asarray(pvec).reshape(-1, 1))
|
295
360
|
probabilities = mm.predict_proba(np.array(pvec[orig_guide_cells_index]).reshape(-1, 1))
|
296
361
|
lik_ratio = probabilities[:, 0] / probabilities[:, 1]
|
297
362
|
post_prob = 1 / (1 + lik_ratio)
|
363
|
+
|
298
364
|
# based on the posterior probability, assign cells to the two classes
|
299
365
|
adata.obs.loc[
|
300
366
|
[orig_guide_cells_index[cell] for cell in np.where(post_prob > 0.5)[0]], new_class_name
|
@@ -302,11 +368,13 @@ class Mixscape:
|
|
302
368
|
adata.obs.loc[
|
303
369
|
[orig_guide_cells_index[cell] for cell in np.where(post_prob <= 0.5)[0]], new_class_name
|
304
370
|
] = f"{gene} NP"
|
371
|
+
|
305
372
|
if sum(adata.obs[new_class_name][split_mask] == gene) < min_de_genes:
|
306
373
|
adata.obs.loc[guide_cells, new_class_name] = "NP"
|
307
374
|
converged = True
|
308
375
|
if adata.obs[new_class_name][all_cells].equals(old_classes):
|
309
376
|
converged = True
|
377
|
+
|
310
378
|
old_classes = adata.obs[new_class_name][all_cells]
|
311
379
|
n_iter += 1
|
312
380
|
|
@@ -315,9 +383,7 @@ class Mixscape:
|
|
315
383
|
)
|
316
384
|
|
317
385
|
adata.obs[f"{new_class_name}_global"] = [a.split(" ")[-1] for a in adata.obs[new_class_name]]
|
318
|
-
adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] =
|
319
|
-
post_prob
|
320
|
-
).astype("int64")
|
386
|
+
adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = post_prob
|
321
387
|
adata.uns["mixscape"] = gv_list
|
322
388
|
|
323
389
|
if copy:
|
@@ -328,11 +394,13 @@ class Mixscape:
|
|
328
394
|
adata: AnnData,
|
329
395
|
labels: str,
|
330
396
|
control: str,
|
397
|
+
*,
|
331
398
|
mixscape_class_global: str | None = "mixscape_class_global",
|
332
399
|
layer: str | None = None,
|
333
400
|
n_comps: int | None = 10,
|
334
401
|
min_de_genes: int | None = 5,
|
335
402
|
logfc_threshold: float | None = 0.25,
|
403
|
+
test_method: str | None = "wilcoxon",
|
336
404
|
split_by: str | None = None,
|
337
405
|
pval_cutoff: float | None = 5e-2,
|
338
406
|
perturbation_type: str | None = "KO",
|
@@ -345,12 +413,13 @@ class Mixscape:
|
|
345
413
|
labels: The column of `.obs` with target gene labels.
|
346
414
|
control: Control category from the `pert_key` column.
|
347
415
|
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
|
348
|
-
layer:
|
416
|
+
layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
|
349
417
|
control: Control category from the `pert_key` column.
|
350
418
|
n_comps: Number of principal components to use.
|
351
419
|
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
352
420
|
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
|
353
|
-
|
421
|
+
test_method: Method to use for differential expression testing.
|
422
|
+
split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific.
|
354
423
|
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
355
424
|
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
356
425
|
copy: Determines whether a copy of the `adata` is returned.
|
@@ -368,9 +437,9 @@ class Mixscape:
|
|
368
437
|
>>> import pertpy as pt
|
369
438
|
>>> mdata = pt.dt.papalexi_2021()
|
370
439
|
>>> ms_pt = pt.tl.Mixscape()
|
371
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
372
|
-
>>> ms_pt.mixscape(
|
373
|
-
>>> ms_pt.lda(
|
440
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
441
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
442
|
+
>>> ms_pt.lda(mdata["rna"], "gene_target", "NT")
|
374
443
|
"""
|
375
444
|
if copy:
|
376
445
|
adata = adata.copy()
|
@@ -386,9 +455,8 @@ class Mixscape:
|
|
386
455
|
categories = split_obs.unique()
|
387
456
|
split_masks = [split_obs == category for category in categories]
|
388
457
|
|
389
|
-
mixscape_identifier = pt.tl.Mixscape()
|
390
458
|
# determine gene sets across all splits/groups through differential gene expression
|
391
|
-
perturbation_markers =
|
459
|
+
perturbation_markers = self._get_perturbation_markers(
|
392
460
|
adata=adata,
|
393
461
|
split_masks=split_masks,
|
394
462
|
categories=categories,
|
@@ -398,6 +466,7 @@ class Mixscape:
|
|
398
466
|
pval_cutoff=pval_cutoff,
|
399
467
|
min_de_genes=min_de_genes,
|
400
468
|
logfc_threshold=logfc_threshold,
|
469
|
+
test_method=test_method,
|
401
470
|
)
|
402
471
|
adata_subset = adata[
|
403
472
|
(adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
|
@@ -443,12 +512,21 @@ class Mixscape:
|
|
443
512
|
pval_cutoff: float,
|
444
513
|
min_de_genes: float,
|
445
514
|
logfc_threshold: float,
|
515
|
+
test_method: str,
|
446
516
|
) -> dict[tuple, np.ndarray]:
|
447
517
|
"""Determine gene sets across all splits/groups through differential gene expression
|
448
518
|
|
449
519
|
Args:
|
450
520
|
adata: :class:`~anndata.AnnData` object
|
451
|
-
|
521
|
+
split_masks: List of boolean masks for each split/group.
|
522
|
+
categories: List of split/group names.
|
523
|
+
labels: The column of `.obs` with target gene labels.
|
524
|
+
control: Control category from the `labels` column.
|
525
|
+
layer: Key from adata.layers whose value will be used to compare gene expression.
|
526
|
+
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
527
|
+
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
528
|
+
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
|
529
|
+
test_method: Method to use for differential expression testing.
|
452
530
|
|
453
531
|
Returns:
|
454
532
|
Set of column indices.
|
@@ -457,21 +535,21 @@ class Mixscape:
|
|
457
535
|
for split, split_mask in enumerate(split_masks):
|
458
536
|
category = categories[split]
|
459
537
|
# get gene sets for each split
|
460
|
-
|
538
|
+
gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
|
461
539
|
adata_split = adata[split_mask].copy()
|
462
540
|
# find top DE genes between cells with targeting and non-targeting gRNAs
|
463
541
|
sc.tl.rank_genes_groups(
|
464
542
|
adata_split,
|
465
543
|
layer=layer,
|
466
544
|
groupby=labels,
|
467
|
-
groups=
|
545
|
+
groups=gene_targets,
|
468
546
|
reference=control,
|
469
|
-
method=
|
547
|
+
method=test_method,
|
470
548
|
use_raw=False,
|
471
549
|
)
|
472
|
-
# get DE genes for each gene
|
473
|
-
for gene in
|
474
|
-
logfc_threshold_mask = adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene] >= logfc_threshold
|
550
|
+
# get DE genes for each target gene
|
551
|
+
for gene in gene_targets:
|
552
|
+
logfc_threshold_mask = np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
|
475
553
|
de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask]
|
476
554
|
pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask]
|
477
555
|
de_genes = de_genes[pvals_adj < pval_cutoff]
|
@@ -492,35 +570,22 @@ class Mixscape:
|
|
492
570
|
|
493
571
|
return indices
|
494
572
|
|
495
|
-
|
496
|
-
"""Calculates the mean and standard deviation of a matrix.
|
497
|
-
|
498
|
-
Args:
|
499
|
-
X: The matrix to calculate the properties for.
|
500
|
-
|
501
|
-
Returns:
|
502
|
-
Mean and standard deviation of the matrix.
|
503
|
-
"""
|
504
|
-
mu = X.mean()
|
505
|
-
sd = X.std()
|
506
|
-
|
507
|
-
return [mu, sd]
|
508
|
-
|
573
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
509
574
|
def plot_barplot( # pragma: no cover
|
510
575
|
self,
|
511
576
|
adata: AnnData,
|
512
577
|
guide_rna_column: str,
|
578
|
+
*,
|
513
579
|
mixscape_class_global: str = "mixscape_class_global",
|
514
580
|
axis_text_x_size: int = 8,
|
515
581
|
axis_text_y_size: int = 6,
|
516
582
|
axis_title_size: int = 8,
|
517
583
|
legend_title_size: int = 8,
|
518
584
|
legend_text_size: int = 8,
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
):
|
585
|
+
legend_bbox_to_anchor: tuple[float, float] = None,
|
586
|
+
figsize: tuple[float, float] = (25, 25),
|
587
|
+
return_fig: bool = False,
|
588
|
+
) -> Figure | None:
|
524
589
|
"""Barplot to visualize perturbation scores calculated by the `mixscape` function.
|
525
590
|
|
526
591
|
Args:
|
@@ -528,19 +593,24 @@ class Mixscape:
|
|
528
593
|
guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels.
|
529
594
|
The format must be <gene_target>g<#>. Examples are 'STAT2g1' and 'ATF2g1'.
|
530
595
|
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
|
531
|
-
|
532
|
-
|
533
|
-
|
596
|
+
axis_text_x_size: Size of the x-axis text.
|
597
|
+
axis_text_y_size: Size of the y-axis text.
|
598
|
+
axis_title_size: Size of the axis title.
|
599
|
+
legend_title_size: Size of the legend title.
|
600
|
+
legend_text_size: Size of the legend text.
|
601
|
+
legend_bbox_to_anchor: The bbox that the legend will be anchored.
|
602
|
+
figsize: The size of the figure.
|
603
|
+
{common_plot_args}
|
534
604
|
|
535
605
|
Returns:
|
536
|
-
If `
|
606
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
537
607
|
|
538
608
|
Examples:
|
539
609
|
>>> import pertpy as pt
|
540
610
|
>>> mdata = pt.dt.papalexi_2021()
|
541
611
|
>>> ms_pt = pt.tl.Mixscape()
|
542
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
543
|
-
>>> ms_pt.mixscape(
|
612
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
613
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
544
614
|
>>> ms_pt.plot_barplot(mdata["rna"], guide_rna_column="NT")
|
545
615
|
|
546
616
|
Preview:
|
@@ -565,63 +635,64 @@ class Mixscape:
|
|
565
635
|
all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"]
|
566
636
|
NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"]
|
567
637
|
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
|
593
|
-
ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size)
|
594
|
-
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size)
|
595
|
-
fig.subplots_adjust(right=0.8)
|
596
|
-
fig.subplots_adjust(hspace=0.5, wspace=0.5)
|
638
|
+
color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"}
|
639
|
+
unique_genes = NP_KO_cells["gene"].unique()
|
640
|
+
fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=figsize, sharey=True)
|
641
|
+
for i, gene in enumerate(unique_genes):
|
642
|
+
ax = axs[int(i / 5), i % 5]
|
643
|
+
grouped_df = (
|
644
|
+
NP_KO_cells[NP_KO_cells["gene"] == gene]
|
645
|
+
.groupby(["guide_number", "mixscape_class_global"], observed=False)["value"]
|
646
|
+
.sum()
|
647
|
+
.unstack()
|
648
|
+
)
|
649
|
+
grouped_df.plot(
|
650
|
+
kind="bar",
|
651
|
+
stacked=True,
|
652
|
+
color=[color_mapping[col] for col in grouped_df.columns],
|
653
|
+
ax=ax,
|
654
|
+
width=0.8,
|
655
|
+
legend=False,
|
656
|
+
)
|
657
|
+
ax.set_title(gene, bbox={"facecolor": "white", "edgecolor": "black", "pad": 1}, fontsize=axis_title_size)
|
658
|
+
ax.set(xlabel="sgRNA", ylabel="% of cells")
|
659
|
+
sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
|
660
|
+
ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size)
|
661
|
+
ax.set_yticks(ax.get_yticks(), ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size)
|
597
662
|
ax.legend(
|
598
|
-
title="
|
663
|
+
title="Mixscape Class",
|
599
664
|
loc="center right",
|
600
|
-
bbox_to_anchor=
|
665
|
+
bbox_to_anchor=legend_bbox_to_anchor,
|
601
666
|
frameon=True,
|
602
667
|
fontsize=legend_text_size,
|
603
668
|
title_fontsize=legend_title_size,
|
604
669
|
)
|
605
670
|
|
671
|
+
fig.subplots_adjust(right=0.8)
|
672
|
+
fig.subplots_adjust(hspace=0.5, wspace=0.5)
|
606
673
|
plt.tight_layout()
|
607
|
-
_utils.savefig_or_show("mixscape_barplot", show=show, save=save)
|
608
674
|
|
675
|
+
if return_fig:
|
676
|
+
return fig
|
677
|
+
plt.show()
|
678
|
+
return None
|
679
|
+
|
680
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
609
681
|
def plot_heatmap( # pragma: no cover
|
610
682
|
self,
|
611
683
|
adata: AnnData,
|
612
684
|
labels: str,
|
613
685
|
target_gene: str,
|
614
686
|
control: str,
|
687
|
+
*,
|
615
688
|
layer: str | None = None,
|
616
689
|
method: str | None = "wilcoxon",
|
617
690
|
subsample_number: int | None = 900,
|
618
691
|
vmin: float | None = -2,
|
619
692
|
vmax: float | None = 2,
|
620
|
-
return_fig: bool
|
621
|
-
show: bool | None = None,
|
622
|
-
save: bool | str | None = None,
|
693
|
+
return_fig: bool = False,
|
623
694
|
**kwds,
|
624
|
-
) ->
|
695
|
+
) -> Figure | None:
|
625
696
|
"""Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first.
|
626
697
|
|
627
698
|
Args:
|
@@ -634,21 +705,18 @@ class Mixscape:
|
|
634
705
|
subsample_number: Subsample to this number of observations.
|
635
706
|
vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
|
636
707
|
vmax: The value representing the upper limit of the color scale. Values larger than vmax are plotted with the same color as vmax.
|
637
|
-
|
638
|
-
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
639
|
-
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
640
|
-
ax: A matplotlib axes object. Only works if plotting a single component.
|
708
|
+
{common_plot_args}
|
641
709
|
**kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`.
|
642
710
|
|
643
711
|
Returns:
|
644
|
-
If `
|
712
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
645
713
|
|
646
714
|
Examples:
|
647
715
|
>>> import pertpy as pt
|
648
716
|
>>> mdata = pt.dt.papalexi_2021()
|
649
717
|
>>> ms_pt = pt.tl.Mixscape()
|
650
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
651
|
-
>>> ms_pt.mixscape(
|
718
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
719
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
652
720
|
>>> ms_pt.plot_heatmap(
|
653
721
|
... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT"
|
654
722
|
... )
|
@@ -663,35 +731,37 @@ class Mixscape:
|
|
663
731
|
sc.pp.scale(adata_subset, max_value=vmax)
|
664
732
|
sc.pp.subsample(adata_subset, n_obs=subsample_number)
|
665
733
|
|
666
|
-
|
734
|
+
fig = sc.pl.rank_genes_groups_heatmap(
|
667
735
|
adata_subset,
|
668
736
|
groupby="mixscape_class",
|
669
737
|
vmin=vmin,
|
670
738
|
vmax=vmax,
|
671
739
|
n_genes=20,
|
672
740
|
groups=["NT"],
|
673
|
-
|
674
|
-
show=show,
|
675
|
-
save=save,
|
741
|
+
show=False,
|
676
742
|
**kwds,
|
677
743
|
)
|
678
744
|
|
745
|
+
if return_fig:
|
746
|
+
return fig
|
747
|
+
plt.show()
|
748
|
+
return None
|
749
|
+
|
750
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
679
751
|
def plot_perturbscore( # pragma: no cover
|
680
752
|
self,
|
681
753
|
adata: AnnData,
|
682
754
|
labels: str,
|
683
755
|
target_gene: str,
|
756
|
+
*,
|
684
757
|
mixscape_class: str = "mixscape_class",
|
685
758
|
color: str = "orange",
|
686
759
|
palette: dict[str, str] = None,
|
687
760
|
split_by: str = None,
|
688
761
|
before_mixscape: bool = False,
|
689
762
|
perturbation_type: str = "KO",
|
690
|
-
return_fig: bool
|
691
|
-
|
692
|
-
show: bool | None = None,
|
693
|
-
save: bool | str | None = None,
|
694
|
-
) -> None:
|
763
|
+
return_fig: bool = False,
|
764
|
+
) -> Figure | None:
|
695
765
|
"""Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
|
696
766
|
|
697
767
|
Requires `pt.tl.mixscape` to be run first.
|
@@ -710,6 +780,10 @@ class Mixscape:
|
|
710
780
|
before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification.
|
711
781
|
Default is set to NULL and plots cells by original class ID.
|
712
782
|
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
783
|
+
{common_plot_args}
|
784
|
+
|
785
|
+
Returns:
|
786
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
713
787
|
|
714
788
|
Examples:
|
715
789
|
Visualizing the perturbation scores for the cells in a dataset:
|
@@ -717,8 +791,8 @@ class Mixscape:
|
|
717
791
|
>>> import pertpy as pt
|
718
792
|
>>> mdata = pt.dt.papalexi_2021()
|
719
793
|
>>> ms_pt = pt.tl.Mixscape()
|
720
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
721
|
-
>>> ms_pt.mixscape(
|
794
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
795
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
722
796
|
>>> ms_pt.plot_perturbscore(adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange")
|
723
797
|
|
724
798
|
Preview:
|
@@ -778,15 +852,6 @@ class Mixscape:
|
|
778
852
|
plt.legend(title="gene_target", title_fontsize=14, fontsize=12)
|
779
853
|
sns.despine()
|
780
854
|
|
781
|
-
if save:
|
782
|
-
plt.savefig(save, bbox_inches="tight")
|
783
|
-
if show:
|
784
|
-
plt.show()
|
785
|
-
if return_fig:
|
786
|
-
return plt.gcf()
|
787
|
-
if not (show or save):
|
788
|
-
return plt.gca()
|
789
|
-
|
790
855
|
# If before_mixscape is False, split densities based on mixscape classifications
|
791
856
|
else:
|
792
857
|
if palette is None:
|
@@ -843,19 +908,17 @@ class Mixscape:
|
|
843
908
|
plt.legend(title="mixscape class", title_fontsize=14, fontsize=12)
|
844
909
|
sns.despine()
|
845
910
|
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
if return_fig:
|
851
|
-
return plt.gcf()
|
852
|
-
if not (show or save):
|
853
|
-
return plt.gca()
|
911
|
+
if return_fig:
|
912
|
+
return plt.gcf()
|
913
|
+
plt.show()
|
914
|
+
return None
|
854
915
|
|
916
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
855
917
|
def plot_violin( # pragma: no cover
|
856
918
|
self,
|
857
919
|
adata: AnnData,
|
858
920
|
target_gene_idents: str | list[str],
|
921
|
+
*,
|
859
922
|
keys: str | Sequence[str] = "mixscape_class_p_ko",
|
860
923
|
groupby: str | None = "mixscape_class",
|
861
924
|
log: bool = False,
|
@@ -872,10 +935,9 @@ class Mixscape:
|
|
872
935
|
ylabel: str | Sequence[str] | None = None,
|
873
936
|
rotation: float | None = None,
|
874
937
|
ax: Axes | None = None,
|
875
|
-
|
876
|
-
save: bool | str | None = None,
|
938
|
+
return_fig: bool = False,
|
877
939
|
**kwargs,
|
878
|
-
):
|
940
|
+
) -> Axes | Figure | None:
|
879
941
|
"""Violin plot using mixscape results.
|
880
942
|
|
881
943
|
Requires `pt.tl.mixscape` to be run first.
|
@@ -892,21 +954,19 @@ class Mixscape:
|
|
892
954
|
xlabel: Label of the x-axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
|
893
955
|
ylabel: Label of the y-axis. If `None` and `groupby` is `None`, defaults to `'value'`.
|
894
956
|
If `None` and `groubpy` is not `None`, defaults to `keys`.
|
895
|
-
show: Show the plot, do not return axis.
|
896
|
-
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
897
|
-
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
898
957
|
ax: A matplotlib axes object. Only works if plotting a single component.
|
958
|
+
{common_plot_args}
|
899
959
|
**kwargs: Additional arguments to `seaborn.violinplot`.
|
900
960
|
|
901
961
|
Returns:
|
902
|
-
|
962
|
+
If `return_fig` is `True`, returns the figure (as Axes list if it's a multi-panel plot), otherwise `None`.
|
903
963
|
|
904
964
|
Examples:
|
905
965
|
>>> import pertpy as pt
|
906
966
|
>>> mdata = pt.dt.papalexi_2021()
|
907
967
|
>>> ms_pt = pt.tl.Mixscape()
|
908
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
909
|
-
>>> ms_pt.mixscape(
|
968
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
969
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
910
970
|
>>> ms_pt.plot_violin(
|
911
971
|
... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class"
|
912
972
|
... )
|
@@ -1042,23 +1102,25 @@ class Mixscape:
|
|
1042
1102
|
if rotation is not None:
|
1043
1103
|
ax.tick_params(axis="x", labelrotation=rotation)
|
1044
1104
|
|
1045
|
-
show = settings.autoshow if show is None else show
|
1046
1105
|
if hue is not None and stripplot is True:
|
1047
1106
|
plt.legend(handles, labels)
|
1048
|
-
_utils.savefig_or_show("mixscape_violin", show=show, save=save)
|
1049
1107
|
|
1050
|
-
if
|
1108
|
+
if return_fig:
|
1051
1109
|
if multi_panel and groupby is None and len(ys) == 1:
|
1052
1110
|
return g
|
1053
1111
|
elif len(axs) == 1:
|
1054
1112
|
return axs[0]
|
1055
1113
|
else:
|
1056
1114
|
return axs
|
1115
|
+
plt.show()
|
1116
|
+
return None
|
1057
1117
|
|
1118
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1058
1119
|
def plot_lda( # pragma: no cover
|
1059
1120
|
self,
|
1060
1121
|
adata: AnnData,
|
1061
1122
|
control: str,
|
1123
|
+
*,
|
1062
1124
|
mixscape_class: str = "mixscape_class",
|
1063
1125
|
mixscape_class_global: str = "mixscape_class_global",
|
1064
1126
|
perturbation_type: str | None = "KO",
|
@@ -1066,12 +1128,10 @@ class Mixscape:
|
|
1066
1128
|
n_components: int | None = None,
|
1067
1129
|
color_map: Colormap | str | None = None,
|
1068
1130
|
palette: str | Sequence[str] | None = None,
|
1069
|
-
return_fig: bool | None = None,
|
1070
1131
|
ax: Axes | None = None,
|
1071
|
-
|
1072
|
-
save: bool | str | None = None,
|
1132
|
+
return_fig: bool = False,
|
1073
1133
|
**kwds,
|
1074
|
-
) -> None:
|
1134
|
+
) -> Figure | None:
|
1075
1135
|
"""Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
|
1076
1136
|
|
1077
1137
|
Args:
|
@@ -1082,18 +1142,16 @@ class Mixscape:
|
|
1082
1142
|
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
1083
1143
|
lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
|
1084
1144
|
n_components: The number of dimensions of the embedding.
|
1085
|
-
|
1086
|
-
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
1087
|
-
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
1145
|
+
{common_plot_args}
|
1088
1146
|
**kwds: Additional arguments to `scanpy.pl.umap`.
|
1089
1147
|
|
1090
1148
|
Examples:
|
1091
1149
|
>>> import pertpy as pt
|
1092
1150
|
>>> mdata = pt.dt.papalexi_2021()
|
1093
1151
|
>>> ms_pt = pt.tl.Mixscape()
|
1094
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
1095
|
-
>>> ms_pt.mixscape(
|
1096
|
-
>>> ms_pt.lda(
|
1152
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
1153
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
1154
|
+
>>> ms_pt.lda(mdata["rna"], "gene_target", "NT", split_by="replicate")
|
1097
1155
|
>>> ms_pt.plot_lda(adata=mdata["rna"], control="NT")
|
1098
1156
|
|
1099
1157
|
Preview:
|
@@ -1112,14 +1170,55 @@ class Mixscape:
|
|
1112
1170
|
n_components = adata_subset.uns[lda_key].shape[1]
|
1113
1171
|
sc.pp.neighbors(adata_subset, use_rep=lda_key)
|
1114
1172
|
sc.tl.umap(adata_subset, n_components=n_components)
|
1115
|
-
sc.pl.umap(
|
1173
|
+
fig = sc.pl.umap(
|
1116
1174
|
adata_subset,
|
1117
1175
|
color=mixscape_class,
|
1118
1176
|
palette=palette,
|
1119
1177
|
color_map=color_map,
|
1120
1178
|
return_fig=return_fig,
|
1121
|
-
show=
|
1122
|
-
save=save,
|
1179
|
+
show=False,
|
1123
1180
|
ax=ax,
|
1124
1181
|
**kwds,
|
1125
1182
|
)
|
1183
|
+
|
1184
|
+
if return_fig:
|
1185
|
+
return fig
|
1186
|
+
plt.show()
|
1187
|
+
return None
|
1188
|
+
|
1189
|
+
class MixscapeGaussianMixture(GaussianMixture):
|
1190
|
+
def __init__(
|
1191
|
+
self,
|
1192
|
+
n_components: int,
|
1193
|
+
fixed_means: Sequence[float] | None = None,
|
1194
|
+
fixed_covariances: Sequence[float] | None = None,
|
1195
|
+
**kwargs
|
1196
|
+
):
|
1197
|
+
"""Custom Gaussian Mixture Model where means and covariances can be fixed for specific components.
|
1198
|
+
|
1199
|
+
Args:
|
1200
|
+
n_components: Number of Gaussian components
|
1201
|
+
fixed_means: Means to fix (use None for those that should be estimated)
|
1202
|
+
fixed_covariances: Covariances to fix (use None for those that should be estimated)
|
1203
|
+
kwargs: Additional arguments passed to scikit-learn's GaussianMixture
|
1204
|
+
"""
|
1205
|
+
super().__init__(n_components=n_components, **kwargs)
|
1206
|
+
self.fixed_means = fixed_means
|
1207
|
+
self.fixed_covariances = fixed_covariances
|
1208
|
+
|
1209
|
+
def _m_step(self, X: np.ndarray, log_resp: np.ndarray):
|
1210
|
+
"""Modified M-step to respect fixed means and covariances."""
|
1211
|
+
super()._m_step(X, log_resp)
|
1212
|
+
|
1213
|
+
if self.fixed_means is not None:
|
1214
|
+
for i in range(self.n_components):
|
1215
|
+
if self.fixed_means[i] is not None:
|
1216
|
+
self.means_[i] = self.fixed_means[i]
|
1217
|
+
|
1218
|
+
if self.fixed_covariances is not None:
|
1219
|
+
for i in range(self.n_components):
|
1220
|
+
if self.fixed_covariances[i] is not None:
|
1221
|
+
self.covariances_[i] = self.fixed_covariances[i]
|
1222
|
+
|
1223
|
+
return self
|
1224
|
+
|