pertpy 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. pertpy/__init__.py +5 -1
  2. pertpy/_doc.py +1 -3
  3. pertpy/_types.py +6 -0
  4. pertpy/data/_dataloader.py +68 -24
  5. pertpy/data/_datasets.py +9 -9
  6. pertpy/metadata/__init__.py +2 -1
  7. pertpy/metadata/_cell_line.py +133 -25
  8. pertpy/metadata/_look_up.py +13 -19
  9. pertpy/metadata/_moa.py +1 -1
  10. pertpy/preprocessing/_guide_rna.py +138 -44
  11. pertpy/preprocessing/_guide_rna_mixture.py +17 -19
  12. pertpy/tools/__init__.py +1 -1
  13. pertpy/tools/_augur.py +106 -98
  14. pertpy/tools/_cinemaot.py +74 -114
  15. pertpy/tools/_coda/_base_coda.py +129 -145
  16. pertpy/tools/_coda/_sccoda.py +66 -69
  17. pertpy/tools/_coda/_tasccoda.py +71 -79
  18. pertpy/tools/_dialogue.py +48 -40
  19. pertpy/tools/_differential_gene_expression/_base.py +21 -31
  20. pertpy/tools/_differential_gene_expression/_checks.py +4 -6
  21. pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
  22. pertpy/tools/_differential_gene_expression/_edger.py +6 -10
  23. pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
  24. pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
  25. pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
  26. pertpy/tools/_distances/_distance_tests.py +1 -2
  27. pertpy/tools/_distances/_distances.py +31 -45
  28. pertpy/tools/_enrichment.py +7 -22
  29. pertpy/tools/_milo.py +19 -15
  30. pertpy/tools/_mixscape.py +73 -75
  31. pertpy/tools/_perturbation_space/_clustering.py +4 -4
  32. pertpy/tools/_perturbation_space/_comparison.py +4 -4
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
  35. pertpy/tools/_perturbation_space/_simple.py +12 -14
  36. pertpy/tools/_scgen/_scgen.py +16 -17
  37. pertpy/tools/_scgen/_scgenvae.py +2 -2
  38. pertpy/tools/_scgen/_utils.py +3 -1
  39. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/METADATA +36 -20
  40. pertpy-0.11.0.dist-info/RECORD +58 -0
  41. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
  42. pertpy/tools/_kernel_pca.py +0 -50
  43. pertpy-0.10.0.dist-info/RECORD +0 -58
  44. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
pertpy/tools/_cinemaot.py CHANGED
@@ -2,18 +2,19 @@ from __future__ import annotations
2
2
 
3
3
  from typing import TYPE_CHECKING
4
4
 
5
+ import jax
5
6
  import matplotlib.pyplot as plt
6
7
  import numpy as np
7
8
  import pandas as pd
8
9
  import scanpy as sc
9
10
  import scipy.stats as ss
10
- import seaborn as sns
11
11
  import sklearn.metrics
12
- from ott.geometry import pointcloud
12
+ from ott.geometry.pointcloud import PointCloud
13
13
  from ott.problems.linear import linear_problem
14
14
  from ott.solvers.linear import sinkhorn, sinkhorn_lr
15
15
  from scanpy.plotting import _utils
16
16
  from scipy.sparse import issparse
17
+ from seaborn import heatmap
17
18
  from sklearn.decomposition import FastICA
18
19
  from sklearn.linear_model import LinearRegression
19
20
  from sklearn.neighbors import NearestNeighbors
@@ -49,6 +50,7 @@ class Cinemaot:
49
50
  eps: float = 1e-3,
50
51
  solver: str = "Sinkhorn",
51
52
  preweight_label: str | None = None,
53
+ random_state: int | None = 0,
52
54
  ):
53
55
  """Calculate the confounding variation, optimal transport counterfactual pairs, and single-cell level treatment effects.
54
56
 
@@ -69,6 +71,7 @@ class Cinemaot:
69
71
  solver: Either "Sinkhorn" or "LRSinkhorn". The ott-jax solver used.
70
72
  preweight_label: The annotated label (e.g. cell type) that is used to assign weights for treated
71
73
  and control cells to balance across the label. Helps overcome the differential abundance issue.
74
+ random_state: The random seed for the shuffling.
72
75
 
73
76
  Returns:
74
77
  Returns an AnnData object that contains the single-cell level treatment effect as de.X and the
@@ -85,7 +88,7 @@ class Cinemaot:
85
88
  """
86
89
  available_solvers = ["Sinkhorn", "LRSinkhorn"]
87
90
  if solver not in available_solvers:
88
- raise ValueError(f"solver = {solver} is not one of the supported solvers:" f" {available_solvers}")
91
+ raise ValueError(f"solver = {solver} is not one of the supported solvers: {available_solvers}")
89
92
 
90
93
  if dim is None:
91
94
  dim = self.get_dim(adata, use_rep=use_rep)
@@ -96,7 +99,7 @@ class Cinemaot:
96
99
  xi = np.zeros(dim)
97
100
  j = 0
98
101
  for source_row in X_transformed.T:
99
- xi_obj = Xi(source_row, groupvec * 1)
102
+ xi_obj = Xi(source_row, groupvec * 1, random_state=random_state)
100
103
  xi[j] = xi_obj.correlation
101
104
  j = j + 1
102
105
 
@@ -111,7 +114,7 @@ class Cinemaot:
111
114
  sklearn.metrics.pairwise_distances(cf1, cf2)
112
115
 
113
116
  e = smoothness * sum(xi < thres)
114
- geom = pointcloud.PointCloud(cf1, cf2, epsilon=e, batch_size=batch_size)
117
+ geom = PointCloud(cf1, cf2, epsilon=e, batch_size=batch_size)
115
118
 
116
119
  if preweight_label is None:
117
120
  ot_prob = linear_problem.LinearProblem(geom, a=None, b=None)
@@ -122,17 +125,15 @@ class Cinemaot:
122
125
  a = np.zeros(cf1.shape[0])
123
126
  b = np.zeros(cf2.shape[0])
124
127
 
125
- adata1 = adata[adata.obs[pert_key] == control, :].copy()
126
- adata2 = adata[adata.obs[pert_key] != control, :].copy()
128
+ adata1 = adata[adata.obs[pert_key] == control]
129
+ adata2 = adata[adata.obs[pert_key] != control]
127
130
 
128
131
  for label in adata1.obs[pert_key].unique():
132
+ mask_label = adata1.obs[pert_key] == label
129
133
  for ct in adata1.obs[preweight_label].unique():
130
- a[((adata1.obs[preweight_label] == ct) & (adata1.obs[pert_key] == label)).values] = np.sum(
131
- (adata2.obs[preweight_label] == ct).values
132
- ) / np.sum((adata1.obs[preweight_label] == ct).values)
133
- a[(adata1.obs[pert_key] == label).values] = a[(adata1.obs[pert_key] == label).values] / np.sum(
134
- a[(adata1.obs[pert_key] == label).values]
135
- )
134
+ mask_ct = adata1.obs[preweight_label] == ct
135
+ a[mask_ct & mask_label] = np.sum(adata2.obs[preweight_label] == ct) / np.sum(mask_ct)
136
+ a[mask_label] /= np.sum(a[mask_label])
136
137
 
137
138
  a = a / np.sum(a)
138
139
  b[:] = 1 / cf2.shape[0]
@@ -141,25 +142,22 @@ class Cinemaot:
141
142
  if solver == "LRSinkhorn":
142
143
  if rank is None:
143
144
  rank = int(min(cf1.shape[0], cf2.shape[0]) / 2)
144
- _solver = sinkhorn_lr.LRSinkhorn(rank=rank, threshold=eps)
145
+ _solver = jax.jit(sinkhorn_lr.LRSinkhorn(rank=rank, threshold=eps))
145
146
  ot_sink = _solver(ot_prob)
146
147
  embedding = (
147
148
  X_transformed[adata.obs[pert_key] != control, :]
148
149
  - ot_sink.apply(X_transformed[adata.obs[pert_key] == control, :].T).T
149
150
  / ot_sink.apply(np.ones_like(X_transformed[adata.obs[pert_key] == control, :].T)).T
150
151
  )
151
- if issparse(adata.X):
152
- te2 = (
153
- adata.X.toarray()[adata.obs[pert_key] != control, :]
154
- - ot_sink.apply(adata.X.toarray()[adata.obs[pert_key] == control, :].T).T
155
- / ot_sink.apply(np.ones_like(adata.X.toarray()[adata.obs[pert_key] == control, :].T)).T
156
- )
157
- else:
158
- te2 = (
159
- adata.X[adata.obs[pert_key] != control, :]
160
- - ot_sink.apply(adata.X[adata.obs[pert_key] == control, :].T).T
161
- / ot_sink.apply(np.ones_like(adata.X[adata.obs[pert_key] == control, :].T)).T
162
- )
152
+
153
+ X = adata.X.toarray() if issparse(adata.X) else adata.X
154
+ te2 = (
155
+ X[adata.obs[pert_key] != control, :]
156
+ - ot_sink.apply(X[adata.obs[pert_key] == control, :].T).T
157
+ / ot_sink.apply(np.ones_like(X[adata.obs[pert_key] == control, :].T)).T
158
+ )
159
+ if issparse(X):
160
+ del X
163
161
 
164
162
  adata.obsm[cf_rep] = cf
165
163
  adata.obsm[cf_rep][adata.obs[pert_key] != control, :] = (
@@ -168,21 +166,20 @@ class Cinemaot:
168
166
  )
169
167
 
170
168
  else:
171
- _solver = sinkhorn.Sinkhorn(threshold=eps)
169
+ _solver = jax.jit(sinkhorn.Sinkhorn(threshold=eps))
172
170
  ot_sink = _solver(ot_prob)
173
171
  ot_matrix = np.array(ot_sink.matrix.T, dtype=np.float64)
174
172
  embedding = X_transformed[adata.obs[pert_key] != control, :] - np.matmul(
175
173
  ot_matrix / np.sum(ot_matrix, axis=1)[:, None], X_transformed[adata.obs[pert_key] == control, :]
176
174
  )
177
175
 
178
- if issparse(adata.X):
179
- te2 = adata.X.toarray()[adata.obs[pert_key] != control, :] - np.matmul(
180
- ot_matrix / np.sum(ot_matrix, axis=1)[:, None], adata.X.toarray()[adata.obs[pert_key] == control, :]
181
- )
182
- else:
183
- te2 = adata.X[adata.obs[pert_key] != control, :] - np.matmul(
184
- ot_matrix / np.sum(ot_matrix, axis=1)[:, None], adata.X[adata.obs[pert_key] == control, :]
185
- )
176
+ X = adata.X.toarray() if issparse(adata.X) else adata.X
177
+
178
+ te2 = X[adata.obs[pert_key] != control, :] - np.matmul(
179
+ ot_matrix / np.sum(ot_matrix, axis=1)[:, None], X[adata.obs[pert_key] == control, :]
180
+ )
181
+ if issparse(X):
182
+ del X
186
183
 
187
184
  adata.obsm[cf_rep] = cf
188
185
  adata.obsm[cf_rep][adata.obs[pert_key] != control, :] = np.matmul(
@@ -251,7 +248,7 @@ class Cinemaot:
251
248
  """
252
249
  available_solvers = ["Sinkhorn", "LRSinkhorn"]
253
250
  assert solver in available_solvers, (
254
- f"solver = {solver} is not one of the supported solvers:" f" {available_solvers}"
251
+ f"solver = {solver} is not one of the supported solvers: {available_solvers}"
255
252
  )
256
253
 
257
254
  if dim is None:
@@ -259,7 +256,9 @@ class Cinemaot:
259
256
 
260
257
  adata.obs_names_make_unique()
261
258
 
262
- idx = self.get_weightidx(adata, pert_key=pert_key, control=control, k=k, use_rep=use_rep, resolution=resolution)
259
+ idx = self._get_weightidx(
260
+ adata, pert_key=pert_key, control=control, k=k, use_rep=use_rep, resolution=resolution
261
+ )
263
262
  adata_ = adata[idx].copy()
264
263
  TE = self.causaleffect(
265
264
  adata_,
@@ -329,11 +328,10 @@ class Cinemaot:
329
328
  df = pd.DataFrame(adata.raw.X.toarray(), columns=adata.raw.var_names, index=adata.raw.obs_names)
330
329
  else:
331
330
  df = pd.DataFrame(adata.raw.X, columns=adata.raw.var_names, index=adata.raw.obs_names)
331
+ elif issparse(adata.X):
332
+ df = pd.DataFrame(adata.X.toarray(), columns=adata.var_names, index=adata.obs_names)
332
333
  else:
333
- if issparse(adata.X):
334
- df = pd.DataFrame(adata.X.toarray(), columns=adata.var_names, index=adata.obs_names)
335
- else:
336
- df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs_names)
334
+ df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs_names)
337
335
 
338
336
  if label_list is None:
339
337
  label_list = ["ct"]
@@ -377,10 +375,7 @@ class Cinemaot:
377
375
  >>> dim = model.get_dim(adata)
378
376
  """
379
377
  sk = SinkhornKnopp()
380
- if issparse(adata.raw.X):
381
- data = adata.raw.X.toarray()
382
- else:
383
- data = adata.raw.X
378
+ data = adata.raw.X.toarray() if issparse(adata.raw.X) else adata.raw.X
384
379
  vm = (1e-3 + data + c * data * data) / (1 + c)
385
380
  sk.fit(vm)
386
381
  wm = np.dot(np.dot(np.sqrt(sk._D1), vm), np.sqrt(sk._D2))
@@ -388,9 +383,10 @@ class Cinemaot:
388
383
  dim = min(sum(s > (np.sqrt(data.shape[0]) + np.sqrt(data.shape[1]))), adata.obsm[use_rep].shape[1])
389
384
  return dim
390
385
 
391
- def get_weightidx(
386
+ def _get_weightidx(
392
387
  self,
393
388
  adata: AnnData,
389
+ *,
394
390
  pert_key: str,
395
391
  control: str,
396
392
  use_rep: str = "X_pca",
@@ -401,7 +397,8 @@ class Cinemaot:
401
397
 
402
398
  Args:
403
399
  adata: The annotated data object.
404
- c: the parameter regarding the quadratic variance distribution. c=0 means Poisson count matrices.
400
+ pert_key: Key of the perturbation col.
401
+ control: Key of the control col.
405
402
  use_rep: the embedding used to give a upper bound for the estimated rank.
406
403
  k: the number of neighbors used in the k-NN matching phase.
407
404
  resolution: the clustering resolution used in the sampling phase.
@@ -413,7 +410,7 @@ class Cinemaot:
413
410
  >>> import pertpy as pt
414
411
  >>> adata = pt.dt.cinemaot_example()
415
412
  >>> model = pt.tl.Cinemaot()
416
- >>> idx = model.get_weightidx(adata, pert_key="perturbation", control="No stimulation")
413
+ >>> idx = model._get_weightidx(adata, pert_key="perturbation", control="No stimulation")
417
414
  """
418
415
  adata_ = adata.copy()
419
416
  X_pca1 = adata_.obsm[use_rep][adata_.obs[pert_key] == control, :]
@@ -621,13 +618,12 @@ class Cinemaot:
621
618
  else:
622
619
  Y0 = adata.raw.X[adata.obs[pert_key] == control, :]
623
620
  Y1 = adata.raw.X[adata.obs[pert_key] != control, :]
621
+ elif issparse(adata.X):
622
+ Y0 = adata.X.toarray()[adata.obs[pert_key] == control, :]
623
+ Y1 = adata.X.toarray()[adata.obs[pert_key] != control, :]
624
624
  else:
625
- if issparse(adata.X):
626
- Y0 = adata.X.toarray()[adata.obs[pert_key] == control, :]
627
- Y1 = adata.X.toarray()[adata.obs[pert_key] != control, :]
628
- else:
629
- Y0 = adata.X[adata.obs[pert_key] == control, :]
630
- Y1 = adata.X[adata.obs[pert_key] != control, :]
625
+ Y0 = adata.X[adata.obs[pert_key] == control, :]
626
+ Y1 = adata.X[adata.obs[pert_key] != control, :]
631
627
  X0 = cf[adata.obs[pert_key] == control, :]
632
628
  X1 = cf[adata.obs[pert_key] != control, :]
633
629
  ols0 = LinearRegression()
@@ -643,7 +639,7 @@ class Cinemaot:
643
639
  return c_effect, s_effect
644
640
 
645
641
  @_doc_params(common_plot_args=doc_common_plot_args)
646
- def plot_vis_matching(
642
+ def plot_vis_matching( # pragma: no cover # noqa: D417
647
643
  self,
648
644
  adata: AnnData,
649
645
  de: AnnData,
@@ -671,9 +667,11 @@ class Cinemaot:
671
667
  de_label: the label for differential response. If none, use leiden cluster labels at resolution 1.0.
672
668
  source_label: the confounder / cell type label.
673
669
  matching_rep: the place that stores the matching matrix. default de.obsm['ot'].
670
+ resolution: Leiden resolution.
674
671
  normalize: normalize the coarse-grained matching matrix by row / column.
675
672
  title: the title for the figure.
676
673
  min_val: The min value to truncate the matching matrix.
674
+ ax: Matplotlib axes object.
677
675
  {common_plot_args}
678
676
  **kwargs: Other parameters to input for seaborn.heatmap.
679
677
 
@@ -703,17 +701,11 @@ class Cinemaot:
703
701
  df["source_label"] = adata_.obs[source_label].astype(str).values
704
702
  df = df.groupby("source_label").sum()
705
703
 
706
- if normalize == "col":
707
- df = df / df.sum(axis=0)
708
- else:
709
- df = (df.T / df.sum(axis=1)).T
704
+ df = df / df.sum(axis=0) if normalize == "col" else (df.T / df.sum(axis=1)).T
710
705
  df = df.clip(lower=min_val) - min_val
711
- if normalize == "col":
712
- df = df / df.sum(axis=0)
713
- else:
714
- df = (df.T / df.sum(axis=1)).T
706
+ df = df / df.sum(axis=0) if normalize == "col" else (df.T / df.sum(axis=1)).T
715
707
 
716
- g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
708
+ g = heatmap(df, annot=True, ax=ax, **kwargs)
717
709
  plt.title(title)
718
710
 
719
711
  if return_fig:
@@ -723,14 +715,12 @@ class Cinemaot:
723
715
 
724
716
 
725
717
  class Xi:
726
- """
727
- A fast implementation of cross-rank dependence metric used in CINEMA-OT.
728
-
729
- """
718
+ """A fast implementation of cross-rank dependence metric used in CINEMA-OT."""
730
719
 
731
- def __init__(self, x, y):
720
+ def __init__(self, x, y, random_state: int | None = 0):
732
721
  self.x = x
733
722
  self.y = y
723
+ self.random_state = random_state
734
724
 
735
725
  @property
736
726
  def sample_size(self):
@@ -739,18 +729,15 @@ class Xi:
739
729
  @property
740
730
  def x_ordered_rank(self):
741
731
  # PI is the rank vector for x, with ties broken at random
742
- # Not mine: source (https://stackoverflow.com/a/47430384/1628971)
743
- # random shuffling of the data - reason to use random.choice is that
744
- # pd.sample(frac=1) uses the same randomizing algorithm
745
732
  len_x = len(self.x)
746
- rng = np.random.default_rng()
747
- randomized_indices = rng.choice(np.arange(len_x), len_x, replace=False)
748
- randomized = [self.x[idx] for idx in randomized_indices]
749
- # same as pandas rank method 'first'
750
- rankdata = ss.rankdata(randomized, method="ordinal")
751
- # Reindexing based on pairs of indices before and after
752
- unrandomized = [rankdata[j] for i, j in sorted(zip(randomized_indices, range(len_x), strict=False))]
753
- return unrandomized
733
+ rng = np.random.default_rng(self.random_state)
734
+ perm = rng.permutation(len_x)
735
+ x_shuffled = self.x[perm]
736
+
737
+ ranks = np.empty(len_x, dtype=int)
738
+ ranks[perm[np.argsort(x_shuffled, stable=True)]] = np.arange(1, len_x + 1)
739
+
740
+ return ranks
754
741
 
755
742
  @property
756
743
  def y_rank_max(self):
@@ -760,7 +747,7 @@ class Xi:
760
747
  @property
761
748
  def g(self):
762
749
  # g[i] is number of j s.t. y[j] >= y[i], divided by n.
763
- return ss.rankdata([-i for i in self.y], method="max") / self.sample_size
750
+ return ss.rankdata(-self.y, method="max") / self.sample_size
764
751
 
765
752
  @property
766
753
  def x_ordered(self):
@@ -769,32 +756,14 @@ class Xi:
769
756
 
770
757
  @property
771
758
  def x_rank_max_ordered(self):
772
- x_ordered_result = self.x_ordered
773
- y_rank_max_result = self.y_rank_max
774
- # Rearrange f according to ord.
775
- return [y_rank_max_result[i] for i in x_ordered_result]
759
+ return self.y_rank_max[self.x_ordered]
776
760
 
777
761
  @property
778
762
  def mean_absolute(self):
779
763
  x1 = self.x_rank_max_ordered[0 : (self.sample_size - 1)]
780
764
  x2 = self.x_rank_max_ordered[1 : self.sample_size]
781
765
 
782
- return (
783
- np.mean(
784
- np.abs(
785
- [
786
- x - y
787
- for x, y in zip(
788
- x1,
789
- x2,
790
- strict=False,
791
- )
792
- ]
793
- )
794
- )
795
- * (self.sample_size - 1)
796
- / (2 * self.sample_size)
797
- )
766
+ return np.mean(np.abs(x1 - x2)) * (self.sample_size - 1) / (2 * self.sample_size)
798
767
 
799
768
  @property
800
769
  def inverse_g_mean(self):
@@ -803,7 +772,7 @@ class Xi:
803
772
 
804
773
  @property
805
774
  def correlation(self):
806
- """xi correlation"""
775
+ """Xi correlation."""
807
776
  return 1 - self.mean_absolute / self.inverse_g_mean
808
777
 
809
778
  @classmethod
@@ -852,10 +821,7 @@ class Xi:
852
821
 
853
822
 
854
823
  class SinkhornKnopp:
855
- """
856
- An simple implementation of Sinkhorn iteration used in the biwhitening approach.
857
-
858
- """
824
+ """An simple implementation of Sinkhorn iteration used in the biwhitening approach."""
859
825
 
860
826
  def __init__(self, max_iter: float = 1000, setr: int = 0, setc: float = 0, epsilon: float = 1e-3):
861
827
  if max_iter < 0:
@@ -889,14 +855,8 @@ class SinkhornKnopp:
889
855
  assert P.ndim == 2
890
856
 
891
857
  N = P.shape[0]
892
- if np.sum(abs(self._setr)) == 0:
893
- rsum = P.shape[1]
894
- else:
895
- rsum = self._setr
896
- if np.sum(abs(self._setc)) == 0:
897
- csum = P.shape[0]
898
- else:
899
- csum = self._setc
858
+ rsum = P.shape[1] if np.sum(abs(self._setr)) == 0 else self._setr
859
+ csum = P.shape[0] if np.sum(abs(self._setc)) == 0 else self._setc
900
860
  max_threshr = rsum + self._epsilon
901
861
  min_threshr = rsum - self._epsilon
902
862
  max_threshc = csum + self._epsilon