pertpy 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -33,24 +33,6 @@ config.update("jax_enable_x64", True)
33
33
  class Tasccoda(CompositionalModel2):
34
34
  r"""Statistical model for tree-aggregated differential composition analysis (tascCODA, Ostner et al., 2021).
35
35
 
36
- The hierarchical formulation of the model for one sample is:
37
-
38
- .. math::
39
- \\begin{align*}
40
- Y_i &\\sim \\textrm{DirMult}(\\bar{Y}_i, \\textbf{a}(\\textbf{x})_i)\\\\
41
- \\log(\\textbf{a}(X))_i &= \\alpha + X_{i, \\cdot} \\beta\\\\
42
- \\alpha_j &\\sim \\mathcal{N}(0, 10) & \\forall j\\in[p]\\\\
43
- \\beta &= \\hat{\\beta} A^T \\\\
44
- \\hat{\\beta}_{l, k} &= 0 & \\forall k \\in \\hat{v}, l \\in [d]\\\\
45
- \\hat{\\beta}_{l, k} &= \\theta \\tilde{\\beta}_{1, l, k} + (1- \\theta) \\tilde{\\beta}_{0, l, k} \\quad & \\forall k\\in\\{[v] \\smallsetminus \\hat{v}\\}, l \\in [d]\\\\
46
- \\tilde{\\beta}_{m, l, k} &= \\sigma_{m, l, k} * b_{m, l, k} \\quad & \\forall k\\in\\{[v] \\smallsetminus \\hat{v}\\}, m \\in \\{0, 1\\}, l \\in [d]\\\\
47
- \\sigma_{m, l, k} &\\sim \\textrm{Exp}(\\lambda_{m, l, k}^2/2) \\quad & \\forall k\\in\\{[v] \\smallsetminus \\hat{v}\\}, l \\in \\{0, 1\\}, l \\in [d]\\\\
48
- b_{m, l, k} &\\sim N(0,1) \\quad & \\forall k\\in\\{[v] \\smallsetminus \\hat{v}\\}, l \\in \\{0, 1\\}, l \\in [d]\\\\
49
- \\theta &\\sim \\textrm{Beta}(1, \\frac{1}{|\\{[v] \\smallsetminus \\hat{v}\\}|})
50
- \\end{align*}
51
-
52
- with Y being the cell counts, X the covariates, and v the set of nodes of the underlying tree structure.
53
-
54
36
  For further information, see `tascCODA: Bayesian Tree-Aggregated Analysis of Compositional Amplicon and Single-Cell Data`
55
37
  (Ostner et al., 2021)
56
38
  """
@@ -75,11 +57,14 @@ class Tasccoda(CompositionalModel2):
75
57
  modality_key_1: str = "rna",
76
58
  modality_key_2: str = "coda",
77
59
  ) -> MuData:
78
- """Prepare a MuData object for subsequent processing. If type is "cell_level", then create a compositional analysis dataset from the input adata. If type is "sample_level", generate ete tree for tascCODA models from dendrogram information or cell-level observations.
60
+ """Prepare a MuData object for subsequent processing.
61
+
62
+ If type is "cell_level", then create a compositional analysis dataset from the input adata.
63
+ If type is "sample_level", generate ete tree for tascCODA models from dendrogram information or cell-level observations.
79
64
 
80
- When using ``type="cell_level"``, ``adata`` needs to have a column in ``adata.obs`` that contains the cell type assignment.
65
+ When using `type="cell_level"`, `adata` needs to have a column in `adata.obs` that contains the cell type assignment.
81
66
  Further, it must contain one column or a set of columns (e.g. subject id, treatment, disease status) that uniquely identify each (statistical) sample.
82
- Further covariates (e.g. subject age) can either be specified via addidional column names in ``adata.obs``, a key in ``adata.uns``, or as a separate DataFrame.
67
+ Further covariates (e.g. subject age) can either be specified via addidional column names in `adata.obs`, a key in `adata.uns`, or as a separate DataFrame.
83
68
 
84
69
  Args:
85
70
  adata: AnnData object.
@@ -90,10 +75,13 @@ class Tasccoda(CompositionalModel2):
90
75
  covariate_obs: If type is "cell_level", specify list of keys for adata.obs, where covariate values are stored.
91
76
  covariate_df: If type is "cell_level", specify dataFrame with covariates.
92
77
  dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object.
93
- levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
94
- levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
78
+ levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels.
79
+ The list must begin with the root level, and end with the leaf level.
80
+ levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels.
81
+ The list must begin with the root level, and end with the leaf level.
95
82
  add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
96
- key_added: If not specified, the tree is stored in .uns[tree]. If `data` is AnnData, save tree in `data`. If `data` is MuData, save tree in data[modality_2].
83
+ key_added: If not specified, the tree is stored in `.uns['tree']`.
84
+ If `data` is AnnData, save tree in `data`. If `data` is MuData, save tree in data[modality_2].
97
85
  modality_key_1: Key to the cell-level AnnData in the MuData object.
98
86
  modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
99
87
 
@@ -120,8 +108,10 @@ class Tasccoda(CompositionalModel2):
120
108
  covariate_df=covariate_df,
121
109
  )
122
110
  mdata = MuData({modality_key_1: adata, modality_key_2: adata_coda})
123
- else:
111
+ elif type == "sample_level":
124
112
  mdata = MuData({modality_key_1: AnnData(), modality_key_2: adata})
113
+ else:
114
+ raise ValueError(f'{type} is not a supported type, expected "cell_level" or "sample_level".')
125
115
  import_tree(
126
116
  data=mdata,
127
117
  modality_1=modality_key_1,
@@ -464,7 +454,7 @@ class Tasccoda(CompositionalModel2):
464
454
  self,
465
455
  data: AnnData | MuData,
466
456
  modality_key: str = "coda",
467
- rng_key=None,
457
+ rng_key: int | None = None,
468
458
  num_prior_samples: int = 500,
469
459
  use_posterior_predictive: bool = True,
470
460
  ) -> az.InferenceData:
@@ -547,6 +537,8 @@ class Tasccoda(CompositionalModel2):
547
537
  if rng_key is None:
548
538
  rng = np.random.default_rng()
549
539
  rng_key = random.key(rng.integers(0, 10000))
540
+ else:
541
+ rng_key = random.key(rng_key)
550
542
 
551
543
  if use_posterior_predictive:
552
544
  posterior_predictive = Predictive(self.model, self.mcmc.get_samples())(
@@ -557,6 +549,15 @@ class Tasccoda(CompositionalModel2):
557
549
  ref_index=ref_index,
558
550
  sample_adata=sample_adata,
559
551
  )
552
+ # Remove problematic posterior predictive arrays with wrong dimensions
553
+ if posterior_predictive and "counts" in posterior_predictive:
554
+ counts_shape = posterior_predictive["counts"].shape
555
+ expected_dims = 2 # ['sample', 'cell_type']
556
+ if len(counts_shape) != expected_dims:
557
+ posterior_predictive = {k: v for k, v in posterior_predictive.items() if k != "counts"}
558
+ logger.warning(
559
+ f"Removed 'counts' from posterior_predictive due to dimension mismatch: got {len(counts_shape)}D, expected {expected_dims}D"
560
+ )
560
561
  else:
561
562
  posterior_predictive = None
562
563
 
@@ -569,6 +570,15 @@ class Tasccoda(CompositionalModel2):
569
570
  ref_index=ref_index,
570
571
  sample_adata=sample_adata,
571
572
  )
573
+ # Remove problematic prior arrays with wrong dimensions
574
+ if prior and "counts" in prior:
575
+ counts_shape = prior["counts"].shape
576
+ expected_dims = 2 # ['sample', 'cell_type']
577
+ if len(counts_shape) != expected_dims:
578
+ prior = {k: v for k, v in prior.items() if k != "counts"}
579
+ logger.warning(
580
+ f"Removed 'counts' from prior due to dimension mismatch: got {len(counts_shape)}D, expected {expected_dims}D"
581
+ )
572
582
  else:
573
583
  prior = None
574
584
 
@@ -592,80 +602,88 @@ class Tasccoda(CompositionalModel2):
592
602
  *args,
593
603
  **kwargs,
594
604
  ):
595
- """Examples:
596
- >>> import pertpy as pt
597
- >>> adata = pt.dt.tasccoda_example()
598
- >>> tasccoda = pt.tl.Tasccoda()
599
- >>> mdata = tasccoda.load(
600
- >>> adata, type="sample_level",
601
- >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
602
- >>> key_added="lineage", add_level_name=True
603
- >>> )
604
- >>> mdata = tasccoda.prepare(
605
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
606
- >>> )
607
- >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42).
608
- """ # noqa: D205
605
+ """
606
+
607
+ Examples:
608
+ >>> import pertpy as pt
609
+ >>> adata = pt.dt.tasccoda_example()
610
+ >>> tasccoda = pt.tl.Tasccoda()
611
+ >>> mdata = tasccoda.load(
612
+ >>> adata, type="sample_level",
613
+ >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
614
+ >>> key_added="lineage", add_level_name=True
615
+ >>> )
616
+ >>> mdata = tasccoda.prepare(
617
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
618
+ >>> )
619
+ >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42).
620
+ """ # noqa: D205, D212
609
621
  return super().run_nuts(data, modality_key, num_samples, num_warmup, rng_key, copy, *args, **kwargs)
610
622
 
611
623
  run_nuts.__doc__ = CompositionalModel2.run_nuts.__doc__ + run_nuts.__doc__
612
624
 
613
625
  def summary(self, data: AnnData | MuData, extended: bool = False, modality_key: str = "coda", *args, **kwargs):
614
- """Examples:
615
- >>> import pertpy as pt
616
- >>> adata = pt.dt.tasccoda_example()
617
- >>> tasccoda = pt.tl.Tasccoda()
618
- >>> mdata = tasccoda.load(
619
- >>> adata, type="sample_level",
620
- >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
621
- >>> key_added="lineage", add_level_name=True
622
- >>> )
623
- >>> mdata = tasccoda.prepare(
624
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
625
- >>> )
626
- >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
627
- >>> tasccoda.summary(mdata).
628
- """ # noqa: D205
626
+ """
627
+
628
+ Examples:
629
+ >>> import pertpy as pt
630
+ >>> adata = pt.dt.tasccoda_example()
631
+ >>> tasccoda = pt.tl.Tasccoda()
632
+ >>> mdata = tasccoda.load(
633
+ >>> adata, type="sample_level",
634
+ >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
635
+ >>> key_added="lineage", add_level_name=True
636
+ >>> )
637
+ >>> mdata = tasccoda.prepare(
638
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
639
+ >>> )
640
+ >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
641
+ >>> tasccoda.summary(mdata).
642
+ """ # noqa: D205, D212
629
643
  return super().summary(data, extended, modality_key, *args, **kwargs)
630
644
 
631
645
  summary.__doc__ = CompositionalModel2.summary.__doc__ + summary.__doc__
632
646
 
633
647
  def credible_effects(self, data: AnnData | MuData, modality_key: str = "coda", est_fdr: float = None) -> pd.Series:
634
- """Examples:
635
- >>> import pertpy as pt
636
- >>> adata = pt.dt.tasccoda_example()
637
- >>> tasccoda = pt.tl.Tasccoda()
638
- >>> mdata = tasccoda.load(
639
- >>> adata, type="sample_level",
640
- >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
641
- >>> key_added="lineage", add_level_name=True
642
- >>> )
643
- >>> mdata = tasccoda.prepare(
644
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
645
- >>> )
646
- >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
647
- >>> tasccoda.credible_effects(mdata).
648
- """ # noqa: D205
648
+ """
649
+
650
+ Examples:
651
+ >>> import pertpy as pt
652
+ >>> adata = pt.dt.tasccoda_example()
653
+ >>> tasccoda = pt.tl.Tasccoda()
654
+ >>> mdata = tasccoda.load(
655
+ >>> adata, type="sample_level",
656
+ >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
657
+ >>> key_added="lineage", add_level_name=True
658
+ >>> )
659
+ >>> mdata = tasccoda.prepare(
660
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
661
+ >>> )
662
+ >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
663
+ >>> tasccoda.credible_effects(mdata).
664
+ """ # noqa: D205, D212
649
665
  return super().credible_effects(data, modality_key, est_fdr)
650
666
 
651
667
  credible_effects.__doc__ = CompositionalModel2.credible_effects.__doc__ + credible_effects.__doc__
652
668
 
653
669
  def set_fdr(self, data: AnnData | MuData, est_fdr: float, modality_key: str = "coda", *args, **kwargs):
654
- """Examples:
655
- >>> import pertpy as pt
656
- >>> adata = pt.dt.tasccoda_example()
657
- >>> tasccoda = pt.tl.Tasccoda()
658
- >>> mdata = tasccoda.load(
659
- >>> adata, type="sample_level",
660
- >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
661
- >>> key_added="lineage", add_level_name=True
662
- >>> )
663
- >>> mdata = tasccoda.prepare(
664
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
665
- >>> )
666
- >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
667
- >>> tasccoda.set_fdr(mdata, est_fdr=0.4).
668
- """ # noqa: D205
670
+ """
671
+
672
+ Examples:
673
+ >>> import pertpy as pt
674
+ >>> adata = pt.dt.tasccoda_example()
675
+ >>> tasccoda = pt.tl.Tasccoda()
676
+ >>> mdata = tasccoda.load(
677
+ >>> adata, type="sample_level",
678
+ >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
679
+ >>> key_added="lineage", add_level_name=True
680
+ >>> )
681
+ >>> mdata = tasccoda.prepare(
682
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
683
+ >>> )
684
+ >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
685
+ >>> tasccoda.set_fdr(mdata, est_fdr=0.4).
686
+ """ # noqa: D205, D212
669
687
  return super().set_fdr(data, est_fdr, modality_key, *args, **kwargs)
670
688
 
671
689
  set_fdr.__doc__ = CompositionalModel2.set_fdr.__doc__ + set_fdr.__doc__
pertpy/tools/_mixscape.py CHANGED
@@ -177,7 +177,7 @@ class Mixscape:
177
177
  def mixscape(
178
178
  self,
179
179
  adata: AnnData,
180
- labels: str,
180
+ pert_key: str,
181
181
  control: str,
182
182
  *,
183
183
  new_class_name: str | None = "mixscape_class",
@@ -201,12 +201,12 @@ class Mixscape:
201
201
 
202
202
  Args:
203
203
  adata: The annotated data object.
204
- labels: The column of `.obs` with target gene labels.
204
+ pert_key: The column of `.obs` with target gene labels.
205
205
  control: Control category from the `labels` column.
206
206
  new_class_name: Name of mixscape classification to be stored in `.obs`.
207
207
  layer: Key from adata.layers whose value will be used to perform tests on. Default is using `.layers["X_pert"]`.
208
208
  min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
209
- logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells (default: 0.25).
209
+ logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
210
210
  de_layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
211
211
  test_method: Method to use for differential expression testing.
212
212
  iter_num: Number of normalmixEM iterations to run if convergence does not occur.
@@ -256,7 +256,7 @@ class Mixscape:
256
256
  adata=adata,
257
257
  split_masks=split_masks,
258
258
  categories=categories,
259
- labels=labels,
259
+ pert_key=pert_key,
260
260
  control=control,
261
261
  layer=de_layer,
262
262
  pval_cutoff=pval_cutoff,
@@ -278,7 +278,7 @@ class Mixscape:
278
278
 
279
279
  # initialize return variables
280
280
  adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
281
- adata.obs[new_class_name] = adata.obs[labels].astype(str)
281
+ adata.obs[new_class_name] = adata.obs[pert_key].astype(str)
282
282
  adata.obs[f"{new_class_name}_global"] = np.empty(
283
283
  [
284
284
  adata.n_obs,
@@ -290,12 +290,12 @@ class Mixscape:
290
290
  adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0.0
291
291
  for split, split_mask in enumerate(split_masks):
292
292
  category = categories[split]
293
- gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
293
+ gene_targets = list(set(adata[split_mask].obs[pert_key]).difference([control]))
294
294
  for gene in gene_targets:
295
295
  post_prob = 0
296
- orig_guide_cells = (adata.obs[labels] == gene) & split_mask
296
+ orig_guide_cells = (adata.obs[pert_key] == gene) & split_mask
297
297
  orig_guide_cells_index = list(orig_guide_cells.index[orig_guide_cells])
298
- nt_cells = (adata.obs[labels] == control) & split_mask
298
+ nt_cells = (adata.obs[pert_key] == control) & split_mask
299
299
  all_cells = orig_guide_cells | nt_cells
300
300
 
301
301
  if len(perturbation_markers[(category, gene)]) == 0:
@@ -307,7 +307,11 @@ class Mixscape:
307
307
 
308
308
  dat = X[np.asarray(all_cells)][:, de_genes_indices]
309
309
  if scale:
310
- dat = sc.pp.scale(dat)
310
+ with warnings.catch_warnings():
311
+ warnings.filterwarnings(
312
+ "ignore", message="zero-centering a sparse array/matrix densifies it."
313
+ )
314
+ dat = sc.pp.scale(dat)
311
315
 
312
316
  converged = False
313
317
  n_iter = 0
@@ -335,10 +339,10 @@ class Mixscape:
335
339
  pvec = pd.Series(np.asarray(pvec).flatten(), index=list(all_cells.index[all_cells]))
336
340
 
337
341
  if n_iter == 0:
338
- gv = pd.DataFrame(columns=["pvec", labels])
342
+ gv = pd.DataFrame(columns=["pvec", pert_key])
339
343
  gv["pvec"] = pvec
340
- gv[labels] = control
341
- gv.loc[guide_cells, labels] = gene
344
+ gv[pert_key] = control
345
+ gv.loc[guide_cells, pert_key] = gene
342
346
  if gene not in gv_list:
343
347
  gv_list[gene] = {}
344
348
  gv_list[gene][category] = gv
@@ -389,7 +393,7 @@ class Mixscape:
389
393
  def lda(
390
394
  self,
391
395
  adata: AnnData,
392
- labels: str,
396
+ pert_key: str,
393
397
  control: str,
394
398
  *,
395
399
  mixscape_class_global: str | None = "mixscape_class_global",
@@ -407,7 +411,7 @@ class Mixscape:
407
411
 
408
412
  Args:
409
413
  adata: The annotated data object.
410
- labels: The column of `.obs` with target gene labels.
414
+ pert_key: The column of `.obs` with target gene labels.
411
415
  control: Control category from the `pert_key` column.
412
416
  mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
413
417
  layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
@@ -456,7 +460,7 @@ class Mixscape:
456
460
  adata=adata,
457
461
  split_masks=split_masks,
458
462
  categories=categories,
459
- labels=labels,
463
+ pert_key=pert_key,
460
464
  control=control,
461
465
  layer=layer,
462
466
  pval_cutoff=pval_cutoff,
@@ -475,17 +479,19 @@ class Mixscape:
475
479
  continue
476
480
  else:
477
481
  gene_subset = adata_subset[
478
- (adata_subset.obs[labels] == key[1]) | (adata_subset.obs[labels] == control)
482
+ (adata_subset.obs[pert_key] == key[1]) | (adata_subset.obs[pert_key] == control)
479
483
  ].copy()
480
- sc.pp.scale(gene_subset)
484
+ with warnings.catch_warnings():
485
+ warnings.simplefilter("ignore", UserWarning)
486
+ sc.pp.scale(gene_subset)
481
487
  sc.tl.pca(gene_subset, n_comps=n_comps)
482
488
  # project cells into PCA space of gene_subset
483
489
  projected_pcs[key[1]] = np.asarray(np.dot(X, gene_subset.varm["PCs"]))
484
490
  # concatenate all pcs into a single matrix.
485
491
  projected_pcs_array = np.concatenate(list(projected_pcs.values()), axis=1)
486
492
 
487
- clf = LinearDiscriminantAnalysis(n_components=len(np.unique(adata_subset.obs[labels])) - 1)
488
- clf.fit(projected_pcs_array, adata_subset.obs[labels])
493
+ clf = LinearDiscriminantAnalysis(n_components=len(np.unique(adata_subset.obs[pert_key])) - 1)
494
+ clf.fit(projected_pcs_array, adata_subset.obs[pert_key])
489
495
  cell_embeddings = clf.transform(projected_pcs_array)
490
496
  adata.uns["mixscape_lda"] = cell_embeddings
491
497
 
@@ -495,9 +501,10 @@ class Mixscape:
495
501
  def _get_perturbation_markers(
496
502
  self,
497
503
  adata: AnnData,
504
+ *,
498
505
  split_masks: list[np.ndarray],
499
506
  categories: list[str],
500
- labels: str,
507
+ pert_key: str,
501
508
  control: str,
502
509
  layer: str,
503
510
  pval_cutoff: float,
@@ -511,7 +518,7 @@ class Mixscape:
511
518
  adata: :class:`~anndata.AnnData` object
512
519
  split_masks: List of boolean masks for each split/group.
513
520
  categories: List of split/group names.
514
- labels: The column of `.obs` with target gene labels.
521
+ pert_key: The column of `.obs` with target gene labels.
515
522
  control: Control category from the `labels` column.
516
523
  layer: Key from adata.layers whose value will be used to compare gene expression.
517
524
  pval_cutoff: P-value cut-off for selection of significantly DE genes.
@@ -526,7 +533,7 @@ class Mixscape:
526
533
  for split, split_mask in enumerate(split_masks):
527
534
  category = categories[split]
528
535
  # get gene sets for each split
529
- gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
536
+ gene_targets = list(set(adata[split_mask].obs[pert_key]).difference([control]))
530
537
  adata_split = adata[split_mask].copy()
531
538
  # find top DE genes between cells with targeting and non-targeting gRNAs
532
539
  with warnings.catch_warnings():
@@ -535,7 +542,7 @@ class Mixscape:
535
542
  sc.tl.rank_genes_groups(
536
543
  adata_split,
537
544
  layer=layer,
538
- groupby=labels,
545
+ groupby=pert_key,
539
546
  groups=gene_targets,
540
547
  reference=control,
541
548
  method=test_method,
@@ -666,7 +673,7 @@ class Mixscape:
666
673
  def plot_heatmap( # pragma: no cover # noqa: D417
667
674
  self,
668
675
  adata: AnnData,
669
- labels: str,
676
+ pert_key: str,
670
677
  target_gene: str,
671
678
  control: str,
672
679
  *,
@@ -682,7 +689,7 @@ class Mixscape:
682
689
 
683
690
  Args:
684
691
  adata: The annotated data object.
685
- labels: The column of `.obs` with target gene labels.
692
+ pert_key: The column of `.obs` with target gene labels.
686
693
  target_gene: Target gene name to visualize heatmap for.
687
694
  control: Control category from the `pert_key` column.
688
695
  layer: Key from `adata.layers` whose value will be used to perform tests on.
@@ -711,12 +718,13 @@ class Mixscape:
711
718
  """
712
719
  if "mixscape_class" not in adata.obs:
713
720
  raise ValueError("Please run `pt.tl.mixscape` first.")
714
- adata_subset = adata[(adata.obs[labels] == target_gene) | (adata.obs[labels] == control)].copy()
721
+ adata_subset = adata[(adata.obs[pert_key] == target_gene) | (adata.obs[pert_key] == control)].copy()
715
722
  with warnings.catch_warnings():
716
723
  warnings.simplefilter("ignore", RuntimeWarning)
717
724
  warnings.simplefilter("ignore", PerformanceWarning)
718
- sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
719
- sc.pp.scale(adata_subset, max_value=vmax)
725
+ warnings.simplefilter("ignore", UserWarning)
726
+ sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=pert_key, method=method)
727
+ sc.pp.scale(adata_subset, max_value=vmax)
720
728
  sc.pp.subsample(adata_subset, n_obs=subsample_number)
721
729
 
722
730
  fig = sc.pl.rank_genes_groups_heatmap(
@@ -739,7 +747,7 @@ class Mixscape:
739
747
  def plot_perturbscore( # pragma: no cover # noqa: D417
740
748
  self,
741
749
  adata: AnnData,
742
- labels: str,
750
+ pert_key: str,
743
751
  target_gene: str,
744
752
  *,
745
753
  mixscape_class: str = "mixscape_class",
@@ -758,7 +766,7 @@ class Mixscape:
758
766
 
759
767
  Args:
760
768
  adata: The annotated data object.
761
- labels: The column of `.obs` with target gene labels.
769
+ pert_key: The column of `.obs` with target gene labels.
762
770
  target_gene: Target gene name to visualize perturbation scores for.
763
771
  mixscape_class: The column of `.obs` with mixscape classifications.
764
772
  color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey.
@@ -797,21 +805,21 @@ class Mixscape:
797
805
  else:
798
806
  perturbation_score = pd.concat([perturbation_score, perturbation_score_temp])
799
807
  perturbation_score["mix"] = adata.obs[mixscape_class][perturbation_score.index]
800
- gd = list(set(perturbation_score[labels]).difference({target_gene}))[0]
808
+ gd = list(set(perturbation_score[pert_key]).difference({target_gene}))[0]
801
809
 
802
810
  # If before_mixscape is True, split densities based on original target gene classification
803
811
  if before_mixscape is True:
804
812
  palette = {gd: "#7d7d7d", target_gene: color}
805
- plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False)
813
+ plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=pert_key, fill=False, common_norm=False)
806
814
  top_r = max(plot_dens.get_lines()[cond].get_data()[1].max() for cond in range(len(plot_dens.get_lines())))
807
815
  plt.close()
808
816
  perturbation_score["y_jitter"] = perturbation_score["pvec"]
809
817
  rng = np.random.default_rng()
810
- perturbation_score.loc[perturbation_score[labels] == gd, "y_jitter"] = rng.uniform(
811
- low=0.001, high=top_r / 10, size=sum(perturbation_score[labels] == gd)
818
+ perturbation_score.loc[perturbation_score[pert_key] == gd, "y_jitter"] = rng.uniform(
819
+ low=0.001, high=top_r / 10, size=sum(perturbation_score[pert_key] == gd)
812
820
  )
813
- perturbation_score.loc[perturbation_score[labels] == target_gene, "y_jitter"] = rng.uniform(
814
- low=-top_r / 10, high=0, size=sum(perturbation_score[labels] == target_gene)
821
+ perturbation_score.loc[perturbation_score[pert_key] == target_gene, "y_jitter"] = rng.uniform(
822
+ low=-top_r / 10, high=0, size=sum(perturbation_score[pert_key] == target_gene)
815
823
  )
816
824
  # If split_by is provided, split densities based on the split_by
817
825
  if split_by is not None:
@@ -844,7 +852,7 @@ class Mixscape:
844
852
  else:
845
853
  if palette is None:
846
854
  palette = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color}
847
- plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False)
855
+ plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=pert_key, fill=False, common_norm=False)
848
856
  top_r = max(plot_dens.get_lines()[i].get_data()[1].max() for i in range(len(plot_dens.get_lines())))
849
857
  plt.close()
850
858
  perturbation_score["y_jitter"] = perturbation_score["pvec"]
@@ -899,6 +907,7 @@ class Mixscape:
899
907
  if return_fig:
900
908
  return plt.gcf()
901
909
  plt.show()
910
+
902
911
  return None
903
912
 
904
913
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1058,7 +1067,7 @@ class Mixscape:
1058
1067
  data=obs_tidy,
1059
1068
  order=order,
1060
1069
  orient="vertical",
1061
- scale=scale,
1070
+ density_norm=scale,
1062
1071
  ax=ax,
1063
1072
  hue=hue,
1064
1073
  **kwargs,
@@ -1072,7 +1081,7 @@ class Mixscape:
1072
1081
  data=obs_tidy,
1073
1082
  order=order,
1074
1083
  jitter=jitter,
1075
- color="black",
1084
+ palette="dark:black",
1076
1085
  size=size,
1077
1086
  ax=ax,
1078
1087
  hue=hue,
@@ -22,7 +22,7 @@ class PerturbationComparison:
22
22
  ) -> float:
23
23
  """Compare classification accuracy between real and simulated perturbations.
24
24
 
25
- Trains a classifier on the real perturbation data + the control data and reports a normalized
25
+ Trains a classifier on the real perturbation data & the control data and reports a normalized
26
26
  classification accuracy on the simulated perturbation.
27
27
 
28
28
  Args:
@@ -64,8 +64,8 @@ class PerturbationComparison:
64
64
  real: Real perturbed data.
65
65
  simulated: Simulated perturbed data.
66
66
  control: Control data
67
- use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph. Only valid when
68
- control (`control`) is provided.
67
+ use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph.
68
+ Only valid when control (`control`) is provided.
69
69
  n_neighbors: Number of neighbors to use in k-neighbor graph.
70
70
  random_state: Random state used for k-neighbor graph construction.
71
71
  n_jobs: Number of cores to use. Defaults to -1 (all).