pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
- import os
1
+ import math
2
2
  from abc import ABC, abstractmethod
3
- from dataclasses import dataclass
4
- from itertools import chain
3
+ from collections.abc import Iterable, Mapping, Sequence
4
+ from itertools import zip_longest
5
5
  from types import MappingProxyType
6
6
 
7
7
  import adjustText
@@ -11,27 +11,14 @@ import matplotlib.pyplot as plt
11
11
  import numpy as np
12
12
  import pandas as pd
13
13
  import seaborn as sns
14
+ from formulaic_contrasts import FormulaicContrasts
15
+ from lamin_utils import logger
16
+ from matplotlib.pyplot import Figure
14
17
  from matplotlib.ticker import MaxNLocator
15
18
 
19
+ from pertpy._doc import _doc_params, doc_common_plot_args
20
+ from pertpy.tools import PseudobulkSpace
16
21
  from pertpy.tools._differential_gene_expression._checks import check_is_numeric_matrix
17
- from pertpy.tools._differential_gene_expression._formulaic import (
18
- AmbiguousAttributeError,
19
- Factor,
20
- get_factor_storage_and_materializer,
21
- resolve_ambiguous,
22
- )
23
-
24
-
25
- @dataclass
26
- class Contrast:
27
- """Simple contrast for comparison between groups"""
28
-
29
- column: str
30
- baseline: str
31
- group_to_compare: str
32
-
33
-
34
- ContrastType = Contrast | tuple[str, str, str]
35
22
 
36
23
 
37
24
  class MethodBase(ABC):
@@ -58,7 +45,7 @@ class MethodBase(ABC):
58
45
  if self.layer is None:
59
46
  return self.adata.X
60
47
  else:
61
- return self.adata.layer[self.layer]
48
+ return self.adata.layers[self.layer]
62
49
 
63
50
  @classmethod
64
51
  @abstractmethod
@@ -91,9 +78,28 @@ class MethodBase(ABC):
91
78
 
92
79
  Returns:
93
80
  Pandas dataframe with results ordered by significance. If multiple comparisons were performed this is indicated in an additional column.
81
+
82
+ Examples:
83
+ >>> # Example with EdgeR
84
+ >>> import pertpy as pt
85
+ >>> adata = pt.dt.zhang_2021()
86
+ >>> adata.layers["counts"] = adata.X.copy()
87
+ >>> ps = pt.tl.PseudobulkSpace()
88
+ >>> pdata = ps.compute(
89
+ ... adata,
90
+ ... target_col="Patient",
91
+ ... groups_col="Cluster",
92
+ ... layer_key="counts",
93
+ ... mode="sum",
94
+ ... min_cells=10,
95
+ ... min_counts=1000,
96
+ ... )
97
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
98
+ >>> res_df = edgr.compare_groups(pdata, column="Efficacy", baseline="SD", groups_to_compare=["PR", "PD"])
94
99
  """
95
100
  ...
96
101
 
102
+ @_doc_params(common_plot_args=doc_common_plot_args)
97
103
  def plot_volcano(
98
104
  self,
99
105
  data: pd.DataFrame | ad.AnnData,
@@ -115,13 +121,13 @@ class MethodBase(ABC):
115
121
  figsize: tuple[int, int] = (5, 5),
116
122
  legend_pos: tuple[float, float] = (1.6, 1),
117
123
  point_sizes: tuple[int, int] = (15, 150),
118
- save: bool | str | None = None,
119
124
  shapes: list[str] | None = None,
120
125
  shape_order: list[str] | None = None,
121
126
  x_label: str | None = None,
122
127
  y_label: str | None = None,
128
+ return_fig: bool = False,
123
129
  **kwargs: int,
124
- ) -> None:
130
+ ) -> Figure | None:
125
131
  """Creates a volcano plot from a pandas DataFrame or Anndata.
126
132
 
127
133
  Args:
@@ -143,12 +149,40 @@ class MethodBase(ABC):
143
149
  top_right_frame: Whether to show the top and right frame of the plot.
144
150
  figsize: Size of the figure.
145
151
  legend_pos: Position of the legend as determined by matplotlib.
146
- save: Saves the plot if True or to the path provided.
147
152
  shapes: List of matplotlib marker ids.
148
153
  shape_order: Order of categories for shapes.
149
154
  x_label: Label for the x-axis.
150
155
  y_label: Label for the y-axis.
156
+ {common_plot_args}
151
157
  **kwargs: Additional arguments for seaborn.scatterplot.
158
+
159
+ Returns:
160
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
161
+
162
+ Examples:
163
+ >>> # Example with EdgeR
164
+ >>> import pertpy as pt
165
+ >>> adata = pt.dt.zhang_2021()
166
+ >>> adata.layers["counts"] = adata.X.copy()
167
+ >>> ps = pt.tl.PseudobulkSpace()
168
+ >>> pdata = ps.compute(
169
+ ... adata,
170
+ ... target_col="Patient",
171
+ ... groups_col="Cluster",
172
+ ... layer_key="counts",
173
+ ... mode="sum",
174
+ ... min_cells=10,
175
+ ... min_counts=1000,
176
+ ... )
177
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
178
+ >>> edgr.fit()
179
+ >>> res_df = edgr.test_contrasts(
180
+ ... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
181
+ ... )
182
+ >>> edgr.plot_volcano(res_df, log2fc_thresh=0)
183
+
184
+ Preview:
185
+ .. image:: /_static/docstring_previews/de_volcano.png
152
186
  """
153
187
  if colors is None:
154
188
  colors = ["gray", "#D62728", "#1F77B4"]
@@ -243,7 +277,7 @@ class MethodBase(ABC):
243
277
  if varm_key is None:
244
278
  raise ValueError("Please pass a .varm key to use for plotting")
245
279
 
246
- raise NotImplementedError("Anndata not implemented yet")
280
+ raise NotImplementedError("Anndata not implemented yet") # TODO: Implement this
247
281
  df = data.varm[varm_key].copy()
248
282
 
249
283
  df = data.copy(deep=True)
@@ -449,26 +483,405 @@ class MethodBase(ABC):
449
483
 
450
484
  plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
451
485
 
452
- # TODO replace with scanpy save style
453
- if save:
454
- files = os.listdir()
455
- for x in range(100):
456
- file_pref = "volcano_" + "%02d" % (x,)
457
- if len([x for x in files if x.startswith(file_pref)]) == 0:
458
- plt.savefig(file_pref + ".png", dpi=300, bbox_inches="tight")
459
- plt.savefig(file_pref + ".svg", bbox_inches="tight")
460
- break
461
- elif isinstance(save, str):
462
- plt.savefig(save + ".png", dpi=300, bbox_inches="tight")
463
- plt.savefig(save + ".svg", bbox_inches="tight")
486
+ if return_fig:
487
+ return plt.gcf()
488
+ plt.show()
489
+ return None
490
+
491
+ @_doc_params(common_plot_args=doc_common_plot_args)
492
+ def plot_paired(
493
+ self,
494
+ adata: ad.AnnData,
495
+ results_df: pd.DataFrame,
496
+ groupby: str,
497
+ pairedby: str,
498
+ *,
499
+ var_names: Sequence[str] = None,
500
+ n_top_vars: int = 15,
501
+ layer: str = None,
502
+ pvalue_col: str = "adj_p_value",
503
+ symbol_col: str = "variable",
504
+ n_cols: int = 4,
505
+ panel_size: tuple[int, int] = (5, 5),
506
+ show_legend: bool = True,
507
+ size: int = 10,
508
+ y_label: str = "expression",
509
+ pvalue_template=lambda x: f"p={x:.2e}",
510
+ boxplot_properties=None,
511
+ palette=None,
512
+ return_fig: bool = False,
513
+ ) -> Figure | None:
514
+ """Creates a pairwise expression plot from a Pandas DataFrame or Anndata.
515
+
516
+ Visualizes a panel of paired scatterplots per variable.
517
+
518
+ Args:
519
+ adata: AnnData object, can be pseudobulked.
520
+ results_df: DataFrame with results from a differential expression test.
521
+ groupby: .obs column containing the grouping. Must contain exactly two different values.
522
+ pairedby: .obs column containing the pairing (e.g. "patient_id"). If None, an independent t-test is performed.
523
+ var_names: Variables to plot.
524
+ n_top_vars: Number of top variables to plot.
525
+ layer: Layer to use for plotting.
526
+ pvalue_col: Column name of the p values.
527
+ symbol_col: Column name of gene IDs.
528
+ n_cols: Number of columns in the plot.
529
+ panel_size: Size of each panel.
530
+ show_legend: Whether to show the legend.
531
+ size: Size of the points.
532
+ y_label: Label for the y-axis.
533
+ pvalue_template: Template for the p-value string displayed in the title of each panel.
534
+ boxplot_properties: Additional properties for the boxplot, passed to seaborn.boxplot.
535
+ palette: Color palette for the line- and stripplot.
536
+ {common_plot_args}
537
+
538
+ Returns:
539
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
540
+
541
+ Examples:
542
+ >>> # Example with EdgeR
543
+ >>> import pertpy as pt
544
+ >>> adata = pt.dt.zhang_2021()
545
+ >>> adata.layers["counts"] = adata.X.copy()
546
+ >>> ps = pt.tl.PseudobulkSpace()
547
+ >>> pdata = ps.compute(
548
+ ... adata,
549
+ ... target_col="Patient",
550
+ ... groups_col="Cluster",
551
+ ... layer_key="counts",
552
+ ... mode="sum",
553
+ ... min_cells=10,
554
+ ... min_counts=1000,
555
+ ... )
556
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
557
+ >>> edgr.fit()
558
+ >>> res_df = edgr.test_contrasts(
559
+ ... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
560
+ ... )
561
+ >>> edgr.plot_paired(pdata, results_df=res_df, n_top_vars=8, groupby="Treatment", pairedby="Efficacy")
562
+
563
+ Preview:
564
+ .. image:: /_static/docstring_previews/de_paired_expression.png
565
+ """
566
+ if boxplot_properties is None:
567
+ boxplot_properties = {}
568
+ groups = adata.obs[groupby].unique()
569
+ if len(groups) != 2:
570
+ raise ValueError("The number of groups in the group_by column must be exactly 2 to enable paired testing")
571
+
572
+ if var_names is None:
573
+ var_names = results_df.head(n_top_vars)[symbol_col].tolist()
574
+
575
+ adata = adata[:, var_names]
576
+
577
+ if any(adata.obs[[groupby, pairedby]].value_counts() > 1):
578
+ logger.info("Performing pseudobulk for paired samples")
579
+ ps = PseudobulkSpace()
580
+ adata = ps.compute(
581
+ adata, target_col=groupby, groups_col=pairedby, layer_key=layer, mode="sum", min_cells=1, min_counts=1
582
+ )
464
583
 
584
+ if layer is not None:
585
+ X = adata.layers[layer]
586
+ else:
587
+ X = adata.X
588
+ try:
589
+ X = X.toarray()
590
+ except AttributeError:
591
+ pass
592
+
593
+ groupby_cols = [pairedby, groupby]
594
+ df = adata.obs.loc[:, groupby_cols].join(pd.DataFrame(X, index=adata.obs_names, columns=var_names))
595
+
596
+ # remove unpaired samples
597
+ paired_samples = set(df[df[groupby] == groups[0]][pairedby]) & set(df[df[groupby] == groups[1]][pairedby])
598
+ df = df[df[pairedby].isin(paired_samples)]
599
+ removed_samples = adata.obs[pairedby].nunique() - len(df[pairedby].unique())
600
+ if removed_samples > 0:
601
+ logger.warning(f"{removed_samples} unpaired samples removed")
602
+
603
+ pvalues = results_df.set_index(symbol_col).loc[var_names, pvalue_col].values
604
+ df.reset_index(drop=False, inplace=True)
605
+
606
+ # transform data for seaborn
607
+ df_melt = df.melt(
608
+ id_vars=groupby_cols,
609
+ var_name="var",
610
+ value_name="val",
611
+ )
612
+
613
+ n_panels = len(var_names)
614
+ nrows = math.ceil(n_panels / n_cols)
615
+ ncols = min(n_cols, n_panels)
616
+
617
+ fig, axes = plt.subplots(
618
+ nrows,
619
+ ncols,
620
+ figsize=(ncols * panel_size[0], nrows * panel_size[1]),
621
+ tight_layout=True,
622
+ squeeze=False,
623
+ )
624
+ axes = axes.flatten()
625
+ for i, (var, ax) in enumerate(zip_longest(var_names, axes)):
626
+ if var is not None:
627
+ sns.boxplot(
628
+ x=groupby,
629
+ data=df_melt.loc[df_melt["var"] == var],
630
+ y="val",
631
+ ax=ax,
632
+ color="white",
633
+ fliersize=0,
634
+ **boxplot_properties,
635
+ )
636
+ if pairedby is not None:
637
+ sns.lineplot(
638
+ x=groupby,
639
+ data=df_melt.loc[df_melt["var"] == var],
640
+ y="val",
641
+ ax=ax,
642
+ hue=pairedby,
643
+ legend=False,
644
+ errorbar=None,
645
+ palette=palette,
646
+ )
647
+ jitter = 0 if pairedby else True
648
+ sns.stripplot(
649
+ x=groupby,
650
+ data=df_melt.loc[df_melt["var"] == var],
651
+ y="val",
652
+ ax=ax,
653
+ hue=pairedby,
654
+ jitter=jitter,
655
+ size=size,
656
+ linewidth=1,
657
+ palette=palette,
658
+ )
659
+
660
+ ax.set_xlabel("")
661
+ ax.tick_params(
662
+ axis="x",
663
+ labelsize=15,
664
+ )
665
+ ax.legend().set_visible(False)
666
+ ax.set_ylabel(y_label)
667
+ ax.set_title(f"{var}\n{pvalue_template(pvalues[i])}")
668
+ else:
669
+ ax.set_visible(False)
670
+ fig.tight_layout()
671
+
672
+ if show_legend is True:
673
+ axes[n_panels - 1].legend().set_visible(True)
674
+ axes[n_panels - 1].legend(
675
+ bbox_to_anchor=(0.5, -0.1), loc="upper center", ncol=adata.obs[pairedby].nunique()
676
+ )
677
+
678
+ plt.tight_layout()
679
+ if return_fig:
680
+ return plt.gcf()
465
681
  plt.show()
682
+ return None
683
+
684
+ @_doc_params(common_plot_args=doc_common_plot_args)
685
+ def plot_fold_change(
686
+ self,
687
+ results_df: pd.DataFrame,
688
+ *,
689
+ var_names: Sequence[str] = None,
690
+ n_top_vars: int = 15,
691
+ log2fc_col: str = "log_fc",
692
+ symbol_col: str = "variable",
693
+ y_label: str = "Log2 fold change",
694
+ figsize: tuple[int, int] = (10, 5),
695
+ return_fig: bool = False,
696
+ **barplot_kwargs,
697
+ ) -> Figure | None:
698
+ """Plot a metric from the results as a bar chart, optionally with additional information about paired samples in a scatter plot.
699
+
700
+ Args:
701
+ results_df: DataFrame with results from DE analysis.
702
+ var_names: Variables to plot. If None, the top n_top_vars variables based on the log2 fold change are plotted.
703
+ n_top_vars: Number of top variables to plot. The top and bottom n_top_vars variables are plotted, respectively.
704
+ log2fc_col: Column name of log2 Fold-Change values.
705
+ symbol_col: Column name of gene IDs.
706
+ y_label: Label for the y-axis.
707
+ figsize: Size of the figure.
708
+ {common_plot_args}
709
+ **barplot_kwargs: Additional arguments for seaborn.barplot.
710
+
711
+ Returns:
712
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
713
+
714
+ Examples:
715
+ >>> # Example with EdgeR
716
+ >>> import pertpy as pt
717
+ >>> adata = pt.dt.zhang_2021()
718
+ >>> adata.layers["counts"] = adata.X.copy()
719
+ >>> ps = pt.tl.PseudobulkSpace()
720
+ >>> pdata = ps.compute(
721
+ ... adata,
722
+ ... target_col="Patient",
723
+ ... groups_col="Cluster",
724
+ ... layer_key="counts",
725
+ ... mode="sum",
726
+ ... min_cells=10,
727
+ ... min_counts=1000,
728
+ ... )
729
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
730
+ >>> edgr.fit()
731
+ >>> res_df = edgr.test_contrasts(
732
+ ... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
733
+ ... )
734
+ >>> edgr.plot_fold_change(res_df)
735
+
736
+ Preview:
737
+ .. image:: /_static/docstring_previews/de_fold_change.png
738
+ """
739
+ if var_names is None:
740
+ var_names = results_df.sort_values(log2fc_col, ascending=False).head(n_top_vars)[symbol_col].tolist()
741
+ var_names += results_df.sort_values(log2fc_col, ascending=True).head(n_top_vars)[symbol_col].tolist()
742
+ assert len(var_names) == 2 * n_top_vars
743
+
744
+ df = results_df[results_df[symbol_col].isin(var_names)]
745
+ df.sort_values(log2fc_col, ascending=False, inplace=True)
746
+
747
+ plt.figure(figsize=figsize)
748
+ sns.barplot(
749
+ x=symbol_col,
750
+ y=log2fc_col,
751
+ data=df,
752
+ palette="RdBu",
753
+ legend=False,
754
+ **barplot_kwargs,
755
+ )
756
+ plt.xticks(rotation=90)
757
+ plt.xlabel("")
758
+ plt.ylabel(y_label)
759
+
760
+ if return_fig:
761
+ return plt.gcf()
762
+ plt.show()
763
+ return None
764
+
765
+ @_doc_params(common_plot_args=doc_common_plot_args)
766
+ def plot_multicomparison_fc(
767
+ self,
768
+ results_df: pd.DataFrame,
769
+ *,
770
+ n_top_vars=15,
771
+ contrast_col: str = "contrast",
772
+ log2fc_col: str = "log_fc",
773
+ pvalue_col: str = "adj_p_value",
774
+ symbol_col: str = "variable",
775
+ marker_size: int = 100,
776
+ figsize: tuple[int, int] = (10, 2),
777
+ x_label: str = "Contrast",
778
+ y_label: str = "Gene",
779
+ return_fig: bool = False,
780
+ **heatmap_kwargs,
781
+ ) -> Figure | None:
782
+ """Plot a matrix of log2 fold changes from the results.
783
+
784
+ Args:
785
+ results_df: DataFrame with results from DE analysis.
786
+ n_top_vars: Number of top variables to plot per group.
787
+ contrast_col: Column in results_df containing information about the contrast.
788
+ log2fc_col: Column in results_df containing the log2 fold change.
789
+ pvalue_col: Column in results_df containing the p-value. Can be used to switch between adjusted and unadjusted p-values.
790
+ symbol_col: Column in results_df containing the gene symbol.
791
+ marker_size: Size of the biggest marker for significant variables.
792
+ figsize: Size of the figure.
793
+ x_label: Label for the x-axis.
794
+ y_label: Label for the y-axis.
795
+ {common_plot_args}
796
+ **heatmap_kwargs: Additional arguments for seaborn.heatmap.
797
+
798
+ Returns:
799
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
800
+
801
+ Examples:
802
+ >>> # Example with EdgeR
803
+ >>> import pertpy as pt
804
+ >>> adata = pt.dt.zhang_2021()
805
+ >>> adata.layers["counts"] = adata.X.copy()
806
+ >>> ps = pt.tl.PseudobulkSpace()
807
+ >>> pdata = ps.compute(
808
+ ... adata,
809
+ ... target_col="Patient",
810
+ ... groups_col="Cluster",
811
+ ... layer_key="counts",
812
+ ... mode="sum",
813
+ ... min_cells=10,
814
+ ... min_counts=1000,
815
+ ... )
816
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
817
+ >>> res_df = edgr.compare_groups(pdata, column="Efficacy", baseline="SD", groups_to_compare=["PR", "PD"])
818
+ >>> edgr.plot_multicomparison_fc(res_df)
819
+
820
+ Preview:
821
+ .. image:: /_static/docstring_previews/de_multicomparison_fc.png
822
+ """
823
+ groups = results_df[contrast_col].unique().tolist()
824
+
825
+ results_df["abs_log_fc"] = results_df[log2fc_col].abs()
826
+
827
+ def _get_significance(p_val):
828
+ if p_val < 0.001:
829
+ return "< 0.001"
830
+ elif p_val < 0.01:
831
+ return "< 0.01"
832
+ elif p_val < 0.1:
833
+ return "< 0.1"
834
+ else:
835
+ return "n.s."
836
+
837
+ results_df["significance"] = results_df[pvalue_col].apply(_get_significance)
838
+
839
+ var_names = []
840
+ for group in groups:
841
+ var_names += (
842
+ results_df[results_df[contrast_col] == group]
843
+ .sort_values("abs_log_fc", ascending=False)
844
+ .head(n_top_vars)[symbol_col]
845
+ .tolist()
846
+ )
847
+
848
+ results_df = results_df[results_df[symbol_col].isin(var_names)]
849
+ df = results_df.pivot(index=contrast_col, columns=symbol_col, values=log2fc_col)[var_names]
850
+
851
+ plt.figure(figsize=figsize)
852
+ sns.heatmap(df, **heatmap_kwargs, cmap="coolwarm", center=0, cbar_kws={"label": "Log2 fold change"})
853
+
854
+ _size = {"< 0.001": marker_size, "< 0.01": math.floor(marker_size / 2), "< 0.1": math.floor(marker_size / 4)}
855
+ x_locs, x_labels = plt.xticks()[0], [label.get_text() for label in plt.xticks()[1]]
856
+ y_locs, y_labels = plt.yticks()[0], [label.get_text() for label in plt.yticks()[1]]
857
+
858
+ for _i, row in results_df.iterrows():
859
+ if row["significance"] != "n.s.":
860
+ plt.scatter(
861
+ x=x_locs[x_labels.index(row[symbol_col])],
862
+ y=y_locs[y_labels.index(row[contrast_col])],
863
+ s=_size[row["significance"]],
864
+ marker="*",
865
+ c="white",
866
+ )
867
+
868
+ plt.scatter([], [], s=marker_size, marker="*", c="black", label="< 0.001")
869
+ plt.scatter([], [], s=math.floor(marker_size / 2), marker="*", c="black", label="< 0.01")
870
+ plt.scatter([], [], s=math.floor(marker_size / 4), marker="*", c="black", label="< 0.1")
871
+ plt.legend(title="Significance", bbox_to_anchor=(1.2, -0.05))
872
+
873
+ plt.xlabel(x_label)
874
+ plt.ylabel(y_label)
875
+
876
+ if return_fig:
877
+ return plt.gcf()
878
+ plt.show()
879
+ return None
466
880
 
467
881
 
468
882
  class LinearModelBase(MethodBase):
469
883
  def __init__(self, adata, design, *, mask=None, layer=None, **kwargs):
470
- """
471
- Initialize the method.
884
+ """Initialize the method.
472
885
 
473
886
  Args:
474
887
  adata: AnnData object, usually pseudobulked.
@@ -480,26 +893,24 @@ class LinearModelBase(MethodBase):
480
893
  super().__init__(adata, mask=mask, layer=layer)
481
894
  self._check_counts()
482
895
 
483
- self.factor_storage = None
484
- self.variable_to_factors = None
485
-
896
+ self.formulaic_contrasts = None
486
897
  if isinstance(design, str):
487
- self.factor_storage, self.variable_to_factors, materializer_class = get_factor_storage_and_materializer()
488
- self.design = materializer_class(adata.obs, record_factor_metadata=True).get_model_matrix(design)
898
+ self.formulaic_contrasts = FormulaicContrasts(adata.obs, design)
899
+ self.design = self.formulaic_contrasts.design_matrix
489
900
  else:
490
901
  self.design = design
491
902
 
492
903
  @classmethod
493
904
  def compare_groups(
494
905
  cls,
495
- adata,
496
- column,
497
- baseline,
498
- groups_to_compare,
906
+ adata: ad.AnnData,
907
+ column: str,
908
+ baseline: str,
909
+ groups_to_compare: str | Iterable[str],
499
910
  *,
500
- paired_by=None,
501
- mask=None,
502
- layer=None,
911
+ paired_by: str | None = None,
912
+ mask: pd.Series | None = None,
913
+ layer: str | None = None,
503
914
  fit_kwargs=MappingProxyType({}),
504
915
  test_kwargs=MappingProxyType({}),
505
916
  ):
@@ -525,17 +936,16 @@ class LinearModelBase(MethodBase):
525
936
  @property
526
937
  def variables(self):
527
938
  """Get the names of the variables used in the model definition."""
528
- try:
529
- return self.design.model_spec.variables_by_source["data"]
530
- except AttributeError:
939
+ if self.formulaic_contrasts is None:
531
940
  raise ValueError(
532
941
  "Retrieving variables is only possible if the model was initialized using a formula."
533
942
  ) from None
943
+ else:
944
+ return self.formulaic_contrasts.variables
534
945
 
535
946
  @abstractmethod
536
947
  def _check_counts(self):
537
- """
538
- Check that counts are valid for the specific method.
948
+ """Check that counts are valid for the specific method.
539
949
 
540
950
  Raises:
541
951
  ValueError: if the data matrix does not comply with the expectations.
@@ -544,8 +954,7 @@ class LinearModelBase(MethodBase):
544
954
 
545
955
  @abstractmethod
546
956
  def fit(self, **kwargs):
547
- """
548
- Fit the model.
957
+ """Fit the model.
549
958
 
550
959
  Args:
551
960
  **kwargs: Additional arguments for fitting the specific method.
@@ -555,9 +964,8 @@ class LinearModelBase(MethodBase):
555
964
  @abstractmethod
556
965
  def _test_single_contrast(self, contrast, **kwargs): ...
557
966
 
558
- def test_contrasts(self, contrasts, **kwargs):
559
- """
560
- Perform a comparison as specified in a contrast vector.
967
+ def test_contrasts(self, contrasts: np.ndarray | Mapping[str | None, np.ndarray], **kwargs):
968
+ """Perform a comparison as specified in a contrast vector.
561
969
 
562
970
  Args:
563
971
  contrasts: Either a numeric contrast vector, or a dictionary of numeric contrast vectors.
@@ -573,25 +981,25 @@ class LinearModelBase(MethodBase):
573
981
  results.append(self._test_single_contrast(contrast, **kwargs).assign(contrast=name))
574
982
 
575
983
  results_df = pd.concat(results)
984
+
576
985
  return results_df
577
986
 
578
987
  def test_reduced(self, modelB):
579
- """
580
- Test against a reduced model.
988
+ """Test against a reduced model.
581
989
 
582
990
  Args:
583
991
  modelB: the reduced model against which to test.
584
992
 
585
993
  Example:
586
- modelA = Model().fit()
587
- modelB = Model().fit()
588
- modelA.test_reduced(modelB)
994
+ >>> import pertpy as pt
995
+ >>> modelA = Model().fit()
996
+ >>> modelB = Model().fit()
997
+ >>> modelA.test_reduced(modelB)
589
998
  """
590
999
  raise NotImplementedError
591
1000
 
592
1001
  def cond(self, **kwargs):
593
- """
594
- Get a contrast vector representing a specific condition.
1002
+ """Get a contrast vector representing a specific condition.
595
1003
 
596
1004
  Args:
597
1005
  **kwargs: column/value pairs.
@@ -599,52 +1007,14 @@ class LinearModelBase(MethodBase):
599
1007
  Returns:
600
1008
  A contrast vector that aligns to the columns of the design matrix.
601
1009
  """
602
- if self.factor_storage is None:
1010
+ if self.formulaic_contrasts is None:
603
1011
  raise RuntimeError(
604
1012
  "Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
605
1013
  )
606
- cond_dict = kwargs
607
- if not set(cond_dict.keys()).issubset(self.variables):
608
- raise ValueError(
609
- "You specified a variable that is not part of the model. Available variables: "
610
- + ",".join(self.variables)
611
- )
612
- for var in self.variables:
613
- if var in cond_dict:
614
- self._check_category(var, cond_dict[var])
615
- else:
616
- cond_dict[var] = self._get_default_value(var)
617
- df = pd.DataFrame([kwargs])
618
- return self.design.model_spec.get_model_matrix(df).iloc[0]
619
-
620
- def _get_factor_metadata_for_variable(self, var):
621
- factors = self.variable_to_factors[var]
622
- return list(chain.from_iterable(self.factor_storage[f] for f in factors))
623
-
624
- def _get_default_value(self, var):
625
- factor_metadata = self._get_factor_metadata_for_variable(var)
626
- if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL:
627
- try:
628
- tmp_base = resolve_ambiguous(factor_metadata, "base")
629
- except AmbiguousAttributeError as e:
630
- raise ValueError(
631
- f"Could not automatically resolve base category for variable {var}. Please specify it explicity in `model.cond`."
632
- ) from e
633
- return tmp_base if tmp_base is not None else "\0"
634
- else:
635
- return 0
636
-
637
- def _check_category(self, var, value):
638
- factor_metadata = self._get_factor_metadata_for_variable(var)
639
- tmp_categories = resolve_ambiguous(factor_metadata, "categories")
640
- if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL and value not in tmp_categories:
641
- raise ValueError(
642
- f"You specified a non-existant category for {var}. Possible categories: {', '.join(tmp_categories)}"
643
- )
1014
+ return self.formulaic_contrasts.cond(**kwargs)
644
1015
 
645
- def contrast(self, column, baseline, group_to_compare):
646
- """
647
- Build a simple contrast for pairwise comparisons.
1016
+ def contrast(self, *args, **kwargs):
1017
+ """Build a simple contrast for pairwise comparisons.
648
1018
 
649
1019
  Args:
650
1020
  column: column in adata.obs to test on.
@@ -654,4 +1024,8 @@ class LinearModelBase(MethodBase):
654
1024
  Returns:
655
1025
  Numeric contrast vector.
656
1026
  """
657
- return self.cond(**{column: group_to_compare}) - self.cond(**{column: baseline})
1027
+ if self.formulaic_contrasts is None:
1028
+ raise RuntimeError(
1029
+ "Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
1030
+ )
1031
+ return self.formulaic_contrasts.contrast(*args, **kwargs)