pertpy 0.9.3__py3-none-any.whl → 0.9.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,7 +1,7 @@
1
- import os
1
+ import math
2
2
  from abc import ABC, abstractmethod
3
- from dataclasses import dataclass
4
- from itertools import chain
3
+ from collections.abc import Iterable, Mapping, Sequence
4
+ from itertools import zip_longest
5
5
  from types import MappingProxyType
6
6
 
7
7
  import adjustText
@@ -11,27 +11,14 @@ import matplotlib.pyplot as plt
11
11
  import numpy as np
12
12
  import pandas as pd
13
13
  import seaborn as sns
14
+ from formulaic_contrasts import FormulaicContrasts
15
+ from lamin_utils import logger
16
+ from matplotlib.pyplot import Figure
14
17
  from matplotlib.ticker import MaxNLocator
15
18
 
19
+ from pertpy._doc import _doc_params, doc_common_plot_args
20
+ from pertpy.tools import PseudobulkSpace
16
21
  from pertpy.tools._differential_gene_expression._checks import check_is_numeric_matrix
17
- from pertpy.tools._differential_gene_expression._formulaic import (
18
- AmbiguousAttributeError,
19
- Factor,
20
- get_factor_storage_and_materializer,
21
- resolve_ambiguous,
22
- )
23
-
24
-
25
- @dataclass
26
- class Contrast:
27
- """Simple contrast for comparison between groups"""
28
-
29
- column: str
30
- baseline: str
31
- group_to_compare: str
32
-
33
-
34
- ContrastType = Contrast | tuple[str, str, str]
35
22
 
36
23
 
37
24
  class MethodBase(ABC):
@@ -58,7 +45,7 @@ class MethodBase(ABC):
58
45
  if self.layer is None:
59
46
  return self.adata.X
60
47
  else:
61
- return self.adata.layer[self.layer]
48
+ return self.adata.layers[self.layer]
62
49
 
63
50
  @classmethod
64
51
  @abstractmethod
@@ -91,9 +78,28 @@ class MethodBase(ABC):
91
78
 
92
79
  Returns:
93
80
  Pandas dataframe with results ordered by significance. If multiple comparisons were performed this is indicated in an additional column.
81
+
82
+ Examples:
83
+ >>> # Example with EdgeR
84
+ >>> import pertpy as pt
85
+ >>> adata = pt.dt.zhang_2021()
86
+ >>> adata.layers["counts"] = adata.X.copy()
87
+ >>> ps = pt.tl.PseudobulkSpace()
88
+ >>> pdata = ps.compute(
89
+ ... adata,
90
+ ... target_col="Patient",
91
+ ... groups_col="Cluster",
92
+ ... layer_key="counts",
93
+ ... mode="sum",
94
+ ... min_cells=10,
95
+ ... min_counts=1000,
96
+ ... )
97
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
98
+ >>> res_df = edgr.compare_groups(pdata, column="Efficacy", baseline="SD", groups_to_compare=["PR", "PD"])
94
99
  """
95
100
  ...
96
101
 
102
+ @_doc_params(common_plot_args=doc_common_plot_args)
97
103
  def plot_volcano(
98
104
  self,
99
105
  data: pd.DataFrame | ad.AnnData,
@@ -115,13 +121,14 @@ class MethodBase(ABC):
115
121
  figsize: tuple[int, int] = (5, 5),
116
122
  legend_pos: tuple[float, float] = (1.6, 1),
117
123
  point_sizes: tuple[int, int] = (15, 150),
118
- save: bool | str | None = None,
119
124
  shapes: list[str] | None = None,
120
125
  shape_order: list[str] | None = None,
121
126
  x_label: str | None = None,
122
127
  y_label: str | None = None,
128
+ show: bool = True,
129
+ return_fig: bool = False,
123
130
  **kwargs: int,
124
- ) -> None:
131
+ ) -> Figure | None:
125
132
  """Creates a volcano plot from a pandas DataFrame or Anndata.
126
133
 
127
134
  Args:
@@ -143,12 +150,40 @@ class MethodBase(ABC):
143
150
  top_right_frame: Whether to show the top and right frame of the plot.
144
151
  figsize: Size of the figure.
145
152
  legend_pos: Position of the legend as determined by matplotlib.
146
- save: Saves the plot if True or to the path provided.
147
153
  shapes: List of matplotlib marker ids.
148
154
  shape_order: Order of categories for shapes.
149
155
  x_label: Label for the x-axis.
150
156
  y_label: Label for the y-axis.
157
+ {common_plot_args}
151
158
  **kwargs: Additional arguments for seaborn.scatterplot.
159
+
160
+ Returns:
161
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
162
+
163
+ Examples:
164
+ >>> # Example with EdgeR
165
+ >>> import pertpy as pt
166
+ >>> adata = pt.dt.zhang_2021()
167
+ >>> adata.layers["counts"] = adata.X.copy()
168
+ >>> ps = pt.tl.PseudobulkSpace()
169
+ >>> pdata = ps.compute(
170
+ ... adata,
171
+ ... target_col="Patient",
172
+ ... groups_col="Cluster",
173
+ ... layer_key="counts",
174
+ ... mode="sum",
175
+ ... min_cells=10,
176
+ ... min_counts=1000,
177
+ ... )
178
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
179
+ >>> edgr.fit()
180
+ >>> res_df = edgr.test_contrasts(
181
+ ... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
182
+ ... )
183
+ >>> edgr.plot_volcano(res_df, log2fc_thresh=0)
184
+
185
+ Preview:
186
+ .. image:: /_static/docstring_previews/de_volcano.png
152
187
  """
153
188
  if colors is None:
154
189
  colors = ["gray", "#D62728", "#1F77B4"]
@@ -243,7 +278,7 @@ class MethodBase(ABC):
243
278
  if varm_key is None:
244
279
  raise ValueError("Please pass a .varm key to use for plotting")
245
280
 
246
- raise NotImplementedError("Anndata not implemented yet")
281
+ raise NotImplementedError("Anndata not implemented yet") # TODO: Implement this
247
282
  df = data.varm[varm_key].copy()
248
283
 
249
284
  df = data.copy(deep=True)
@@ -449,26 +484,412 @@ class MethodBase(ABC):
449
484
 
450
485
  plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
451
486
 
452
- # TODO replace with scanpy save style
453
- if save:
454
- files = os.listdir()
455
- for x in range(100):
456
- file_pref = "volcano_" + "%02d" % (x,)
457
- if len([x for x in files if x.startswith(file_pref)]) == 0:
458
- plt.savefig(file_pref + ".png", dpi=300, bbox_inches="tight")
459
- plt.savefig(file_pref + ".svg", bbox_inches="tight")
460
- break
461
- elif isinstance(save, str):
462
- plt.savefig(save + ".png", dpi=300, bbox_inches="tight")
463
- plt.savefig(save + ".svg", bbox_inches="tight")
487
+ if show:
488
+ plt.show()
489
+ if return_fig:
490
+ return plt.gcf()
491
+ return None
492
+
493
+ @_doc_params(common_plot_args=doc_common_plot_args)
494
+ def plot_paired(
495
+ self,
496
+ adata: ad.AnnData,
497
+ results_df: pd.DataFrame,
498
+ groupby: str,
499
+ pairedby: str,
500
+ *,
501
+ var_names: Sequence[str] = None,
502
+ n_top_vars: int = 15,
503
+ layer: str = None,
504
+ pvalue_col: str = "adj_p_value",
505
+ symbol_col: str = "variable",
506
+ n_cols: int = 4,
507
+ panel_size: tuple[int, int] = (5, 5),
508
+ show_legend: bool = True,
509
+ size: int = 10,
510
+ y_label: str = "expression",
511
+ pvalue_template=lambda x: f"p={x:.2e}",
512
+ boxplot_properties=None,
513
+ palette=None,
514
+ show: bool = True,
515
+ return_fig: bool = False,
516
+ ) -> Figure | None:
517
+ """Creates a pairwise expression plot from a Pandas DataFrame or Anndata.
518
+
519
+ Visualizes a panel of paired scatterplots per variable.
520
+
521
+ Args:
522
+ adata: AnnData object, can be pseudobulked.
523
+ results_df: DataFrame with results from a differential expression test.
524
+ groupby: .obs column containing the grouping. Must contain exactly two different values.
525
+ pairedby: .obs column containing the pairing (e.g. "patient_id"). If None, an independent t-test is performed.
526
+ var_names: Variables to plot.
527
+ n_top_vars: Number of top variables to plot.
528
+ layer: Layer to use for plotting.
529
+ pvalue_col: Column name of the p values.
530
+ symbol_col: Column name of gene IDs.
531
+ n_cols: Number of columns in the plot.
532
+ panel_size: Size of each panel.
533
+ show_legend: Whether to show the legend.
534
+ size: Size of the points.
535
+ y_label: Label for the y-axis.
536
+ pvalue_template: Template for the p-value string displayed in the title of each panel.
537
+ boxplot_properties: Additional properties for the boxplot, passed to seaborn.boxplot.
538
+ palette: Color palette for the line- and stripplot.
539
+ {common_plot_args}
540
+
541
+ Returns:
542
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
543
+
544
+ Examples:
545
+ >>> # Example with EdgeR
546
+ >>> import pertpy as pt
547
+ >>> adata = pt.dt.zhang_2021()
548
+ >>> adata.layers["counts"] = adata.X.copy()
549
+ >>> ps = pt.tl.PseudobulkSpace()
550
+ >>> pdata = ps.compute(
551
+ ... adata,
552
+ ... target_col="Patient",
553
+ ... groups_col="Cluster",
554
+ ... layer_key="counts",
555
+ ... mode="sum",
556
+ ... min_cells=10,
557
+ ... min_counts=1000,
558
+ ... )
559
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
560
+ >>> edgr.fit()
561
+ >>> res_df = edgr.test_contrasts(
562
+ ... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
563
+ ... )
564
+ >>> edgr.plot_paired(pdata, results_df=res_df, n_top_vars=8, groupby="Treatment", pairedby="Efficacy")
565
+
566
+ Preview:
567
+ .. image:: /_static/docstring_previews/de_paired_expression.png
568
+ """
569
+ if boxplot_properties is None:
570
+ boxplot_properties = {}
571
+ groups = adata.obs[groupby].unique()
572
+ if len(groups) != 2:
573
+ raise ValueError("The number of groups in the group_by column must be exactly 2 to enable paired testing")
574
+
575
+ if var_names is None:
576
+ var_names = results_df.head(n_top_vars)[symbol_col].tolist()
577
+
578
+ adata = adata[:, var_names]
579
+
580
+ if any(adata.obs[[groupby, pairedby]].value_counts() > 1):
581
+ logger.info("Performing pseudobulk for paired samples")
582
+ ps = PseudobulkSpace()
583
+ adata = ps.compute(
584
+ adata, target_col=groupby, groups_col=pairedby, layer_key=layer, mode="sum", min_cells=1, min_counts=1
585
+ )
586
+
587
+ if layer is not None:
588
+ X = adata.layers[layer]
589
+ else:
590
+ X = adata.X
591
+ try:
592
+ X = X.toarray()
593
+ except AttributeError:
594
+ pass
595
+
596
+ groupby_cols = [pairedby, groupby]
597
+ df = adata.obs.loc[:, groupby_cols].join(pd.DataFrame(X, index=adata.obs_names, columns=var_names))
598
+
599
+ # remove unpaired samples
600
+ paired_samples = set(df[df[groupby] == groups[0]][pairedby]) & set(df[df[groupby] == groups[1]][pairedby])
601
+ df = df[df[pairedby].isin(paired_samples)]
602
+ removed_samples = adata.obs[pairedby].nunique() - len(df[pairedby].unique())
603
+ if removed_samples > 0:
604
+ logger.warning(f"{removed_samples} unpaired samples removed")
605
+
606
+ pvalues = results_df.set_index(symbol_col).loc[var_names, pvalue_col].values
607
+ df.reset_index(drop=False, inplace=True)
608
+
609
+ # transform data for seaborn
610
+ df_melt = df.melt(
611
+ id_vars=groupby_cols,
612
+ var_name="var",
613
+ value_name="val",
614
+ )
615
+
616
+ n_panels = len(var_names)
617
+ nrows = math.ceil(n_panels / n_cols)
618
+ ncols = min(n_cols, n_panels)
619
+
620
+ fig, axes = plt.subplots(
621
+ nrows,
622
+ ncols,
623
+ figsize=(ncols * panel_size[0], nrows * panel_size[1]),
624
+ tight_layout=True,
625
+ squeeze=False,
626
+ )
627
+ axes = axes.flatten()
628
+ for i, (var, ax) in enumerate(zip_longest(var_names, axes)):
629
+ if var is not None:
630
+ sns.boxplot(
631
+ x=groupby,
632
+ data=df_melt.loc[df_melt["var"] == var],
633
+ y="val",
634
+ ax=ax,
635
+ color="white",
636
+ fliersize=0,
637
+ **boxplot_properties,
638
+ )
639
+ if pairedby is not None:
640
+ sns.lineplot(
641
+ x=groupby,
642
+ data=df_melt.loc[df_melt["var"] == var],
643
+ y="val",
644
+ ax=ax,
645
+ hue=pairedby,
646
+ legend=False,
647
+ errorbar=None,
648
+ palette=palette,
649
+ )
650
+ jitter = 0 if pairedby else True
651
+ sns.stripplot(
652
+ x=groupby,
653
+ data=df_melt.loc[df_melt["var"] == var],
654
+ y="val",
655
+ ax=ax,
656
+ hue=pairedby,
657
+ jitter=jitter,
658
+ size=size,
659
+ linewidth=1,
660
+ palette=palette,
661
+ )
662
+
663
+ ax.set_xlabel("")
664
+ ax.tick_params(
665
+ axis="x",
666
+ labelsize=15,
667
+ )
668
+ ax.legend().set_visible(False)
669
+ ax.set_ylabel(y_label)
670
+ ax.set_title(f"{var}\n{pvalue_template(pvalues[i])}")
671
+ else:
672
+ ax.set_visible(False)
673
+ fig.tight_layout()
674
+
675
+ if show_legend is True:
676
+ axes[n_panels - 1].legend().set_visible(True)
677
+ axes[n_panels - 1].legend(
678
+ bbox_to_anchor=(0.5, -0.1), loc="upper center", ncol=adata.obs[pairedby].nunique()
679
+ )
680
+
681
+ plt.tight_layout()
682
+ if show:
683
+ plt.show()
684
+ if return_fig:
685
+ return plt.gcf()
686
+ return None
687
+
688
+ @_doc_params(common_plot_args=doc_common_plot_args)
689
+ def plot_fold_change(
690
+ self,
691
+ results_df: pd.DataFrame,
692
+ *,
693
+ var_names: Sequence[str] = None,
694
+ n_top_vars: int = 15,
695
+ log2fc_col: str = "log_fc",
696
+ symbol_col: str = "variable",
697
+ y_label: str = "Log2 fold change",
698
+ figsize: tuple[int, int] = (10, 5),
699
+ show: bool = True,
700
+ return_fig: bool = False,
701
+ **barplot_kwargs,
702
+ ) -> Figure | None:
703
+ """Plot a metric from the results as a bar chart, optionally with additional information about paired samples in a scatter plot.
704
+
705
+ Args:
706
+ results_df: DataFrame with results from DE analysis.
707
+ var_names: Variables to plot. If None, the top n_top_vars variables based on the log2 fold change are plotted.
708
+ n_top_vars: Number of top variables to plot. The top and bottom n_top_vars variables are plotted, respectively.
709
+ log2fc_col: Column name of log2 Fold-Change values.
710
+ symbol_col: Column name of gene IDs.
711
+ y_label: Label for the y-axis.
712
+ figsize: Size of the figure.
713
+ {common_plot_args}
714
+ **barplot_kwargs: Additional arguments for seaborn.barplot.
715
+
716
+ Returns:
717
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
718
+
719
+ Examples:
720
+ >>> # Example with EdgeR
721
+ >>> import pertpy as pt
722
+ >>> adata = pt.dt.zhang_2021()
723
+ >>> adata.layers["counts"] = adata.X.copy()
724
+ >>> ps = pt.tl.PseudobulkSpace()
725
+ >>> pdata = ps.compute(
726
+ ... adata,
727
+ ... target_col="Patient",
728
+ ... groups_col="Cluster",
729
+ ... layer_key="counts",
730
+ ... mode="sum",
731
+ ... min_cells=10,
732
+ ... min_counts=1000,
733
+ ... )
734
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
735
+ >>> edgr.fit()
736
+ >>> res_df = edgr.test_contrasts(
737
+ ... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
738
+ ... )
739
+ >>> edgr.plot_fold_change(res_df)
740
+
741
+ Preview:
742
+ .. image:: /_static/docstring_previews/de_fold_change.png
743
+ """
744
+ if var_names is None:
745
+ var_names = results_df.sort_values(log2fc_col, ascending=False).head(n_top_vars)[symbol_col].tolist()
746
+ var_names += results_df.sort_values(log2fc_col, ascending=True).head(n_top_vars)[symbol_col].tolist()
747
+ assert len(var_names) == 2 * n_top_vars
748
+
749
+ df = results_df[results_df[symbol_col].isin(var_names)]
750
+ df.sort_values(log2fc_col, ascending=False, inplace=True)
751
+
752
+ plt.figure(figsize=figsize)
753
+ sns.barplot(
754
+ x=symbol_col,
755
+ y=log2fc_col,
756
+ data=df,
757
+ palette="RdBu",
758
+ legend=False,
759
+ **barplot_kwargs,
760
+ )
761
+ plt.xticks(rotation=90)
762
+ plt.xlabel("")
763
+ plt.ylabel(y_label)
764
+
765
+ if show:
766
+ plt.show()
767
+ if return_fig:
768
+ return plt.gcf()
769
+ return None
770
+
771
+ @_doc_params(common_plot_args=doc_common_plot_args)
772
+ def plot_multicomparison_fc(
773
+ self,
774
+ results_df: pd.DataFrame,
775
+ *,
776
+ n_top_vars=15,
777
+ contrast_col: str = "contrast",
778
+ log2fc_col: str = "log_fc",
779
+ pvalue_col: str = "adj_p_value",
780
+ symbol_col: str = "variable",
781
+ marker_size: int = 100,
782
+ figsize: tuple[int, int] = (10, 2),
783
+ x_label: str = "Contrast",
784
+ y_label: str = "Gene",
785
+ show: bool = True,
786
+ return_fig: bool = False,
787
+ **heatmap_kwargs,
788
+ ) -> Figure | None:
789
+ """Plot a matrix of log2 fold changes from the results.
790
+
791
+ Args:
792
+ results_df: DataFrame with results from DE analysis.
793
+ n_top_vars: Number of top variables to plot per group.
794
+ contrast_col: Column in results_df containing information about the contrast.
795
+ log2fc_col: Column in results_df containing the log2 fold change.
796
+ pvalue_col: Column in results_df containing the p-value. Can be used to switch between adjusted and unadjusted p-values.
797
+ symbol_col: Column in results_df containing the gene symbol.
798
+ marker_size: Size of the biggest marker for significant variables.
799
+ figsize: Size of the figure.
800
+ x_label: Label for the x-axis.
801
+ y_label: Label for the y-axis.
802
+ {common_plot_args}
803
+ **heatmap_kwargs: Additional arguments for seaborn.heatmap.
804
+
805
+ Returns:
806
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
807
+
808
+ Examples:
809
+ >>> # Example with EdgeR
810
+ >>> import pertpy as pt
811
+ >>> adata = pt.dt.zhang_2021()
812
+ >>> adata.layers["counts"] = adata.X.copy()
813
+ >>> ps = pt.tl.PseudobulkSpace()
814
+ >>> pdata = ps.compute(
815
+ ... adata,
816
+ ... target_col="Patient",
817
+ ... groups_col="Cluster",
818
+ ... layer_key="counts",
819
+ ... mode="sum",
820
+ ... min_cells=10,
821
+ ... min_counts=1000,
822
+ ... )
823
+ >>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
824
+ >>> res_df = edgr.compare_groups(pdata, column="Efficacy", baseline="SD", groups_to_compare=["PR", "PD"])
825
+ >>> edgr.plot_multicomparison_fc(res_df)
826
+
827
+ Preview:
828
+ .. image:: /_static/docstring_previews/de_multicomparison_fc.png
829
+ """
830
+ groups = results_df[contrast_col].unique().tolist()
831
+
832
+ results_df["abs_log_fc"] = results_df[log2fc_col].abs()
833
+
834
+ def _get_significance(p_val):
835
+ if p_val < 0.001:
836
+ return "< 0.001"
837
+ elif p_val < 0.01:
838
+ return "< 0.01"
839
+ elif p_val < 0.1:
840
+ return "< 0.1"
841
+ else:
842
+ return "n.s."
843
+
844
+ results_df["significance"] = results_df[pvalue_col].apply(_get_significance)
845
+
846
+ var_names = []
847
+ for group in groups:
848
+ var_names += (
849
+ results_df[results_df[contrast_col] == group]
850
+ .sort_values("abs_log_fc", ascending=False)
851
+ .head(n_top_vars)[symbol_col]
852
+ .tolist()
853
+ )
854
+
855
+ results_df = results_df[results_df[symbol_col].isin(var_names)]
856
+ df = results_df.pivot(index=contrast_col, columns=symbol_col, values=log2fc_col)[var_names]
464
857
 
465
- plt.show()
858
+ plt.figure(figsize=figsize)
859
+ sns.heatmap(df, **heatmap_kwargs, cmap="coolwarm", center=0, cbar_kws={"label": "Log2 fold change"})
860
+
861
+ _size = {"< 0.001": marker_size, "< 0.01": math.floor(marker_size / 2), "< 0.1": math.floor(marker_size / 4)}
862
+ x_locs, x_labels = plt.xticks()[0], [label.get_text() for label in plt.xticks()[1]]
863
+ y_locs, y_labels = plt.yticks()[0], [label.get_text() for label in plt.yticks()[1]]
864
+
865
+ for _i, row in results_df.iterrows():
866
+ if row["significance"] != "n.s.":
867
+ plt.scatter(
868
+ x=x_locs[x_labels.index(row[symbol_col])],
869
+ y=y_locs[y_labels.index(row[contrast_col])],
870
+ s=_size[row["significance"]],
871
+ marker="*",
872
+ c="white",
873
+ )
874
+
875
+ plt.scatter([], [], s=marker_size, marker="*", c="black", label="< 0.001")
876
+ plt.scatter([], [], s=math.floor(marker_size / 2), marker="*", c="black", label="< 0.01")
877
+ plt.scatter([], [], s=math.floor(marker_size / 4), marker="*", c="black", label="< 0.1")
878
+ plt.legend(title="Significance", bbox_to_anchor=(1.2, -0.05))
879
+
880
+ plt.xlabel(x_label)
881
+ plt.ylabel(y_label)
882
+
883
+ if show:
884
+ plt.show()
885
+ if return_fig:
886
+ return plt.gcf()
887
+ return None
466
888
 
467
889
 
468
890
  class LinearModelBase(MethodBase):
469
891
  def __init__(self, adata, design, *, mask=None, layer=None, **kwargs):
470
- """
471
- Initialize the method.
892
+ """Initialize the method.
472
893
 
473
894
  Args:
474
895
  adata: AnnData object, usually pseudobulked.
@@ -480,26 +901,24 @@ class LinearModelBase(MethodBase):
480
901
  super().__init__(adata, mask=mask, layer=layer)
481
902
  self._check_counts()
482
903
 
483
- self.factor_storage = None
484
- self.variable_to_factors = None
485
-
904
+ self.formulaic_contrasts = None
486
905
  if isinstance(design, str):
487
- self.factor_storage, self.variable_to_factors, materializer_class = get_factor_storage_and_materializer()
488
- self.design = materializer_class(adata.obs, record_factor_metadata=True).get_model_matrix(design)
906
+ self.formulaic_contrasts = FormulaicContrasts(adata.obs, design)
907
+ self.design = self.formulaic_contrasts.design_matrix
489
908
  else:
490
909
  self.design = design
491
910
 
492
911
  @classmethod
493
912
  def compare_groups(
494
913
  cls,
495
- adata,
496
- column,
497
- baseline,
498
- groups_to_compare,
914
+ adata: ad.AnnData,
915
+ column: str,
916
+ baseline: str,
917
+ groups_to_compare: str | Iterable[str],
499
918
  *,
500
- paired_by=None,
501
- mask=None,
502
- layer=None,
919
+ paired_by: str | None = None,
920
+ mask: pd.Series | None = None,
921
+ layer: str | None = None,
503
922
  fit_kwargs=MappingProxyType({}),
504
923
  test_kwargs=MappingProxyType({}),
505
924
  ):
@@ -525,17 +944,16 @@ class LinearModelBase(MethodBase):
525
944
  @property
526
945
  def variables(self):
527
946
  """Get the names of the variables used in the model definition."""
528
- try:
529
- return self.design.model_spec.variables_by_source["data"]
530
- except AttributeError:
947
+ if self.formulaic_contrasts is None:
531
948
  raise ValueError(
532
949
  "Retrieving variables is only possible if the model was initialized using a formula."
533
950
  ) from None
951
+ else:
952
+ return self.formulaic_contrasts.variables
534
953
 
535
954
  @abstractmethod
536
955
  def _check_counts(self):
537
- """
538
- Check that counts are valid for the specific method.
956
+ """Check that counts are valid for the specific method.
539
957
 
540
958
  Raises:
541
959
  ValueError: if the data matrix does not comply with the expectations.
@@ -544,8 +962,7 @@ class LinearModelBase(MethodBase):
544
962
 
545
963
  @abstractmethod
546
964
  def fit(self, **kwargs):
547
- """
548
- Fit the model.
965
+ """Fit the model.
549
966
 
550
967
  Args:
551
968
  **kwargs: Additional arguments for fitting the specific method.
@@ -555,9 +972,8 @@ class LinearModelBase(MethodBase):
555
972
  @abstractmethod
556
973
  def _test_single_contrast(self, contrast, **kwargs): ...
557
974
 
558
- def test_contrasts(self, contrasts, **kwargs):
559
- """
560
- Perform a comparison as specified in a contrast vector.
975
+ def test_contrasts(self, contrasts: np.ndarray | Mapping[str | None, np.ndarray], **kwargs):
976
+ """Perform a comparison as specified in a contrast vector.
561
977
 
562
978
  Args:
563
979
  contrasts: Either a numeric contrast vector, or a dictionary of numeric contrast vectors.
@@ -573,25 +989,25 @@ class LinearModelBase(MethodBase):
573
989
  results.append(self._test_single_contrast(contrast, **kwargs).assign(contrast=name))
574
990
 
575
991
  results_df = pd.concat(results)
992
+
576
993
  return results_df
577
994
 
578
995
  def test_reduced(self, modelB):
579
- """
580
- Test against a reduced model.
996
+ """Test against a reduced model.
581
997
 
582
998
  Args:
583
999
  modelB: the reduced model against which to test.
584
1000
 
585
1001
  Example:
586
- modelA = Model().fit()
587
- modelB = Model().fit()
588
- modelA.test_reduced(modelB)
1002
+ >>> import pertpy as pt
1003
+ >>> modelA = Model().fit()
1004
+ >>> modelB = Model().fit()
1005
+ >>> modelA.test_reduced(modelB)
589
1006
  """
590
1007
  raise NotImplementedError
591
1008
 
592
1009
  def cond(self, **kwargs):
593
- """
594
- Get a contrast vector representing a specific condition.
1010
+ """Get a contrast vector representing a specific condition.
595
1011
 
596
1012
  Args:
597
1013
  **kwargs: column/value pairs.
@@ -599,52 +1015,14 @@ class LinearModelBase(MethodBase):
599
1015
  Returns:
600
1016
  A contrast vector that aligns to the columns of the design matrix.
601
1017
  """
602
- if self.factor_storage is None:
1018
+ if self.formulaic_contrasts is None:
603
1019
  raise RuntimeError(
604
1020
  "Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
605
1021
  )
606
- cond_dict = kwargs
607
- if not set(cond_dict.keys()).issubset(self.variables):
608
- raise ValueError(
609
- "You specified a variable that is not part of the model. Available variables: "
610
- + ",".join(self.variables)
611
- )
612
- for var in self.variables:
613
- if var in cond_dict:
614
- self._check_category(var, cond_dict[var])
615
- else:
616
- cond_dict[var] = self._get_default_value(var)
617
- df = pd.DataFrame([kwargs])
618
- return self.design.model_spec.get_model_matrix(df).iloc[0]
619
-
620
- def _get_factor_metadata_for_variable(self, var):
621
- factors = self.variable_to_factors[var]
622
- return list(chain.from_iterable(self.factor_storage[f] for f in factors))
623
-
624
- def _get_default_value(self, var):
625
- factor_metadata = self._get_factor_metadata_for_variable(var)
626
- if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL:
627
- try:
628
- tmp_base = resolve_ambiguous(factor_metadata, "base")
629
- except AmbiguousAttributeError as e:
630
- raise ValueError(
631
- f"Could not automatically resolve base category for variable {var}. Please specify it explicity in `model.cond`."
632
- ) from e
633
- return tmp_base if tmp_base is not None else "\0"
634
- else:
635
- return 0
636
-
637
- def _check_category(self, var, value):
638
- factor_metadata = self._get_factor_metadata_for_variable(var)
639
- tmp_categories = resolve_ambiguous(factor_metadata, "categories")
640
- if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL and value not in tmp_categories:
641
- raise ValueError(
642
- f"You specified a non-existant category for {var}. Possible categories: {', '.join(tmp_categories)}"
643
- )
1022
+ return self.formulaic_contrasts.cond(**kwargs)
644
1023
 
645
- def contrast(self, column, baseline, group_to_compare):
646
- """
647
- Build a simple contrast for pairwise comparisons.
1024
+ def contrast(self, *args, **kwargs):
1025
+ """Build a simple contrast for pairwise comparisons.
648
1026
 
649
1027
  Args:
650
1028
  column: column in adata.obs to test on.
@@ -654,4 +1032,8 @@ class LinearModelBase(MethodBase):
654
1032
  Returns:
655
1033
  Numeric contrast vector.
656
1034
  """
657
- return self.cond(**{column: group_to_compare}) - self.cond(**{column: baseline})
1035
+ if self.formulaic_contrasts is None:
1036
+ raise RuntimeError(
1037
+ "Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
1038
+ )
1039
+ return self.formulaic_contrasts.contrast(*args, **kwargs)