pertpy 0.9.5__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +5 -1
- pertpy/_doc.py +2 -5
- pertpy/_types.py +6 -0
- pertpy/data/_dataloader.py +68 -24
- pertpy/data/_datasets.py +9 -9
- pertpy/metadata/__init__.py +2 -1
- pertpy/metadata/_cell_line.py +136 -30
- pertpy/metadata/_look_up.py +13 -19
- pertpy/metadata/_moa.py +1 -1
- pertpy/preprocessing/_guide_rna.py +221 -39
- pertpy/preprocessing/_guide_rna_mixture.py +177 -0
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +138 -142
- pertpy/tools/_cinemaot.py +75 -117
- pertpy/tools/_coda/_base_coda.py +150 -174
- pertpy/tools/_coda/_sccoda.py +66 -69
- pertpy/tools/_coda/_tasccoda.py +71 -79
- pertpy/tools/_dialogue.py +60 -56
- pertpy/tools/_differential_gene_expression/_base.py +25 -43
- pertpy/tools/_differential_gene_expression/_checks.py +4 -6
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
- pertpy/tools/_differential_gene_expression/_edger.py +6 -10
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
- pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
- pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
- pertpy/tools/_distances/_distance_tests.py +1 -2
- pertpy/tools/_distances/_distances.py +86 -92
- pertpy/tools/_enrichment.py +8 -25
- pertpy/tools/_milo.py +23 -27
- pertpy/tools/_mixscape.py +261 -175
- pertpy/tools/_perturbation_space/_clustering.py +4 -4
- pertpy/tools/_perturbation_space/_comparison.py +4 -4
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
- pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
- pertpy/tools/_perturbation_space/_simple.py +13 -17
- pertpy/tools/_scgen/_scgen.py +17 -20
- pertpy/tools/_scgen/_scgenvae.py +2 -2
- pertpy/tools/_scgen/_utils.py +3 -1
- {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/METADATA +37 -21
- pertpy-0.11.0.dist-info/RECORD +58 -0
- {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
- pertpy/tools/_kernel_pca.py +0 -50
- pertpy-0.9.5.dist-info/RECORD +0 -57
- {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
@@ -1,44 +1,53 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import uuid
|
4
|
-
from
|
4
|
+
from functools import singledispatchmethod
|
5
|
+
from typing import TYPE_CHECKING, Literal
|
6
|
+
from warnings import warn
|
5
7
|
|
6
8
|
import matplotlib.pyplot as plt
|
7
9
|
import numpy as np
|
8
10
|
import pandas as pd
|
9
11
|
import scanpy as sc
|
10
|
-
import
|
12
|
+
from anndata import AnnData
|
13
|
+
from numba import njit, prange
|
14
|
+
from rich.progress import track
|
15
|
+
from scanpy.get import _get_obs_rep, _set_obs_rep
|
16
|
+
from scipy.sparse import csr_matrix, issparse
|
11
17
|
|
12
18
|
from pertpy._doc import _doc_params, doc_common_plot_args
|
19
|
+
from pertpy._types import CSRBase
|
20
|
+
from pertpy.preprocessing._guide_rna_mixture import PoissonGaussMixture
|
13
21
|
|
14
22
|
if TYPE_CHECKING:
|
15
|
-
from anndata import AnnData
|
16
23
|
from matplotlib.pyplot import Figure
|
17
24
|
|
18
25
|
|
19
26
|
class GuideAssignment:
|
20
|
-
"""
|
27
|
+
"""Assign cells to guide RNAs."""
|
21
28
|
|
29
|
+
@singledispatchmethod
|
22
30
|
def assign_by_threshold(
|
23
31
|
self,
|
24
|
-
|
32
|
+
data: AnnData | np.ndarray | CSRBase,
|
33
|
+
/,
|
34
|
+
*,
|
25
35
|
assignment_threshold: float,
|
26
36
|
layer: str | None = None,
|
27
37
|
output_layer: str = "assigned_guides",
|
28
|
-
|
29
|
-
) -> np.ndarray | None:
|
38
|
+
):
|
30
39
|
"""Simple threshold based gRNA assignment function.
|
31
40
|
|
32
41
|
Each cell is assigned to gRNA with at least `assignment_threshold` counts.
|
33
42
|
This function expects unnormalized data as input.
|
34
43
|
|
35
44
|
Args:
|
36
|
-
|
45
|
+
data: The (annotated) data matrix of shape `n_obs` × `n_vars`.
|
46
|
+
Rows correspond to cells and columns to genes.
|
37
47
|
assignment_threshold: The count threshold that is required for an assignment to be viable.
|
38
48
|
layer: Key to the layer containing raw count values of the gRNAs.
|
39
49
|
adata.X is used if layer is None. Expects count data.
|
40
50
|
output_layer: Assigned guide will be saved on adata.layers[output_key].
|
41
|
-
only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
|
42
51
|
|
43
52
|
Examples:
|
44
53
|
Each cell is assigned to gRNA that occurs at least 5 times in the respective cell.
|
@@ -49,26 +58,52 @@ class GuideAssignment:
|
|
49
58
|
>>> ga = pt.pp.GuideAssignment()
|
50
59
|
>>> ga.assign_by_threshold(gdo, assignment_threshold=5)
|
51
60
|
"""
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
assigned_grnas = np.where(counts >= assignment_threshold, 1, 0)
|
57
|
-
assigned_grnas = scipy.sparse.csr_matrix(assigned_grnas)
|
58
|
-
if only_return_results:
|
59
|
-
return assigned_grnas
|
60
|
-
adata.layers[output_layer] = assigned_grnas
|
61
|
-
|
62
|
-
return None
|
61
|
+
raise NotImplementedError(
|
62
|
+
f"No implementation found for {type(data)}. Must be numpy array, sparse matrix, or AnnData object."
|
63
|
+
)
|
63
64
|
|
64
|
-
|
65
|
+
@assign_by_threshold.register(AnnData)
|
66
|
+
def _assign_by_threshold_anndata(
|
65
67
|
self,
|
66
68
|
adata: AnnData,
|
69
|
+
/,
|
70
|
+
*,
|
67
71
|
assignment_threshold: float,
|
68
72
|
layer: str | None = None,
|
69
|
-
|
70
|
-
|
71
|
-
|
73
|
+
output_layer: str = "assigned_guides",
|
74
|
+
) -> None:
|
75
|
+
X = _get_obs_rep(adata, layer=layer)
|
76
|
+
guide_assignments = self.assign_by_threshold(X, assignment_threshold=assignment_threshold)
|
77
|
+
_set_obs_rep(adata, guide_assignments, layer=output_layer)
|
78
|
+
|
79
|
+
@assign_by_threshold.register(np.ndarray)
|
80
|
+
def _assign_by_threshold_numpy(self, X: np.ndarray, /, *, assignment_threshold: float) -> np.ndarray:
|
81
|
+
return np.where(assignment_threshold <= X, 1, 0)
|
82
|
+
|
83
|
+
@staticmethod
|
84
|
+
@njit(parallel=True)
|
85
|
+
def _threshold_sparse_numba(data: np.ndarray, threshold: float) -> np.ndarray:
|
86
|
+
out = np.zeros_like(data, dtype=np.int8)
|
87
|
+
for i in prange(data.shape[0]):
|
88
|
+
if data[i] >= threshold:
|
89
|
+
out[i] = 1
|
90
|
+
return out
|
91
|
+
|
92
|
+
@assign_by_threshold.register(CSRBase)
|
93
|
+
def _assign_by_threshold_sparse(self, X: CSRBase, /, *, assignment_threshold: float) -> CSRBase:
|
94
|
+
new_data = self._threshold_sparse_numba(X.data, assignment_threshold)
|
95
|
+
return csr_matrix((new_data, X.indices, X.indptr), shape=X.shape)
|
96
|
+
|
97
|
+
@singledispatchmethod
|
98
|
+
def assign_to_max_guide(
|
99
|
+
self,
|
100
|
+
data: AnnData | np.ndarray | CSRBase,
|
101
|
+
/,
|
102
|
+
*,
|
103
|
+
assignment_threshold: float,
|
104
|
+
layer: str | None = None,
|
105
|
+
obs_key: str = "assigned_guide",
|
106
|
+
no_grna_assigned_key: str = "Negative",
|
72
107
|
) -> np.ndarray | None:
|
73
108
|
"""Simple threshold based max gRNA assignment function.
|
74
109
|
|
@@ -76,13 +111,13 @@ class GuideAssignment:
|
|
76
111
|
This function expects unnormalized data as input.
|
77
112
|
|
78
113
|
Args:
|
79
|
-
|
114
|
+
data: The (annotated) data matrix of shape `n_obs` × `n_vars`.
|
115
|
+
Rows correspond to cells and columns to genes.
|
80
116
|
assignment_threshold: The count threshold that is required for an assignment to be viable.
|
81
117
|
layer: Key to the layer containing raw count values of the gRNAs.
|
82
118
|
adata.X is used if layer is None. Expects count data.
|
83
|
-
|
119
|
+
obs_key: Assigned guide will be saved on adata.obs[output_key].
|
84
120
|
no_grna_assigned_key: The key to return if no gRNA is expressed enough.
|
85
|
-
only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
|
86
121
|
|
87
122
|
Examples:
|
88
123
|
Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
|
@@ -93,31 +128,179 @@ class GuideAssignment:
|
|
93
128
|
>>> ga = pt.pp.GuideAssignment()
|
94
129
|
>>> ga.assign_to_max_guide(gdo, assignment_threshold=5)
|
95
130
|
"""
|
96
|
-
|
97
|
-
|
98
|
-
|
131
|
+
raise NotImplementedError(
|
132
|
+
f"No implementation found for {type(data)}. Must be numpy array, sparse matrix, or AnnData object."
|
133
|
+
)
|
134
|
+
|
135
|
+
@assign_to_max_guide.register(AnnData)
|
136
|
+
def assign_to_max_guide_anndata(
|
137
|
+
self,
|
138
|
+
adata: AnnData,
|
139
|
+
/,
|
140
|
+
*,
|
141
|
+
assignment_threshold: float,
|
142
|
+
layer: str | None = None,
|
143
|
+
obs_key: str = "assigned_guide",
|
144
|
+
no_grna_assigned_key: str = "Negative",
|
145
|
+
) -> None:
|
146
|
+
X = _get_obs_rep(adata, layer=layer)
|
147
|
+
guide_assignments = self.assign_to_max_guide(
|
148
|
+
X, var=adata.var, assignment_threshold=assignment_threshold, no_grna_assigned_key=no_grna_assigned_key
|
149
|
+
)
|
150
|
+
adata.obs[obs_key] = guide_assignments
|
99
151
|
|
152
|
+
@assign_to_max_guide.register(np.ndarray)
|
153
|
+
def assign_to_max_guide_numpy(
|
154
|
+
self,
|
155
|
+
X: np.ndarray,
|
156
|
+
/,
|
157
|
+
*,
|
158
|
+
var: pd.DataFrame,
|
159
|
+
assignment_threshold: float,
|
160
|
+
no_grna_assigned_key: str = "Negative",
|
161
|
+
) -> np.ndarray:
|
100
162
|
assigned_grna = np.where(
|
101
|
-
|
102
|
-
|
163
|
+
X.max(axis=1).squeeze() >= assignment_threshold,
|
164
|
+
var.index[X.argmax(axis=1).squeeze()],
|
103
165
|
no_grna_assigned_key,
|
104
166
|
)
|
105
167
|
|
168
|
+
return assigned_grna
|
169
|
+
|
170
|
+
@staticmethod
|
171
|
+
@njit(parallel=True)
|
172
|
+
def _assign_max_guide_sparse(indptr, data, indices, assignment_threshold, assigned_grna):
|
173
|
+
n_rows = len(indptr) - 1
|
174
|
+
for i in range(n_rows):
|
175
|
+
row_start = indptr[i]
|
176
|
+
row_end = indptr[i + 1]
|
177
|
+
|
178
|
+
if row_end > row_start:
|
179
|
+
data_row = data[row_start:row_end]
|
180
|
+
indices_row = indices[row_start:row_end]
|
181
|
+
max_pos = np.argmax(data_row)
|
182
|
+
if data_row[max_pos] >= assignment_threshold:
|
183
|
+
assigned_grna[i] = indices_row[max_pos]
|
184
|
+
return assigned_grna
|
185
|
+
|
186
|
+
@assign_to_max_guide.register(CSRBase)
|
187
|
+
def assign_to_max_guide_sparse(
|
188
|
+
self, X: CSRBase, /, *, var: pd.DataFrame, assignment_threshold: float, no_grna_assigned_key: str = "Negative"
|
189
|
+
) -> np.ndarray:
|
190
|
+
n_rows = X.shape[0]
|
191
|
+
|
192
|
+
assigned_positions = np.zeros(n_rows, dtype=np.int32) - 1 # -1 means not assigned
|
193
|
+
assigned_positions = self._assign_max_guide_sparse(
|
194
|
+
X.indptr, X.data, X.indices, assignment_threshold, assigned_positions
|
195
|
+
)
|
196
|
+
|
197
|
+
assigned_grna = np.full(n_rows, no_grna_assigned_key, dtype=object)
|
198
|
+
mask = assigned_positions >= 0
|
199
|
+
var_index_array = np.array(var.index)
|
200
|
+
if np.any(mask):
|
201
|
+
assigned_grna[mask] = var_index_array[assigned_positions[mask]]
|
202
|
+
|
203
|
+
return assigned_grna
|
204
|
+
|
205
|
+
def assign_mixture_model(
|
206
|
+
self,
|
207
|
+
adata: AnnData,
|
208
|
+
model: Literal["poisson_gauss_mixture"] = "poisson_gauss_mixture",
|
209
|
+
assigned_guides_key: str = "assigned_guide",
|
210
|
+
no_grna_assigned_key: str = "negative",
|
211
|
+
max_assignments_per_cell: int = 5,
|
212
|
+
multiple_grna_assigned_key: str = "multiple",
|
213
|
+
multiple_grna_assignment_string: str = "+",
|
214
|
+
only_return_results: bool = False,
|
215
|
+
show_progress: bool = False,
|
216
|
+
**mixture_model_kwargs,
|
217
|
+
) -> np.ndarray | None:
|
218
|
+
"""Assigns gRNAs to cells using a mixture model.
|
219
|
+
|
220
|
+
Args:
|
221
|
+
adata: AnnData object containing gRNA values.
|
222
|
+
model: The model to use for the mixture model. Currently only `Poisson_Gauss_Mixture` is supported.
|
223
|
+
assigned_guides_key: Assigned guide will be saved on adata.obs[output_key].
|
224
|
+
no_grna_assigned_key: The key to return if a cell is negative for all gRNAs.
|
225
|
+
max_assignments_per_cell: The maximum number of gRNAs that can be assigned to a cell.
|
226
|
+
multiple_grna_assigned_key: The key to return if multiple gRNAs are assigned to a cell.
|
227
|
+
multiple_grna_assignment_string: The string to use to join multiple gRNAs assigned to a cell.
|
228
|
+
only_return_results: Whether input AnnData is not modified and the result is returned as an np.ndarray.
|
229
|
+
show_progress: Whether to shows progress bar.
|
230
|
+
mixture_model_kwargs: Are passed to the mixture model.
|
231
|
+
|
232
|
+
Examples:
|
233
|
+
>>> import pertpy as pt
|
234
|
+
>>> mdata = pt.dt.papalexi_2021()
|
235
|
+
>>> gdo = mdata.mod["gdo"]
|
236
|
+
>>> ga = pt.pp.GuideAssignment()
|
237
|
+
>>> ga.assign_mixture_model(gdo)
|
238
|
+
"""
|
239
|
+
if model == "poisson_gauss_mixture":
|
240
|
+
mixture_model = PoissonGaussMixture(**mixture_model_kwargs)
|
241
|
+
else:
|
242
|
+
raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.")
|
243
|
+
|
244
|
+
res = pd.DataFrame(0, index=adata.obs_names, columns=adata.var_names)
|
245
|
+
fct = track if show_progress else lambda iterable: iterable
|
246
|
+
for gene in fct(adata.var_names):
|
247
|
+
is_nonzero = (
|
248
|
+
np.ravel((adata[:, gene].X != 0).todense()) if issparse(adata.X) else np.ravel(adata[:, gene].X != 0)
|
249
|
+
)
|
250
|
+
if sum(is_nonzero) < 2:
|
251
|
+
warn(f"Skipping {gene} as there are less than 2 cells expressing the guide at all.", stacklevel=2)
|
252
|
+
continue
|
253
|
+
# We are only fitting the model to the non-zero values, the rest is
|
254
|
+
# automatically assigned to the negative class
|
255
|
+
data = adata[is_nonzero, gene].X.todense().A1 if issparse(adata.X) else adata[is_nonzero, gene].X
|
256
|
+
data = np.ravel(data)
|
257
|
+
|
258
|
+
if np.any(data < 0):
|
259
|
+
raise ValueError(
|
260
|
+
"Data contains negative values. Please use non-negative data for guide assignment with the Mixture Model."
|
261
|
+
)
|
262
|
+
|
263
|
+
# Log2 transform the data so positive population is approximately normal
|
264
|
+
data = np.log2(data)
|
265
|
+
assignments = mixture_model.run_model(data)
|
266
|
+
res.loc[adata.obs_names[is_nonzero][assignments == "Positive"], gene] = 1
|
267
|
+
|
268
|
+
# Add the parameters to the adata.var DataFrame
|
269
|
+
for params_name, param in mixture_model.params.items():
|
270
|
+
if param.ndim == 0:
|
271
|
+
if params_name not in adata.var.columns:
|
272
|
+
adata.var[params_name] = np.nan
|
273
|
+
adata.var.loc[gene, params_name] = param.item()
|
274
|
+
else:
|
275
|
+
for i, p in enumerate(param):
|
276
|
+
if f"{params_name}_{i}" not in adata.var.columns:
|
277
|
+
adata.var[f"{params_name}_{i}"] = np.nan
|
278
|
+
adata.var.loc[gene, f"{params_name}_{i}"] = p
|
279
|
+
|
280
|
+
# Assign guides to cells
|
281
|
+
# Some cells might have multiple guides assigned
|
282
|
+
series = pd.Series(no_grna_assigned_key, index=adata.obs_names)
|
283
|
+
num_guides_assigned = res.sum(1)
|
284
|
+
series.loc[(num_guides_assigned <= max_assignments_per_cell) & (num_guides_assigned != 0)] = res.apply(
|
285
|
+
lambda row: row.index[row == 1].tolist(), axis=1
|
286
|
+
).str.join(multiple_grna_assignment_string)
|
287
|
+
series.loc[num_guides_assigned > max_assignments_per_cell] = multiple_grna_assigned_key
|
288
|
+
|
106
289
|
if only_return_results:
|
107
|
-
return
|
108
|
-
|
290
|
+
return series.values
|
291
|
+
|
292
|
+
adata.obs[assigned_guides_key] = series.values
|
109
293
|
|
110
294
|
return None
|
111
295
|
|
112
296
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
113
|
-
def plot_heatmap(
|
297
|
+
def plot_heatmap( # pragma: no cover # noqa: D417
|
114
298
|
self,
|
115
299
|
adata: AnnData,
|
116
300
|
*,
|
117
301
|
layer: str | None = None,
|
118
302
|
order_by: np.ndarray | str | None = None,
|
119
303
|
key_to_save_order: str = None,
|
120
|
-
show: bool = True,
|
121
304
|
return_fig: bool = False,
|
122
305
|
**kwargs,
|
123
306
|
) -> Figure | None:
|
@@ -159,7 +342,7 @@ class GuideAssignment:
|
|
159
342
|
data = adata.X if layer is None else adata.layers[layer]
|
160
343
|
|
161
344
|
if order_by is None:
|
162
|
-
if
|
345
|
+
if issparse(data):
|
163
346
|
max_values = data.max(axis=1).toarray().squeeze()
|
164
347
|
data_argmax = data.argmax(axis=1).A.squeeze()
|
165
348
|
max_guide_index = np.where(max_values != data.min(axis=1).toarray().squeeze(), data_argmax, -1)
|
@@ -194,8 +377,7 @@ class GuideAssignment:
|
|
194
377
|
finally:
|
195
378
|
del adata.obs[temp_col_name]
|
196
379
|
|
197
|
-
if show:
|
198
|
-
plt.show()
|
199
380
|
if return_fig:
|
200
381
|
return fig
|
382
|
+
plt.show()
|
201
383
|
return None
|
@@ -0,0 +1,177 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from collections.abc import Mapping
|
5
|
+
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import numpy as np
|
8
|
+
from jax.random import PRNGKey
|
9
|
+
from jax.scipy.special import logsumexp
|
10
|
+
from numpyro import factor, plate, sample
|
11
|
+
from numpyro.distributions import Dirichlet, Exponential, HalfNormal, Normal, Poisson
|
12
|
+
from numpyro.infer import MCMC, NUTS
|
13
|
+
|
14
|
+
ParamsDict = Mapping[str, jnp.ndarray]
|
15
|
+
|
16
|
+
|
17
|
+
class MixtureModel(ABC):
|
18
|
+
"""Abstract base class for 2-component mixture models.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
num_warmup: Number of warmup steps for MCMC sampling.
|
22
|
+
num_samples: Number of samples to draw after warmup.
|
23
|
+
fraction_positive_expected: Prior belief about fraction of positive components.
|
24
|
+
poisson_rate_prior: Rate parameter for exponential prior on Poisson component.
|
25
|
+
gaussian_mean_prior: Mean and standard deviation for Gaussian prior on positive component mean.
|
26
|
+
gaussian_std_prior: Scale parameter for half-normal prior on positive component std.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
num_warmup: int = 50,
|
32
|
+
num_samples: int = 100,
|
33
|
+
fraction_positive_expected: float = 0.15,
|
34
|
+
poisson_rate_prior: float = 0.2,
|
35
|
+
gaussian_mean_prior: tuple[float, float] = (3, 2),
|
36
|
+
gaussian_std_prior: float = 1,
|
37
|
+
) -> None:
|
38
|
+
self.num_warmup = num_warmup
|
39
|
+
self.num_samples = num_samples
|
40
|
+
self.fraction_positive_expected = fraction_positive_expected
|
41
|
+
self.poisson_rate_prior = poisson_rate_prior
|
42
|
+
self.gaussian_mean_prior = gaussian_mean_prior
|
43
|
+
self.gaussian_std_prior = gaussian_std_prior
|
44
|
+
|
45
|
+
@abstractmethod
|
46
|
+
def initialize_params(self) -> ParamsDict:
|
47
|
+
"""Initialize model parameters via sampling from priors.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
Dictionary of sampled parameter values.
|
51
|
+
"""
|
52
|
+
|
53
|
+
@abstractmethod
|
54
|
+
def log_likelihood(self, data: jnp.ndarray, params: ParamsDict) -> jnp.ndarray:
|
55
|
+
"""Calculate log likelihood of data under current parameters.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
data: Input data array.
|
59
|
+
params: Current parameter values.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
Log likelihood values for each datapoint.
|
63
|
+
"""
|
64
|
+
|
65
|
+
def fit_model(self, data: jnp.ndarray, seed: int = 0) -> MCMC:
|
66
|
+
"""Fit the mixture model using MCMC.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
data: Input data to fit.
|
70
|
+
seed: Random seed for reproducibility.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
Fitted MCMC object containing samples.
|
74
|
+
"""
|
75
|
+
nuts_kernel = NUTS(self.mixture_model)
|
76
|
+
mcmc = MCMC(nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, progress_bar=False)
|
77
|
+
mcmc.run(PRNGKey(seed), data=data)
|
78
|
+
return mcmc
|
79
|
+
|
80
|
+
def run_model(self, data: jnp.ndarray, seed: int = 0) -> np.ndarray:
|
81
|
+
"""Run model fitting and assign components.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
data: Input data array.
|
85
|
+
seed: Random seed.
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
Array of "Positive"/"Negative" assignments for each datapoint.
|
89
|
+
"""
|
90
|
+
self.mcmc = self.fit_model(data, seed)
|
91
|
+
self.samples = self.mcmc.get_samples()
|
92
|
+
self.assignments = self.assignment(self.samples, data)
|
93
|
+
return self.assignments
|
94
|
+
|
95
|
+
def mixture_model(self, data: jnp.ndarray) -> None:
|
96
|
+
"""Define mixture model structure for NumPyro.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
data: Input data array.
|
100
|
+
"""
|
101
|
+
params = self.initialize_params()
|
102
|
+
|
103
|
+
with plate("data", data.shape[0]):
|
104
|
+
log_likelihoods = self.log_likelihood(data, params)
|
105
|
+
log_mixture_likelihood = logsumexp(log_likelihoods, axis=-1)
|
106
|
+
sample("obs", Normal(log_mixture_likelihood, 1.0), obs=data)
|
107
|
+
|
108
|
+
def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray:
|
109
|
+
"""Assign data points to mixture components.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
samples: MCMC samples of parameters.
|
113
|
+
data: Input data array.
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
Array of component assignments.
|
117
|
+
"""
|
118
|
+
params = {key: samples[key].mean(axis=0) for key in samples}
|
119
|
+
self.params = params
|
120
|
+
|
121
|
+
log_likelihoods = self.log_likelihood(data, params)
|
122
|
+
guide_assignments = jnp.argmax(log_likelihoods, axis=-1)
|
123
|
+
|
124
|
+
assignments = ["Negative" if assign == 0 else "Positive" for assign in guide_assignments]
|
125
|
+
return np.array(assignments)
|
126
|
+
|
127
|
+
|
128
|
+
class PoissonGaussMixture(MixtureModel):
|
129
|
+
"""Mixture model combining Poisson and Gaussian distributions."""
|
130
|
+
|
131
|
+
def log_likelihood(self, data: np.ndarray, params: ParamsDict) -> jnp.ndarray:
|
132
|
+
"""Calculate component-wise log likelihoods.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
data: Input data array.
|
136
|
+
params: Current parameter values.
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
Log likelihood values for each component.
|
140
|
+
"""
|
141
|
+
poisson_rate = params["poisson_rate"]
|
142
|
+
gaussian_mean = params["gaussian_mean"]
|
143
|
+
gaussian_std = params["gaussian_std"]
|
144
|
+
mix_probs = params["mix_probs"]
|
145
|
+
|
146
|
+
# We penalize the model for positioning the Poisson component to the right of the Gaussian component
|
147
|
+
# by imposing a soft constraint to penalize the Poisson rate being larger than the Gaussian mean
|
148
|
+
# Heuristic regularization term to prevent flipping of the components
|
149
|
+
factor("separation_penalty", +10 * jnp.heaviside(-poisson_rate + gaussian_mean, 0))
|
150
|
+
|
151
|
+
log_likelihoods = jnp.stack(
|
152
|
+
[
|
153
|
+
# Poisson component
|
154
|
+
jnp.log(mix_probs[0]) + Poisson(poisson_rate).log_prob(data),
|
155
|
+
# Gaussian component
|
156
|
+
jnp.log(mix_probs[1]) + Normal(gaussian_mean, gaussian_std).log_prob(data),
|
157
|
+
],
|
158
|
+
axis=-1,
|
159
|
+
)
|
160
|
+
|
161
|
+
return log_likelihoods
|
162
|
+
|
163
|
+
def initialize_params(self) -> ParamsDict:
|
164
|
+
"""Initialize model parameters via prior sampling.
|
165
|
+
|
166
|
+
Returns:
|
167
|
+
Dictionary of sampled parameter values.
|
168
|
+
"""
|
169
|
+
params = {}
|
170
|
+
params["poisson_rate"] = sample("poisson_rate", Exponential(self.poisson_rate_prior))
|
171
|
+
params["gaussian_mean"] = sample("gaussian_mean", Normal(*self.gaussian_mean_prior))
|
172
|
+
params["gaussian_std"] = sample("gaussian_std", HalfNormal(self.gaussian_std_prior))
|
173
|
+
params["mix_probs"] = sample(
|
174
|
+
"mix_probs",
|
175
|
+
Dirichlet(jnp.array([1 - self.fraction_positive_expected, self.fraction_positive_expected])),
|
176
|
+
)
|
177
|
+
return params
|
pertpy/tools/__init__.py
CHANGED
@@ -41,7 +41,7 @@ from pertpy.tools._perturbation_space._simple import (
|
|
41
41
|
)
|
42
42
|
from pertpy.tools._scgen import Scgen
|
43
43
|
|
44
|
-
CODA_EXTRAS = ["toytree", "arviz", "
|
44
|
+
CODA_EXTRAS = ["toytree", "arviz", "ete4"] # also pyqt6 technically
|
45
45
|
Sccoda = lazy_import("pertpy.tools._coda._sccoda", "Sccoda", CODA_EXTRAS)
|
46
46
|
Tasccoda = lazy_import("pertpy.tools._coda._tasccoda", "Tasccoda", CODA_EXTRAS)
|
47
47
|
|