pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,7 +1,7 @@
1
- import os
1
+ import math
2
2
  from abc import ABC, abstractmethod
3
- from dataclasses import dataclass
4
- from itertools import chain
3
+ from collections.abc import Iterable, Mapping, Sequence
4
+ from itertools import zip_longest
5
5
  from types import MappingProxyType
6
6
 
7
7
  import adjustText
@@ -11,27 +11,14 @@ import matplotlib.pyplot as plt
11
11
  import numpy as np
12
12
  import pandas as pd
13
13
  import seaborn as sns
14
+ from formulaic_contrasts import FormulaicContrasts
15
+ from lamin_utils import logger
16
+ from matplotlib.pyplot import Figure
14
17
  from matplotlib.ticker import MaxNLocator
15
18
 
19
+ from pertpy._doc import _doc_params, doc_common_plot_args
20
+ from pertpy.tools import PseudobulkSpace
16
21
  from pertpy.tools._differential_gene_expression._checks import check_is_numeric_matrix
17
- from pertpy.tools._differential_gene_expression._formulaic import (
18
- AmbiguousAttributeError,
19
- Factor,
20
- get_factor_storage_and_materializer,
21
- resolve_ambiguous,
22
- )
23
-
24
-
25
- @dataclass
26
- class Contrast:
27
- """Simple contrast for comparison between groups"""
28
-
29
- column: str
30
- baseline: str
31
- group_to_compare: str
32
-
33
-
34
- ContrastType = Contrast | tuple[str, str, str]
35
22
 
36
23
 
37
24
  class MethodBase(ABC):
@@ -58,7 +45,7 @@ class MethodBase(ABC):
58
45
  if self.layer is None:
59
46
  return self.adata.X
60
47
  else:
61
- return self.adata.layer[self.layer]
48
+ return self.adata.layers[self.layer]
62
49
 
63
50
  @classmethod
64
51
  @abstractmethod
@@ -91,9 +78,28 @@ class MethodBase(ABC):
91
78
 
92
79
  Returns:
93
80
  Pandas dataframe with results ordered by significance. If multiple comparisons were performed this is indicated in an additional column.
81
+
82
+ Examples:
83
+ >>> # Example with EdgeR
84
+ >>> import pertpy as pt
85
+ >>> adata = pt.dt.zhang_2021()
86
+ >>> adata.layers["counts"] = adata.X.copy()
87
+ >>> ps = pt.tl.PseudobulkSpace()
88
+ >>> pdata = ps.compute(
89
+ ... adata,
90
+ ... target_col="Patient",
91
+ ... groups_col="Cluster",
92
+ ... layer_key="counts",
93
+ ... mode="sum",
94
+ ... min_cells=10,
95
+ ... min_counts=1000,
96
+ ... )
97
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
98
+ >>> res_df = edgr.compare_groups(pdata, column="Efficacy", baseline="SD", groups_to_compare=["PR", "PD"])
94
99
  """
95
100
  ...
96
101
 
102
+ @_doc_params(common_plot_args=doc_common_plot_args)
97
103
  def plot_volcano(
98
104
  self,
99
105
  data: pd.DataFrame | ad.AnnData,
@@ -115,13 +121,13 @@ class MethodBase(ABC):
115
121
  figsize: tuple[int, int] = (5, 5),
116
122
  legend_pos: tuple[float, float] = (1.6, 1),
117
123
  point_sizes: tuple[int, int] = (15, 150),
118
- save: bool | str | None = None,
119
124
  shapes: list[str] | None = None,
120
125
  shape_order: list[str] | None = None,
121
126
  x_label: str | None = None,
122
127
  y_label: str | None = None,
128
+ return_fig: bool = False,
123
129
  **kwargs: int,
124
- ) -> None:
130
+ ) -> Figure | None:
125
131
  """Creates a volcano plot from a pandas DataFrame or Anndata.
126
132
 
127
133
  Args:
@@ -143,12 +149,40 @@ class MethodBase(ABC):
143
149
  top_right_frame: Whether to show the top and right frame of the plot.
144
150
  figsize: Size of the figure.
145
151
  legend_pos: Position of the legend as determined by matplotlib.
146
- save: Saves the plot if True or to the path provided.
147
152
  shapes: List of matplotlib marker ids.
148
153
  shape_order: Order of categories for shapes.
149
154
  x_label: Label for the x-axis.
150
155
  y_label: Label for the y-axis.
156
+ {common_plot_args}
151
157
  **kwargs: Additional arguments for seaborn.scatterplot.
158
+
159
+ Returns:
160
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
161
+
162
+ Examples:
163
+ >>> # Example with EdgeR
164
+ >>> import pertpy as pt
165
+ >>> adata = pt.dt.zhang_2021()
166
+ >>> adata.layers["counts"] = adata.X.copy()
167
+ >>> ps = pt.tl.PseudobulkSpace()
168
+ >>> pdata = ps.compute(
169
+ ... adata,
170
+ ... target_col="Patient",
171
+ ... groups_col="Cluster",
172
+ ... layer_key="counts",
173
+ ... mode="sum",
174
+ ... min_cells=10,
175
+ ... min_counts=1000,
176
+ ... )
177
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
178
+ >>> edgr.fit()
179
+ >>> res_df = edgr.test_contrasts(
180
+ ... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
181
+ ... )
182
+ >>> edgr.plot_volcano(res_df, log2fc_thresh=0)
183
+
184
+ Preview:
185
+ .. image:: /_static/docstring_previews/de_volcano.png
152
186
  """
153
187
  if colors is None:
154
188
  colors = ["gray", "#D62728", "#1F77B4"]
@@ -243,7 +277,7 @@ class MethodBase(ABC):
243
277
  if varm_key is None:
244
278
  raise ValueError("Please pass a .varm key to use for plotting")
245
279
 
246
- raise NotImplementedError("Anndata not implemented yet")
280
+ raise NotImplementedError("Anndata not implemented yet") # TODO: Implement this
247
281
  df = data.varm[varm_key].copy()
248
282
 
249
283
  df = data.copy(deep=True)
@@ -449,26 +483,405 @@ class MethodBase(ABC):
449
483
 
450
484
  plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
451
485
 
452
- # TODO replace with scanpy save style
453
- if save:
454
- files = os.listdir()
455
- for x in range(100):
456
- file_pref = "volcano_" + "%02d" % (x,)
457
- if len([x for x in files if x.startswith(file_pref)]) == 0:
458
- plt.savefig(file_pref + ".png", dpi=300, bbox_inches="tight")
459
- plt.savefig(file_pref + ".svg", bbox_inches="tight")
460
- break
461
- elif isinstance(save, str):
462
- plt.savefig(save + ".png", dpi=300, bbox_inches="tight")
463
- plt.savefig(save + ".svg", bbox_inches="tight")
486
+ if return_fig:
487
+ return plt.gcf()
488
+ plt.show()
489
+ return None
490
+
491
+ @_doc_params(common_plot_args=doc_common_plot_args)
492
+ def plot_paired(
493
+ self,
494
+ adata: ad.AnnData,
495
+ results_df: pd.DataFrame,
496
+ groupby: str,
497
+ pairedby: str,
498
+ *,
499
+ var_names: Sequence[str] = None,
500
+ n_top_vars: int = 15,
501
+ layer: str = None,
502
+ pvalue_col: str = "adj_p_value",
503
+ symbol_col: str = "variable",
504
+ n_cols: int = 4,
505
+ panel_size: tuple[int, int] = (5, 5),
506
+ show_legend: bool = True,
507
+ size: int = 10,
508
+ y_label: str = "expression",
509
+ pvalue_template=lambda x: f"p={x:.2e}",
510
+ boxplot_properties=None,
511
+ palette=None,
512
+ return_fig: bool = False,
513
+ ) -> Figure | None:
514
+ """Creates a pairwise expression plot from a Pandas DataFrame or Anndata.
515
+
516
+ Visualizes a panel of paired scatterplots per variable.
517
+
518
+ Args:
519
+ adata: AnnData object, can be pseudobulked.
520
+ results_df: DataFrame with results from a differential expression test.
521
+ groupby: .obs column containing the grouping. Must contain exactly two different values.
522
+ pairedby: .obs column containing the pairing (e.g. "patient_id"). If None, an independent t-test is performed.
523
+ var_names: Variables to plot.
524
+ n_top_vars: Number of top variables to plot.
525
+ layer: Layer to use for plotting.
526
+ pvalue_col: Column name of the p values.
527
+ symbol_col: Column name of gene IDs.
528
+ n_cols: Number of columns in the plot.
529
+ panel_size: Size of each panel.
530
+ show_legend: Whether to show the legend.
531
+ size: Size of the points.
532
+ y_label: Label for the y-axis.
533
+ pvalue_template: Template for the p-value string displayed in the title of each panel.
534
+ boxplot_properties: Additional properties for the boxplot, passed to seaborn.boxplot.
535
+ palette: Color palette for the line- and stripplot.
536
+ {common_plot_args}
537
+
538
+ Returns:
539
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
540
+
541
+ Examples:
542
+ >>> # Example with EdgeR
543
+ >>> import pertpy as pt
544
+ >>> adata = pt.dt.zhang_2021()
545
+ >>> adata.layers["counts"] = adata.X.copy()
546
+ >>> ps = pt.tl.PseudobulkSpace()
547
+ >>> pdata = ps.compute(
548
+ ... adata,
549
+ ... target_col="Patient",
550
+ ... groups_col="Cluster",
551
+ ... layer_key="counts",
552
+ ... mode="sum",
553
+ ... min_cells=10,
554
+ ... min_counts=1000,
555
+ ... )
556
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
557
+ >>> edgr.fit()
558
+ >>> res_df = edgr.test_contrasts(
559
+ ... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
560
+ ... )
561
+ >>> edgr.plot_paired(pdata, results_df=res_df, n_top_vars=8, groupby="Treatment", pairedby="Efficacy")
562
+
563
+ Preview:
564
+ .. image:: /_static/docstring_previews/de_paired_expression.png
565
+ """
566
+ if boxplot_properties is None:
567
+ boxplot_properties = {}
568
+ groups = adata.obs[groupby].unique()
569
+ if len(groups) != 2:
570
+ raise ValueError("The number of groups in the group_by column must be exactly 2 to enable paired testing")
571
+
572
+ if var_names is None:
573
+ var_names = results_df.head(n_top_vars)[symbol_col].tolist()
574
+
575
+ adata = adata[:, var_names]
576
+
577
+ if any(adata.obs[[groupby, pairedby]].value_counts() > 1):
578
+ logger.info("Performing pseudobulk for paired samples")
579
+ ps = PseudobulkSpace()
580
+ adata = ps.compute(
581
+ adata, target_col=groupby, groups_col=pairedby, layer_key=layer, mode="sum", min_cells=1, min_counts=1
582
+ )
464
583
 
584
+ if layer is not None:
585
+ X = adata.layers[layer]
586
+ else:
587
+ X = adata.X
588
+ try:
589
+ X = X.toarray()
590
+ except AttributeError:
591
+ pass
592
+
593
+ groupby_cols = [pairedby, groupby]
594
+ df = adata.obs.loc[:, groupby_cols].join(pd.DataFrame(X, index=adata.obs_names, columns=var_names))
595
+
596
+ # remove unpaired samples
597
+ paired_samples = set(df[df[groupby] == groups[0]][pairedby]) & set(df[df[groupby] == groups[1]][pairedby])
598
+ df = df[df[pairedby].isin(paired_samples)]
599
+ removed_samples = adata.obs[pairedby].nunique() - len(df[pairedby].unique())
600
+ if removed_samples > 0:
601
+ logger.warning(f"{removed_samples} unpaired samples removed")
602
+
603
+ pvalues = results_df.set_index(symbol_col).loc[var_names, pvalue_col].values
604
+ df.reset_index(drop=False, inplace=True)
605
+
606
+ # transform data for seaborn
607
+ df_melt = df.melt(
608
+ id_vars=groupby_cols,
609
+ var_name="var",
610
+ value_name="val",
611
+ )
612
+
613
+ n_panels = len(var_names)
614
+ nrows = math.ceil(n_panels / n_cols)
615
+ ncols = min(n_cols, n_panels)
616
+
617
+ fig, axes = plt.subplots(
618
+ nrows,
619
+ ncols,
620
+ figsize=(ncols * panel_size[0], nrows * panel_size[1]),
621
+ tight_layout=True,
622
+ squeeze=False,
623
+ )
624
+ axes = axes.flatten()
625
+ for i, (var, ax) in enumerate(zip_longest(var_names, axes)):
626
+ if var is not None:
627
+ sns.boxplot(
628
+ x=groupby,
629
+ data=df_melt.loc[df_melt["var"] == var],
630
+ y="val",
631
+ ax=ax,
632
+ color="white",
633
+ fliersize=0,
634
+ **boxplot_properties,
635
+ )
636
+ if pairedby is not None:
637
+ sns.lineplot(
638
+ x=groupby,
639
+ data=df_melt.loc[df_melt["var"] == var],
640
+ y="val",
641
+ ax=ax,
642
+ hue=pairedby,
643
+ legend=False,
644
+ errorbar=None,
645
+ palette=palette,
646
+ )
647
+ jitter = 0 if pairedby else True
648
+ sns.stripplot(
649
+ x=groupby,
650
+ data=df_melt.loc[df_melt["var"] == var],
651
+ y="val",
652
+ ax=ax,
653
+ hue=pairedby,
654
+ jitter=jitter,
655
+ size=size,
656
+ linewidth=1,
657
+ palette=palette,
658
+ )
659
+
660
+ ax.set_xlabel("")
661
+ ax.tick_params(
662
+ axis="x",
663
+ labelsize=15,
664
+ )
665
+ ax.legend().set_visible(False)
666
+ ax.set_ylabel(y_label)
667
+ ax.set_title(f"{var}\n{pvalue_template(pvalues[i])}")
668
+ else:
669
+ ax.set_visible(False)
670
+ fig.tight_layout()
671
+
672
+ if show_legend is True:
673
+ axes[n_panels - 1].legend().set_visible(True)
674
+ axes[n_panels - 1].legend(
675
+ bbox_to_anchor=(0.5, -0.1), loc="upper center", ncol=adata.obs[pairedby].nunique()
676
+ )
677
+
678
+ plt.tight_layout()
679
+ if return_fig:
680
+ return plt.gcf()
465
681
  plt.show()
682
+ return None
683
+
684
+ @_doc_params(common_plot_args=doc_common_plot_args)
685
+ def plot_fold_change(
686
+ self,
687
+ results_df: pd.DataFrame,
688
+ *,
689
+ var_names: Sequence[str] = None,
690
+ n_top_vars: int = 15,
691
+ log2fc_col: str = "log_fc",
692
+ symbol_col: str = "variable",
693
+ y_label: str = "Log2 fold change",
694
+ figsize: tuple[int, int] = (10, 5),
695
+ return_fig: bool = False,
696
+ **barplot_kwargs,
697
+ ) -> Figure | None:
698
+ """Plot a metric from the results as a bar chart, optionally with additional information about paired samples in a scatter plot.
699
+
700
+ Args:
701
+ results_df: DataFrame with results from DE analysis.
702
+ var_names: Variables to plot. If None, the top n_top_vars variables based on the log2 fold change are plotted.
703
+ n_top_vars: Number of top variables to plot. The top and bottom n_top_vars variables are plotted, respectively.
704
+ log2fc_col: Column name of log2 Fold-Change values.
705
+ symbol_col: Column name of gene IDs.
706
+ y_label: Label for the y-axis.
707
+ figsize: Size of the figure.
708
+ {common_plot_args}
709
+ **barplot_kwargs: Additional arguments for seaborn.barplot.
710
+
711
+ Returns:
712
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
713
+
714
+ Examples:
715
+ >>> # Example with EdgeR
716
+ >>> import pertpy as pt
717
+ >>> adata = pt.dt.zhang_2021()
718
+ >>> adata.layers["counts"] = adata.X.copy()
719
+ >>> ps = pt.tl.PseudobulkSpace()
720
+ >>> pdata = ps.compute(
721
+ ... adata,
722
+ ... target_col="Patient",
723
+ ... groups_col="Cluster",
724
+ ... layer_key="counts",
725
+ ... mode="sum",
726
+ ... min_cells=10,
727
+ ... min_counts=1000,
728
+ ... )
729
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
730
+ >>> edgr.fit()
731
+ >>> res_df = edgr.test_contrasts(
732
+ ... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
733
+ ... )
734
+ >>> edgr.plot_fold_change(res_df)
735
+
736
+ Preview:
737
+ .. image:: /_static/docstring_previews/de_fold_change.png
738
+ """
739
+ if var_names is None:
740
+ var_names = results_df.sort_values(log2fc_col, ascending=False).head(n_top_vars)[symbol_col].tolist()
741
+ var_names += results_df.sort_values(log2fc_col, ascending=True).head(n_top_vars)[symbol_col].tolist()
742
+ assert len(var_names) == 2 * n_top_vars
743
+
744
+ df = results_df[results_df[symbol_col].isin(var_names)]
745
+ df.sort_values(log2fc_col, ascending=False, inplace=True)
746
+
747
+ plt.figure(figsize=figsize)
748
+ sns.barplot(
749
+ x=symbol_col,
750
+ y=log2fc_col,
751
+ data=df,
752
+ palette="RdBu",
753
+ legend=False,
754
+ **barplot_kwargs,
755
+ )
756
+ plt.xticks(rotation=90)
757
+ plt.xlabel("")
758
+ plt.ylabel(y_label)
759
+
760
+ if return_fig:
761
+ return plt.gcf()
762
+ plt.show()
763
+ return None
764
+
765
+ @_doc_params(common_plot_args=doc_common_plot_args)
766
+ def plot_multicomparison_fc(
767
+ self,
768
+ results_df: pd.DataFrame,
769
+ *,
770
+ n_top_vars=15,
771
+ contrast_col: str = "contrast",
772
+ log2fc_col: str = "log_fc",
773
+ pvalue_col: str = "adj_p_value",
774
+ symbol_col: str = "variable",
775
+ marker_size: int = 100,
776
+ figsize: tuple[int, int] = (10, 2),
777
+ x_label: str = "Contrast",
778
+ y_label: str = "Gene",
779
+ return_fig: bool = False,
780
+ **heatmap_kwargs,
781
+ ) -> Figure | None:
782
+ """Plot a matrix of log2 fold changes from the results.
783
+
784
+ Args:
785
+ results_df: DataFrame with results from DE analysis.
786
+ n_top_vars: Number of top variables to plot per group.
787
+ contrast_col: Column in results_df containing information about the contrast.
788
+ log2fc_col: Column in results_df containing the log2 fold change.
789
+ pvalue_col: Column in results_df containing the p-value. Can be used to switch between adjusted and unadjusted p-values.
790
+ symbol_col: Column in results_df containing the gene symbol.
791
+ marker_size: Size of the biggest marker for significant variables.
792
+ figsize: Size of the figure.
793
+ x_label: Label for the x-axis.
794
+ y_label: Label for the y-axis.
795
+ {common_plot_args}
796
+ **heatmap_kwargs: Additional arguments for seaborn.heatmap.
797
+
798
+ Returns:
799
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
800
+
801
+ Examples:
802
+ >>> # Example with EdgeR
803
+ >>> import pertpy as pt
804
+ >>> adata = pt.dt.zhang_2021()
805
+ >>> adata.layers["counts"] = adata.X.copy()
806
+ >>> ps = pt.tl.PseudobulkSpace()
807
+ >>> pdata = ps.compute(
808
+ ... adata,
809
+ ... target_col="Patient",
810
+ ... groups_col="Cluster",
811
+ ... layer_key="counts",
812
+ ... mode="sum",
813
+ ... min_cells=10,
814
+ ... min_counts=1000,
815
+ ... )
816
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
817
+ >>> res_df = edgr.compare_groups(pdata, column="Efficacy", baseline="SD", groups_to_compare=["PR", "PD"])
818
+ >>> edgr.plot_multicomparison_fc(res_df)
819
+
820
+ Preview:
821
+ .. image:: /_static/docstring_previews/de_multicomparison_fc.png
822
+ """
823
+ groups = results_df[contrast_col].unique().tolist()
824
+
825
+ results_df["abs_log_fc"] = results_df[log2fc_col].abs()
826
+
827
+ def _get_significance(p_val):
828
+ if p_val < 0.001:
829
+ return "< 0.001"
830
+ elif p_val < 0.01:
831
+ return "< 0.01"
832
+ elif p_val < 0.1:
833
+ return "< 0.1"
834
+ else:
835
+ return "n.s."
836
+
837
+ results_df["significance"] = results_df[pvalue_col].apply(_get_significance)
838
+
839
+ var_names = []
840
+ for group in groups:
841
+ var_names += (
842
+ results_df[results_df[contrast_col] == group]
843
+ .sort_values("abs_log_fc", ascending=False)
844
+ .head(n_top_vars)[symbol_col]
845
+ .tolist()
846
+ )
847
+
848
+ results_df = results_df[results_df[symbol_col].isin(var_names)]
849
+ df = results_df.pivot(index=contrast_col, columns=symbol_col, values=log2fc_col)[var_names]
850
+
851
+ plt.figure(figsize=figsize)
852
+ sns.heatmap(df, **heatmap_kwargs, cmap="coolwarm", center=0, cbar_kws={"label": "Log2 fold change"})
853
+
854
+ _size = {"< 0.001": marker_size, "< 0.01": math.floor(marker_size / 2), "< 0.1": math.floor(marker_size / 4)}
855
+ x_locs, x_labels = plt.xticks()[0], [label.get_text() for label in plt.xticks()[1]]
856
+ y_locs, y_labels = plt.yticks()[0], [label.get_text() for label in plt.yticks()[1]]
857
+
858
+ for _i, row in results_df.iterrows():
859
+ if row["significance"] != "n.s.":
860
+ plt.scatter(
861
+ x=x_locs[x_labels.index(row[symbol_col])],
862
+ y=y_locs[y_labels.index(row[contrast_col])],
863
+ s=_size[row["significance"]],
864
+ marker="*",
865
+ c="white",
866
+ )
867
+
868
+ plt.scatter([], [], s=marker_size, marker="*", c="black", label="< 0.001")
869
+ plt.scatter([], [], s=math.floor(marker_size / 2), marker="*", c="black", label="< 0.01")
870
+ plt.scatter([], [], s=math.floor(marker_size / 4), marker="*", c="black", label="< 0.1")
871
+ plt.legend(title="Significance", bbox_to_anchor=(1.2, -0.05))
872
+
873
+ plt.xlabel(x_label)
874
+ plt.ylabel(y_label)
875
+
876
+ if return_fig:
877
+ return plt.gcf()
878
+ plt.show()
879
+ return None
466
880
 
467
881
 
468
882
  class LinearModelBase(MethodBase):
469
883
  def __init__(self, adata, design, *, mask=None, layer=None, **kwargs):
470
- """
471
- Initialize the method.
884
+ """Initialize the method.
472
885
 
473
886
  Args:
474
887
  adata: AnnData object, usually pseudobulked.
@@ -480,26 +893,24 @@ class LinearModelBase(MethodBase):
480
893
  super().__init__(adata, mask=mask, layer=layer)
481
894
  self._check_counts()
482
895
 
483
- self.factor_storage = None
484
- self.variable_to_factors = None
485
-
896
+ self.formulaic_contrasts = None
486
897
  if isinstance(design, str):
487
- self.factor_storage, self.variable_to_factors, materializer_class = get_factor_storage_and_materializer()
488
- self.design = materializer_class(adata.obs, record_factor_metadata=True).get_model_matrix(design)
898
+ self.formulaic_contrasts = FormulaicContrasts(adata.obs, design)
899
+ self.design = self.formulaic_contrasts.design_matrix
489
900
  else:
490
901
  self.design = design
491
902
 
492
903
  @classmethod
493
904
  def compare_groups(
494
905
  cls,
495
- adata,
496
- column,
497
- baseline,
498
- groups_to_compare,
906
+ adata: ad.AnnData,
907
+ column: str,
908
+ baseline: str,
909
+ groups_to_compare: str | Iterable[str],
499
910
  *,
500
- paired_by=None,
501
- mask=None,
502
- layer=None,
911
+ paired_by: str | None = None,
912
+ mask: pd.Series | None = None,
913
+ layer: str | None = None,
503
914
  fit_kwargs=MappingProxyType({}),
504
915
  test_kwargs=MappingProxyType({}),
505
916
  ):
@@ -525,17 +936,16 @@ class LinearModelBase(MethodBase):
525
936
  @property
526
937
  def variables(self):
527
938
  """Get the names of the variables used in the model definition."""
528
- try:
529
- return self.design.model_spec.variables_by_source["data"]
530
- except AttributeError:
939
+ if self.formulaic_contrasts is None:
531
940
  raise ValueError(
532
941
  "Retrieving variables is only possible if the model was initialized using a formula."
533
942
  ) from None
943
+ else:
944
+ return self.formulaic_contrasts.variables
534
945
 
535
946
  @abstractmethod
536
947
  def _check_counts(self):
537
- """
538
- Check that counts are valid for the specific method.
948
+ """Check that counts are valid for the specific method.
539
949
 
540
950
  Raises:
541
951
  ValueError: if the data matrix does not comply with the expectations.
@@ -544,8 +954,7 @@ class LinearModelBase(MethodBase):
544
954
 
545
955
  @abstractmethod
546
956
  def fit(self, **kwargs):
547
- """
548
- Fit the model.
957
+ """Fit the model.
549
958
 
550
959
  Args:
551
960
  **kwargs: Additional arguments for fitting the specific method.
@@ -555,9 +964,8 @@ class LinearModelBase(MethodBase):
555
964
  @abstractmethod
556
965
  def _test_single_contrast(self, contrast, **kwargs): ...
557
966
 
558
- def test_contrasts(self, contrasts, **kwargs):
559
- """
560
- Perform a comparison as specified in a contrast vector.
967
+ def test_contrasts(self, contrasts: np.ndarray | Mapping[str | None, np.ndarray], **kwargs):
968
+ """Perform a comparison as specified in a contrast vector.
561
969
 
562
970
  Args:
563
971
  contrasts: Either a numeric contrast vector, or a dictionary of numeric contrast vectors.
@@ -573,25 +981,25 @@ class LinearModelBase(MethodBase):
573
981
  results.append(self._test_single_contrast(contrast, **kwargs).assign(contrast=name))
574
982
 
575
983
  results_df = pd.concat(results)
984
+
576
985
  return results_df
577
986
 
578
987
  def test_reduced(self, modelB):
579
- """
580
- Test against a reduced model.
988
+ """Test against a reduced model.
581
989
 
582
990
  Args:
583
991
  modelB: the reduced model against which to test.
584
992
 
585
993
  Example:
586
- modelA = Model().fit()
587
- modelB = Model().fit()
588
- modelA.test_reduced(modelB)
994
+ >>> import pertpy as pt
995
+ >>> modelA = Model().fit()
996
+ >>> modelB = Model().fit()
997
+ >>> modelA.test_reduced(modelB)
589
998
  """
590
999
  raise NotImplementedError
591
1000
 
592
1001
  def cond(self, **kwargs):
593
- """
594
- Get a contrast vector representing a specific condition.
1002
+ """Get a contrast vector representing a specific condition.
595
1003
 
596
1004
  Args:
597
1005
  **kwargs: column/value pairs.
@@ -599,52 +1007,14 @@ class LinearModelBase(MethodBase):
599
1007
  Returns:
600
1008
  A contrast vector that aligns to the columns of the design matrix.
601
1009
  """
602
- if self.factor_storage is None:
1010
+ if self.formulaic_contrasts is None:
603
1011
  raise RuntimeError(
604
1012
  "Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
605
1013
  )
606
- cond_dict = kwargs
607
- if not set(cond_dict.keys()).issubset(self.variables):
608
- raise ValueError(
609
- "You specified a variable that is not part of the model. Available variables: "
610
- + ",".join(self.variables)
611
- )
612
- for var in self.variables:
613
- if var in cond_dict:
614
- self._check_category(var, cond_dict[var])
615
- else:
616
- cond_dict[var] = self._get_default_value(var)
617
- df = pd.DataFrame([kwargs])
618
- return self.design.model_spec.get_model_matrix(df).iloc[0]
619
-
620
- def _get_factor_metadata_for_variable(self, var):
621
- factors = self.variable_to_factors[var]
622
- return list(chain.from_iterable(self.factor_storage[f] for f in factors))
623
-
624
- def _get_default_value(self, var):
625
- factor_metadata = self._get_factor_metadata_for_variable(var)
626
- if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL:
627
- try:
628
- tmp_base = resolve_ambiguous(factor_metadata, "base")
629
- except AmbiguousAttributeError as e:
630
- raise ValueError(
631
- f"Could not automatically resolve base category for variable {var}. Please specify it explicity in `model.cond`."
632
- ) from e
633
- return tmp_base if tmp_base is not None else "\0"
634
- else:
635
- return 0
636
-
637
- def _check_category(self, var, value):
638
- factor_metadata = self._get_factor_metadata_for_variable(var)
639
- tmp_categories = resolve_ambiguous(factor_metadata, "categories")
640
- if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL and value not in tmp_categories:
641
- raise ValueError(
642
- f"You specified a non-existant category for {var}. Possible categories: {', '.join(tmp_categories)}"
643
- )
1014
+ return self.formulaic_contrasts.cond(**kwargs)
644
1015
 
645
- def contrast(self, column, baseline, group_to_compare):
646
- """
647
- Build a simple contrast for pairwise comparisons.
1016
+ def contrast(self, *args, **kwargs):
1017
+ """Build a simple contrast for pairwise comparisons.
648
1018
 
649
1019
  Args:
650
1020
  column: column in adata.obs to test on.
@@ -654,4 +1024,8 @@ class LinearModelBase(MethodBase):
654
1024
  Returns:
655
1025
  Numeric contrast vector.
656
1026
  """
657
- return self.cond(**{column: group_to_compare}) - self.cond(**{column: baseline})
1027
+ if self.formulaic_contrasts is None:
1028
+ raise RuntimeError(
1029
+ "Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
1030
+ )
1031
+ return self.formulaic_contrasts.contrast(*args, **kwargs)