pertpy 0.9.4__py3-none-any.whl → 0.9.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
- import os
1
+ import math
2
2
  from abc import ABC, abstractmethod
3
- from dataclasses import dataclass
4
- from itertools import chain
3
+ from collections.abc import Iterable, Mapping, Sequence
4
+ from itertools import zip_longest
5
5
  from types import MappingProxyType
6
6
 
7
7
  import adjustText
@@ -11,27 +11,14 @@ import matplotlib.pyplot as plt
11
11
  import numpy as np
12
12
  import pandas as pd
13
13
  import seaborn as sns
14
+ from formulaic_contrasts import FormulaicContrasts
15
+ from lamin_utils import logger
16
+ from matplotlib.pyplot import Figure
14
17
  from matplotlib.ticker import MaxNLocator
15
18
 
19
+ from pertpy._doc import _doc_params, doc_common_plot_args
20
+ from pertpy.tools import PseudobulkSpace
16
21
  from pertpy.tools._differential_gene_expression._checks import check_is_numeric_matrix
17
- from pertpy.tools._differential_gene_expression._formulaic import (
18
- AmbiguousAttributeError,
19
- Factor,
20
- get_factor_storage_and_materializer,
21
- resolve_ambiguous,
22
- )
23
-
24
-
25
- @dataclass
26
- class Contrast:
27
- """Simple contrast for comparison between groups"""
28
-
29
- column: str
30
- baseline: str
31
- group_to_compare: str
32
-
33
-
34
- ContrastType = Contrast | tuple[str, str, str]
35
22
 
36
23
 
37
24
  class MethodBase(ABC):
@@ -58,7 +45,7 @@ class MethodBase(ABC):
58
45
  if self.layer is None:
59
46
  return self.adata.X
60
47
  else:
61
- return self.adata.layer[self.layer]
48
+ return self.adata.layers[self.layer]
62
49
 
63
50
  @classmethod
64
51
  @abstractmethod
@@ -91,9 +78,28 @@ class MethodBase(ABC):
91
78
 
92
79
  Returns:
93
80
  Pandas dataframe with results ordered by significance. If multiple comparisons were performed this is indicated in an additional column.
81
+
82
+ Examples:
83
+ >>> # Example with EdgeR
84
+ >>> import pertpy as pt
85
+ >>> adata = pt.dt.zhang_2021()
86
+ >>> adata.layers["counts"] = adata.X.copy()
87
+ >>> ps = pt.tl.PseudobulkSpace()
88
+ >>> pdata = ps.compute(
89
+ ... adata,
90
+ ... target_col="Patient",
91
+ ... groups_col="Cluster",
92
+ ... layer_key="counts",
93
+ ... mode="sum",
94
+ ... min_cells=10,
95
+ ... min_counts=1000,
96
+ ... )
97
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
98
+ >>> res_df = edgr.compare_groups(pdata, column="Efficacy", baseline="SD", groups_to_compare=["PR", "PD"])
94
99
  """
95
100
  ...
96
101
 
102
+ @_doc_params(common_plot_args=doc_common_plot_args)
97
103
  def plot_volcano(
98
104
  self,
99
105
  data: pd.DataFrame | ad.AnnData,
@@ -115,13 +121,14 @@ class MethodBase(ABC):
115
121
  figsize: tuple[int, int] = (5, 5),
116
122
  legend_pos: tuple[float, float] = (1.6, 1),
117
123
  point_sizes: tuple[int, int] = (15, 150),
118
- save: bool | str | None = None,
119
124
  shapes: list[str] | None = None,
120
125
  shape_order: list[str] | None = None,
121
126
  x_label: str | None = None,
122
127
  y_label: str | None = None,
128
+ show: bool = True,
129
+ return_fig: bool = False,
123
130
  **kwargs: int,
124
- ) -> None:
131
+ ) -> Figure | None:
125
132
  """Creates a volcano plot from a pandas DataFrame or Anndata.
126
133
 
127
134
  Args:
@@ -143,12 +150,40 @@ class MethodBase(ABC):
143
150
  top_right_frame: Whether to show the top and right frame of the plot.
144
151
  figsize: Size of the figure.
145
152
  legend_pos: Position of the legend as determined by matplotlib.
146
- save: Saves the plot if True or to the path provided.
147
153
  shapes: List of matplotlib marker ids.
148
154
  shape_order: Order of categories for shapes.
149
155
  x_label: Label for the x-axis.
150
156
  y_label: Label for the y-axis.
157
+ {common_plot_args}
151
158
  **kwargs: Additional arguments for seaborn.scatterplot.
159
+
160
+ Returns:
161
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
162
+
163
+ Examples:
164
+ >>> # Example with EdgeR
165
+ >>> import pertpy as pt
166
+ >>> adata = pt.dt.zhang_2021()
167
+ >>> adata.layers["counts"] = adata.X.copy()
168
+ >>> ps = pt.tl.PseudobulkSpace()
169
+ >>> pdata = ps.compute(
170
+ ... adata,
171
+ ... target_col="Patient",
172
+ ... groups_col="Cluster",
173
+ ... layer_key="counts",
174
+ ... mode="sum",
175
+ ... min_cells=10,
176
+ ... min_counts=1000,
177
+ ... )
178
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
179
+ >>> edgr.fit()
180
+ >>> res_df = edgr.test_contrasts(
181
+ ... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
182
+ ... )
183
+ >>> edgr.plot_volcano(res_df, log2fc_thresh=0)
184
+
185
+ Preview:
186
+ .. image:: /_static/docstring_previews/de_volcano.png
152
187
  """
153
188
  if colors is None:
154
189
  colors = ["gray", "#D62728", "#1F77B4"]
@@ -243,7 +278,7 @@ class MethodBase(ABC):
243
278
  if varm_key is None:
244
279
  raise ValueError("Please pass a .varm key to use for plotting")
245
280
 
246
- raise NotImplementedError("Anndata not implemented yet")
281
+ raise NotImplementedError("Anndata not implemented yet") # TODO: Implement this
247
282
  df = data.varm[varm_key].copy()
248
283
 
249
284
  df = data.copy(deep=True)
@@ -449,26 +484,412 @@ class MethodBase(ABC):
449
484
 
450
485
  plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
451
486
 
452
- # TODO replace with scanpy save style
453
- if save:
454
- files = os.listdir()
455
- for x in range(100):
456
- file_pref = "volcano_" + "%02d" % (x,)
457
- if len([x for x in files if x.startswith(file_pref)]) == 0:
458
- plt.savefig(file_pref + ".png", dpi=300, bbox_inches="tight")
459
- plt.savefig(file_pref + ".svg", bbox_inches="tight")
460
- break
461
- elif isinstance(save, str):
462
- plt.savefig(save + ".png", dpi=300, bbox_inches="tight")
463
- plt.savefig(save + ".svg", bbox_inches="tight")
487
+ if show:
488
+ plt.show()
489
+ if return_fig:
490
+ return plt.gcf()
491
+ return None
492
+
493
+ @_doc_params(common_plot_args=doc_common_plot_args)
494
+ def plot_paired(
495
+ self,
496
+ adata: ad.AnnData,
497
+ results_df: pd.DataFrame,
498
+ groupby: str,
499
+ pairedby: str,
500
+ *,
501
+ var_names: Sequence[str] = None,
502
+ n_top_vars: int = 15,
503
+ layer: str = None,
504
+ pvalue_col: str = "adj_p_value",
505
+ symbol_col: str = "variable",
506
+ n_cols: int = 4,
507
+ panel_size: tuple[int, int] = (5, 5),
508
+ show_legend: bool = True,
509
+ size: int = 10,
510
+ y_label: str = "expression",
511
+ pvalue_template=lambda x: f"p={x:.2e}",
512
+ boxplot_properties=None,
513
+ palette=None,
514
+ show: bool = True,
515
+ return_fig: bool = False,
516
+ ) -> Figure | None:
517
+ """Creates a pairwise expression plot from a Pandas DataFrame or Anndata.
518
+
519
+ Visualizes a panel of paired scatterplots per variable.
520
+
521
+ Args:
522
+ adata: AnnData object, can be pseudobulked.
523
+ results_df: DataFrame with results from a differential expression test.
524
+ groupby: .obs column containing the grouping. Must contain exactly two different values.
525
+ pairedby: .obs column containing the pairing (e.g. "patient_id"). If None, an independent t-test is performed.
526
+ var_names: Variables to plot.
527
+ n_top_vars: Number of top variables to plot.
528
+ layer: Layer to use for plotting.
529
+ pvalue_col: Column name of the p values.
530
+ symbol_col: Column name of gene IDs.
531
+ n_cols: Number of columns in the plot.
532
+ panel_size: Size of each panel.
533
+ show_legend: Whether to show the legend.
534
+ size: Size of the points.
535
+ y_label: Label for the y-axis.
536
+ pvalue_template: Template for the p-value string displayed in the title of each panel.
537
+ boxplot_properties: Additional properties for the boxplot, passed to seaborn.boxplot.
538
+ palette: Color palette for the line- and stripplot.
539
+ {common_plot_args}
540
+
541
+ Returns:
542
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
543
+
544
+ Examples:
545
+ >>> # Example with EdgeR
546
+ >>> import pertpy as pt
547
+ >>> adata = pt.dt.zhang_2021()
548
+ >>> adata.layers["counts"] = adata.X.copy()
549
+ >>> ps = pt.tl.PseudobulkSpace()
550
+ >>> pdata = ps.compute(
551
+ ... adata,
552
+ ... target_col="Patient",
553
+ ... groups_col="Cluster",
554
+ ... layer_key="counts",
555
+ ... mode="sum",
556
+ ... min_cells=10,
557
+ ... min_counts=1000,
558
+ ... )
559
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
560
+ >>> edgr.fit()
561
+ >>> res_df = edgr.test_contrasts(
562
+ ... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
563
+ ... )
564
+ >>> edgr.plot_paired(pdata, results_df=res_df, n_top_vars=8, groupby="Treatment", pairedby="Efficacy")
565
+
566
+ Preview:
567
+ .. image:: /_static/docstring_previews/de_paired_expression.png
568
+ """
569
+ if boxplot_properties is None:
570
+ boxplot_properties = {}
571
+ groups = adata.obs[groupby].unique()
572
+ if len(groups) != 2:
573
+ raise ValueError("The number of groups in the group_by column must be exactly 2 to enable paired testing")
574
+
575
+ if var_names is None:
576
+ var_names = results_df.head(n_top_vars)[symbol_col].tolist()
577
+
578
+ adata = adata[:, var_names]
579
+
580
+ if any(adata.obs[[groupby, pairedby]].value_counts() > 1):
581
+ logger.info("Performing pseudobulk for paired samples")
582
+ ps = PseudobulkSpace()
583
+ adata = ps.compute(
584
+ adata, target_col=groupby, groups_col=pairedby, layer_key=layer, mode="sum", min_cells=1, min_counts=1
585
+ )
586
+
587
+ if layer is not None:
588
+ X = adata.layers[layer]
589
+ else:
590
+ X = adata.X
591
+ try:
592
+ X = X.toarray()
593
+ except AttributeError:
594
+ pass
595
+
596
+ groupby_cols = [pairedby, groupby]
597
+ df = adata.obs.loc[:, groupby_cols].join(pd.DataFrame(X, index=adata.obs_names, columns=var_names))
598
+
599
+ # remove unpaired samples
600
+ paired_samples = set(df[df[groupby] == groups[0]][pairedby]) & set(df[df[groupby] == groups[1]][pairedby])
601
+ df = df[df[pairedby].isin(paired_samples)]
602
+ removed_samples = adata.obs[pairedby].nunique() - len(df[pairedby].unique())
603
+ if removed_samples > 0:
604
+ logger.warning(f"{removed_samples} unpaired samples removed")
605
+
606
+ pvalues = results_df.set_index(symbol_col).loc[var_names, pvalue_col].values
607
+ df.reset_index(drop=False, inplace=True)
608
+
609
+ # transform data for seaborn
610
+ df_melt = df.melt(
611
+ id_vars=groupby_cols,
612
+ var_name="var",
613
+ value_name="val",
614
+ )
615
+
616
+ n_panels = len(var_names)
617
+ nrows = math.ceil(n_panels / n_cols)
618
+ ncols = min(n_cols, n_panels)
619
+
620
+ fig, axes = plt.subplots(
621
+ nrows,
622
+ ncols,
623
+ figsize=(ncols * panel_size[0], nrows * panel_size[1]),
624
+ tight_layout=True,
625
+ squeeze=False,
626
+ )
627
+ axes = axes.flatten()
628
+ for i, (var, ax) in enumerate(zip_longest(var_names, axes)):
629
+ if var is not None:
630
+ sns.boxplot(
631
+ x=groupby,
632
+ data=df_melt.loc[df_melt["var"] == var],
633
+ y="val",
634
+ ax=ax,
635
+ color="white",
636
+ fliersize=0,
637
+ **boxplot_properties,
638
+ )
639
+ if pairedby is not None:
640
+ sns.lineplot(
641
+ x=groupby,
642
+ data=df_melt.loc[df_melt["var"] == var],
643
+ y="val",
644
+ ax=ax,
645
+ hue=pairedby,
646
+ legend=False,
647
+ errorbar=None,
648
+ palette=palette,
649
+ )
650
+ jitter = 0 if pairedby else True
651
+ sns.stripplot(
652
+ x=groupby,
653
+ data=df_melt.loc[df_melt["var"] == var],
654
+ y="val",
655
+ ax=ax,
656
+ hue=pairedby,
657
+ jitter=jitter,
658
+ size=size,
659
+ linewidth=1,
660
+ palette=palette,
661
+ )
662
+
663
+ ax.set_xlabel("")
664
+ ax.tick_params(
665
+ axis="x",
666
+ labelsize=15,
667
+ )
668
+ ax.legend().set_visible(False)
669
+ ax.set_ylabel(y_label)
670
+ ax.set_title(f"{var}\n{pvalue_template(pvalues[i])}")
671
+ else:
672
+ ax.set_visible(False)
673
+ fig.tight_layout()
674
+
675
+ if show_legend is True:
676
+ axes[n_panels - 1].legend().set_visible(True)
677
+ axes[n_panels - 1].legend(
678
+ bbox_to_anchor=(0.5, -0.1), loc="upper center", ncol=adata.obs[pairedby].nunique()
679
+ )
680
+
681
+ plt.tight_layout()
682
+ if show:
683
+ plt.show()
684
+ if return_fig:
685
+ return plt.gcf()
686
+ return None
687
+
688
+ @_doc_params(common_plot_args=doc_common_plot_args)
689
+ def plot_fold_change(
690
+ self,
691
+ results_df: pd.DataFrame,
692
+ *,
693
+ var_names: Sequence[str] = None,
694
+ n_top_vars: int = 15,
695
+ log2fc_col: str = "log_fc",
696
+ symbol_col: str = "variable",
697
+ y_label: str = "Log2 fold change",
698
+ figsize: tuple[int, int] = (10, 5),
699
+ show: bool = True,
700
+ return_fig: bool = False,
701
+ **barplot_kwargs,
702
+ ) -> Figure | None:
703
+ """Plot a metric from the results as a bar chart, optionally with additional information about paired samples in a scatter plot.
704
+
705
+ Args:
706
+ results_df: DataFrame with results from DE analysis.
707
+ var_names: Variables to plot. If None, the top n_top_vars variables based on the log2 fold change are plotted.
708
+ n_top_vars: Number of top variables to plot. The top and bottom n_top_vars variables are plotted, respectively.
709
+ log2fc_col: Column name of log2 Fold-Change values.
710
+ symbol_col: Column name of gene IDs.
711
+ y_label: Label for the y-axis.
712
+ figsize: Size of the figure.
713
+ {common_plot_args}
714
+ **barplot_kwargs: Additional arguments for seaborn.barplot.
715
+
716
+ Returns:
717
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
718
+
719
+ Examples:
720
+ >>> # Example with EdgeR
721
+ >>> import pertpy as pt
722
+ >>> adata = pt.dt.zhang_2021()
723
+ >>> adata.layers["counts"] = adata.X.copy()
724
+ >>> ps = pt.tl.PseudobulkSpace()
725
+ >>> pdata = ps.compute(
726
+ ... adata,
727
+ ... target_col="Patient",
728
+ ... groups_col="Cluster",
729
+ ... layer_key="counts",
730
+ ... mode="sum",
731
+ ... min_cells=10,
732
+ ... min_counts=1000,
733
+ ... )
734
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
735
+ >>> edgr.fit()
736
+ >>> res_df = edgr.test_contrasts(
737
+ ... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
738
+ ... )
739
+ >>> edgr.plot_fold_change(res_df)
740
+
741
+ Preview:
742
+ .. image:: /_static/docstring_previews/de_fold_change.png
743
+ """
744
+ if var_names is None:
745
+ var_names = results_df.sort_values(log2fc_col, ascending=False).head(n_top_vars)[symbol_col].tolist()
746
+ var_names += results_df.sort_values(log2fc_col, ascending=True).head(n_top_vars)[symbol_col].tolist()
747
+ assert len(var_names) == 2 * n_top_vars
748
+
749
+ df = results_df[results_df[symbol_col].isin(var_names)]
750
+ df.sort_values(log2fc_col, ascending=False, inplace=True)
751
+
752
+ plt.figure(figsize=figsize)
753
+ sns.barplot(
754
+ x=symbol_col,
755
+ y=log2fc_col,
756
+ data=df,
757
+ palette="RdBu",
758
+ legend=False,
759
+ **barplot_kwargs,
760
+ )
761
+ plt.xticks(rotation=90)
762
+ plt.xlabel("")
763
+ plt.ylabel(y_label)
764
+
765
+ if show:
766
+ plt.show()
767
+ if return_fig:
768
+ return plt.gcf()
769
+ return None
770
+
771
+ @_doc_params(common_plot_args=doc_common_plot_args)
772
+ def plot_multicomparison_fc(
773
+ self,
774
+ results_df: pd.DataFrame,
775
+ *,
776
+ n_top_vars=15,
777
+ contrast_col: str = "contrast",
778
+ log2fc_col: str = "log_fc",
779
+ pvalue_col: str = "adj_p_value",
780
+ symbol_col: str = "variable",
781
+ marker_size: int = 100,
782
+ figsize: tuple[int, int] = (10, 2),
783
+ x_label: str = "Contrast",
784
+ y_label: str = "Gene",
785
+ show: bool = True,
786
+ return_fig: bool = False,
787
+ **heatmap_kwargs,
788
+ ) -> Figure | None:
789
+ """Plot a matrix of log2 fold changes from the results.
790
+
791
+ Args:
792
+ results_df: DataFrame with results from DE analysis.
793
+ n_top_vars: Number of top variables to plot per group.
794
+ contrast_col: Column in results_df containing information about the contrast.
795
+ log2fc_col: Column in results_df containing the log2 fold change.
796
+ pvalue_col: Column in results_df containing the p-value. Can be used to switch between adjusted and unadjusted p-values.
797
+ symbol_col: Column in results_df containing the gene symbol.
798
+ marker_size: Size of the biggest marker for significant variables.
799
+ figsize: Size of the figure.
800
+ x_label: Label for the x-axis.
801
+ y_label: Label for the y-axis.
802
+ {common_plot_args}
803
+ **heatmap_kwargs: Additional arguments for seaborn.heatmap.
804
+
805
+ Returns:
806
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
807
+
808
+ Examples:
809
+ >>> # Example with EdgeR
810
+ >>> import pertpy as pt
811
+ >>> adata = pt.dt.zhang_2021()
812
+ >>> adata.layers["counts"] = adata.X.copy()
813
+ >>> ps = pt.tl.PseudobulkSpace()
814
+ >>> pdata = ps.compute(
815
+ ... adata,
816
+ ... target_col="Patient",
817
+ ... groups_col="Cluster",
818
+ ... layer_key="counts",
819
+ ... mode="sum",
820
+ ... min_cells=10,
821
+ ... min_counts=1000,
822
+ ... )
823
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
824
+ >>> res_df = edgr.compare_groups(pdata, column="Efficacy", baseline="SD", groups_to_compare=["PR", "PD"])
825
+ >>> edgr.plot_multicomparison_fc(res_df)
826
+
827
+ Preview:
828
+ .. image:: /_static/docstring_previews/de_multicomparison_fc.png
829
+ """
830
+ groups = results_df[contrast_col].unique().tolist()
831
+
832
+ results_df["abs_log_fc"] = results_df[log2fc_col].abs()
833
+
834
+ def _get_significance(p_val):
835
+ if p_val < 0.001:
836
+ return "< 0.001"
837
+ elif p_val < 0.01:
838
+ return "< 0.01"
839
+ elif p_val < 0.1:
840
+ return "< 0.1"
841
+ else:
842
+ return "n.s."
843
+
844
+ results_df["significance"] = results_df[pvalue_col].apply(_get_significance)
845
+
846
+ var_names = []
847
+ for group in groups:
848
+ var_names += (
849
+ results_df[results_df[contrast_col] == group]
850
+ .sort_values("abs_log_fc", ascending=False)
851
+ .head(n_top_vars)[symbol_col]
852
+ .tolist()
853
+ )
854
+
855
+ results_df = results_df[results_df[symbol_col].isin(var_names)]
856
+ df = results_df.pivot(index=contrast_col, columns=symbol_col, values=log2fc_col)[var_names]
464
857
 
465
- plt.show()
858
+ plt.figure(figsize=figsize)
859
+ sns.heatmap(df, **heatmap_kwargs, cmap="coolwarm", center=0, cbar_kws={"label": "Log2 fold change"})
860
+
861
+ _size = {"< 0.001": marker_size, "< 0.01": math.floor(marker_size / 2), "< 0.1": math.floor(marker_size / 4)}
862
+ x_locs, x_labels = plt.xticks()[0], [label.get_text() for label in plt.xticks()[1]]
863
+ y_locs, y_labels = plt.yticks()[0], [label.get_text() for label in plt.yticks()[1]]
864
+
865
+ for _i, row in results_df.iterrows():
866
+ if row["significance"] != "n.s.":
867
+ plt.scatter(
868
+ x=x_locs[x_labels.index(row[symbol_col])],
869
+ y=y_locs[y_labels.index(row[contrast_col])],
870
+ s=_size[row["significance"]],
871
+ marker="*",
872
+ c="white",
873
+ )
874
+
875
+ plt.scatter([], [], s=marker_size, marker="*", c="black", label="< 0.001")
876
+ plt.scatter([], [], s=math.floor(marker_size / 2), marker="*", c="black", label="< 0.01")
877
+ plt.scatter([], [], s=math.floor(marker_size / 4), marker="*", c="black", label="< 0.1")
878
+ plt.legend(title="Significance", bbox_to_anchor=(1.2, -0.05))
879
+
880
+ plt.xlabel(x_label)
881
+ plt.ylabel(y_label)
882
+
883
+ if show:
884
+ plt.show()
885
+ if return_fig:
886
+ return plt.gcf()
887
+ return None
466
888
 
467
889
 
468
890
  class LinearModelBase(MethodBase):
469
891
  def __init__(self, adata, design, *, mask=None, layer=None, **kwargs):
470
- """
471
- Initialize the method.
892
+ """Initialize the method.
472
893
 
473
894
  Args:
474
895
  adata: AnnData object, usually pseudobulked.
@@ -480,26 +901,24 @@ class LinearModelBase(MethodBase):
480
901
  super().__init__(adata, mask=mask, layer=layer)
481
902
  self._check_counts()
482
903
 
483
- self.factor_storage = None
484
- self.variable_to_factors = None
485
-
904
+ self.formulaic_contrasts = None
486
905
  if isinstance(design, str):
487
- self.factor_storage, self.variable_to_factors, materializer_class = get_factor_storage_and_materializer()
488
- self.design = materializer_class(adata.obs, record_factor_metadata=True).get_model_matrix(design)
906
+ self.formulaic_contrasts = FormulaicContrasts(adata.obs, design)
907
+ self.design = self.formulaic_contrasts.design_matrix
489
908
  else:
490
909
  self.design = design
491
910
 
492
911
  @classmethod
493
912
  def compare_groups(
494
913
  cls,
495
- adata,
496
- column,
497
- baseline,
498
- groups_to_compare,
914
+ adata: ad.AnnData,
915
+ column: str,
916
+ baseline: str,
917
+ groups_to_compare: str | Iterable[str],
499
918
  *,
500
- paired_by=None,
501
- mask=None,
502
- layer=None,
919
+ paired_by: str | None = None,
920
+ mask: pd.Series | None = None,
921
+ layer: str | None = None,
503
922
  fit_kwargs=MappingProxyType({}),
504
923
  test_kwargs=MappingProxyType({}),
505
924
  ):
@@ -525,17 +944,16 @@ class LinearModelBase(MethodBase):
525
944
  @property
526
945
  def variables(self):
527
946
  """Get the names of the variables used in the model definition."""
528
- try:
529
- return self.design.model_spec.variables_by_source["data"]
530
- except AttributeError:
947
+ if self.formulaic_contrasts is None:
531
948
  raise ValueError(
532
949
  "Retrieving variables is only possible if the model was initialized using a formula."
533
950
  ) from None
951
+ else:
952
+ return self.formulaic_contrasts.variables
534
953
 
535
954
  @abstractmethod
536
955
  def _check_counts(self):
537
- """
538
- Check that counts are valid for the specific method.
956
+ """Check that counts are valid for the specific method.
539
957
 
540
958
  Raises:
541
959
  ValueError: if the data matrix does not comply with the expectations.
@@ -544,8 +962,7 @@ class LinearModelBase(MethodBase):
544
962
 
545
963
  @abstractmethod
546
964
  def fit(self, **kwargs):
547
- """
548
- Fit the model.
965
+ """Fit the model.
549
966
 
550
967
  Args:
551
968
  **kwargs: Additional arguments for fitting the specific method.
@@ -555,9 +972,8 @@ class LinearModelBase(MethodBase):
555
972
  @abstractmethod
556
973
  def _test_single_contrast(self, contrast, **kwargs): ...
557
974
 
558
- def test_contrasts(self, contrasts, **kwargs):
559
- """
560
- Perform a comparison as specified in a contrast vector.
975
+ def test_contrasts(self, contrasts: np.ndarray | Mapping[str | None, np.ndarray], **kwargs):
976
+ """Perform a comparison as specified in a contrast vector.
561
977
 
562
978
  Args:
563
979
  contrasts: Either a numeric contrast vector, or a dictionary of numeric contrast vectors.
@@ -573,25 +989,25 @@ class LinearModelBase(MethodBase):
573
989
  results.append(self._test_single_contrast(contrast, **kwargs).assign(contrast=name))
574
990
 
575
991
  results_df = pd.concat(results)
992
+
576
993
  return results_df
577
994
 
578
995
  def test_reduced(self, modelB):
579
- """
580
- Test against a reduced model.
996
+ """Test against a reduced model.
581
997
 
582
998
  Args:
583
999
  modelB: the reduced model against which to test.
584
1000
 
585
1001
  Example:
586
- modelA = Model().fit()
587
- modelB = Model().fit()
588
- modelA.test_reduced(modelB)
1002
+ >>> import pertpy as pt
1003
+ >>> modelA = Model().fit()
1004
+ >>> modelB = Model().fit()
1005
+ >>> modelA.test_reduced(modelB)
589
1006
  """
590
1007
  raise NotImplementedError
591
1008
 
592
1009
  def cond(self, **kwargs):
593
- """
594
- Get a contrast vector representing a specific condition.
1010
+ """Get a contrast vector representing a specific condition.
595
1011
 
596
1012
  Args:
597
1013
  **kwargs: column/value pairs.
@@ -599,52 +1015,14 @@ class LinearModelBase(MethodBase):
599
1015
  Returns:
600
1016
  A contrast vector that aligns to the columns of the design matrix.
601
1017
  """
602
- if self.factor_storage is None:
1018
+ if self.formulaic_contrasts is None:
603
1019
  raise RuntimeError(
604
1020
  "Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
605
1021
  )
606
- cond_dict = kwargs
607
- if not set(cond_dict.keys()).issubset(self.variables):
608
- raise ValueError(
609
- "You specified a variable that is not part of the model. Available variables: "
610
- + ",".join(self.variables)
611
- )
612
- for var in self.variables:
613
- if var in cond_dict:
614
- self._check_category(var, cond_dict[var])
615
- else:
616
- cond_dict[var] = self._get_default_value(var)
617
- df = pd.DataFrame([kwargs])
618
- return self.design.model_spec.get_model_matrix(df).iloc[0]
619
-
620
- def _get_factor_metadata_for_variable(self, var):
621
- factors = self.variable_to_factors[var]
622
- return list(chain.from_iterable(self.factor_storage[f] for f in factors))
623
-
624
- def _get_default_value(self, var):
625
- factor_metadata = self._get_factor_metadata_for_variable(var)
626
- if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL:
627
- try:
628
- tmp_base = resolve_ambiguous(factor_metadata, "base")
629
- except AmbiguousAttributeError as e:
630
- raise ValueError(
631
- f"Could not automatically resolve base category for variable {var}. Please specify it explicity in `model.cond`."
632
- ) from e
633
- return tmp_base if tmp_base is not None else "\0"
634
- else:
635
- return 0
636
-
637
- def _check_category(self, var, value):
638
- factor_metadata = self._get_factor_metadata_for_variable(var)
639
- tmp_categories = resolve_ambiguous(factor_metadata, "categories")
640
- if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL and value not in tmp_categories:
641
- raise ValueError(
642
- f"You specified a non-existant category for {var}. Possible categories: {', '.join(tmp_categories)}"
643
- )
1022
+ return self.formulaic_contrasts.cond(**kwargs)
644
1023
 
645
- def contrast(self, column, baseline, group_to_compare):
646
- """
647
- Build a simple contrast for pairwise comparisons.
1024
+ def contrast(self, *args, **kwargs):
1025
+ """Build a simple contrast for pairwise comparisons.
648
1026
 
649
1027
  Args:
650
1028
  column: column in adata.obs to test on.
@@ -654,4 +1032,8 @@ class LinearModelBase(MethodBase):
654
1032
  Returns:
655
1033
  Numeric contrast vector.
656
1034
  """
657
- return self.cond(**{column: group_to_compare}) - self.cond(**{column: baseline})
1035
+ if self.formulaic_contrasts is None:
1036
+ raise RuntimeError(
1037
+ "Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
1038
+ )
1039
+ return self.formulaic_contrasts.contrast(*args, **kwargs)