pertpy 0.9.5__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +1 -2
- pertpy/metadata/_cell_line.py +3 -5
- pertpy/preprocessing/_guide_rna.py +98 -10
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/_augur.py +32 -44
- pertpy/tools/_cinemaot.py +1 -3
- pertpy/tools/_coda/_base_coda.py +21 -29
- pertpy/tools/_dialogue.py +17 -21
- pertpy/tools/_differential_gene_expression/_base.py +4 -12
- pertpy/tools/_distances/_distances.py +56 -48
- pertpy/tools/_enrichment.py +1 -3
- pertpy/tools/_milo.py +4 -12
- pertpy/tools/_mixscape.py +215 -127
- pertpy/tools/_perturbation_space/_simple.py +1 -3
- pertpy/tools/_scgen/_scgen.py +1 -3
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/METADATA +2 -2
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/RECORD +20 -19
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/WHEEL +0 -0
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_mixscape.py
CHANGED
@@ -41,9 +41,12 @@ class Mixscape:
|
|
41
41
|
adata: AnnData,
|
42
42
|
pert_key: str,
|
43
43
|
control: str,
|
44
|
+
*,
|
45
|
+
ref_selection_mode: Literal["nn", "split_by"] = "nn",
|
44
46
|
split_by: str | None = None,
|
45
47
|
n_neighbors: int = 20,
|
46
48
|
use_rep: str | None = None,
|
49
|
+
n_dims: int | None = 15,
|
47
50
|
n_pcs: int | None = None,
|
48
51
|
batch_size: int | None = None,
|
49
52
|
copy: bool = False,
|
@@ -51,14 +54,18 @@ class Mixscape:
|
|
51
54
|
):
|
52
55
|
"""Calculate perturbation signature.
|
53
56
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
+
The perturbation signature is calculated by subtracting the mRNA expression profile of each cell from the averaged
|
58
|
+
mRNA expression profile of the control cells (selected according to `ref_selection_mode`).
|
59
|
+
The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the
|
60
|
+
perturbation signature is calculated on unscaled data by default, and we therefore recommend to do the same.
|
57
61
|
|
58
62
|
Args:
|
59
63
|
adata: The annotated data object.
|
60
64
|
pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
|
61
|
-
control:
|
65
|
+
control: Name of the control category from the `pert_key` column.
|
66
|
+
ref_selection_mode: Method to select reference cells for the perturbation signature calculation. If `nn`,
|
67
|
+
the `n_neighbors` cells from the control pool with the most similar mRNA expression profiles are selected. If `split_by`,
|
68
|
+
the control cells from the same split in `split_by` (e.g. indicating biological replicates) are used to calculate the perturbation signature.
|
62
69
|
split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
|
63
70
|
the perturbation signature for every replicate separately.
|
64
71
|
n_neighbors: Number of neighbors from the control to use for the perturbation signature.
|
@@ -66,7 +73,10 @@ class Mixscape:
|
|
66
73
|
If `None`, the representation is chosen automatically:
|
67
74
|
For `.n_vars` < 50, `.X` is used, otherwise 'X_pca' is used.
|
68
75
|
If 'X_pca' is not present, it’s computed with default parameters.
|
69
|
-
|
76
|
+
n_dims: Number of dimensions to use from the representation to calculate the perturbation signature.
|
77
|
+
If `None`, use all dimensions.
|
78
|
+
n_pcs: If PCA representation is used, the number of principal components to compute.
|
79
|
+
If `n_pcs==0` use `.X` if `use_rep is None`.
|
70
80
|
batch_size: Size of batch to calculate the perturbation signature.
|
71
81
|
If 'None', the perturbation signature is calcuated in the full mode, requiring more memory.
|
72
82
|
The batched mode is very inefficient for sparse data.
|
@@ -83,8 +93,13 @@ class Mixscape:
|
|
83
93
|
>>> import pertpy as pt
|
84
94
|
>>> mdata = pt.dt.papalexi_2021()
|
85
95
|
>>> ms_pt = pt.tl.Mixscape()
|
86
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
96
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
87
97
|
"""
|
98
|
+
if ref_selection_mode not in ["nn", "split_by"]:
|
99
|
+
raise ValueError("ref_selection_mode must be either 'nn' or 'split_by'.")
|
100
|
+
if ref_selection_mode == "split_by" and split_by is None:
|
101
|
+
raise ValueError("split_by must be provided if ref_selection_mode is 'split_by'.")
|
102
|
+
|
88
103
|
if copy:
|
89
104
|
adata = adata.copy()
|
90
105
|
|
@@ -92,59 +107,75 @@ class Mixscape:
|
|
92
107
|
|
93
108
|
control_mask = adata.obs[pert_key] == control
|
94
109
|
|
95
|
-
if
|
96
|
-
|
110
|
+
if ref_selection_mode == "split_by":
|
111
|
+
for split in adata.obs[split_by].unique():
|
112
|
+
split_mask = adata.obs[split_by] == split
|
113
|
+
control_mask_group = control_mask & split_mask
|
114
|
+
control_mean_expr = adata.X[control_mask_group].mean(0)
|
115
|
+
adata.layers["X_pert"][split_mask] = (
|
116
|
+
np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0)
|
117
|
+
- adata.layers["X_pert"][split_mask]
|
118
|
+
)
|
97
119
|
else:
|
98
|
-
|
99
|
-
|
120
|
+
if split_by is None:
|
121
|
+
split_masks = [np.full(adata.n_obs, True, dtype=bool)]
|
122
|
+
else:
|
123
|
+
split_obs = adata.obs[split_by]
|
124
|
+
split_masks = [split_obs == cat for cat in split_obs.unique()]
|
100
125
|
|
101
|
-
|
126
|
+
representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
|
127
|
+
if n_dims is not None and n_dims < representation.shape[1]:
|
128
|
+
representation = representation[:, :n_dims]
|
102
129
|
|
103
|
-
|
104
|
-
|
130
|
+
for split_mask in split_masks:
|
131
|
+
control_mask_split = control_mask & split_mask
|
105
132
|
|
106
|
-
|
107
|
-
|
133
|
+
R_split = representation[split_mask]
|
134
|
+
R_control = representation[np.asarray(control_mask_split)]
|
108
135
|
|
109
|
-
|
136
|
+
from pynndescent import NNDescent
|
110
137
|
|
111
|
-
|
112
|
-
|
113
|
-
|
138
|
+
eps = kwargs.pop("epsilon", 0.1)
|
139
|
+
nn_index = NNDescent(R_control, **kwargs)
|
140
|
+
indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
|
114
141
|
|
115
|
-
|
142
|
+
X_control = np.expm1(adata.X[np.asarray(control_mask_split)])
|
116
143
|
|
117
|
-
|
118
|
-
|
144
|
+
n_split = split_mask.sum()
|
145
|
+
n_control = X_control.shape[0]
|
119
146
|
|
120
|
-
|
121
|
-
|
122
|
-
|
147
|
+
if batch_size is None:
|
148
|
+
col_indices = np.ravel(indices)
|
149
|
+
row_indices = np.repeat(np.arange(n_split), n_neighbors)
|
123
150
|
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
151
|
+
neigh_matrix = csr_matrix(
|
152
|
+
(np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)),
|
153
|
+
shape=(n_split, n_control),
|
154
|
+
)
|
155
|
+
neigh_matrix /= n_neighbors
|
156
|
+
adata.layers["X_pert"][split_mask] = (
|
157
|
+
np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask]
|
158
|
+
)
|
159
|
+
else:
|
160
|
+
is_sparse = issparse(X_control)
|
161
|
+
split_indices = np.where(split_mask)[0]
|
162
|
+
for i in range(0, n_split, batch_size):
|
163
|
+
size = min(i + batch_size, n_split)
|
164
|
+
select = slice(i, size)
|
136
165
|
|
137
|
-
|
138
|
-
|
166
|
+
batch = np.ravel(indices[select])
|
167
|
+
split_batch = split_indices[select]
|
139
168
|
|
140
|
-
|
169
|
+
size = size - i
|
141
170
|
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
171
|
+
# sparse is very slow
|
172
|
+
means_batch = X_control[batch]
|
173
|
+
means_batch = means_batch.toarray() if is_sparse else means_batch
|
174
|
+
means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1)
|
146
175
|
|
147
|
-
|
176
|
+
adata.layers["X_pert"][split_batch] = (
|
177
|
+
np.log1p(means_batch) - adata.layers["X_pert"][split_batch]
|
178
|
+
)
|
148
179
|
|
149
180
|
if copy:
|
150
181
|
return adata
|
@@ -154,33 +185,41 @@ class Mixscape:
|
|
154
185
|
adata: AnnData,
|
155
186
|
labels: str,
|
156
187
|
control: str,
|
188
|
+
*,
|
157
189
|
new_class_name: str | None = "mixscape_class",
|
158
|
-
min_de_genes: int | None = 5,
|
159
190
|
layer: str | None = None,
|
191
|
+
min_de_genes: int | None = 5,
|
160
192
|
logfc_threshold: float | None = 0.25,
|
193
|
+
de_layer: str | None = None,
|
194
|
+
test_method: str | None = "wilcoxon",
|
161
195
|
iter_num: int | None = 10,
|
196
|
+
scale: bool | None = True,
|
162
197
|
split_by: str | None = None,
|
163
198
|
pval_cutoff: float | None = 5e-2,
|
164
199
|
perturbation_type: str | None = "KO",
|
200
|
+
random_state: int | None = 0,
|
165
201
|
copy: bool | None = False,
|
166
202
|
):
|
167
203
|
"""Identify perturbed and non-perturbed gRNA expressing cells that accounts for multiple treatments/conditions/chemical perturbations.
|
168
204
|
|
169
|
-
The implementation resembles https://satijalab.org/seurat/reference/runmixscape
|
205
|
+
The implementation resembles https://satijalab.org/seurat/reference/runmixscape.
|
170
206
|
|
171
207
|
Args:
|
172
208
|
adata: The annotated data object.
|
173
209
|
labels: The column of `.obs` with target gene labels.
|
174
|
-
control: Control category from the `
|
210
|
+
control: Control category from the `labels` column.
|
175
211
|
new_class_name: Name of mixscape classification to be stored in `.obs`.
|
176
|
-
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
177
212
|
layer: Key from adata.layers whose value will be used to perform tests on. Default is using `.layers["X_pert"]`.
|
213
|
+
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
178
214
|
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells (default: 0.25).
|
215
|
+
de_layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
|
216
|
+
test_method: Method to use for differential expression testing.
|
179
217
|
iter_num: Number of normalmixEM iterations to run if convergence does not occur.
|
180
|
-
|
181
|
-
|
218
|
+
scale: Scale the data specified in `layer` before running the GaussianMixture model on it.
|
219
|
+
split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific.
|
182
220
|
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
183
221
|
perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
222
|
+
random_state: Random seed for the GaussianMixture model.
|
184
223
|
copy: Determines whether a copy of the `adata` is returned.
|
185
224
|
|
186
225
|
Returns:
|
@@ -203,8 +242,8 @@ class Mixscape:
|
|
203
242
|
>>> import pertpy as pt
|
204
243
|
>>> mdata = pt.dt.papalexi_2021()
|
205
244
|
>>> ms_pt = pt.tl.Mixscape()
|
206
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
207
|
-
>>> ms_pt.mixscape(
|
245
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
246
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
208
247
|
"""
|
209
248
|
if copy:
|
210
249
|
adata = adata.copy()
|
@@ -218,7 +257,16 @@ class Mixscape:
|
|
218
257
|
split_masks = [split_obs == category for category in categories]
|
219
258
|
|
220
259
|
perturbation_markers = self._get_perturbation_markers(
|
221
|
-
adata,
|
260
|
+
adata=adata,
|
261
|
+
split_masks=split_masks,
|
262
|
+
categories=categories,
|
263
|
+
labels=labels,
|
264
|
+
control=control,
|
265
|
+
layer=de_layer,
|
266
|
+
pval_cutoff=pval_cutoff,
|
267
|
+
min_de_genes=min_de_genes,
|
268
|
+
logfc_threshold=logfc_threshold,
|
269
|
+
test_method=test_method,
|
222
270
|
)
|
223
271
|
|
224
272
|
adata_comp = adata
|
@@ -231,6 +279,7 @@ class Mixscape:
|
|
231
279
|
raise KeyError(
|
232
280
|
"No 'X_pert' found in .layers! Please run perturbation_signature first to calculate perturbation signature!"
|
233
281
|
) from None
|
282
|
+
|
234
283
|
# initialize return variables
|
235
284
|
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
|
236
285
|
adata.obs[new_class_name] = adata.obs[labels].astype(str)
|
@@ -241,10 +290,12 @@ class Mixscape:
|
|
241
290
|
dtype=np.object_,
|
242
291
|
)
|
243
292
|
gv_list: dict[str, dict] = {}
|
293
|
+
|
294
|
+
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0.0
|
244
295
|
for split, split_mask in enumerate(split_masks):
|
245
296
|
category = categories[split]
|
246
|
-
|
247
|
-
for gene in
|
297
|
+
gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
|
298
|
+
for gene in gene_targets:
|
248
299
|
post_prob = 0
|
249
300
|
orig_guide_cells = (adata.obs[labels] == gene) & split_mask
|
250
301
|
orig_guide_cells_index = list(orig_guide_cells.index[orig_guide_cells])
|
@@ -253,28 +304,38 @@ class Mixscape:
|
|
253
304
|
|
254
305
|
if len(perturbation_markers[(category, gene)]) == 0:
|
255
306
|
adata.obs.loc[orig_guide_cells, new_class_name] = f"{gene} NP"
|
307
|
+
|
256
308
|
else:
|
257
309
|
de_genes = perturbation_markers[(category, gene)]
|
258
310
|
de_genes_indices = self._get_column_indices(adata, list(de_genes))
|
311
|
+
|
259
312
|
dat = X[np.asarray(all_cells)][:, de_genes_indices]
|
313
|
+
dat_cells = all_cells[all_cells].index
|
314
|
+
if scale:
|
315
|
+
dat = sc.pp.scale(dat)
|
316
|
+
|
260
317
|
converged = False
|
261
318
|
n_iter = 0
|
262
|
-
old_classes = adata.obs[
|
319
|
+
old_classes = adata.obs[new_class_name][all_cells]
|
320
|
+
|
263
321
|
while not converged and n_iter < iter_num:
|
264
322
|
# Get all cells in current split&Gene
|
265
|
-
guide_cells = (adata.obs[
|
323
|
+
guide_cells = (adata.obs[new_class_name] == gene) & split_mask
|
324
|
+
|
266
325
|
# get average value for each gene over all selected cells
|
267
326
|
# all cells in current split&Gene minus all NT cells in current split
|
268
327
|
# Each row is for each cell, each column is for each gene, get mean for each column
|
269
|
-
|
270
|
-
|
271
|
-
)
|
328
|
+
guide_cells_dat_idx = all_cells[all_cells].index.get_indexer(guide_cells[guide_cells].index)
|
329
|
+
nt_cells_dat_idx = all_cells[all_cells].index.get_indexer(nt_cells[nt_cells].index)
|
330
|
+
vec = np.mean(dat[guide_cells_dat_idx], axis=0) - np.mean(dat[nt_cells_dat_idx], axis=0)
|
331
|
+
|
272
332
|
# project cells onto the perturbation vector
|
273
333
|
if isinstance(dat, spmatrix):
|
274
|
-
pvec = np.
|
334
|
+
pvec = np.dot(dat.toarray(), vec) / np.dot(vec, vec)
|
275
335
|
else:
|
276
|
-
pvec = np.
|
336
|
+
pvec = np.dot(dat, vec) / np.dot(vec, vec)
|
277
337
|
pvec = pd.Series(np.asarray(pvec).flatten(), index=list(all_cells.index[all_cells]))
|
338
|
+
|
278
339
|
if n_iter == 0:
|
279
340
|
gv = pd.DataFrame(columns=["pvec", labels])
|
280
341
|
gv["pvec"] = pvec
|
@@ -284,19 +345,22 @@ class Mixscape:
|
|
284
345
|
gv_list[gene] = {}
|
285
346
|
gv_list[gene][category] = gv
|
286
347
|
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
precisions_init = np.array([nt_norm[1], guide_norm[1]])
|
291
|
-
mm = GaussianMixture(
|
348
|
+
means_init = np.array([[pvec[nt_cells].mean()], [pvec[guide_cells].mean()]])
|
349
|
+
std_init = np.array([pvec[nt_cells].std(), pvec[guide_cells].std()])
|
350
|
+
mm = MixscapeGaussianMixture(
|
292
351
|
n_components=2,
|
293
352
|
covariance_type="spherical",
|
294
353
|
means_init=means_init,
|
295
|
-
precisions_init=
|
354
|
+
precisions_init=1 / (std_init ** 2),
|
355
|
+
random_state=random_state,
|
356
|
+
max_iter=5000,
|
357
|
+
fixed_means=[pvec[nt_cells].mean(), None],
|
358
|
+
fixed_covariances=[pvec[nt_cells].std() ** 2, None],
|
296
359
|
).fit(np.asarray(pvec).reshape(-1, 1))
|
297
360
|
probabilities = mm.predict_proba(np.array(pvec[orig_guide_cells_index]).reshape(-1, 1))
|
298
361
|
lik_ratio = probabilities[:, 0] / probabilities[:, 1]
|
299
362
|
post_prob = 1 / (1 + lik_ratio)
|
363
|
+
|
300
364
|
# based on the posterior probability, assign cells to the two classes
|
301
365
|
adata.obs.loc[
|
302
366
|
[orig_guide_cells_index[cell] for cell in np.where(post_prob > 0.5)[0]], new_class_name
|
@@ -304,11 +368,13 @@ class Mixscape:
|
|
304
368
|
adata.obs.loc[
|
305
369
|
[orig_guide_cells_index[cell] for cell in np.where(post_prob <= 0.5)[0]], new_class_name
|
306
370
|
] = f"{gene} NP"
|
371
|
+
|
307
372
|
if sum(adata.obs[new_class_name][split_mask] == gene) < min_de_genes:
|
308
373
|
adata.obs.loc[guide_cells, new_class_name] = "NP"
|
309
374
|
converged = True
|
310
375
|
if adata.obs[new_class_name][all_cells].equals(old_classes):
|
311
376
|
converged = True
|
377
|
+
|
312
378
|
old_classes = adata.obs[new_class_name][all_cells]
|
313
379
|
n_iter += 1
|
314
380
|
|
@@ -317,9 +383,7 @@ class Mixscape:
|
|
317
383
|
)
|
318
384
|
|
319
385
|
adata.obs[f"{new_class_name}_global"] = [a.split(" ")[-1] for a in adata.obs[new_class_name]]
|
320
|
-
adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] =
|
321
|
-
post_prob
|
322
|
-
).astype("int64")
|
386
|
+
adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = post_prob
|
323
387
|
adata.uns["mixscape"] = gv_list
|
324
388
|
|
325
389
|
if copy:
|
@@ -330,11 +394,13 @@ class Mixscape:
|
|
330
394
|
adata: AnnData,
|
331
395
|
labels: str,
|
332
396
|
control: str,
|
397
|
+
*,
|
333
398
|
mixscape_class_global: str | None = "mixscape_class_global",
|
334
399
|
layer: str | None = None,
|
335
400
|
n_comps: int | None = 10,
|
336
401
|
min_de_genes: int | None = 5,
|
337
402
|
logfc_threshold: float | None = 0.25,
|
403
|
+
test_method: str | None = "wilcoxon",
|
338
404
|
split_by: str | None = None,
|
339
405
|
pval_cutoff: float | None = 5e-2,
|
340
406
|
perturbation_type: str | None = "KO",
|
@@ -347,12 +413,13 @@ class Mixscape:
|
|
347
413
|
labels: The column of `.obs` with target gene labels.
|
348
414
|
control: Control category from the `pert_key` column.
|
349
415
|
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
|
350
|
-
layer:
|
416
|
+
layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
|
351
417
|
control: Control category from the `pert_key` column.
|
352
418
|
n_comps: Number of principal components to use.
|
353
419
|
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
354
420
|
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
|
355
|
-
|
421
|
+
test_method: Method to use for differential expression testing.
|
422
|
+
split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific.
|
356
423
|
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
357
424
|
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
358
425
|
copy: Determines whether a copy of the `adata` is returned.
|
@@ -370,9 +437,9 @@ class Mixscape:
|
|
370
437
|
>>> import pertpy as pt
|
371
438
|
>>> mdata = pt.dt.papalexi_2021()
|
372
439
|
>>> ms_pt = pt.tl.Mixscape()
|
373
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
374
|
-
>>> ms_pt.mixscape(
|
375
|
-
>>> ms_pt.lda(
|
440
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
441
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
442
|
+
>>> ms_pt.lda(mdata["rna"], "gene_target", "NT")
|
376
443
|
"""
|
377
444
|
if copy:
|
378
445
|
adata = adata.copy()
|
@@ -388,9 +455,8 @@ class Mixscape:
|
|
388
455
|
categories = split_obs.unique()
|
389
456
|
split_masks = [split_obs == category for category in categories]
|
390
457
|
|
391
|
-
mixscape_identifier = pt.tl.Mixscape()
|
392
458
|
# determine gene sets across all splits/groups through differential gene expression
|
393
|
-
perturbation_markers =
|
459
|
+
perturbation_markers = self._get_perturbation_markers(
|
394
460
|
adata=adata,
|
395
461
|
split_masks=split_masks,
|
396
462
|
categories=categories,
|
@@ -400,6 +466,7 @@ class Mixscape:
|
|
400
466
|
pval_cutoff=pval_cutoff,
|
401
467
|
min_de_genes=min_de_genes,
|
402
468
|
logfc_threshold=logfc_threshold,
|
469
|
+
test_method=test_method,
|
403
470
|
)
|
404
471
|
adata_subset = adata[
|
405
472
|
(adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
|
@@ -445,12 +512,21 @@ class Mixscape:
|
|
445
512
|
pval_cutoff: float,
|
446
513
|
min_de_genes: float,
|
447
514
|
logfc_threshold: float,
|
515
|
+
test_method: str,
|
448
516
|
) -> dict[tuple, np.ndarray]:
|
449
517
|
"""Determine gene sets across all splits/groups through differential gene expression
|
450
518
|
|
451
519
|
Args:
|
452
520
|
adata: :class:`~anndata.AnnData` object
|
453
|
-
|
521
|
+
split_masks: List of boolean masks for each split/group.
|
522
|
+
categories: List of split/group names.
|
523
|
+
labels: The column of `.obs` with target gene labels.
|
524
|
+
control: Control category from the `labels` column.
|
525
|
+
layer: Key from adata.layers whose value will be used to compare gene expression.
|
526
|
+
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
527
|
+
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
528
|
+
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
|
529
|
+
test_method: Method to use for differential expression testing.
|
454
530
|
|
455
531
|
Returns:
|
456
532
|
Set of column indices.
|
@@ -459,21 +535,21 @@ class Mixscape:
|
|
459
535
|
for split, split_mask in enumerate(split_masks):
|
460
536
|
category = categories[split]
|
461
537
|
# get gene sets for each split
|
462
|
-
|
538
|
+
gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
|
463
539
|
adata_split = adata[split_mask].copy()
|
464
540
|
# find top DE genes between cells with targeting and non-targeting gRNAs
|
465
541
|
sc.tl.rank_genes_groups(
|
466
542
|
adata_split,
|
467
543
|
layer=layer,
|
468
544
|
groupby=labels,
|
469
|
-
groups=
|
545
|
+
groups=gene_targets,
|
470
546
|
reference=control,
|
471
|
-
method=
|
547
|
+
method=test_method,
|
472
548
|
use_raw=False,
|
473
549
|
)
|
474
|
-
# get DE genes for each gene
|
475
|
-
for gene in
|
476
|
-
logfc_threshold_mask = adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene] >= logfc_threshold
|
550
|
+
# get DE genes for each target gene
|
551
|
+
for gene in gene_targets:
|
552
|
+
logfc_threshold_mask = np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
|
477
553
|
de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask]
|
478
554
|
pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask]
|
479
555
|
de_genes = de_genes[pvals_adj < pval_cutoff]
|
@@ -494,20 +570,6 @@ class Mixscape:
|
|
494
570
|
|
495
571
|
return indices
|
496
572
|
|
497
|
-
def _define_normal_mixscape(self, X: np.ndarray | sparse.spmatrix | pd.DataFrame | None) -> list[float]:
|
498
|
-
"""Calculates the mean and standard deviation of a matrix.
|
499
|
-
|
500
|
-
Args:
|
501
|
-
X: The matrix to calculate the properties for.
|
502
|
-
|
503
|
-
Returns:
|
504
|
-
Mean and standard deviation of the matrix.
|
505
|
-
"""
|
506
|
-
mu = X.mean()
|
507
|
-
sd = X.std()
|
508
|
-
|
509
|
-
return [mu, sd]
|
510
|
-
|
511
573
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
512
574
|
def plot_barplot( # pragma: no cover
|
513
575
|
self,
|
@@ -522,7 +584,6 @@ class Mixscape:
|
|
522
584
|
legend_text_size: int = 8,
|
523
585
|
legend_bbox_to_anchor: tuple[float, float] = None,
|
524
586
|
figsize: tuple[float, float] = (25, 25),
|
525
|
-
show: bool = True,
|
526
587
|
return_fig: bool = False,
|
527
588
|
) -> Figure | None:
|
528
589
|
"""Barplot to visualize perturbation scores calculated by the `mixscape` function.
|
@@ -548,8 +609,8 @@ class Mixscape:
|
|
548
609
|
>>> import pertpy as pt
|
549
610
|
>>> mdata = pt.dt.papalexi_2021()
|
550
611
|
>>> ms_pt = pt.tl.Mixscape()
|
551
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
552
|
-
>>> ms_pt.mixscape(
|
612
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
613
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
553
614
|
>>> ms_pt.plot_barplot(mdata["rna"], guide_rna_column="NT")
|
554
615
|
|
555
616
|
Preview:
|
@@ -611,10 +672,9 @@ class Mixscape:
|
|
611
672
|
fig.subplots_adjust(hspace=0.5, wspace=0.5)
|
612
673
|
plt.tight_layout()
|
613
674
|
|
614
|
-
if show:
|
615
|
-
plt.show()
|
616
675
|
if return_fig:
|
617
676
|
return fig
|
677
|
+
plt.show()
|
618
678
|
return None
|
619
679
|
|
620
680
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -630,7 +690,6 @@ class Mixscape:
|
|
630
690
|
subsample_number: int | None = 900,
|
631
691
|
vmin: float | None = -2,
|
632
692
|
vmax: float | None = 2,
|
633
|
-
show: bool = True,
|
634
693
|
return_fig: bool = False,
|
635
694
|
**kwds,
|
636
695
|
) -> Figure | None:
|
@@ -656,8 +715,8 @@ class Mixscape:
|
|
656
715
|
>>> import pertpy as pt
|
657
716
|
>>> mdata = pt.dt.papalexi_2021()
|
658
717
|
>>> ms_pt = pt.tl.Mixscape()
|
659
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
660
|
-
>>> ms_pt.mixscape(
|
718
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
719
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
661
720
|
>>> ms_pt.plot_heatmap(
|
662
721
|
... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT"
|
663
722
|
... )
|
@@ -683,10 +742,9 @@ class Mixscape:
|
|
683
742
|
**kwds,
|
684
743
|
)
|
685
744
|
|
686
|
-
if show:
|
687
|
-
plt.show()
|
688
745
|
if return_fig:
|
689
746
|
return fig
|
747
|
+
plt.show()
|
690
748
|
return None
|
691
749
|
|
692
750
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -702,7 +760,6 @@ class Mixscape:
|
|
702
760
|
split_by: str = None,
|
703
761
|
before_mixscape: bool = False,
|
704
762
|
perturbation_type: str = "KO",
|
705
|
-
show: bool = True,
|
706
763
|
return_fig: bool = False,
|
707
764
|
) -> Figure | None:
|
708
765
|
"""Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
|
@@ -734,8 +791,8 @@ class Mixscape:
|
|
734
791
|
>>> import pertpy as pt
|
735
792
|
>>> mdata = pt.dt.papalexi_2021()
|
736
793
|
>>> ms_pt = pt.tl.Mixscape()
|
737
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
738
|
-
>>> ms_pt.mixscape(
|
794
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
795
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
739
796
|
>>> ms_pt.plot_perturbscore(adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange")
|
740
797
|
|
741
798
|
Preview:
|
@@ -851,10 +908,9 @@ class Mixscape:
|
|
851
908
|
plt.legend(title="mixscape class", title_fontsize=14, fontsize=12)
|
852
909
|
sns.despine()
|
853
910
|
|
854
|
-
if show:
|
855
|
-
plt.show()
|
856
911
|
if return_fig:
|
857
912
|
return plt.gcf()
|
913
|
+
plt.show()
|
858
914
|
return None
|
859
915
|
|
860
916
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -879,7 +935,6 @@ class Mixscape:
|
|
879
935
|
ylabel: str | Sequence[str] | None = None,
|
880
936
|
rotation: float | None = None,
|
881
937
|
ax: Axes | None = None,
|
882
|
-
show: bool = True,
|
883
938
|
return_fig: bool = False,
|
884
939
|
**kwargs,
|
885
940
|
) -> Axes | Figure | None:
|
@@ -910,8 +965,8 @@ class Mixscape:
|
|
910
965
|
>>> import pertpy as pt
|
911
966
|
>>> mdata = pt.dt.papalexi_2021()
|
912
967
|
>>> ms_pt = pt.tl.Mixscape()
|
913
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
914
|
-
>>> ms_pt.mixscape(
|
968
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
969
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
915
970
|
>>> ms_pt.plot_violin(
|
916
971
|
... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class"
|
917
972
|
... )
|
@@ -1047,12 +1102,9 @@ class Mixscape:
|
|
1047
1102
|
if rotation is not None:
|
1048
1103
|
ax.tick_params(axis="x", labelrotation=rotation)
|
1049
1104
|
|
1050
|
-
show = settings.autoshow if show is None else show
|
1051
1105
|
if hue is not None and stripplot is True:
|
1052
1106
|
plt.legend(handles, labels)
|
1053
1107
|
|
1054
|
-
if show:
|
1055
|
-
plt.show()
|
1056
1108
|
if return_fig:
|
1057
1109
|
if multi_panel and groupby is None and len(ys) == 1:
|
1058
1110
|
return g
|
@@ -1060,6 +1112,7 @@ class Mixscape:
|
|
1060
1112
|
return axs[0]
|
1061
1113
|
else:
|
1062
1114
|
return axs
|
1115
|
+
plt.show()
|
1063
1116
|
return None
|
1064
1117
|
|
1065
1118
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1076,7 +1129,6 @@ class Mixscape:
|
|
1076
1129
|
color_map: Colormap | str | None = None,
|
1077
1130
|
palette: str | Sequence[str] | None = None,
|
1078
1131
|
ax: Axes | None = None,
|
1079
|
-
show: bool = True,
|
1080
1132
|
return_fig: bool = False,
|
1081
1133
|
**kwds,
|
1082
1134
|
) -> Figure | None:
|
@@ -1097,9 +1149,9 @@ class Mixscape:
|
|
1097
1149
|
>>> import pertpy as pt
|
1098
1150
|
>>> mdata = pt.dt.papalexi_2021()
|
1099
1151
|
>>> ms_pt = pt.tl.Mixscape()
|
1100
|
-
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
|
1101
|
-
>>> ms_pt.mixscape(
|
1102
|
-
>>> ms_pt.lda(
|
1152
|
+
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
|
1153
|
+
>>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
|
1154
|
+
>>> ms_pt.lda(mdata["rna"], "gene_target", "NT", split_by="replicate")
|
1103
1155
|
>>> ms_pt.plot_lda(adata=mdata["rna"], control="NT")
|
1104
1156
|
|
1105
1157
|
Preview:
|
@@ -1129,8 +1181,44 @@ class Mixscape:
|
|
1129
1181
|
**kwds,
|
1130
1182
|
)
|
1131
1183
|
|
1132
|
-
if show:
|
1133
|
-
plt.show()
|
1134
1184
|
if return_fig:
|
1135
1185
|
return fig
|
1186
|
+
plt.show()
|
1136
1187
|
return None
|
1188
|
+
|
1189
|
+
class MixscapeGaussianMixture(GaussianMixture):
|
1190
|
+
def __init__(
|
1191
|
+
self,
|
1192
|
+
n_components: int,
|
1193
|
+
fixed_means: Sequence[float] | None = None,
|
1194
|
+
fixed_covariances: Sequence[float] | None = None,
|
1195
|
+
**kwargs
|
1196
|
+
):
|
1197
|
+
"""Custom Gaussian Mixture Model where means and covariances can be fixed for specific components.
|
1198
|
+
|
1199
|
+
Args:
|
1200
|
+
n_components: Number of Gaussian components
|
1201
|
+
fixed_means: Means to fix (use None for those that should be estimated)
|
1202
|
+
fixed_covariances: Covariances to fix (use None for those that should be estimated)
|
1203
|
+
kwargs: Additional arguments passed to scikit-learn's GaussianMixture
|
1204
|
+
"""
|
1205
|
+
super().__init__(n_components=n_components, **kwargs)
|
1206
|
+
self.fixed_means = fixed_means
|
1207
|
+
self.fixed_covariances = fixed_covariances
|
1208
|
+
|
1209
|
+
def _m_step(self, X: np.ndarray, log_resp: np.ndarray):
|
1210
|
+
"""Modified M-step to respect fixed means and covariances."""
|
1211
|
+
super()._m_step(X, log_resp)
|
1212
|
+
|
1213
|
+
if self.fixed_means is not None:
|
1214
|
+
for i in range(self.n_components):
|
1215
|
+
if self.fixed_means[i] is not None:
|
1216
|
+
self.means_[i] = self.fixed_means[i]
|
1217
|
+
|
1218
|
+
if self.fixed_covariances is not None:
|
1219
|
+
for i in range(self.n_components):
|
1220
|
+
if self.fixed_covariances[i] is not None:
|
1221
|
+
self.covariances_[i] = self.fixed_covariances[i]
|
1222
|
+
|
1223
|
+
return self
|
1224
|
+
|