pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +19 -0
- pertpy/data/_datasets.py +1 -1
- pertpy/metadata/_cell_line.py +18 -8
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +114 -13
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +64 -86
- pertpy/tools/_cinemaot.py +21 -17
- pertpy/tools/_coda/_base_coda.py +90 -117
- pertpy/tools/_dialogue.py +32 -40
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +486 -112
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +71 -56
- pertpy/tools/_enrichment.py +16 -8
- pertpy/tools/_milo.py +54 -50
- pertpy/tools/_mixscape.py +307 -208
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +48 -0
- pertpy/tools/_scgen/_scgen.py +35 -27
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/METADATA +6 -6
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/RECORD +29 -28
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
pertpy/__init__.py
CHANGED
pertpy/_doc.py
ADDED
@@ -0,0 +1,19 @@
|
|
1
|
+
from textwrap import dedent
|
2
|
+
|
3
|
+
|
4
|
+
def _doc_params(**kwds): # pragma: no cover
|
5
|
+
"""\
|
6
|
+
Docstrings should start with "\" in the first line for proper formatting.
|
7
|
+
"""
|
8
|
+
|
9
|
+
def dec(obj):
|
10
|
+
obj.__orig_doc__ = obj.__doc__
|
11
|
+
obj.__doc__ = dedent(obj.__doc__.format_map(kwds))
|
12
|
+
return obj
|
13
|
+
|
14
|
+
return dec
|
15
|
+
|
16
|
+
|
17
|
+
doc_common_plot_args = """\
|
18
|
+
return_fig: if `True`, returns figure of the plot, that can be used for saving.\
|
19
|
+
"""
|
pertpy/data/_datasets.py
CHANGED
@@ -66,7 +66,7 @@ def sc_sim_augur() -> AnnData: # pragma: no cover
|
|
66
66
|
output_file_path = settings.datasetdir / output_file_name
|
67
67
|
if not Path(output_file_path).exists():
|
68
68
|
_download(
|
69
|
-
url="https://figshare.com/ndownloader/files/
|
69
|
+
url="https://figshare.com/ndownloader/files/49828902",
|
70
70
|
output_file_name=output_file_name,
|
71
71
|
output_path=settings.datasetdir,
|
72
72
|
is_zip=False,
|
pertpy/metadata/_cell_line.py
CHANGED
@@ -8,12 +8,15 @@ from lamin_utils import logger
|
|
8
8
|
if TYPE_CHECKING:
|
9
9
|
from collections.abc import Iterable
|
10
10
|
|
11
|
+
from matplotlib.pyplot import Figure
|
12
|
+
|
11
13
|
import matplotlib.pyplot as plt
|
12
14
|
import numpy as np
|
13
15
|
import pandas as pd
|
14
16
|
from scanpy import settings
|
15
17
|
from scipy import stats
|
16
18
|
|
19
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
17
20
|
from pertpy.data._dataloader import _download
|
18
21
|
|
19
22
|
from ._look_up import LookUp
|
@@ -338,8 +341,8 @@ class CellLine(MetaData):
|
|
338
341
|
# then we can compare these keys and fetch the corresponding metadata.
|
339
342
|
if query_id not in adata.obs.columns:
|
340
343
|
raise ValueError(
|
341
|
-
f"The specified `query_id` {query_id} can't be found in the `adata.obs
|
342
|
-
"Ensure that you are using one of the available query IDs present in the adata.obs for the annotation
|
344
|
+
f"The specified `query_id` {query_id} can't be found in the `adata.obs`. \n"
|
345
|
+
"Ensure that you are using one of the available query IDs present in the adata.obs for the annotation."
|
343
346
|
"If the desired query ID is not available, you can fetch the cell line metadata "
|
344
347
|
"using the `annotate()` function before calling 'annotate_bulk_rna()'. "
|
345
348
|
"This ensures that the required query ID is included in your data, e.g. stripped_cell_line_name, DepMap ID."
|
@@ -356,9 +359,8 @@ class CellLine(MetaData):
|
|
356
359
|
else:
|
357
360
|
reference_id = "DepMap_ID"
|
358
361
|
logger.warning(
|
359
|
-
"To annotate bulk RNA data from Broad Institue, `DepMap_ID` is used as default reference and query identifier if no `reference_id` is given
|
360
|
-
"
|
361
|
-
"Alternatively, use `annotate()` to annotate the cell line first "
|
362
|
+
"To annotate bulk RNA data from Broad Institue, `DepMap_ID` is used as default reference and query identifier if no `reference_id` is given."
|
363
|
+
"If `DepMap_ID` isn't available in 'adata.obs', use `annotate()` to annotate the cell line first."
|
362
364
|
)
|
363
365
|
if self.bulk_rna_broad is None:
|
364
366
|
self._download_bulk_rna(cell_line_source="broad")
|
@@ -690,6 +692,7 @@ class CellLine(MetaData):
|
|
690
692
|
|
691
693
|
return corr, pvals, new_corr, new_pvals
|
692
694
|
|
695
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
693
696
|
def plot_correlation(
|
694
697
|
self,
|
695
698
|
adata: AnnData,
|
@@ -700,7 +703,8 @@ class CellLine(MetaData):
|
|
700
703
|
metadata_key: str = "bulk_rna_broad",
|
701
704
|
category: str = "cell line",
|
702
705
|
subset_identifier: str | int | Iterable[str] | Iterable[int] | None = None,
|
703
|
-
|
706
|
+
return_fig: bool = False,
|
707
|
+
) -> Figure | None:
|
704
708
|
"""Visualise the correlation of cell lines with annotated metadata.
|
705
709
|
|
706
710
|
Args:
|
@@ -713,6 +717,8 @@ class CellLine(MetaData):
|
|
713
717
|
subset_identifier: Selected identifiers for scatter plot visualization between the X matrix and `metadata_key`.
|
714
718
|
If not None, only the chosen cell line will be plotted, either specified as a value in `identifier` (string) or as an index number.
|
715
719
|
If None, all cell lines will be plotted.
|
720
|
+
{common_plot_args}
|
721
|
+
|
716
722
|
Returns:
|
717
723
|
Pearson correlation coefficients and their corresponding p-values for matched and unmatched cell lines separately.
|
718
724
|
"""
|
@@ -740,7 +746,7 @@ class CellLine(MetaData):
|
|
740
746
|
if all(isinstance(id, str) for id in subset_identifier_list):
|
741
747
|
if set(subset_identifier_list).issubset(adata.obs[identifier].unique()):
|
742
748
|
subset_identifier_list = np.where(
|
743
|
-
np.
|
749
|
+
np.isin(adata.obs[identifier].values, subset_identifier_list)
|
744
750
|
)[0]
|
745
751
|
else:
|
746
752
|
raise ValueError("`Subset_identifier` must be found in adata.obs.`identifier`.")
|
@@ -790,6 +796,10 @@ class CellLine(MetaData):
|
|
790
796
|
"edgecolor": "black",
|
791
797
|
},
|
792
798
|
)
|
799
|
+
|
800
|
+
if return_fig:
|
801
|
+
return plt.gcf()
|
793
802
|
plt.show()
|
803
|
+
return None
|
794
804
|
else:
|
795
|
-
raise NotImplementedError
|
805
|
+
raise NotImplementedError("Only 'cell line' category is supported for correlation comparison.")
|
pertpy/metadata/_compound.py
CHANGED
@@ -42,7 +42,7 @@ class Compound(MetaData):
|
|
42
42
|
adata = adata.copy()
|
43
43
|
|
44
44
|
if query_id not in adata.obs.columns:
|
45
|
-
raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n
|
45
|
+
raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n Please check again.")
|
46
46
|
|
47
47
|
query_dict = {}
|
48
48
|
not_matched_identifiers = []
|
@@ -84,7 +84,7 @@ class Compound(MetaData):
|
|
84
84
|
query_df = pd.DataFrame.from_dict(query_dict, orient="index", columns=["pubchem_name", "pubchem_ID", "smiles"])
|
85
85
|
# Merge and remove duplicate columns
|
86
86
|
# Column is converted to float after merging due to unmatches
|
87
|
-
# Convert back to integers
|
87
|
+
# Convert back to integers afterwards
|
88
88
|
if query_id_type == "cid":
|
89
89
|
query_df.pubchem_ID = query_df.pubchem_ID.astype("Int64")
|
90
90
|
adata.obs = (
|
@@ -119,8 +119,7 @@ class Compound(MetaData):
|
|
119
119
|
|
120
120
|
The LookUp object provides an overview of the metadata to annotate.
|
121
121
|
Each annotate_{metadata} function has a corresponding lookup function in the LookUp object,
|
122
|
-
where users can search the reference_id in the metadata and
|
123
|
-
compare with the query_id in their own data.
|
122
|
+
where users can search the reference_id in the metadata and compare with the query_id in their own data.
|
124
123
|
|
125
124
|
Returns:
|
126
125
|
Returns a LookUp object specific for compound annotation.
|
pertpy/metadata/_metadata.py
CHANGED
@@ -62,7 +62,7 @@ class MetaData:
|
|
62
62
|
if verbosity > 0:
|
63
63
|
logger.info(
|
64
64
|
f"There are {total_identifiers} identifiers in `adata.obs`."
|
65
|
-
f"However, {len(unmatched_identifiers)} identifiers can't be found in the {metadata_type} annotation,"
|
65
|
+
f"However, {len(unmatched_identifiers)} identifiers can't be found in the {metadata_type} annotation, "
|
66
66
|
"leading to the presence of NA values for their respective metadata.\n"
|
67
67
|
f"Please check again: *unmatched_identifiers[:verbosity]..."
|
68
68
|
)
|
@@ -1,20 +1,27 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import uuid
|
4
|
-
from typing import TYPE_CHECKING
|
4
|
+
from typing import TYPE_CHECKING, Literal
|
5
|
+
from warnings import warn
|
5
6
|
|
7
|
+
import matplotlib.pyplot as plt
|
6
8
|
import numpy as np
|
7
9
|
import pandas as pd
|
8
10
|
import scanpy as sc
|
9
11
|
import scipy
|
12
|
+
from rich.progress import track
|
13
|
+
from scipy.sparse import issparse
|
14
|
+
|
15
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
16
|
+
from pertpy.preprocessing._guide_rna_mixture import PoissonGaussMixture
|
10
17
|
|
11
18
|
if TYPE_CHECKING:
|
12
19
|
from anndata import AnnData
|
13
|
-
from matplotlib.
|
20
|
+
from matplotlib.pyplot import Figure
|
14
21
|
|
15
22
|
|
16
23
|
class GuideAssignment:
|
17
|
-
"""
|
24
|
+
"""Assign cells to guide RNAs."""
|
18
25
|
|
19
26
|
def assign_by_threshold(
|
20
27
|
self,
|
@@ -30,12 +37,12 @@ class GuideAssignment:
|
|
30
37
|
This function expects unnormalized data as input.
|
31
38
|
|
32
39
|
Args:
|
33
|
-
adata:
|
40
|
+
adata: AnnData object containing gRNA values.
|
34
41
|
assignment_threshold: The count threshold that is required for an assignment to be viable.
|
35
42
|
layer: Key to the layer containing raw count values of the gRNAs.
|
36
43
|
adata.X is used if layer is None. Expects count data.
|
37
44
|
output_layer: Assigned guide will be saved on adata.layers[output_key].
|
38
|
-
only_return_results:
|
45
|
+
only_return_results: Whether to input AnnData is not modified and the result is returned as an :class:`np.ndarray`.
|
39
46
|
|
40
47
|
Examples:
|
41
48
|
Each cell is assigned to gRNA that occurs at least 5 times in the respective cell.
|
@@ -64,7 +71,7 @@ class GuideAssignment:
|
|
64
71
|
assignment_threshold: float,
|
65
72
|
layer: str | None = None,
|
66
73
|
output_key: str = "assigned_guide",
|
67
|
-
no_grna_assigned_key: str = "
|
74
|
+
no_grna_assigned_key: str = "Negative",
|
68
75
|
only_return_results: bool = False,
|
69
76
|
) -> np.ndarray | None:
|
70
77
|
"""Simple threshold based max gRNA assignment function.
|
@@ -73,13 +80,13 @@ class GuideAssignment:
|
|
73
80
|
This function expects unnormalized data as input.
|
74
81
|
|
75
82
|
Args:
|
76
|
-
adata:
|
83
|
+
adata: AnnData object containing gRNA values.
|
77
84
|
assignment_threshold: The count threshold that is required for an assignment to be viable.
|
78
85
|
layer: Key to the layer containing raw count values of the gRNAs.
|
79
86
|
adata.X is used if layer is None. Expects count data.
|
80
87
|
output_key: Assigned guide will be saved on adata.obs[output_key]. default value is `assigned_guide`.
|
81
88
|
no_grna_assigned_key: The key to return if no gRNA is expressed enough.
|
82
|
-
only_return_results:
|
89
|
+
only_return_results: Whether to input AnnData is not modified and the result is returned as an np.ndarray.
|
83
90
|
|
84
91
|
Examples:
|
85
92
|
Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
|
@@ -106,14 +113,103 @@ class GuideAssignment:
|
|
106
113
|
|
107
114
|
return None
|
108
115
|
|
116
|
+
def assign_mixture_model(
|
117
|
+
self,
|
118
|
+
adata: AnnData,
|
119
|
+
model: Literal["poisson_gauss_mixture"] = "poisson_gauss_mixture",
|
120
|
+
assigned_guides_key: str = "assigned_guide",
|
121
|
+
no_grna_assigned_key: str = "negative",
|
122
|
+
max_assignments_per_cell: int = 5,
|
123
|
+
multiple_grna_assigned_key: str = "multiple",
|
124
|
+
multiple_grna_assignment_string: str = "+",
|
125
|
+
only_return_results: bool = False,
|
126
|
+
uns_key: str = "guide_assignment_params",
|
127
|
+
show_progress: bool = False,
|
128
|
+
**mixture_model_kwargs,
|
129
|
+
) -> np.ndarray | None:
|
130
|
+
"""Assigns gRNAs to cells using a mixture model.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
adata: AnnData object containing gRNA values.
|
134
|
+
model: The model to use for the mixture model. Currently only `Poisson_Gauss_Mixture` is supported.
|
135
|
+
output_key: Assigned guide will be saved on adata.obs[output_key].
|
136
|
+
no_grna_assigned_key: The key to return if a cell is negative for all gRNAs.
|
137
|
+
max_assignments_per_cell: The maximum number of gRNAs that can be assigned to a cell.
|
138
|
+
multiple_grna_assigned_key: The key to return if multiple gRNAs are assigned to a cell.
|
139
|
+
multiple_grna_assignment_string: The string to use to join multiple gRNAs assigned to a cell.
|
140
|
+
only_return_results: Whether input AnnData is not modified and the result is returned as an np.ndarray.
|
141
|
+
show_progress: Whether to shows progress bar.
|
142
|
+
mixture_model_kwargs: Are passed to the mixture model.
|
143
|
+
|
144
|
+
Examples:
|
145
|
+
>>> import pertpy as pt
|
146
|
+
>>> mdata = pt.dt.papalexi_2021()
|
147
|
+
>>> gdo = mdata.mod["gdo"]
|
148
|
+
>>> ga = pt.pp.GuideAssignment()
|
149
|
+
>>> ga.assign_mixture_model(gdo)
|
150
|
+
"""
|
151
|
+
if model == "poisson_gauss_mixture":
|
152
|
+
mixture_model = PoissonGaussMixture(**mixture_model_kwargs)
|
153
|
+
else:
|
154
|
+
raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.")
|
155
|
+
|
156
|
+
if uns_key not in adata.uns:
|
157
|
+
adata.uns[uns_key] = {}
|
158
|
+
elif type(adata.uns[uns_key]) is not dict:
|
159
|
+
raise ValueError(f"adata.uns['{uns_key}'] should be a dictionary. Please remove it or change the key.")
|
160
|
+
|
161
|
+
res = pd.DataFrame(0, index=adata.obs_names, columns=adata.var_names)
|
162
|
+
fct = track if show_progress else lambda iterable: iterable
|
163
|
+
for gene in fct(adata.var_names):
|
164
|
+
is_nonzero = (
|
165
|
+
np.ravel((adata[:, gene].X != 0).todense()) if issparse(adata.X) else np.ravel(adata[:, gene].X != 0)
|
166
|
+
)
|
167
|
+
if sum(is_nonzero) < 2:
|
168
|
+
warn(f"Skipping {gene} as there are less than 2 cells expressing the guide at all.", stacklevel=2)
|
169
|
+
continue
|
170
|
+
# We are only fitting the model to the non-zero values, the rest is
|
171
|
+
# automatically assigned to the negative class
|
172
|
+
data = adata[is_nonzero, gene].X.todense().A1 if issparse(adata.X) else adata[is_nonzero, gene].X
|
173
|
+
data = np.ravel(data)
|
174
|
+
|
175
|
+
if np.any(data < 0):
|
176
|
+
raise ValueError(
|
177
|
+
"Data contains negative values. Please use non-negative data for guide assignment with the Mixture Model."
|
178
|
+
)
|
179
|
+
|
180
|
+
# Log2 transform the data so positive population is approximately normal
|
181
|
+
data = np.log2(data)
|
182
|
+
assignments = mixture_model.run_model(data)
|
183
|
+
res.loc[adata.obs_names[is_nonzero][assignments == "Positive"], gene] = 1
|
184
|
+
adata.uns[uns_key][gene] = mixture_model.params
|
185
|
+
|
186
|
+
# Assign guides to cells
|
187
|
+
# Some cells might have multiple guides assigned
|
188
|
+
series = pd.Series(no_grna_assigned_key, index=adata.obs_names)
|
189
|
+
num_guides_assigned = res.sum(1)
|
190
|
+
series.loc[(num_guides_assigned <= max_assignments_per_cell) & (num_guides_assigned != 0)] = res.apply(
|
191
|
+
lambda row: row.index[row == 1].tolist(), axis=1
|
192
|
+
).str.join(multiple_grna_assignment_string)
|
193
|
+
series.loc[num_guides_assigned > max_assignments_per_cell] = multiple_grna_assigned_key
|
194
|
+
|
195
|
+
if only_return_results:
|
196
|
+
return series.values
|
197
|
+
|
198
|
+
adata.obs[assigned_guides_key] = series.values
|
199
|
+
|
200
|
+
return None
|
201
|
+
|
202
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
109
203
|
def plot_heatmap(
|
110
204
|
self,
|
111
205
|
adata: AnnData,
|
206
|
+
*,
|
112
207
|
layer: str | None = None,
|
113
208
|
order_by: np.ndarray | str | None = None,
|
114
209
|
key_to_save_order: str = None,
|
210
|
+
return_fig: bool = False,
|
115
211
|
**kwargs,
|
116
|
-
) ->
|
212
|
+
) -> Figure | None:
|
117
213
|
"""Heatmap plotting of guide RNA expression matrix.
|
118
214
|
|
119
215
|
Assuming guides have sparse expression, this function reorders cells
|
@@ -131,11 +227,12 @@ class GuideAssignment:
|
|
131
227
|
If a string is provided, adata.obs[order_by] will be used as the order.
|
132
228
|
If a numpy array is provided, the array will be used for ordering.
|
133
229
|
key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None.
|
230
|
+
{common_plot_args}
|
134
231
|
kwargs: Are passed to sc.pl.heatmap.
|
135
232
|
|
136
233
|
Returns:
|
137
|
-
|
138
|
-
Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided.
|
234
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
235
|
+
Order of cells in the y-axis will be saved on `adata.obs[key_to_save_order]` if provided.
|
139
236
|
|
140
237
|
Examples:
|
141
238
|
Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then
|
@@ -172,7 +269,7 @@ class GuideAssignment:
|
|
172
269
|
adata.obs[key_to_save_order] = pd.Categorical(order)
|
173
270
|
|
174
271
|
try:
|
175
|
-
|
272
|
+
fig = sc.pl.heatmap(
|
176
273
|
adata[order, :],
|
177
274
|
var_names=adata.var.index.tolist(),
|
178
275
|
groupby=temp_col_name,
|
@@ -180,9 +277,13 @@ class GuideAssignment:
|
|
180
277
|
use_raw=False,
|
181
278
|
dendrogram=False,
|
182
279
|
layer=layer,
|
280
|
+
show=False,
|
183
281
|
**kwargs,
|
184
282
|
)
|
185
283
|
finally:
|
186
284
|
del adata.obs[temp_col_name]
|
187
285
|
|
188
|
-
|
286
|
+
if return_fig:
|
287
|
+
return fig
|
288
|
+
plt.show()
|
289
|
+
return None
|
@@ -0,0 +1,179 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from collections.abc import Mapping
|
5
|
+
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import numpy as np
|
9
|
+
import numpyro
|
10
|
+
import numpyro.distributions as dist
|
11
|
+
from jax import random
|
12
|
+
from numpyro.infer import MCMC, NUTS
|
13
|
+
|
14
|
+
ParamsDict = Mapping[str, jnp.ndarray]
|
15
|
+
|
16
|
+
|
17
|
+
class MixtureModel(ABC):
|
18
|
+
"""Abstract base class for 2-component mixture models.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
num_warmup: Number of warmup steps for MCMC sampling.
|
22
|
+
num_samples: Number of samples to draw after warmup.
|
23
|
+
fraction_positive_expected: Prior belief about fraction of positive components.
|
24
|
+
poisson_rate_prior: Rate parameter for exponential prior on Poisson component.
|
25
|
+
gaussian_mean_prior: Mean and standard deviation for Gaussian prior on positive component mean.
|
26
|
+
gaussian_std_prior: Scale parameter for half-normal prior on positive component std.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
num_warmup: int = 50,
|
32
|
+
num_samples: int = 100,
|
33
|
+
fraction_positive_expected: float = 0.15,
|
34
|
+
poisson_rate_prior: float = 0.2,
|
35
|
+
gaussian_mean_prior: tuple[float, float] = (3, 2),
|
36
|
+
gaussian_std_prior: float = 1,
|
37
|
+
) -> None:
|
38
|
+
self.num_warmup = num_warmup
|
39
|
+
self.num_samples = num_samples
|
40
|
+
self.fraction_positive_expected = fraction_positive_expected
|
41
|
+
self.poisson_rate_prior = poisson_rate_prior
|
42
|
+
self.gaussian_mean_prior = gaussian_mean_prior
|
43
|
+
self.gaussian_std_prior = gaussian_std_prior
|
44
|
+
|
45
|
+
@abstractmethod
|
46
|
+
def initialize_params(self) -> ParamsDict:
|
47
|
+
"""Initialize model parameters via sampling from priors.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
Dictionary of sampled parameter values.
|
51
|
+
"""
|
52
|
+
pass
|
53
|
+
|
54
|
+
@abstractmethod
|
55
|
+
def log_likelihood(self, data: jnp.ndarray, params: ParamsDict) -> jnp.ndarray:
|
56
|
+
"""Calculate log likelihood of data under current parameters.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
data: Input data array.
|
60
|
+
params: Current parameter values.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Log likelihood values for each datapoint.
|
64
|
+
"""
|
65
|
+
pass
|
66
|
+
|
67
|
+
def fit_model(self, data: jnp.ndarray, seed: int = 0) -> MCMC:
|
68
|
+
"""Fit the mixture model using MCMC.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
data: Input data to fit.
|
72
|
+
seed: Random seed for reproducibility.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
Fitted MCMC object containing samples.
|
76
|
+
"""
|
77
|
+
nuts_kernel = NUTS(self.mixture_model)
|
78
|
+
mcmc = MCMC(nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, progress_bar=False)
|
79
|
+
mcmc.run(random.PRNGKey(seed), data=data)
|
80
|
+
return mcmc
|
81
|
+
|
82
|
+
def run_model(self, data: jnp.ndarray, seed: int = 0) -> np.ndarray:
|
83
|
+
"""Run model fitting and assign components.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
data: Input data array.
|
87
|
+
seed: Random seed.
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
Array of "Positive"/"Negative" assignments for each datapoint.
|
91
|
+
"""
|
92
|
+
self.mcmc = self.fit_model(data, seed)
|
93
|
+
self.samples = self.mcmc.get_samples()
|
94
|
+
self.assignments = self.assignment(self.samples, data)
|
95
|
+
return self.assignments
|
96
|
+
|
97
|
+
def mixture_model(self, data: jnp.ndarray) -> None:
|
98
|
+
"""Define mixture model structure for NumPyro.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
data: Input data array.
|
102
|
+
"""
|
103
|
+
params = self.initialize_params()
|
104
|
+
|
105
|
+
with numpyro.plate("data", data.shape[0]):
|
106
|
+
log_likelihoods = self.log_likelihood(data, params)
|
107
|
+
log_mixture_likelihood = jax.scipy.special.logsumexp(log_likelihoods, axis=-1)
|
108
|
+
numpyro.sample("obs", dist.Normal(log_mixture_likelihood, 1.0), obs=data)
|
109
|
+
|
110
|
+
def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray:
|
111
|
+
"""Assign data points to mixture components.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
samples: MCMC samples of parameters.
|
115
|
+
data: Input data array.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
Array of component assignments.
|
119
|
+
"""
|
120
|
+
params = {key: samples[key].mean(axis=0) for key in samples.keys()}
|
121
|
+
self.params = params
|
122
|
+
|
123
|
+
log_likelihoods = self.log_likelihood(data, params)
|
124
|
+
guide_assignments = jnp.argmax(log_likelihoods, axis=-1)
|
125
|
+
|
126
|
+
assignments = ["Negative" if assign == 0 else "Positive" for assign in guide_assignments]
|
127
|
+
return np.array(assignments)
|
128
|
+
|
129
|
+
|
130
|
+
class PoissonGaussMixture(MixtureModel):
|
131
|
+
"""Mixture model combining Poisson and Gaussian distributions."""
|
132
|
+
|
133
|
+
def log_likelihood(self, data: np.ndarray, params: ParamsDict) -> jnp.ndarray:
|
134
|
+
"""Calculate component-wise log likelihoods.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
data: Input data array.
|
138
|
+
params: Current parameter values.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
Log likelihood values for each component.
|
142
|
+
"""
|
143
|
+
poisson_rate = params["poisson_rate"]
|
144
|
+
gaussian_mean = params["gaussian_mean"]
|
145
|
+
gaussian_std = params["gaussian_std"]
|
146
|
+
mix_probs = params["mix_probs"]
|
147
|
+
|
148
|
+
# We penalize the model for positioning the Poisson component to the right of the Gaussian component
|
149
|
+
# by imposing a soft constraint to penalize the Poisson rate being larger than the Gaussian mean
|
150
|
+
# Heuristic regularization term to prevent flipping of the components
|
151
|
+
numpyro.factor("separation_penalty", +10 * jnp.heaviside(-poisson_rate + gaussian_mean, 0))
|
152
|
+
|
153
|
+
log_likelihoods = jnp.stack(
|
154
|
+
[
|
155
|
+
# Poisson component
|
156
|
+
jnp.log(mix_probs[0]) + dist.Poisson(poisson_rate).log_prob(data),
|
157
|
+
# Gaussian component
|
158
|
+
jnp.log(mix_probs[1]) + dist.Normal(gaussian_mean, gaussian_std).log_prob(data),
|
159
|
+
],
|
160
|
+
axis=-1,
|
161
|
+
)
|
162
|
+
|
163
|
+
return log_likelihoods
|
164
|
+
|
165
|
+
def initialize_params(self) -> ParamsDict:
|
166
|
+
"""Initialize model parameters via prior sampling.
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
Dictionary of sampled parameter values.
|
170
|
+
"""
|
171
|
+
params = {}
|
172
|
+
params["poisson_rate"] = numpyro.sample("poisson_rate", dist.Exponential(self.poisson_rate_prior))
|
173
|
+
params["gaussian_mean"] = numpyro.sample("gaussian_mean", dist.Normal(*self.gaussian_mean_prior))
|
174
|
+
params["gaussian_std"] = numpyro.sample("gaussian_std", dist.HalfNormal(self.gaussian_std_prior))
|
175
|
+
params["mix_probs"] = numpyro.sample(
|
176
|
+
"mix_probs",
|
177
|
+
dist.Dirichlet(jnp.array([1 - self.fraction_positive_expected, self.fraction_positive_expected])),
|
178
|
+
)
|
179
|
+
return params
|
pertpy/tools/__init__.py
CHANGED
@@ -46,7 +46,7 @@ Sccoda = lazy_import("pertpy.tools._coda._sccoda", "Sccoda", CODA_EXTRAS)
|
|
46
46
|
Tasccoda = lazy_import("pertpy.tools._coda._tasccoda", "Tasccoda", CODA_EXTRAS)
|
47
47
|
|
48
48
|
DE_EXTRAS = ["formulaic", "pydeseq2"]
|
49
|
-
EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS)
|
49
|
+
EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2
|
50
50
|
PyDESeq2 = lazy_import("pertpy.tools._differential_gene_expression", "PyDESeq2", DE_EXTRAS)
|
51
51
|
Statsmodels = lazy_import("pertpy.tools._differential_gene_expression", "Statsmodels", DE_EXTRAS + ["statsmodels"])
|
52
52
|
TTest = lazy_import("pertpy.tools._differential_gene_expression", "TTest", DE_EXTRAS)
|