pertpy 0.9.5__py3-none-any.whl → 0.10.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +1 -2
- pertpy/metadata/_cell_line.py +3 -5
- pertpy/preprocessing/_guide_rna.py +98 -10
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/_augur.py +32 -44
- pertpy/tools/_cinemaot.py +1 -3
- pertpy/tools/_coda/_base_coda.py +21 -29
- pertpy/tools/_dialogue.py +17 -21
- pertpy/tools/_differential_gene_expression/_base.py +4 -12
- pertpy/tools/_distances/_distances.py +56 -48
- pertpy/tools/_enrichment.py +1 -3
- pertpy/tools/_milo.py +4 -12
- pertpy/tools/_mixscape.py +215 -127
- pertpy/tools/_perturbation_space/_simple.py +1 -3
- pertpy/tools/_scgen/_scgen.py +1 -3
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/METADATA +2 -2
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/RECORD +20 -19
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/WHEEL +0 -0
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_mixscape.py
CHANGED
@@ -41,9 +41,12 @@ class Mixscape:
|
|
41
41
|
adata: AnnData,
|
42
42
|
pert_key: str,
|
43
43
|
control: str,
|
44
|
+
*,
|
45
|
+
ref_selection_mode: Literal["nn", "split_by"] = "nn",
|
44
46
|
split_by: str | None = None,
|
45
47
|
n_neighbors: int = 20,
|
46
48
|
use_rep: str | None = None,
|
49
|
+
n_dims: int | None = 15,
|
47
50
|
n_pcs: int | None = None,
|
48
51
|
batch_size: int | None = None,
|
49
52
|
copy: bool = False,
|
@@ -51,14 +54,18 @@ class Mixscape:
|
|
51
54
|
):
|
52
55
|
"""Calculate perturbation signature.
|
53
56
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
+
The perturbation signature is calculated by subtracting the mRNA expression profile of each cell from the averaged
|
58
|
+
mRNA expression profile of the control cells (selected according to `ref_selection_mode`).
|
59
|
+
The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the
|
60
|
+
perturbation signature is calculated on unscaled data by default, and we therefore recommend to do the same.
|
57
61
|
|
58
62
|
Args:
|
59
63
|
adata: The annotated data object.
|
60
64
|
pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
|
61
|
-
control:
|
65
|
+
control: Name of the control category from the `pert_key` column.
|
66
|
+
ref_selection_mode: Method to select reference cells for the perturbation signature calculation. If `nn`,
|
67
|
+
the `n_neighbors` cells from the control pool with the most similar mRNA expression profiles are selected. If `split_by`,
|
68
|
+
the control cells from the same split in `split_by` (e.g. indicating biological replicates) are used to calculate the perturbation signature.
|
62
69
|
split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
|
63
70
|
the perturbation signature for every replicate separately.
|
64
71
|
n_neighbors: Number of neighbors from the control to use for the perturbation signature.
|
@@ -66,7 +73,10 @@ class Mixscape:
|
|
66
73
|
If `None`, the representation is chosen automatically:
|
67
74
|
For `.n_vars` < 50, `.X` is used, otherwise 'X_pca' is used.
|
68
75
|
If 'X_pca' is not present, it’s computed with default parameters.
|
69
|
-
|
76
|
+
n_dims: Number of dimensions to use from the representation to calculate the perturbation signature.
|
77
|
+
If `None`, use all dimensions.
|
78
|
+
n_pcs: If PCA representation is used, the number of principal components to compute.
|
79
|
+
If `n_pcs==0` use `.X` if `use_rep is None`.
|
70
80
|
batch_size: Size of batch to calculate the perturbation signature.
|
71
81
|
If 'None', the perturbation signature is calcuated in the full mode, requiring more memory.
|
72
82
|
The batched mode is very inefficient for sparse data.
|
@@ -83,8 +93,13 @@ class Mixscape:
|
|
83
93
|
>>> import pertpy as pt
|
84
94
|
>>> mdata = pt.dt.papalexi_2021()
|
85
95
|
>>> ms_pt = pt.tl.Mixscape()
|
86
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
96
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
87
97
|
"""
|
98
|
+
if ref_selection_mode not in ["nn", "split_by"]:
|
99
|
+
raise ValueError("ref_selection_mode must be either 'nn' or 'split_by'.")
|
100
|
+
if ref_selection_mode == "split_by" and split_by is None:
|
101
|
+
raise ValueError("split_by must be provided if ref_selection_mode is 'split_by'.")
|
102
|
+
|
88
103
|
if copy:
|
89
104
|
adata = adata.copy()
|
90
105
|
|
@@ -92,59 +107,75 @@ class Mixscape:
|
|
92
107
|
|
93
108
|
control_mask = adata.obs[pert_key] == control
|
94
109
|
|
95
|
-
if
|
96
|
-
|
110
|
+
if ref_selection_mode == "split_by":
|
111
|
+
for split in adata.obs[split_by].unique():
|
112
|
+
split_mask = adata.obs[split_by] == split
|
113
|
+
control_mask_group = control_mask & split_mask
|
114
|
+
control_mean_expr = adata.X[control_mask_group].mean(0)
|
115
|
+
adata.layers["X_pert"][split_mask] = (
|
116
|
+
np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0)
|
117
|
+
- adata.layers["X_pert"][split_mask]
|
118
|
+
)
|
97
119
|
else:
|
98
|
-
|
99
|
-
|
120
|
+
if split_by is None:
|
121
|
+
split_masks = [np.full(adata.n_obs, True, dtype=bool)]
|
122
|
+
else:
|
123
|
+
split_obs = adata.obs[split_by]
|
124
|
+
split_masks = [split_obs == cat for cat in split_obs.unique()]
|
100
125
|
|
101
|
-
|
126
|
+
representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
|
127
|
+
if n_dims is not None and n_dims < representation.shape[1]:
|
128
|
+
representation = representation[:, :n_dims]
|
102
129
|
|
103
|
-
|
104
|
-
|
130
|
+
for split_mask in split_masks:
|
131
|
+
control_mask_split = control_mask & split_mask
|
105
132
|
|
106
|
-
|
107
|
-
|
133
|
+
R_split = representation[split_mask]
|
134
|
+
R_control = representation[np.asarray(control_mask_split)]
|
108
135
|
|
109
|
-
|
136
|
+
from pynndescent import NNDescent
|
110
137
|
|
111
|
-
|
112
|
-
|
113
|
-
|
138
|
+
eps = kwargs.pop("epsilon", 0.1)
|
139
|
+
nn_index = NNDescent(R_control, **kwargs)
|
140
|
+
indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
|
114
141
|
|
115
|
-
|
142
|
+
X_control = np.expm1(adata.X[np.asarray(control_mask_split)])
|
116
143
|
|
117
|
-
|
118
|
-
|
144
|
+
n_split = split_mask.sum()
|
145
|
+
n_control = X_control.shape[0]
|
119
146
|
|
120
|
-
|
121
|
-
|
122
|
-
|
147
|
+
if batch_size is None:
|
148
|
+
col_indices = np.ravel(indices)
|
149
|
+
row_indices = np.repeat(np.arange(n_split), n_neighbors)
|
123
150
|
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
151
|
+
neigh_matrix = csr_matrix(
|
152
|
+
(np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)),
|
153
|
+
shape=(n_split, n_control),
|
154
|
+
)
|
155
|
+
neigh_matrix /= n_neighbors
|
156
|
+
adata.layers["X_pert"][split_mask] = (
|
157
|
+
np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask]
|
158
|
+
)
|
159
|
+
else:
|
160
|
+
is_sparse = issparse(X_control)
|
161
|
+
split_indices = np.where(split_mask)[0]
|
162
|
+
for i in range(0, n_split, batch_size):
|
163
|
+
size = min(i + batch_size, n_split)
|
164
|
+
select = slice(i, size)
|
136
165
|
|
137
|
-
|
138
|
-
|
166
|
+
batch = np.ravel(indices[select])
|
167
|
+
split_batch = split_indices[select]
|
139
168
|
|
140
|
-
|
169
|
+
size = size - i
|
141
170
|
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
171
|
+
# sparse is very slow
|
172
|
+
means_batch = X_control[batch]
|
173
|
+
means_batch = means_batch.toarray() if is_sparse else means_batch
|
174
|
+
means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1)
|
146
175
|
|
147
|
-
|
176
|
+
adata.layers["X_pert"][split_batch] = (
|
177
|
+
np.log1p(means_batch) - adata.layers["X_pert"][split_batch]
|
178
|
+
)
|
148
179
|
|
149
180
|
if copy:
|
150
181
|
return adata
|
@@ -154,33 +185,41 @@ class Mixscape:
|
|
154
185
|
adata: AnnData,
|
155
186
|
labels: str,
|
156
187
|
control: str,
|
188
|
+
*,
|
157
189
|
new_class_name: str | None = "mixscape_class",
|
158
|
-
min_de_genes: int | None = 5,
|
159
190
|
layer: str | None = None,
|
191
|
+
min_de_genes: int | None = 5,
|
160
192
|
logfc_threshold: float | None = 0.25,
|
193
|
+
de_layer: str | None = None,
|
194
|
+
test_method: str | None = "wilcoxon",
|
161
195
|
iter_num: int | None = 10,
|
196
|
+
scale: bool | None = True,
|
162
197
|
split_by: str | None = None,
|
163
198
|
pval_cutoff: float | None = 5e-2,
|
164
199
|
perturbation_type: str | None = "KO",
|
200
|
+
random_state: int | None = 0,
|
165
201
|
copy: bool | None = False,
|
166
202
|
):
|
167
203
|
"""Identify perturbed and non-perturbed gRNA expressing cells that accounts for multiple treatments/conditions/chemical perturbations.
|
168
204
|
|
169
|
-
The implementation resembles https://satijalab.org/seurat/reference/runmixscape
|
205
|
+
The implementation resembles https://satijalab.org/seurat/reference/runmixscape.
|
170
206
|
|
171
207
|
Args:
|
172
208
|
adata: The annotated data object.
|
173
209
|
labels: The column of `.obs` with target gene labels.
|
174
|
-
control: Control category from the `
|
210
|
+
control: Control category from the `labels` column.
|
175
211
|
new_class_name: Name of mixscape classification to be stored in `.obs`.
|
176
|
-
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
177
212
|
layer: Key from adata.layers whose value will be used to perform tests on. Default is using `.layers["X_pert"]`.
|
213
|
+
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
178
214
|
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells (default: 0.25).
|
215
|
+
de_layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
|
216
|
+
test_method: Method to use for differential expression testing.
|
179
217
|
iter_num: Number of normalmixEM iterations to run if convergence does not occur.
|
180
|
-
|
181
|
-
|
218
|
+
scale: Scale the data specified in `layer` before running the GaussianMixture model on it.
|
219
|
+
split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific.
|
182
220
|
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
183
221
|
perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
222
|
+
random_state: Random seed for the GaussianMixture model.
|
184
223
|
copy: Determines whether a copy of the `adata` is returned.
|
185
224
|
|
186
225
|
Returns:
|
@@ -203,8 +242,8 @@ class Mixscape:
|
|
203
242
|
>>> import pertpy as pt
|
204
243
|
>>> mdata = pt.dt.papalexi_2021()
|
205
244
|
>>> ms_pt = pt.tl.Mixscape()
|
206
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
207
|
-
>>> ms_pt.mixscape(
|
245
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
246
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
208
247
|
"""
|
209
248
|
if copy:
|
210
249
|
adata = adata.copy()
|
@@ -218,7 +257,16 @@ class Mixscape:
|
|
218
257
|
split_masks = [split_obs == category for category in categories]
|
219
258
|
|
220
259
|
perturbation_markers = self._get_perturbation_markers(
|
221
|
-
adata,
|
260
|
+
adata=adata,
|
261
|
+
split_masks=split_masks,
|
262
|
+
categories=categories,
|
263
|
+
labels=labels,
|
264
|
+
control=control,
|
265
|
+
layer=de_layer,
|
266
|
+
pval_cutoff=pval_cutoff,
|
267
|
+
min_de_genes=min_de_genes,
|
268
|
+
logfc_threshold=logfc_threshold,
|
269
|
+
test_method=test_method,
|
222
270
|
)
|
223
271
|
|
224
272
|
adata_comp = adata
|
@@ -231,6 +279,7 @@ class Mixscape:
|
|
231
279
|
raise KeyError(
|
232
280
|
"No 'X_pert' found in .layers! Please run perturbation_signature first to calculate perturbation signature!"
|
233
281
|
) from None
|
282
|
+
|
234
283
|
# initialize return variables
|
235
284
|
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
|
236
285
|
adata.obs[new_class_name] = adata.obs[labels].astype(str)
|
@@ -241,10 +290,12 @@ class Mixscape:
|
|
241
290
|
dtype=np.object_,
|
242
291
|
)
|
243
292
|
gv_list: dict[str, dict] = {}
|
293
|
+
|
294
|
+
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0.0
|
244
295
|
for split, split_mask in enumerate(split_masks):
|
245
296
|
category = categories[split]
|
246
|
-
|
247
|
-
for gene in
|
297
|
+
gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
|
298
|
+
for gene in gene_targets:
|
248
299
|
post_prob = 0
|
249
300
|
orig_guide_cells = (adata.obs[labels] == gene) & split_mask
|
250
301
|
orig_guide_cells_index = list(orig_guide_cells.index[orig_guide_cells])
|
@@ -253,28 +304,38 @@ class Mixscape:
|
|
253
304
|
|
254
305
|
if len(perturbation_markers[(category, gene)]) == 0:
|
255
306
|
adata.obs.loc[orig_guide_cells, new_class_name] = f"{gene} NP"
|
307
|
+
|
256
308
|
else:
|
257
309
|
de_genes = perturbation_markers[(category, gene)]
|
258
310
|
de_genes_indices = self._get_column_indices(adata, list(de_genes))
|
311
|
+
|
259
312
|
dat = X[np.asarray(all_cells)][:, de_genes_indices]
|
313
|
+
dat_cells = all_cells[all_cells].index
|
314
|
+
if scale:
|
315
|
+
dat = sc.pp.scale(dat)
|
316
|
+
|
260
317
|
converged = False
|
261
318
|
n_iter = 0
|
262
|
-
old_classes = adata.obs[
|
319
|
+
old_classes = adata.obs[new_class_name][all_cells]
|
320
|
+
|
263
321
|
while not converged and n_iter < iter_num:
|
264
322
|
# Get all cells in current split&Gene
|
265
|
-
guide_cells = (adata.obs[
|
323
|
+
guide_cells = (adata.obs[new_class_name] == gene) & split_mask
|
324
|
+
|
266
325
|
# get average value for each gene over all selected cells
|
267
326
|
# all cells in current split&Gene minus all NT cells in current split
|
268
327
|
# Each row is for each cell, each column is for each gene, get mean for each column
|
269
|
-
|
270
|
-
|
271
|
-
)
|
328
|
+
guide_cells_dat_idx = all_cells[all_cells].index.get_indexer(guide_cells[guide_cells].index)
|
329
|
+
nt_cells_dat_idx = all_cells[all_cells].index.get_indexer(nt_cells[nt_cells].index)
|
330
|
+
vec = np.mean(dat[guide_cells_dat_idx], axis=0) - np.mean(dat[nt_cells_dat_idx], axis=0)
|
331
|
+
|
272
332
|
# project cells onto the perturbation vector
|
273
333
|
if isinstance(dat, spmatrix):
|
274
|
-
pvec = np.
|
334
|
+
pvec = np.dot(dat.toarray(), vec) / np.dot(vec, vec)
|
275
335
|
else:
|
276
|
-
pvec = np.
|
336
|
+
pvec = np.dot(dat, vec) / np.dot(vec, vec)
|
277
337
|
pvec = pd.Series(np.asarray(pvec).flatten(), index=list(all_cells.index[all_cells]))
|
338
|
+
|
278
339
|
if n_iter == 0:
|
279
340
|
gv = pd.DataFrame(columns=["pvec", labels])
|
280
341
|
gv["pvec"] = pvec
|
@@ -284,19 +345,22 @@ class Mixscape:
|
|
284
345
|
gv_list[gene] = {}
|
285
346
|
gv_list[gene][category] = gv
|
286
347
|
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
precisions_init = np.array([nt_norm[1], guide_norm[1]])
|
291
|
-
mm = GaussianMixture(
|
348
|
+
means_init = np.array([[pvec[nt_cells].mean()], [pvec[guide_cells].mean()]])
|
349
|
+
std_init = np.array([pvec[nt_cells].std(), pvec[guide_cells].std()])
|
350
|
+
mm = MixscapeGaussianMixture(
|
292
351
|
n_components=2,
|
293
352
|
covariance_type="spherical",
|
294
353
|
means_init=means_init,
|
295
|
-
precisions_init=
|
354
|
+
precisions_init=1 / (std_init ** 2),
|
355
|
+
random_state=random_state,
|
356
|
+
max_iter=5000,
|
357
|
+
fixed_means=[pvec[nt_cells].mean(), None],
|
358
|
+
fixed_covariances=[pvec[nt_cells].std() ** 2, None],
|
296
359
|
).fit(np.asarray(pvec).reshape(-1, 1))
|
297
360
|
probabilities = mm.predict_proba(np.array(pvec[orig_guide_cells_index]).reshape(-1, 1))
|
298
361
|
lik_ratio = probabilities[:, 0] / probabilities[:, 1]
|
299
362
|
post_prob = 1 / (1 + lik_ratio)
|
363
|
+
|
300
364
|
# based on the posterior probability, assign cells to the two classes
|
301
365
|
adata.obs.loc[
|
302
366
|
[orig_guide_cells_index[cell] for cell in np.where(post_prob > 0.5)[0]], new_class_name
|
@@ -304,11 +368,13 @@ class Mixscape:
|
|
304
368
|
adata.obs.loc[
|
305
369
|
[orig_guide_cells_index[cell] for cell in np.where(post_prob <= 0.5)[0]], new_class_name
|
306
370
|
] = f"{gene} NP"
|
371
|
+
|
307
372
|
if sum(adata.obs[new_class_name][split_mask] == gene) < min_de_genes:
|
308
373
|
adata.obs.loc[guide_cells, new_class_name] = "NP"
|
309
374
|
converged = True
|
310
375
|
if adata.obs[new_class_name][all_cells].equals(old_classes):
|
311
376
|
converged = True
|
377
|
+
|
312
378
|
old_classes = adata.obs[new_class_name][all_cells]
|
313
379
|
n_iter += 1
|
314
380
|
|
@@ -317,9 +383,7 @@ class Mixscape:
|
|
317
383
|
)
|
318
384
|
|
319
385
|
adata.obs[f"{new_class_name}_global"] = [a.split(" ")[-1] for a in adata.obs[new_class_name]]
|
320
|
-
adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] =
|
321
|
-
post_prob
|
322
|
-
).astype("int64")
|
386
|
+
adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = post_prob
|
323
387
|
adata.uns["mixscape"] = gv_list
|
324
388
|
|
325
389
|
if copy:
|
@@ -330,11 +394,13 @@ class Mixscape:
|
|
330
394
|
adata: AnnData,
|
331
395
|
labels: str,
|
332
396
|
control: str,
|
397
|
+
*,
|
333
398
|
mixscape_class_global: str | None = "mixscape_class_global",
|
334
399
|
layer: str | None = None,
|
335
400
|
n_comps: int | None = 10,
|
336
401
|
min_de_genes: int | None = 5,
|
337
402
|
logfc_threshold: float | None = 0.25,
|
403
|
+
test_method: str | None = "wilcoxon",
|
338
404
|
split_by: str | None = None,
|
339
405
|
pval_cutoff: float | None = 5e-2,
|
340
406
|
perturbation_type: str | None = "KO",
|
@@ -347,12 +413,13 @@ class Mixscape:
|
|
347
413
|
labels: The column of `.obs` with target gene labels.
|
348
414
|
control: Control category from the `pert_key` column.
|
349
415
|
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
|
350
|
-
layer:
|
416
|
+
layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
|
351
417
|
control: Control category from the `pert_key` column.
|
352
418
|
n_comps: Number of principal components to use.
|
353
419
|
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
354
420
|
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
|
355
|
-
|
421
|
+
test_method: Method to use for differential expression testing.
|
422
|
+
split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific.
|
356
423
|
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
357
424
|
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
358
425
|
copy: Determines whether a copy of the `adata` is returned.
|
@@ -370,9 +437,9 @@ class Mixscape:
|
|
370
437
|
>>> import pertpy as pt
|
371
438
|
>>> mdata = pt.dt.papalexi_2021()
|
372
439
|
>>> ms_pt = pt.tl.Mixscape()
|
373
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
374
|
-
>>> ms_pt.mixscape(
|
375
|
-
>>> ms_pt.lda(
|
440
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
441
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
442
|
+
>>> ms_pt.lda(mdata["rna"], "gene_target", "NT")
|
376
443
|
"""
|
377
444
|
if copy:
|
378
445
|
adata = adata.copy()
|
@@ -388,9 +455,8 @@ class Mixscape:
|
|
388
455
|
categories = split_obs.unique()
|
389
456
|
split_masks = [split_obs == category for category in categories]
|
390
457
|
|
391
|
-
mixscape_identifier = pt.tl.Mixscape()
|
392
458
|
# determine gene sets across all splits/groups through differential gene expression
|
393
|
-
perturbation_markers =
|
459
|
+
perturbation_markers = self._get_perturbation_markers(
|
394
460
|
adata=adata,
|
395
461
|
split_masks=split_masks,
|
396
462
|
categories=categories,
|
@@ -400,6 +466,7 @@ class Mixscape:
|
|
400
466
|
pval_cutoff=pval_cutoff,
|
401
467
|
min_de_genes=min_de_genes,
|
402
468
|
logfc_threshold=logfc_threshold,
|
469
|
+
test_method=test_method,
|
403
470
|
)
|
404
471
|
adata_subset = adata[
|
405
472
|
(adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
|
@@ -445,12 +512,21 @@ class Mixscape:
|
|
445
512
|
pval_cutoff: float,
|
446
513
|
min_de_genes: float,
|
447
514
|
logfc_threshold: float,
|
515
|
+
test_method: str,
|
448
516
|
) -> dict[tuple, np.ndarray]:
|
449
517
|
"""Determine gene sets across all splits/groups through differential gene expression
|
450
518
|
|
451
519
|
Args:
|
452
520
|
adata: :class:`~anndata.AnnData` object
|
453
|
-
|
521
|
+
split_masks: List of boolean masks for each split/group.
|
522
|
+
categories: List of split/group names.
|
523
|
+
labels: The column of `.obs` with target gene labels.
|
524
|
+
control: Control category from the `labels` column.
|
525
|
+
layer: Key from adata.layers whose value will be used to compare gene expression.
|
526
|
+
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
527
|
+
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
528
|
+
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
|
529
|
+
test_method: Method to use for differential expression testing.
|
454
530
|
|
455
531
|
Returns:
|
456
532
|
Set of column indices.
|
@@ -459,21 +535,21 @@ class Mixscape:
|
|
459
535
|
for split, split_mask in enumerate(split_masks):
|
460
536
|
category = categories[split]
|
461
537
|
# get gene sets for each split
|
462
|
-
|
538
|
+
gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
|
463
539
|
adata_split = adata[split_mask].copy()
|
464
540
|
# find top DE genes between cells with targeting and non-targeting gRNAs
|
465
541
|
sc.tl.rank_genes_groups(
|
466
542
|
adata_split,
|
467
543
|
layer=layer,
|
468
544
|
groupby=labels,
|
469
|
-
groups=
|
545
|
+
groups=gene_targets,
|
470
546
|
reference=control,
|
471
|
-
method=
|
547
|
+
method=test_method,
|
472
548
|
use_raw=False,
|
473
549
|
)
|
474
|
-
# get DE genes for each gene
|
475
|
-
for gene in
|
476
|
-
logfc_threshold_mask = adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene] >= logfc_threshold
|
550
|
+
# get DE genes for each target gene
|
551
|
+
for gene in gene_targets:
|
552
|
+
logfc_threshold_mask = np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
|
477
553
|
de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask]
|
478
554
|
pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask]
|
479
555
|
de_genes = de_genes[pvals_adj < pval_cutoff]
|
@@ -494,20 +570,6 @@ class Mixscape:
|
|
494
570
|
|
495
571
|
return indices
|
496
572
|
|
497
|
-
def _define_normal_mixscape(self, X: np.ndarray | sparse.spmatrix | pd.DataFrame | None) -> list[float]:
|
498
|
-
"""Calculates the mean and standard deviation of a matrix.
|
499
|
-
|
500
|
-
Args:
|
501
|
-
X: The matrix to calculate the properties for.
|
502
|
-
|
503
|
-
Returns:
|
504
|
-
Mean and standard deviation of the matrix.
|
505
|
-
"""
|
506
|
-
mu = X.mean()
|
507
|
-
sd = X.std()
|
508
|
-
|
509
|
-
return [mu, sd]
|
510
|
-
|
511
573
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
512
574
|
def plot_barplot( # pragma: no cover
|
513
575
|
self,
|
@@ -522,7 +584,6 @@ class Mixscape:
|
|
522
584
|
legend_text_size: int = 8,
|
523
585
|
legend_bbox_to_anchor: tuple[float, float] = None,
|
524
586
|
figsize: tuple[float, float] = (25, 25),
|
525
|
-
show: bool = True,
|
526
587
|
return_fig: bool = False,
|
527
588
|
) -> Figure | None:
|
528
589
|
"""Barplot to visualize perturbation scores calculated by the `mixscape` function.
|
@@ -548,8 +609,8 @@ class Mixscape:
|
|
548
609
|
>>> import pertpy as pt
|
549
610
|
>>> mdata = pt.dt.papalexi_2021()
|
550
611
|
>>> ms_pt = pt.tl.Mixscape()
|
551
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
552
|
-
>>> ms_pt.mixscape(
|
612
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
613
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
553
614
|
>>> ms_pt.plot_barplot(mdata["rna"], guide_rna_column="NT")
|
554
615
|
|
555
616
|
Preview:
|
@@ -611,10 +672,9 @@ class Mixscape:
|
|
611
672
|
fig.subplots_adjust(hspace=0.5, wspace=0.5)
|
612
673
|
plt.tight_layout()
|
613
674
|
|
614
|
-
if show:
|
615
|
-
plt.show()
|
616
675
|
if return_fig:
|
617
676
|
return fig
|
677
|
+
plt.show()
|
618
678
|
return None
|
619
679
|
|
620
680
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -630,7 +690,6 @@ class Mixscape:
|
|
630
690
|
subsample_number: int | None = 900,
|
631
691
|
vmin: float | None = -2,
|
632
692
|
vmax: float | None = 2,
|
633
|
-
show: bool = True,
|
634
693
|
return_fig: bool = False,
|
635
694
|
**kwds,
|
636
695
|
) -> Figure | None:
|
@@ -656,8 +715,8 @@ class Mixscape:
|
|
656
715
|
>>> import pertpy as pt
|
657
716
|
>>> mdata = pt.dt.papalexi_2021()
|
658
717
|
>>> ms_pt = pt.tl.Mixscape()
|
659
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
660
|
-
>>> ms_pt.mixscape(
|
718
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
719
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
661
720
|
>>> ms_pt.plot_heatmap(
|
662
721
|
... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT"
|
663
722
|
... )
|
@@ -683,10 +742,9 @@ class Mixscape:
|
|
683
742
|
**kwds,
|
684
743
|
)
|
685
744
|
|
686
|
-
if show:
|
687
|
-
plt.show()
|
688
745
|
if return_fig:
|
689
746
|
return fig
|
747
|
+
plt.show()
|
690
748
|
return None
|
691
749
|
|
692
750
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -702,7 +760,6 @@ class Mixscape:
|
|
702
760
|
split_by: str = None,
|
703
761
|
before_mixscape: bool = False,
|
704
762
|
perturbation_type: str = "KO",
|
705
|
-
show: bool = True,
|
706
763
|
return_fig: bool = False,
|
707
764
|
) -> Figure | None:
|
708
765
|
"""Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
|
@@ -734,8 +791,8 @@ class Mixscape:
|
|
734
791
|
>>> import pertpy as pt
|
735
792
|
>>> mdata = pt.dt.papalexi_2021()
|
736
793
|
>>> ms_pt = pt.tl.Mixscape()
|
737
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
738
|
-
>>> ms_pt.mixscape(
|
794
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
795
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
739
796
|
>>> ms_pt.plot_perturbscore(adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange")
|
740
797
|
|
741
798
|
Preview:
|
@@ -851,10 +908,9 @@ class Mixscape:
|
|
851
908
|
plt.legend(title="mixscape class", title_fontsize=14, fontsize=12)
|
852
909
|
sns.despine()
|
853
910
|
|
854
|
-
if show:
|
855
|
-
plt.show()
|
856
911
|
if return_fig:
|
857
912
|
return plt.gcf()
|
913
|
+
plt.show()
|
858
914
|
return None
|
859
915
|
|
860
916
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -879,7 +935,6 @@ class Mixscape:
|
|
879
935
|
ylabel: str | Sequence[str] | None = None,
|
880
936
|
rotation: float | None = None,
|
881
937
|
ax: Axes | None = None,
|
882
|
-
show: bool = True,
|
883
938
|
return_fig: bool = False,
|
884
939
|
**kwargs,
|
885
940
|
) -> Axes | Figure | None:
|
@@ -910,8 +965,8 @@ class Mixscape:
|
|
910
965
|
>>> import pertpy as pt
|
911
966
|
>>> mdata = pt.dt.papalexi_2021()
|
912
967
|
>>> ms_pt = pt.tl.Mixscape()
|
913
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
914
|
-
>>> ms_pt.mixscape(
|
968
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
969
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
915
970
|
>>> ms_pt.plot_violin(
|
916
971
|
... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class"
|
917
972
|
... )
|
@@ -1047,12 +1102,9 @@ class Mixscape:
|
|
1047
1102
|
if rotation is not None:
|
1048
1103
|
ax.tick_params(axis="x", labelrotation=rotation)
|
1049
1104
|
|
1050
|
-
show = settings.autoshow if show is None else show
|
1051
1105
|
if hue is not None and stripplot is True:
|
1052
1106
|
plt.legend(handles, labels)
|
1053
1107
|
|
1054
|
-
if show:
|
1055
|
-
plt.show()
|
1056
1108
|
if return_fig:
|
1057
1109
|
if multi_panel and groupby is None and len(ys) == 1:
|
1058
1110
|
return g
|
@@ -1060,6 +1112,7 @@ class Mixscape:
|
|
1060
1112
|
return axs[0]
|
1061
1113
|
else:
|
1062
1114
|
return axs
|
1115
|
+
plt.show()
|
1063
1116
|
return None
|
1064
1117
|
|
1065
1118
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1076,7 +1129,6 @@ class Mixscape:
|
|
1076
1129
|
color_map: Colormap | str | None = None,
|
1077
1130
|
palette: str | Sequence[str] | None = None,
|
1078
1131
|
ax: Axes | None = None,
|
1079
|
-
show: bool = True,
|
1080
1132
|
return_fig: bool = False,
|
1081
1133
|
**kwds,
|
1082
1134
|
) -> Figure | None:
|
@@ -1097,9 +1149,9 @@ class Mixscape:
|
|
1097
1149
|
>>> import pertpy as pt
|
1098
1150
|
>>> mdata = pt.dt.papalexi_2021()
|
1099
1151
|
>>> ms_pt = pt.tl.Mixscape()
|
1100
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
1101
|
-
>>> ms_pt.mixscape(
|
1102
|
-
>>> ms_pt.lda(
|
1152
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
1153
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
1154
|
+
>>> ms_pt.lda(mdata["rna"], "gene_target", "NT", split_by="replicate")
|
1103
1155
|
>>> ms_pt.plot_lda(adata=mdata["rna"], control="NT")
|
1104
1156
|
|
1105
1157
|
Preview:
|
@@ -1129,8 +1181,44 @@ class Mixscape:
|
|
1129
1181
|
**kwds,
|
1130
1182
|
)
|
1131
1183
|
|
1132
|
-
if show:
|
1133
|
-
plt.show()
|
1134
1184
|
if return_fig:
|
1135
1185
|
return fig
|
1186
|
+
plt.show()
|
1136
1187
|
return None
|
1188
|
+
|
1189
|
+
class MixscapeGaussianMixture(GaussianMixture):
|
1190
|
+
def __init__(
|
1191
|
+
self,
|
1192
|
+
n_components: int,
|
1193
|
+
fixed_means: Sequence[float] | None = None,
|
1194
|
+
fixed_covariances: Sequence[float] | None = None,
|
1195
|
+
**kwargs
|
1196
|
+
):
|
1197
|
+
"""Custom Gaussian Mixture Model where means and covariances can be fixed for specific components.
|
1198
|
+
|
1199
|
+
Args:
|
1200
|
+
n_components: Number of Gaussian components
|
1201
|
+
fixed_means: Means to fix (use None for those that should be estimated)
|
1202
|
+
fixed_covariances: Covariances to fix (use None for those that should be estimated)
|
1203
|
+
kwargs: Additional arguments passed to scikit-learn's GaussianMixture
|
1204
|
+
"""
|
1205
|
+
super().__init__(n_components=n_components, **kwargs)
|
1206
|
+
self.fixed_means = fixed_means
|
1207
|
+
self.fixed_covariances = fixed_covariances
|
1208
|
+
|
1209
|
+
def _m_step(self, X: np.ndarray, log_resp: np.ndarray):
|
1210
|
+
"""Modified M-step to respect fixed means and covariances."""
|
1211
|
+
super()._m_step(X, log_resp)
|
1212
|
+
|
1213
|
+
if self.fixed_means is not None:
|
1214
|
+
for i in range(self.n_components):
|
1215
|
+
if self.fixed_means[i] is not None:
|
1216
|
+
self.means_[i] = self.fixed_means[i]
|
1217
|
+
|
1218
|
+
if self.fixed_covariances is not None:
|
1219
|
+
for i in range(self.n_components):
|
1220
|
+
if self.fixed_covariances[i] is not None:
|
1221
|
+
self.covariances_[i] = self.fixed_covariances[i]
|
1222
|
+
|
1223
|
+
return self
|
1224
|
+
|