pertpy 0.9.5__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. pertpy/__init__.py +5 -1
  2. pertpy/_doc.py +2 -5
  3. pertpy/_types.py +6 -0
  4. pertpy/data/_dataloader.py +68 -24
  5. pertpy/data/_datasets.py +9 -9
  6. pertpy/metadata/__init__.py +2 -1
  7. pertpy/metadata/_cell_line.py +136 -30
  8. pertpy/metadata/_look_up.py +13 -19
  9. pertpy/metadata/_moa.py +1 -1
  10. pertpy/preprocessing/_guide_rna.py +221 -39
  11. pertpy/preprocessing/_guide_rna_mixture.py +177 -0
  12. pertpy/tools/__init__.py +1 -1
  13. pertpy/tools/_augur.py +138 -142
  14. pertpy/tools/_cinemaot.py +75 -117
  15. pertpy/tools/_coda/_base_coda.py +150 -174
  16. pertpy/tools/_coda/_sccoda.py +66 -69
  17. pertpy/tools/_coda/_tasccoda.py +71 -79
  18. pertpy/tools/_dialogue.py +60 -56
  19. pertpy/tools/_differential_gene_expression/_base.py +25 -43
  20. pertpy/tools/_differential_gene_expression/_checks.py +4 -6
  21. pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
  22. pertpy/tools/_differential_gene_expression/_edger.py +6 -10
  23. pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
  24. pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
  25. pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
  26. pertpy/tools/_distances/_distance_tests.py +1 -2
  27. pertpy/tools/_distances/_distances.py +86 -92
  28. pertpy/tools/_enrichment.py +8 -25
  29. pertpy/tools/_milo.py +23 -27
  30. pertpy/tools/_mixscape.py +261 -175
  31. pertpy/tools/_perturbation_space/_clustering.py +4 -4
  32. pertpy/tools/_perturbation_space/_comparison.py +4 -4
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
  35. pertpy/tools/_perturbation_space/_simple.py +13 -17
  36. pertpy/tools/_scgen/_scgen.py +17 -20
  37. pertpy/tools/_scgen/_scgenvae.py +2 -2
  38. pertpy/tools/_scgen/_utils.py +3 -1
  39. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/METADATA +37 -21
  40. pertpy-0.11.0.dist-info/RECORD +58 -0
  41. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
  42. pertpy/tools/_kernel_pca.py +0 -50
  43. pertpy-0.9.5.dist-info/RECORD +0 -57
  44. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
@@ -10,7 +10,7 @@ from ._checks import check_is_integer_matrix
10
10
 
11
11
 
12
12
  class EdgeR(LinearModelBase):
13
- """Differential expression test using EdgeR"""
13
+ """Differential expression test using EdgeR."""
14
14
 
15
15
  def _check_counts(self):
16
16
  check_is_integer_matrix(self.data)
@@ -39,17 +39,13 @@ class EdgeR(LinearModelBase):
39
39
  edger = importr("edgeR")
40
40
  except ImportError as e:
41
41
  raise ImportError(
42
- "edgeR requires a valid R installation with the following packages:\n"
43
- "edgeR, BiocParallel, RhpcBLASctl"
42
+ "edgeR requires a valid R installation with the following packages:\nedgeR, BiocParallel, RhpcBLASctl"
44
43
  ) from e
45
44
 
46
45
  # Convert dataframe
47
46
  with localconverter(get_conversion() + numpy2ri.converter):
48
47
  expr = self.adata.X if self.layer is None else self.adata.layers[self.layer]
49
- if issparse(expr):
50
- expr = expr.T.toarray()
51
- else:
52
- expr = expr.T
48
+ expr = expr.T.toarray() if issparse(expr) else expr.T
53
49
 
54
50
  with localconverter(get_conversion() + pandas2ri.converter):
55
51
  expr_r = ro.conversion.py2rpy(pd.DataFrame(expr, index=self.adata.var_names, columns=self.adata.obs_names))
@@ -72,8 +68,8 @@ class EdgeR(LinearModelBase):
72
68
  ro.globalenv["fit"] = fit
73
69
  self.fit = fit
74
70
 
75
- def _test_single_contrast(self, contrast: Sequence[float], **kwargs) -> pd.DataFrame:
76
- """Conduct test for each contrast and return a data frame
71
+ def _test_single_contrast(self, contrast: Sequence[float], **kwargs) -> pd.DataFrame: # noqa: D417
72
+ """Conduct test for each contrast and return a data frame.
77
73
 
78
74
  Args:
79
75
  contrast: numpy array of integars indicating contrast i.e. [-1, 0, 1, 0, 0]
@@ -100,7 +96,7 @@ class EdgeR(LinearModelBase):
100
96
  importr("edgeR")
101
97
  except ImportError:
102
98
  raise ImportError(
103
- "edgeR requires a valid R installation with the following packages: " "edgeR, BiocParallel, RhpcBLASctl"
99
+ "edgeR requires a valid R installation with the following packages: edgeR, BiocParallel, RhpcBLASctl"
104
100
  ) from None
105
101
 
106
102
  # Convert vector to R, which drops a category like `self.design_matrix` to use the intercept for the left out.
@@ -16,7 +16,7 @@ from ._checks import check_is_integer_matrix
16
16
 
17
17
 
18
18
  class PyDESeq2(LinearModelBase):
19
- """Differential expression test using a PyDESeq2"""
19
+ """Differential expression test using a PyDESeq2."""
20
20
 
21
21
  def __init__(
22
22
  self, adata: AnnData, design: str | ndarray, *, mask: str | None = None, layer: str | None = None, **kwargs
@@ -1,4 +1,4 @@
1
- """Simple tests such as t-test, wilcoxon"""
1
+ """Simple tests such as t-test, wilcoxon."""
2
2
 
3
3
  import warnings
4
4
  from abc import abstractmethod
@@ -10,7 +10,7 @@ import pandas as pd
10
10
  import scipy.stats
11
11
  import statsmodels
12
12
  from anndata import AnnData
13
- from pandas.core.api import DataFrame as DataFrame
13
+ from pandas.core.api import DataFrame
14
14
  from scipy.sparse import diags, issparse
15
15
  from tqdm.auto import tqdm
16
16
 
@@ -152,7 +152,7 @@ class WilcoxonTest(SimpleComparisonBase):
152
152
 
153
153
 
154
154
  class TTest(SimpleComparisonBase):
155
- """Perform a unpaired or paired T-test"""
155
+ """Perform a unpaired or paired T-test."""
156
156
 
157
157
  @staticmethod
158
158
  def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
@@ -6,14 +6,14 @@ import statsmodels.api as sm
6
6
  from tqdm.auto import tqdm
7
7
 
8
8
  from ._base import LinearModelBase
9
- from ._checks import check_is_integer_matrix
9
+ from ._checks import check_is_numeric_matrix
10
10
 
11
11
 
12
12
  class Statsmodels(LinearModelBase):
13
- """Differential expression test using a statsmodels linear regression"""
13
+ """Differential expression test using a statsmodels linear regression."""
14
14
 
15
15
  def _check_counts(self):
16
- check_is_integer_matrix(self.data)
16
+ check_is_numeric_matrix(self.data)
17
17
 
18
18
  def fit(
19
19
  self,
@@ -55,7 +55,10 @@ class Statsmodels(LinearModelBase):
55
55
  "t_value": t_test.tvalue.item(),
56
56
  "sd": t_test.sd.item(),
57
57
  "log_fc": t_test.effect.item(),
58
- "adj_p_value": statsmodels.stats.multitest.fdrcorrection(np.array([t_test.pvalue]))[1].item(),
59
58
  }
60
59
  )
61
- return pd.DataFrame(res).sort_values("p_value")
60
+ return (
61
+ pd.DataFrame(res)
62
+ .sort_values("p_value")
63
+ .assign(adj_p_value=lambda x: statsmodels.stats.multitest.fdrcorrection(x["p_value"])[1])
64
+ )
@@ -83,8 +83,7 @@ class DistanceTest:
83
83
  contrast: str,
84
84
  show_progressbar: bool = True,
85
85
  ) -> pd.DataFrame:
86
- """Run a permutation test using the specified distance metric, testing
87
- all groups of cells against a specified contrast group ("control").
86
+ """Run a permutation test using the specified distance metric, testing all groups of cells against a specified contrast group ("control").
88
87
 
89
88
  Args:
90
89
  adata: Annotated data matrix.
@@ -4,9 +4,9 @@ import multiprocessing
4
4
  from abc import ABC, abstractmethod
5
5
  from typing import TYPE_CHECKING, Literal, NamedTuple
6
6
 
7
- import numba
8
7
  import numpy as np
9
8
  import pandas as pd
9
+ from numba import jit
10
10
  from ott.geometry.geometry import Geometry
11
11
  from ott.geometry.pointcloud import PointCloud
12
12
  from ott.problems.linear.linear_problem import LinearProblem
@@ -135,9 +135,7 @@ class Distance:
135
135
  self.aggregation_func = agg_fct
136
136
  if metric == "edistance":
137
137
  metric_fct = Edistance()
138
- elif metric == "euclidean":
139
- metric_fct = EuclideanDistance(self.aggregation_func)
140
- elif metric == "root_mean_squared_error":
138
+ elif metric in ("euclidean", "root_mean_squared_error"):
141
139
  metric_fct = EuclideanDistance(self.aggregation_func)
142
140
  elif metric == "mse":
143
141
  metric_fct = MeanSquaredDistance(self.aggregation_func)
@@ -181,7 +179,7 @@ class Distance:
181
179
 
182
180
  if layer_key and obsm_key:
183
181
  raise ValueError(
184
- "Cannot use 'layer_key' and 'obsm_key' at the same time.\n" "Please provide only one of the two keys."
182
+ "Cannot use 'layer_key' and 'obsm_key' at the same time.\nPlease provide only one of the two keys."
185
183
  )
186
184
  if not layer_key and not obsm_key:
187
185
  obsm_key = "X_pca"
@@ -201,6 +199,7 @@ class Distance:
201
199
  Args:
202
200
  X: First vector of shape (n_samples, n_features).
203
201
  Y: Second vector of shape (n_samples, n_features).
202
+ kwargs: Passed to the metric function.
204
203
 
205
204
  Returns:
206
205
  float: Distance between X and Y.
@@ -239,9 +238,10 @@ class Distance:
239
238
  Y: Second vector of shape (n_samples, n_features).
240
239
  n_bootstrap: Number of bootstrap samples.
241
240
  random_state: Random state for bootstrapping.
241
+ **kwargs: Passed to the metric function.
242
242
 
243
243
  Returns:
244
- MeanVar: Mean and variance of distance between X and Y.
244
+ Mean and variance of distance between X and Y.
245
245
 
246
246
  Examples:
247
247
  >>> import pertpy as pt
@@ -286,8 +286,8 @@ class Distance:
286
286
  kwargs: Additional keyword arguments passed to the metric function.
287
287
 
288
288
  Returns:
289
- pd.DataFrame: Dataframe with pairwise distances.
290
- tuple[pd.DataFrame, pd.DataFrame]: Two Dataframes, one for the mean and one for the variance of pairwise distances.
289
+ :class:`pandas.DataFrame`: Dataframe with pairwise distances.
290
+ tuple[:class:`pandas.DataFrame`, :class:`pandas.DataFrame`]: Two Dataframes, one for the mean and one for the variance of pairwise distances.
291
291
 
292
292
  Examples:
293
293
  >>> import pertpy as pt
@@ -309,7 +309,7 @@ class Distance:
309
309
  # able to handle precomputed distances such as the PseudobulkDistance.
310
310
  if self.metric_fct.accepts_precomputed:
311
311
  # Precompute the pairwise distances if needed
312
- if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp.keys():
312
+ if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp:
313
313
  self.precompute_distances(adata, n_jobs=n_jobs, **kwargs)
314
314
  pwd = adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"]
315
315
  for index_x, group_x in enumerate(fct(groups)):
@@ -339,10 +339,7 @@ class Distance:
339
339
  df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean
340
340
  df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance
341
341
  else:
342
- if self.layer_key:
343
- embedding = adata.layers[self.layer_key]
344
- else:
345
- embedding = adata.obsm[self.obsm_key].copy()
342
+ embedding = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key].copy()
346
343
  for index_x, group_x in enumerate(fct(groups)):
347
344
  cells_x = embedding[np.asarray(grouping == group_x)].copy()
348
345
  for group_y in groups[index_x:]: # type: ignore
@@ -409,8 +406,8 @@ class Distance:
409
406
  kwargs: Additional keyword arguments passed to the metric function.
410
407
 
411
408
  Returns:
412
- pd.DataFrame: Dataframe with distances of groups to selected_group.
413
- tuple[pd.DataFrame, pd.DataFrame]: Two Dataframes, one for the mean and one for the variance of distances of groups to selected_group.
409
+ :class:`pandas.DataFrame`: Dataframe with distances of groups to selected_group.
410
+ tuple[:class:`pandas.DataFrame`, :class:`pandas.DataFrame`]: Two Dataframes, one for the mean and one for the variance of distances of groups to selected_group.
414
411
 
415
412
 
416
413
  Examples:
@@ -446,7 +443,7 @@ class Distance:
446
443
  # able to handle precomputed distances such as the PseudobulkDistance.
447
444
  if self.metric_fct.accepts_precomputed:
448
445
  # Precompute the pairwise distances if needed
449
- if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp.keys():
446
+ if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp:
450
447
  self.precompute_distances(adata, n_jobs=n_jobs, **kwargs)
451
448
  pwd = adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"]
452
449
  for group_x in fct(groups):
@@ -473,10 +470,7 @@ class Distance:
473
470
  df.loc[group_x] = bootstrap_output.mean
474
471
  df_var.loc[group_x] = bootstrap_output.variance
475
472
  else:
476
- if self.layer_key:
477
- embedding = adata.layers[self.layer_key]
478
- else:
479
- embedding = adata.obsm[self.obsm_key].copy()
473
+ embedding = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key].copy()
480
474
  for group_x in fct(groups):
481
475
  cells_x = embedding[np.asarray(grouping == group_x)].copy()
482
476
  group_y = selected_group
@@ -524,10 +518,7 @@ class Distance:
524
518
  >>> distance = pt.tools.Distance(metric="edistance")
525
519
  >>> distance.precompute_distances(adata)
526
520
  """
527
- if self.layer_key:
528
- cells = adata.layers[self.layer_key]
529
- else:
530
- cells = adata.obsm[self.obsm_key].copy()
521
+ cells = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key].copy()
531
522
  pwd = pairwise_distances(cells, cells, metric=self.cell_wise_metric, n_jobs=n_jobs)
532
523
  adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] = pwd
533
524
 
@@ -618,6 +609,7 @@ class AbstractDistance(ABC):
618
609
  Args:
619
610
  X: First vector of shape (n_samples, n_features).
620
611
  Y: Second vector of shape (n_samples, n_features).
612
+ kwargs: Passed to the metrics function.
621
613
 
622
614
  Returns:
623
615
  float: Distance between X and Y.
@@ -630,8 +622,8 @@ class AbstractDistance(ABC):
630
622
 
631
623
  Args:
632
624
  P: Pairwise distance matrix of shape (n_samples, n_samples).
633
- idx: Boolean array of shape (n_samples,) indicating which
634
- samples belong to X (or Y, since each metric is symmetric).
625
+ idx: Boolean array of shape (n_samples,) indicating which samples belong to X (or Y, since each metric is symmetric).
626
+ kwargs: Passed to the metrics function.
635
627
 
636
628
  Returns:
637
629
  float: Distance between X and Y.
@@ -645,12 +637,12 @@ class Edistance(AbstractDistance):
645
637
  def __init__(self) -> None:
646
638
  super().__init__()
647
639
  self.accepts_precomputed = True
648
- self.cell_wise_metric = "sqeuclidean"
640
+ self.cell_wise_metric = "euclidean"
649
641
 
650
642
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
651
- sigma_X = pairwise_distances(X, X, metric="sqeuclidean").mean()
652
- sigma_Y = pairwise_distances(Y, Y, metric="sqeuclidean").mean()
653
- delta = pairwise_distances(X, Y, metric="sqeuclidean").mean()
643
+ sigma_X = pairwise_distances(X, X, metric=self.cell_wise_metric, **kwargs).mean()
644
+ sigma_Y = pairwise_distances(Y, Y, metric=self.cell_wise_metric, **kwargs).mean()
645
+ delta = pairwise_distances(X, Y, metric=self.cell_wise_metric, **kwargs).mean()
654
646
  return 2 * delta - sigma_X - sigma_Y
655
647
 
656
648
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
@@ -881,7 +873,7 @@ class R2ScoreDistance(AbstractDistance):
881
873
 
882
874
 
883
875
  class SymmetricKLDivergence(AbstractDistance):
884
- """Average of symmetric KL divergence between gene distributions of two groups
876
+ """Average of symmetric KL divergence between gene distributions of two groups.
885
877
 
886
878
  Assuming a Gaussian distribution for each gene in each group, calculates
887
879
  the KL divergence between them and averages over all genes. Repeats this ABBA to get a symmetrized distance.
@@ -908,7 +900,7 @@ class SymmetricKLDivergence(AbstractDistance):
908
900
 
909
901
 
910
902
  class TTestDistance(AbstractDistance):
911
- """Average of T test statistic between two groups assuming unequal variances"""
903
+ """Average of T test statistic between two groups assuming unequal variances."""
912
904
 
913
905
  def __init__(self) -> None:
914
906
  super().__init__()
@@ -932,16 +924,14 @@ class TTestDistance(AbstractDistance):
932
924
 
933
925
 
934
926
  class KSTestDistance(AbstractDistance):
935
- """Average of two-sided KS test statistic between two groups"""
927
+ """Average of two-sided KS test statistic between two groups."""
936
928
 
937
929
  def __init__(self) -> None:
938
930
  super().__init__()
939
931
  self.accepts_precomputed = False
940
932
 
941
933
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
942
- stats = []
943
- for i in range(X.shape[1]):
944
- stats.append(abs(kstest(X[:, i], Y[:, i])[0]))
934
+ stats = [abs(kstest(X[:, i], Y[:, i])[0]) for i in range(X.shape[1])]
945
935
  return sum(stats) / len(stats)
946
936
 
947
937
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
@@ -949,10 +939,7 @@ class KSTestDistance(AbstractDistance):
949
939
 
950
940
 
951
941
  class NBLL(AbstractDistance):
952
- """
953
- Average of Log likelihood (scalar) of group B cells
954
- according to a NB distribution fitted over group A
955
- """
942
+ """Average of Log likelihood (scalar) of group B cells according to a NB distribution fitted over group A."""
956
943
 
957
944
  def __init__(self) -> None:
958
945
  super().__init__()
@@ -960,15 +947,12 @@ class NBLL(AbstractDistance):
960
947
 
961
948
  def __call__(self, X: np.ndarray, Y: np.ndarray, epsilon=1e-8, **kwargs) -> float:
962
949
  def _is_count_matrix(matrix, tolerance=1e-6):
963
- if matrix.dtype.kind == "i" or np.all(np.abs(matrix - np.round(matrix)) < tolerance):
964
- return True
965
- else:
966
- return False
950
+ return bool(matrix.dtype.kind == "i" or np.all(np.abs(matrix - np.round(matrix)) < tolerance))
967
951
 
968
952
  if not _is_count_matrix(matrix=X) or not _is_count_matrix(matrix=Y):
969
953
  raise ValueError("NBLL distance only works for raw counts.")
970
954
 
971
- @numba.jit(forceobj=True)
955
+ @jit(forceobj=True)
972
956
  def _compute_nll(y: np.ndarray, nb_params: tuple[float, float], epsilon: float) -> float:
973
957
  mu = np.exp(nb_params[0])
974
958
  theta = 1 / nb_params[1]
@@ -1117,67 +1101,77 @@ class MeanVarDistributionDistance(AbstractDistance):
1117
1101
  super().__init__()
1118
1102
  self.accepts_precomputed = False
1119
1103
 
1104
+ @staticmethod
1105
+ def _mean_var(x, log: bool = False):
1106
+ mean = np.mean(x, axis=0)
1107
+ var = np.var(x, axis=0)
1108
+ positive = mean > 0
1109
+ mean = mean[positive]
1110
+ var = var[positive]
1111
+ if log:
1112
+ mean = np.log(mean)
1113
+ var = np.log(var)
1114
+ return mean, var
1115
+
1116
+ @staticmethod
1117
+ def _prep_kde_data(x, y):
1118
+ return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
1119
+
1120
+ @staticmethod
1121
+ def _grid_points(d, n_points=100):
1122
+ # Make grid, add 1 bin on lower/upper end to get final n_points
1123
+ d_min = d.min()
1124
+ d_max = d.max()
1125
+ # Compute bin size
1126
+ d_bin = (d_max - d_min) / (n_points - 2)
1127
+ d_min = d_min - d_bin
1128
+ d_max = d_max + d_bin
1129
+ return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
1130
+
1131
+ @staticmethod
1132
+ def _kde_eval_both(x_kde, y_kde, grid):
1133
+ n_points = len(grid)
1134
+ chunk_size = 10000
1135
+
1136
+ result_x = np.zeros(n_points)
1137
+ result_y = np.zeros(n_points)
1138
+
1139
+ # Process same chunks for both KDEs
1140
+ for start in range(0, n_points, chunk_size):
1141
+ end = min(start + chunk_size, n_points)
1142
+ chunk = grid[start:end]
1143
+ result_x[start:end] = x_kde.score_samples(chunk)
1144
+ result_y[start:end] = y_kde.score_samples(chunk)
1145
+
1146
+ return result_x, result_y
1147
+
1120
1148
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
1121
1149
  """Difference of mean-var distributions in 2 matrices.
1122
1150
 
1123
1151
  Args:
1124
1152
  X: Normalized and log transformed cells x genes count matrix.
1125
1153
  Y: Normalized and log transformed cells x genes count matrix.
1154
+ kwargs: Passed to the metrics function.
1126
1155
  """
1156
+ mean_x, var_x = self._mean_var(X, log=True)
1157
+ mean_y, var_y = self._mean_var(Y, log=True)
1127
1158
 
1128
- def _mean_var(x, log: bool = False):
1129
- mean = np.mean(x, axis=0)
1130
- var = np.var(x, axis=0)
1131
- positive = mean > 0
1132
- mean = mean[positive]
1133
- var = var[positive]
1134
- if log:
1135
- mean = np.log(mean)
1136
- var = np.log(var)
1137
- return mean, var
1138
-
1139
- def _prep_kde_data(x, y):
1140
- return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
1141
-
1142
- def _grid_points(d, n_points=100):
1143
- # Make grid, add 1 bin on lower/upper end to get final n_points
1144
- d_min = d.min()
1145
- d_max = d.max()
1146
- # Compute bin size
1147
- d_bin = (d_max - d_min) / (n_points - 2)
1148
- d_min = d_min - d_bin
1149
- d_max = d_max + d_bin
1150
- return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
1151
-
1152
- def _parallel_score_samples(kde, samples, thread_count=int(0.875 * multiprocessing.cpu_count())):
1153
- # the thread_count is determined using the factor 0.875 as recommended here:
1154
- # https://stackoverflow.com/questions/32625094/scipy-parallel-computing-in-ipython-notebook
1155
- with multiprocessing.Pool(thread_count) as p:
1156
- return np.concatenate(p.map(kde.score_samples, np.array_split(samples, thread_count)))
1157
-
1158
- def _kde_eval(d, grid):
1159
- # Kernel choice: Gaussian is too smoothing and cosine or other kernels that do not stretch out
1160
- # can not be compared well on regions further away from the data as they are -inf
1161
- kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(d)
1162
- return _parallel_score_samples(kde, grid)
1163
-
1164
- mean_x, var_x = _mean_var(X, log=True)
1165
- mean_y, var_y = _mean_var(Y, log=True)
1166
-
1167
- x = _prep_kde_data(mean_x, var_x)
1168
- y = _prep_kde_data(mean_y, var_y)
1159
+ x = self._prep_kde_data(mean_x, var_x)
1160
+ y = self._prep_kde_data(mean_y, var_y)
1169
1161
 
1170
1162
  # Gridpoints to eval KDE on
1171
- mean_grid = _grid_points(np.concatenate([mean_x, mean_y]))
1172
- var_grid = _grid_points(np.concatenate([var_x, var_y]))
1163
+ mean_grid = self._grid_points(np.concatenate([mean_x, mean_y]))
1164
+ var_grid = self._grid_points(np.concatenate([var_x, var_y]))
1173
1165
  grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2)
1174
1166
 
1175
- kde_x = _kde_eval(x, grid)
1176
- kde_y = _kde_eval(y, grid)
1167
+ # Fit both KDEs first
1168
+ x_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(x)
1169
+ y_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(y)
1177
1170
 
1178
- kde_diff = ((kde_x - kde_y) ** 2).mean()
1171
+ # Evaluate both KDEs on same grid chunks
1172
+ kde_x, kde_y = self._kde_eval_both(x_kde, y_kde, grid)
1179
1173
 
1180
- return kde_diff
1174
+ return ((np.exp(kde_x) - np.exp(kde_y)) ** 2).mean()
1181
1175
 
1182
1176
  def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
1183
1177
  raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.")
@@ -25,10 +25,7 @@ def _prepare_targets(
25
25
  categories: str | Sequence[str] = None,
26
26
  ) -> ChainMap | dict:
27
27
  if categories is not None:
28
- if isinstance(categories, str):
29
- categories = [categories]
30
- else:
31
- categories = list(categories)
28
+ categories = [categories] if isinstance(categories, str) else list(categories)
32
29
 
33
30
  if targets is None:
34
31
  pt_drug = Drug()
@@ -97,10 +94,7 @@ class Enrichment:
97
94
  Returns:
98
95
  An AnnData object with scores.
99
96
  """
100
- if layer is not None:
101
- mtx = adata.layers[layer]
102
- else:
103
- mtx = adata.X
97
+ mtx = adata.layers[layer] if layer is not None else adata.X
104
98
 
105
99
  targets = _prepare_targets(targets=targets, nested=nested, categories=categories) # type: ignore
106
100
  full_targets = targets.copy()
@@ -114,10 +108,7 @@ class Enrichment:
114
108
  weights = pd.DataFrame(targets, index=adata.var_names)
115
109
  weights = weights.loc[:, weights.sum() > 0]
116
110
  weights = weights / weights.sum()
117
- if issparse(mtx):
118
- scores = mtx.dot(weights)
119
- else:
120
- scores = np.dot(mtx, weights)
111
+ scores = mtx.dot(weights) if issparse(mtx) else np.dot(mtx, weights)
121
112
 
122
113
  if method == "seurat":
123
114
  obs_avg = _mean(mtx, names=adata.var_names, axis=0)
@@ -136,10 +127,7 @@ class Enrichment:
136
127
  control_gene_weights = pd.DataFrame(control_groups, index=adata.var_names)
137
128
  control_gene_weights = control_gene_weights / control_gene_weights.sum()
138
129
 
139
- if issparse(mtx):
140
- control_profiles = mtx.dot(control_gene_weights)
141
- else:
142
- control_profiles = np.dot(mtx, control_gene_weights)
130
+ control_profiles = mtx.dot(control_gene_weights) if issparse(mtx) else np.dot(mtx, control_gene_weights)
143
131
  drug_bins = {}
144
132
  for drug in weights.columns:
145
133
  bins = np.unique(obs_cut[targets[drug]])
@@ -178,7 +166,7 @@ class Enrichment:
178
166
  Accepts two forms:
179
167
  - A dictionary with the names of the groups as keys, and the entries being the corresponding gene lists.
180
168
  - A dictionary of dictionaries defined like above, with names of gene group categories as keys.
181
- If passing one of those, specify `nested=True`.
169
+ If passing one of those, specify `nested=True`.
182
170
  nested: Whether `targets` is a dictionary of dictionaries with group categories as keys.
183
171
  categories: If `targets=None` or `nested=True`, this argument can be used to subset the gene groups to one or more categories (keys of the original dictionary).
184
172
  In case of the ChEMBL drug targets, these are ATC level 1/level 2 category codes.
@@ -293,7 +281,7 @@ class Enrichment:
293
281
  return enrichment
294
282
 
295
283
  @_doc_params(common_plot_args=doc_common_plot_args)
296
- def plot_dotplot(
284
+ def plot_dotplot( # pragma: no cover # noqa: D417
297
285
  self,
298
286
  adata: AnnData,
299
287
  *,
@@ -304,7 +292,6 @@ class Enrichment:
304
292
  groupby: str = None,
305
293
  key: str = "pertpy_enrichment",
306
294
  ax: Axes | None = None,
307
- show: bool = True,
308
295
  return_fig: bool = False,
309
296
  **kwargs,
310
297
  ) -> DotPlot | None:
@@ -342,10 +329,7 @@ class Enrichment:
342
329
  .. image:: /_static/docstring_previews/enrichment_dotplot.png
343
330
  """
344
331
  if categories is not None:
345
- if isinstance(categories, str):
346
- categories = [categories]
347
- else:
348
- categories = list(categories)
332
+ categories = [categories] if isinstance(categories, str) else list(categories)
349
333
 
350
334
  if targets is None:
351
335
  pt_drug = Drug()
@@ -417,10 +401,9 @@ class Enrichment:
417
401
  **kwargs,
418
402
  )
419
403
 
420
- if show:
421
- plt.show()
422
404
  if return_fig:
423
405
  return fig
406
+ plt.show()
424
407
  return None
425
408
 
426
409
  def plot_gsea(