pertpy 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. pertpy/__init__.py +5 -1
  2. pertpy/_doc.py +1 -3
  3. pertpy/_types.py +6 -0
  4. pertpy/data/_dataloader.py +68 -24
  5. pertpy/data/_datasets.py +9 -9
  6. pertpy/metadata/__init__.py +2 -1
  7. pertpy/metadata/_cell_line.py +133 -25
  8. pertpy/metadata/_look_up.py +13 -19
  9. pertpy/metadata/_moa.py +1 -1
  10. pertpy/preprocessing/_guide_rna.py +138 -44
  11. pertpy/preprocessing/_guide_rna_mixture.py +17 -19
  12. pertpy/tools/__init__.py +1 -1
  13. pertpy/tools/_augur.py +106 -98
  14. pertpy/tools/_cinemaot.py +74 -114
  15. pertpy/tools/_coda/_base_coda.py +129 -145
  16. pertpy/tools/_coda/_sccoda.py +66 -69
  17. pertpy/tools/_coda/_tasccoda.py +71 -79
  18. pertpy/tools/_dialogue.py +48 -40
  19. pertpy/tools/_differential_gene_expression/_base.py +21 -31
  20. pertpy/tools/_differential_gene_expression/_checks.py +4 -6
  21. pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
  22. pertpy/tools/_differential_gene_expression/_edger.py +6 -10
  23. pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
  24. pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
  25. pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
  26. pertpy/tools/_distances/_distance_tests.py +1 -2
  27. pertpy/tools/_distances/_distances.py +31 -45
  28. pertpy/tools/_enrichment.py +7 -22
  29. pertpy/tools/_milo.py +19 -15
  30. pertpy/tools/_mixscape.py +73 -75
  31. pertpy/tools/_perturbation_space/_clustering.py +4 -4
  32. pertpy/tools/_perturbation_space/_comparison.py +4 -4
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
  35. pertpy/tools/_perturbation_space/_simple.py +12 -14
  36. pertpy/tools/_scgen/_scgen.py +16 -17
  37. pertpy/tools/_scgen/_scgenvae.py +2 -2
  38. pertpy/tools/_scgen/_utils.py +3 -1
  39. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/METADATA +36 -20
  40. pertpy-0.11.0.dist-info/RECORD +58 -0
  41. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
  42. pertpy/tools/_kernel_pca.py +0 -50
  43. pertpy-0.10.0.dist-info/RECORD +0 -58
  44. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
pertpy/tools/_augur.py CHANGED
@@ -11,7 +11,6 @@ import matplotlib.pyplot as plt
11
11
  import numpy as np
12
12
  import pandas as pd
13
13
  import scanpy as sc
14
- import statsmodels.api as sm
15
14
  from anndata import AnnData
16
15
  from joblib import Parallel, delayed
17
16
  from lamin_utils import logger
@@ -34,6 +33,7 @@ from sklearn.metrics import (
34
33
  from sklearn.model_selection import StratifiedKFold, cross_validate
35
34
  from sklearn.preprocessing import LabelEncoder
36
35
  from skmisc.loess import loess
36
+ from statsmodels.api import OLS
37
37
  from statsmodels.stats.multitest import fdrcorrection
38
38
 
39
39
  from pertpy._doc import _doc_params, doc_common_plot_args
@@ -43,14 +43,24 @@ if TYPE_CHECKING:
43
43
  from matplotlib.figure import Figure
44
44
 
45
45
 
46
- @dataclass
47
- class Params:
48
- """Type signature for random forest and logistic regression parameters.
46
+ class Augur:
47
+ """Python implementation of Augur."""
48
+
49
+ def __init__(
50
+ self,
51
+ estimator: Literal["random_forest_classifier", "random_forest_regressor", "logistic_regression_classifier"],
52
+ n_estimators: int = 100,
53
+ max_depth: int | None = None,
54
+ max_features: Literal["auto", "log2", "sqrt"] | int | float = 2,
55
+ penalty: Literal["l1", "l2", "elasticnet", "none"] = "l2",
56
+ random_state: int | None = None,
57
+ ):
58
+ """Defines the Augur estimator model and parameters.
49
59
 
50
- Parameters:
51
- n_estimators: defines the number of trees in the forest;
52
- max_depth: specifies the maximal depth of each tree;
53
- max_features: specifies the maximal number of features considered when looking at best split.
60
+ estimator: The scikit-learn estimator model that classifies.
61
+ n_estimators: Number of trees in the forest.
62
+ max_depth: Maximal depth of each tree.
63
+ max_features: Maximal number of features considered when looking at best split.
54
64
 
55
65
  * if int then consider max_features for each split
56
66
  * if float consider round(max_features*n_features)
@@ -58,38 +68,26 @@ class Params:
58
68
  * if `log2` then max_features=log2(n_features)
59
69
  * if `sqrt` then max_featuers=sqrt(n_features)
60
70
 
61
- penalty: defines the norm of the penalty used in logistic regression
71
+ penalty: Norm of the penalty used in logistic regression
62
72
 
63
73
  * if `l1` then L1 penalty is added
64
74
  * if `l2` then L2 penalty is added (default)
65
75
  * if `elasticnet` both L1 and L2 penalties are added
66
76
  * if `none` no penalty is added
67
-
68
- random_state: sets random model seed
69
- """
70
-
71
- n_estimators: int = 100
72
- max_depth: int | None = None
73
- max_features: Literal["auto"] | Literal["log2"] | Literal["sqrt"] | int | float = 2
74
- penalty: Literal["l1"] | Literal["l2"] | Literal["elasticnet"] | Literal["none"] = "l2"
75
- random_state: int | None = None
76
-
77
-
78
- class Augur:
79
- """Python implementation of Augur."""
80
-
81
- def __init__(
82
- self,
83
- estimator: Literal["random_forest_classifier"]
84
- | Literal["random_forest_regressor"]
85
- | Literal["logistic_regression_classifier"],
86
- params: Params | None = None,
87
- ):
88
- self.estimator = self.create_estimator(classifier=estimator, params=params)
77
+ """
78
+ self.estimator = self.create_estimator(
79
+ classifier=estimator,
80
+ n_estimators=n_estimators,
81
+ max_depth=max_depth,
82
+ max_features=max_features,
83
+ penalty=penalty,
84
+ random_state=random_state,
85
+ )
89
86
 
90
87
  def load(
91
88
  self,
92
89
  input: AnnData | pd.DataFrame,
90
+ *,
93
91
  meta: pd.DataFrame | None = None,
94
92
  label_col: str = "label_col",
95
93
  cell_type_col: str = "cell_type_col",
@@ -99,8 +97,8 @@ class Augur:
99
97
  """Loads the input data.
100
98
 
101
99
  Args:
102
- input: Anndata or matrix containing gene expression values (genes in rows, cells in columns) and optionally meta
103
- data about each cell.
100
+ input: Anndata or matrix containing gene expression values (genes in rows, cells in columns)
101
+ and optionally meta data about each cell.
104
102
  meta: Optional Pandas DataFrame containing meta data about each cell.
105
103
  label_col: column of the meta DataFrame or the Anndata or matrix containing the condition labels for each cell
106
104
  in the cell-by-gene expression matrix
@@ -110,8 +108,8 @@ class Augur:
110
108
  treatment_label: in the case of more than two labels, this label is used in the analysis
111
109
 
112
110
  Returns:
113
- Anndata object containing gene expression values (cells in rows, genes in columns) and cell type, label and y
114
- dummy variables as obs
111
+ Anndata object containing gene expression values (cells in rows, genes in columns)
112
+ and cell type, label and y dummy variables as obs
115
113
 
116
114
  Examples:
117
115
  >>> import pertpy as pt
@@ -157,12 +155,13 @@ class Augur:
157
155
 
158
156
  def create_estimator(
159
157
  self,
160
- classifier: (
161
- Literal["random_forest_classifier"]
162
- | Literal["random_forest_regressor"]
163
- | Literal["logistic_regression_classifier"]
164
- ),
165
- params: Params | None = None,
158
+ classifier: (Literal["random_forest_classifier", "random_forest_regressor", "logistic_regression_classifier"]),
159
+ *,
160
+ n_estimators: int = 100,
161
+ max_depth: int | None = None,
162
+ max_features: Literal["auto", "log2", "sqrt"] | int | float = 2,
163
+ penalty: Literal["l1", "l2", "elasticnet", "none"] = "l2",
164
+ random_state: int | None = None,
166
165
  ) -> RandomForestClassifier | RandomForestRegressor | LogisticRegression:
167
166
  """Creates a model object of the provided type and populates it with desired parameters.
168
167
 
@@ -170,35 +169,46 @@ class Augur:
170
169
  classifier: classifier to use in calculating the area under the curve.
171
170
  Either random forest classifier or logistic regression for categorical data
172
171
  or random forest regressor for continous data
173
- params: parameters used to populate the model object. Default values are `n_estimators` =
174
- 100, `max_depth` = None, `max_features` = 2, `penalty` = `l2`, `random_state` = None.
172
+ n_estimators: Number of trees in the forest.
173
+ max_depth: Maximal depth of each tree.
174
+ max_features: Maximal number of features considered when looking at best split.
175
175
 
176
- Returns:
177
- Estimator object.
176
+ * if int then consider max_features for each split
177
+ * if float consider round(max_features*n_features)
178
+ * if `auto` then max_features=n_features (default)
179
+ * if `log2` then max_features=log2(n_features)
180
+ * if `sqrt` then max_featuers=sqrt(n_features)
181
+
182
+ penalty: Norm of the penalty used in logistic regression
183
+
184
+ * if `l1` then L1 penalty is added
185
+ * if `l2` then L2 penalty is added (default)
186
+ * if `elasticnet` both L1 and L2 penalties are added
187
+ * if `none` no penalty is added
188
+
189
+ random_state: Random model seed.
178
190
 
179
191
  Examples:
180
192
  >>> import pertpy as pt
181
193
  >>> augur = pt.tl.Augur("random_forest_classifier")
182
194
  >>> estimator = augur.create_estimator("logistic_regression_classifier")
183
195
  """
184
- if params is None:
185
- params = Params()
186
196
  if classifier == "random_forest_classifier":
187
197
  return RandomForestClassifier(
188
- n_estimators=params.n_estimators,
189
- max_depth=params.max_depth,
190
- max_features=params.max_features,
191
- random_state=params.random_state,
198
+ n_estimators=n_estimators,
199
+ max_depth=max_depth,
200
+ max_features=max_features,
201
+ random_state=random_state,
192
202
  )
193
203
  elif classifier == "random_forest_regressor":
194
204
  return RandomForestRegressor(
195
- n_estimators=params.n_estimators,
196
- max_depth=params.max_depth,
197
- max_features=params.max_features,
198
- random_state=params.random_state,
205
+ n_estimators=n_estimators,
206
+ max_depth=max_depth,
207
+ max_features=max_features,
208
+ random_state=random_state,
199
209
  )
200
210
  elif classifier == "logistic_regression_classifier":
201
- return LogisticRegression(penalty=params.penalty, random_state=params.random_state)
211
+ return LogisticRegression(penalty=penalty, random_state=random_state)
202
212
  else:
203
213
  raise ValueError("Invalid classifier")
204
214
 
@@ -231,15 +241,16 @@ class Augur:
231
241
  if categorical:
232
242
  label_subsamples = []
233
243
  y_encodings = adata.obs["y_"].unique()
234
- for code in y_encodings:
235
- label_subsamples.append(
236
- sc.pp.subsample(
237
- adata[adata.obs["y_"] == code, features],
238
- n_obs=subsample_size,
239
- copy=True,
240
- random_state=random_state,
241
- )
244
+ label_subsamples = [
245
+ sc.pp.subsample(
246
+ adata[adata.obs["y_"] == code, features],
247
+ n_obs=subsample_size,
248
+ copy=True,
249
+ random_state=random_state,
242
250
  )
251
+ for code in y_encodings
252
+ ]
253
+
243
254
  subsample = ad.concat([*label_subsamples], index_unique=None)
244
255
  else:
245
256
  subsample = sc.pp.subsample(adata[:, features], n_obs=subsample_size, copy=True, random_state=random_state)
@@ -259,6 +270,7 @@ class Augur:
259
270
  def draw_subsample(
260
271
  self,
261
272
  adata: AnnData,
273
+ *,
262
274
  augur_mode: str,
263
275
  subsample_size: int,
264
276
  feature_perc: float,
@@ -319,6 +331,7 @@ class Augur:
319
331
  def cross_validate_subsample(
320
332
  self,
321
333
  adata: AnnData,
334
+ *,
322
335
  augur_mode: str,
323
336
  subsample_size: int,
324
337
  folds: int,
@@ -358,8 +371,8 @@ class Augur:
358
371
  """
359
372
  subsample = self.draw_subsample(
360
373
  adata,
361
- augur_mode,
362
- subsample_size,
374
+ augur_mode=augur_mode,
375
+ subsample_size=subsample_size,
363
376
  feature_perc=feature_perc,
364
377
  categorical=is_classifier(self.estimator),
365
378
  random_state=subsample_idx,
@@ -373,7 +386,7 @@ class Augur:
373
386
  )
374
387
  return results
375
388
 
376
- def ccc_score(self, y_true, y_pred) -> float:
389
+ def ccc_score(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
377
390
  """Implementation of Lin's Concordance correlation coefficient, based on https://gitlab.com/-/snippets/1730605.
378
391
 
379
392
  Args:
@@ -405,7 +418,7 @@ class Augur:
405
418
  """Set scoring fuctions for cross-validation based on estimator.
406
419
 
407
420
  Args:
408
- multiclass: `True` if there are more than two target classes
421
+ multiclass: Whether there are more than two target classes
409
422
  zero_division: 0 or 1 or `warn`; Sets the value to return when there is a zero division. If
410
423
  set to “warn”, this acts as 0, but warnings are also raised. Precision metric parameter.
411
424
 
@@ -435,7 +448,7 @@ class Augur:
435
448
  "f1": make_scorer(f1_score, average="binary"),
436
449
  "recall": make_scorer(recall_score, average="binary"),
437
450
  }
438
- if isinstance(self.estimator, RandomForestClassifier) or isinstance(self.estimator, LogisticRegression)
451
+ if isinstance(self.estimator, RandomForestClassifier | LogisticRegression)
439
452
  else {
440
453
  "augur_score": make_scorer(self.ccc_score),
441
454
  "r2": make_scorer(r2_score),
@@ -448,6 +461,7 @@ class Augur:
448
461
  def run_cross_validation(
449
462
  self,
450
463
  subsample: AnnData,
464
+ *,
451
465
  subsample_idx: int,
452
466
  folds: int,
453
467
  random_state: int | None,
@@ -479,7 +493,7 @@ class Augur:
479
493
  """
480
494
  x = subsample.to_df()
481
495
  y = subsample.obs["y_"]
482
- scorer = self.set_scorer(multiclass=True if len(y.unique()) > 2 else False, zero_division=zero_division)
496
+ scorer = self.set_scorer(multiclass=len(y.unique()) > 2, zero_division=zero_division)
483
497
  folds = StratifiedKFold(n_splits=folds, random_state=random_state, shuffle=True)
484
498
 
485
499
  results = cross_validate(
@@ -492,12 +506,12 @@ class Augur:
492
506
  )
493
507
 
494
508
  results["subsample_idx"] = subsample_idx
495
- for score in scorer.keys():
509
+ for score in scorer:
496
510
  results[f"mean_{score}"] = results[f"test_{score}"].mean()
497
511
 
498
512
  # feature importances
499
513
  feature_importances = defaultdict(list)
500
- if isinstance(self.estimator, RandomForestClassifier) or isinstance(self.estimator, RandomForestRegressor):
514
+ if isinstance(self.estimator, RandomForestClassifier | RandomForestRegressor):
501
515
  for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
502
516
  feature_importances["genes"].extend(x.columns.tolist())
503
517
  feature_importances["feature_importances"].extend(estimator.feature_importances_.tolist())
@@ -589,9 +603,9 @@ class Augur:
589
603
  sigma2_x = np.sum(np.power(loess1.outputs.fitted_residuals, 2)) / nobs
590
604
  sigma2_z = np.sum(np.power(loess2.outputs.fitted_residuals, 2)) / nobs
591
605
  yhat_x = loess1.outputs.fitted_values
592
- res_dx = sm.OLS(yhat_x, z).fit()
606
+ res_dx = OLS(yhat_x, z).fit()
593
607
  err_zx = res_dx.resid
594
- res_xzx = sm.OLS(err_zx, x).fit()
608
+ res_xzx = OLS(err_zx, x).fit()
595
609
  err_xzx = res_xzx.resid
596
610
 
597
611
  sigma2_zx = sigma2_x + np.dot(err_zx.T, err_zx) / nobs
@@ -602,7 +616,9 @@ class Augur:
602
616
 
603
617
  return q, pval
604
618
 
605
- def select_variance(self, adata: AnnData, var_quantile: float, filter_negative_residuals: bool, span: float = 0.75):
619
+ def select_variance(
620
+ self, adata: AnnData, *, var_quantile: float, filter_negative_residuals: bool, span: float = 0.75
621
+ ):
606
622
  """Feature selection based on Augur implementation.
607
623
 
608
624
  Args:
@@ -656,11 +672,8 @@ class Augur:
656
672
  cox1 = self.cox_compare(fit1, fit2)
657
673
  cox2 = self.cox_compare(fit2, fit1)
658
674
 
659
- # compare pvalues
660
- if cox1[1] < cox2[1]:
661
- model = fit1
662
- else:
663
- model = fit2
675
+ # compare p values
676
+ model = fit1 if cox1[1] < cox2[1] else fit2
664
677
 
665
678
  residuals = model.outputs.fitted_residuals
666
679
 
@@ -676,6 +689,7 @@ class Augur:
676
689
  def predict(
677
690
  self,
678
691
  adata: AnnData,
692
+ *,
679
693
  n_subsamples: int = 50,
680
694
  subsample_size: int = 20,
681
695
  folds: int = 3,
@@ -717,7 +731,7 @@ class Augur:
717
731
  set to “warn”, this acts as 0, but warnings are also raised. Precision metric parameter.
718
732
 
719
733
  Returns:
720
- A tuple with a dictionary containing the following keys with an updated AnnData object with mean_augur_score metrics in obs:
734
+ A tuple with a dictionary containing the following keys with an updated AnnData object with mean_augur_score metrics in obs.
721
735
 
722
736
  * summary_metrics: Pandas Dataframe containing mean metrics for each cell type
723
737
  * feature_importances: Pandas Dataframe containing feature importances of genes across all cross validation runs
@@ -756,7 +770,7 @@ class Augur:
756
770
  adata.obs["augur_score"] = nan
757
771
  for cell_type in track(adata.obs["cell_type"].unique(), description="Processing data..."):
758
772
  cell_type_subsample = adata[adata.obs["cell_type"] == cell_type].copy()
759
- if augur_mode == "default" or augur_mode == "permute":
773
+ if augur_mode in ("default", "permute"):
760
774
  cell_type_subsample = (
761
775
  self.select_highly_variable(cell_type_subsample)
762
776
  if not select_variance_features
@@ -846,10 +860,10 @@ class Augur:
846
860
  between two conditions respectively compared to the control.
847
861
 
848
862
  Args:
849
- augur1: Augurpy results from condition 1, obtained from `predict()[1]`
850
- augur2: Augurpy results from condition 2, obtained from `predict()[1]`
851
- permuted1: permuted Augurpy results from condition 1, obtained from `predict()` with argument `augur_mode=permute`
852
- permuted2: permuted Augurpy results from condition 2, obtained from `predict()` with argument `augur_mode=permute`
863
+ augur_results1: Augurpy results from condition 1, obtained from `predict()[1]`
864
+ augur_results2: Augurpy results from condition 2, obtained from `predict()[1]`
865
+ permuted_results1: permuted Augurpy results from condition 1, obtained from `predict()` with argument `augur_mode=permute`
866
+ permuted_results2: permuted Augurpy results from condition 2, obtained from `predict()` with argument `augur_mode=permute`
853
867
  n_subsamples: number of subsamples to pool when calculating the mean augur score for each permutation.
854
868
  n_permutations: the total number of mean augur scores to calculate from a background distribution
855
869
 
@@ -972,7 +986,7 @@ class Augur:
972
986
  return delta
973
987
 
974
988
  @_doc_params(common_plot_args=doc_common_plot_args)
975
- def plot_dp_scatter(
989
+ def plot_dp_scatter( # pragma: no cover # noqa: D417
976
990
  self,
977
991
  results: pd.DataFrame,
978
992
  *,
@@ -1042,7 +1056,7 @@ class Augur:
1042
1056
  return None
1043
1057
 
1044
1058
  @_doc_params(common_plot_args=doc_common_plot_args)
1045
- def plot_important_features(
1059
+ def plot_important_features( # pragma: no cover # noqa: D417
1046
1060
  self,
1047
1061
  data: dict[str, Any],
1048
1062
  *,
@@ -1076,10 +1090,7 @@ class Augur:
1076
1090
  Preview:
1077
1091
  .. image:: /_static/docstring_previews/augur_important_features.png
1078
1092
  """
1079
- if isinstance(data, AnnData):
1080
- results = data.uns[key]
1081
- else:
1082
- results = data
1093
+ results = data.uns[key] if isinstance(data, AnnData) else data
1083
1094
  n_features = (
1084
1095
  results["feature_importances"]
1085
1096
  .groupby("genes", as_index=False)
@@ -1108,7 +1119,7 @@ class Augur:
1108
1119
  return None
1109
1120
 
1110
1121
  @_doc_params(common_plot_args=doc_common_plot_args)
1111
- def plot_lollipop(
1122
+ def plot_lollipop( # pragma: no cover # noqa: D417
1112
1123
  self,
1113
1124
  data: dict[str, Any] | AnnData,
1114
1125
  *,
@@ -1140,10 +1151,7 @@ class Augur:
1140
1151
  Preview:
1141
1152
  .. image:: /_static/docstring_previews/augur_lollipop.png
1142
1153
  """
1143
- if isinstance(data, AnnData):
1144
- results = data.uns[key]
1145
- else:
1146
- results = data
1154
+ results = data.uns[key] if isinstance(data, AnnData) else data
1147
1155
  if ax is None:
1148
1156
  fig, ax = plt.subplots()
1149
1157
  y_axes_range = range(1, len(results["summary_metrics"].columns) + 1)
@@ -1169,7 +1177,7 @@ class Augur:
1169
1177
  return None
1170
1178
 
1171
1179
  @_doc_params(common_plot_args=doc_common_plot_args)
1172
- def plot_scatterplot(
1180
+ def plot_scatterplot( # pragma: no cover # noqa: D417
1173
1181
  self,
1174
1182
  results1: dict[str, Any],
1175
1183
  results2: dict[str, Any],