pertpy 0.9.5__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. pertpy/__init__.py +5 -1
  2. pertpy/_doc.py +2 -5
  3. pertpy/_types.py +6 -0
  4. pertpy/data/_dataloader.py +68 -24
  5. pertpy/data/_datasets.py +9 -9
  6. pertpy/metadata/__init__.py +2 -1
  7. pertpy/metadata/_cell_line.py +136 -30
  8. pertpy/metadata/_look_up.py +13 -19
  9. pertpy/metadata/_moa.py +1 -1
  10. pertpy/preprocessing/_guide_rna.py +221 -39
  11. pertpy/preprocessing/_guide_rna_mixture.py +177 -0
  12. pertpy/tools/__init__.py +1 -1
  13. pertpy/tools/_augur.py +138 -142
  14. pertpy/tools/_cinemaot.py +75 -117
  15. pertpy/tools/_coda/_base_coda.py +150 -174
  16. pertpy/tools/_coda/_sccoda.py +66 -69
  17. pertpy/tools/_coda/_tasccoda.py +71 -79
  18. pertpy/tools/_dialogue.py +60 -56
  19. pertpy/tools/_differential_gene_expression/_base.py +25 -43
  20. pertpy/tools/_differential_gene_expression/_checks.py +4 -6
  21. pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
  22. pertpy/tools/_differential_gene_expression/_edger.py +6 -10
  23. pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
  24. pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
  25. pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
  26. pertpy/tools/_distances/_distance_tests.py +1 -2
  27. pertpy/tools/_distances/_distances.py +86 -92
  28. pertpy/tools/_enrichment.py +8 -25
  29. pertpy/tools/_milo.py +23 -27
  30. pertpy/tools/_mixscape.py +261 -175
  31. pertpy/tools/_perturbation_space/_clustering.py +4 -4
  32. pertpy/tools/_perturbation_space/_comparison.py +4 -4
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
  35. pertpy/tools/_perturbation_space/_simple.py +13 -17
  36. pertpy/tools/_scgen/_scgen.py +17 -20
  37. pertpy/tools/_scgen/_scgenvae.py +2 -2
  38. pertpy/tools/_scgen/_utils.py +3 -1
  39. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/METADATA +37 -21
  40. pertpy-0.11.0.dist-info/RECORD +58 -0
  41. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
  42. pertpy/tools/_kernel_pca.py +0 -50
  43. pertpy-0.9.5.dist-info/RECORD +0 -57
  44. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
pertpy/tools/_dialogue.py CHANGED
@@ -33,9 +33,17 @@ if TYPE_CHECKING:
33
33
 
34
34
 
35
35
  class Dialogue:
36
- """Python implementation of DIALOGUE"""
36
+ """Python implementation of DIALOGUE."""
37
37
 
38
- def __init__(self, sample_id: str, celltype_key: str, n_counts_key: str, n_mpcs: int):
38
+ def __init__(
39
+ self,
40
+ sample_id: str,
41
+ celltype_key: str,
42
+ n_counts_key: str,
43
+ n_mpcs: int,
44
+ feature_space_key: str = "X_pca",
45
+ n_components: int = 50,
46
+ ):
39
47
  """Constructor for Dialogue.
40
48
 
41
49
  Args:
@@ -43,6 +51,8 @@ class Dialogue:
43
51
  celltype_key: The key in AnnData.obs which contains the cell type column.
44
52
  n_counts_key: The key of the number of counts in Anndata.obs . Also commonly the size factor.
45
53
  n_mpcs: Number of PMD components which corresponds to the number of determined MCPs.
54
+ feature_space_key: The key in adata.obsm for the feature space (e.g., "X_pca", "X_umap").
55
+ n_components: The number of components of the feature space to use, e.g. PCA components.
46
56
  """
47
57
  self.sample_id = sample_id
48
58
  self.celltype_key = celltype_key
@@ -53,6 +63,8 @@ class Dialogue:
53
63
  )
54
64
  self.n_counts_key = n_counts_key
55
65
  self.n_mcps = n_mpcs
66
+ self.feature_space_key = feature_space_key
67
+ self.n_components = n_components
56
68
 
57
69
  def _get_pseudobulks(
58
70
  self, adata: AnnData, groupby: str, strategy: Literal["median", "mean"] = "median"
@@ -62,6 +74,7 @@ class Dialogue:
62
74
  Copied from `https://github.com/schillerlab/sc-toolbox/blob/397e80dc5e8fb8017b75f6c3fa634a1e1213d484/sc_toolbox/tools/__init__.py#L458`
63
75
 
64
76
  Args:
77
+ adata: Annotated data matrix.
65
78
  groupby: The key to groupby for pseudobulks
66
79
  strategy: The pseudobulking strategy. One of "median" or "mean"
67
80
 
@@ -82,27 +95,28 @@ class Dialogue:
82
95
 
83
96
  return pseudobulk
84
97
 
85
- def _pseudobulk_pca(self, adata: AnnData, groupby: str, n_components: int = 50) -> pd.DataFrame:
86
- """Return cell-averaged PCA components.
98
+ def _pseudobulk_feature_space(
99
+ self,
100
+ adata: AnnData,
101
+ groupby: str,
102
+ ) -> pd.DataFrame:
103
+ """Return Cell-averaged components from a passed feature space.
87
104
 
88
105
  TODO: consider merging with `get_pseudobulks`
89
106
  TODO: DIALOGUE recommends running PCA on each cell type separately before running PMD - this should be implemented as an option here.
90
107
 
91
108
  Args:
92
- groupby: The key to groupby for pseudobulks
93
- n_components: The number of PCA components
109
+ adata: Annotated data matrix.
110
+ groupby: The key to groupby for pseudobulks.
94
111
 
95
112
  Returns:
96
- A pseudobulk of PCA components.
113
+ A pseudobulk DataFrame of the averaged components.
97
114
  """
98
115
  aggr = {}
99
-
100
116
  for category in adata.obs.loc[:, groupby].cat.categories:
101
117
  temp = adata.obs.loc[:, groupby] == category
102
- aggr[category] = adata[temp].obsm["X_pca"][:, :n_components].mean(axis=0)
103
-
118
+ aggr[category] = adata[temp].obsm[self.feature_space_key][:, : self.n_components].mean(axis=0)
104
119
  aggr = pd.DataFrame(aggr)
105
-
106
120
  return aggr
107
121
 
108
122
  def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
@@ -130,6 +144,7 @@ class Dialogue:
130
144
 
131
145
  Args:
132
146
  adata: The AnnData object to append mcp scores to.
147
+ ct_subs: cell type objects.
133
148
  mcp_scores: The MCP scores dictionary.
134
149
  celltype_key: Key of the cell type column in obs.
135
150
 
@@ -213,7 +228,7 @@ class Dialogue:
213
228
  sample_obs: str,
214
229
  return_all: bool = False,
215
230
  ):
216
- """Applies a mixed linear model using the specified formula (MCP scores used for the dependent var) and returns the coefficient and p-value
231
+ """Applies a mixed linear model using the specified formula (MCP scores used for the dependent var) and returns the coefficient and p-value.
217
232
 
218
233
  TODO: reduce runtime? Maybe we can use an approximation or something that isn't statsmodels.
219
234
 
@@ -332,7 +347,7 @@ class Dialogue:
332
347
 
333
348
  Args:
334
349
  mcp_name: The name of the MCP to model.
335
- scores: The MCP scores for a cell type. Number of MCPs x number of features.
350
+ scores_df: The MCP scores for a cell type. Number of MCPs x number of features.
336
351
  ct_data: The AnnData object containing the metadata and labels in obs.
337
352
  tme: Transcript mean expression in `x`.
338
353
  sig: DataFrame containing a series of up and downregulated MCPs.
@@ -418,11 +433,10 @@ class Dialogue:
418
433
  # Finally get corr coeff
419
434
  return np.dot(A_mA, B_mB.T) / np.sqrt(np.dot(ssA[:, None], ssB[None]))
420
435
 
436
+ # TODO: needs check for correctness and variable renaming
437
+ # TODO: Confirm that this doesn't return duplicate gene names.
421
438
  def _get_top_elements(self, m: pd.DataFrame, max_length: int, min_threshold: float):
422
- """
423
-
424
- TODO: needs check for correctness and variable renaming
425
- TODO: Confirm that this doesn't return duplicate gene names
439
+ """Get top elements.
426
440
 
427
441
  Args:
428
442
  m: Any DataFrame of Gene name as index with variable columns.
@@ -457,12 +471,11 @@ class Dialogue:
457
471
  # TODO this whole function should be standalone
458
472
  # It will contain the calculation of up/down + calculation (new final mcp scores)
459
473
  # Ensure that it'll still fit/work with the hierarchical multilevel_modeling
460
-
461
474
  """Determine the up and down genes per MCP."""
462
475
  # TODO: something is slightly slow here
463
476
  cca_sig_results: dict[Any, dict[str, Any]] = {}
464
477
  new_mcp_scores: dict[Any, list[Any]] = {}
465
- for ct in ct_subs.keys():
478
+ for ct in ct_subs:
466
479
  ct_adata = ct_subs[ct]
467
480
  conf_m = ct_adata.obs[n_counts_key].values
468
481
 
@@ -483,9 +496,7 @@ class Dialogue:
483
496
  from scipy.stats import spearmanr
484
497
 
485
498
  def _pcor_mat(v1, v2, v3, method="spearman"):
486
- """
487
- MAJOR TODO: I've only used normal correlation instead of partial correlation as we wait on the implementation
488
- """
499
+ """MAJOR TODO: I've only used normal correlation instead of partial correlation as we wait on the implementation."""
489
500
  correlations = [] # R
490
501
  pvals = [] # P
491
502
  for x2 in v2:
@@ -506,7 +517,7 @@ class Dialogue:
506
517
  return np.array(correlations), np.array(pvals) # pvals_adjusted
507
518
 
508
519
  C1, P1 = _pcor_mat(ct_adata[:, top_cor_genes_flattened].X.toarray().T, mcp_scores[ct].T, conf_m)
509
- C1[P1 > (0.05 / ct_adata.shape[1])] = 0 # why?
520
+ C1[(0.05 / ct_adata.shape[1]) < P1] = 0 # why?
510
521
 
511
522
  cca_sig_unformatted = self._get_top_elements( # 3 up, 3 dn, for each mcp
512
523
  pd.DataFrame(C1.T, index=top_cor_genes_flattened), max_length=max_genes, min_threshold=0.05
@@ -514,7 +525,7 @@ class Dialogue:
514
525
 
515
526
  # TODO: probably format the up and down within get_top_elements
516
527
  cca_sig: dict[str, Any] = defaultdict(dict)
517
- for i in range(0, int(len(cca_sig_unformatted) / 2)):
528
+ for i in range(int(len(cca_sig_unformatted) / 2)):
518
529
  cca_sig[f"MCP{i}"]["up"] = cca_sig_unformatted[i * 2]
519
530
  cca_sig[f"MCP{i}"]["down"] = cca_sig_unformatted[i * 2 + 1]
520
531
 
@@ -523,7 +534,7 @@ class Dialogue:
523
534
 
524
535
  # This is basically DIALOGUE 3 now
525
536
  pre_r_scores = {
526
- ct: ct_subs[ct].obsm["X_pca"][:, :50] @ ws_dict[ct]
537
+ ct: ct_subs[ct].obsm[self.feature_space_key][:, : self.n_components] @ ws_dict[ct]
527
538
  for i, ct in enumerate(ct_subs.keys())
528
539
  # TODO This is a recalculation and not a new calculation
529
540
  }
@@ -558,7 +569,7 @@ class Dialogue:
558
569
  self,
559
570
  adata: AnnData,
560
571
  ct_order: list[str],
561
- agg_pca: bool = True,
572
+ agg_feature: bool = True,
562
573
  normalize: bool = True,
563
574
  ) -> tuple[list, dict]:
564
575
  """Separates cell into AnnDatas by celltype_key and creates the multifactor PMD input.
@@ -568,14 +579,14 @@ class Dialogue:
568
579
  Args:
569
580
  adata: AnnData object generate celltype objects for
570
581
  ct_order: The order of cell types
571
- agg_pca: Whether to aggregate pseudobulks with PCA or not.
582
+ agg_feature: Whether to aggregate pseudobulks with some embeddings or not.
572
583
  normalize: Whether to mimic DIALOGUE behavior or not.
573
584
 
574
585
  Returns:
575
586
  A celltype_label:array dictionary.
576
587
  """
577
588
  ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
578
- fn = self._pseudobulk_pca if agg_pca else self._get_pseudobulks
589
+ fn = self._pseudobulk_feature_space if agg_feature else self._get_pseudobulks
579
590
  ct_aggr = {ct: fn(ad, self.sample_id) for ct, ad in ct_subs.items()} # type: ignore
580
591
 
581
592
  # TODO: implement check (as in https://github.com/livnatje/DIALOGUE/blob/55da9be0a9bf2fcd360d9e11f63e30d041ec4318/R/DIALOGUE.main.R#L114-L119)
@@ -591,9 +602,9 @@ class Dialogue:
591
602
  def calculate_multifactor_PMD(
592
603
  self,
593
604
  adata: AnnData,
594
- penalties: list[int] = None,
595
- ct_order: list[str] = None,
596
- agg_pca: bool = True,
605
+ penalties: list[int] | None = None,
606
+ ct_order: list[str] | None = None,
607
+ agg_feature: bool = True,
597
608
  solver: Literal["lp", "bs"] = "bs",
598
609
  normalize: bool = True,
599
610
  ) -> tuple[AnnData, dict[str, np.ndarray], dict[Any, Any], dict[Any, Any]]:
@@ -603,10 +614,9 @@ class Dialogue:
603
614
 
604
615
  Args:
605
616
  adata: AnnData object to calculate PMD for.
606
- sample_id: Key to use for pseudobulk determination.
607
617
  penalties: PMD penalties.
608
618
  ct_order: The order of cell types.
609
- agg_pca: Whether to calculate cell-averaged PCA components.
619
+ agg_feature: Whether to calculate cell-averaged principal components.
610
620
  solver: Which solver to use for PMD. Must be one of "lp" (linear programming) or "bs" (binary search).
611
621
  For differences between these to please refer to https://github.com/theislab/sparsecca/blob/main/examples/linear_programming_multicca.ipynb
612
622
  normalize: Whether to mimic DIALOGUE as close as possible
@@ -631,7 +641,7 @@ class Dialogue:
631
641
  else:
632
642
  ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
633
643
 
634
- mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
644
+ mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_feature=agg_feature, normalize=normalize)
635
645
 
636
646
  n_samples = mcca_in[0].shape[1]
637
647
  if penalties is None:
@@ -644,8 +654,6 @@ class Dialogue:
644
654
  raise ValueError("Please ensure that every cell type is represented in every sample.") from e
645
655
  else:
646
656
  raise
647
- else:
648
- penalties = penalties
649
657
 
650
658
  if solver == "bs":
651
659
  ws, _ = multicca_pmd(mcca_in, penalties, K=self.n_mcps, standardize=True, niter=100, mimic_R=normalize)
@@ -656,8 +664,8 @@ class Dialogue:
656
664
  ws_dict = {ct: ws[i] for i, ct in enumerate(ct_order)}
657
665
 
658
666
  pre_r_scores = {
659
- ct: ct_subs[ct].obsm["X_pca"][:, :50] @ ws[i]
660
- for i, ct in enumerate(cell_types) # TODO change from 50
667
+ ct: ct_subs[ct].obsm[self.feature_space_key][:, : self.n_components] @ ws[i]
668
+ for i, ct in enumerate(cell_types)
661
669
  }
662
670
 
663
671
  # TODO: output format needs some cleanup, even though each MCP score is matched to one cell, it's not at all
@@ -681,17 +689,17 @@ class Dialogue:
681
689
  ws_dict: dict,
682
690
  confounder: str | None,
683
691
  formula: str = None,
684
- ):
692
+ ) -> pd.DataFrame:
685
693
  """Runs the multilevel modeling step to match genes to MCPs and generate p-values for MCPs.
686
694
 
687
695
  Args:
688
696
  ct_subs: The DIALOGUE cell type objects.
689
697
  mcp_scores: The determined MCP scores from the PMD step.
698
+ ws_dict: WS dictionary.
690
699
  confounder: Any modeling confounders.
691
700
  formula: The hierarchical modeling formula. Defaults to y ~ x + n_counts.
692
701
 
693
702
  Returns:
694
- A Pandas DataFrame containing:
695
703
  - for each mcp: HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2
696
704
  - merged HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2 of all mcps
697
705
 
@@ -875,15 +883,15 @@ class Dialogue:
875
883
  if len(conditions_compare) != 2:
876
884
  raise ValueError("Please specify conditions to compare or supply an object with only 2 conditions")
877
885
 
878
- pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
879
- tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
880
- pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
886
+ pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
887
+ tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
888
+ pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
881
889
 
882
890
  response = adata.obs.groupby(sample_label)[condition_label].agg(pd.Series.mode)
883
891
  for celltype in adata.obs[celltype_label].unique():
884
892
  df = adata.obs[adata.obs[celltype_label] == celltype]
885
893
 
886
- for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
894
+ for mcpnum in ["mcp_" + str(n) for n in range(n_mcps)]:
887
895
  mns = df.groupby(sample_label)[mcpnum].mean()
888
896
  mns = pd.concat([mns, response], axis=1)
889
897
  res = stats.ttest_ind(
@@ -893,7 +901,7 @@ class Dialogue:
893
901
  pvals.loc[celltype, mcpnum] = res[1]
894
902
  tstats.loc[celltype, mcpnum] = res[0]
895
903
 
896
- for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
904
+ for mcpnum in ["mcp_" + str(n) for n in range(n_mcps)]:
897
905
  pvals_adj[mcpnum] = multipletests(pvals[mcpnum], method="fdr_bh")[1]
898
906
 
899
907
  return {"pvals": pvals, "tstats": tstats, "pvals_adj": pvals_adj}
@@ -956,7 +964,7 @@ class Dialogue:
956
964
 
957
965
  genes_dict_up = {} # type: ignore
958
966
  genes_dict_down = {} # type: ignore
959
- for celltype2 in mcp_dict.keys():
967
+ for celltype2 in mcp_dict:
960
968
  for gene in mcp_dict[celltype2][MCP + ".up"]:
961
969
  if gene in genes_dict_up:
962
970
  genes_dict_up[gene] += 1
@@ -1008,7 +1016,7 @@ class Dialogue:
1008
1016
  >>> genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
1009
1017
  """
1010
1018
  genes = {}
1011
- for ct in ct_subs.keys():
1019
+ for ct in ct_subs:
1012
1020
  mini = ct_subs[ct]
1013
1021
  mini.obs["extrema"] = pd.qcut(
1014
1022
  mini.obs[mcp],
@@ -1056,13 +1064,13 @@ class Dialogue:
1056
1064
  for mcp in mcps:
1057
1065
  rank_dfs[mcp] = {}
1058
1066
  ct_ranked = self._get_extrema_MCP_genes_single(ct_subs, mcp=mcp, fraction=fraction)
1059
- for celltype in ct_ranked.keys():
1067
+ for celltype in ct_ranked:
1060
1068
  rank_dfs[mcp][celltype] = sc.get.rank_genes_groups_df(ct_ranked[celltype], group=None)
1061
1069
 
1062
1070
  return rank_dfs
1063
1071
 
1064
1072
  @_doc_params(common_plot_args=doc_common_plot_args)
1065
- def plot_split_violins(
1073
+ def plot_split_violins( # pragma: no cover # noqa: D417
1066
1074
  self,
1067
1075
  adata: AnnData,
1068
1076
  split_key: str,
@@ -1070,7 +1078,6 @@ class Dialogue:
1070
1078
  *,
1071
1079
  split_which: tuple[str, str] = None,
1072
1080
  mcp: str = "mcp_0",
1073
- show: bool = True,
1074
1081
  return_fig: bool = False,
1075
1082
  ) -> Figure | None:
1076
1083
  """Plots split violin plots for a given MCP and split variable.
@@ -1110,14 +1117,13 @@ class Dialogue:
1110
1117
  ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
1111
1118
  ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
1112
1119
 
1113
- if show:
1114
- plt.show()
1115
1120
  if return_fig:
1116
1121
  return plt.gcf()
1122
+ plt.show()
1117
1123
  return None
1118
1124
 
1119
1125
  @_doc_params(common_plot_args=doc_common_plot_args)
1120
- def plot_pairplot(
1126
+ def plot_pairplot( # pragma: no cover # noqa: D417
1121
1127
  self,
1122
1128
  adata: AnnData,
1123
1129
  celltype_key: str,
@@ -1125,7 +1131,6 @@ class Dialogue:
1125
1131
  sample_id: str,
1126
1132
  *,
1127
1133
  mcp: str = "mcp_0",
1128
- show: bool = True,
1129
1134
  return_fig: bool = False,
1130
1135
  ) -> Figure | None:
1131
1136
  """Generate a pairplot visualization for multi-cell perturbation (MCP) data.
@@ -1167,8 +1172,7 @@ class Dialogue:
1167
1172
  mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
1168
1173
  sns.pairplot(mcp_pivot, hue=color, corner=True)
1169
1174
 
1170
- if show:
1171
- plt.show()
1172
1175
  if return_fig:
1173
1176
  return plt.gcf()
1177
+ plt.show()
1174
1178
  return None
@@ -1,3 +1,4 @@
1
+ import contextlib
1
2
  import math
2
3
  from abc import ABC, abstractmethod
3
4
  from collections.abc import Iterable, Mapping, Sequence
@@ -23,8 +24,7 @@ from pertpy.tools._differential_gene_expression._checks import check_is_numeric_
23
24
 
24
25
  class MethodBase(ABC):
25
26
  def __init__(self, adata, *, mask=None, layer=None, **kwargs):
26
- """
27
- Initialize the method.
27
+ """Initialize the method.
28
28
 
29
29
  Args:
30
30
  adata: AnnData object, usually pseudobulked.
@@ -62,8 +62,7 @@ class MethodBase(ABC):
62
62
  fit_kwargs=MappingProxyType({}),
63
63
  test_kwargs=MappingProxyType({}),
64
64
  ):
65
- """
66
- Compare between groups in a specified column.
65
+ """Compare between groups in a specified column.
67
66
 
68
67
  Args:
69
68
  adata: AnnData object.
@@ -100,7 +99,7 @@ class MethodBase(ABC):
100
99
  ...
101
100
 
102
101
  @_doc_params(common_plot_args=doc_common_plot_args)
103
- def plot_volcano(
102
+ def plot_volcano( # pragma: no cover # noqa: D417
104
103
  self,
105
104
  data: pd.DataFrame | ad.AnnData,
106
105
  *,
@@ -125,7 +124,6 @@ class MethodBase(ABC):
125
124
  shape_order: list[str] | None = None,
126
125
  x_label: str | None = None,
127
126
  y_label: str | None = None,
128
- show: bool = True,
129
127
  return_fig: bool = False,
130
128
  **kwargs: int,
131
129
  ) -> Figure | None:
@@ -189,8 +187,7 @@ class MethodBase(ABC):
189
187
  colors = ["gray", "#D62728", "#1F77B4"]
190
188
 
191
189
  def _pval_reciprocal(lfc: float) -> float:
192
- """
193
- Function for relating -log10(pvalue) and logfoldchange in a reciprocal.
190
+ """Function for relating -log10(pvalue) and logfoldchange in a reciprocal.
194
191
 
195
192
  Used for plotting the S-curve
196
193
  """
@@ -198,7 +195,7 @@ class MethodBase(ABC):
198
195
 
199
196
  def _map_shape(symbol: str) -> str:
200
197
  if shape_dict is not None:
201
- for k in shape_dict.keys():
198
+ for k in shape_dict:
202
199
  if shape_dict[k] is not None and symbol in shape_dict[k]:
203
200
  return k
204
201
  return "other"
@@ -212,8 +209,7 @@ class MethodBase(ABC):
212
209
  pval_thresh: float = None,
213
210
  s_curve: bool = False,
214
211
  ) -> str:
215
- """
216
- Map genes to categorize based on log2fc and pvalue.
212
+ """Map genes to categorize based on log2fc and pvalue.
217
213
 
218
214
  These categories are used for coloring the dots.
219
215
  Used when no color_dict is passed, sets up/down/nonsignificant.
@@ -230,14 +226,13 @@ class MethodBase(ABC):
230
226
  return "Down"
231
227
  else:
232
228
  return "not DE"
229
+ # Standard condition for Up or Down categorization
230
+ elif log2fc > log2fc_thresh and nlog10 > pval_thresh:
231
+ return "Up"
232
+ elif log2fc < -log2fc_thresh and nlog10 > pval_thresh:
233
+ return "Down"
233
234
  else:
234
- # Standard condition for Up or Down categorization
235
- if log2fc > log2fc_thresh and nlog10 > pval_thresh:
236
- return "Up"
237
- elif log2fc < -log2fc_thresh and nlog10 > pval_thresh:
238
- return "Down"
239
- else:
240
- return "not DE"
235
+ return "not DE"
241
236
 
242
237
  def _map_genes_categories_highlight(
243
238
  row: pd.Series,
@@ -248,8 +243,7 @@ class MethodBase(ABC):
248
243
  s_curve: bool = False,
249
244
  symbol_col: str = None,
250
245
  ) -> str:
251
- """
252
- Map genes to categorize based on log2fc and pvalue.
246
+ """Map genes to categorize based on log2fc and pvalue.
253
247
 
254
248
  These categories are used for coloring the dots.
255
249
  Used when color_dict is passed, sets DE / not DE for background and user supplied highlight genes.
@@ -259,7 +253,7 @@ class MethodBase(ABC):
259
253
  symbol = row[symbol_col]
260
254
 
261
255
  if color_dict is not None:
262
- for k in color_dict.keys():
256
+ for k in color_dict:
263
257
  if symbol in color_dict[k]:
264
258
  return k
265
259
 
@@ -484,14 +478,13 @@ class MethodBase(ABC):
484
478
 
485
479
  plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
486
480
 
487
- if show:
488
- plt.show()
489
481
  if return_fig:
490
482
  return plt.gcf()
483
+ plt.show()
491
484
  return None
492
485
 
493
486
  @_doc_params(common_plot_args=doc_common_plot_args)
494
- def plot_paired(
487
+ def plot_paired( # pragma: no cover # noqa: D417
495
488
  self,
496
489
  adata: ad.AnnData,
497
490
  results_df: pd.DataFrame,
@@ -511,7 +504,6 @@ class MethodBase(ABC):
511
504
  pvalue_template=lambda x: f"p={x:.2e}",
512
505
  boxplot_properties=None,
513
506
  palette=None,
514
- show: bool = True,
515
507
  return_fig: bool = False,
516
508
  ) -> Figure | None:
517
509
  """Creates a pairwise expression plot from a Pandas DataFrame or Anndata.
@@ -584,14 +576,9 @@ class MethodBase(ABC):
584
576
  adata, target_col=groupby, groups_col=pairedby, layer_key=layer, mode="sum", min_cells=1, min_counts=1
585
577
  )
586
578
 
587
- if layer is not None:
588
- X = adata.layers[layer]
589
- else:
590
- X = adata.X
591
- try:
579
+ X = adata.layers[layer] if layer is not None else adata.X
580
+ with contextlib.suppress(AttributeError):
592
581
  X = X.toarray()
593
- except AttributeError:
594
- pass
595
582
 
596
583
  groupby_cols = [pairedby, groupby]
597
584
  df = adata.obs.loc[:, groupby_cols].join(pd.DataFrame(X, index=adata.obs_names, columns=var_names))
@@ -679,14 +666,13 @@ class MethodBase(ABC):
679
666
  )
680
667
 
681
668
  plt.tight_layout()
682
- if show:
683
- plt.show()
684
669
  if return_fig:
685
670
  return plt.gcf()
671
+ plt.show()
686
672
  return None
687
673
 
688
674
  @_doc_params(common_plot_args=doc_common_plot_args)
689
- def plot_fold_change(
675
+ def plot_fold_change( # pragma: no cover # noqa: D417
690
676
  self,
691
677
  results_df: pd.DataFrame,
692
678
  *,
@@ -696,7 +682,6 @@ class MethodBase(ABC):
696
682
  symbol_col: str = "variable",
697
683
  y_label: str = "Log2 fold change",
698
684
  figsize: tuple[int, int] = (10, 5),
699
- show: bool = True,
700
685
  return_fig: bool = False,
701
686
  **barplot_kwargs,
702
687
  ) -> Figure | None:
@@ -762,14 +747,13 @@ class MethodBase(ABC):
762
747
  plt.xlabel("")
763
748
  plt.ylabel(y_label)
764
749
 
765
- if show:
766
- plt.show()
767
750
  if return_fig:
768
751
  return plt.gcf()
752
+ plt.show()
769
753
  return None
770
754
 
771
755
  @_doc_params(common_plot_args=doc_common_plot_args)
772
- def plot_multicomparison_fc(
756
+ def plot_multicomparison_fc( # pragma: no cover # noqa: D417
773
757
  self,
774
758
  results_df: pd.DataFrame,
775
759
  *,
@@ -782,7 +766,6 @@ class MethodBase(ABC):
782
766
  figsize: tuple[int, int] = (10, 2),
783
767
  x_label: str = "Contrast",
784
768
  y_label: str = "Gene",
785
- show: bool = True,
786
769
  return_fig: bool = False,
787
770
  **heatmap_kwargs,
788
771
  ) -> Figure | None:
@@ -880,10 +863,9 @@ class MethodBase(ABC):
880
863
  plt.xlabel(x_label)
881
864
  plt.ylabel(y_label)
882
865
 
883
- if show:
884
- plt.show()
885
866
  if return_fig:
886
867
  return plt.gcf()
868
+ plt.show()
887
869
  return None
888
870
 
889
871
 
@@ -1021,7 +1003,7 @@ class LinearModelBase(MethodBase):
1021
1003
  )
1022
1004
  return self.formulaic_contrasts.cond(**kwargs)
1023
1005
 
1024
- def contrast(self, *args, **kwargs):
1006
+ def contrast(self, *args, **kwargs): # noqa: D417
1025
1007
  """Build a simple contrast for pairwise comparisons.
1026
1008
 
1027
1009
  Args:
@@ -16,9 +16,8 @@ def check_is_numeric_matrix(array: np.ndarray | spmatrix) -> None:
16
16
  if issparse(array):
17
17
  if np.any(~np.isfinite(array.data)):
18
18
  raise ValueError("Counts cannot contain negative, NaN or Inf values.")
19
- else:
20
- if np.any(~np.isfinite(array)):
21
- raise ValueError("Counts cannot contain negative, NaN or Inf values.")
19
+ elif np.any(~np.isfinite(array)):
20
+ raise ValueError("Counts cannot contain negative, NaN or Inf values.")
22
21
 
23
22
 
24
23
  def check_is_integer_matrix(array: np.ndarray | spmatrix, tolerance: float = 1e-6) -> None:
@@ -34,8 +33,7 @@ def check_is_integer_matrix(array: np.ndarray | spmatrix, tolerance: float = 1e-
34
33
  if issparse(array):
35
34
  if not array.data.dtype.kind == "i" and not np.all(np.abs(array.data - np.round(array.data)) < tolerance):
36
35
  raise ValueError("Non-zero elements of the matrix must be close to integer values.")
37
- else:
38
- if not array.dtype.kind == "i" and not np.all(np.abs(array - np.round(array)) < tolerance):
39
- raise ValueError("Matrix must be a count matrix.")
36
+ elif array.dtype.kind != "i" and not np.all(np.abs(array - np.round(array)) < tolerance):
37
+ raise ValueError("Matrix must be a count matrix.")
40
38
  if (array < 0).sum() > 0:
41
39
  raise ValueError("Non-zero elements of the matrix must be positive.")
@@ -36,16 +36,15 @@ class DGEEVAL:
36
36
  if not de_key1 or not de_key2:
37
37
  raise ValueError("Both `de_key1` and `de_key2` must be provided together if using `adata`.")
38
38
 
39
- else: # use dfs
40
- if de_df1 is None or de_df2 is None:
41
- raise ValueError("Both `de_df1` and `de_df2` must be provided together if using DataFrames.")
39
+ elif de_df1 is None or de_df2 is None:
40
+ raise ValueError("Both `de_df1` and `de_df2` must be provided together if using DataFrames.")
42
41
 
43
42
  if de_key1:
44
43
  if not adata:
45
44
  raise ValueError("`adata` should be provided with `de_key1` and `de_key2`. ")
46
- assert all(
47
- k in adata.uns for k in [de_key1, de_key2]
48
- ), "Provided `de_key1` and `de_key2` must exist in `adata.uns`."
45
+ assert all(k in adata.uns for k in [de_key1, de_key2]), (
46
+ "Provided `de_key1` and `de_key2` must exist in `adata.uns`."
47
+ )
49
48
  vars = adata.var_names
50
49
 
51
50
  if de_df1 is not None: