pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +19 -0
- pertpy/data/_datasets.py +1 -1
- pertpy/metadata/_cell_line.py +18 -8
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +114 -13
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +64 -86
- pertpy/tools/_cinemaot.py +21 -17
- pertpy/tools/_coda/_base_coda.py +90 -117
- pertpy/tools/_dialogue.py +32 -40
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +486 -112
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +71 -56
- pertpy/tools/_enrichment.py +16 -8
- pertpy/tools/_milo.py +54 -50
- pertpy/tools/_mixscape.py +307 -208
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +48 -0
- pertpy/tools/_scgen/_scgen.py +35 -27
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/METADATA +6 -6
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/RECORD +29 -28
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_mixscape.py
CHANGED
@@ -18,6 +18,7 @@ from scipy.sparse import csr_matrix, issparse, spmatrix
|
|
18
18
|
from sklearn.mixture import GaussianMixture
|
19
19
|
|
20
20
|
import pertpy as pt
|
21
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
21
22
|
|
22
23
|
if TYPE_CHECKING:
|
23
24
|
from collections.abc import Sequence
|
@@ -25,6 +26,7 @@ if TYPE_CHECKING:
|
|
25
26
|
from anndata import AnnData
|
26
27
|
from matplotlib.axes import Axes
|
27
28
|
from matplotlib.colors import Colormap
|
29
|
+
from matplotlib.pyplot import Figure
|
28
30
|
from scipy import sparse
|
29
31
|
|
30
32
|
|
@@ -39,9 +41,12 @@ class Mixscape:
|
|
39
41
|
adata: AnnData,
|
40
42
|
pert_key: str,
|
41
43
|
control: str,
|
44
|
+
*,
|
45
|
+
ref_selection_mode: Literal["nn", "split_by"] = "nn",
|
42
46
|
split_by: str | None = None,
|
43
47
|
n_neighbors: int = 20,
|
44
48
|
use_rep: str | None = None,
|
49
|
+
n_dims: int | None = 15,
|
45
50
|
n_pcs: int | None = None,
|
46
51
|
batch_size: int | None = None,
|
47
52
|
copy: bool = False,
|
@@ -49,14 +54,18 @@ class Mixscape:
|
|
49
54
|
):
|
50
55
|
"""Calculate perturbation signature.
|
51
56
|
|
52
|
-
|
53
|
-
|
54
|
-
|
57
|
+
The perturbation signature is calculated by subtracting the mRNA expression profile of each cell from the averaged
|
58
|
+
mRNA expression profile of the control cells (selected according to `ref_selection_mode`).
|
59
|
+
The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the
|
60
|
+
perturbation signature is calculated on unscaled data by default, and we therefore recommend to do the same.
|
55
61
|
|
56
62
|
Args:
|
57
63
|
adata: The annotated data object.
|
58
64
|
pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
|
59
|
-
control:
|
65
|
+
control: Name of the control category from the `pert_key` column.
|
66
|
+
ref_selection_mode: Method to select reference cells for the perturbation signature calculation. If `nn`,
|
67
|
+
the `n_neighbors` cells from the control pool with the most similar mRNA expression profiles are selected. If `split_by`,
|
68
|
+
the control cells from the same split in `split_by` (e.g. indicating biological replicates) are used to calculate the perturbation signature.
|
60
69
|
split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
|
61
70
|
the perturbation signature for every replicate separately.
|
62
71
|
n_neighbors: Number of neighbors from the control to use for the perturbation signature.
|
@@ -64,7 +73,10 @@ class Mixscape:
|
|
64
73
|
If `None`, the representation is chosen automatically:
|
65
74
|
For `.n_vars` < 50, `.X` is used, otherwise 'X_pca' is used.
|
66
75
|
If 'X_pca' is not present, it’s computed with default parameters.
|
67
|
-
|
76
|
+
n_dims: Number of dimensions to use from the representation to calculate the perturbation signature.
|
77
|
+
If `None`, use all dimensions.
|
78
|
+
n_pcs: If PCA representation is used, the number of principal components to compute.
|
79
|
+
If `n_pcs==0` use `.X` if `use_rep is None`.
|
68
80
|
batch_size: Size of batch to calculate the perturbation signature.
|
69
81
|
If 'None', the perturbation signature is calcuated in the full mode, requiring more memory.
|
70
82
|
The batched mode is very inefficient for sparse data.
|
@@ -81,8 +93,13 @@ class Mixscape:
|
|
81
93
|
>>> import pertpy as pt
|
82
94
|
>>> mdata = pt.dt.papalexi_2021()
|
83
95
|
>>> ms_pt = pt.tl.Mixscape()
|
84
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
96
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
85
97
|
"""
|
98
|
+
if ref_selection_mode not in ["nn", "split_by"]:
|
99
|
+
raise ValueError("ref_selection_mode must be either 'nn' or 'split_by'.")
|
100
|
+
if ref_selection_mode == "split_by" and split_by is None:
|
101
|
+
raise ValueError("split_by must be provided if ref_selection_mode is 'split_by'.")
|
102
|
+
|
86
103
|
if copy:
|
87
104
|
adata = adata.copy()
|
88
105
|
|
@@ -90,59 +107,75 @@ class Mixscape:
|
|
90
107
|
|
91
108
|
control_mask = adata.obs[pert_key] == control
|
92
109
|
|
93
|
-
if
|
94
|
-
|
110
|
+
if ref_selection_mode == "split_by":
|
111
|
+
for split in adata.obs[split_by].unique():
|
112
|
+
split_mask = adata.obs[split_by] == split
|
113
|
+
control_mask_group = control_mask & split_mask
|
114
|
+
control_mean_expr = adata.X[control_mask_group].mean(0)
|
115
|
+
adata.layers["X_pert"][split_mask] = (
|
116
|
+
np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0)
|
117
|
+
- adata.layers["X_pert"][split_mask]
|
118
|
+
)
|
95
119
|
else:
|
96
|
-
|
97
|
-
|
120
|
+
if split_by is None:
|
121
|
+
split_masks = [np.full(adata.n_obs, True, dtype=bool)]
|
122
|
+
else:
|
123
|
+
split_obs = adata.obs[split_by]
|
124
|
+
split_masks = [split_obs == cat for cat in split_obs.unique()]
|
98
125
|
|
99
|
-
|
126
|
+
representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
|
127
|
+
if n_dims is not None and n_dims < representation.shape[1]:
|
128
|
+
representation = representation[:, :n_dims]
|
100
129
|
|
101
|
-
|
102
|
-
|
130
|
+
for split_mask in split_masks:
|
131
|
+
control_mask_split = control_mask & split_mask
|
103
132
|
|
104
|
-
|
105
|
-
|
133
|
+
R_split = representation[split_mask]
|
134
|
+
R_control = representation[np.asarray(control_mask_split)]
|
106
135
|
|
107
|
-
|
136
|
+
from pynndescent import NNDescent
|
108
137
|
|
109
|
-
|
110
|
-
|
111
|
-
|
138
|
+
eps = kwargs.pop("epsilon", 0.1)
|
139
|
+
nn_index = NNDescent(R_control, **kwargs)
|
140
|
+
indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
|
112
141
|
|
113
|
-
|
142
|
+
X_control = np.expm1(adata.X[np.asarray(control_mask_split)])
|
114
143
|
|
115
|
-
|
116
|
-
|
144
|
+
n_split = split_mask.sum()
|
145
|
+
n_control = X_control.shape[0]
|
117
146
|
|
118
|
-
|
119
|
-
|
120
|
-
|
147
|
+
if batch_size is None:
|
148
|
+
col_indices = np.ravel(indices)
|
149
|
+
row_indices = np.repeat(np.arange(n_split), n_neighbors)
|
121
150
|
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
151
|
+
neigh_matrix = csr_matrix(
|
152
|
+
(np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)),
|
153
|
+
shape=(n_split, n_control),
|
154
|
+
)
|
155
|
+
neigh_matrix /= n_neighbors
|
156
|
+
adata.layers["X_pert"][split_mask] = (
|
157
|
+
np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask]
|
158
|
+
)
|
159
|
+
else:
|
160
|
+
is_sparse = issparse(X_control)
|
161
|
+
split_indices = np.where(split_mask)[0]
|
162
|
+
for i in range(0, n_split, batch_size):
|
163
|
+
size = min(i + batch_size, n_split)
|
164
|
+
select = slice(i, size)
|
134
165
|
|
135
|
-
|
136
|
-
|
166
|
+
batch = np.ravel(indices[select])
|
167
|
+
split_batch = split_indices[select]
|
137
168
|
|
138
|
-
|
169
|
+
size = size - i
|
139
170
|
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
171
|
+
# sparse is very slow
|
172
|
+
means_batch = X_control[batch]
|
173
|
+
means_batch = means_batch.toarray() if is_sparse else means_batch
|
174
|
+
means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1)
|
144
175
|
|
145
|
-
|
176
|
+
adata.layers["X_pert"][split_batch] = (
|
177
|
+
np.log1p(means_batch) - adata.layers["X_pert"][split_batch]
|
178
|
+
)
|
146
179
|
|
147
180
|
if copy:
|
148
181
|
return adata
|
@@ -152,33 +185,41 @@ class Mixscape:
|
|
152
185
|
adata: AnnData,
|
153
186
|
labels: str,
|
154
187
|
control: str,
|
188
|
+
*,
|
155
189
|
new_class_name: str | None = "mixscape_class",
|
156
|
-
min_de_genes: int | None = 5,
|
157
190
|
layer: str | None = None,
|
191
|
+
min_de_genes: int | None = 5,
|
158
192
|
logfc_threshold: float | None = 0.25,
|
193
|
+
de_layer: str | None = None,
|
194
|
+
test_method: str | None = "wilcoxon",
|
159
195
|
iter_num: int | None = 10,
|
196
|
+
scale: bool | None = True,
|
160
197
|
split_by: str | None = None,
|
161
198
|
pval_cutoff: float | None = 5e-2,
|
162
199
|
perturbation_type: str | None = "KO",
|
200
|
+
random_state: int | None = 0,
|
163
201
|
copy: bool | None = False,
|
164
202
|
):
|
165
203
|
"""Identify perturbed and non-perturbed gRNA expressing cells that accounts for multiple treatments/conditions/chemical perturbations.
|
166
204
|
|
167
|
-
The implementation resembles https://satijalab.org/seurat/reference/runmixscape
|
205
|
+
The implementation resembles https://satijalab.org/seurat/reference/runmixscape.
|
168
206
|
|
169
207
|
Args:
|
170
208
|
adata: The annotated data object.
|
171
209
|
labels: The column of `.obs` with target gene labels.
|
172
|
-
control: Control category from the `
|
210
|
+
control: Control category from the `labels` column.
|
173
211
|
new_class_name: Name of mixscape classification to be stored in `.obs`.
|
174
|
-
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
175
212
|
layer: Key from adata.layers whose value will be used to perform tests on. Default is using `.layers["X_pert"]`.
|
213
|
+
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
176
214
|
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells (default: 0.25).
|
215
|
+
de_layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
|
216
|
+
test_method: Method to use for differential expression testing.
|
177
217
|
iter_num: Number of normalmixEM iterations to run if convergence does not occur.
|
178
|
-
|
179
|
-
|
218
|
+
scale: Scale the data specified in `layer` before running the GaussianMixture model on it.
|
219
|
+
split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific.
|
180
220
|
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
181
221
|
perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
222
|
+
random_state: Random seed for the GaussianMixture model.
|
182
223
|
copy: Determines whether a copy of the `adata` is returned.
|
183
224
|
|
184
225
|
Returns:
|
@@ -201,8 +242,8 @@ class Mixscape:
|
|
201
242
|
>>> import pertpy as pt
|
202
243
|
>>> mdata = pt.dt.papalexi_2021()
|
203
244
|
>>> ms_pt = pt.tl.Mixscape()
|
204
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
205
|
-
>>> ms_pt.mixscape(
|
245
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
246
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
206
247
|
"""
|
207
248
|
if copy:
|
208
249
|
adata = adata.copy()
|
@@ -216,7 +257,16 @@ class Mixscape:
|
|
216
257
|
split_masks = [split_obs == category for category in categories]
|
217
258
|
|
218
259
|
perturbation_markers = self._get_perturbation_markers(
|
219
|
-
adata,
|
260
|
+
adata=adata,
|
261
|
+
split_masks=split_masks,
|
262
|
+
categories=categories,
|
263
|
+
labels=labels,
|
264
|
+
control=control,
|
265
|
+
layer=de_layer,
|
266
|
+
pval_cutoff=pval_cutoff,
|
267
|
+
min_de_genes=min_de_genes,
|
268
|
+
logfc_threshold=logfc_threshold,
|
269
|
+
test_method=test_method,
|
220
270
|
)
|
221
271
|
|
222
272
|
adata_comp = adata
|
@@ -229,6 +279,7 @@ class Mixscape:
|
|
229
279
|
raise KeyError(
|
230
280
|
"No 'X_pert' found in .layers! Please run perturbation_signature first to calculate perturbation signature!"
|
231
281
|
) from None
|
282
|
+
|
232
283
|
# initialize return variables
|
233
284
|
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
|
234
285
|
adata.obs[new_class_name] = adata.obs[labels].astype(str)
|
@@ -239,10 +290,12 @@ class Mixscape:
|
|
239
290
|
dtype=np.object_,
|
240
291
|
)
|
241
292
|
gv_list: dict[str, dict] = {}
|
293
|
+
|
294
|
+
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0.0
|
242
295
|
for split, split_mask in enumerate(split_masks):
|
243
296
|
category = categories[split]
|
244
|
-
|
245
|
-
for gene in
|
297
|
+
gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
|
298
|
+
for gene in gene_targets:
|
246
299
|
post_prob = 0
|
247
300
|
orig_guide_cells = (adata.obs[labels] == gene) & split_mask
|
248
301
|
orig_guide_cells_index = list(orig_guide_cells.index[orig_guide_cells])
|
@@ -251,28 +304,38 @@ class Mixscape:
|
|
251
304
|
|
252
305
|
if len(perturbation_markers[(category, gene)]) == 0:
|
253
306
|
adata.obs.loc[orig_guide_cells, new_class_name] = f"{gene} NP"
|
307
|
+
|
254
308
|
else:
|
255
309
|
de_genes = perturbation_markers[(category, gene)]
|
256
310
|
de_genes_indices = self._get_column_indices(adata, list(de_genes))
|
257
|
-
|
311
|
+
|
312
|
+
dat = X[np.asarray(all_cells)][:, de_genes_indices]
|
313
|
+
dat_cells = all_cells[all_cells].index
|
314
|
+
if scale:
|
315
|
+
dat = sc.pp.scale(dat)
|
316
|
+
|
258
317
|
converged = False
|
259
318
|
n_iter = 0
|
260
|
-
old_classes = adata.obs[
|
319
|
+
old_classes = adata.obs[new_class_name][all_cells]
|
320
|
+
|
261
321
|
while not converged and n_iter < iter_num:
|
262
322
|
# Get all cells in current split&Gene
|
263
|
-
guide_cells = (adata.obs[
|
323
|
+
guide_cells = (adata.obs[new_class_name] == gene) & split_mask
|
324
|
+
|
264
325
|
# get average value for each gene over all selected cells
|
265
326
|
# all cells in current split&Gene minus all NT cells in current split
|
266
327
|
# Each row is for each cell, each column is for each gene, get mean for each column
|
267
|
-
|
268
|
-
|
269
|
-
)
|
328
|
+
guide_cells_dat_idx = all_cells[all_cells].index.get_indexer(guide_cells[guide_cells].index)
|
329
|
+
nt_cells_dat_idx = all_cells[all_cells].index.get_indexer(nt_cells[nt_cells].index)
|
330
|
+
vec = np.mean(dat[guide_cells_dat_idx], axis=0) - np.mean(dat[nt_cells_dat_idx], axis=0)
|
331
|
+
|
270
332
|
# project cells onto the perturbation vector
|
271
333
|
if isinstance(dat, spmatrix):
|
272
|
-
pvec = np.
|
334
|
+
pvec = np.dot(dat.toarray(), vec) / np.dot(vec, vec)
|
273
335
|
else:
|
274
|
-
pvec = np.
|
336
|
+
pvec = np.dot(dat, vec) / np.dot(vec, vec)
|
275
337
|
pvec = pd.Series(np.asarray(pvec).flatten(), index=list(all_cells.index[all_cells]))
|
338
|
+
|
276
339
|
if n_iter == 0:
|
277
340
|
gv = pd.DataFrame(columns=["pvec", labels])
|
278
341
|
gv["pvec"] = pvec
|
@@ -282,19 +345,22 @@ class Mixscape:
|
|
282
345
|
gv_list[gene] = {}
|
283
346
|
gv_list[gene][category] = gv
|
284
347
|
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
precisions_init = np.array([nt_norm[1], guide_norm[1]])
|
289
|
-
mm = GaussianMixture(
|
348
|
+
means_init = np.array([[pvec[nt_cells].mean()], [pvec[guide_cells].mean()]])
|
349
|
+
std_init = np.array([pvec[nt_cells].std(), pvec[guide_cells].std()])
|
350
|
+
mm = MixscapeGaussianMixture(
|
290
351
|
n_components=2,
|
291
352
|
covariance_type="spherical",
|
292
353
|
means_init=means_init,
|
293
|
-
precisions_init=
|
354
|
+
precisions_init=1 / (std_init ** 2),
|
355
|
+
random_state=random_state,
|
356
|
+
max_iter=5000,
|
357
|
+
fixed_means=[pvec[nt_cells].mean(), None],
|
358
|
+
fixed_covariances=[pvec[nt_cells].std() ** 2, None],
|
294
359
|
).fit(np.asarray(pvec).reshape(-1, 1))
|
295
360
|
probabilities = mm.predict_proba(np.array(pvec[orig_guide_cells_index]).reshape(-1, 1))
|
296
361
|
lik_ratio = probabilities[:, 0] / probabilities[:, 1]
|
297
362
|
post_prob = 1 / (1 + lik_ratio)
|
363
|
+
|
298
364
|
# based on the posterior probability, assign cells to the two classes
|
299
365
|
adata.obs.loc[
|
300
366
|
[orig_guide_cells_index[cell] for cell in np.where(post_prob > 0.5)[0]], new_class_name
|
@@ -302,11 +368,13 @@ class Mixscape:
|
|
302
368
|
adata.obs.loc[
|
303
369
|
[orig_guide_cells_index[cell] for cell in np.where(post_prob <= 0.5)[0]], new_class_name
|
304
370
|
] = f"{gene} NP"
|
371
|
+
|
305
372
|
if sum(adata.obs[new_class_name][split_mask] == gene) < min_de_genes:
|
306
373
|
adata.obs.loc[guide_cells, new_class_name] = "NP"
|
307
374
|
converged = True
|
308
375
|
if adata.obs[new_class_name][all_cells].equals(old_classes):
|
309
376
|
converged = True
|
377
|
+
|
310
378
|
old_classes = adata.obs[new_class_name][all_cells]
|
311
379
|
n_iter += 1
|
312
380
|
|
@@ -315,9 +383,7 @@ class Mixscape:
|
|
315
383
|
)
|
316
384
|
|
317
385
|
adata.obs[f"{new_class_name}_global"] = [a.split(" ")[-1] for a in adata.obs[new_class_name]]
|
318
|
-
adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] =
|
319
|
-
post_prob
|
320
|
-
).astype("int64")
|
386
|
+
adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = post_prob
|
321
387
|
adata.uns["mixscape"] = gv_list
|
322
388
|
|
323
389
|
if copy:
|
@@ -328,11 +394,13 @@ class Mixscape:
|
|
328
394
|
adata: AnnData,
|
329
395
|
labels: str,
|
330
396
|
control: str,
|
397
|
+
*,
|
331
398
|
mixscape_class_global: str | None = "mixscape_class_global",
|
332
399
|
layer: str | None = None,
|
333
400
|
n_comps: int | None = 10,
|
334
401
|
min_de_genes: int | None = 5,
|
335
402
|
logfc_threshold: float | None = 0.25,
|
403
|
+
test_method: str | None = "wilcoxon",
|
336
404
|
split_by: str | None = None,
|
337
405
|
pval_cutoff: float | None = 5e-2,
|
338
406
|
perturbation_type: str | None = "KO",
|
@@ -345,12 +413,13 @@ class Mixscape:
|
|
345
413
|
labels: The column of `.obs` with target gene labels.
|
346
414
|
control: Control category from the `pert_key` column.
|
347
415
|
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
|
348
|
-
layer:
|
416
|
+
layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
|
349
417
|
control: Control category from the `pert_key` column.
|
350
418
|
n_comps: Number of principal components to use.
|
351
419
|
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
352
420
|
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
|
353
|
-
|
421
|
+
test_method: Method to use for differential expression testing.
|
422
|
+
split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific.
|
354
423
|
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
355
424
|
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
356
425
|
copy: Determines whether a copy of the `adata` is returned.
|
@@ -368,9 +437,9 @@ class Mixscape:
|
|
368
437
|
>>> import pertpy as pt
|
369
438
|
>>> mdata = pt.dt.papalexi_2021()
|
370
439
|
>>> ms_pt = pt.tl.Mixscape()
|
371
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
372
|
-
>>> ms_pt.mixscape(
|
373
|
-
>>> ms_pt.lda(
|
440
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
441
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
442
|
+
>>> ms_pt.lda(mdata["rna"], "gene_target", "NT")
|
374
443
|
"""
|
375
444
|
if copy:
|
376
445
|
adata = adata.copy()
|
@@ -386,9 +455,8 @@ class Mixscape:
|
|
386
455
|
categories = split_obs.unique()
|
387
456
|
split_masks = [split_obs == category for category in categories]
|
388
457
|
|
389
|
-
mixscape_identifier = pt.tl.Mixscape()
|
390
458
|
# determine gene sets across all splits/groups through differential gene expression
|
391
|
-
perturbation_markers =
|
459
|
+
perturbation_markers = self._get_perturbation_markers(
|
392
460
|
adata=adata,
|
393
461
|
split_masks=split_masks,
|
394
462
|
categories=categories,
|
@@ -398,6 +466,7 @@ class Mixscape:
|
|
398
466
|
pval_cutoff=pval_cutoff,
|
399
467
|
min_de_genes=min_de_genes,
|
400
468
|
logfc_threshold=logfc_threshold,
|
469
|
+
test_method=test_method,
|
401
470
|
)
|
402
471
|
adata_subset = adata[
|
403
472
|
(adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
|
@@ -443,12 +512,21 @@ class Mixscape:
|
|
443
512
|
pval_cutoff: float,
|
444
513
|
min_de_genes: float,
|
445
514
|
logfc_threshold: float,
|
515
|
+
test_method: str,
|
446
516
|
) -> dict[tuple, np.ndarray]:
|
447
517
|
"""Determine gene sets across all splits/groups through differential gene expression
|
448
518
|
|
449
519
|
Args:
|
450
520
|
adata: :class:`~anndata.AnnData` object
|
451
|
-
|
521
|
+
split_masks: List of boolean masks for each split/group.
|
522
|
+
categories: List of split/group names.
|
523
|
+
labels: The column of `.obs` with target gene labels.
|
524
|
+
control: Control category from the `labels` column.
|
525
|
+
layer: Key from adata.layers whose value will be used to compare gene expression.
|
526
|
+
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
527
|
+
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
528
|
+
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
|
529
|
+
test_method: Method to use for differential expression testing.
|
452
530
|
|
453
531
|
Returns:
|
454
532
|
Set of column indices.
|
@@ -457,21 +535,21 @@ class Mixscape:
|
|
457
535
|
for split, split_mask in enumerate(split_masks):
|
458
536
|
category = categories[split]
|
459
537
|
# get gene sets for each split
|
460
|
-
|
538
|
+
gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
|
461
539
|
adata_split = adata[split_mask].copy()
|
462
540
|
# find top DE genes between cells with targeting and non-targeting gRNAs
|
463
541
|
sc.tl.rank_genes_groups(
|
464
542
|
adata_split,
|
465
543
|
layer=layer,
|
466
544
|
groupby=labels,
|
467
|
-
groups=
|
545
|
+
groups=gene_targets,
|
468
546
|
reference=control,
|
469
|
-
method=
|
547
|
+
method=test_method,
|
470
548
|
use_raw=False,
|
471
549
|
)
|
472
|
-
# get DE genes for each gene
|
473
|
-
for gene in
|
474
|
-
logfc_threshold_mask = adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene] >= logfc_threshold
|
550
|
+
# get DE genes for each target gene
|
551
|
+
for gene in gene_targets:
|
552
|
+
logfc_threshold_mask = np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
|
475
553
|
de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask]
|
476
554
|
pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask]
|
477
555
|
de_genes = de_genes[pvals_adj < pval_cutoff]
|
@@ -492,35 +570,22 @@ class Mixscape:
|
|
492
570
|
|
493
571
|
return indices
|
494
572
|
|
495
|
-
|
496
|
-
"""Calculates the mean and standard deviation of a matrix.
|
497
|
-
|
498
|
-
Args:
|
499
|
-
X: The matrix to calculate the properties for.
|
500
|
-
|
501
|
-
Returns:
|
502
|
-
Mean and standard deviation of the matrix.
|
503
|
-
"""
|
504
|
-
mu = X.mean()
|
505
|
-
sd = X.std()
|
506
|
-
|
507
|
-
return [mu, sd]
|
508
|
-
|
573
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
509
574
|
def plot_barplot( # pragma: no cover
|
510
575
|
self,
|
511
576
|
adata: AnnData,
|
512
577
|
guide_rna_column: str,
|
578
|
+
*,
|
513
579
|
mixscape_class_global: str = "mixscape_class_global",
|
514
580
|
axis_text_x_size: int = 8,
|
515
581
|
axis_text_y_size: int = 6,
|
516
582
|
axis_title_size: int = 8,
|
517
583
|
legend_title_size: int = 8,
|
518
584
|
legend_text_size: int = 8,
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
):
|
585
|
+
legend_bbox_to_anchor: tuple[float, float] = None,
|
586
|
+
figsize: tuple[float, float] = (25, 25),
|
587
|
+
return_fig: bool = False,
|
588
|
+
) -> Figure | None:
|
524
589
|
"""Barplot to visualize perturbation scores calculated by the `mixscape` function.
|
525
590
|
|
526
591
|
Args:
|
@@ -528,19 +593,24 @@ class Mixscape:
|
|
528
593
|
guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels.
|
529
594
|
The format must be <gene_target>g<#>. Examples are 'STAT2g1' and 'ATF2g1'.
|
530
595
|
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
|
531
|
-
|
532
|
-
|
533
|
-
|
596
|
+
axis_text_x_size: Size of the x-axis text.
|
597
|
+
axis_text_y_size: Size of the y-axis text.
|
598
|
+
axis_title_size: Size of the axis title.
|
599
|
+
legend_title_size: Size of the legend title.
|
600
|
+
legend_text_size: Size of the legend text.
|
601
|
+
legend_bbox_to_anchor: The bbox that the legend will be anchored.
|
602
|
+
figsize: The size of the figure.
|
603
|
+
{common_plot_args}
|
534
604
|
|
535
605
|
Returns:
|
536
|
-
If `
|
606
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
537
607
|
|
538
608
|
Examples:
|
539
609
|
>>> import pertpy as pt
|
540
610
|
>>> mdata = pt.dt.papalexi_2021()
|
541
611
|
>>> ms_pt = pt.tl.Mixscape()
|
542
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
543
|
-
>>> ms_pt.mixscape(
|
612
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
613
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
544
614
|
>>> ms_pt.plot_barplot(mdata["rna"], guide_rna_column="NT")
|
545
615
|
|
546
616
|
Preview:
|
@@ -565,63 +635,64 @@ class Mixscape:
|
|
565
635
|
all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"]
|
566
636
|
NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"]
|
567
637
|
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
|
593
|
-
ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size)
|
594
|
-
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size)
|
595
|
-
fig.subplots_adjust(right=0.8)
|
596
|
-
fig.subplots_adjust(hspace=0.5, wspace=0.5)
|
638
|
+
color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"}
|
639
|
+
unique_genes = NP_KO_cells["gene"].unique()
|
640
|
+
fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=figsize, sharey=True)
|
641
|
+
for i, gene in enumerate(unique_genes):
|
642
|
+
ax = axs[int(i / 5), i % 5]
|
643
|
+
grouped_df = (
|
644
|
+
NP_KO_cells[NP_KO_cells["gene"] == gene]
|
645
|
+
.groupby(["guide_number", "mixscape_class_global"], observed=False)["value"]
|
646
|
+
.sum()
|
647
|
+
.unstack()
|
648
|
+
)
|
649
|
+
grouped_df.plot(
|
650
|
+
kind="bar",
|
651
|
+
stacked=True,
|
652
|
+
color=[color_mapping[col] for col in grouped_df.columns],
|
653
|
+
ax=ax,
|
654
|
+
width=0.8,
|
655
|
+
legend=False,
|
656
|
+
)
|
657
|
+
ax.set_title(gene, bbox={"facecolor": "white", "edgecolor": "black", "pad": 1}, fontsize=axis_title_size)
|
658
|
+
ax.set(xlabel="sgRNA", ylabel="% of cells")
|
659
|
+
sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
|
660
|
+
ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size)
|
661
|
+
ax.set_yticks(ax.get_yticks(), ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size)
|
597
662
|
ax.legend(
|
598
|
-
title="
|
663
|
+
title="Mixscape Class",
|
599
664
|
loc="center right",
|
600
|
-
bbox_to_anchor=
|
665
|
+
bbox_to_anchor=legend_bbox_to_anchor,
|
601
666
|
frameon=True,
|
602
667
|
fontsize=legend_text_size,
|
603
668
|
title_fontsize=legend_title_size,
|
604
669
|
)
|
605
670
|
|
671
|
+
fig.subplots_adjust(right=0.8)
|
672
|
+
fig.subplots_adjust(hspace=0.5, wspace=0.5)
|
606
673
|
plt.tight_layout()
|
607
|
-
_utils.savefig_or_show("mixscape_barplot", show=show, save=save)
|
608
674
|
|
675
|
+
if return_fig:
|
676
|
+
return fig
|
677
|
+
plt.show()
|
678
|
+
return None
|
679
|
+
|
680
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
609
681
|
def plot_heatmap( # pragma: no cover
|
610
682
|
self,
|
611
683
|
adata: AnnData,
|
612
684
|
labels: str,
|
613
685
|
target_gene: str,
|
614
686
|
control: str,
|
687
|
+
*,
|
615
688
|
layer: str | None = None,
|
616
689
|
method: str | None = "wilcoxon",
|
617
690
|
subsample_number: int | None = 900,
|
618
691
|
vmin: float | None = -2,
|
619
692
|
vmax: float | None = 2,
|
620
|
-
return_fig: bool
|
621
|
-
show: bool | None = None,
|
622
|
-
save: bool | str | None = None,
|
693
|
+
return_fig: bool = False,
|
623
694
|
**kwds,
|
624
|
-
) ->
|
695
|
+
) -> Figure | None:
|
625
696
|
"""Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first.
|
626
697
|
|
627
698
|
Args:
|
@@ -634,21 +705,18 @@ class Mixscape:
|
|
634
705
|
subsample_number: Subsample to this number of observations.
|
635
706
|
vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
|
636
707
|
vmax: The value representing the upper limit of the color scale. Values larger than vmax are plotted with the same color as vmax.
|
637
|
-
|
638
|
-
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
639
|
-
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
640
|
-
ax: A matplotlib axes object. Only works if plotting a single component.
|
708
|
+
{common_plot_args}
|
641
709
|
**kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`.
|
642
710
|
|
643
711
|
Returns:
|
644
|
-
If `
|
712
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
645
713
|
|
646
714
|
Examples:
|
647
715
|
>>> import pertpy as pt
|
648
716
|
>>> mdata = pt.dt.papalexi_2021()
|
649
717
|
>>> ms_pt = pt.tl.Mixscape()
|
650
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
651
|
-
>>> ms_pt.mixscape(
|
718
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
719
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
652
720
|
>>> ms_pt.plot_heatmap(
|
653
721
|
... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT"
|
654
722
|
... )
|
@@ -663,35 +731,37 @@ class Mixscape:
|
|
663
731
|
sc.pp.scale(adata_subset, max_value=vmax)
|
664
732
|
sc.pp.subsample(adata_subset, n_obs=subsample_number)
|
665
733
|
|
666
|
-
|
734
|
+
fig = sc.pl.rank_genes_groups_heatmap(
|
667
735
|
adata_subset,
|
668
736
|
groupby="mixscape_class",
|
669
737
|
vmin=vmin,
|
670
738
|
vmax=vmax,
|
671
739
|
n_genes=20,
|
672
740
|
groups=["NT"],
|
673
|
-
|
674
|
-
show=show,
|
675
|
-
save=save,
|
741
|
+
show=False,
|
676
742
|
**kwds,
|
677
743
|
)
|
678
744
|
|
745
|
+
if return_fig:
|
746
|
+
return fig
|
747
|
+
plt.show()
|
748
|
+
return None
|
749
|
+
|
750
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
679
751
|
def plot_perturbscore( # pragma: no cover
|
680
752
|
self,
|
681
753
|
adata: AnnData,
|
682
754
|
labels: str,
|
683
755
|
target_gene: str,
|
756
|
+
*,
|
684
757
|
mixscape_class: str = "mixscape_class",
|
685
758
|
color: str = "orange",
|
686
759
|
palette: dict[str, str] = None,
|
687
760
|
split_by: str = None,
|
688
761
|
before_mixscape: bool = False,
|
689
762
|
perturbation_type: str = "KO",
|
690
|
-
return_fig: bool
|
691
|
-
|
692
|
-
show: bool | None = None,
|
693
|
-
save: bool | str | None = None,
|
694
|
-
) -> None:
|
763
|
+
return_fig: bool = False,
|
764
|
+
) -> Figure | None:
|
695
765
|
"""Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
|
696
766
|
|
697
767
|
Requires `pt.tl.mixscape` to be run first.
|
@@ -710,6 +780,10 @@ class Mixscape:
|
|
710
780
|
before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification.
|
711
781
|
Default is set to NULL and plots cells by original class ID.
|
712
782
|
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
783
|
+
{common_plot_args}
|
784
|
+
|
785
|
+
Returns:
|
786
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
713
787
|
|
714
788
|
Examples:
|
715
789
|
Visualizing the perturbation scores for the cells in a dataset:
|
@@ -717,8 +791,8 @@ class Mixscape:
|
|
717
791
|
>>> import pertpy as pt
|
718
792
|
>>> mdata = pt.dt.papalexi_2021()
|
719
793
|
>>> ms_pt = pt.tl.Mixscape()
|
720
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
721
|
-
>>> ms_pt.mixscape(
|
794
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
795
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
722
796
|
>>> ms_pt.plot_perturbscore(adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange")
|
723
797
|
|
724
798
|
Preview:
|
@@ -778,15 +852,6 @@ class Mixscape:
|
|
778
852
|
plt.legend(title="gene_target", title_fontsize=14, fontsize=12)
|
779
853
|
sns.despine()
|
780
854
|
|
781
|
-
if save:
|
782
|
-
plt.savefig(save, bbox_inches="tight")
|
783
|
-
if show:
|
784
|
-
plt.show()
|
785
|
-
if return_fig:
|
786
|
-
return plt.gcf()
|
787
|
-
if not (show or save):
|
788
|
-
return plt.gca()
|
789
|
-
|
790
855
|
# If before_mixscape is False, split densities based on mixscape classifications
|
791
856
|
else:
|
792
857
|
if palette is None:
|
@@ -843,19 +908,17 @@ class Mixscape:
|
|
843
908
|
plt.legend(title="mixscape class", title_fontsize=14, fontsize=12)
|
844
909
|
sns.despine()
|
845
910
|
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
if return_fig:
|
851
|
-
return plt.gcf()
|
852
|
-
if not (show or save):
|
853
|
-
return plt.gca()
|
911
|
+
if return_fig:
|
912
|
+
return plt.gcf()
|
913
|
+
plt.show()
|
914
|
+
return None
|
854
915
|
|
916
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
855
917
|
def plot_violin( # pragma: no cover
|
856
918
|
self,
|
857
919
|
adata: AnnData,
|
858
920
|
target_gene_idents: str | list[str],
|
921
|
+
*,
|
859
922
|
keys: str | Sequence[str] = "mixscape_class_p_ko",
|
860
923
|
groupby: str | None = "mixscape_class",
|
861
924
|
log: bool = False,
|
@@ -872,10 +935,9 @@ class Mixscape:
|
|
872
935
|
ylabel: str | Sequence[str] | None = None,
|
873
936
|
rotation: float | None = None,
|
874
937
|
ax: Axes | None = None,
|
875
|
-
|
876
|
-
save: bool | str | None = None,
|
938
|
+
return_fig: bool = False,
|
877
939
|
**kwargs,
|
878
|
-
):
|
940
|
+
) -> Axes | Figure | None:
|
879
941
|
"""Violin plot using mixscape results.
|
880
942
|
|
881
943
|
Requires `pt.tl.mixscape` to be run first.
|
@@ -892,21 +954,19 @@ class Mixscape:
|
|
892
954
|
xlabel: Label of the x-axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
|
893
955
|
ylabel: Label of the y-axis. If `None` and `groupby` is `None`, defaults to `'value'`.
|
894
956
|
If `None` and `groubpy` is not `None`, defaults to `keys`.
|
895
|
-
show: Show the plot, do not return axis.
|
896
|
-
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
897
|
-
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
898
957
|
ax: A matplotlib axes object. Only works if plotting a single component.
|
958
|
+
{common_plot_args}
|
899
959
|
**kwargs: Additional arguments to `seaborn.violinplot`.
|
900
960
|
|
901
961
|
Returns:
|
902
|
-
|
962
|
+
If `return_fig` is `True`, returns the figure (as Axes list if it's a multi-panel plot), otherwise `None`.
|
903
963
|
|
904
964
|
Examples:
|
905
965
|
>>> import pertpy as pt
|
906
966
|
>>> mdata = pt.dt.papalexi_2021()
|
907
967
|
>>> ms_pt = pt.tl.Mixscape()
|
908
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
909
|
-
>>> ms_pt.mixscape(
|
968
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
969
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
910
970
|
>>> ms_pt.plot_violin(
|
911
971
|
... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class"
|
912
972
|
... )
|
@@ -1042,23 +1102,25 @@ class Mixscape:
|
|
1042
1102
|
if rotation is not None:
|
1043
1103
|
ax.tick_params(axis="x", labelrotation=rotation)
|
1044
1104
|
|
1045
|
-
show = settings.autoshow if show is None else show
|
1046
1105
|
if hue is not None and stripplot is True:
|
1047
1106
|
plt.legend(handles, labels)
|
1048
|
-
_utils.savefig_or_show("mixscape_violin", show=show, save=save)
|
1049
1107
|
|
1050
|
-
if
|
1108
|
+
if return_fig:
|
1051
1109
|
if multi_panel and groupby is None and len(ys) == 1:
|
1052
1110
|
return g
|
1053
1111
|
elif len(axs) == 1:
|
1054
1112
|
return axs[0]
|
1055
1113
|
else:
|
1056
1114
|
return axs
|
1115
|
+
plt.show()
|
1116
|
+
return None
|
1057
1117
|
|
1118
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1058
1119
|
def plot_lda( # pragma: no cover
|
1059
1120
|
self,
|
1060
1121
|
adata: AnnData,
|
1061
1122
|
control: str,
|
1123
|
+
*,
|
1062
1124
|
mixscape_class: str = "mixscape_class",
|
1063
1125
|
mixscape_class_global: str = "mixscape_class_global",
|
1064
1126
|
perturbation_type: str | None = "KO",
|
@@ -1066,12 +1128,10 @@ class Mixscape:
|
|
1066
1128
|
n_components: int | None = None,
|
1067
1129
|
color_map: Colormap | str | None = None,
|
1068
1130
|
palette: str | Sequence[str] | None = None,
|
1069
|
-
return_fig: bool | None = None,
|
1070
1131
|
ax: Axes | None = None,
|
1071
|
-
|
1072
|
-
save: bool | str | None = None,
|
1132
|
+
return_fig: bool = False,
|
1073
1133
|
**kwds,
|
1074
|
-
) -> None:
|
1134
|
+
) -> Figure | None:
|
1075
1135
|
"""Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
|
1076
1136
|
|
1077
1137
|
Args:
|
@@ -1082,18 +1142,16 @@ class Mixscape:
|
|
1082
1142
|
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
1083
1143
|
lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
|
1084
1144
|
n_components: The number of dimensions of the embedding.
|
1085
|
-
|
1086
|
-
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
1087
|
-
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
1145
|
+
{common_plot_args}
|
1088
1146
|
**kwds: Additional arguments to `scanpy.pl.umap`.
|
1089
1147
|
|
1090
1148
|
Examples:
|
1091
1149
|
>>> import pertpy as pt
|
1092
1150
|
>>> mdata = pt.dt.papalexi_2021()
|
1093
1151
|
>>> ms_pt = pt.tl.Mixscape()
|
1094
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
1095
|
-
>>> ms_pt.mixscape(
|
1096
|
-
>>> ms_pt.lda(
|
1152
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
1153
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
1154
|
+
>>> ms_pt.lda(mdata["rna"], "gene_target", "NT", split_by="replicate")
|
1097
1155
|
>>> ms_pt.plot_lda(adata=mdata["rna"], control="NT")
|
1098
1156
|
|
1099
1157
|
Preview:
|
@@ -1112,14 +1170,55 @@ class Mixscape:
|
|
1112
1170
|
n_components = adata_subset.uns[lda_key].shape[1]
|
1113
1171
|
sc.pp.neighbors(adata_subset, use_rep=lda_key)
|
1114
1172
|
sc.tl.umap(adata_subset, n_components=n_components)
|
1115
|
-
sc.pl.umap(
|
1173
|
+
fig = sc.pl.umap(
|
1116
1174
|
adata_subset,
|
1117
1175
|
color=mixscape_class,
|
1118
1176
|
palette=palette,
|
1119
1177
|
color_map=color_map,
|
1120
1178
|
return_fig=return_fig,
|
1121
|
-
show=
|
1122
|
-
save=save,
|
1179
|
+
show=False,
|
1123
1180
|
ax=ax,
|
1124
1181
|
**kwds,
|
1125
1182
|
)
|
1183
|
+
|
1184
|
+
if return_fig:
|
1185
|
+
return fig
|
1186
|
+
plt.show()
|
1187
|
+
return None
|
1188
|
+
|
1189
|
+
class MixscapeGaussianMixture(GaussianMixture):
|
1190
|
+
def __init__(
|
1191
|
+
self,
|
1192
|
+
n_components: int,
|
1193
|
+
fixed_means: Sequence[float] | None = None,
|
1194
|
+
fixed_covariances: Sequence[float] | None = None,
|
1195
|
+
**kwargs
|
1196
|
+
):
|
1197
|
+
"""Custom Gaussian Mixture Model where means and covariances can be fixed for specific components.
|
1198
|
+
|
1199
|
+
Args:
|
1200
|
+
n_components: Number of Gaussian components
|
1201
|
+
fixed_means: Means to fix (use None for those that should be estimated)
|
1202
|
+
fixed_covariances: Covariances to fix (use None for those that should be estimated)
|
1203
|
+
kwargs: Additional arguments passed to scikit-learn's GaussianMixture
|
1204
|
+
"""
|
1205
|
+
super().__init__(n_components=n_components, **kwargs)
|
1206
|
+
self.fixed_means = fixed_means
|
1207
|
+
self.fixed_covariances = fixed_covariances
|
1208
|
+
|
1209
|
+
def _m_step(self, X: np.ndarray, log_resp: np.ndarray):
|
1210
|
+
"""Modified M-step to respect fixed means and covariances."""
|
1211
|
+
super()._m_step(X, log_resp)
|
1212
|
+
|
1213
|
+
if self.fixed_means is not None:
|
1214
|
+
for i in range(self.n_components):
|
1215
|
+
if self.fixed_means[i] is not None:
|
1216
|
+
self.means_[i] = self.fixed_means[i]
|
1217
|
+
|
1218
|
+
if self.fixed_covariances is not None:
|
1219
|
+
for i in range(self.n_components):
|
1220
|
+
if self.fixed_covariances[i] is not None:
|
1221
|
+
self.covariances_[i] = self.fixed_covariances[i]
|
1222
|
+
|
1223
|
+
return self
|
1224
|
+
|