pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +19 -0
- pertpy/data/_datasets.py +1 -1
- pertpy/metadata/_cell_line.py +18 -8
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +114 -13
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +64 -86
- pertpy/tools/_cinemaot.py +21 -17
- pertpy/tools/_coda/_base_coda.py +90 -117
- pertpy/tools/_dialogue.py +32 -40
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +486 -112
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +71 -56
- pertpy/tools/_enrichment.py +16 -8
- pertpy/tools/_milo.py +54 -50
- pertpy/tools/_mixscape.py +307 -208
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +48 -0
- pertpy/tools/_scgen/_scgen.py +35 -27
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/METADATA +6 -6
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/RECORD +29 -28
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
pertpy/__init__.py
CHANGED
pertpy/_doc.py
ADDED
@@ -0,0 +1,19 @@
|
|
1
|
+
from textwrap import dedent
|
2
|
+
|
3
|
+
|
4
|
+
def _doc_params(**kwds): # pragma: no cover
|
5
|
+
"""\
|
6
|
+
Docstrings should start with "\" in the first line for proper formatting.
|
7
|
+
"""
|
8
|
+
|
9
|
+
def dec(obj):
|
10
|
+
obj.__orig_doc__ = obj.__doc__
|
11
|
+
obj.__doc__ = dedent(obj.__doc__.format_map(kwds))
|
12
|
+
return obj
|
13
|
+
|
14
|
+
return dec
|
15
|
+
|
16
|
+
|
17
|
+
doc_common_plot_args = """\
|
18
|
+
return_fig: if `True`, returns figure of the plot, that can be used for saving.\
|
19
|
+
"""
|
pertpy/data/_datasets.py
CHANGED
@@ -66,7 +66,7 @@ def sc_sim_augur() -> AnnData: # pragma: no cover
|
|
66
66
|
output_file_path = settings.datasetdir / output_file_name
|
67
67
|
if not Path(output_file_path).exists():
|
68
68
|
_download(
|
69
|
-
url="https://figshare.com/ndownloader/files/
|
69
|
+
url="https://figshare.com/ndownloader/files/49828902",
|
70
70
|
output_file_name=output_file_name,
|
71
71
|
output_path=settings.datasetdir,
|
72
72
|
is_zip=False,
|
pertpy/metadata/_cell_line.py
CHANGED
@@ -8,12 +8,15 @@ from lamin_utils import logger
|
|
8
8
|
if TYPE_CHECKING:
|
9
9
|
from collections.abc import Iterable
|
10
10
|
|
11
|
+
from matplotlib.pyplot import Figure
|
12
|
+
|
11
13
|
import matplotlib.pyplot as plt
|
12
14
|
import numpy as np
|
13
15
|
import pandas as pd
|
14
16
|
from scanpy import settings
|
15
17
|
from scipy import stats
|
16
18
|
|
19
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
17
20
|
from pertpy.data._dataloader import _download
|
18
21
|
|
19
22
|
from ._look_up import LookUp
|
@@ -338,8 +341,8 @@ class CellLine(MetaData):
|
|
338
341
|
# then we can compare these keys and fetch the corresponding metadata.
|
339
342
|
if query_id not in adata.obs.columns:
|
340
343
|
raise ValueError(
|
341
|
-
f"The specified `query_id` {query_id} can't be found in the `adata.obs
|
342
|
-
"Ensure that you are using one of the available query IDs present in the adata.obs for the annotation
|
344
|
+
f"The specified `query_id` {query_id} can't be found in the `adata.obs`. \n"
|
345
|
+
"Ensure that you are using one of the available query IDs present in the adata.obs for the annotation."
|
343
346
|
"If the desired query ID is not available, you can fetch the cell line metadata "
|
344
347
|
"using the `annotate()` function before calling 'annotate_bulk_rna()'. "
|
345
348
|
"This ensures that the required query ID is included in your data, e.g. stripped_cell_line_name, DepMap ID."
|
@@ -356,9 +359,8 @@ class CellLine(MetaData):
|
|
356
359
|
else:
|
357
360
|
reference_id = "DepMap_ID"
|
358
361
|
logger.warning(
|
359
|
-
"To annotate bulk RNA data from Broad Institue, `DepMap_ID` is used as default reference and query identifier if no `reference_id` is given
|
360
|
-
"
|
361
|
-
"Alternatively, use `annotate()` to annotate the cell line first "
|
362
|
+
"To annotate bulk RNA data from Broad Institue, `DepMap_ID` is used as default reference and query identifier if no `reference_id` is given."
|
363
|
+
"If `DepMap_ID` isn't available in 'adata.obs', use `annotate()` to annotate the cell line first."
|
362
364
|
)
|
363
365
|
if self.bulk_rna_broad is None:
|
364
366
|
self._download_bulk_rna(cell_line_source="broad")
|
@@ -690,6 +692,7 @@ class CellLine(MetaData):
|
|
690
692
|
|
691
693
|
return corr, pvals, new_corr, new_pvals
|
692
694
|
|
695
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
693
696
|
def plot_correlation(
|
694
697
|
self,
|
695
698
|
adata: AnnData,
|
@@ -700,7 +703,8 @@ class CellLine(MetaData):
|
|
700
703
|
metadata_key: str = "bulk_rna_broad",
|
701
704
|
category: str = "cell line",
|
702
705
|
subset_identifier: str | int | Iterable[str] | Iterable[int] | None = None,
|
703
|
-
|
706
|
+
return_fig: bool = False,
|
707
|
+
) -> Figure | None:
|
704
708
|
"""Visualise the correlation of cell lines with annotated metadata.
|
705
709
|
|
706
710
|
Args:
|
@@ -713,6 +717,8 @@ class CellLine(MetaData):
|
|
713
717
|
subset_identifier: Selected identifiers for scatter plot visualization between the X matrix and `metadata_key`.
|
714
718
|
If not None, only the chosen cell line will be plotted, either specified as a value in `identifier` (string) or as an index number.
|
715
719
|
If None, all cell lines will be plotted.
|
720
|
+
{common_plot_args}
|
721
|
+
|
716
722
|
Returns:
|
717
723
|
Pearson correlation coefficients and their corresponding p-values for matched and unmatched cell lines separately.
|
718
724
|
"""
|
@@ -740,7 +746,7 @@ class CellLine(MetaData):
|
|
740
746
|
if all(isinstance(id, str) for id in subset_identifier_list):
|
741
747
|
if set(subset_identifier_list).issubset(adata.obs[identifier].unique()):
|
742
748
|
subset_identifier_list = np.where(
|
743
|
-
np.
|
749
|
+
np.isin(adata.obs[identifier].values, subset_identifier_list)
|
744
750
|
)[0]
|
745
751
|
else:
|
746
752
|
raise ValueError("`Subset_identifier` must be found in adata.obs.`identifier`.")
|
@@ -790,6 +796,10 @@ class CellLine(MetaData):
|
|
790
796
|
"edgecolor": "black",
|
791
797
|
},
|
792
798
|
)
|
799
|
+
|
800
|
+
if return_fig:
|
801
|
+
return plt.gcf()
|
793
802
|
plt.show()
|
803
|
+
return None
|
794
804
|
else:
|
795
|
-
raise NotImplementedError
|
805
|
+
raise NotImplementedError("Only 'cell line' category is supported for correlation comparison.")
|
pertpy/metadata/_compound.py
CHANGED
@@ -42,7 +42,7 @@ class Compound(MetaData):
|
|
42
42
|
adata = adata.copy()
|
43
43
|
|
44
44
|
if query_id not in adata.obs.columns:
|
45
|
-
raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n
|
45
|
+
raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n Please check again.")
|
46
46
|
|
47
47
|
query_dict = {}
|
48
48
|
not_matched_identifiers = []
|
@@ -84,7 +84,7 @@ class Compound(MetaData):
|
|
84
84
|
query_df = pd.DataFrame.from_dict(query_dict, orient="index", columns=["pubchem_name", "pubchem_ID", "smiles"])
|
85
85
|
# Merge and remove duplicate columns
|
86
86
|
# Column is converted to float after merging due to unmatches
|
87
|
-
# Convert back to integers
|
87
|
+
# Convert back to integers afterwards
|
88
88
|
if query_id_type == "cid":
|
89
89
|
query_df.pubchem_ID = query_df.pubchem_ID.astype("Int64")
|
90
90
|
adata.obs = (
|
@@ -119,8 +119,7 @@ class Compound(MetaData):
|
|
119
119
|
|
120
120
|
The LookUp object provides an overview of the metadata to annotate.
|
121
121
|
Each annotate_{metadata} function has a corresponding lookup function in the LookUp object,
|
122
|
-
where users can search the reference_id in the metadata and
|
123
|
-
compare with the query_id in their own data.
|
122
|
+
where users can search the reference_id in the metadata and compare with the query_id in their own data.
|
124
123
|
|
125
124
|
Returns:
|
126
125
|
Returns a LookUp object specific for compound annotation.
|
pertpy/metadata/_metadata.py
CHANGED
@@ -62,7 +62,7 @@ class MetaData:
|
|
62
62
|
if verbosity > 0:
|
63
63
|
logger.info(
|
64
64
|
f"There are {total_identifiers} identifiers in `adata.obs`."
|
65
|
-
f"However, {len(unmatched_identifiers)} identifiers can't be found in the {metadata_type} annotation,"
|
65
|
+
f"However, {len(unmatched_identifiers)} identifiers can't be found in the {metadata_type} annotation, "
|
66
66
|
"leading to the presence of NA values for their respective metadata.\n"
|
67
67
|
f"Please check again: *unmatched_identifiers[:verbosity]..."
|
68
68
|
)
|
@@ -1,20 +1,27 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import uuid
|
4
|
-
from typing import TYPE_CHECKING
|
4
|
+
from typing import TYPE_CHECKING, Literal
|
5
|
+
from warnings import warn
|
5
6
|
|
7
|
+
import matplotlib.pyplot as plt
|
6
8
|
import numpy as np
|
7
9
|
import pandas as pd
|
8
10
|
import scanpy as sc
|
9
11
|
import scipy
|
12
|
+
from rich.progress import track
|
13
|
+
from scipy.sparse import issparse
|
14
|
+
|
15
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
16
|
+
from pertpy.preprocessing._guide_rna_mixture import PoissonGaussMixture
|
10
17
|
|
11
18
|
if TYPE_CHECKING:
|
12
19
|
from anndata import AnnData
|
13
|
-
from matplotlib.
|
20
|
+
from matplotlib.pyplot import Figure
|
14
21
|
|
15
22
|
|
16
23
|
class GuideAssignment:
|
17
|
-
"""
|
24
|
+
"""Assign cells to guide RNAs."""
|
18
25
|
|
19
26
|
def assign_by_threshold(
|
20
27
|
self,
|
@@ -30,12 +37,12 @@ class GuideAssignment:
|
|
30
37
|
This function expects unnormalized data as input.
|
31
38
|
|
32
39
|
Args:
|
33
|
-
adata:
|
40
|
+
adata: AnnData object containing gRNA values.
|
34
41
|
assignment_threshold: The count threshold that is required for an assignment to be viable.
|
35
42
|
layer: Key to the layer containing raw count values of the gRNAs.
|
36
43
|
adata.X is used if layer is None. Expects count data.
|
37
44
|
output_layer: Assigned guide will be saved on adata.layers[output_key].
|
38
|
-
only_return_results:
|
45
|
+
only_return_results: Whether to input AnnData is not modified and the result is returned as an :class:`np.ndarray`.
|
39
46
|
|
40
47
|
Examples:
|
41
48
|
Each cell is assigned to gRNA that occurs at least 5 times in the respective cell.
|
@@ -64,7 +71,7 @@ class GuideAssignment:
|
|
64
71
|
assignment_threshold: float,
|
65
72
|
layer: str | None = None,
|
66
73
|
output_key: str = "assigned_guide",
|
67
|
-
no_grna_assigned_key: str = "
|
74
|
+
no_grna_assigned_key: str = "Negative",
|
68
75
|
only_return_results: bool = False,
|
69
76
|
) -> np.ndarray | None:
|
70
77
|
"""Simple threshold based max gRNA assignment function.
|
@@ -73,13 +80,13 @@ class GuideAssignment:
|
|
73
80
|
This function expects unnormalized data as input.
|
74
81
|
|
75
82
|
Args:
|
76
|
-
adata:
|
83
|
+
adata: AnnData object containing gRNA values.
|
77
84
|
assignment_threshold: The count threshold that is required for an assignment to be viable.
|
78
85
|
layer: Key to the layer containing raw count values of the gRNAs.
|
79
86
|
adata.X is used if layer is None. Expects count data.
|
80
87
|
output_key: Assigned guide will be saved on adata.obs[output_key]. default value is `assigned_guide`.
|
81
88
|
no_grna_assigned_key: The key to return if no gRNA is expressed enough.
|
82
|
-
only_return_results:
|
89
|
+
only_return_results: Whether to input AnnData is not modified and the result is returned as an np.ndarray.
|
83
90
|
|
84
91
|
Examples:
|
85
92
|
Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
|
@@ -106,14 +113,103 @@ class GuideAssignment:
|
|
106
113
|
|
107
114
|
return None
|
108
115
|
|
116
|
+
def assign_mixture_model(
|
117
|
+
self,
|
118
|
+
adata: AnnData,
|
119
|
+
model: Literal["poisson_gauss_mixture"] = "poisson_gauss_mixture",
|
120
|
+
assigned_guides_key: str = "assigned_guide",
|
121
|
+
no_grna_assigned_key: str = "negative",
|
122
|
+
max_assignments_per_cell: int = 5,
|
123
|
+
multiple_grna_assigned_key: str = "multiple",
|
124
|
+
multiple_grna_assignment_string: str = "+",
|
125
|
+
only_return_results: bool = False,
|
126
|
+
uns_key: str = "guide_assignment_params",
|
127
|
+
show_progress: bool = False,
|
128
|
+
**mixture_model_kwargs,
|
129
|
+
) -> np.ndarray | None:
|
130
|
+
"""Assigns gRNAs to cells using a mixture model.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
adata: AnnData object containing gRNA values.
|
134
|
+
model: The model to use for the mixture model. Currently only `Poisson_Gauss_Mixture` is supported.
|
135
|
+
output_key: Assigned guide will be saved on adata.obs[output_key].
|
136
|
+
no_grna_assigned_key: The key to return if a cell is negative for all gRNAs.
|
137
|
+
max_assignments_per_cell: The maximum number of gRNAs that can be assigned to a cell.
|
138
|
+
multiple_grna_assigned_key: The key to return if multiple gRNAs are assigned to a cell.
|
139
|
+
multiple_grna_assignment_string: The string to use to join multiple gRNAs assigned to a cell.
|
140
|
+
only_return_results: Whether input AnnData is not modified and the result is returned as an np.ndarray.
|
141
|
+
show_progress: Whether to shows progress bar.
|
142
|
+
mixture_model_kwargs: Are passed to the mixture model.
|
143
|
+
|
144
|
+
Examples:
|
145
|
+
>>> import pertpy as pt
|
146
|
+
>>> mdata = pt.dt.papalexi_2021()
|
147
|
+
>>> gdo = mdata.mod["gdo"]
|
148
|
+
>>> ga = pt.pp.GuideAssignment()
|
149
|
+
>>> ga.assign_mixture_model(gdo)
|
150
|
+
"""
|
151
|
+
if model == "poisson_gauss_mixture":
|
152
|
+
mixture_model = PoissonGaussMixture(**mixture_model_kwargs)
|
153
|
+
else:
|
154
|
+
raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.")
|
155
|
+
|
156
|
+
if uns_key not in adata.uns:
|
157
|
+
adata.uns[uns_key] = {}
|
158
|
+
elif type(adata.uns[uns_key]) is not dict:
|
159
|
+
raise ValueError(f"adata.uns['{uns_key}'] should be a dictionary. Please remove it or change the key.")
|
160
|
+
|
161
|
+
res = pd.DataFrame(0, index=adata.obs_names, columns=adata.var_names)
|
162
|
+
fct = track if show_progress else lambda iterable: iterable
|
163
|
+
for gene in fct(adata.var_names):
|
164
|
+
is_nonzero = (
|
165
|
+
np.ravel((adata[:, gene].X != 0).todense()) if issparse(adata.X) else np.ravel(adata[:, gene].X != 0)
|
166
|
+
)
|
167
|
+
if sum(is_nonzero) < 2:
|
168
|
+
warn(f"Skipping {gene} as there are less than 2 cells expressing the guide at all.", stacklevel=2)
|
169
|
+
continue
|
170
|
+
# We are only fitting the model to the non-zero values, the rest is
|
171
|
+
# automatically assigned to the negative class
|
172
|
+
data = adata[is_nonzero, gene].X.todense().A1 if issparse(adata.X) else adata[is_nonzero, gene].X
|
173
|
+
data = np.ravel(data)
|
174
|
+
|
175
|
+
if np.any(data < 0):
|
176
|
+
raise ValueError(
|
177
|
+
"Data contains negative values. Please use non-negative data for guide assignment with the Mixture Model."
|
178
|
+
)
|
179
|
+
|
180
|
+
# Log2 transform the data so positive population is approximately normal
|
181
|
+
data = np.log2(data)
|
182
|
+
assignments = mixture_model.run_model(data)
|
183
|
+
res.loc[adata.obs_names[is_nonzero][assignments == "Positive"], gene] = 1
|
184
|
+
adata.uns[uns_key][gene] = mixture_model.params
|
185
|
+
|
186
|
+
# Assign guides to cells
|
187
|
+
# Some cells might have multiple guides assigned
|
188
|
+
series = pd.Series(no_grna_assigned_key, index=adata.obs_names)
|
189
|
+
num_guides_assigned = res.sum(1)
|
190
|
+
series.loc[(num_guides_assigned <= max_assignments_per_cell) & (num_guides_assigned != 0)] = res.apply(
|
191
|
+
lambda row: row.index[row == 1].tolist(), axis=1
|
192
|
+
).str.join(multiple_grna_assignment_string)
|
193
|
+
series.loc[num_guides_assigned > max_assignments_per_cell] = multiple_grna_assigned_key
|
194
|
+
|
195
|
+
if only_return_results:
|
196
|
+
return series.values
|
197
|
+
|
198
|
+
adata.obs[assigned_guides_key] = series.values
|
199
|
+
|
200
|
+
return None
|
201
|
+
|
202
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
109
203
|
def plot_heatmap(
|
110
204
|
self,
|
111
205
|
adata: AnnData,
|
206
|
+
*,
|
112
207
|
layer: str | None = None,
|
113
208
|
order_by: np.ndarray | str | None = None,
|
114
209
|
key_to_save_order: str = None,
|
210
|
+
return_fig: bool = False,
|
115
211
|
**kwargs,
|
116
|
-
) ->
|
212
|
+
) -> Figure | None:
|
117
213
|
"""Heatmap plotting of guide RNA expression matrix.
|
118
214
|
|
119
215
|
Assuming guides have sparse expression, this function reorders cells
|
@@ -131,11 +227,12 @@ class GuideAssignment:
|
|
131
227
|
If a string is provided, adata.obs[order_by] will be used as the order.
|
132
228
|
If a numpy array is provided, the array will be used for ordering.
|
133
229
|
key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None.
|
230
|
+
{common_plot_args}
|
134
231
|
kwargs: Are passed to sc.pl.heatmap.
|
135
232
|
|
136
233
|
Returns:
|
137
|
-
|
138
|
-
Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided.
|
234
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
235
|
+
Order of cells in the y-axis will be saved on `adata.obs[key_to_save_order]` if provided.
|
139
236
|
|
140
237
|
Examples:
|
141
238
|
Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then
|
@@ -172,7 +269,7 @@ class GuideAssignment:
|
|
172
269
|
adata.obs[key_to_save_order] = pd.Categorical(order)
|
173
270
|
|
174
271
|
try:
|
175
|
-
|
272
|
+
fig = sc.pl.heatmap(
|
176
273
|
adata[order, :],
|
177
274
|
var_names=adata.var.index.tolist(),
|
178
275
|
groupby=temp_col_name,
|
@@ -180,9 +277,13 @@ class GuideAssignment:
|
|
180
277
|
use_raw=False,
|
181
278
|
dendrogram=False,
|
182
279
|
layer=layer,
|
280
|
+
show=False,
|
183
281
|
**kwargs,
|
184
282
|
)
|
185
283
|
finally:
|
186
284
|
del adata.obs[temp_col_name]
|
187
285
|
|
188
|
-
|
286
|
+
if return_fig:
|
287
|
+
return fig
|
288
|
+
plt.show()
|
289
|
+
return None
|
@@ -0,0 +1,179 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from collections.abc import Mapping
|
5
|
+
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import numpy as np
|
9
|
+
import numpyro
|
10
|
+
import numpyro.distributions as dist
|
11
|
+
from jax import random
|
12
|
+
from numpyro.infer import MCMC, NUTS
|
13
|
+
|
14
|
+
ParamsDict = Mapping[str, jnp.ndarray]
|
15
|
+
|
16
|
+
|
17
|
+
class MixtureModel(ABC):
|
18
|
+
"""Abstract base class for 2-component mixture models.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
num_warmup: Number of warmup steps for MCMC sampling.
|
22
|
+
num_samples: Number of samples to draw after warmup.
|
23
|
+
fraction_positive_expected: Prior belief about fraction of positive components.
|
24
|
+
poisson_rate_prior: Rate parameter for exponential prior on Poisson component.
|
25
|
+
gaussian_mean_prior: Mean and standard deviation for Gaussian prior on positive component mean.
|
26
|
+
gaussian_std_prior: Scale parameter for half-normal prior on positive component std.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
num_warmup: int = 50,
|
32
|
+
num_samples: int = 100,
|
33
|
+
fraction_positive_expected: float = 0.15,
|
34
|
+
poisson_rate_prior: float = 0.2,
|
35
|
+
gaussian_mean_prior: tuple[float, float] = (3, 2),
|
36
|
+
gaussian_std_prior: float = 1,
|
37
|
+
) -> None:
|
38
|
+
self.num_warmup = num_warmup
|
39
|
+
self.num_samples = num_samples
|
40
|
+
self.fraction_positive_expected = fraction_positive_expected
|
41
|
+
self.poisson_rate_prior = poisson_rate_prior
|
42
|
+
self.gaussian_mean_prior = gaussian_mean_prior
|
43
|
+
self.gaussian_std_prior = gaussian_std_prior
|
44
|
+
|
45
|
+
@abstractmethod
|
46
|
+
def initialize_params(self) -> ParamsDict:
|
47
|
+
"""Initialize model parameters via sampling from priors.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
Dictionary of sampled parameter values.
|
51
|
+
"""
|
52
|
+
pass
|
53
|
+
|
54
|
+
@abstractmethod
|
55
|
+
def log_likelihood(self, data: jnp.ndarray, params: ParamsDict) -> jnp.ndarray:
|
56
|
+
"""Calculate log likelihood of data under current parameters.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
data: Input data array.
|
60
|
+
params: Current parameter values.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Log likelihood values for each datapoint.
|
64
|
+
"""
|
65
|
+
pass
|
66
|
+
|
67
|
+
def fit_model(self, data: jnp.ndarray, seed: int = 0) -> MCMC:
|
68
|
+
"""Fit the mixture model using MCMC.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
data: Input data to fit.
|
72
|
+
seed: Random seed for reproducibility.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
Fitted MCMC object containing samples.
|
76
|
+
"""
|
77
|
+
nuts_kernel = NUTS(self.mixture_model)
|
78
|
+
mcmc = MCMC(nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, progress_bar=False)
|
79
|
+
mcmc.run(random.PRNGKey(seed), data=data)
|
80
|
+
return mcmc
|
81
|
+
|
82
|
+
def run_model(self, data: jnp.ndarray, seed: int = 0) -> np.ndarray:
|
83
|
+
"""Run model fitting and assign components.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
data: Input data array.
|
87
|
+
seed: Random seed.
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
Array of "Positive"/"Negative" assignments for each datapoint.
|
91
|
+
"""
|
92
|
+
self.mcmc = self.fit_model(data, seed)
|
93
|
+
self.samples = self.mcmc.get_samples()
|
94
|
+
self.assignments = self.assignment(self.samples, data)
|
95
|
+
return self.assignments
|
96
|
+
|
97
|
+
def mixture_model(self, data: jnp.ndarray) -> None:
|
98
|
+
"""Define mixture model structure for NumPyro.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
data: Input data array.
|
102
|
+
"""
|
103
|
+
params = self.initialize_params()
|
104
|
+
|
105
|
+
with numpyro.plate("data", data.shape[0]):
|
106
|
+
log_likelihoods = self.log_likelihood(data, params)
|
107
|
+
log_mixture_likelihood = jax.scipy.special.logsumexp(log_likelihoods, axis=-1)
|
108
|
+
numpyro.sample("obs", dist.Normal(log_mixture_likelihood, 1.0), obs=data)
|
109
|
+
|
110
|
+
def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray:
|
111
|
+
"""Assign data points to mixture components.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
samples: MCMC samples of parameters.
|
115
|
+
data: Input data array.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
Array of component assignments.
|
119
|
+
"""
|
120
|
+
params = {key: samples[key].mean(axis=0) for key in samples.keys()}
|
121
|
+
self.params = params
|
122
|
+
|
123
|
+
log_likelihoods = self.log_likelihood(data, params)
|
124
|
+
guide_assignments = jnp.argmax(log_likelihoods, axis=-1)
|
125
|
+
|
126
|
+
assignments = ["Negative" if assign == 0 else "Positive" for assign in guide_assignments]
|
127
|
+
return np.array(assignments)
|
128
|
+
|
129
|
+
|
130
|
+
class PoissonGaussMixture(MixtureModel):
|
131
|
+
"""Mixture model combining Poisson and Gaussian distributions."""
|
132
|
+
|
133
|
+
def log_likelihood(self, data: np.ndarray, params: ParamsDict) -> jnp.ndarray:
|
134
|
+
"""Calculate component-wise log likelihoods.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
data: Input data array.
|
138
|
+
params: Current parameter values.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
Log likelihood values for each component.
|
142
|
+
"""
|
143
|
+
poisson_rate = params["poisson_rate"]
|
144
|
+
gaussian_mean = params["gaussian_mean"]
|
145
|
+
gaussian_std = params["gaussian_std"]
|
146
|
+
mix_probs = params["mix_probs"]
|
147
|
+
|
148
|
+
# We penalize the model for positioning the Poisson component to the right of the Gaussian component
|
149
|
+
# by imposing a soft constraint to penalize the Poisson rate being larger than the Gaussian mean
|
150
|
+
# Heuristic regularization term to prevent flipping of the components
|
151
|
+
numpyro.factor("separation_penalty", +10 * jnp.heaviside(-poisson_rate + gaussian_mean, 0))
|
152
|
+
|
153
|
+
log_likelihoods = jnp.stack(
|
154
|
+
[
|
155
|
+
# Poisson component
|
156
|
+
jnp.log(mix_probs[0]) + dist.Poisson(poisson_rate).log_prob(data),
|
157
|
+
# Gaussian component
|
158
|
+
jnp.log(mix_probs[1]) + dist.Normal(gaussian_mean, gaussian_std).log_prob(data),
|
159
|
+
],
|
160
|
+
axis=-1,
|
161
|
+
)
|
162
|
+
|
163
|
+
return log_likelihoods
|
164
|
+
|
165
|
+
def initialize_params(self) -> ParamsDict:
|
166
|
+
"""Initialize model parameters via prior sampling.
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
Dictionary of sampled parameter values.
|
170
|
+
"""
|
171
|
+
params = {}
|
172
|
+
params["poisson_rate"] = numpyro.sample("poisson_rate", dist.Exponential(self.poisson_rate_prior))
|
173
|
+
params["gaussian_mean"] = numpyro.sample("gaussian_mean", dist.Normal(*self.gaussian_mean_prior))
|
174
|
+
params["gaussian_std"] = numpyro.sample("gaussian_std", dist.HalfNormal(self.gaussian_std_prior))
|
175
|
+
params["mix_probs"] = numpyro.sample(
|
176
|
+
"mix_probs",
|
177
|
+
dist.Dirichlet(jnp.array([1 - self.fraction_positive_expected, self.fraction_positive_expected])),
|
178
|
+
)
|
179
|
+
return params
|
pertpy/tools/__init__.py
CHANGED
@@ -46,7 +46,7 @@ Sccoda = lazy_import("pertpy.tools._coda._sccoda", "Sccoda", CODA_EXTRAS)
|
|
46
46
|
Tasccoda = lazy_import("pertpy.tools._coda._tasccoda", "Tasccoda", CODA_EXTRAS)
|
47
47
|
|
48
48
|
DE_EXTRAS = ["formulaic", "pydeseq2"]
|
49
|
-
EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS)
|
49
|
+
EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2
|
50
50
|
PyDESeq2 = lazy_import("pertpy.tools._differential_gene_expression", "PyDESeq2", DE_EXTRAS)
|
51
51
|
Statsmodels = lazy_import("pertpy.tools._differential_gene_expression", "Statsmodels", DE_EXTRAS + ["statsmodels"])
|
52
52
|
TTest = lazy_import("pertpy.tools._differential_gene_expression", "TTest", DE_EXTRAS)
|