pertpy 0.10.0__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. pertpy/__init__.py +5 -1
  2. pertpy/_doc.py +1 -3
  3. pertpy/_types.py +6 -0
  4. pertpy/data/_dataloader.py +68 -24
  5. pertpy/data/_datasets.py +9 -9
  6. pertpy/metadata/__init__.py +2 -1
  7. pertpy/metadata/_cell_line.py +133 -25
  8. pertpy/metadata/_look_up.py +13 -19
  9. pertpy/metadata/_moa.py +1 -1
  10. pertpy/preprocessing/_guide_rna.py +138 -44
  11. pertpy/preprocessing/_guide_rna_mixture.py +17 -19
  12. pertpy/tools/__init__.py +4 -3
  13. pertpy/tools/_augur.py +106 -98
  14. pertpy/tools/_cinemaot.py +74 -114
  15. pertpy/tools/_coda/_base_coda.py +134 -148
  16. pertpy/tools/_coda/_sccoda.py +69 -70
  17. pertpy/tools/_coda/_tasccoda.py +74 -80
  18. pertpy/tools/_dialogue.py +48 -41
  19. pertpy/tools/_differential_gene_expression/_base.py +21 -31
  20. pertpy/tools/_differential_gene_expression/_checks.py +4 -6
  21. pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
  22. pertpy/tools/_differential_gene_expression/_edger.py +6 -10
  23. pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
  24. pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
  25. pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
  26. pertpy/tools/_distances/_distance_tests.py +1 -2
  27. pertpy/tools/_distances/_distances.py +31 -46
  28. pertpy/tools/_enrichment.py +7 -22
  29. pertpy/tools/_milo.py +19 -15
  30. pertpy/tools/_mixscape.py +73 -75
  31. pertpy/tools/_perturbation_space/_clustering.py +4 -4
  32. pertpy/tools/_perturbation_space/_comparison.py +4 -4
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
  35. pertpy/tools/_perturbation_space/_simple.py +12 -14
  36. pertpy/tools/_scgen/_scgen.py +16 -17
  37. pertpy/tools/_scgen/_scgenvae.py +2 -2
  38. pertpy/tools/_scgen/_utils.py +3 -1
  39. {pertpy-0.10.0.dist-info → pertpy-0.11.1.dist-info}/METADATA +42 -24
  40. pertpy-0.11.1.dist-info/RECORD +58 -0
  41. {pertpy-0.10.0.dist-info → pertpy-0.11.1.dist-info}/licenses/LICENSE +1 -0
  42. pertpy/tools/_kernel_pca.py +0 -50
  43. pertpy-0.10.0.dist-info/RECORD +0 -58
  44. {pertpy-0.10.0.dist-info → pertpy-0.11.1.dist-info}/WHEEL +0 -0
pertpy/tools/_dialogue.py CHANGED
@@ -20,7 +20,6 @@ from rich.live import Live
20
20
  from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
21
21
  from scipy import stats
22
22
  from scipy.optimize import nnls
23
- from seaborn import PairGrid
24
23
  from sklearn.linear_model import LinearRegression
25
24
  from sparsecca import lp_pmd, multicca_permute, multicca_pmd
26
25
  from statsmodels.sandbox.stats.multicomp import multipletests
@@ -33,9 +32,17 @@ if TYPE_CHECKING:
33
32
 
34
33
 
35
34
  class Dialogue:
36
- """Python implementation of DIALOGUE"""
35
+ """Python implementation of DIALOGUE."""
37
36
 
38
- def __init__(self, sample_id: str, celltype_key: str, n_counts_key: str, n_mpcs: int):
37
+ def __init__(
38
+ self,
39
+ sample_id: str,
40
+ celltype_key: str,
41
+ n_counts_key: str,
42
+ n_mpcs: int,
43
+ feature_space_key: str = "X_pca",
44
+ n_components: int = 50,
45
+ ):
39
46
  """Constructor for Dialogue.
40
47
 
41
48
  Args:
@@ -43,6 +50,8 @@ class Dialogue:
43
50
  celltype_key: The key in AnnData.obs which contains the cell type column.
44
51
  n_counts_key: The key of the number of counts in Anndata.obs . Also commonly the size factor.
45
52
  n_mpcs: Number of PMD components which corresponds to the number of determined MCPs.
53
+ feature_space_key: The key in adata.obsm for the feature space (e.g., "X_pca", "X_umap").
54
+ n_components: The number of components of the feature space to use, e.g. PCA components.
46
55
  """
47
56
  self.sample_id = sample_id
48
57
  self.celltype_key = celltype_key
@@ -53,6 +62,8 @@ class Dialogue:
53
62
  )
54
63
  self.n_counts_key = n_counts_key
55
64
  self.n_mcps = n_mpcs
65
+ self.feature_space_key = feature_space_key
66
+ self.n_components = n_components
56
67
 
57
68
  def _get_pseudobulks(
58
69
  self, adata: AnnData, groupby: str, strategy: Literal["median", "mean"] = "median"
@@ -62,6 +73,7 @@ class Dialogue:
62
73
  Copied from `https://github.com/schillerlab/sc-toolbox/blob/397e80dc5e8fb8017b75f6c3fa634a1e1213d484/sc_toolbox/tools/__init__.py#L458`
63
74
 
64
75
  Args:
76
+ adata: Annotated data matrix.
65
77
  groupby: The key to groupby for pseudobulks
66
78
  strategy: The pseudobulking strategy. One of "median" or "mean"
67
79
 
@@ -83,7 +95,9 @@ class Dialogue:
83
95
  return pseudobulk
84
96
 
85
97
  def _pseudobulk_feature_space(
86
- self, adata: AnnData, groupby: str, n_components: int = 50, feature_space_key: str = "X_pca"
98
+ self,
99
+ adata: AnnData,
100
+ groupby: str,
87
101
  ) -> pd.DataFrame:
88
102
  """Return Cell-averaged components from a passed feature space.
89
103
 
@@ -91,9 +105,8 @@ class Dialogue:
91
105
  TODO: DIALOGUE recommends running PCA on each cell type separately before running PMD - this should be implemented as an option here.
92
106
 
93
107
  Args:
108
+ adata: Annotated data matrix.
94
109
  groupby: The key to groupby for pseudobulks.
95
- n_components: The number of components to use.
96
- feature_key: The key in adata.obsm for the feature space (e.g., "X_pca", "X_umap").
97
110
 
98
111
  Returns:
99
112
  A pseudobulk DataFrame of the averaged components.
@@ -101,7 +114,7 @@ class Dialogue:
101
114
  aggr = {}
102
115
  for category in adata.obs.loc[:, groupby].cat.categories:
103
116
  temp = adata.obs.loc[:, groupby] == category
104
- aggr[category] = adata[temp].obsm[feature_space_key][:, :n_components].mean(axis=0)
117
+ aggr[category] = adata[temp].obsm[self.feature_space_key][:, : self.n_components].mean(axis=0)
105
118
  aggr = pd.DataFrame(aggr)
106
119
  return aggr
107
120
 
@@ -130,6 +143,7 @@ class Dialogue:
130
143
 
131
144
  Args:
132
145
  adata: The AnnData object to append mcp scores to.
146
+ ct_subs: cell type objects.
133
147
  mcp_scores: The MCP scores dictionary.
134
148
  celltype_key: Key of the cell type column in obs.
135
149
 
@@ -213,7 +227,7 @@ class Dialogue:
213
227
  sample_obs: str,
214
228
  return_all: bool = False,
215
229
  ):
216
- """Applies a mixed linear model using the specified formula (MCP scores used for the dependent var) and returns the coefficient and p-value
230
+ """Applies a mixed linear model using the specified formula (MCP scores used for the dependent var) and returns the coefficient and p-value.
217
231
 
218
232
  TODO: reduce runtime? Maybe we can use an approximation or something that isn't statsmodels.
219
233
 
@@ -332,7 +346,7 @@ class Dialogue:
332
346
 
333
347
  Args:
334
348
  mcp_name: The name of the MCP to model.
335
- scores: The MCP scores for a cell type. Number of MCPs x number of features.
349
+ scores_df: The MCP scores for a cell type. Number of MCPs x number of features.
336
350
  ct_data: The AnnData object containing the metadata and labels in obs.
337
351
  tme: Transcript mean expression in `x`.
338
352
  sig: DataFrame containing a series of up and downregulated MCPs.
@@ -418,11 +432,10 @@ class Dialogue:
418
432
  # Finally get corr coeff
419
433
  return np.dot(A_mA, B_mB.T) / np.sqrt(np.dot(ssA[:, None], ssB[None]))
420
434
 
435
+ # TODO: needs check for correctness and variable renaming
436
+ # TODO: Confirm that this doesn't return duplicate gene names.
421
437
  def _get_top_elements(self, m: pd.DataFrame, max_length: int, min_threshold: float):
422
- """
423
-
424
- TODO: needs check for correctness and variable renaming
425
- TODO: Confirm that this doesn't return duplicate gene names
438
+ """Get top elements.
426
439
 
427
440
  Args:
428
441
  m: Any DataFrame of Gene name as index with variable columns.
@@ -457,12 +470,11 @@ class Dialogue:
457
470
  # TODO this whole function should be standalone
458
471
  # It will contain the calculation of up/down + calculation (new final mcp scores)
459
472
  # Ensure that it'll still fit/work with the hierarchical multilevel_modeling
460
-
461
473
  """Determine the up and down genes per MCP."""
462
474
  # TODO: something is slightly slow here
463
475
  cca_sig_results: dict[Any, dict[str, Any]] = {}
464
476
  new_mcp_scores: dict[Any, list[Any]] = {}
465
- for ct in ct_subs.keys():
477
+ for ct in ct_subs:
466
478
  ct_adata = ct_subs[ct]
467
479
  conf_m = ct_adata.obs[n_counts_key].values
468
480
 
@@ -483,9 +495,7 @@ class Dialogue:
483
495
  from scipy.stats import spearmanr
484
496
 
485
497
  def _pcor_mat(v1, v2, v3, method="spearman"):
486
- """
487
- MAJOR TODO: I've only used normal correlation instead of partial correlation as we wait on the implementation
488
- """
498
+ """MAJOR TODO: I've only used normal correlation instead of partial correlation as we wait on the implementation."""
489
499
  correlations = [] # R
490
500
  pvals = [] # P
491
501
  for x2 in v2:
@@ -506,7 +516,7 @@ class Dialogue:
506
516
  return np.array(correlations), np.array(pvals) # pvals_adjusted
507
517
 
508
518
  C1, P1 = _pcor_mat(ct_adata[:, top_cor_genes_flattened].X.toarray().T, mcp_scores[ct].T, conf_m)
509
- C1[P1 > (0.05 / ct_adata.shape[1])] = 0 # why?
519
+ C1[(0.05 / ct_adata.shape[1]) < P1] = 0 # why?
510
520
 
511
521
  cca_sig_unformatted = self._get_top_elements( # 3 up, 3 dn, for each mcp
512
522
  pd.DataFrame(C1.T, index=top_cor_genes_flattened), max_length=max_genes, min_threshold=0.05
@@ -514,7 +524,7 @@ class Dialogue:
514
524
 
515
525
  # TODO: probably format the up and down within get_top_elements
516
526
  cca_sig: dict[str, Any] = defaultdict(dict)
517
- for i in range(0, int(len(cca_sig_unformatted) / 2)):
527
+ for i in range(int(len(cca_sig_unformatted) / 2)):
518
528
  cca_sig[f"MCP{i}"]["up"] = cca_sig_unformatted[i * 2]
519
529
  cca_sig[f"MCP{i}"]["down"] = cca_sig_unformatted[i * 2 + 1]
520
530
 
@@ -523,7 +533,7 @@ class Dialogue:
523
533
 
524
534
  # This is basically DIALOGUE 3 now
525
535
  pre_r_scores = {
526
- ct: ct_subs[ct].obsm["X_pca"][:, :50] @ ws_dict[ct]
536
+ ct: ct_subs[ct].obsm[self.feature_space_key][:, : self.n_components] @ ws_dict[ct]
527
537
  for i, ct in enumerate(ct_subs.keys())
528
538
  # TODO This is a recalculation and not a new calculation
529
539
  }
@@ -591,8 +601,8 @@ class Dialogue:
591
601
  def calculate_multifactor_PMD(
592
602
  self,
593
603
  adata: AnnData,
594
- penalties: list[int] = None,
595
- ct_order: list[str] = None,
604
+ penalties: list[int] | None = None,
605
+ ct_order: list[str] | None = None,
596
606
  agg_feature: bool = True,
597
607
  solver: Literal["lp", "bs"] = "bs",
598
608
  normalize: bool = True,
@@ -603,10 +613,9 @@ class Dialogue:
603
613
 
604
614
  Args:
605
615
  adata: AnnData object to calculate PMD for.
606
- sample_id: Key to use for pseudobulk determination.
607
616
  penalties: PMD penalties.
608
617
  ct_order: The order of cell types.
609
- agg_features: Whether to calculate cell-averaged principal components.
618
+ agg_feature: Whether to calculate cell-averaged principal components.
610
619
  solver: Which solver to use for PMD. Must be one of "lp" (linear programming) or "bs" (binary search).
611
620
  For differences between these to please refer to https://github.com/theislab/sparsecca/blob/main/examples/linear_programming_multicca.ipynb
612
621
  normalize: Whether to mimic DIALOGUE as close as possible
@@ -644,8 +653,6 @@ class Dialogue:
644
653
  raise ValueError("Please ensure that every cell type is represented in every sample.") from e
645
654
  else:
646
655
  raise
647
- else:
648
- penalties = penalties
649
656
 
650
657
  if solver == "bs":
651
658
  ws, _ = multicca_pmd(mcca_in, penalties, K=self.n_mcps, standardize=True, niter=100, mimic_R=normalize)
@@ -656,8 +663,8 @@ class Dialogue:
656
663
  ws_dict = {ct: ws[i] for i, ct in enumerate(ct_order)}
657
664
 
658
665
  pre_r_scores = {
659
- ct: ct_subs[ct].obsm["X_pca"][:, :50] @ ws[i]
660
- for i, ct in enumerate(cell_types) # TODO change from 50
666
+ ct: ct_subs[ct].obsm[self.feature_space_key][:, : self.n_components] @ ws[i]
667
+ for i, ct in enumerate(cell_types)
661
668
  }
662
669
 
663
670
  # TODO: output format needs some cleanup, even though each MCP score is matched to one cell, it's not at all
@@ -681,17 +688,17 @@ class Dialogue:
681
688
  ws_dict: dict,
682
689
  confounder: str | None,
683
690
  formula: str = None,
684
- ):
691
+ ) -> pd.DataFrame:
685
692
  """Runs the multilevel modeling step to match genes to MCPs and generate p-values for MCPs.
686
693
 
687
694
  Args:
688
695
  ct_subs: The DIALOGUE cell type objects.
689
696
  mcp_scores: The determined MCP scores from the PMD step.
697
+ ws_dict: WS dictionary.
690
698
  confounder: Any modeling confounders.
691
699
  formula: The hierarchical modeling formula. Defaults to y ~ x + n_counts.
692
700
 
693
701
  Returns:
694
- A Pandas DataFrame containing:
695
702
  - for each mcp: HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2
696
703
  - merged HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2 of all mcps
697
704
 
@@ -875,15 +882,15 @@ class Dialogue:
875
882
  if len(conditions_compare) != 2:
876
883
  raise ValueError("Please specify conditions to compare or supply an object with only 2 conditions")
877
884
 
878
- pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
879
- tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
880
- pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
885
+ pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
886
+ tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
887
+ pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
881
888
 
882
889
  response = adata.obs.groupby(sample_label)[condition_label].agg(pd.Series.mode)
883
890
  for celltype in adata.obs[celltype_label].unique():
884
891
  df = adata.obs[adata.obs[celltype_label] == celltype]
885
892
 
886
- for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
893
+ for mcpnum in ["mcp_" + str(n) for n in range(n_mcps)]:
887
894
  mns = df.groupby(sample_label)[mcpnum].mean()
888
895
  mns = pd.concat([mns, response], axis=1)
889
896
  res = stats.ttest_ind(
@@ -893,7 +900,7 @@ class Dialogue:
893
900
  pvals.loc[celltype, mcpnum] = res[1]
894
901
  tstats.loc[celltype, mcpnum] = res[0]
895
902
 
896
- for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
903
+ for mcpnum in ["mcp_" + str(n) for n in range(n_mcps)]:
897
904
  pvals_adj[mcpnum] = multipletests(pvals[mcpnum], method="fdr_bh")[1]
898
905
 
899
906
  return {"pvals": pvals, "tstats": tstats, "pvals_adj": pvals_adj}
@@ -956,7 +963,7 @@ class Dialogue:
956
963
 
957
964
  genes_dict_up = {} # type: ignore
958
965
  genes_dict_down = {} # type: ignore
959
- for celltype2 in mcp_dict.keys():
966
+ for celltype2 in mcp_dict:
960
967
  for gene in mcp_dict[celltype2][MCP + ".up"]:
961
968
  if gene in genes_dict_up:
962
969
  genes_dict_up[gene] += 1
@@ -1008,7 +1015,7 @@ class Dialogue:
1008
1015
  >>> genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
1009
1016
  """
1010
1017
  genes = {}
1011
- for ct in ct_subs.keys():
1018
+ for ct in ct_subs:
1012
1019
  mini = ct_subs[ct]
1013
1020
  mini.obs["extrema"] = pd.qcut(
1014
1021
  mini.obs[mcp],
@@ -1056,13 +1063,13 @@ class Dialogue:
1056
1063
  for mcp in mcps:
1057
1064
  rank_dfs[mcp] = {}
1058
1065
  ct_ranked = self._get_extrema_MCP_genes_single(ct_subs, mcp=mcp, fraction=fraction)
1059
- for celltype in ct_ranked.keys():
1066
+ for celltype in ct_ranked:
1060
1067
  rank_dfs[mcp][celltype] = sc.get.rank_genes_groups_df(ct_ranked[celltype], group=None)
1061
1068
 
1062
1069
  return rank_dfs
1063
1070
 
1064
1071
  @_doc_params(common_plot_args=doc_common_plot_args)
1065
- def plot_split_violins(
1072
+ def plot_split_violins( # pragma: no cover # noqa: D417
1066
1073
  self,
1067
1074
  adata: AnnData,
1068
1075
  split_key: str,
@@ -1115,7 +1122,7 @@ class Dialogue:
1115
1122
  return None
1116
1123
 
1117
1124
  @_doc_params(common_plot_args=doc_common_plot_args)
1118
- def plot_pairplot(
1125
+ def plot_pairplot( # pragma: no cover # noqa: D417
1119
1126
  self,
1120
1127
  adata: AnnData,
1121
1128
  celltype_key: str,
@@ -1,3 +1,4 @@
1
+ import contextlib
1
2
  import math
2
3
  from abc import ABC, abstractmethod
3
4
  from collections.abc import Iterable, Mapping, Sequence
@@ -23,8 +24,7 @@ from pertpy.tools._differential_gene_expression._checks import check_is_numeric_
23
24
 
24
25
  class MethodBase(ABC):
25
26
  def __init__(self, adata, *, mask=None, layer=None, **kwargs):
26
- """
27
- Initialize the method.
27
+ """Initialize the method.
28
28
 
29
29
  Args:
30
30
  adata: AnnData object, usually pseudobulked.
@@ -62,8 +62,7 @@ class MethodBase(ABC):
62
62
  fit_kwargs=MappingProxyType({}),
63
63
  test_kwargs=MappingProxyType({}),
64
64
  ):
65
- """
66
- Compare between groups in a specified column.
65
+ """Compare between groups in a specified column.
67
66
 
68
67
  Args:
69
68
  adata: AnnData object.
@@ -100,7 +99,7 @@ class MethodBase(ABC):
100
99
  ...
101
100
 
102
101
  @_doc_params(common_plot_args=doc_common_plot_args)
103
- def plot_volcano(
102
+ def plot_volcano( # pragma: no cover # noqa: D417
104
103
  self,
105
104
  data: pd.DataFrame | ad.AnnData,
106
105
  *,
@@ -188,8 +187,7 @@ class MethodBase(ABC):
188
187
  colors = ["gray", "#D62728", "#1F77B4"]
189
188
 
190
189
  def _pval_reciprocal(lfc: float) -> float:
191
- """
192
- Function for relating -log10(pvalue) and logfoldchange in a reciprocal.
190
+ """Function for relating -log10(pvalue) and logfoldchange in a reciprocal.
193
191
 
194
192
  Used for plotting the S-curve
195
193
  """
@@ -197,7 +195,7 @@ class MethodBase(ABC):
197
195
 
198
196
  def _map_shape(symbol: str) -> str:
199
197
  if shape_dict is not None:
200
- for k in shape_dict.keys():
198
+ for k in shape_dict:
201
199
  if shape_dict[k] is not None and symbol in shape_dict[k]:
202
200
  return k
203
201
  return "other"
@@ -211,8 +209,7 @@ class MethodBase(ABC):
211
209
  pval_thresh: float = None,
212
210
  s_curve: bool = False,
213
211
  ) -> str:
214
- """
215
- Map genes to categorize based on log2fc and pvalue.
212
+ """Map genes to categorize based on log2fc and pvalue.
216
213
 
217
214
  These categories are used for coloring the dots.
218
215
  Used when no color_dict is passed, sets up/down/nonsignificant.
@@ -229,14 +226,13 @@ class MethodBase(ABC):
229
226
  return "Down"
230
227
  else:
231
228
  return "not DE"
229
+ # Standard condition for Up or Down categorization
230
+ elif log2fc > log2fc_thresh and nlog10 > pval_thresh:
231
+ return "Up"
232
+ elif log2fc < -log2fc_thresh and nlog10 > pval_thresh:
233
+ return "Down"
232
234
  else:
233
- # Standard condition for Up or Down categorization
234
- if log2fc > log2fc_thresh and nlog10 > pval_thresh:
235
- return "Up"
236
- elif log2fc < -log2fc_thresh and nlog10 > pval_thresh:
237
- return "Down"
238
- else:
239
- return "not DE"
235
+ return "not DE"
240
236
 
241
237
  def _map_genes_categories_highlight(
242
238
  row: pd.Series,
@@ -247,8 +243,7 @@ class MethodBase(ABC):
247
243
  s_curve: bool = False,
248
244
  symbol_col: str = None,
249
245
  ) -> str:
250
- """
251
- Map genes to categorize based on log2fc and pvalue.
246
+ """Map genes to categorize based on log2fc and pvalue.
252
247
 
253
248
  These categories are used for coloring the dots.
254
249
  Used when color_dict is passed, sets DE / not DE for background and user supplied highlight genes.
@@ -258,7 +253,7 @@ class MethodBase(ABC):
258
253
  symbol = row[symbol_col]
259
254
 
260
255
  if color_dict is not None:
261
- for k in color_dict.keys():
256
+ for k in color_dict:
262
257
  if symbol in color_dict[k]:
263
258
  return k
264
259
 
@@ -489,7 +484,7 @@ class MethodBase(ABC):
489
484
  return None
490
485
 
491
486
  @_doc_params(common_plot_args=doc_common_plot_args)
492
- def plot_paired(
487
+ def plot_paired( # pragma: no cover # noqa: D417
493
488
  self,
494
489
  adata: ad.AnnData,
495
490
  results_df: pd.DataFrame,
@@ -581,14 +576,9 @@ class MethodBase(ABC):
581
576
  adata, target_col=groupby, groups_col=pairedby, layer_key=layer, mode="sum", min_cells=1, min_counts=1
582
577
  )
583
578
 
584
- if layer is not None:
585
- X = adata.layers[layer]
586
- else:
587
- X = adata.X
588
- try:
579
+ X = adata.layers[layer] if layer is not None else adata.X
580
+ with contextlib.suppress(AttributeError):
589
581
  X = X.toarray()
590
- except AttributeError:
591
- pass
592
582
 
593
583
  groupby_cols = [pairedby, groupby]
594
584
  df = adata.obs.loc[:, groupby_cols].join(pd.DataFrame(X, index=adata.obs_names, columns=var_names))
@@ -682,7 +672,7 @@ class MethodBase(ABC):
682
672
  return None
683
673
 
684
674
  @_doc_params(common_plot_args=doc_common_plot_args)
685
- def plot_fold_change(
675
+ def plot_fold_change( # pragma: no cover # noqa: D417
686
676
  self,
687
677
  results_df: pd.DataFrame,
688
678
  *,
@@ -763,7 +753,7 @@ class MethodBase(ABC):
763
753
  return None
764
754
 
765
755
  @_doc_params(common_plot_args=doc_common_plot_args)
766
- def plot_multicomparison_fc(
756
+ def plot_multicomparison_fc( # pragma: no cover # noqa: D417
767
757
  self,
768
758
  results_df: pd.DataFrame,
769
759
  *,
@@ -1013,7 +1003,7 @@ class LinearModelBase(MethodBase):
1013
1003
  )
1014
1004
  return self.formulaic_contrasts.cond(**kwargs)
1015
1005
 
1016
- def contrast(self, *args, **kwargs):
1006
+ def contrast(self, *args, **kwargs): # noqa: D417
1017
1007
  """Build a simple contrast for pairwise comparisons.
1018
1008
 
1019
1009
  Args:
@@ -16,9 +16,8 @@ def check_is_numeric_matrix(array: np.ndarray | spmatrix) -> None:
16
16
  if issparse(array):
17
17
  if np.any(~np.isfinite(array.data)):
18
18
  raise ValueError("Counts cannot contain negative, NaN or Inf values.")
19
- else:
20
- if np.any(~np.isfinite(array)):
21
- raise ValueError("Counts cannot contain negative, NaN or Inf values.")
19
+ elif np.any(~np.isfinite(array)):
20
+ raise ValueError("Counts cannot contain negative, NaN or Inf values.")
22
21
 
23
22
 
24
23
  def check_is_integer_matrix(array: np.ndarray | spmatrix, tolerance: float = 1e-6) -> None:
@@ -34,8 +33,7 @@ def check_is_integer_matrix(array: np.ndarray | spmatrix, tolerance: float = 1e-
34
33
  if issparse(array):
35
34
  if not array.data.dtype.kind == "i" and not np.all(np.abs(array.data - np.round(array.data)) < tolerance):
36
35
  raise ValueError("Non-zero elements of the matrix must be close to integer values.")
37
- else:
38
- if not array.dtype.kind == "i" and not np.all(np.abs(array - np.round(array)) < tolerance):
39
- raise ValueError("Matrix must be a count matrix.")
36
+ elif array.dtype.kind != "i" and not np.all(np.abs(array - np.round(array)) < tolerance):
37
+ raise ValueError("Matrix must be a count matrix.")
40
38
  if (array < 0).sum() > 0:
41
39
  raise ValueError("Non-zero elements of the matrix must be positive.")
@@ -36,16 +36,15 @@ class DGEEVAL:
36
36
  if not de_key1 or not de_key2:
37
37
  raise ValueError("Both `de_key1` and `de_key2` must be provided together if using `adata`.")
38
38
 
39
- else: # use dfs
40
- if de_df1 is None or de_df2 is None:
41
- raise ValueError("Both `de_df1` and `de_df2` must be provided together if using DataFrames.")
39
+ elif de_df1 is None or de_df2 is None:
40
+ raise ValueError("Both `de_df1` and `de_df2` must be provided together if using DataFrames.")
42
41
 
43
42
  if de_key1:
44
43
  if not adata:
45
44
  raise ValueError("`adata` should be provided with `de_key1` and `de_key2`. ")
46
- assert all(
47
- k in adata.uns for k in [de_key1, de_key2]
48
- ), "Provided `de_key1` and `de_key2` must exist in `adata.uns`."
45
+ assert all(k in adata.uns for k in [de_key1, de_key2]), (
46
+ "Provided `de_key1` and `de_key2` must exist in `adata.uns`."
47
+ )
49
48
  vars = adata.var_names
50
49
 
51
50
  if de_df1 is not None:
@@ -10,7 +10,7 @@ from ._checks import check_is_integer_matrix
10
10
 
11
11
 
12
12
  class EdgeR(LinearModelBase):
13
- """Differential expression test using EdgeR"""
13
+ """Differential expression test using EdgeR."""
14
14
 
15
15
  def _check_counts(self):
16
16
  check_is_integer_matrix(self.data)
@@ -39,17 +39,13 @@ class EdgeR(LinearModelBase):
39
39
  edger = importr("edgeR")
40
40
  except ImportError as e:
41
41
  raise ImportError(
42
- "edgeR requires a valid R installation with the following packages:\n"
43
- "edgeR, BiocParallel, RhpcBLASctl"
42
+ "edgeR requires a valid R installation with the following packages:\nedgeR, BiocParallel, RhpcBLASctl"
44
43
  ) from e
45
44
 
46
45
  # Convert dataframe
47
46
  with localconverter(get_conversion() + numpy2ri.converter):
48
47
  expr = self.adata.X if self.layer is None else self.adata.layers[self.layer]
49
- if issparse(expr):
50
- expr = expr.T.toarray()
51
- else:
52
- expr = expr.T
48
+ expr = expr.T.toarray() if issparse(expr) else expr.T
53
49
 
54
50
  with localconverter(get_conversion() + pandas2ri.converter):
55
51
  expr_r = ro.conversion.py2rpy(pd.DataFrame(expr, index=self.adata.var_names, columns=self.adata.obs_names))
@@ -72,8 +68,8 @@ class EdgeR(LinearModelBase):
72
68
  ro.globalenv["fit"] = fit
73
69
  self.fit = fit
74
70
 
75
- def _test_single_contrast(self, contrast: Sequence[float], **kwargs) -> pd.DataFrame:
76
- """Conduct test for each contrast and return a data frame
71
+ def _test_single_contrast(self, contrast: Sequence[float], **kwargs) -> pd.DataFrame: # noqa: D417
72
+ """Conduct test for each contrast and return a data frame.
77
73
 
78
74
  Args:
79
75
  contrast: numpy array of integars indicating contrast i.e. [-1, 0, 1, 0, 0]
@@ -100,7 +96,7 @@ class EdgeR(LinearModelBase):
100
96
  importr("edgeR")
101
97
  except ImportError:
102
98
  raise ImportError(
103
- "edgeR requires a valid R installation with the following packages: " "edgeR, BiocParallel, RhpcBLASctl"
99
+ "edgeR requires a valid R installation with the following packages: edgeR, BiocParallel, RhpcBLASctl"
104
100
  ) from None
105
101
 
106
102
  # Convert vector to R, which drops a category like `self.design_matrix` to use the intercept for the left out.
@@ -16,7 +16,7 @@ from ._checks import check_is_integer_matrix
16
16
 
17
17
 
18
18
  class PyDESeq2(LinearModelBase):
19
- """Differential expression test using a PyDESeq2"""
19
+ """Differential expression test using a PyDESeq2."""
20
20
 
21
21
  def __init__(
22
22
  self, adata: AnnData, design: str | ndarray, *, mask: str | None = None, layer: str | None = None, **kwargs
@@ -1,4 +1,4 @@
1
- """Simple tests such as t-test, wilcoxon"""
1
+ """Simple tests such as t-test, wilcoxon."""
2
2
 
3
3
  import warnings
4
4
  from abc import abstractmethod
@@ -10,7 +10,7 @@ import pandas as pd
10
10
  import scipy.stats
11
11
  import statsmodels
12
12
  from anndata import AnnData
13
- from pandas.core.api import DataFrame as DataFrame
13
+ from pandas.core.api import DataFrame
14
14
  from scipy.sparse import diags, issparse
15
15
  from tqdm.auto import tqdm
16
16
 
@@ -152,7 +152,7 @@ class WilcoxonTest(SimpleComparisonBase):
152
152
 
153
153
 
154
154
  class TTest(SimpleComparisonBase):
155
- """Perform a unpaired or paired T-test"""
155
+ """Perform a unpaired or paired T-test."""
156
156
 
157
157
  @staticmethod
158
158
  def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
@@ -6,14 +6,14 @@ import statsmodels.api as sm
6
6
  from tqdm.auto import tqdm
7
7
 
8
8
  from ._base import LinearModelBase
9
- from ._checks import check_is_integer_matrix
9
+ from ._checks import check_is_numeric_matrix
10
10
 
11
11
 
12
12
  class Statsmodels(LinearModelBase):
13
- """Differential expression test using a statsmodels linear regression"""
13
+ """Differential expression test using a statsmodels linear regression."""
14
14
 
15
15
  def _check_counts(self):
16
- check_is_integer_matrix(self.data)
16
+ check_is_numeric_matrix(self.data)
17
17
 
18
18
  def fit(
19
19
  self,
@@ -55,7 +55,10 @@ class Statsmodels(LinearModelBase):
55
55
  "t_value": t_test.tvalue.item(),
56
56
  "sd": t_test.sd.item(),
57
57
  "log_fc": t_test.effect.item(),
58
- "adj_p_value": statsmodels.stats.multitest.fdrcorrection(np.array([t_test.pvalue]))[1].item(),
59
58
  }
60
59
  )
61
- return pd.DataFrame(res).sort_values("p_value")
60
+ return (
61
+ pd.DataFrame(res)
62
+ .sort_values("p_value")
63
+ .assign(adj_p_value=lambda x: statsmodels.stats.multitest.fdrcorrection(x["p_value"])[1])
64
+ )
@@ -83,8 +83,7 @@ class DistanceTest:
83
83
  contrast: str,
84
84
  show_progressbar: bool = True,
85
85
  ) -> pd.DataFrame:
86
- """Run a permutation test using the specified distance metric, testing
87
- all groups of cells against a specified contrast group ("control").
86
+ """Run a permutation test using the specified distance metric, testing all groups of cells against a specified contrast group ("control").
88
87
 
89
88
  Args:
90
89
  adata: Annotated data matrix.