pertpy 0.9.5__py3-none-any.whl → 0.10.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
pertpy/tools/_mixscape.py CHANGED
@@ -41,9 +41,12 @@ class Mixscape:
41
41
  adata: AnnData,
42
42
  pert_key: str,
43
43
  control: str,
44
+ *,
45
+ ref_selection_mode: Literal["nn", "split_by"] = "nn",
44
46
  split_by: str | None = None,
45
47
  n_neighbors: int = 20,
46
48
  use_rep: str | None = None,
49
+ n_dims: int | None = 15,
47
50
  n_pcs: int | None = None,
48
51
  batch_size: int | None = None,
49
52
  copy: bool = False,
@@ -51,14 +54,18 @@ class Mixscape:
51
54
  ):
52
55
  """Calculate perturbation signature.
53
56
 
54
- For each cell, we identify `n_neighbors` cells from the control pool with the most similar mRNA expression profiles.
55
- The perturbation signature is calculated by subtracting the averaged mRNA expression profile of the control
56
- neighbors from the mRNA expression profile of each cell.
57
+ The perturbation signature is calculated by subtracting the mRNA expression profile of each cell from the averaged
58
+ mRNA expression profile of the control cells (selected according to `ref_selection_mode`).
59
+ The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the
60
+ perturbation signature is calculated on unscaled data by default, and we therefore recommend to do the same.
57
61
 
58
62
  Args:
59
63
  adata: The annotated data object.
60
64
  pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
61
- control: Control category from the `pert_key` column.
65
+ control: Name of the control category from the `pert_key` column.
66
+ ref_selection_mode: Method to select reference cells for the perturbation signature calculation. If `nn`,
67
+ the `n_neighbors` cells from the control pool with the most similar mRNA expression profiles are selected. If `split_by`,
68
+ the control cells from the same split in `split_by` (e.g. indicating biological replicates) are used to calculate the perturbation signature.
62
69
  split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
63
70
  the perturbation signature for every replicate separately.
64
71
  n_neighbors: Number of neighbors from the control to use for the perturbation signature.
@@ -66,7 +73,10 @@ class Mixscape:
66
73
  If `None`, the representation is chosen automatically:
67
74
  For `.n_vars` < 50, `.X` is used, otherwise 'X_pca' is used.
68
75
  If 'X_pca' is not present, it’s computed with default parameters.
69
- n_pcs: Use this many PCs. If `n_pcs==0` use `.X` if `use_rep is None`.
76
+ n_dims: Number of dimensions to use from the representation to calculate the perturbation signature.
77
+ If `None`, use all dimensions.
78
+ n_pcs: If PCA representation is used, the number of principal components to compute.
79
+ If `n_pcs==0` use `.X` if `use_rep is None`.
70
80
  batch_size: Size of batch to calculate the perturbation signature.
71
81
  If 'None', the perturbation signature is calcuated in the full mode, requiring more memory.
72
82
  The batched mode is very inefficient for sparse data.
@@ -83,8 +93,13 @@ class Mixscape:
83
93
  >>> import pertpy as pt
84
94
  >>> mdata = pt.dt.papalexi_2021()
85
95
  >>> ms_pt = pt.tl.Mixscape()
86
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
96
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
87
97
  """
98
+ if ref_selection_mode not in ["nn", "split_by"]:
99
+ raise ValueError("ref_selection_mode must be either 'nn' or 'split_by'.")
100
+ if ref_selection_mode == "split_by" and split_by is None:
101
+ raise ValueError("split_by must be provided if ref_selection_mode is 'split_by'.")
102
+
88
103
  if copy:
89
104
  adata = adata.copy()
90
105
 
@@ -92,59 +107,75 @@ class Mixscape:
92
107
 
93
108
  control_mask = adata.obs[pert_key] == control
94
109
 
95
- if split_by is None:
96
- split_masks = [np.full(adata.n_obs, True, dtype=bool)]
110
+ if ref_selection_mode == "split_by":
111
+ for split in adata.obs[split_by].unique():
112
+ split_mask = adata.obs[split_by] == split
113
+ control_mask_group = control_mask & split_mask
114
+ control_mean_expr = adata.X[control_mask_group].mean(0)
115
+ adata.layers["X_pert"][split_mask] = (
116
+ np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0)
117
+ - adata.layers["X_pert"][split_mask]
118
+ )
97
119
  else:
98
- split_obs = adata.obs[split_by]
99
- split_masks = [split_obs == cat for cat in split_obs.unique()]
120
+ if split_by is None:
121
+ split_masks = [np.full(adata.n_obs, True, dtype=bool)]
122
+ else:
123
+ split_obs = adata.obs[split_by]
124
+ split_masks = [split_obs == cat for cat in split_obs.unique()]
100
125
 
101
- representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
126
+ representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
127
+ if n_dims is not None and n_dims < representation.shape[1]:
128
+ representation = representation[:, :n_dims]
102
129
 
103
- for split_mask in split_masks:
104
- control_mask_split = control_mask & split_mask
130
+ for split_mask in split_masks:
131
+ control_mask_split = control_mask & split_mask
105
132
 
106
- R_split = representation[split_mask]
107
- R_control = representation[np.asarray(control_mask_split)]
133
+ R_split = representation[split_mask]
134
+ R_control = representation[np.asarray(control_mask_split)]
108
135
 
109
- from pynndescent import NNDescent
136
+ from pynndescent import NNDescent
110
137
 
111
- eps = kwargs.pop("epsilon", 0.1)
112
- nn_index = NNDescent(R_control, **kwargs)
113
- indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
138
+ eps = kwargs.pop("epsilon", 0.1)
139
+ nn_index = NNDescent(R_control, **kwargs)
140
+ indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
114
141
 
115
- X_control = np.expm1(adata.X[np.asarray(control_mask_split)])
142
+ X_control = np.expm1(adata.X[np.asarray(control_mask_split)])
116
143
 
117
- n_split = split_mask.sum()
118
- n_control = X_control.shape[0]
144
+ n_split = split_mask.sum()
145
+ n_control = X_control.shape[0]
119
146
 
120
- if batch_size is None:
121
- col_indices = np.ravel(indices)
122
- row_indices = np.repeat(np.arange(n_split), n_neighbors)
147
+ if batch_size is None:
148
+ col_indices = np.ravel(indices)
149
+ row_indices = np.repeat(np.arange(n_split), n_neighbors)
123
150
 
124
- neigh_matrix = csr_matrix(
125
- (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)),
126
- shape=(n_split, n_control),
127
- )
128
- neigh_matrix /= n_neighbors
129
- adata.layers["X_pert"][split_mask] -= np.log1p(neigh_matrix @ X_control)
130
- else:
131
- is_sparse = issparse(X_control)
132
- split_indices = np.where(split_mask)[0]
133
- for i in range(0, n_split, batch_size):
134
- size = min(i + batch_size, n_split)
135
- select = slice(i, size)
151
+ neigh_matrix = csr_matrix(
152
+ (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)),
153
+ shape=(n_split, n_control),
154
+ )
155
+ neigh_matrix /= n_neighbors
156
+ adata.layers["X_pert"][split_mask] = (
157
+ np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask]
158
+ )
159
+ else:
160
+ is_sparse = issparse(X_control)
161
+ split_indices = np.where(split_mask)[0]
162
+ for i in range(0, n_split, batch_size):
163
+ size = min(i + batch_size, n_split)
164
+ select = slice(i, size)
136
165
 
137
- batch = np.ravel(indices[select])
138
- split_batch = split_indices[select]
166
+ batch = np.ravel(indices[select])
167
+ split_batch = split_indices[select]
139
168
 
140
- size = size - i
169
+ size = size - i
141
170
 
142
- # sparse is very slow
143
- means_batch = X_control[batch]
144
- means_batch = means_batch.toarray() if is_sparse else means_batch
145
- means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1)
171
+ # sparse is very slow
172
+ means_batch = X_control[batch]
173
+ means_batch = means_batch.toarray() if is_sparse else means_batch
174
+ means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1)
146
175
 
147
- adata.layers["X_pert"][split_batch] -= np.log1p(means_batch)
176
+ adata.layers["X_pert"][split_batch] = (
177
+ np.log1p(means_batch) - adata.layers["X_pert"][split_batch]
178
+ )
148
179
 
149
180
  if copy:
150
181
  return adata
@@ -154,33 +185,41 @@ class Mixscape:
154
185
  adata: AnnData,
155
186
  labels: str,
156
187
  control: str,
188
+ *,
157
189
  new_class_name: str | None = "mixscape_class",
158
- min_de_genes: int | None = 5,
159
190
  layer: str | None = None,
191
+ min_de_genes: int | None = 5,
160
192
  logfc_threshold: float | None = 0.25,
193
+ de_layer: str | None = None,
194
+ test_method: str | None = "wilcoxon",
161
195
  iter_num: int | None = 10,
196
+ scale: bool | None = True,
162
197
  split_by: str | None = None,
163
198
  pval_cutoff: float | None = 5e-2,
164
199
  perturbation_type: str | None = "KO",
200
+ random_state: int | None = 0,
165
201
  copy: bool | None = False,
166
202
  ):
167
203
  """Identify perturbed and non-perturbed gRNA expressing cells that accounts for multiple treatments/conditions/chemical perturbations.
168
204
 
169
- The implementation resembles https://satijalab.org/seurat/reference/runmixscape
205
+ The implementation resembles https://satijalab.org/seurat/reference/runmixscape.
170
206
 
171
207
  Args:
172
208
  adata: The annotated data object.
173
209
  labels: The column of `.obs` with target gene labels.
174
- control: Control category from the `pert_key` column.
210
+ control: Control category from the `labels` column.
175
211
  new_class_name: Name of mixscape classification to be stored in `.obs`.
176
- min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
177
212
  layer: Key from adata.layers whose value will be used to perform tests on. Default is using `.layers["X_pert"]`.
213
+ min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
178
214
  logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells (default: 0.25).
215
+ de_layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
216
+ test_method: Method to use for differential expression testing.
179
217
  iter_num: Number of normalmixEM iterations to run if convergence does not occur.
180
- split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
181
- the perturbation signature for every replicate separately.
218
+ scale: Scale the data specified in `layer` before running the GaussianMixture model on it.
219
+ split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific.
182
220
  pval_cutoff: P-value cut-off for selection of significantly DE genes.
183
221
  perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications.
222
+ random_state: Random seed for the GaussianMixture model.
184
223
  copy: Determines whether a copy of the `adata` is returned.
185
224
 
186
225
  Returns:
@@ -203,8 +242,8 @@ class Mixscape:
203
242
  >>> import pertpy as pt
204
243
  >>> mdata = pt.dt.papalexi_2021()
205
244
  >>> ms_pt = pt.tl.Mixscape()
206
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
207
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
245
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
246
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
208
247
  """
209
248
  if copy:
210
249
  adata = adata.copy()
@@ -218,7 +257,16 @@ class Mixscape:
218
257
  split_masks = [split_obs == category for category in categories]
219
258
 
220
259
  perturbation_markers = self._get_perturbation_markers(
221
- adata, split_masks, categories, labels, control, layer, pval_cutoff, min_de_genes, logfc_threshold
260
+ adata=adata,
261
+ split_masks=split_masks,
262
+ categories=categories,
263
+ labels=labels,
264
+ control=control,
265
+ layer=de_layer,
266
+ pval_cutoff=pval_cutoff,
267
+ min_de_genes=min_de_genes,
268
+ logfc_threshold=logfc_threshold,
269
+ test_method=test_method,
222
270
  )
223
271
 
224
272
  adata_comp = adata
@@ -231,6 +279,7 @@ class Mixscape:
231
279
  raise KeyError(
232
280
  "No 'X_pert' found in .layers! Please run perturbation_signature first to calculate perturbation signature!"
233
281
  ) from None
282
+
234
283
  # initialize return variables
235
284
  adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
236
285
  adata.obs[new_class_name] = adata.obs[labels].astype(str)
@@ -241,10 +290,12 @@ class Mixscape:
241
290
  dtype=np.object_,
242
291
  )
243
292
  gv_list: dict[str, dict] = {}
293
+
294
+ adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0.0
244
295
  for split, split_mask in enumerate(split_masks):
245
296
  category = categories[split]
246
- genes = list(set(adata[split_mask].obs[labels]).difference([control]))
247
- for gene in genes:
297
+ gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
298
+ for gene in gene_targets:
248
299
  post_prob = 0
249
300
  orig_guide_cells = (adata.obs[labels] == gene) & split_mask
250
301
  orig_guide_cells_index = list(orig_guide_cells.index[orig_guide_cells])
@@ -253,28 +304,38 @@ class Mixscape:
253
304
 
254
305
  if len(perturbation_markers[(category, gene)]) == 0:
255
306
  adata.obs.loc[orig_guide_cells, new_class_name] = f"{gene} NP"
307
+
256
308
  else:
257
309
  de_genes = perturbation_markers[(category, gene)]
258
310
  de_genes_indices = self._get_column_indices(adata, list(de_genes))
311
+
259
312
  dat = X[np.asarray(all_cells)][:, de_genes_indices]
313
+ dat_cells = all_cells[all_cells].index
314
+ if scale:
315
+ dat = sc.pp.scale(dat)
316
+
260
317
  converged = False
261
318
  n_iter = 0
262
- old_classes = adata.obs[labels][all_cells]
319
+ old_classes = adata.obs[new_class_name][all_cells]
320
+
263
321
  while not converged and n_iter < iter_num:
264
322
  # Get all cells in current split&Gene
265
- guide_cells = (adata.obs[labels] == gene) & split_mask
323
+ guide_cells = (adata.obs[new_class_name] == gene) & split_mask
324
+
266
325
  # get average value for each gene over all selected cells
267
326
  # all cells in current split&Gene minus all NT cells in current split
268
327
  # Each row is for each cell, each column is for each gene, get mean for each column
269
- vec = np.mean(X[np.asarray(guide_cells)][:, de_genes_indices], axis=0) - np.mean(
270
- X[np.asarray(nt_cells)][:, de_genes_indices], axis=0
271
- )
328
+ guide_cells_dat_idx = all_cells[all_cells].index.get_indexer(guide_cells[guide_cells].index)
329
+ nt_cells_dat_idx = all_cells[all_cells].index.get_indexer(nt_cells[nt_cells].index)
330
+ vec = np.mean(dat[guide_cells_dat_idx], axis=0) - np.mean(dat[nt_cells_dat_idx], axis=0)
331
+
272
332
  # project cells onto the perturbation vector
273
333
  if isinstance(dat, spmatrix):
274
- pvec = np.sum(np.multiply(dat.toarray(), vec), axis=1) / np.sum(np.multiply(vec, vec))
334
+ pvec = np.dot(dat.toarray(), vec) / np.dot(vec, vec)
275
335
  else:
276
- pvec = np.sum(np.multiply(dat, vec), axis=1) / np.sum(np.multiply(vec, vec))
336
+ pvec = np.dot(dat, vec) / np.dot(vec, vec)
277
337
  pvec = pd.Series(np.asarray(pvec).flatten(), index=list(all_cells.index[all_cells]))
338
+
278
339
  if n_iter == 0:
279
340
  gv = pd.DataFrame(columns=["pvec", labels])
280
341
  gv["pvec"] = pvec
@@ -284,19 +345,22 @@ class Mixscape:
284
345
  gv_list[gene] = {}
285
346
  gv_list[gene][category] = gv
286
347
 
287
- guide_norm = self._define_normal_mixscape(pvec[guide_cells])
288
- nt_norm = self._define_normal_mixscape(pvec[nt_cells])
289
- means_init = np.array([[nt_norm[0]], [guide_norm[0]]])
290
- precisions_init = np.array([nt_norm[1], guide_norm[1]])
291
- mm = GaussianMixture(
348
+ means_init = np.array([[pvec[nt_cells].mean()], [pvec[guide_cells].mean()]])
349
+ std_init = np.array([pvec[nt_cells].std(), pvec[guide_cells].std()])
350
+ mm = MixscapeGaussianMixture(
292
351
  n_components=2,
293
352
  covariance_type="spherical",
294
353
  means_init=means_init,
295
- precisions_init=precisions_init,
354
+ precisions_init=1 / (std_init ** 2),
355
+ random_state=random_state,
356
+ max_iter=5000,
357
+ fixed_means=[pvec[nt_cells].mean(), None],
358
+ fixed_covariances=[pvec[nt_cells].std() ** 2, None],
296
359
  ).fit(np.asarray(pvec).reshape(-1, 1))
297
360
  probabilities = mm.predict_proba(np.array(pvec[orig_guide_cells_index]).reshape(-1, 1))
298
361
  lik_ratio = probabilities[:, 0] / probabilities[:, 1]
299
362
  post_prob = 1 / (1 + lik_ratio)
363
+
300
364
  # based on the posterior probability, assign cells to the two classes
301
365
  adata.obs.loc[
302
366
  [orig_guide_cells_index[cell] for cell in np.where(post_prob > 0.5)[0]], new_class_name
@@ -304,11 +368,13 @@ class Mixscape:
304
368
  adata.obs.loc[
305
369
  [orig_guide_cells_index[cell] for cell in np.where(post_prob <= 0.5)[0]], new_class_name
306
370
  ] = f"{gene} NP"
371
+
307
372
  if sum(adata.obs[new_class_name][split_mask] == gene) < min_de_genes:
308
373
  adata.obs.loc[guide_cells, new_class_name] = "NP"
309
374
  converged = True
310
375
  if adata.obs[new_class_name][all_cells].equals(old_classes):
311
376
  converged = True
377
+
312
378
  old_classes = adata.obs[new_class_name][all_cells]
313
379
  n_iter += 1
314
380
 
@@ -317,9 +383,7 @@ class Mixscape:
317
383
  )
318
384
 
319
385
  adata.obs[f"{new_class_name}_global"] = [a.split(" ")[-1] for a in adata.obs[new_class_name]]
320
- adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = np.round(
321
- post_prob
322
- ).astype("int64")
386
+ adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = post_prob
323
387
  adata.uns["mixscape"] = gv_list
324
388
 
325
389
  if copy:
@@ -330,11 +394,13 @@ class Mixscape:
330
394
  adata: AnnData,
331
395
  labels: str,
332
396
  control: str,
397
+ *,
333
398
  mixscape_class_global: str | None = "mixscape_class_global",
334
399
  layer: str | None = None,
335
400
  n_comps: int | None = 10,
336
401
  min_de_genes: int | None = 5,
337
402
  logfc_threshold: float | None = 0.25,
403
+ test_method: str | None = "wilcoxon",
338
404
  split_by: str | None = None,
339
405
  pval_cutoff: float | None = 5e-2,
340
406
  perturbation_type: str | None = "KO",
@@ -347,12 +413,13 @@ class Mixscape:
347
413
  labels: The column of `.obs` with target gene labels.
348
414
  control: Control category from the `pert_key` column.
349
415
  mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
350
- layer: Key from `adata.layers` whose value will be used to perform tests on.
416
+ layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
351
417
  control: Control category from the `pert_key` column.
352
418
  n_comps: Number of principal components to use.
353
419
  min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
354
420
  logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
355
- split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
421
+ test_method: Method to use for differential expression testing.
422
+ split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific.
356
423
  pval_cutoff: P-value cut-off for selection of significantly DE genes.
357
424
  perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
358
425
  copy: Determines whether a copy of the `adata` is returned.
@@ -370,9 +437,9 @@ class Mixscape:
370
437
  >>> import pertpy as pt
371
438
  >>> mdata = pt.dt.papalexi_2021()
372
439
  >>> ms_pt = pt.tl.Mixscape()
373
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
374
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
375
- >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
440
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
441
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
442
+ >>> ms_pt.lda(mdata["rna"], "gene_target", "NT")
376
443
  """
377
444
  if copy:
378
445
  adata = adata.copy()
@@ -388,9 +455,8 @@ class Mixscape:
388
455
  categories = split_obs.unique()
389
456
  split_masks = [split_obs == category for category in categories]
390
457
 
391
- mixscape_identifier = pt.tl.Mixscape()
392
458
  # determine gene sets across all splits/groups through differential gene expression
393
- perturbation_markers = mixscape_identifier._get_perturbation_markers(
459
+ perturbation_markers = self._get_perturbation_markers(
394
460
  adata=adata,
395
461
  split_masks=split_masks,
396
462
  categories=categories,
@@ -400,6 +466,7 @@ class Mixscape:
400
466
  pval_cutoff=pval_cutoff,
401
467
  min_de_genes=min_de_genes,
402
468
  logfc_threshold=logfc_threshold,
469
+ test_method=test_method,
403
470
  )
404
471
  adata_subset = adata[
405
472
  (adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
@@ -445,12 +512,21 @@ class Mixscape:
445
512
  pval_cutoff: float,
446
513
  min_de_genes: float,
447
514
  logfc_threshold: float,
515
+ test_method: str,
448
516
  ) -> dict[tuple, np.ndarray]:
449
517
  """Determine gene sets across all splits/groups through differential gene expression
450
518
 
451
519
  Args:
452
520
  adata: :class:`~anndata.AnnData` object
453
- col_names: Column names to extract the indices for
521
+ split_masks: List of boolean masks for each split/group.
522
+ categories: List of split/group names.
523
+ labels: The column of `.obs` with target gene labels.
524
+ control: Control category from the `labels` column.
525
+ layer: Key from adata.layers whose value will be used to compare gene expression.
526
+ pval_cutoff: P-value cut-off for selection of significantly DE genes.
527
+ min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
528
+ logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
529
+ test_method: Method to use for differential expression testing.
454
530
 
455
531
  Returns:
456
532
  Set of column indices.
@@ -459,21 +535,21 @@ class Mixscape:
459
535
  for split, split_mask in enumerate(split_masks):
460
536
  category = categories[split]
461
537
  # get gene sets for each split
462
- genes = list(set(adata[split_mask].obs[labels]).difference([control]))
538
+ gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
463
539
  adata_split = adata[split_mask].copy()
464
540
  # find top DE genes between cells with targeting and non-targeting gRNAs
465
541
  sc.tl.rank_genes_groups(
466
542
  adata_split,
467
543
  layer=layer,
468
544
  groupby=labels,
469
- groups=genes,
545
+ groups=gene_targets,
470
546
  reference=control,
471
- method="t-test",
547
+ method=test_method,
472
548
  use_raw=False,
473
549
  )
474
- # get DE genes for each gene
475
- for gene in genes:
476
- logfc_threshold_mask = adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene] >= logfc_threshold
550
+ # get DE genes for each target gene
551
+ for gene in gene_targets:
552
+ logfc_threshold_mask = np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
477
553
  de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask]
478
554
  pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask]
479
555
  de_genes = de_genes[pvals_adj < pval_cutoff]
@@ -494,20 +570,6 @@ class Mixscape:
494
570
 
495
571
  return indices
496
572
 
497
- def _define_normal_mixscape(self, X: np.ndarray | sparse.spmatrix | pd.DataFrame | None) -> list[float]:
498
- """Calculates the mean and standard deviation of a matrix.
499
-
500
- Args:
501
- X: The matrix to calculate the properties for.
502
-
503
- Returns:
504
- Mean and standard deviation of the matrix.
505
- """
506
- mu = X.mean()
507
- sd = X.std()
508
-
509
- return [mu, sd]
510
-
511
573
  @_doc_params(common_plot_args=doc_common_plot_args)
512
574
  def plot_barplot( # pragma: no cover
513
575
  self,
@@ -522,7 +584,6 @@ class Mixscape:
522
584
  legend_text_size: int = 8,
523
585
  legend_bbox_to_anchor: tuple[float, float] = None,
524
586
  figsize: tuple[float, float] = (25, 25),
525
- show: bool = True,
526
587
  return_fig: bool = False,
527
588
  ) -> Figure | None:
528
589
  """Barplot to visualize perturbation scores calculated by the `mixscape` function.
@@ -548,8 +609,8 @@ class Mixscape:
548
609
  >>> import pertpy as pt
549
610
  >>> mdata = pt.dt.papalexi_2021()
550
611
  >>> ms_pt = pt.tl.Mixscape()
551
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
552
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
612
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
613
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
553
614
  >>> ms_pt.plot_barplot(mdata["rna"], guide_rna_column="NT")
554
615
 
555
616
  Preview:
@@ -611,10 +672,9 @@ class Mixscape:
611
672
  fig.subplots_adjust(hspace=0.5, wspace=0.5)
612
673
  plt.tight_layout()
613
674
 
614
- if show:
615
- plt.show()
616
675
  if return_fig:
617
676
  return fig
677
+ plt.show()
618
678
  return None
619
679
 
620
680
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -630,7 +690,6 @@ class Mixscape:
630
690
  subsample_number: int | None = 900,
631
691
  vmin: float | None = -2,
632
692
  vmax: float | None = 2,
633
- show: bool = True,
634
693
  return_fig: bool = False,
635
694
  **kwds,
636
695
  ) -> Figure | None:
@@ -656,8 +715,8 @@ class Mixscape:
656
715
  >>> import pertpy as pt
657
716
  >>> mdata = pt.dt.papalexi_2021()
658
717
  >>> ms_pt = pt.tl.Mixscape()
659
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
660
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
718
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
719
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
661
720
  >>> ms_pt.plot_heatmap(
662
721
  ... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT"
663
722
  ... )
@@ -683,10 +742,9 @@ class Mixscape:
683
742
  **kwds,
684
743
  )
685
744
 
686
- if show:
687
- plt.show()
688
745
  if return_fig:
689
746
  return fig
747
+ plt.show()
690
748
  return None
691
749
 
692
750
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -702,7 +760,6 @@ class Mixscape:
702
760
  split_by: str = None,
703
761
  before_mixscape: bool = False,
704
762
  perturbation_type: str = "KO",
705
- show: bool = True,
706
763
  return_fig: bool = False,
707
764
  ) -> Figure | None:
708
765
  """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
@@ -734,8 +791,8 @@ class Mixscape:
734
791
  >>> import pertpy as pt
735
792
  >>> mdata = pt.dt.papalexi_2021()
736
793
  >>> ms_pt = pt.tl.Mixscape()
737
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
738
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
794
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
795
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
739
796
  >>> ms_pt.plot_perturbscore(adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange")
740
797
 
741
798
  Preview:
@@ -851,10 +908,9 @@ class Mixscape:
851
908
  plt.legend(title="mixscape class", title_fontsize=14, fontsize=12)
852
909
  sns.despine()
853
910
 
854
- if show:
855
- plt.show()
856
911
  if return_fig:
857
912
  return plt.gcf()
913
+ plt.show()
858
914
  return None
859
915
 
860
916
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -879,7 +935,6 @@ class Mixscape:
879
935
  ylabel: str | Sequence[str] | None = None,
880
936
  rotation: float | None = None,
881
937
  ax: Axes | None = None,
882
- show: bool = True,
883
938
  return_fig: bool = False,
884
939
  **kwargs,
885
940
  ) -> Axes | Figure | None:
@@ -910,8 +965,8 @@ class Mixscape:
910
965
  >>> import pertpy as pt
911
966
  >>> mdata = pt.dt.papalexi_2021()
912
967
  >>> ms_pt = pt.tl.Mixscape()
913
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
914
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
968
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
969
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
915
970
  >>> ms_pt.plot_violin(
916
971
  ... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class"
917
972
  ... )
@@ -1047,12 +1102,9 @@ class Mixscape:
1047
1102
  if rotation is not None:
1048
1103
  ax.tick_params(axis="x", labelrotation=rotation)
1049
1104
 
1050
- show = settings.autoshow if show is None else show
1051
1105
  if hue is not None and stripplot is True:
1052
1106
  plt.legend(handles, labels)
1053
1107
 
1054
- if show:
1055
- plt.show()
1056
1108
  if return_fig:
1057
1109
  if multi_panel and groupby is None and len(ys) == 1:
1058
1110
  return g
@@ -1060,6 +1112,7 @@ class Mixscape:
1060
1112
  return axs[0]
1061
1113
  else:
1062
1114
  return axs
1115
+ plt.show()
1063
1116
  return None
1064
1117
 
1065
1118
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1076,7 +1129,6 @@ class Mixscape:
1076
1129
  color_map: Colormap | str | None = None,
1077
1130
  palette: str | Sequence[str] | None = None,
1078
1131
  ax: Axes | None = None,
1079
- show: bool = True,
1080
1132
  return_fig: bool = False,
1081
1133
  **kwds,
1082
1134
  ) -> Figure | None:
@@ -1097,9 +1149,9 @@ class Mixscape:
1097
1149
  >>> import pertpy as pt
1098
1150
  >>> mdata = pt.dt.papalexi_2021()
1099
1151
  >>> ms_pt = pt.tl.Mixscape()
1100
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
1101
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
1102
- >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
1152
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
1153
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
1154
+ >>> ms_pt.lda(mdata["rna"], "gene_target", "NT", split_by="replicate")
1103
1155
  >>> ms_pt.plot_lda(adata=mdata["rna"], control="NT")
1104
1156
 
1105
1157
  Preview:
@@ -1129,8 +1181,44 @@ class Mixscape:
1129
1181
  **kwds,
1130
1182
  )
1131
1183
 
1132
- if show:
1133
- plt.show()
1134
1184
  if return_fig:
1135
1185
  return fig
1186
+ plt.show()
1136
1187
  return None
1188
+
1189
+ class MixscapeGaussianMixture(GaussianMixture):
1190
+ def __init__(
1191
+ self,
1192
+ n_components: int,
1193
+ fixed_means: Sequence[float] | None = None,
1194
+ fixed_covariances: Sequence[float] | None = None,
1195
+ **kwargs
1196
+ ):
1197
+ """Custom Gaussian Mixture Model where means and covariances can be fixed for specific components.
1198
+
1199
+ Args:
1200
+ n_components: Number of Gaussian components
1201
+ fixed_means: Means to fix (use None for those that should be estimated)
1202
+ fixed_covariances: Covariances to fix (use None for those that should be estimated)
1203
+ kwargs: Additional arguments passed to scikit-learn's GaussianMixture
1204
+ """
1205
+ super().__init__(n_components=n_components, **kwargs)
1206
+ self.fixed_means = fixed_means
1207
+ self.fixed_covariances = fixed_covariances
1208
+
1209
+ def _m_step(self, X: np.ndarray, log_resp: np.ndarray):
1210
+ """Modified M-step to respect fixed means and covariances."""
1211
+ super()._m_step(X, log_resp)
1212
+
1213
+ if self.fixed_means is not None:
1214
+ for i in range(self.n_components):
1215
+ if self.fixed_means[i] is not None:
1216
+ self.means_[i] = self.fixed_means[i]
1217
+
1218
+ if self.fixed_covariances is not None:
1219
+ for i in range(self.n_components):
1220
+ if self.fixed_covariances[i] is not None:
1221
+ self.covariances_[i] = self.fixed_covariances[i]
1222
+
1223
+ return self
1224
+