pertpy 0.9.5__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. pertpy/__init__.py +5 -1
  2. pertpy/_doc.py +2 -5
  3. pertpy/_types.py +6 -0
  4. pertpy/data/_dataloader.py +68 -24
  5. pertpy/data/_datasets.py +9 -9
  6. pertpy/metadata/__init__.py +2 -1
  7. pertpy/metadata/_cell_line.py +136 -30
  8. pertpy/metadata/_look_up.py +13 -19
  9. pertpy/metadata/_moa.py +1 -1
  10. pertpy/preprocessing/_guide_rna.py +221 -39
  11. pertpy/preprocessing/_guide_rna_mixture.py +177 -0
  12. pertpy/tools/__init__.py +1 -1
  13. pertpy/tools/_augur.py +138 -142
  14. pertpy/tools/_cinemaot.py +75 -117
  15. pertpy/tools/_coda/_base_coda.py +150 -174
  16. pertpy/tools/_coda/_sccoda.py +66 -69
  17. pertpy/tools/_coda/_tasccoda.py +71 -79
  18. pertpy/tools/_dialogue.py +60 -56
  19. pertpy/tools/_differential_gene_expression/_base.py +25 -43
  20. pertpy/tools/_differential_gene_expression/_checks.py +4 -6
  21. pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
  22. pertpy/tools/_differential_gene_expression/_edger.py +6 -10
  23. pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
  24. pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
  25. pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
  26. pertpy/tools/_distances/_distance_tests.py +1 -2
  27. pertpy/tools/_distances/_distances.py +86 -92
  28. pertpy/tools/_enrichment.py +8 -25
  29. pertpy/tools/_milo.py +23 -27
  30. pertpy/tools/_mixscape.py +261 -175
  31. pertpy/tools/_perturbation_space/_clustering.py +4 -4
  32. pertpy/tools/_perturbation_space/_comparison.py +4 -4
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
  35. pertpy/tools/_perturbation_space/_simple.py +13 -17
  36. pertpy/tools/_scgen/_scgen.py +17 -20
  37. pertpy/tools/_scgen/_scgenvae.py +2 -2
  38. pertpy/tools/_scgen/_utils.py +3 -1
  39. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/METADATA +37 -21
  40. pertpy-0.11.0.dist-info/RECORD +58 -0
  41. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
  42. pertpy/tools/_kernel_pca.py +0 -50
  43. pertpy-0.9.5.dist-info/RECORD +0 -57
  44. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
pertpy/tools/_cinemaot.py CHANGED
@@ -2,18 +2,19 @@ from __future__ import annotations
2
2
 
3
3
  from typing import TYPE_CHECKING
4
4
 
5
+ import jax
5
6
  import matplotlib.pyplot as plt
6
7
  import numpy as np
7
8
  import pandas as pd
8
9
  import scanpy as sc
9
10
  import scipy.stats as ss
10
- import seaborn as sns
11
11
  import sklearn.metrics
12
- from ott.geometry import pointcloud
12
+ from ott.geometry.pointcloud import PointCloud
13
13
  from ott.problems.linear import linear_problem
14
14
  from ott.solvers.linear import sinkhorn, sinkhorn_lr
15
15
  from scanpy.plotting import _utils
16
16
  from scipy.sparse import issparse
17
+ from seaborn import heatmap
17
18
  from sklearn.decomposition import FastICA
18
19
  from sklearn.linear_model import LinearRegression
19
20
  from sklearn.neighbors import NearestNeighbors
@@ -49,6 +50,7 @@ class Cinemaot:
49
50
  eps: float = 1e-3,
50
51
  solver: str = "Sinkhorn",
51
52
  preweight_label: str | None = None,
53
+ random_state: int | None = 0,
52
54
  ):
53
55
  """Calculate the confounding variation, optimal transport counterfactual pairs, and single-cell level treatment effects.
54
56
 
@@ -69,6 +71,7 @@ class Cinemaot:
69
71
  solver: Either "Sinkhorn" or "LRSinkhorn". The ott-jax solver used.
70
72
  preweight_label: The annotated label (e.g. cell type) that is used to assign weights for treated
71
73
  and control cells to balance across the label. Helps overcome the differential abundance issue.
74
+ random_state: The random seed for the shuffling.
72
75
 
73
76
  Returns:
74
77
  Returns an AnnData object that contains the single-cell level treatment effect as de.X and the
@@ -85,7 +88,7 @@ class Cinemaot:
85
88
  """
86
89
  available_solvers = ["Sinkhorn", "LRSinkhorn"]
87
90
  if solver not in available_solvers:
88
- raise ValueError(f"solver = {solver} is not one of the supported solvers:" f" {available_solvers}")
91
+ raise ValueError(f"solver = {solver} is not one of the supported solvers: {available_solvers}")
89
92
 
90
93
  if dim is None:
91
94
  dim = self.get_dim(adata, use_rep=use_rep)
@@ -96,7 +99,7 @@ class Cinemaot:
96
99
  xi = np.zeros(dim)
97
100
  j = 0
98
101
  for source_row in X_transformed.T:
99
- xi_obj = Xi(source_row, groupvec * 1)
102
+ xi_obj = Xi(source_row, groupvec * 1, random_state=random_state)
100
103
  xi[j] = xi_obj.correlation
101
104
  j = j + 1
102
105
 
@@ -111,7 +114,7 @@ class Cinemaot:
111
114
  sklearn.metrics.pairwise_distances(cf1, cf2)
112
115
 
113
116
  e = smoothness * sum(xi < thres)
114
- geom = pointcloud.PointCloud(cf1, cf2, epsilon=e, batch_size=batch_size)
117
+ geom = PointCloud(cf1, cf2, epsilon=e, batch_size=batch_size)
115
118
 
116
119
  if preweight_label is None:
117
120
  ot_prob = linear_problem.LinearProblem(geom, a=None, b=None)
@@ -122,17 +125,15 @@ class Cinemaot:
122
125
  a = np.zeros(cf1.shape[0])
123
126
  b = np.zeros(cf2.shape[0])
124
127
 
125
- adata1 = adata[adata.obs[pert_key] == control, :].copy()
126
- adata2 = adata[adata.obs[pert_key] != control, :].copy()
128
+ adata1 = adata[adata.obs[pert_key] == control]
129
+ adata2 = adata[adata.obs[pert_key] != control]
127
130
 
128
131
  for label in adata1.obs[pert_key].unique():
132
+ mask_label = adata1.obs[pert_key] == label
129
133
  for ct in adata1.obs[preweight_label].unique():
130
- a[((adata1.obs[preweight_label] == ct) & (adata1.obs[pert_key] == label)).values] = np.sum(
131
- (adata2.obs[preweight_label] == ct).values
132
- ) / np.sum((adata1.obs[preweight_label] == ct).values)
133
- a[(adata1.obs[pert_key] == label).values] = a[(adata1.obs[pert_key] == label).values] / np.sum(
134
- a[(adata1.obs[pert_key] == label).values]
135
- )
134
+ mask_ct = adata1.obs[preweight_label] == ct
135
+ a[mask_ct & mask_label] = np.sum(adata2.obs[preweight_label] == ct) / np.sum(mask_ct)
136
+ a[mask_label] /= np.sum(a[mask_label])
136
137
 
137
138
  a = a / np.sum(a)
138
139
  b[:] = 1 / cf2.shape[0]
@@ -141,25 +142,22 @@ class Cinemaot:
141
142
  if solver == "LRSinkhorn":
142
143
  if rank is None:
143
144
  rank = int(min(cf1.shape[0], cf2.shape[0]) / 2)
144
- _solver = sinkhorn_lr.LRSinkhorn(rank=rank, threshold=eps)
145
+ _solver = jax.jit(sinkhorn_lr.LRSinkhorn(rank=rank, threshold=eps))
145
146
  ot_sink = _solver(ot_prob)
146
147
  embedding = (
147
148
  X_transformed[adata.obs[pert_key] != control, :]
148
149
  - ot_sink.apply(X_transformed[adata.obs[pert_key] == control, :].T).T
149
150
  / ot_sink.apply(np.ones_like(X_transformed[adata.obs[pert_key] == control, :].T)).T
150
151
  )
151
- if issparse(adata.X):
152
- te2 = (
153
- adata.X.toarray()[adata.obs[pert_key] != control, :]
154
- - ot_sink.apply(adata.X.toarray()[adata.obs[pert_key] == control, :].T).T
155
- / ot_sink.apply(np.ones_like(adata.X.toarray()[adata.obs[pert_key] == control, :].T)).T
156
- )
157
- else:
158
- te2 = (
159
- adata.X[adata.obs[pert_key] != control, :]
160
- - ot_sink.apply(adata.X[adata.obs[pert_key] == control, :].T).T
161
- / ot_sink.apply(np.ones_like(adata.X[adata.obs[pert_key] == control, :].T)).T
162
- )
152
+
153
+ X = adata.X.toarray() if issparse(adata.X) else adata.X
154
+ te2 = (
155
+ X[adata.obs[pert_key] != control, :]
156
+ - ot_sink.apply(X[adata.obs[pert_key] == control, :].T).T
157
+ / ot_sink.apply(np.ones_like(X[adata.obs[pert_key] == control, :].T)).T
158
+ )
159
+ if issparse(X):
160
+ del X
163
161
 
164
162
  adata.obsm[cf_rep] = cf
165
163
  adata.obsm[cf_rep][adata.obs[pert_key] != control, :] = (
@@ -168,21 +166,20 @@ class Cinemaot:
168
166
  )
169
167
 
170
168
  else:
171
- _solver = sinkhorn.Sinkhorn(threshold=eps)
169
+ _solver = jax.jit(sinkhorn.Sinkhorn(threshold=eps))
172
170
  ot_sink = _solver(ot_prob)
173
171
  ot_matrix = np.array(ot_sink.matrix.T, dtype=np.float64)
174
172
  embedding = X_transformed[adata.obs[pert_key] != control, :] - np.matmul(
175
173
  ot_matrix / np.sum(ot_matrix, axis=1)[:, None], X_transformed[adata.obs[pert_key] == control, :]
176
174
  )
177
175
 
178
- if issparse(adata.X):
179
- te2 = adata.X.toarray()[adata.obs[pert_key] != control, :] - np.matmul(
180
- ot_matrix / np.sum(ot_matrix, axis=1)[:, None], adata.X.toarray()[adata.obs[pert_key] == control, :]
181
- )
182
- else:
183
- te2 = adata.X[adata.obs[pert_key] != control, :] - np.matmul(
184
- ot_matrix / np.sum(ot_matrix, axis=1)[:, None], adata.X[adata.obs[pert_key] == control, :]
185
- )
176
+ X = adata.X.toarray() if issparse(adata.X) else adata.X
177
+
178
+ te2 = X[adata.obs[pert_key] != control, :] - np.matmul(
179
+ ot_matrix / np.sum(ot_matrix, axis=1)[:, None], X[adata.obs[pert_key] == control, :]
180
+ )
181
+ if issparse(X):
182
+ del X
186
183
 
187
184
  adata.obsm[cf_rep] = cf
188
185
  adata.obsm[cf_rep][adata.obs[pert_key] != control, :] = np.matmul(
@@ -251,7 +248,7 @@ class Cinemaot:
251
248
  """
252
249
  available_solvers = ["Sinkhorn", "LRSinkhorn"]
253
250
  assert solver in available_solvers, (
254
- f"solver = {solver} is not one of the supported solvers:" f" {available_solvers}"
251
+ f"solver = {solver} is not one of the supported solvers: {available_solvers}"
255
252
  )
256
253
 
257
254
  if dim is None:
@@ -259,7 +256,9 @@ class Cinemaot:
259
256
 
260
257
  adata.obs_names_make_unique()
261
258
 
262
- idx = self.get_weightidx(adata, pert_key=pert_key, control=control, k=k, use_rep=use_rep, resolution=resolution)
259
+ idx = self._get_weightidx(
260
+ adata, pert_key=pert_key, control=control, k=k, use_rep=use_rep, resolution=resolution
261
+ )
263
262
  adata_ = adata[idx].copy()
264
263
  TE = self.causaleffect(
265
264
  adata_,
@@ -329,11 +328,10 @@ class Cinemaot:
329
328
  df = pd.DataFrame(adata.raw.X.toarray(), columns=adata.raw.var_names, index=adata.raw.obs_names)
330
329
  else:
331
330
  df = pd.DataFrame(adata.raw.X, columns=adata.raw.var_names, index=adata.raw.obs_names)
331
+ elif issparse(adata.X):
332
+ df = pd.DataFrame(adata.X.toarray(), columns=adata.var_names, index=adata.obs_names)
332
333
  else:
333
- if issparse(adata.X):
334
- df = pd.DataFrame(adata.X.toarray(), columns=adata.var_names, index=adata.obs_names)
335
- else:
336
- df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs_names)
334
+ df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs_names)
337
335
 
338
336
  if label_list is None:
339
337
  label_list = ["ct"]
@@ -377,10 +375,7 @@ class Cinemaot:
377
375
  >>> dim = model.get_dim(adata)
378
376
  """
379
377
  sk = SinkhornKnopp()
380
- if issparse(adata.raw.X):
381
- data = adata.raw.X.toarray()
382
- else:
383
- data = adata.raw.X
378
+ data = adata.raw.X.toarray() if issparse(adata.raw.X) else adata.raw.X
384
379
  vm = (1e-3 + data + c * data * data) / (1 + c)
385
380
  sk.fit(vm)
386
381
  wm = np.dot(np.dot(np.sqrt(sk._D1), vm), np.sqrt(sk._D2))
@@ -388,9 +383,10 @@ class Cinemaot:
388
383
  dim = min(sum(s > (np.sqrt(data.shape[0]) + np.sqrt(data.shape[1]))), adata.obsm[use_rep].shape[1])
389
384
  return dim
390
385
 
391
- def get_weightidx(
386
+ def _get_weightidx(
392
387
  self,
393
388
  adata: AnnData,
389
+ *,
394
390
  pert_key: str,
395
391
  control: str,
396
392
  use_rep: str = "X_pca",
@@ -401,7 +397,8 @@ class Cinemaot:
401
397
 
402
398
  Args:
403
399
  adata: The annotated data object.
404
- c: the parameter regarding the quadratic variance distribution. c=0 means Poisson count matrices.
400
+ pert_key: Key of the perturbation col.
401
+ control: Key of the control col.
405
402
  use_rep: the embedding used to give a upper bound for the estimated rank.
406
403
  k: the number of neighbors used in the k-NN matching phase.
407
404
  resolution: the clustering resolution used in the sampling phase.
@@ -413,7 +410,7 @@ class Cinemaot:
413
410
  >>> import pertpy as pt
414
411
  >>> adata = pt.dt.cinemaot_example()
415
412
  >>> model = pt.tl.Cinemaot()
416
- >>> idx = model.get_weightidx(adata, pert_key="perturbation", control="No stimulation")
413
+ >>> idx = model._get_weightidx(adata, pert_key="perturbation", control="No stimulation")
417
414
  """
418
415
  adata_ = adata.copy()
419
416
  X_pca1 = adata_.obsm[use_rep][adata_.obs[pert_key] == control, :]
@@ -621,13 +618,12 @@ class Cinemaot:
621
618
  else:
622
619
  Y0 = adata.raw.X[adata.obs[pert_key] == control, :]
623
620
  Y1 = adata.raw.X[adata.obs[pert_key] != control, :]
621
+ elif issparse(adata.X):
622
+ Y0 = adata.X.toarray()[adata.obs[pert_key] == control, :]
623
+ Y1 = adata.X.toarray()[adata.obs[pert_key] != control, :]
624
624
  else:
625
- if issparse(adata.X):
626
- Y0 = adata.X.toarray()[adata.obs[pert_key] == control, :]
627
- Y1 = adata.X.toarray()[adata.obs[pert_key] != control, :]
628
- else:
629
- Y0 = adata.X[adata.obs[pert_key] == control, :]
630
- Y1 = adata.X[adata.obs[pert_key] != control, :]
625
+ Y0 = adata.X[adata.obs[pert_key] == control, :]
626
+ Y1 = adata.X[adata.obs[pert_key] != control, :]
631
627
  X0 = cf[adata.obs[pert_key] == control, :]
632
628
  X1 = cf[adata.obs[pert_key] != control, :]
633
629
  ols0 = LinearRegression()
@@ -643,7 +639,7 @@ class Cinemaot:
643
639
  return c_effect, s_effect
644
640
 
645
641
  @_doc_params(common_plot_args=doc_common_plot_args)
646
- def plot_vis_matching(
642
+ def plot_vis_matching( # pragma: no cover # noqa: D417
647
643
  self,
648
644
  adata: AnnData,
649
645
  de: AnnData,
@@ -658,7 +654,6 @@ class Cinemaot:
658
654
  title: str = "CINEMA-OT matching matrix",
659
655
  min_val: float = 0.01,
660
656
  ax: Axes | None = None,
661
- show: bool = True,
662
657
  return_fig: bool = False,
663
658
  **kwargs,
664
659
  ) -> Figure | None:
@@ -672,9 +667,11 @@ class Cinemaot:
672
667
  de_label: the label for differential response. If none, use leiden cluster labels at resolution 1.0.
673
668
  source_label: the confounder / cell type label.
674
669
  matching_rep: the place that stores the matching matrix. default de.obsm['ot'].
670
+ resolution: Leiden resolution.
675
671
  normalize: normalize the coarse-grained matching matrix by row / column.
676
672
  title: the title for the figure.
677
673
  min_val: The min value to truncate the matching matrix.
674
+ ax: Matplotlib axes object.
678
675
  {common_plot_args}
679
676
  **kwargs: Other parameters to input for seaborn.heatmap.
680
677
 
@@ -704,35 +701,26 @@ class Cinemaot:
704
701
  df["source_label"] = adata_.obs[source_label].astype(str).values
705
702
  df = df.groupby("source_label").sum()
706
703
 
707
- if normalize == "col":
708
- df = df / df.sum(axis=0)
709
- else:
710
- df = (df.T / df.sum(axis=1)).T
704
+ df = df / df.sum(axis=0) if normalize == "col" else (df.T / df.sum(axis=1)).T
711
705
  df = df.clip(lower=min_val) - min_val
712
- if normalize == "col":
713
- df = df / df.sum(axis=0)
714
- else:
715
- df = (df.T / df.sum(axis=1)).T
706
+ df = df / df.sum(axis=0) if normalize == "col" else (df.T / df.sum(axis=1)).T
716
707
 
717
- g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
708
+ g = heatmap(df, annot=True, ax=ax, **kwargs)
718
709
  plt.title(title)
719
710
 
720
- if show:
721
- plt.show()
722
711
  if return_fig:
723
712
  return g
713
+ plt.show()
724
714
  return None
725
715
 
726
716
 
727
717
  class Xi:
728
- """
729
- A fast implementation of cross-rank dependence metric used in CINEMA-OT.
730
-
731
- """
718
+ """A fast implementation of cross-rank dependence metric used in CINEMA-OT."""
732
719
 
733
- def __init__(self, x, y):
720
+ def __init__(self, x, y, random_state: int | None = 0):
734
721
  self.x = x
735
722
  self.y = y
723
+ self.random_state = random_state
736
724
 
737
725
  @property
738
726
  def sample_size(self):
@@ -741,18 +729,15 @@ class Xi:
741
729
  @property
742
730
  def x_ordered_rank(self):
743
731
  # PI is the rank vector for x, with ties broken at random
744
- # Not mine: source (https://stackoverflow.com/a/47430384/1628971)
745
- # random shuffling of the data - reason to use random.choice is that
746
- # pd.sample(frac=1) uses the same randomizing algorithm
747
732
  len_x = len(self.x)
748
- rng = np.random.default_rng()
749
- randomized_indices = rng.choice(np.arange(len_x), len_x, replace=False)
750
- randomized = [self.x[idx] for idx in randomized_indices]
751
- # same as pandas rank method 'first'
752
- rankdata = ss.rankdata(randomized, method="ordinal")
753
- # Reindexing based on pairs of indices before and after
754
- unrandomized = [rankdata[j] for i, j in sorted(zip(randomized_indices, range(len_x), strict=False))]
755
- return unrandomized
733
+ rng = np.random.default_rng(self.random_state)
734
+ perm = rng.permutation(len_x)
735
+ x_shuffled = self.x[perm]
736
+
737
+ ranks = np.empty(len_x, dtype=int)
738
+ ranks[perm[np.argsort(x_shuffled, stable=True)]] = np.arange(1, len_x + 1)
739
+
740
+ return ranks
756
741
 
757
742
  @property
758
743
  def y_rank_max(self):
@@ -762,7 +747,7 @@ class Xi:
762
747
  @property
763
748
  def g(self):
764
749
  # g[i] is number of j s.t. y[j] >= y[i], divided by n.
765
- return ss.rankdata([-i for i in self.y], method="max") / self.sample_size
750
+ return ss.rankdata(-self.y, method="max") / self.sample_size
766
751
 
767
752
  @property
768
753
  def x_ordered(self):
@@ -771,32 +756,14 @@ class Xi:
771
756
 
772
757
  @property
773
758
  def x_rank_max_ordered(self):
774
- x_ordered_result = self.x_ordered
775
- y_rank_max_result = self.y_rank_max
776
- # Rearrange f according to ord.
777
- return [y_rank_max_result[i] for i in x_ordered_result]
759
+ return self.y_rank_max[self.x_ordered]
778
760
 
779
761
  @property
780
762
  def mean_absolute(self):
781
763
  x1 = self.x_rank_max_ordered[0 : (self.sample_size - 1)]
782
764
  x2 = self.x_rank_max_ordered[1 : self.sample_size]
783
765
 
784
- return (
785
- np.mean(
786
- np.abs(
787
- [
788
- x - y
789
- for x, y in zip(
790
- x1,
791
- x2,
792
- strict=False,
793
- )
794
- ]
795
- )
796
- )
797
- * (self.sample_size - 1)
798
- / (2 * self.sample_size)
799
- )
766
+ return np.mean(np.abs(x1 - x2)) * (self.sample_size - 1) / (2 * self.sample_size)
800
767
 
801
768
  @property
802
769
  def inverse_g_mean(self):
@@ -805,7 +772,7 @@ class Xi:
805
772
 
806
773
  @property
807
774
  def correlation(self):
808
- """xi correlation"""
775
+ """Xi correlation."""
809
776
  return 1 - self.mean_absolute / self.inverse_g_mean
810
777
 
811
778
  @classmethod
@@ -854,10 +821,7 @@ class Xi:
854
821
 
855
822
 
856
823
  class SinkhornKnopp:
857
- """
858
- An simple implementation of Sinkhorn iteration used in the biwhitening approach.
859
-
860
- """
824
+ """An simple implementation of Sinkhorn iteration used in the biwhitening approach."""
861
825
 
862
826
  def __init__(self, max_iter: float = 1000, setr: int = 0, setc: float = 0, epsilon: float = 1e-3):
863
827
  if max_iter < 0:
@@ -891,14 +855,8 @@ class SinkhornKnopp:
891
855
  assert P.ndim == 2
892
856
 
893
857
  N = P.shape[0]
894
- if np.sum(abs(self._setr)) == 0:
895
- rsum = P.shape[1]
896
- else:
897
- rsum = self._setr
898
- if np.sum(abs(self._setc)) == 0:
899
- csum = P.shape[0]
900
- else:
901
- csum = self._setc
858
+ rsum = P.shape[1] if np.sum(abs(self._setr)) == 0 else self._setr
859
+ csum = P.shape[0] if np.sum(abs(self._setc)) == 0 else self._setc
902
860
  max_threshr = rsum + self._epsilon
903
861
  min_threshr = rsum - self._epsilon
904
862
  max_threshc = csum + self._epsilon