pertpy 0.9.5__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. pertpy/__init__.py +5 -1
  2. pertpy/_doc.py +2 -5
  3. pertpy/_types.py +6 -0
  4. pertpy/data/_dataloader.py +68 -24
  5. pertpy/data/_datasets.py +9 -9
  6. pertpy/metadata/__init__.py +2 -1
  7. pertpy/metadata/_cell_line.py +136 -30
  8. pertpy/metadata/_look_up.py +13 -19
  9. pertpy/metadata/_moa.py +1 -1
  10. pertpy/preprocessing/_guide_rna.py +221 -39
  11. pertpy/preprocessing/_guide_rna_mixture.py +177 -0
  12. pertpy/tools/__init__.py +1 -1
  13. pertpy/tools/_augur.py +138 -142
  14. pertpy/tools/_cinemaot.py +75 -117
  15. pertpy/tools/_coda/_base_coda.py +150 -174
  16. pertpy/tools/_coda/_sccoda.py +66 -69
  17. pertpy/tools/_coda/_tasccoda.py +71 -79
  18. pertpy/tools/_dialogue.py +60 -56
  19. pertpy/tools/_differential_gene_expression/_base.py +25 -43
  20. pertpy/tools/_differential_gene_expression/_checks.py +4 -6
  21. pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
  22. pertpy/tools/_differential_gene_expression/_edger.py +6 -10
  23. pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
  24. pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
  25. pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
  26. pertpy/tools/_distances/_distance_tests.py +1 -2
  27. pertpy/tools/_distances/_distances.py +86 -92
  28. pertpy/tools/_enrichment.py +8 -25
  29. pertpy/tools/_milo.py +23 -27
  30. pertpy/tools/_mixscape.py +261 -175
  31. pertpy/tools/_perturbation_space/_clustering.py +4 -4
  32. pertpy/tools/_perturbation_space/_comparison.py +4 -4
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
  35. pertpy/tools/_perturbation_space/_simple.py +13 -17
  36. pertpy/tools/_scgen/_scgen.py +17 -20
  37. pertpy/tools/_scgen/_scgenvae.py +2 -2
  38. pertpy/tools/_scgen/_utils.py +3 -1
  39. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/METADATA +37 -21
  40. pertpy-0.11.0.dist-info/RECORD +58 -0
  41. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
  42. pertpy/tools/_kernel_pca.py +0 -50
  43. pertpy-0.9.5.dist-info/RECORD +0 -57
  44. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
pertpy/tools/_augur.py CHANGED
@@ -11,7 +11,6 @@ import matplotlib.pyplot as plt
11
11
  import numpy as np
12
12
  import pandas as pd
13
13
  import scanpy as sc
14
- import statsmodels.api as sm
15
14
  from anndata import AnnData
16
15
  from joblib import Parallel, delayed
17
16
  from lamin_utils import logger
@@ -34,6 +33,7 @@ from sklearn.metrics import (
34
33
  from sklearn.model_selection import StratifiedKFold, cross_validate
35
34
  from sklearn.preprocessing import LabelEncoder
36
35
  from skmisc.loess import loess
36
+ from statsmodels.api import OLS
37
37
  from statsmodels.stats.multitest import fdrcorrection
38
38
 
39
39
  from pertpy._doc import _doc_params, doc_common_plot_args
@@ -43,14 +43,24 @@ if TYPE_CHECKING:
43
43
  from matplotlib.figure import Figure
44
44
 
45
45
 
46
- @dataclass
47
- class Params:
48
- """Type signature for random forest and logistic regression parameters.
46
+ class Augur:
47
+ """Python implementation of Augur."""
49
48
 
50
- Parameters:
51
- n_estimators: defines the number of trees in the forest;
52
- max_depth: specifies the maximal depth of each tree;
53
- max_features: specifies the maximal number of features considered when looking at best split.
49
+ def __init__(
50
+ self,
51
+ estimator: Literal["random_forest_classifier", "random_forest_regressor", "logistic_regression_classifier"],
52
+ n_estimators: int = 100,
53
+ max_depth: int | None = None,
54
+ max_features: Literal["auto", "log2", "sqrt"] | int | float = 2,
55
+ penalty: Literal["l1", "l2", "elasticnet", "none"] = "l2",
56
+ random_state: int | None = None,
57
+ ):
58
+ """Defines the Augur estimator model and parameters.
59
+
60
+ estimator: The scikit-learn estimator model that classifies.
61
+ n_estimators: Number of trees in the forest.
62
+ max_depth: Maximal depth of each tree.
63
+ max_features: Maximal number of features considered when looking at best split.
54
64
 
55
65
  * if int then consider max_features for each split
56
66
  * if float consider round(max_features*n_features)
@@ -58,38 +68,26 @@ class Params:
58
68
  * if `log2` then max_features=log2(n_features)
59
69
  * if `sqrt` then max_featuers=sqrt(n_features)
60
70
 
61
- penalty: defines the norm of the penalty used in logistic regression
71
+ penalty: Norm of the penalty used in logistic regression
62
72
 
63
73
  * if `l1` then L1 penalty is added
64
74
  * if `l2` then L2 penalty is added (default)
65
75
  * if `elasticnet` both L1 and L2 penalties are added
66
76
  * if `none` no penalty is added
67
-
68
- random_state: sets random model seed
69
- """
70
-
71
- n_estimators: int = 100
72
- max_depth: int | None = None
73
- max_features: Literal["auto"] | Literal["log2"] | Literal["sqrt"] | int | float = 2
74
- penalty: Literal["l1"] | Literal["l2"] | Literal["elasticnet"] | Literal["none"] = "l2"
75
- random_state: int | None = None
76
-
77
-
78
- class Augur:
79
- """Python implementation of Augur."""
80
-
81
- def __init__(
82
- self,
83
- estimator: Literal["random_forest_classifier"]
84
- | Literal["random_forest_regressor"]
85
- | Literal["logistic_regression_classifier"],
86
- params: Params | None = None,
87
- ):
88
- self.estimator = self.create_estimator(classifier=estimator, params=params)
77
+ """
78
+ self.estimator = self.create_estimator(
79
+ classifier=estimator,
80
+ n_estimators=n_estimators,
81
+ max_depth=max_depth,
82
+ max_features=max_features,
83
+ penalty=penalty,
84
+ random_state=random_state,
85
+ )
89
86
 
90
87
  def load(
91
88
  self,
92
89
  input: AnnData | pd.DataFrame,
90
+ *,
93
91
  meta: pd.DataFrame | None = None,
94
92
  label_col: str = "label_col",
95
93
  cell_type_col: str = "cell_type_col",
@@ -99,8 +97,8 @@ class Augur:
99
97
  """Loads the input data.
100
98
 
101
99
  Args:
102
- input: Anndata or matrix containing gene expression values (genes in rows, cells in columns) and optionally meta
103
- data about each cell.
100
+ input: Anndata or matrix containing gene expression values (genes in rows, cells in columns)
101
+ and optionally meta data about each cell.
104
102
  meta: Optional Pandas DataFrame containing meta data about each cell.
105
103
  label_col: column of the meta DataFrame or the Anndata or matrix containing the condition labels for each cell
106
104
  in the cell-by-gene expression matrix
@@ -110,8 +108,8 @@ class Augur:
110
108
  treatment_label: in the case of more than two labels, this label is used in the analysis
111
109
 
112
110
  Returns:
113
- Anndata object containing gene expression values (cells in rows, genes in columns) and cell type, label and y
114
- dummy variables as obs
111
+ Anndata object containing gene expression values (cells in rows, genes in columns)
112
+ and cell type, label and y dummy variables as obs
115
113
 
116
114
  Examples:
117
115
  >>> import pertpy as pt
@@ -157,12 +155,13 @@ class Augur:
157
155
 
158
156
  def create_estimator(
159
157
  self,
160
- classifier: (
161
- Literal["random_forest_classifier"]
162
- | Literal["random_forest_regressor"]
163
- | Literal["logistic_regression_classifier"]
164
- ),
165
- params: Params | None = None,
158
+ classifier: (Literal["random_forest_classifier", "random_forest_regressor", "logistic_regression_classifier"]),
159
+ *,
160
+ n_estimators: int = 100,
161
+ max_depth: int | None = None,
162
+ max_features: Literal["auto", "log2", "sqrt"] | int | float = 2,
163
+ penalty: Literal["l1", "l2", "elasticnet", "none"] = "l2",
164
+ random_state: int | None = None,
166
165
  ) -> RandomForestClassifier | RandomForestRegressor | LogisticRegression:
167
166
  """Creates a model object of the provided type and populates it with desired parameters.
168
167
 
@@ -170,35 +169,46 @@ class Augur:
170
169
  classifier: classifier to use in calculating the area under the curve.
171
170
  Either random forest classifier or logistic regression for categorical data
172
171
  or random forest regressor for continous data
173
- params: parameters used to populate the model object. Default values are `n_estimators` =
174
- 100, `max_depth` = None, `max_features` = 2, `penalty` = `l2`, `random_state` = None.
172
+ n_estimators: Number of trees in the forest.
173
+ max_depth: Maximal depth of each tree.
174
+ max_features: Maximal number of features considered when looking at best split.
175
175
 
176
- Returns:
177
- Estimator object.
176
+ * if int then consider max_features for each split
177
+ * if float consider round(max_features*n_features)
178
+ * if `auto` then max_features=n_features (default)
179
+ * if `log2` then max_features=log2(n_features)
180
+ * if `sqrt` then max_featuers=sqrt(n_features)
181
+
182
+ penalty: Norm of the penalty used in logistic regression
183
+
184
+ * if `l1` then L1 penalty is added
185
+ * if `l2` then L2 penalty is added (default)
186
+ * if `elasticnet` both L1 and L2 penalties are added
187
+ * if `none` no penalty is added
188
+
189
+ random_state: Random model seed.
178
190
 
179
191
  Examples:
180
192
  >>> import pertpy as pt
181
193
  >>> augur = pt.tl.Augur("random_forest_classifier")
182
194
  >>> estimator = augur.create_estimator("logistic_regression_classifier")
183
195
  """
184
- if params is None:
185
- params = Params()
186
196
  if classifier == "random_forest_classifier":
187
197
  return RandomForestClassifier(
188
- n_estimators=params.n_estimators,
189
- max_depth=params.max_depth,
190
- max_features=params.max_features,
191
- random_state=params.random_state,
198
+ n_estimators=n_estimators,
199
+ max_depth=max_depth,
200
+ max_features=max_features,
201
+ random_state=random_state,
192
202
  )
193
203
  elif classifier == "random_forest_regressor":
194
204
  return RandomForestRegressor(
195
- n_estimators=params.n_estimators,
196
- max_depth=params.max_depth,
197
- max_features=params.max_features,
198
- random_state=params.random_state,
205
+ n_estimators=n_estimators,
206
+ max_depth=max_depth,
207
+ max_features=max_features,
208
+ random_state=random_state,
199
209
  )
200
210
  elif classifier == "logistic_regression_classifier":
201
- return LogisticRegression(penalty=params.penalty, random_state=params.random_state)
211
+ return LogisticRegression(penalty=penalty, random_state=random_state)
202
212
  else:
203
213
  raise ValueError("Invalid classifier")
204
214
 
@@ -231,15 +241,16 @@ class Augur:
231
241
  if categorical:
232
242
  label_subsamples = []
233
243
  y_encodings = adata.obs["y_"].unique()
234
- for code in y_encodings:
235
- label_subsamples.append(
236
- sc.pp.subsample(
237
- adata[adata.obs["y_"] == code, features],
238
- n_obs=subsample_size,
239
- copy=True,
240
- random_state=random_state,
241
- )
244
+ label_subsamples = [
245
+ sc.pp.subsample(
246
+ adata[adata.obs["y_"] == code, features],
247
+ n_obs=subsample_size,
248
+ copy=True,
249
+ random_state=random_state,
242
250
  )
251
+ for code in y_encodings
252
+ ]
253
+
243
254
  subsample = ad.concat([*label_subsamples], index_unique=None)
244
255
  else:
245
256
  subsample = sc.pp.subsample(adata[:, features], n_obs=subsample_size, copy=True, random_state=random_state)
@@ -259,6 +270,7 @@ class Augur:
259
270
  def draw_subsample(
260
271
  self,
261
272
  adata: AnnData,
273
+ *,
262
274
  augur_mode: str,
263
275
  subsample_size: int,
264
276
  feature_perc: float,
@@ -319,6 +331,7 @@ class Augur:
319
331
  def cross_validate_subsample(
320
332
  self,
321
333
  adata: AnnData,
334
+ *,
322
335
  augur_mode: str,
323
336
  subsample_size: int,
324
337
  folds: int,
@@ -358,8 +371,8 @@ class Augur:
358
371
  """
359
372
  subsample = self.draw_subsample(
360
373
  adata,
361
- augur_mode,
362
- subsample_size,
374
+ augur_mode=augur_mode,
375
+ subsample_size=subsample_size,
363
376
  feature_perc=feature_perc,
364
377
  categorical=is_classifier(self.estimator),
365
378
  random_state=subsample_idx,
@@ -373,7 +386,7 @@ class Augur:
373
386
  )
374
387
  return results
375
388
 
376
- def ccc_score(self, y_true, y_pred) -> float:
389
+ def ccc_score(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
377
390
  """Implementation of Lin's Concordance correlation coefficient, based on https://gitlab.com/-/snippets/1730605.
378
391
 
379
392
  Args:
@@ -405,7 +418,7 @@ class Augur:
405
418
  """Set scoring fuctions for cross-validation based on estimator.
406
419
 
407
420
  Args:
408
- multiclass: `True` if there are more than two target classes
421
+ multiclass: Whether there are more than two target classes
409
422
  zero_division: 0 or 1 or `warn`; Sets the value to return when there is a zero division. If
410
423
  set to “warn”, this acts as 0, but warnings are also raised. Precision metric parameter.
411
424
 
@@ -435,7 +448,7 @@ class Augur:
435
448
  "f1": make_scorer(f1_score, average="binary"),
436
449
  "recall": make_scorer(recall_score, average="binary"),
437
450
  }
438
- if isinstance(self.estimator, RandomForestClassifier) or isinstance(self.estimator, LogisticRegression)
451
+ if isinstance(self.estimator, RandomForestClassifier | LogisticRegression)
439
452
  else {
440
453
  "augur_score": make_scorer(self.ccc_score),
441
454
  "r2": make_scorer(r2_score),
@@ -448,6 +461,7 @@ class Augur:
448
461
  def run_cross_validation(
449
462
  self,
450
463
  subsample: AnnData,
464
+ *,
451
465
  subsample_idx: int,
452
466
  folds: int,
453
467
  random_state: int | None,
@@ -479,7 +493,7 @@ class Augur:
479
493
  """
480
494
  x = subsample.to_df()
481
495
  y = subsample.obs["y_"]
482
- scorer = self.set_scorer(multiclass=True if len(y.unique()) > 2 else False, zero_division=zero_division)
496
+ scorer = self.set_scorer(multiclass=len(y.unique()) > 2, zero_division=zero_division)
483
497
  folds = StratifiedKFold(n_splits=folds, random_state=random_state, shuffle=True)
484
498
 
485
499
  results = cross_validate(
@@ -492,12 +506,12 @@ class Augur:
492
506
  )
493
507
 
494
508
  results["subsample_idx"] = subsample_idx
495
- for score in scorer.keys():
509
+ for score in scorer:
496
510
  results[f"mean_{score}"] = results[f"test_{score}"].mean()
497
511
 
498
512
  # feature importances
499
513
  feature_importances = defaultdict(list)
500
- if isinstance(self.estimator, RandomForestClassifier) or isinstance(self.estimator, RandomForestRegressor):
514
+ if isinstance(self.estimator, RandomForestClassifier | RandomForestRegressor):
501
515
  for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
502
516
  feature_importances["genes"].extend(x.columns.tolist())
503
517
  feature_importances["feature_importances"].extend(estimator.feature_importances_.tolist())
@@ -589,9 +603,9 @@ class Augur:
589
603
  sigma2_x = np.sum(np.power(loess1.outputs.fitted_residuals, 2)) / nobs
590
604
  sigma2_z = np.sum(np.power(loess2.outputs.fitted_residuals, 2)) / nobs
591
605
  yhat_x = loess1.outputs.fitted_values
592
- res_dx = sm.OLS(yhat_x, z).fit()
606
+ res_dx = OLS(yhat_x, z).fit()
593
607
  err_zx = res_dx.resid
594
- res_xzx = sm.OLS(err_zx, x).fit()
608
+ res_xzx = OLS(err_zx, x).fit()
595
609
  err_xzx = res_xzx.resid
596
610
 
597
611
  sigma2_zx = sigma2_x + np.dot(err_zx.T, err_zx) / nobs
@@ -602,7 +616,9 @@ class Augur:
602
616
 
603
617
  return q, pval
604
618
 
605
- def select_variance(self, adata: AnnData, var_quantile: float, filter_negative_residuals: bool, span: float = 0.75):
619
+ def select_variance(
620
+ self, adata: AnnData, *, var_quantile: float, filter_negative_residuals: bool, span: float = 0.75
621
+ ):
606
622
  """Feature selection based on Augur implementation.
607
623
 
608
624
  Args:
@@ -656,11 +672,8 @@ class Augur:
656
672
  cox1 = self.cox_compare(fit1, fit2)
657
673
  cox2 = self.cox_compare(fit2, fit1)
658
674
 
659
- # compare pvalues
660
- if cox1[1] < cox2[1]:
661
- model = fit1
662
- else:
663
- model = fit2
675
+ # compare p values
676
+ model = fit1 if cox1[1] < cox2[1] else fit2
664
677
 
665
678
  residuals = model.outputs.fitted_residuals
666
679
 
@@ -676,6 +689,7 @@ class Augur:
676
689
  def predict(
677
690
  self,
678
691
  adata: AnnData,
692
+ *,
679
693
  n_subsamples: int = 50,
680
694
  subsample_size: int = 20,
681
695
  folds: int = 3,
@@ -685,7 +699,7 @@ class Augur:
685
699
  span: float = 0.75,
686
700
  filter_negative_residuals: bool = False,
687
701
  n_threads: int = 4,
688
- augur_mode: Literal["permute"] | Literal["default"] | Literal["velocity"] = "default",
702
+ augur_mode: Literal["default", "permute", "velocity"] = "default",
689
703
  select_variance_features: bool = True,
690
704
  key_added: str = "augurpy_results",
691
705
  random_state: int | None = None,
@@ -717,7 +731,7 @@ class Augur:
717
731
  set to “warn”, this acts as 0, but warnings are also raised. Precision metric parameter.
718
732
 
719
733
  Returns:
720
- A tuple with a dictionary containing the following keys with an updated AnnData object with mean_augur_score metrics in obs:
734
+ A tuple with a dictionary containing the following keys with an updated AnnData object with mean_augur_score metrics in obs.
721
735
 
722
736
  * summary_metrics: Pandas Dataframe containing mean metrics for each cell type
723
737
  * feature_importances: Pandas Dataframe containing feature importances of genes across all cross validation runs
@@ -756,7 +770,7 @@ class Augur:
756
770
  adata.obs["augur_score"] = nan
757
771
  for cell_type in track(adata.obs["cell_type"].unique(), description="Processing data..."):
758
772
  cell_type_subsample = adata[adata.obs["cell_type"] == cell_type].copy()
759
- if augur_mode == "default" or augur_mode == "permute":
773
+ if augur_mode in ("default", "permute"):
760
774
  cell_type_subsample = (
761
775
  self.select_highly_variable(cell_type_subsample)
762
776
  if not select_variance_features
@@ -846,10 +860,10 @@ class Augur:
846
860
  between two conditions respectively compared to the control.
847
861
 
848
862
  Args:
849
- augur1: Augurpy results from condition 1, obtained from `predict()[1]`
850
- augur2: Augurpy results from condition 2, obtained from `predict()[1]`
851
- permuted1: permuted Augurpy results from condition 1, obtained from `predict()` with argument `augur_mode=permute`
852
- permuted2: permuted Augurpy results from condition 2, obtained from `predict()` with argument `augur_mode=permute`
863
+ augur_results1: Augurpy results from condition 1, obtained from `predict()[1]`
864
+ augur_results2: Augurpy results from condition 2, obtained from `predict()[1]`
865
+ permuted_results1: permuted Augurpy results from condition 1, obtained from `predict()` with argument `augur_mode=permute`
866
+ permuted_results2: permuted Augurpy results from condition 2, obtained from `predict()` with argument `augur_mode=permute`
853
867
  n_subsamples: number of subsamples to pool when calculating the mean augur score for each permutation.
854
868
  n_permutations: the total number of mean augur scores to calculate from a background distribution
855
869
 
@@ -908,41 +922,39 @@ class Augur:
908
922
  .mean()
909
923
  )
910
924
 
911
- sampled_permuted_cv_augur1 = []
912
- sampled_permuted_cv_augur2 = []
925
+ rng = np.random.default_rng()
926
+ sampled_data = []
913
927
 
914
928
  # draw mean aucs for permute1 and permute2
915
929
  for celltype in permuted_cv_augur1["cell_type"].unique():
916
930
  df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
917
931
  df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
918
- for permutation_idx in range(n_permutations):
919
- # subsample
920
- sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
921
- sampled_permuted_cv_augur1.append(
922
- pd.DataFrame(
923
- {
924
- "cell_type": [celltype],
925
- "permutation_idx": [permutation_idx],
926
- "mean": [sample1["augur_score"].mean(axis=0)],
927
- "std": [sample1["augur_score"].std(axis=0)],
928
- }
929
- )
930
- )
931
932
 
932
- sample2 = df2.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
933
- sampled_permuted_cv_augur2.append(
934
- pd.DataFrame(
935
- {
936
- "cell_type": [celltype],
937
- "permutation_idx": [permutation_idx],
938
- "mean": [sample2["augur_score"].mean(axis=0)],
939
- "std": [sample2["augur_score"].std(axis=0)],
940
- }
941
- )
933
+ indices1 = rng.choice(len(df1), size=(n_permutations, n_subsamples), replace=True)
934
+ indices2 = rng.choice(len(df2), size=(n_permutations, n_subsamples), replace=True)
935
+
936
+ scores1 = df1["augur_score"].values[indices1]
937
+ scores2 = df2["augur_score"].values[indices2]
938
+
939
+ means1 = scores1.mean(axis=1)
940
+ means2 = scores2.mean(axis=1)
941
+ stds1 = scores1.std(axis=1)
942
+ stds2 = scores2.std(axis=1)
943
+
944
+ sampled_data.append(
945
+ pd.DataFrame(
946
+ {
947
+ "cell_type": np.repeat(celltype, n_permutations),
948
+ "permutation_idx": np.arange(n_permutations),
949
+ "mean1": means1,
950
+ "mean2": means2,
951
+ "std1": stds1,
952
+ "std2": stds2,
953
+ }
942
954
  )
955
+ )
943
956
 
944
- permuted_samples1 = pd.concat(sampled_permuted_cv_augur1)
945
- permuted_samples2 = pd.concat(sampled_permuted_cv_augur2)
957
+ sampled_df = pd.concat(sampled_data)
946
958
 
947
959
  # delta between augur scores
948
960
  delta = augur_score1.merge(augur_score2, on=["cell_type"], suffixes=("1", "2")).assign(
@@ -950,9 +962,7 @@ class Augur:
950
962
  )
951
963
 
952
964
  # delta between permutation scores
953
- delta_rnd = permuted_samples1.merge(
954
- permuted_samples2, on=["cell_type", "permutation_idx"], suffixes=("1", "2")
955
- ).assign(delta_rnd=lambda x: x.mean2 - x.mean1)
965
+ delta_rnd = sampled_df.assign(delta_rnd=lambda x: x.mean2 - x.mean1)
956
966
 
957
967
  # number of values where permutations are larger than test statistic
958
968
  delta["b"] = (
@@ -967,7 +977,7 @@ class Augur:
967
977
  delta["z"] = (
968
978
  delta["delta_augur"] - delta_rnd.groupby("cell_type", as_index=False).mean()["delta_rnd"]
969
979
  ) / delta_rnd.groupby("cell_type", as_index=False).std()["delta_rnd"]
970
- # calculate pvalues
980
+
971
981
  delta["pval"] = np.minimum(
972
982
  2 * (delta["b"] + 1) / (delta["m"] + 1), 2 * (delta["m"] - delta["b"] + 1) / (delta["m"] + 1)
973
983
  )
@@ -976,13 +986,12 @@ class Augur:
976
986
  return delta
977
987
 
978
988
  @_doc_params(common_plot_args=doc_common_plot_args)
979
- def plot_dp_scatter(
989
+ def plot_dp_scatter( # pragma: no cover # noqa: D417
980
990
  self,
981
991
  results: pd.DataFrame,
982
992
  *,
983
993
  top_n: int = None,
984
994
  ax: Axes = None,
985
- show: bool = True,
986
995
  return_fig: bool = False,
987
996
  ) -> Figure | None:
988
997
  """Plot scatterplot of differential prioritization.
@@ -1041,21 +1050,19 @@ class Augur:
1041
1050
  legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
1042
1051
  ax.add_artist(legend1)
1043
1052
 
1044
- if show:
1045
- plt.show()
1046
1053
  if return_fig:
1047
1054
  return plt.gcf()
1055
+ plt.show()
1048
1056
  return None
1049
1057
 
1050
1058
  @_doc_params(common_plot_args=doc_common_plot_args)
1051
- def plot_important_features(
1059
+ def plot_important_features( # pragma: no cover # noqa: D417
1052
1060
  self,
1053
1061
  data: dict[str, Any],
1054
1062
  *,
1055
1063
  key: str = "augurpy_results",
1056
1064
  top_n: int = 10,
1057
1065
  ax: Axes = None,
1058
- show: bool = True,
1059
1066
  return_fig: bool = False,
1060
1067
  ) -> Figure | None:
1061
1068
  """Plot a lollipop plot of the n features with largest feature importances.
@@ -1083,10 +1090,7 @@ class Augur:
1083
1090
  Preview:
1084
1091
  .. image:: /_static/docstring_previews/augur_important_features.png
1085
1092
  """
1086
- if isinstance(data, AnnData):
1087
- results = data.uns[key]
1088
- else:
1089
- results = data
1093
+ results = data.uns[key] if isinstance(data, AnnData) else data
1090
1094
  n_features = (
1091
1095
  results["feature_importances"]
1092
1096
  .groupby("genes", as_index=False)
@@ -1109,20 +1113,18 @@ class Augur:
1109
1113
  plt.ylabel("Gene")
1110
1114
  plt.yticks(y_axes_range, n_features["genes"])
1111
1115
 
1112
- if show:
1113
- plt.show()
1114
1116
  if return_fig:
1115
1117
  return plt.gcf()
1118
+ plt.show()
1116
1119
  return None
1117
1120
 
1118
1121
  @_doc_params(common_plot_args=doc_common_plot_args)
1119
- def plot_lollipop(
1122
+ def plot_lollipop( # pragma: no cover # noqa: D417
1120
1123
  self,
1121
1124
  data: dict[str, Any] | AnnData,
1122
1125
  *,
1123
1126
  key: str = "augurpy_results",
1124
1127
  ax: Axes = None,
1125
- show: bool = True,
1126
1128
  return_fig: bool = False,
1127
1129
  ) -> Figure | None:
1128
1130
  """Plot a lollipop plot of the mean augur values.
@@ -1149,10 +1151,7 @@ class Augur:
1149
1151
  Preview:
1150
1152
  .. image:: /_static/docstring_previews/augur_lollipop.png
1151
1153
  """
1152
- if isinstance(data, AnnData):
1153
- results = data.uns[key]
1154
- else:
1155
- results = data
1154
+ results = data.uns[key] if isinstance(data, AnnData) else data
1156
1155
  if ax is None:
1157
1156
  fig, ax = plt.subplots()
1158
1157
  y_axes_range = range(1, len(results["summary_metrics"].columns) + 1)
@@ -1172,20 +1171,18 @@ class Augur:
1172
1171
  plt.ylabel("Cell Type")
1173
1172
  plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
1174
1173
 
1175
- if show:
1176
- plt.show()
1177
1174
  if return_fig:
1178
1175
  return plt.gcf()
1176
+ plt.show()
1179
1177
  return None
1180
1178
 
1181
1179
  @_doc_params(common_plot_args=doc_common_plot_args)
1182
- def plot_scatterplot(
1180
+ def plot_scatterplot( # pragma: no cover # noqa: D417
1183
1181
  self,
1184
1182
  results1: dict[str, Any],
1185
1183
  results2: dict[str, Any],
1186
1184
  *,
1187
1185
  top_n: int = None,
1188
- show: bool = True,
1189
1186
  return_fig: bool = False,
1190
1187
  ) -> Figure | None:
1191
1188
  """Create scatterplot with two augur results.
@@ -1243,8 +1240,7 @@ class Augur:
1243
1240
  plt.xlabel("Augur scores 1")
1244
1241
  plt.ylabel("Augur scores 2")
1245
1242
 
1246
- if show:
1247
- plt.show()
1248
1243
  if return_fig:
1249
1244
  return plt.gcf()
1245
+ plt.show()
1250
1246
  return None