pertpy 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. pertpy/__init__.py +5 -1
  2. pertpy/_doc.py +1 -3
  3. pertpy/_types.py +6 -0
  4. pertpy/data/_dataloader.py +68 -24
  5. pertpy/data/_datasets.py +9 -9
  6. pertpy/metadata/__init__.py +2 -1
  7. pertpy/metadata/_cell_line.py +133 -25
  8. pertpy/metadata/_look_up.py +13 -19
  9. pertpy/metadata/_moa.py +1 -1
  10. pertpy/preprocessing/_guide_rna.py +138 -44
  11. pertpy/preprocessing/_guide_rna_mixture.py +17 -19
  12. pertpy/tools/__init__.py +1 -1
  13. pertpy/tools/_augur.py +106 -98
  14. pertpy/tools/_cinemaot.py +74 -114
  15. pertpy/tools/_coda/_base_coda.py +129 -145
  16. pertpy/tools/_coda/_sccoda.py +66 -69
  17. pertpy/tools/_coda/_tasccoda.py +71 -79
  18. pertpy/tools/_dialogue.py +48 -40
  19. pertpy/tools/_differential_gene_expression/_base.py +21 -31
  20. pertpy/tools/_differential_gene_expression/_checks.py +4 -6
  21. pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
  22. pertpy/tools/_differential_gene_expression/_edger.py +6 -10
  23. pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
  24. pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
  25. pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
  26. pertpy/tools/_distances/_distance_tests.py +1 -2
  27. pertpy/tools/_distances/_distances.py +31 -45
  28. pertpy/tools/_enrichment.py +7 -22
  29. pertpy/tools/_milo.py +19 -15
  30. pertpy/tools/_mixscape.py +73 -75
  31. pertpy/tools/_perturbation_space/_clustering.py +4 -4
  32. pertpy/tools/_perturbation_space/_comparison.py +4 -4
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
  35. pertpy/tools/_perturbation_space/_simple.py +12 -14
  36. pertpy/tools/_scgen/_scgen.py +16 -17
  37. pertpy/tools/_scgen/_scgenvae.py +2 -2
  38. pertpy/tools/_scgen/_utils.py +3 -1
  39. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/METADATA +36 -20
  40. pertpy-0.11.0.dist-info/RECORD +58 -0
  41. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
  42. pertpy/tools/_kernel_pca.py +0 -50
  43. pertpy-0.10.0.dist-info/RECORD +0 -58
  44. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import uuid
4
+ from functools import singledispatchmethod
4
5
  from typing import TYPE_CHECKING, Literal
5
6
  from warnings import warn
6
7
 
@@ -8,41 +9,45 @@ import matplotlib.pyplot as plt
8
9
  import numpy as np
9
10
  import pandas as pd
10
11
  import scanpy as sc
11
- import scipy
12
+ from anndata import AnnData
13
+ from numba import njit, prange
12
14
  from rich.progress import track
13
- from scipy.sparse import issparse
15
+ from scanpy.get import _get_obs_rep, _set_obs_rep
16
+ from scipy.sparse import csr_matrix, issparse
14
17
 
15
18
  from pertpy._doc import _doc_params, doc_common_plot_args
19
+ from pertpy._types import CSRBase
16
20
  from pertpy.preprocessing._guide_rna_mixture import PoissonGaussMixture
17
21
 
18
22
  if TYPE_CHECKING:
19
- from anndata import AnnData
20
23
  from matplotlib.pyplot import Figure
21
24
 
22
25
 
23
26
  class GuideAssignment:
24
27
  """Assign cells to guide RNAs."""
25
28
 
29
+ @singledispatchmethod
26
30
  def assign_by_threshold(
27
31
  self,
28
- adata: AnnData,
32
+ data: AnnData | np.ndarray | CSRBase,
33
+ /,
34
+ *,
29
35
  assignment_threshold: float,
30
36
  layer: str | None = None,
31
37
  output_layer: str = "assigned_guides",
32
- only_return_results: bool = False,
33
- ) -> np.ndarray | None:
38
+ ):
34
39
  """Simple threshold based gRNA assignment function.
35
40
 
36
41
  Each cell is assigned to gRNA with at least `assignment_threshold` counts.
37
42
  This function expects unnormalized data as input.
38
43
 
39
44
  Args:
40
- adata: AnnData object containing gRNA values.
45
+ data: The (annotated) data matrix of shape `n_obs` × `n_vars`.
46
+ Rows correspond to cells and columns to genes.
41
47
  assignment_threshold: The count threshold that is required for an assignment to be viable.
42
48
  layer: Key to the layer containing raw count values of the gRNAs.
43
49
  adata.X is used if layer is None. Expects count data.
44
50
  output_layer: Assigned guide will be saved on adata.layers[output_key].
45
- only_return_results: Whether to input AnnData is not modified and the result is returned as an :class:`np.ndarray`.
46
51
 
47
52
  Examples:
48
53
  Each cell is assigned to gRNA that occurs at least 5 times in the respective cell.
@@ -53,26 +58,52 @@ class GuideAssignment:
53
58
  >>> ga = pt.pp.GuideAssignment()
54
59
  >>> ga.assign_by_threshold(gdo, assignment_threshold=5)
55
60
  """
56
- counts = adata.X if layer is None else adata.layers[layer]
57
- if scipy.sparse.issparse(counts):
58
- counts = counts.toarray()
59
-
60
- assigned_grnas = np.where(counts >= assignment_threshold, 1, 0)
61
- assigned_grnas = scipy.sparse.csr_matrix(assigned_grnas)
62
- if only_return_results:
63
- return assigned_grnas
64
- adata.layers[output_layer] = assigned_grnas
65
-
66
- return None
61
+ raise NotImplementedError(
62
+ f"No implementation found for {type(data)}. Must be numpy array, sparse matrix, or AnnData object."
63
+ )
67
64
 
68
- def assign_to_max_guide(
65
+ @assign_by_threshold.register(AnnData)
66
+ def _assign_by_threshold_anndata(
69
67
  self,
70
68
  adata: AnnData,
69
+ /,
70
+ *,
71
+ assignment_threshold: float,
72
+ layer: str | None = None,
73
+ output_layer: str = "assigned_guides",
74
+ ) -> None:
75
+ X = _get_obs_rep(adata, layer=layer)
76
+ guide_assignments = self.assign_by_threshold(X, assignment_threshold=assignment_threshold)
77
+ _set_obs_rep(adata, guide_assignments, layer=output_layer)
78
+
79
+ @assign_by_threshold.register(np.ndarray)
80
+ def _assign_by_threshold_numpy(self, X: np.ndarray, /, *, assignment_threshold: float) -> np.ndarray:
81
+ return np.where(assignment_threshold <= X, 1, 0)
82
+
83
+ @staticmethod
84
+ @njit(parallel=True)
85
+ def _threshold_sparse_numba(data: np.ndarray, threshold: float) -> np.ndarray:
86
+ out = np.zeros_like(data, dtype=np.int8)
87
+ for i in prange(data.shape[0]):
88
+ if data[i] >= threshold:
89
+ out[i] = 1
90
+ return out
91
+
92
+ @assign_by_threshold.register(CSRBase)
93
+ def _assign_by_threshold_sparse(self, X: CSRBase, /, *, assignment_threshold: float) -> CSRBase:
94
+ new_data = self._threshold_sparse_numba(X.data, assignment_threshold)
95
+ return csr_matrix((new_data, X.indices, X.indptr), shape=X.shape)
96
+
97
+ @singledispatchmethod
98
+ def assign_to_max_guide(
99
+ self,
100
+ data: AnnData | np.ndarray | CSRBase,
101
+ /,
102
+ *,
71
103
  assignment_threshold: float,
72
104
  layer: str | None = None,
73
- output_key: str = "assigned_guide",
105
+ obs_key: str = "assigned_guide",
74
106
  no_grna_assigned_key: str = "Negative",
75
- only_return_results: bool = False,
76
107
  ) -> np.ndarray | None:
77
108
  """Simple threshold based max gRNA assignment function.
78
109
 
@@ -80,13 +111,13 @@ class GuideAssignment:
80
111
  This function expects unnormalized data as input.
81
112
 
82
113
  Args:
83
- adata: AnnData object containing gRNA values.
114
+ data: The (annotated) data matrix of shape `n_obs` × `n_vars`.
115
+ Rows correspond to cells and columns to genes.
84
116
  assignment_threshold: The count threshold that is required for an assignment to be viable.
85
117
  layer: Key to the layer containing raw count values of the gRNAs.
86
118
  adata.X is used if layer is None. Expects count data.
87
- output_key: Assigned guide will be saved on adata.obs[output_key]. default value is `assigned_guide`.
119
+ obs_key: Assigned guide will be saved on adata.obs[output_key].
88
120
  no_grna_assigned_key: The key to return if no gRNA is expressed enough.
89
- only_return_results: Whether to input AnnData is not modified and the result is returned as an np.ndarray.
90
121
 
91
122
  Examples:
92
123
  Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
@@ -97,21 +128,79 @@ class GuideAssignment:
97
128
  >>> ga = pt.pp.GuideAssignment()
98
129
  >>> ga.assign_to_max_guide(gdo, assignment_threshold=5)
99
130
  """
100
- counts = adata.X if layer is None else adata.layers[layer]
101
- if scipy.sparse.issparse(counts):
102
- counts = counts.toarray()
131
+ raise NotImplementedError(
132
+ f"No implementation found for {type(data)}. Must be numpy array, sparse matrix, or AnnData object."
133
+ )
134
+
135
+ @assign_to_max_guide.register(AnnData)
136
+ def assign_to_max_guide_anndata(
137
+ self,
138
+ adata: AnnData,
139
+ /,
140
+ *,
141
+ assignment_threshold: float,
142
+ layer: str | None = None,
143
+ obs_key: str = "assigned_guide",
144
+ no_grna_assigned_key: str = "Negative",
145
+ ) -> None:
146
+ X = _get_obs_rep(adata, layer=layer)
147
+ guide_assignments = self.assign_to_max_guide(
148
+ X, var=adata.var, assignment_threshold=assignment_threshold, no_grna_assigned_key=no_grna_assigned_key
149
+ )
150
+ adata.obs[obs_key] = guide_assignments
103
151
 
152
+ @assign_to_max_guide.register(np.ndarray)
153
+ def assign_to_max_guide_numpy(
154
+ self,
155
+ X: np.ndarray,
156
+ /,
157
+ *,
158
+ var: pd.DataFrame,
159
+ assignment_threshold: float,
160
+ no_grna_assigned_key: str = "Negative",
161
+ ) -> np.ndarray:
104
162
  assigned_grna = np.where(
105
- counts.max(axis=1).squeeze() >= assignment_threshold,
106
- adata.var.index[counts.argmax(axis=1).squeeze()],
163
+ X.max(axis=1).squeeze() >= assignment_threshold,
164
+ var.index[X.argmax(axis=1).squeeze()],
107
165
  no_grna_assigned_key,
108
166
  )
109
167
 
110
- if only_return_results:
111
- return assigned_grna
112
- adata.obs[output_key] = assigned_grna
168
+ return assigned_grna
169
+
170
+ @staticmethod
171
+ @njit(parallel=True)
172
+ def _assign_max_guide_sparse(indptr, data, indices, assignment_threshold, assigned_grna):
173
+ n_rows = len(indptr) - 1
174
+ for i in range(n_rows):
175
+ row_start = indptr[i]
176
+ row_end = indptr[i + 1]
177
+
178
+ if row_end > row_start:
179
+ data_row = data[row_start:row_end]
180
+ indices_row = indices[row_start:row_end]
181
+ max_pos = np.argmax(data_row)
182
+ if data_row[max_pos] >= assignment_threshold:
183
+ assigned_grna[i] = indices_row[max_pos]
184
+ return assigned_grna
185
+
186
+ @assign_to_max_guide.register(CSRBase)
187
+ def assign_to_max_guide_sparse(
188
+ self, X: CSRBase, /, *, var: pd.DataFrame, assignment_threshold: float, no_grna_assigned_key: str = "Negative"
189
+ ) -> np.ndarray:
190
+ n_rows = X.shape[0]
191
+
192
+ assigned_positions = np.zeros(n_rows, dtype=np.int32) - 1 # -1 means not assigned
193
+ assigned_positions = self._assign_max_guide_sparse(
194
+ X.indptr, X.data, X.indices, assignment_threshold, assigned_positions
195
+ )
113
196
 
114
- return None
197
+ assigned_grna = np.full(n_rows, no_grna_assigned_key, dtype=object)
198
+ mask = assigned_positions >= 0
199
+ var_index_array = np.array(var.index)
200
+ if np.any(mask):
201
+ assigned_grna[mask] = var_index_array[assigned_positions[mask]]
202
+
203
+ return assigned_grna
115
204
 
116
205
  def assign_mixture_model(
117
206
  self,
@@ -123,7 +212,6 @@ class GuideAssignment:
123
212
  multiple_grna_assigned_key: str = "multiple",
124
213
  multiple_grna_assignment_string: str = "+",
125
214
  only_return_results: bool = False,
126
- uns_key: str = "guide_assignment_params",
127
215
  show_progress: bool = False,
128
216
  **mixture_model_kwargs,
129
217
  ) -> np.ndarray | None:
@@ -132,7 +220,7 @@ class GuideAssignment:
132
220
  Args:
133
221
  adata: AnnData object containing gRNA values.
134
222
  model: The model to use for the mixture model. Currently only `Poisson_Gauss_Mixture` is supported.
135
- output_key: Assigned guide will be saved on adata.obs[output_key].
223
+ assigned_guides_key: Assigned guide will be saved on adata.obs[output_key].
136
224
  no_grna_assigned_key: The key to return if a cell is negative for all gRNAs.
137
225
  max_assignments_per_cell: The maximum number of gRNAs that can be assigned to a cell.
138
226
  multiple_grna_assigned_key: The key to return if multiple gRNAs are assigned to a cell.
@@ -153,11 +241,6 @@ class GuideAssignment:
153
241
  else:
154
242
  raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.")
155
243
 
156
- if uns_key not in adata.uns:
157
- adata.uns[uns_key] = {}
158
- elif type(adata.uns[uns_key]) is not dict:
159
- raise ValueError(f"adata.uns['{uns_key}'] should be a dictionary. Please remove it or change the key.")
160
-
161
244
  res = pd.DataFrame(0, index=adata.obs_names, columns=adata.var_names)
162
245
  fct = track if show_progress else lambda iterable: iterable
163
246
  for gene in fct(adata.var_names):
@@ -181,7 +264,18 @@ class GuideAssignment:
181
264
  data = np.log2(data)
182
265
  assignments = mixture_model.run_model(data)
183
266
  res.loc[adata.obs_names[is_nonzero][assignments == "Positive"], gene] = 1
184
- adata.uns[uns_key][gene] = mixture_model.params
267
+
268
+ # Add the parameters to the adata.var DataFrame
269
+ for params_name, param in mixture_model.params.items():
270
+ if param.ndim == 0:
271
+ if params_name not in adata.var.columns:
272
+ adata.var[params_name] = np.nan
273
+ adata.var.loc[gene, params_name] = param.item()
274
+ else:
275
+ for i, p in enumerate(param):
276
+ if f"{params_name}_{i}" not in adata.var.columns:
277
+ adata.var[f"{params_name}_{i}"] = np.nan
278
+ adata.var.loc[gene, f"{params_name}_{i}"] = p
185
279
 
186
280
  # Assign guides to cells
187
281
  # Some cells might have multiple guides assigned
@@ -200,7 +294,7 @@ class GuideAssignment:
200
294
  return None
201
295
 
202
296
  @_doc_params(common_plot_args=doc_common_plot_args)
203
- def plot_heatmap(
297
+ def plot_heatmap( # pragma: no cover # noqa: D417
204
298
  self,
205
299
  adata: AnnData,
206
300
  *,
@@ -248,7 +342,7 @@ class GuideAssignment:
248
342
  data = adata.X if layer is None else adata.layers[layer]
249
343
 
250
344
  if order_by is None:
251
- if scipy.sparse.issparse(data):
345
+ if issparse(data):
252
346
  max_values = data.max(axis=1).toarray().squeeze()
253
347
  data_argmax = data.argmax(axis=1).A.squeeze()
254
348
  max_guide_index = np.where(max_values != data.min(axis=1).toarray().squeeze(), data_argmax, -1)
@@ -3,12 +3,12 @@ from __future__ import annotations
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import Mapping
5
5
 
6
- import jax
7
6
  import jax.numpy as jnp
8
7
  import numpy as np
9
- import numpyro
10
- import numpyro.distributions as dist
11
- from jax import random
8
+ from jax.random import PRNGKey
9
+ from jax.scipy.special import logsumexp
10
+ from numpyro import factor, plate, sample
11
+ from numpyro.distributions import Dirichlet, Exponential, HalfNormal, Normal, Poisson
12
12
  from numpyro.infer import MCMC, NUTS
13
13
 
14
14
  ParamsDict = Mapping[str, jnp.ndarray]
@@ -49,7 +49,6 @@ class MixtureModel(ABC):
49
49
  Returns:
50
50
  Dictionary of sampled parameter values.
51
51
  """
52
- pass
53
52
 
54
53
  @abstractmethod
55
54
  def log_likelihood(self, data: jnp.ndarray, params: ParamsDict) -> jnp.ndarray:
@@ -62,7 +61,6 @@ class MixtureModel(ABC):
62
61
  Returns:
63
62
  Log likelihood values for each datapoint.
64
63
  """
65
- pass
66
64
 
67
65
  def fit_model(self, data: jnp.ndarray, seed: int = 0) -> MCMC:
68
66
  """Fit the mixture model using MCMC.
@@ -76,7 +74,7 @@ class MixtureModel(ABC):
76
74
  """
77
75
  nuts_kernel = NUTS(self.mixture_model)
78
76
  mcmc = MCMC(nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, progress_bar=False)
79
- mcmc.run(random.PRNGKey(seed), data=data)
77
+ mcmc.run(PRNGKey(seed), data=data)
80
78
  return mcmc
81
79
 
82
80
  def run_model(self, data: jnp.ndarray, seed: int = 0) -> np.ndarray:
@@ -102,10 +100,10 @@ class MixtureModel(ABC):
102
100
  """
103
101
  params = self.initialize_params()
104
102
 
105
- with numpyro.plate("data", data.shape[0]):
103
+ with plate("data", data.shape[0]):
106
104
  log_likelihoods = self.log_likelihood(data, params)
107
- log_mixture_likelihood = jax.scipy.special.logsumexp(log_likelihoods, axis=-1)
108
- numpyro.sample("obs", dist.Normal(log_mixture_likelihood, 1.0), obs=data)
105
+ log_mixture_likelihood = logsumexp(log_likelihoods, axis=-1)
106
+ sample("obs", Normal(log_mixture_likelihood, 1.0), obs=data)
109
107
 
110
108
  def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray:
111
109
  """Assign data points to mixture components.
@@ -117,7 +115,7 @@ class MixtureModel(ABC):
117
115
  Returns:
118
116
  Array of component assignments.
119
117
  """
120
- params = {key: samples[key].mean(axis=0) for key in samples.keys()}
118
+ params = {key: samples[key].mean(axis=0) for key in samples}
121
119
  self.params = params
122
120
 
123
121
  log_likelihoods = self.log_likelihood(data, params)
@@ -148,14 +146,14 @@ class PoissonGaussMixture(MixtureModel):
148
146
  # We penalize the model for positioning the Poisson component to the right of the Gaussian component
149
147
  # by imposing a soft constraint to penalize the Poisson rate being larger than the Gaussian mean
150
148
  # Heuristic regularization term to prevent flipping of the components
151
- numpyro.factor("separation_penalty", +10 * jnp.heaviside(-poisson_rate + gaussian_mean, 0))
149
+ factor("separation_penalty", +10 * jnp.heaviside(-poisson_rate + gaussian_mean, 0))
152
150
 
153
151
  log_likelihoods = jnp.stack(
154
152
  [
155
153
  # Poisson component
156
- jnp.log(mix_probs[0]) + dist.Poisson(poisson_rate).log_prob(data),
154
+ jnp.log(mix_probs[0]) + Poisson(poisson_rate).log_prob(data),
157
155
  # Gaussian component
158
- jnp.log(mix_probs[1]) + dist.Normal(gaussian_mean, gaussian_std).log_prob(data),
156
+ jnp.log(mix_probs[1]) + Normal(gaussian_mean, gaussian_std).log_prob(data),
159
157
  ],
160
158
  axis=-1,
161
159
  )
@@ -169,11 +167,11 @@ class PoissonGaussMixture(MixtureModel):
169
167
  Dictionary of sampled parameter values.
170
168
  """
171
169
  params = {}
172
- params["poisson_rate"] = numpyro.sample("poisson_rate", dist.Exponential(self.poisson_rate_prior))
173
- params["gaussian_mean"] = numpyro.sample("gaussian_mean", dist.Normal(*self.gaussian_mean_prior))
174
- params["gaussian_std"] = numpyro.sample("gaussian_std", dist.HalfNormal(self.gaussian_std_prior))
175
- params["mix_probs"] = numpyro.sample(
170
+ params["poisson_rate"] = sample("poisson_rate", Exponential(self.poisson_rate_prior))
171
+ params["gaussian_mean"] = sample("gaussian_mean", Normal(*self.gaussian_mean_prior))
172
+ params["gaussian_std"] = sample("gaussian_std", HalfNormal(self.gaussian_std_prior))
173
+ params["mix_probs"] = sample(
176
174
  "mix_probs",
177
- dist.Dirichlet(jnp.array([1 - self.fraction_positive_expected, self.fraction_positive_expected])),
175
+ Dirichlet(jnp.array([1 - self.fraction_positive_expected, self.fraction_positive_expected])),
178
176
  )
179
177
  return params
pertpy/tools/__init__.py CHANGED
@@ -41,7 +41,7 @@ from pertpy.tools._perturbation_space._simple import (
41
41
  )
42
42
  from pertpy.tools._scgen import Scgen
43
43
 
44
- CODA_EXTRAS = ["toytree", "arviz", "ete3"] # also pyqt5 technically
44
+ CODA_EXTRAS = ["toytree", "arviz", "ete4"] # also pyqt6 technically
45
45
  Sccoda = lazy_import("pertpy.tools._coda._sccoda", "Sccoda", CODA_EXTRAS)
46
46
  Tasccoda = lazy_import("pertpy.tools._coda._tasccoda", "Tasccoda", CODA_EXTRAS)
47
47