pertpy 0.9.3__py3-none-any.whl → 0.9.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pertpy/tools/_mixscape.py CHANGED
@@ -18,6 +18,7 @@ from scipy.sparse import csr_matrix, issparse, spmatrix
18
18
  from sklearn.mixture import GaussianMixture
19
19
 
20
20
  import pertpy as pt
21
+ from pertpy._doc import _doc_params, doc_common_plot_args
21
22
 
22
23
  if TYPE_CHECKING:
23
24
  from collections.abc import Sequence
@@ -25,6 +26,7 @@ if TYPE_CHECKING:
25
26
  from anndata import AnnData
26
27
  from matplotlib.axes import Axes
27
28
  from matplotlib.colors import Colormap
29
+ from matplotlib.pyplot import Figure
28
30
  from scipy import sparse
29
31
 
30
32
 
@@ -102,7 +104,7 @@ class Mixscape:
102
104
  control_mask_split = control_mask & split_mask
103
105
 
104
106
  R_split = representation[split_mask]
105
- R_control = representation[control_mask_split]
107
+ R_control = representation[np.asarray(control_mask_split)]
106
108
 
107
109
  from pynndescent import NNDescent
108
110
 
@@ -110,7 +112,7 @@ class Mixscape:
110
112
  nn_index = NNDescent(R_control, **kwargs)
111
113
  indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
112
114
 
113
- X_control = np.expm1(adata.X[control_mask_split])
115
+ X_control = np.expm1(adata.X[np.asarray(control_mask_split)])
114
116
 
115
117
  n_split = split_mask.sum()
116
118
  n_control = X_control.shape[0]
@@ -254,7 +256,7 @@ class Mixscape:
254
256
  else:
255
257
  de_genes = perturbation_markers[(category, gene)]
256
258
  de_genes_indices = self._get_column_indices(adata, list(de_genes))
257
- dat = X[all_cells][:, de_genes_indices]
259
+ dat = X[np.asarray(all_cells)][:, de_genes_indices]
258
260
  converged = False
259
261
  n_iter = 0
260
262
  old_classes = adata.obs[labels][all_cells]
@@ -264,8 +266,8 @@ class Mixscape:
264
266
  # get average value for each gene over all selected cells
265
267
  # all cells in current split&Gene minus all NT cells in current split
266
268
  # Each row is for each cell, each column is for each gene, get mean for each column
267
- vec = np.mean(X[guide_cells][:, de_genes_indices], axis=0) - np.mean(
268
- X[nt_cells][:, de_genes_indices], axis=0
269
+ vec = np.mean(X[np.asarray(guide_cells)][:, de_genes_indices], axis=0) - np.mean(
270
+ X[np.asarray(nt_cells)][:, de_genes_indices], axis=0
269
271
  )
270
272
  # project cells onto the perturbation vector
271
273
  if isinstance(dat, spmatrix):
@@ -506,21 +508,23 @@ class Mixscape:
506
508
 
507
509
  return [mu, sd]
508
510
 
511
+ @_doc_params(common_plot_args=doc_common_plot_args)
509
512
  def plot_barplot( # pragma: no cover
510
513
  self,
511
514
  adata: AnnData,
512
515
  guide_rna_column: str,
516
+ *,
513
517
  mixscape_class_global: str = "mixscape_class_global",
514
518
  axis_text_x_size: int = 8,
515
519
  axis_text_y_size: int = 6,
516
520
  axis_title_size: int = 8,
517
521
  legend_title_size: int = 8,
518
522
  legend_text_size: int = 8,
519
- return_fig: bool | None = None,
520
- ax: Axes | None = None,
521
- show: bool | None = None,
522
- save: bool | str | None = None,
523
- ):
523
+ legend_bbox_to_anchor: tuple[float, float] = None,
524
+ figsize: tuple[float, float] = (25, 25),
525
+ show: bool = True,
526
+ return_fig: bool = False,
527
+ ) -> Figure | None:
524
528
  """Barplot to visualize perturbation scores calculated by the `mixscape` function.
525
529
 
526
530
  Args:
@@ -528,12 +532,17 @@ class Mixscape:
528
532
  guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels.
529
533
  The format must be <gene_target>g<#>. Examples are 'STAT2g1' and 'ATF2g1'.
530
534
  mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
531
- show: Show the plot, do not return axis.
532
- save: If True or a str, save the figure. A string is appended to the default filename.
533
- Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
535
+ axis_text_x_size: Size of the x-axis text.
536
+ axis_text_y_size: Size of the y-axis text.
537
+ axis_title_size: Size of the axis title.
538
+ legend_title_size: Size of the legend title.
539
+ legend_text_size: Size of the legend text.
540
+ legend_bbox_to_anchor: The bbox that the legend will be anchored.
541
+ figsize: The size of the figure.
542
+ {common_plot_args}
534
543
 
535
544
  Returns:
536
- If `show==False`, return a :class:`~matplotlib.axes.Axes.
545
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
537
546
 
538
547
  Examples:
539
548
  >>> import pertpy as pt
@@ -565,63 +574,66 @@ class Mixscape:
565
574
  all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"]
566
575
  NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"]
567
576
 
568
- if show:
569
- color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"}
570
- unique_genes = NP_KO_cells["gene"].unique()
571
- fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=(25, 25), sharey=True)
572
- for i, gene in enumerate(unique_genes):
573
- ax = axs[int(i / 5), i % 5]
574
- grouped_df = (
575
- NP_KO_cells[NP_KO_cells["gene"] == gene]
576
- .groupby(["guide_number", "mixscape_class_global"], observed=False)["value"]
577
- .sum()
578
- .unstack()
579
- )
580
- grouped_df.plot(
581
- kind="bar",
582
- stacked=True,
583
- color=[color_mapping[col] for col in grouped_df.columns],
584
- ax=ax,
585
- width=0.8,
586
- legend=False,
587
- )
588
- ax.set_title(
589
- gene, bbox={"facecolor": "white", "edgecolor": "black", "pad": 1}, fontsize=axis_title_size
590
- )
591
- ax.set(xlabel="sgRNA", ylabel="% of cells")
592
- sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
593
- ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size)
594
- ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size)
595
- fig.subplots_adjust(right=0.8)
596
- fig.subplots_adjust(hspace=0.5, wspace=0.5)
577
+ color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"}
578
+ unique_genes = NP_KO_cells["gene"].unique()
579
+ fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=figsize, sharey=True)
580
+ for i, gene in enumerate(unique_genes):
581
+ ax = axs[int(i / 5), i % 5]
582
+ grouped_df = (
583
+ NP_KO_cells[NP_KO_cells["gene"] == gene]
584
+ .groupby(["guide_number", "mixscape_class_global"], observed=False)["value"]
585
+ .sum()
586
+ .unstack()
587
+ )
588
+ grouped_df.plot(
589
+ kind="bar",
590
+ stacked=True,
591
+ color=[color_mapping[col] for col in grouped_df.columns],
592
+ ax=ax,
593
+ width=0.8,
594
+ legend=False,
595
+ )
596
+ ax.set_title(gene, bbox={"facecolor": "white", "edgecolor": "black", "pad": 1}, fontsize=axis_title_size)
597
+ ax.set(xlabel="sgRNA", ylabel="% of cells")
598
+ sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
599
+ ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size)
600
+ ax.set_yticks(ax.get_yticks(), ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size)
597
601
  ax.legend(
598
- title="mixscape_class_global",
602
+ title="Mixscape Class",
599
603
  loc="center right",
600
- bbox_to_anchor=(2.2, 3.5),
604
+ bbox_to_anchor=legend_bbox_to_anchor,
601
605
  frameon=True,
602
606
  fontsize=legend_text_size,
603
607
  title_fontsize=legend_title_size,
604
608
  )
605
609
 
610
+ fig.subplots_adjust(right=0.8)
611
+ fig.subplots_adjust(hspace=0.5, wspace=0.5)
606
612
  plt.tight_layout()
607
- _utils.savefig_or_show("mixscape_barplot", show=show, save=save)
608
613
 
614
+ if show:
615
+ plt.show()
616
+ if return_fig:
617
+ return fig
618
+ return None
619
+
620
+ @_doc_params(common_plot_args=doc_common_plot_args)
609
621
  def plot_heatmap( # pragma: no cover
610
622
  self,
611
623
  adata: AnnData,
612
624
  labels: str,
613
625
  target_gene: str,
614
626
  control: str,
627
+ *,
615
628
  layer: str | None = None,
616
629
  method: str | None = "wilcoxon",
617
630
  subsample_number: int | None = 900,
618
631
  vmin: float | None = -2,
619
632
  vmax: float | None = 2,
620
- return_fig: bool | None = None,
621
- show: bool | None = None,
622
- save: bool | str | None = None,
633
+ show: bool = True,
634
+ return_fig: bool = False,
623
635
  **kwds,
624
- ) -> Axes | None:
636
+ ) -> Figure | None:
625
637
  """Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first.
626
638
 
627
639
  Args:
@@ -634,14 +646,11 @@ class Mixscape:
634
646
  subsample_number: Subsample to this number of observations.
635
647
  vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
636
648
  vmax: The value representing the upper limit of the color scale. Values larger than vmax are plotted with the same color as vmax.
637
- show: Show the plot, do not return axis.
638
- save: If `True` or a `str`, save the figure. A string is appended to the default filename.
639
- Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
640
- ax: A matplotlib axes object. Only works if plotting a single component.
649
+ {common_plot_args}
641
650
  **kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`.
642
651
 
643
652
  Returns:
644
- If `show==False`, return a :class:`~matplotlib.axes.Axes`.
653
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
645
654
 
646
655
  Examples:
647
656
  >>> import pertpy as pt
@@ -663,35 +672,39 @@ class Mixscape:
663
672
  sc.pp.scale(adata_subset, max_value=vmax)
664
673
  sc.pp.subsample(adata_subset, n_obs=subsample_number)
665
674
 
666
- return sc.pl.rank_genes_groups_heatmap(
675
+ fig = sc.pl.rank_genes_groups_heatmap(
667
676
  adata_subset,
668
677
  groupby="mixscape_class",
669
678
  vmin=vmin,
670
679
  vmax=vmax,
671
680
  n_genes=20,
672
681
  groups=["NT"],
673
- return_fig=return_fig,
674
- show=show,
675
- save=save,
682
+ show=False,
676
683
  **kwds,
677
684
  )
678
685
 
686
+ if show:
687
+ plt.show()
688
+ if return_fig:
689
+ return fig
690
+ return None
691
+
692
+ @_doc_params(common_plot_args=doc_common_plot_args)
679
693
  def plot_perturbscore( # pragma: no cover
680
694
  self,
681
695
  adata: AnnData,
682
696
  labels: str,
683
697
  target_gene: str,
698
+ *,
684
699
  mixscape_class: str = "mixscape_class",
685
700
  color: str = "orange",
686
701
  palette: dict[str, str] = None,
687
702
  split_by: str = None,
688
703
  before_mixscape: bool = False,
689
704
  perturbation_type: str = "KO",
690
- return_fig: bool | None = None,
691
- ax: Axes | None = None,
692
- show: bool | None = None,
693
- save: bool | str | None = None,
694
- ) -> None:
705
+ show: bool = True,
706
+ return_fig: bool = False,
707
+ ) -> Figure | None:
695
708
  """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
696
709
 
697
710
  Requires `pt.tl.mixscape` to be run first.
@@ -710,6 +723,10 @@ class Mixscape:
710
723
  before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification.
711
724
  Default is set to NULL and plots cells by original class ID.
712
725
  perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
726
+ {common_plot_args}
727
+
728
+ Returns:
729
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
713
730
 
714
731
  Examples:
715
732
  Visualizing the perturbation scores for the cells in a dataset:
@@ -778,15 +795,6 @@ class Mixscape:
778
795
  plt.legend(title="gene_target", title_fontsize=14, fontsize=12)
779
796
  sns.despine()
780
797
 
781
- if save:
782
- plt.savefig(save, bbox_inches="tight")
783
- if show:
784
- plt.show()
785
- if return_fig:
786
- return plt.gcf()
787
- if not (show or save):
788
- return plt.gca()
789
-
790
798
  # If before_mixscape is False, split densities based on mixscape classifications
791
799
  else:
792
800
  if palette is None:
@@ -843,19 +851,18 @@ class Mixscape:
843
851
  plt.legend(title="mixscape class", title_fontsize=14, fontsize=12)
844
852
  sns.despine()
845
853
 
846
- if save:
847
- plt.savefig(save, bbox_inches="tight")
848
- if show:
849
- plt.show()
850
- if return_fig:
851
- return plt.gcf()
852
- if not (show or save):
853
- return plt.gca()
854
+ if show:
855
+ plt.show()
856
+ if return_fig:
857
+ return plt.gcf()
858
+ return None
854
859
 
860
+ @_doc_params(common_plot_args=doc_common_plot_args)
855
861
  def plot_violin( # pragma: no cover
856
862
  self,
857
863
  adata: AnnData,
858
864
  target_gene_idents: str | list[str],
865
+ *,
859
866
  keys: str | Sequence[str] = "mixscape_class_p_ko",
860
867
  groupby: str | None = "mixscape_class",
861
868
  log: bool = False,
@@ -872,10 +879,10 @@ class Mixscape:
872
879
  ylabel: str | Sequence[str] | None = None,
873
880
  rotation: float | None = None,
874
881
  ax: Axes | None = None,
875
- show: bool | None = None,
876
- save: bool | str | None = None,
882
+ show: bool = True,
883
+ return_fig: bool = False,
877
884
  **kwargs,
878
- ):
885
+ ) -> Axes | Figure | None:
879
886
  """Violin plot using mixscape results.
880
887
 
881
888
  Requires `pt.tl.mixscape` to be run first.
@@ -892,14 +899,12 @@ class Mixscape:
892
899
  xlabel: Label of the x-axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
893
900
  ylabel: Label of the y-axis. If `None` and `groupby` is `None`, defaults to `'value'`.
894
901
  If `None` and `groubpy` is not `None`, defaults to `keys`.
895
- show: Show the plot, do not return axis.
896
- save: If `True` or a `str`, save the figure. A string is appended to the default filename.
897
- Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
898
902
  ax: A matplotlib axes object. Only works if plotting a single component.
903
+ {common_plot_args}
899
904
  **kwargs: Additional arguments to `seaborn.violinplot`.
900
905
 
901
906
  Returns:
902
- A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`.
907
+ If `return_fig` is `True`, returns the figure (as Axes list if it's a multi-panel plot), otherwise `None`.
903
908
 
904
909
  Examples:
905
910
  >>> import pertpy as pt
@@ -1045,20 +1050,24 @@ class Mixscape:
1045
1050
  show = settings.autoshow if show is None else show
1046
1051
  if hue is not None and stripplot is True:
1047
1052
  plt.legend(handles, labels)
1048
- _utils.savefig_or_show("mixscape_violin", show=show, save=save)
1049
1053
 
1050
- if not show:
1054
+ if show:
1055
+ plt.show()
1056
+ if return_fig:
1051
1057
  if multi_panel and groupby is None and len(ys) == 1:
1052
1058
  return g
1053
1059
  elif len(axs) == 1:
1054
1060
  return axs[0]
1055
1061
  else:
1056
1062
  return axs
1063
+ return None
1057
1064
 
1065
+ @_doc_params(common_plot_args=doc_common_plot_args)
1058
1066
  def plot_lda( # pragma: no cover
1059
1067
  self,
1060
1068
  adata: AnnData,
1061
1069
  control: str,
1070
+ *,
1062
1071
  mixscape_class: str = "mixscape_class",
1063
1072
  mixscape_class_global: str = "mixscape_class_global",
1064
1073
  perturbation_type: str | None = "KO",
@@ -1066,12 +1075,11 @@ class Mixscape:
1066
1075
  n_components: int | None = None,
1067
1076
  color_map: Colormap | str | None = None,
1068
1077
  palette: str | Sequence[str] | None = None,
1069
- return_fig: bool | None = None,
1070
1078
  ax: Axes | None = None,
1071
- show: bool | None = None,
1072
- save: bool | str | None = None,
1079
+ show: bool = True,
1080
+ return_fig: bool = False,
1073
1081
  **kwds,
1074
- ) -> None:
1082
+ ) -> Figure | None:
1075
1083
  """Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
1076
1084
 
1077
1085
  Args:
@@ -1082,9 +1090,7 @@ class Mixscape:
1082
1090
  perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
1083
1091
  lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
1084
1092
  n_components: The number of dimensions of the embedding.
1085
- show: Show the plot, do not return axis.
1086
- save: If `True` or a `str`, save the figure. A string is appended to the default filename.
1087
- Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
1093
+ {common_plot_args}
1088
1094
  **kwds: Additional arguments to `scanpy.pl.umap`.
1089
1095
 
1090
1096
  Examples:
@@ -1112,14 +1118,19 @@ class Mixscape:
1112
1118
  n_components = adata_subset.uns[lda_key].shape[1]
1113
1119
  sc.pp.neighbors(adata_subset, use_rep=lda_key)
1114
1120
  sc.tl.umap(adata_subset, n_components=n_components)
1115
- sc.pl.umap(
1121
+ fig = sc.pl.umap(
1116
1122
  adata_subset,
1117
1123
  color=mixscape_class,
1118
1124
  palette=palette,
1119
1125
  color_map=color_map,
1120
1126
  return_fig=return_fig,
1121
- show=show,
1122
- save=save,
1127
+ show=False,
1123
1128
  ax=ax,
1124
1129
  **kwds,
1125
1130
  )
1131
+
1132
+ if show:
1133
+ plt.show()
1134
+ if return_fig:
1135
+ return fig
1136
+ return None
@@ -7,6 +7,7 @@ import pandas as pd
7
7
  from anndata import AnnData
8
8
  from lamin_utils import logger
9
9
  from rich import print
10
+ from scipy.stats import entropy
10
11
 
11
12
  if TYPE_CHECKING:
12
13
  from collections.abc import Iterable
@@ -41,7 +42,7 @@ class PerturbationSpace:
41
42
  Args:
42
43
  adata: Anndata object of size cells x genes.
43
44
  target_col: .obs column name that stores the label of the perturbation applied to each cell.
44
- group_col: .obs column name that stores the label of the group of eah cell. If None, ignore groups.
45
+ group_col: .obs column name that stores the label of the group of each cell. If None, ignore groups.
45
46
  reference_key: The key of the control values.
46
47
  layer_key: Key of the AnnData layer to use for computation.
47
48
  new_layer_key: the results are stored in the given layer.
@@ -364,50 +365,58 @@ class PerturbationSpace:
364
365
  self,
365
366
  adata: AnnData,
366
367
  column: str = "perturbation",
368
+ column_uncertainty_score_key: str = "perturbation_transfer_uncertainty",
367
369
  target_val: str = "unknown",
368
- n_neighbors: int = 5,
369
- use_rep: str = "X_umap",
370
+ neighbors_key: str = "neighbors",
370
371
  ) -> None:
371
372
  """Impute missing values in the specified column using KNN imputation in the space defined by `use_rep`.
372
373
 
374
+ Uncertainty is calculated as the entropy of the label distribution in the neighborhood of the target cell.
375
+ In other words, a cell where all neighbors have the same set of labels will have an uncertainty of 0, whereas a cell
376
+ where all neighbors have many different labels will have high uncertainty.
377
+
373
378
  Args:
374
379
  adata: The AnnData object containing single-cell data.
375
- column: The column name in AnnData object to perform imputation on.
380
+ column: The column name in adata.obs to perform imputation on.
381
+ column_uncertainty_score_key: The column name in adata.obs to store the uncertainty score of the label transfer.
376
382
  target_val: The target value to impute.
377
- n_neighbors: Number of neighbors to use for imputation.
378
- use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored.
383
+ neighbors_key: The key in adata.uns where the neighbors are stored.
379
384
 
380
385
  Examples:
381
386
  >>> import pertpy as pt
382
387
  >>> import scanpy as sc
383
388
  >>> import numpy as np
384
389
  >>> adata = sc.datasets.pbmc68k_reduced()
385
- >>> rng = np.random.default_rng()
386
- >>> adata.obs["perturbation"] = rng.choice(
387
- ... ["A", "B", "C", "unknown"], size=adata.n_obs, p=[0.33, 0.33, 0.33, 0.01]
388
- ... )
390
+ >>> # randomly dropout 10% of the data annotations
391
+ >>> adata.obs["perturbation"] = adata.obs["louvain"].astype(str).copy()
392
+ >>> random_cells = np.random.choice(adata.obs.index, int(adata.obs.shape[0] * 0.1), replace=False)
393
+ >>> adata.obs.loc[random_cells, "perturbation"] = "unknown"
389
394
  >>> sc.pp.neighbors(adata)
390
395
  >>> sc.tl.umap(adata)
391
396
  >>> ps = pt.tl.PseudobulkSpace()
392
- >>> ps.label_transfer(adata, n_neighbors=5, use_rep="X_umap")
397
+ >>> ps.label_transfer(adata)
393
398
  """
394
- if use_rep not in adata.obsm:
395
- raise ValueError(f"Representation {use_rep} not found in the AnnData object.")
396
-
397
- embedding = adata.obsm[use_rep]
398
-
399
- from pynndescent import NNDescent
400
-
401
- nnd = NNDescent(embedding, n_neighbors=n_neighbors)
402
- indices, _ = nnd.query(embedding, k=n_neighbors)
403
-
404
- perturbations = np.array(adata.obs[column])
405
- missing_mask = perturbations == target_val
406
-
407
- for idx in np.where(missing_mask)[0]:
408
- neighbor_indices = indices[idx]
409
- neighbor_categories = perturbations[neighbor_indices]
410
- most_common = pd.Series(neighbor_categories).mode()[0]
411
- perturbations[idx] = most_common
412
-
413
- adata.obs[column] = perturbations
399
+ if neighbors_key not in adata.uns:
400
+ raise ValueError(f"Key {neighbors_key} not found in adata.uns. Please run `sc.pp.neighbors` first.")
401
+
402
+ labels = adata.obs[column].astype(str)
403
+ target_cells = labels == target_val
404
+
405
+ connectivities = adata.obsp[adata.uns[neighbors_key]["connectivities_key"]]
406
+ # convert labels to an incidence matrix
407
+ one_hot_encoded_labels = adata.obs[column].astype(str).str.get_dummies()
408
+ # convert to distance-weighted neighborhood incidence matrix
409
+ weighted_label_occurence = pd.DataFrame(
410
+ (one_hot_encoded_labels.values.T * connectivities).T,
411
+ index=adata.obs_names,
412
+ columns=one_hot_encoded_labels.columns,
413
+ )
414
+ # choose best label for each target cell
415
+ best_labels = weighted_label_occurence.drop(target_val, axis=1)[target_cells].idxmax(axis=1)
416
+ adata.obs[column] = labels
417
+ adata.obs.loc[target_cells, column] = best_labels
418
+
419
+ # calculate uncertainty
420
+ uncertainty = np.zeros(adata.n_obs)
421
+ uncertainty[target_cells] = entropy(weighted_label_occurence.drop(target_val, axis=1)[target_cells], axis=1)
422
+ adata.obs[column_uncertainty_score_key] = uncertainty
@@ -1,13 +1,20 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from typing import TYPE_CHECKING
4
+
3
5
  import decoupler as dc
6
+ import matplotlib.pyplot as plt
4
7
  import numpy as np
5
8
  from anndata import AnnData
6
9
  from sklearn.cluster import DBSCAN, KMeans
7
10
 
11
+ from pertpy._doc import _doc_params, doc_common_plot_args
8
12
  from pertpy.tools._perturbation_space._clustering import ClusteringSpace
9
13
  from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
10
14
 
15
+ if TYPE_CHECKING:
16
+ from matplotlib.pyplot import Figure
17
+
11
18
 
12
19
  class CentroidSpace(PerturbationSpace):
13
20
  """Computes the centroids per perturbation of a pre-computed embedding."""
@@ -168,6 +175,49 @@ class PseudobulkSpace(PerturbationSpace):
168
175
 
169
176
  return ps_adata
170
177
 
178
+ @_doc_params(common_plot_args=doc_common_plot_args)
179
+ def plot_psbulk_samples(
180
+ self,
181
+ adata: AnnData,
182
+ groupby: str,
183
+ *,
184
+ show: bool = True,
185
+ return_fig: bool = False,
186
+ **kwargs,
187
+ ) -> Figure | None:
188
+ """Plot the pseudobulk samples of an AnnData object.
189
+
190
+ Plot the count number vs. the number of cells per pseudobulk sample.
191
+
192
+ Args:
193
+ adata: Anndata containing pseudobulk samples.
194
+ groupby: `.obs` column to color the samples by.
195
+ {common_plot_args}
196
+ **kwargs: Are passed to decoupler's plot_psbulk_samples.
197
+
198
+ Returns:
199
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
200
+
201
+ Examples:
202
+ >>> import pertpy as pt
203
+ >>> adata = pt.dt.zhang_2021()
204
+ >>> ps = pt.tl.PseudobulkSpace()
205
+ >>> pdata = ps.compute(
206
+ ... adata, target_col="Patient", groups_col="Cluster", mode="sum", min_cells=10, min_counts=1000
207
+ ... )
208
+ >>> ps.plot_psbulk_samples(pdata, groupby=["Patient", "Major celltype"], figsize=(12, 4))
209
+
210
+ Preview:
211
+ .. image:: /_static/docstring_previews/pseudobulk_samples.png
212
+ """
213
+ fig = dc.plot_psbulk_samples(adata, groupby, return_fig=True, **kwargs)
214
+
215
+ if show:
216
+ plt.show()
217
+ if return_fig:
218
+ return fig
219
+ return None
220
+
171
221
 
172
222
  class KMeansSpace(ClusteringSpace):
173
223
  """Computes K-Means clustering of the expression values."""