pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
pertpy/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = "Lukas Heumos"
4
4
  __email__ = "lukas.heumos@posteo.net"
5
- __version__ = "0.9.4"
5
+ __version__ = "0.10.0"
6
6
 
7
7
  import warnings
8
8
 
pertpy/_doc.py ADDED
@@ -0,0 +1,19 @@
1
+ from textwrap import dedent
2
+
3
+
4
+ def _doc_params(**kwds): # pragma: no cover
5
+ """\
6
+ Docstrings should start with "\" in the first line for proper formatting.
7
+ """
8
+
9
+ def dec(obj):
10
+ obj.__orig_doc__ = obj.__doc__
11
+ obj.__doc__ = dedent(obj.__doc__.format_map(kwds))
12
+ return obj
13
+
14
+ return dec
15
+
16
+
17
+ doc_common_plot_args = """\
18
+ return_fig: if `True`, returns figure of the plot, that can be used for saving.\
19
+ """
pertpy/data/_datasets.py CHANGED
@@ -66,7 +66,7 @@ def sc_sim_augur() -> AnnData: # pragma: no cover
66
66
  output_file_path = settings.datasetdir / output_file_name
67
67
  if not Path(output_file_path).exists():
68
68
  _download(
69
- url="https://figshare.com/ndownloader/files/31645886",
69
+ url="https://figshare.com/ndownloader/files/49828902",
70
70
  output_file_name=output_file_name,
71
71
  output_path=settings.datasetdir,
72
72
  is_zip=False,
@@ -8,12 +8,15 @@ from lamin_utils import logger
8
8
  if TYPE_CHECKING:
9
9
  from collections.abc import Iterable
10
10
 
11
+ from matplotlib.pyplot import Figure
12
+
11
13
  import matplotlib.pyplot as plt
12
14
  import numpy as np
13
15
  import pandas as pd
14
16
  from scanpy import settings
15
17
  from scipy import stats
16
18
 
19
+ from pertpy._doc import _doc_params, doc_common_plot_args
17
20
  from pertpy.data._dataloader import _download
18
21
 
19
22
  from ._look_up import LookUp
@@ -338,8 +341,8 @@ class CellLine(MetaData):
338
341
  # then we can compare these keys and fetch the corresponding metadata.
339
342
  if query_id not in adata.obs.columns:
340
343
  raise ValueError(
341
- f"The specified `query_id` {query_id} can't be found in the `adata.obs`.\n"
342
- "Ensure that you are using one of the available query IDs present in the adata.obs for the annotation.\n"
344
+ f"The specified `query_id` {query_id} can't be found in the `adata.obs`. \n"
345
+ "Ensure that you are using one of the available query IDs present in the adata.obs for the annotation."
343
346
  "If the desired query ID is not available, you can fetch the cell line metadata "
344
347
  "using the `annotate()` function before calling 'annotate_bulk_rna()'. "
345
348
  "This ensures that the required query ID is included in your data, e.g. stripped_cell_line_name, DepMap ID."
@@ -356,9 +359,8 @@ class CellLine(MetaData):
356
359
  else:
357
360
  reference_id = "DepMap_ID"
358
361
  logger.warning(
359
- "To annotate bulk RNA data from Broad Institue, `DepMap_ID` is used as default reference and query identifier if no `reference_id` is given.\n"
360
- "Ensure that `DepMap_ID` is available in 'adata.obs'.\n"
361
- "Alternatively, use `annotate()` to annotate the cell line first "
362
+ "To annotate bulk RNA data from Broad Institue, `DepMap_ID` is used as default reference and query identifier if no `reference_id` is given."
363
+ "If `DepMap_ID` isn't available in 'adata.obs', use `annotate()` to annotate the cell line first."
362
364
  )
363
365
  if self.bulk_rna_broad is None:
364
366
  self._download_bulk_rna(cell_line_source="broad")
@@ -690,6 +692,7 @@ class CellLine(MetaData):
690
692
 
691
693
  return corr, pvals, new_corr, new_pvals
692
694
 
695
+ @_doc_params(common_plot_args=doc_common_plot_args)
693
696
  def plot_correlation(
694
697
  self,
695
698
  adata: AnnData,
@@ -700,7 +703,8 @@ class CellLine(MetaData):
700
703
  metadata_key: str = "bulk_rna_broad",
701
704
  category: str = "cell line",
702
705
  subset_identifier: str | int | Iterable[str] | Iterable[int] | None = None,
703
- ) -> None:
706
+ return_fig: bool = False,
707
+ ) -> Figure | None:
704
708
  """Visualise the correlation of cell lines with annotated metadata.
705
709
 
706
710
  Args:
@@ -713,6 +717,8 @@ class CellLine(MetaData):
713
717
  subset_identifier: Selected identifiers for scatter plot visualization between the X matrix and `metadata_key`.
714
718
  If not None, only the chosen cell line will be plotted, either specified as a value in `identifier` (string) or as an index number.
715
719
  If None, all cell lines will be plotted.
720
+ {common_plot_args}
721
+
716
722
  Returns:
717
723
  Pearson correlation coefficients and their corresponding p-values for matched and unmatched cell lines separately.
718
724
  """
@@ -740,7 +746,7 @@ class CellLine(MetaData):
740
746
  if all(isinstance(id, str) for id in subset_identifier_list):
741
747
  if set(subset_identifier_list).issubset(adata.obs[identifier].unique()):
742
748
  subset_identifier_list = np.where(
743
- np.in1d(adata.obs[identifier].values, subset_identifier_list)
749
+ np.isin(adata.obs[identifier].values, subset_identifier_list)
744
750
  )[0]
745
751
  else:
746
752
  raise ValueError("`Subset_identifier` must be found in adata.obs.`identifier`.")
@@ -790,6 +796,10 @@ class CellLine(MetaData):
790
796
  "edgecolor": "black",
791
797
  },
792
798
  )
799
+
800
+ if return_fig:
801
+ return plt.gcf()
793
802
  plt.show()
803
+ return None
794
804
  else:
795
- raise NotImplementedError
805
+ raise NotImplementedError("Only 'cell line' category is supported for correlation comparison.")
@@ -42,7 +42,7 @@ class Compound(MetaData):
42
42
  adata = adata.copy()
43
43
 
44
44
  if query_id not in adata.obs.columns:
45
- raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n" f"Please check again. ")
45
+ raise ValueError(f"The requested query_id {query_id} is not in `adata.obs`.\n Please check again.")
46
46
 
47
47
  query_dict = {}
48
48
  not_matched_identifiers = []
@@ -84,7 +84,7 @@ class Compound(MetaData):
84
84
  query_df = pd.DataFrame.from_dict(query_dict, orient="index", columns=["pubchem_name", "pubchem_ID", "smiles"])
85
85
  # Merge and remove duplicate columns
86
86
  # Column is converted to float after merging due to unmatches
87
- # Convert back to integers
87
+ # Convert back to integers afterwards
88
88
  if query_id_type == "cid":
89
89
  query_df.pubchem_ID = query_df.pubchem_ID.astype("Int64")
90
90
  adata.obs = (
@@ -119,8 +119,7 @@ class Compound(MetaData):
119
119
 
120
120
  The LookUp object provides an overview of the metadata to annotate.
121
121
  Each annotate_{metadata} function has a corresponding lookup function in the LookUp object,
122
- where users can search the reference_id in the metadata and
123
- compare with the query_id in their own data.
122
+ where users can search the reference_id in the metadata and compare with the query_id in their own data.
124
123
 
125
124
  Returns:
126
125
  Returns a LookUp object specific for compound annotation.
@@ -62,7 +62,7 @@ class MetaData:
62
62
  if verbosity > 0:
63
63
  logger.info(
64
64
  f"There are {total_identifiers} identifiers in `adata.obs`."
65
- f"However, {len(unmatched_identifiers)} identifiers can't be found in the {metadata_type} annotation,"
65
+ f"However, {len(unmatched_identifiers)} identifiers can't be found in the {metadata_type} annotation, "
66
66
  "leading to the presence of NA values for their respective metadata.\n"
67
67
  f"Please check again: *unmatched_identifiers[:verbosity]..."
68
68
  )
@@ -1,20 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import uuid
4
- from typing import TYPE_CHECKING
4
+ from typing import TYPE_CHECKING, Literal
5
+ from warnings import warn
5
6
 
7
+ import matplotlib.pyplot as plt
6
8
  import numpy as np
7
9
  import pandas as pd
8
10
  import scanpy as sc
9
11
  import scipy
12
+ from rich.progress import track
13
+ from scipy.sparse import issparse
14
+
15
+ from pertpy._doc import _doc_params, doc_common_plot_args
16
+ from pertpy.preprocessing._guide_rna_mixture import PoissonGaussMixture
10
17
 
11
18
  if TYPE_CHECKING:
12
19
  from anndata import AnnData
13
- from matplotlib.axes import Axes
20
+ from matplotlib.pyplot import Figure
14
21
 
15
22
 
16
23
  class GuideAssignment:
17
- """Offers simple guide assigment based on count thresholds."""
24
+ """Assign cells to guide RNAs."""
18
25
 
19
26
  def assign_by_threshold(
20
27
  self,
@@ -30,12 +37,12 @@ class GuideAssignment:
30
37
  This function expects unnormalized data as input.
31
38
 
32
39
  Args:
33
- adata: Annotated data matrix containing gRNA values
40
+ adata: AnnData object containing gRNA values.
34
41
  assignment_threshold: The count threshold that is required for an assignment to be viable.
35
42
  layer: Key to the layer containing raw count values of the gRNAs.
36
43
  adata.X is used if layer is None. Expects count data.
37
44
  output_layer: Assigned guide will be saved on adata.layers[output_key].
38
- only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
45
+ only_return_results: Whether to input AnnData is not modified and the result is returned as an :class:`np.ndarray`.
39
46
 
40
47
  Examples:
41
48
  Each cell is assigned to gRNA that occurs at least 5 times in the respective cell.
@@ -64,7 +71,7 @@ class GuideAssignment:
64
71
  assignment_threshold: float,
65
72
  layer: str | None = None,
66
73
  output_key: str = "assigned_guide",
67
- no_grna_assigned_key: str = "NT",
74
+ no_grna_assigned_key: str = "Negative",
68
75
  only_return_results: bool = False,
69
76
  ) -> np.ndarray | None:
70
77
  """Simple threshold based max gRNA assignment function.
@@ -73,13 +80,13 @@ class GuideAssignment:
73
80
  This function expects unnormalized data as input.
74
81
 
75
82
  Args:
76
- adata: Annotated data matrix containing gRNA values
83
+ adata: AnnData object containing gRNA values.
77
84
  assignment_threshold: The count threshold that is required for an assignment to be viable.
78
85
  layer: Key to the layer containing raw count values of the gRNAs.
79
86
  adata.X is used if layer is None. Expects count data.
80
87
  output_key: Assigned guide will be saved on adata.obs[output_key]. default value is `assigned_guide`.
81
88
  no_grna_assigned_key: The key to return if no gRNA is expressed enough.
82
- only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
89
+ only_return_results: Whether to input AnnData is not modified and the result is returned as an np.ndarray.
83
90
 
84
91
  Examples:
85
92
  Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
@@ -106,14 +113,103 @@ class GuideAssignment:
106
113
 
107
114
  return None
108
115
 
116
+ def assign_mixture_model(
117
+ self,
118
+ adata: AnnData,
119
+ model: Literal["poisson_gauss_mixture"] = "poisson_gauss_mixture",
120
+ assigned_guides_key: str = "assigned_guide",
121
+ no_grna_assigned_key: str = "negative",
122
+ max_assignments_per_cell: int = 5,
123
+ multiple_grna_assigned_key: str = "multiple",
124
+ multiple_grna_assignment_string: str = "+",
125
+ only_return_results: bool = False,
126
+ uns_key: str = "guide_assignment_params",
127
+ show_progress: bool = False,
128
+ **mixture_model_kwargs,
129
+ ) -> np.ndarray | None:
130
+ """Assigns gRNAs to cells using a mixture model.
131
+
132
+ Args:
133
+ adata: AnnData object containing gRNA values.
134
+ model: The model to use for the mixture model. Currently only `Poisson_Gauss_Mixture` is supported.
135
+ output_key: Assigned guide will be saved on adata.obs[output_key].
136
+ no_grna_assigned_key: The key to return if a cell is negative for all gRNAs.
137
+ max_assignments_per_cell: The maximum number of gRNAs that can be assigned to a cell.
138
+ multiple_grna_assigned_key: The key to return if multiple gRNAs are assigned to a cell.
139
+ multiple_grna_assignment_string: The string to use to join multiple gRNAs assigned to a cell.
140
+ only_return_results: Whether input AnnData is not modified and the result is returned as an np.ndarray.
141
+ show_progress: Whether to shows progress bar.
142
+ mixture_model_kwargs: Are passed to the mixture model.
143
+
144
+ Examples:
145
+ >>> import pertpy as pt
146
+ >>> mdata = pt.dt.papalexi_2021()
147
+ >>> gdo = mdata.mod["gdo"]
148
+ >>> ga = pt.pp.GuideAssignment()
149
+ >>> ga.assign_mixture_model(gdo)
150
+ """
151
+ if model == "poisson_gauss_mixture":
152
+ mixture_model = PoissonGaussMixture(**mixture_model_kwargs)
153
+ else:
154
+ raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.")
155
+
156
+ if uns_key not in adata.uns:
157
+ adata.uns[uns_key] = {}
158
+ elif type(adata.uns[uns_key]) is not dict:
159
+ raise ValueError(f"adata.uns['{uns_key}'] should be a dictionary. Please remove it or change the key.")
160
+
161
+ res = pd.DataFrame(0, index=adata.obs_names, columns=adata.var_names)
162
+ fct = track if show_progress else lambda iterable: iterable
163
+ for gene in fct(adata.var_names):
164
+ is_nonzero = (
165
+ np.ravel((adata[:, gene].X != 0).todense()) if issparse(adata.X) else np.ravel(adata[:, gene].X != 0)
166
+ )
167
+ if sum(is_nonzero) < 2:
168
+ warn(f"Skipping {gene} as there are less than 2 cells expressing the guide at all.", stacklevel=2)
169
+ continue
170
+ # We are only fitting the model to the non-zero values, the rest is
171
+ # automatically assigned to the negative class
172
+ data = adata[is_nonzero, gene].X.todense().A1 if issparse(adata.X) else adata[is_nonzero, gene].X
173
+ data = np.ravel(data)
174
+
175
+ if np.any(data < 0):
176
+ raise ValueError(
177
+ "Data contains negative values. Please use non-negative data for guide assignment with the Mixture Model."
178
+ )
179
+
180
+ # Log2 transform the data so positive population is approximately normal
181
+ data = np.log2(data)
182
+ assignments = mixture_model.run_model(data)
183
+ res.loc[adata.obs_names[is_nonzero][assignments == "Positive"], gene] = 1
184
+ adata.uns[uns_key][gene] = mixture_model.params
185
+
186
+ # Assign guides to cells
187
+ # Some cells might have multiple guides assigned
188
+ series = pd.Series(no_grna_assigned_key, index=adata.obs_names)
189
+ num_guides_assigned = res.sum(1)
190
+ series.loc[(num_guides_assigned <= max_assignments_per_cell) & (num_guides_assigned != 0)] = res.apply(
191
+ lambda row: row.index[row == 1].tolist(), axis=1
192
+ ).str.join(multiple_grna_assignment_string)
193
+ series.loc[num_guides_assigned > max_assignments_per_cell] = multiple_grna_assigned_key
194
+
195
+ if only_return_results:
196
+ return series.values
197
+
198
+ adata.obs[assigned_guides_key] = series.values
199
+
200
+ return None
201
+
202
+ @_doc_params(common_plot_args=doc_common_plot_args)
109
203
  def plot_heatmap(
110
204
  self,
111
205
  adata: AnnData,
206
+ *,
112
207
  layer: str | None = None,
113
208
  order_by: np.ndarray | str | None = None,
114
209
  key_to_save_order: str = None,
210
+ return_fig: bool = False,
115
211
  **kwargs,
116
- ) -> list[Axes]:
212
+ ) -> Figure | None:
117
213
  """Heatmap plotting of guide RNA expression matrix.
118
214
 
119
215
  Assuming guides have sparse expression, this function reorders cells
@@ -131,11 +227,12 @@ class GuideAssignment:
131
227
  If a string is provided, adata.obs[order_by] will be used as the order.
132
228
  If a numpy array is provided, the array will be used for ordering.
133
229
  key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None.
230
+ {common_plot_args}
134
231
  kwargs: Are passed to sc.pl.heatmap.
135
232
 
136
233
  Returns:
137
- List of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap.
138
- Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided.
234
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
235
+ Order of cells in the y-axis will be saved on `adata.obs[key_to_save_order]` if provided.
139
236
 
140
237
  Examples:
141
238
  Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then
@@ -172,7 +269,7 @@ class GuideAssignment:
172
269
  adata.obs[key_to_save_order] = pd.Categorical(order)
173
270
 
174
271
  try:
175
- axis_group = sc.pl.heatmap(
272
+ fig = sc.pl.heatmap(
176
273
  adata[order, :],
177
274
  var_names=adata.var.index.tolist(),
178
275
  groupby=temp_col_name,
@@ -180,9 +277,13 @@ class GuideAssignment:
180
277
  use_raw=False,
181
278
  dendrogram=False,
182
279
  layer=layer,
280
+ show=False,
183
281
  **kwargs,
184
282
  )
185
283
  finally:
186
284
  del adata.obs[temp_col_name]
187
285
 
188
- return axis_group
286
+ if return_fig:
287
+ return fig
288
+ plt.show()
289
+ return None
@@ -0,0 +1,179 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Mapping
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ import numpyro
10
+ import numpyro.distributions as dist
11
+ from jax import random
12
+ from numpyro.infer import MCMC, NUTS
13
+
14
+ ParamsDict = Mapping[str, jnp.ndarray]
15
+
16
+
17
+ class MixtureModel(ABC):
18
+ """Abstract base class for 2-component mixture models.
19
+
20
+ Args:
21
+ num_warmup: Number of warmup steps for MCMC sampling.
22
+ num_samples: Number of samples to draw after warmup.
23
+ fraction_positive_expected: Prior belief about fraction of positive components.
24
+ poisson_rate_prior: Rate parameter for exponential prior on Poisson component.
25
+ gaussian_mean_prior: Mean and standard deviation for Gaussian prior on positive component mean.
26
+ gaussian_std_prior: Scale parameter for half-normal prior on positive component std.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ num_warmup: int = 50,
32
+ num_samples: int = 100,
33
+ fraction_positive_expected: float = 0.15,
34
+ poisson_rate_prior: float = 0.2,
35
+ gaussian_mean_prior: tuple[float, float] = (3, 2),
36
+ gaussian_std_prior: float = 1,
37
+ ) -> None:
38
+ self.num_warmup = num_warmup
39
+ self.num_samples = num_samples
40
+ self.fraction_positive_expected = fraction_positive_expected
41
+ self.poisson_rate_prior = poisson_rate_prior
42
+ self.gaussian_mean_prior = gaussian_mean_prior
43
+ self.gaussian_std_prior = gaussian_std_prior
44
+
45
+ @abstractmethod
46
+ def initialize_params(self) -> ParamsDict:
47
+ """Initialize model parameters via sampling from priors.
48
+
49
+ Returns:
50
+ Dictionary of sampled parameter values.
51
+ """
52
+ pass
53
+
54
+ @abstractmethod
55
+ def log_likelihood(self, data: jnp.ndarray, params: ParamsDict) -> jnp.ndarray:
56
+ """Calculate log likelihood of data under current parameters.
57
+
58
+ Args:
59
+ data: Input data array.
60
+ params: Current parameter values.
61
+
62
+ Returns:
63
+ Log likelihood values for each datapoint.
64
+ """
65
+ pass
66
+
67
+ def fit_model(self, data: jnp.ndarray, seed: int = 0) -> MCMC:
68
+ """Fit the mixture model using MCMC.
69
+
70
+ Args:
71
+ data: Input data to fit.
72
+ seed: Random seed for reproducibility.
73
+
74
+ Returns:
75
+ Fitted MCMC object containing samples.
76
+ """
77
+ nuts_kernel = NUTS(self.mixture_model)
78
+ mcmc = MCMC(nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, progress_bar=False)
79
+ mcmc.run(random.PRNGKey(seed), data=data)
80
+ return mcmc
81
+
82
+ def run_model(self, data: jnp.ndarray, seed: int = 0) -> np.ndarray:
83
+ """Run model fitting and assign components.
84
+
85
+ Args:
86
+ data: Input data array.
87
+ seed: Random seed.
88
+
89
+ Returns:
90
+ Array of "Positive"/"Negative" assignments for each datapoint.
91
+ """
92
+ self.mcmc = self.fit_model(data, seed)
93
+ self.samples = self.mcmc.get_samples()
94
+ self.assignments = self.assignment(self.samples, data)
95
+ return self.assignments
96
+
97
+ def mixture_model(self, data: jnp.ndarray) -> None:
98
+ """Define mixture model structure for NumPyro.
99
+
100
+ Args:
101
+ data: Input data array.
102
+ """
103
+ params = self.initialize_params()
104
+
105
+ with numpyro.plate("data", data.shape[0]):
106
+ log_likelihoods = self.log_likelihood(data, params)
107
+ log_mixture_likelihood = jax.scipy.special.logsumexp(log_likelihoods, axis=-1)
108
+ numpyro.sample("obs", dist.Normal(log_mixture_likelihood, 1.0), obs=data)
109
+
110
+ def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray:
111
+ """Assign data points to mixture components.
112
+
113
+ Args:
114
+ samples: MCMC samples of parameters.
115
+ data: Input data array.
116
+
117
+ Returns:
118
+ Array of component assignments.
119
+ """
120
+ params = {key: samples[key].mean(axis=0) for key in samples.keys()}
121
+ self.params = params
122
+
123
+ log_likelihoods = self.log_likelihood(data, params)
124
+ guide_assignments = jnp.argmax(log_likelihoods, axis=-1)
125
+
126
+ assignments = ["Negative" if assign == 0 else "Positive" for assign in guide_assignments]
127
+ return np.array(assignments)
128
+
129
+
130
+ class PoissonGaussMixture(MixtureModel):
131
+ """Mixture model combining Poisson and Gaussian distributions."""
132
+
133
+ def log_likelihood(self, data: np.ndarray, params: ParamsDict) -> jnp.ndarray:
134
+ """Calculate component-wise log likelihoods.
135
+
136
+ Args:
137
+ data: Input data array.
138
+ params: Current parameter values.
139
+
140
+ Returns:
141
+ Log likelihood values for each component.
142
+ """
143
+ poisson_rate = params["poisson_rate"]
144
+ gaussian_mean = params["gaussian_mean"]
145
+ gaussian_std = params["gaussian_std"]
146
+ mix_probs = params["mix_probs"]
147
+
148
+ # We penalize the model for positioning the Poisson component to the right of the Gaussian component
149
+ # by imposing a soft constraint to penalize the Poisson rate being larger than the Gaussian mean
150
+ # Heuristic regularization term to prevent flipping of the components
151
+ numpyro.factor("separation_penalty", +10 * jnp.heaviside(-poisson_rate + gaussian_mean, 0))
152
+
153
+ log_likelihoods = jnp.stack(
154
+ [
155
+ # Poisson component
156
+ jnp.log(mix_probs[0]) + dist.Poisson(poisson_rate).log_prob(data),
157
+ # Gaussian component
158
+ jnp.log(mix_probs[1]) + dist.Normal(gaussian_mean, gaussian_std).log_prob(data),
159
+ ],
160
+ axis=-1,
161
+ )
162
+
163
+ return log_likelihoods
164
+
165
+ def initialize_params(self) -> ParamsDict:
166
+ """Initialize model parameters via prior sampling.
167
+
168
+ Returns:
169
+ Dictionary of sampled parameter values.
170
+ """
171
+ params = {}
172
+ params["poisson_rate"] = numpyro.sample("poisson_rate", dist.Exponential(self.poisson_rate_prior))
173
+ params["gaussian_mean"] = numpyro.sample("gaussian_mean", dist.Normal(*self.gaussian_mean_prior))
174
+ params["gaussian_std"] = numpyro.sample("gaussian_std", dist.HalfNormal(self.gaussian_std_prior))
175
+ params["mix_probs"] = numpyro.sample(
176
+ "mix_probs",
177
+ dist.Dirichlet(jnp.array([1 - self.fraction_positive_expected, self.fraction_positive_expected])),
178
+ )
179
+ return params
pertpy/tools/__init__.py CHANGED
@@ -46,7 +46,7 @@ Sccoda = lazy_import("pertpy.tools._coda._sccoda", "Sccoda", CODA_EXTRAS)
46
46
  Tasccoda = lazy_import("pertpy.tools._coda._tasccoda", "Tasccoda", CODA_EXTRAS)
47
47
 
48
48
  DE_EXTRAS = ["formulaic", "pydeseq2"]
49
- EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2
49
+ EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2
50
50
  PyDESeq2 = lazy_import("pertpy.tools._differential_gene_expression", "PyDESeq2", DE_EXTRAS)
51
51
  Statsmodels = lazy_import("pertpy.tools._differential_gene_expression", "Statsmodels", DE_EXTRAS + ["statsmodels"])
52
52
  TTest = lazy_import("pertpy.tools._differential_gene_expression", "TTest", DE_EXTRAS)