pertpy 0.9.5__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. pertpy/__init__.py +5 -1
  2. pertpy/_doc.py +2 -5
  3. pertpy/_types.py +6 -0
  4. pertpy/data/_dataloader.py +68 -24
  5. pertpy/data/_datasets.py +9 -9
  6. pertpy/metadata/__init__.py +2 -1
  7. pertpy/metadata/_cell_line.py +136 -30
  8. pertpy/metadata/_look_up.py +13 -19
  9. pertpy/metadata/_moa.py +1 -1
  10. pertpy/preprocessing/_guide_rna.py +221 -39
  11. pertpy/preprocessing/_guide_rna_mixture.py +177 -0
  12. pertpy/tools/__init__.py +1 -1
  13. pertpy/tools/_augur.py +138 -142
  14. pertpy/tools/_cinemaot.py +75 -117
  15. pertpy/tools/_coda/_base_coda.py +150 -174
  16. pertpy/tools/_coda/_sccoda.py +66 -69
  17. pertpy/tools/_coda/_tasccoda.py +71 -79
  18. pertpy/tools/_dialogue.py +60 -56
  19. pertpy/tools/_differential_gene_expression/_base.py +25 -43
  20. pertpy/tools/_differential_gene_expression/_checks.py +4 -6
  21. pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
  22. pertpy/tools/_differential_gene_expression/_edger.py +6 -10
  23. pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
  24. pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
  25. pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
  26. pertpy/tools/_distances/_distance_tests.py +1 -2
  27. pertpy/tools/_distances/_distances.py +86 -92
  28. pertpy/tools/_enrichment.py +8 -25
  29. pertpy/tools/_milo.py +23 -27
  30. pertpy/tools/_mixscape.py +261 -175
  31. pertpy/tools/_perturbation_space/_clustering.py +4 -4
  32. pertpy/tools/_perturbation_space/_comparison.py +4 -4
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
  35. pertpy/tools/_perturbation_space/_simple.py +13 -17
  36. pertpy/tools/_scgen/_scgen.py +17 -20
  37. pertpy/tools/_scgen/_scgenvae.py +2 -2
  38. pertpy/tools/_scgen/_utils.py +3 -1
  39. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/METADATA +37 -21
  40. pertpy-0.11.0.dist-info/RECORD +58 -0
  41. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
  42. pertpy/tools/_kernel_pca.py +0 -50
  43. pertpy-0.9.5.dist-info/RECORD +0 -57
  44. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
@@ -1,44 +1,53 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import uuid
4
- from typing import TYPE_CHECKING
4
+ from functools import singledispatchmethod
5
+ from typing import TYPE_CHECKING, Literal
6
+ from warnings import warn
5
7
 
6
8
  import matplotlib.pyplot as plt
7
9
  import numpy as np
8
10
  import pandas as pd
9
11
  import scanpy as sc
10
- import scipy
12
+ from anndata import AnnData
13
+ from numba import njit, prange
14
+ from rich.progress import track
15
+ from scanpy.get import _get_obs_rep, _set_obs_rep
16
+ from scipy.sparse import csr_matrix, issparse
11
17
 
12
18
  from pertpy._doc import _doc_params, doc_common_plot_args
19
+ from pertpy._types import CSRBase
20
+ from pertpy.preprocessing._guide_rna_mixture import PoissonGaussMixture
13
21
 
14
22
  if TYPE_CHECKING:
15
- from anndata import AnnData
16
23
  from matplotlib.pyplot import Figure
17
24
 
18
25
 
19
26
  class GuideAssignment:
20
- """Offers simple guide assigment based on count thresholds."""
27
+ """Assign cells to guide RNAs."""
21
28
 
29
+ @singledispatchmethod
22
30
  def assign_by_threshold(
23
31
  self,
24
- adata: AnnData,
32
+ data: AnnData | np.ndarray | CSRBase,
33
+ /,
34
+ *,
25
35
  assignment_threshold: float,
26
36
  layer: str | None = None,
27
37
  output_layer: str = "assigned_guides",
28
- only_return_results: bool = False,
29
- ) -> np.ndarray | None:
38
+ ):
30
39
  """Simple threshold based gRNA assignment function.
31
40
 
32
41
  Each cell is assigned to gRNA with at least `assignment_threshold` counts.
33
42
  This function expects unnormalized data as input.
34
43
 
35
44
  Args:
36
- adata: Annotated data matrix containing gRNA values
45
+ data: The (annotated) data matrix of shape `n_obs` × `n_vars`.
46
+ Rows correspond to cells and columns to genes.
37
47
  assignment_threshold: The count threshold that is required for an assignment to be viable.
38
48
  layer: Key to the layer containing raw count values of the gRNAs.
39
49
  adata.X is used if layer is None. Expects count data.
40
50
  output_layer: Assigned guide will be saved on adata.layers[output_key].
41
- only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
42
51
 
43
52
  Examples:
44
53
  Each cell is assigned to gRNA that occurs at least 5 times in the respective cell.
@@ -49,26 +58,52 @@ class GuideAssignment:
49
58
  >>> ga = pt.pp.GuideAssignment()
50
59
  >>> ga.assign_by_threshold(gdo, assignment_threshold=5)
51
60
  """
52
- counts = adata.X if layer is None else adata.layers[layer]
53
- if scipy.sparse.issparse(counts):
54
- counts = counts.toarray()
55
-
56
- assigned_grnas = np.where(counts >= assignment_threshold, 1, 0)
57
- assigned_grnas = scipy.sparse.csr_matrix(assigned_grnas)
58
- if only_return_results:
59
- return assigned_grnas
60
- adata.layers[output_layer] = assigned_grnas
61
-
62
- return None
61
+ raise NotImplementedError(
62
+ f"No implementation found for {type(data)}. Must be numpy array, sparse matrix, or AnnData object."
63
+ )
63
64
 
64
- def assign_to_max_guide(
65
+ @assign_by_threshold.register(AnnData)
66
+ def _assign_by_threshold_anndata(
65
67
  self,
66
68
  adata: AnnData,
69
+ /,
70
+ *,
67
71
  assignment_threshold: float,
68
72
  layer: str | None = None,
69
- output_key: str = "assigned_guide",
70
- no_grna_assigned_key: str = "NT",
71
- only_return_results: bool = False,
73
+ output_layer: str = "assigned_guides",
74
+ ) -> None:
75
+ X = _get_obs_rep(adata, layer=layer)
76
+ guide_assignments = self.assign_by_threshold(X, assignment_threshold=assignment_threshold)
77
+ _set_obs_rep(adata, guide_assignments, layer=output_layer)
78
+
79
+ @assign_by_threshold.register(np.ndarray)
80
+ def _assign_by_threshold_numpy(self, X: np.ndarray, /, *, assignment_threshold: float) -> np.ndarray:
81
+ return np.where(assignment_threshold <= X, 1, 0)
82
+
83
+ @staticmethod
84
+ @njit(parallel=True)
85
+ def _threshold_sparse_numba(data: np.ndarray, threshold: float) -> np.ndarray:
86
+ out = np.zeros_like(data, dtype=np.int8)
87
+ for i in prange(data.shape[0]):
88
+ if data[i] >= threshold:
89
+ out[i] = 1
90
+ return out
91
+
92
+ @assign_by_threshold.register(CSRBase)
93
+ def _assign_by_threshold_sparse(self, X: CSRBase, /, *, assignment_threshold: float) -> CSRBase:
94
+ new_data = self._threshold_sparse_numba(X.data, assignment_threshold)
95
+ return csr_matrix((new_data, X.indices, X.indptr), shape=X.shape)
96
+
97
+ @singledispatchmethod
98
+ def assign_to_max_guide(
99
+ self,
100
+ data: AnnData | np.ndarray | CSRBase,
101
+ /,
102
+ *,
103
+ assignment_threshold: float,
104
+ layer: str | None = None,
105
+ obs_key: str = "assigned_guide",
106
+ no_grna_assigned_key: str = "Negative",
72
107
  ) -> np.ndarray | None:
73
108
  """Simple threshold based max gRNA assignment function.
74
109
 
@@ -76,13 +111,13 @@ class GuideAssignment:
76
111
  This function expects unnormalized data as input.
77
112
 
78
113
  Args:
79
- adata: Annotated data matrix containing gRNA values
114
+ data: The (annotated) data matrix of shape `n_obs` × `n_vars`.
115
+ Rows correspond to cells and columns to genes.
80
116
  assignment_threshold: The count threshold that is required for an assignment to be viable.
81
117
  layer: Key to the layer containing raw count values of the gRNAs.
82
118
  adata.X is used if layer is None. Expects count data.
83
- output_key: Assigned guide will be saved on adata.obs[output_key]. default value is `assigned_guide`.
119
+ obs_key: Assigned guide will be saved on adata.obs[output_key].
84
120
  no_grna_assigned_key: The key to return if no gRNA is expressed enough.
85
- only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
86
121
 
87
122
  Examples:
88
123
  Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
@@ -93,31 +128,179 @@ class GuideAssignment:
93
128
  >>> ga = pt.pp.GuideAssignment()
94
129
  >>> ga.assign_to_max_guide(gdo, assignment_threshold=5)
95
130
  """
96
- counts = adata.X if layer is None else adata.layers[layer]
97
- if scipy.sparse.issparse(counts):
98
- counts = counts.toarray()
131
+ raise NotImplementedError(
132
+ f"No implementation found for {type(data)}. Must be numpy array, sparse matrix, or AnnData object."
133
+ )
134
+
135
+ @assign_to_max_guide.register(AnnData)
136
+ def assign_to_max_guide_anndata(
137
+ self,
138
+ adata: AnnData,
139
+ /,
140
+ *,
141
+ assignment_threshold: float,
142
+ layer: str | None = None,
143
+ obs_key: str = "assigned_guide",
144
+ no_grna_assigned_key: str = "Negative",
145
+ ) -> None:
146
+ X = _get_obs_rep(adata, layer=layer)
147
+ guide_assignments = self.assign_to_max_guide(
148
+ X, var=adata.var, assignment_threshold=assignment_threshold, no_grna_assigned_key=no_grna_assigned_key
149
+ )
150
+ adata.obs[obs_key] = guide_assignments
99
151
 
152
+ @assign_to_max_guide.register(np.ndarray)
153
+ def assign_to_max_guide_numpy(
154
+ self,
155
+ X: np.ndarray,
156
+ /,
157
+ *,
158
+ var: pd.DataFrame,
159
+ assignment_threshold: float,
160
+ no_grna_assigned_key: str = "Negative",
161
+ ) -> np.ndarray:
100
162
  assigned_grna = np.where(
101
- counts.max(axis=1).squeeze() >= assignment_threshold,
102
- adata.var.index[counts.argmax(axis=1).squeeze()],
163
+ X.max(axis=1).squeeze() >= assignment_threshold,
164
+ var.index[X.argmax(axis=1).squeeze()],
103
165
  no_grna_assigned_key,
104
166
  )
105
167
 
168
+ return assigned_grna
169
+
170
+ @staticmethod
171
+ @njit(parallel=True)
172
+ def _assign_max_guide_sparse(indptr, data, indices, assignment_threshold, assigned_grna):
173
+ n_rows = len(indptr) - 1
174
+ for i in range(n_rows):
175
+ row_start = indptr[i]
176
+ row_end = indptr[i + 1]
177
+
178
+ if row_end > row_start:
179
+ data_row = data[row_start:row_end]
180
+ indices_row = indices[row_start:row_end]
181
+ max_pos = np.argmax(data_row)
182
+ if data_row[max_pos] >= assignment_threshold:
183
+ assigned_grna[i] = indices_row[max_pos]
184
+ return assigned_grna
185
+
186
+ @assign_to_max_guide.register(CSRBase)
187
+ def assign_to_max_guide_sparse(
188
+ self, X: CSRBase, /, *, var: pd.DataFrame, assignment_threshold: float, no_grna_assigned_key: str = "Negative"
189
+ ) -> np.ndarray:
190
+ n_rows = X.shape[0]
191
+
192
+ assigned_positions = np.zeros(n_rows, dtype=np.int32) - 1 # -1 means not assigned
193
+ assigned_positions = self._assign_max_guide_sparse(
194
+ X.indptr, X.data, X.indices, assignment_threshold, assigned_positions
195
+ )
196
+
197
+ assigned_grna = np.full(n_rows, no_grna_assigned_key, dtype=object)
198
+ mask = assigned_positions >= 0
199
+ var_index_array = np.array(var.index)
200
+ if np.any(mask):
201
+ assigned_grna[mask] = var_index_array[assigned_positions[mask]]
202
+
203
+ return assigned_grna
204
+
205
+ def assign_mixture_model(
206
+ self,
207
+ adata: AnnData,
208
+ model: Literal["poisson_gauss_mixture"] = "poisson_gauss_mixture",
209
+ assigned_guides_key: str = "assigned_guide",
210
+ no_grna_assigned_key: str = "negative",
211
+ max_assignments_per_cell: int = 5,
212
+ multiple_grna_assigned_key: str = "multiple",
213
+ multiple_grna_assignment_string: str = "+",
214
+ only_return_results: bool = False,
215
+ show_progress: bool = False,
216
+ **mixture_model_kwargs,
217
+ ) -> np.ndarray | None:
218
+ """Assigns gRNAs to cells using a mixture model.
219
+
220
+ Args:
221
+ adata: AnnData object containing gRNA values.
222
+ model: The model to use for the mixture model. Currently only `Poisson_Gauss_Mixture` is supported.
223
+ assigned_guides_key: Assigned guide will be saved on adata.obs[output_key].
224
+ no_grna_assigned_key: The key to return if a cell is negative for all gRNAs.
225
+ max_assignments_per_cell: The maximum number of gRNAs that can be assigned to a cell.
226
+ multiple_grna_assigned_key: The key to return if multiple gRNAs are assigned to a cell.
227
+ multiple_grna_assignment_string: The string to use to join multiple gRNAs assigned to a cell.
228
+ only_return_results: Whether input AnnData is not modified and the result is returned as an np.ndarray.
229
+ show_progress: Whether to shows progress bar.
230
+ mixture_model_kwargs: Are passed to the mixture model.
231
+
232
+ Examples:
233
+ >>> import pertpy as pt
234
+ >>> mdata = pt.dt.papalexi_2021()
235
+ >>> gdo = mdata.mod["gdo"]
236
+ >>> ga = pt.pp.GuideAssignment()
237
+ >>> ga.assign_mixture_model(gdo)
238
+ """
239
+ if model == "poisson_gauss_mixture":
240
+ mixture_model = PoissonGaussMixture(**mixture_model_kwargs)
241
+ else:
242
+ raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.")
243
+
244
+ res = pd.DataFrame(0, index=adata.obs_names, columns=adata.var_names)
245
+ fct = track if show_progress else lambda iterable: iterable
246
+ for gene in fct(adata.var_names):
247
+ is_nonzero = (
248
+ np.ravel((adata[:, gene].X != 0).todense()) if issparse(adata.X) else np.ravel(adata[:, gene].X != 0)
249
+ )
250
+ if sum(is_nonzero) < 2:
251
+ warn(f"Skipping {gene} as there are less than 2 cells expressing the guide at all.", stacklevel=2)
252
+ continue
253
+ # We are only fitting the model to the non-zero values, the rest is
254
+ # automatically assigned to the negative class
255
+ data = adata[is_nonzero, gene].X.todense().A1 if issparse(adata.X) else adata[is_nonzero, gene].X
256
+ data = np.ravel(data)
257
+
258
+ if np.any(data < 0):
259
+ raise ValueError(
260
+ "Data contains negative values. Please use non-negative data for guide assignment with the Mixture Model."
261
+ )
262
+
263
+ # Log2 transform the data so positive population is approximately normal
264
+ data = np.log2(data)
265
+ assignments = mixture_model.run_model(data)
266
+ res.loc[adata.obs_names[is_nonzero][assignments == "Positive"], gene] = 1
267
+
268
+ # Add the parameters to the adata.var DataFrame
269
+ for params_name, param in mixture_model.params.items():
270
+ if param.ndim == 0:
271
+ if params_name not in adata.var.columns:
272
+ adata.var[params_name] = np.nan
273
+ adata.var.loc[gene, params_name] = param.item()
274
+ else:
275
+ for i, p in enumerate(param):
276
+ if f"{params_name}_{i}" not in adata.var.columns:
277
+ adata.var[f"{params_name}_{i}"] = np.nan
278
+ adata.var.loc[gene, f"{params_name}_{i}"] = p
279
+
280
+ # Assign guides to cells
281
+ # Some cells might have multiple guides assigned
282
+ series = pd.Series(no_grna_assigned_key, index=adata.obs_names)
283
+ num_guides_assigned = res.sum(1)
284
+ series.loc[(num_guides_assigned <= max_assignments_per_cell) & (num_guides_assigned != 0)] = res.apply(
285
+ lambda row: row.index[row == 1].tolist(), axis=1
286
+ ).str.join(multiple_grna_assignment_string)
287
+ series.loc[num_guides_assigned > max_assignments_per_cell] = multiple_grna_assigned_key
288
+
106
289
  if only_return_results:
107
- return assigned_grna
108
- adata.obs[output_key] = assigned_grna
290
+ return series.values
291
+
292
+ adata.obs[assigned_guides_key] = series.values
109
293
 
110
294
  return None
111
295
 
112
296
  @_doc_params(common_plot_args=doc_common_plot_args)
113
- def plot_heatmap(
297
+ def plot_heatmap( # pragma: no cover # noqa: D417
114
298
  self,
115
299
  adata: AnnData,
116
300
  *,
117
301
  layer: str | None = None,
118
302
  order_by: np.ndarray | str | None = None,
119
303
  key_to_save_order: str = None,
120
- show: bool = True,
121
304
  return_fig: bool = False,
122
305
  **kwargs,
123
306
  ) -> Figure | None:
@@ -159,7 +342,7 @@ class GuideAssignment:
159
342
  data = adata.X if layer is None else adata.layers[layer]
160
343
 
161
344
  if order_by is None:
162
- if scipy.sparse.issparse(data):
345
+ if issparse(data):
163
346
  max_values = data.max(axis=1).toarray().squeeze()
164
347
  data_argmax = data.argmax(axis=1).A.squeeze()
165
348
  max_guide_index = np.where(max_values != data.min(axis=1).toarray().squeeze(), data_argmax, -1)
@@ -194,8 +377,7 @@ class GuideAssignment:
194
377
  finally:
195
378
  del adata.obs[temp_col_name]
196
379
 
197
- if show:
198
- plt.show()
199
380
  if return_fig:
200
381
  return fig
382
+ plt.show()
201
383
  return None
@@ -0,0 +1,177 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Mapping
5
+
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+ from jax.random import PRNGKey
9
+ from jax.scipy.special import logsumexp
10
+ from numpyro import factor, plate, sample
11
+ from numpyro.distributions import Dirichlet, Exponential, HalfNormal, Normal, Poisson
12
+ from numpyro.infer import MCMC, NUTS
13
+
14
+ ParamsDict = Mapping[str, jnp.ndarray]
15
+
16
+
17
+ class MixtureModel(ABC):
18
+ """Abstract base class for 2-component mixture models.
19
+
20
+ Args:
21
+ num_warmup: Number of warmup steps for MCMC sampling.
22
+ num_samples: Number of samples to draw after warmup.
23
+ fraction_positive_expected: Prior belief about fraction of positive components.
24
+ poisson_rate_prior: Rate parameter for exponential prior on Poisson component.
25
+ gaussian_mean_prior: Mean and standard deviation for Gaussian prior on positive component mean.
26
+ gaussian_std_prior: Scale parameter for half-normal prior on positive component std.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ num_warmup: int = 50,
32
+ num_samples: int = 100,
33
+ fraction_positive_expected: float = 0.15,
34
+ poisson_rate_prior: float = 0.2,
35
+ gaussian_mean_prior: tuple[float, float] = (3, 2),
36
+ gaussian_std_prior: float = 1,
37
+ ) -> None:
38
+ self.num_warmup = num_warmup
39
+ self.num_samples = num_samples
40
+ self.fraction_positive_expected = fraction_positive_expected
41
+ self.poisson_rate_prior = poisson_rate_prior
42
+ self.gaussian_mean_prior = gaussian_mean_prior
43
+ self.gaussian_std_prior = gaussian_std_prior
44
+
45
+ @abstractmethod
46
+ def initialize_params(self) -> ParamsDict:
47
+ """Initialize model parameters via sampling from priors.
48
+
49
+ Returns:
50
+ Dictionary of sampled parameter values.
51
+ """
52
+
53
+ @abstractmethod
54
+ def log_likelihood(self, data: jnp.ndarray, params: ParamsDict) -> jnp.ndarray:
55
+ """Calculate log likelihood of data under current parameters.
56
+
57
+ Args:
58
+ data: Input data array.
59
+ params: Current parameter values.
60
+
61
+ Returns:
62
+ Log likelihood values for each datapoint.
63
+ """
64
+
65
+ def fit_model(self, data: jnp.ndarray, seed: int = 0) -> MCMC:
66
+ """Fit the mixture model using MCMC.
67
+
68
+ Args:
69
+ data: Input data to fit.
70
+ seed: Random seed for reproducibility.
71
+
72
+ Returns:
73
+ Fitted MCMC object containing samples.
74
+ """
75
+ nuts_kernel = NUTS(self.mixture_model)
76
+ mcmc = MCMC(nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, progress_bar=False)
77
+ mcmc.run(PRNGKey(seed), data=data)
78
+ return mcmc
79
+
80
+ def run_model(self, data: jnp.ndarray, seed: int = 0) -> np.ndarray:
81
+ """Run model fitting and assign components.
82
+
83
+ Args:
84
+ data: Input data array.
85
+ seed: Random seed.
86
+
87
+ Returns:
88
+ Array of "Positive"/"Negative" assignments for each datapoint.
89
+ """
90
+ self.mcmc = self.fit_model(data, seed)
91
+ self.samples = self.mcmc.get_samples()
92
+ self.assignments = self.assignment(self.samples, data)
93
+ return self.assignments
94
+
95
+ def mixture_model(self, data: jnp.ndarray) -> None:
96
+ """Define mixture model structure for NumPyro.
97
+
98
+ Args:
99
+ data: Input data array.
100
+ """
101
+ params = self.initialize_params()
102
+
103
+ with plate("data", data.shape[0]):
104
+ log_likelihoods = self.log_likelihood(data, params)
105
+ log_mixture_likelihood = logsumexp(log_likelihoods, axis=-1)
106
+ sample("obs", Normal(log_mixture_likelihood, 1.0), obs=data)
107
+
108
+ def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray:
109
+ """Assign data points to mixture components.
110
+
111
+ Args:
112
+ samples: MCMC samples of parameters.
113
+ data: Input data array.
114
+
115
+ Returns:
116
+ Array of component assignments.
117
+ """
118
+ params = {key: samples[key].mean(axis=0) for key in samples}
119
+ self.params = params
120
+
121
+ log_likelihoods = self.log_likelihood(data, params)
122
+ guide_assignments = jnp.argmax(log_likelihoods, axis=-1)
123
+
124
+ assignments = ["Negative" if assign == 0 else "Positive" for assign in guide_assignments]
125
+ return np.array(assignments)
126
+
127
+
128
+ class PoissonGaussMixture(MixtureModel):
129
+ """Mixture model combining Poisson and Gaussian distributions."""
130
+
131
+ def log_likelihood(self, data: np.ndarray, params: ParamsDict) -> jnp.ndarray:
132
+ """Calculate component-wise log likelihoods.
133
+
134
+ Args:
135
+ data: Input data array.
136
+ params: Current parameter values.
137
+
138
+ Returns:
139
+ Log likelihood values for each component.
140
+ """
141
+ poisson_rate = params["poisson_rate"]
142
+ gaussian_mean = params["gaussian_mean"]
143
+ gaussian_std = params["gaussian_std"]
144
+ mix_probs = params["mix_probs"]
145
+
146
+ # We penalize the model for positioning the Poisson component to the right of the Gaussian component
147
+ # by imposing a soft constraint to penalize the Poisson rate being larger than the Gaussian mean
148
+ # Heuristic regularization term to prevent flipping of the components
149
+ factor("separation_penalty", +10 * jnp.heaviside(-poisson_rate + gaussian_mean, 0))
150
+
151
+ log_likelihoods = jnp.stack(
152
+ [
153
+ # Poisson component
154
+ jnp.log(mix_probs[0]) + Poisson(poisson_rate).log_prob(data),
155
+ # Gaussian component
156
+ jnp.log(mix_probs[1]) + Normal(gaussian_mean, gaussian_std).log_prob(data),
157
+ ],
158
+ axis=-1,
159
+ )
160
+
161
+ return log_likelihoods
162
+
163
+ def initialize_params(self) -> ParamsDict:
164
+ """Initialize model parameters via prior sampling.
165
+
166
+ Returns:
167
+ Dictionary of sampled parameter values.
168
+ """
169
+ params = {}
170
+ params["poisson_rate"] = sample("poisson_rate", Exponential(self.poisson_rate_prior))
171
+ params["gaussian_mean"] = sample("gaussian_mean", Normal(*self.gaussian_mean_prior))
172
+ params["gaussian_std"] = sample("gaussian_std", HalfNormal(self.gaussian_std_prior))
173
+ params["mix_probs"] = sample(
174
+ "mix_probs",
175
+ Dirichlet(jnp.array([1 - self.fraction_positive_expected, self.fraction_positive_expected])),
176
+ )
177
+ return params
pertpy/tools/__init__.py CHANGED
@@ -41,7 +41,7 @@ from pertpy.tools._perturbation_space._simple import (
41
41
  )
42
42
  from pertpy.tools._scgen import Scgen
43
43
 
44
- CODA_EXTRAS = ["toytree", "arviz", "ete3"] # also pyqt5 technically
44
+ CODA_EXTRAS = ["toytree", "arviz", "ete4"] # also pyqt6 technically
45
45
  Sccoda = lazy_import("pertpy.tools._coda._sccoda", "Sccoda", CODA_EXTRAS)
46
46
  Tasccoda = lazy_import("pertpy.tools._coda._tasccoda", "Tasccoda", CODA_EXTRAS)
47
47