pertpy 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -266,16 +266,23 @@ class GuideAssignment:
266
266
  res.loc[adata.obs_names[is_nonzero][assignments == "Positive"], gene] = 1
267
267
 
268
268
  # Add the parameters to the adata.var DataFrame
269
- for params_name, param in mixture_model.params.items():
270
- if param.ndim == 0:
271
- if params_name not in adata.var.columns:
272
- adata.var[params_name] = np.nan
273
- adata.var.loc[gene, params_name] = param.item()
274
- else:
275
- for i, p in enumerate(param):
276
- if f"{params_name}_{i}" not in adata.var.columns:
277
- adata.var[f"{params_name}_{i}"] = np.nan
278
- adata.var.loc[gene, f"{params_name}_{i}"] = p
269
+ samples = mixture_model.mcmc.get_samples()
270
+ param_data = {}
271
+
272
+ for param_name in ["gaussian_mean", "gaussian_std", "poisson_rate", "mix_probs"]:
273
+ if param_name in samples:
274
+ param_value = samples[param_name].mean(axis=0)
275
+ if param_value.ndim == 0:
276
+ param_data[param_name] = param_value.item()
277
+ else:
278
+ for i, p in enumerate(param_value):
279
+ param_data[f"{param_name}_{i}"] = p.item()
280
+
281
+ # Add all columns at once
282
+ for col_name, value in param_data.items():
283
+ if col_name not in adata.var.columns:
284
+ adata.var[col_name] = np.nan
285
+ adata.var.loc[gene, col_name] = value
279
286
 
280
287
  # Assign guides to cells
281
288
  # Some cells might have multiple guides assigned
@@ -8,7 +8,7 @@ import numpy as np
8
8
  from jax.random import PRNGKey
9
9
  from jax.scipy.special import logsumexp
10
10
  from numpyro import factor, plate, sample
11
- from numpyro.distributions import Dirichlet, Exponential, HalfNormal, Normal, Poisson
11
+ from numpyro.distributions import Categorical, Dirichlet, Exponential, HalfNormal, Normal, Poisson
12
12
  from numpyro.infer import MCMC, NUTS
13
13
 
14
14
  ParamsDict = Mapping[str, jnp.ndarray]
@@ -102,8 +102,14 @@ class MixtureModel(ABC):
102
102
 
103
103
  with plate("data", data.shape[0]):
104
104
  log_likelihoods = self.log_likelihood(data, params)
105
- log_mixture_likelihood = logsumexp(log_likelihoods, axis=-1)
106
- sample("obs", Normal(log_mixture_likelihood, 1.0), obs=data)
105
+ mixture_probs = jnp.exp(log_likelihoods - logsumexp(log_likelihoods, axis=-1, keepdims=True))
106
+ z = sample("z", Categorical(mixture_probs), infer={"enumerate": "parallel"})
107
+
108
+ # Observe under selected component
109
+ poisson_ll = Poisson(params["poisson_rate"]).log_prob(data)
110
+ gaussian_ll = Normal(params["gaussian_mean"], params["gaussian_std"]).log_prob(data)
111
+ obs_ll = jnp.where(z == 0, poisson_ll, gaussian_ll)
112
+ factor("obs", obs_ll)
107
113
 
108
114
  def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray:
109
115
  """Assign data points to mixture components.
pertpy/tools/__init__.py CHANGED
@@ -21,7 +21,6 @@ from pertpy.tools._perturbation_space._simple import (
21
21
  KMeansSpace,
22
22
  PseudobulkSpace,
23
23
  )
24
- from pertpy.tools._scgen import Scgen
25
24
 
26
25
 
27
26
  def __getattr__(name: str):
@@ -35,14 +34,25 @@ def __getattr__(name: str):
35
34
  raise ImportError(
36
35
  "Extra dependencies required: toytree, ete4. Please install with: pip install toytree ete4"
37
36
  ) from None
38
-
39
37
  elif name in ["EdgeR", "PyDESeq2", "Statsmodels", "TTest", "WilcoxonTest"]:
40
38
  module = import_module("pertpy.tools._differential_gene_expression")
41
39
  return getattr(module, name)
40
+ elif name == "Scgen":
41
+ try:
42
+ module = import_module("pertpy.tools._scgen")
43
+ return module.Scgen
44
+ except ImportError:
45
+ raise ImportError(
46
+ "Scgen requires scvi-tools to be installed. Please install with: pip install scvi-tools"
47
+ ) from None
42
48
 
43
49
  raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
44
50
 
45
51
 
52
+ def __dir__():
53
+ return __all__
54
+
55
+
46
56
  __all__ = [
47
57
  "Augur",
48
58
  "Cinemaot",
pertpy/tools/_augur.py CHANGED
@@ -36,6 +36,7 @@ from statsmodels.api import OLS
36
36
  from statsmodels.stats.multitest import fdrcorrection
37
37
 
38
38
  from pertpy._doc import _doc_params, doc_common_plot_args
39
+ from pertpy.tools.core import _is_raw_counts
39
40
 
40
41
  if TYPE_CHECKING:
41
42
  from matplotlib.axes import Axes
@@ -87,6 +88,7 @@ class Augur:
87
88
  self,
88
89
  input: AnnData | pd.DataFrame,
89
90
  *,
91
+ layer: str | None = None,
90
92
  meta: pd.DataFrame | None = None,
91
93
  label_col: str = "label_col",
92
94
  cell_type_col: str = "cell_type_col",
@@ -98,6 +100,7 @@ class Augur:
98
100
  Args:
99
101
  input: Anndata or matrix containing gene expression values (genes in rows, cells in columns)
100
102
  and optionally meta data about each cell.
103
+ layer: Layer in AnnData to use for expression data. If None, uses .X
101
104
  meta: Optional Pandas DataFrame containing meta data about each cell.
102
105
  label_col: column of the meta DataFrame or the Anndata or matrix containing the condition labels for each cell
103
106
  in the cell-by-gene expression matrix
@@ -114,11 +117,11 @@ class Augur:
114
117
  >>> import pertpy as pt
115
118
  >>> adata = pt.dt.sc_sim_augur()
116
119
  >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
117
- >>> loaded_data = ag_rfc.load(adata)
120
+ >>> augur_adata = ag_rfc.load(adata)
118
121
  """
119
122
  if isinstance(input, AnnData):
120
- input.obs = input.obs.rename(columns={cell_type_col: "cell_type", label_col: "label"})
121
123
  adata = input
124
+ obs_renamed = adata.obs.rename(columns={cell_type_col: "cell_type", label_col: "label"})
122
125
 
123
126
  elif isinstance(input, pd.DataFrame):
124
127
  if meta is None:
@@ -130,27 +133,47 @@ class Augur:
130
133
 
131
134
  label = input[label_col] if meta is None else meta[label_col]
132
135
  cell_type = input[cell_type_col] if meta is None else meta[cell_type_col]
133
- x = input.drop([label_col, cell_type_col], axis=1) if meta is None else input
134
- adata = AnnData(X=x, obs=pd.DataFrame({"cell_type": cell_type, "label": label}))
136
+ X = input.drop([label_col, cell_type_col], axis=1) if meta is None else input
137
+ adata = AnnData(X=X, obs=pd.DataFrame({"cell_type": cell_type, "label": label}))
138
+ obs_renamed = adata.obs
135
139
 
136
- if len(adata.obs["label"].unique()) < 2:
140
+ if len(obs_renamed["label"].unique()) < 2:
137
141
  raise ValueError("Less than two unique labels in dataset. At least two are needed for the analysis.")
142
+
143
+ if isinstance(input, AnnData):
144
+ final_adata = AnnData(X=adata.X, obs=obs_renamed, var=adata.var, layers=adata.layers)
145
+ else:
146
+ final_adata = adata
147
+
138
148
  # dummy variables for categorical data
139
- if adata.obs["label"].dtype.name == "category":
140
- # filter samples according to label
149
+ if final_adata.obs["label"].dtype.name == "category":
150
+ label_encoder = LabelEncoder()
151
+ final_adata.obs["y_"] = label_encoder.fit_transform(final_adata.obs["label"])
152
+
141
153
  if condition_label is not None and treatment_label is not None:
142
154
  logger.info(f"Filtering samples with {condition_label} and {treatment_label} labels.")
143
- adata = ad.concat(
144
- [adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]]
155
+ final_adata = ad.concat(
156
+ [
157
+ final_adata[final_adata.obs["label"] == condition_label],
158
+ final_adata[final_adata.obs["label"] == treatment_label],
159
+ ]
145
160
  )
146
- label_encoder = LabelEncoder()
147
- adata.obs["y_"] = label_encoder.fit_transform(adata.obs["label"])
148
161
  else:
149
- y = adata.obs["label"].to_frame()
162
+ y = final_adata.obs["label"].to_frame()
150
163
  y = y.rename(columns={"label": "y_"})
151
- adata.obs = pd.concat([adata.obs, y], axis=1)
164
+ final_adata.obs = pd.concat([final_adata.obs, y], axis=1)
152
165
 
153
- return adata
166
+ if layer is not None:
167
+ if layer not in final_adata.layers:
168
+ raise ValueError(f"Layer '{layer}' not found in AnnData object")
169
+ X = final_adata.layers[layer]
170
+ else:
171
+ X = final_adata.X
172
+
173
+ if not _is_raw_counts(X):
174
+ logger.warning("Data does not appear to be raw counts. Augur developers recommend using raw counts.")
175
+
176
+ return final_adata
154
177
 
155
178
  def create_estimator(
156
179
  self,
@@ -11,7 +11,6 @@ from jax import config, random
11
11
  from lamin_utils import logger
12
12
  from mudata import MuData
13
13
  from numpyro.infer import Predictive
14
- from rich import print
15
14
 
16
15
  from pertpy.tools._coda._base_coda import CompositionalModel2, from_scanpy
17
16
 
@@ -25,24 +24,6 @@ config.update("jax_enable_x64", True)
25
24
  class Sccoda(CompositionalModel2):
26
25
  r"""Statistical model for single-cell differential composition analysis with specification of a reference cell type.
27
26
 
28
- This is the standard scCODA model and recommended for all uses.
29
-
30
- The hierarchical formulation of the model for one sample is:
31
-
32
- .. math::
33
- y|x &\\sim DirMult(\\phi, \\bar{y}) \\\\
34
- \\log(\\phi) &= \\alpha + x \\beta \\\\
35
- \\alpha_k &\\sim N(0, 5) \\quad &\\forall k \\in [K] \\\\
36
- \\beta_{m, \\hat{k}} &= 0 &\\forall m \\in [M]\\\\
37
- \\beta_{m, k} &= \\tau_{m, k} \\tilde{\\beta}_{m, k} \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
38
- \\tau_{m, k} &= \\frac{\\exp(t_{m, k})}{1+ \\exp(t_{m, k})} \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
39
- \\frac{t_{m, k}}{50} &\\sim N(0, 1) \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
40
- \\tilde{\\beta}_{m, k} &= \\sigma_m^2 \\cdot \\gamma_{m, k} \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
41
- \\sigma_m^2 &\\sim HC(0, 1) \\quad &\\forall m \\in [M] \\\\
42
- \\gamma_{m, k} &\\sim N(0,1) \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
43
-
44
- with y being the cell counts and x the covariates.
45
-
46
27
  For further information, see `scCODA is a Bayesian model for compositional single-cell data analysis`
47
28
  (Büttner, Ostner et al., NatComms, 2021)
48
29
  """
@@ -303,7 +284,7 @@ class Sccoda(CompositionalModel2):
303
284
  self,
304
285
  data: AnnData | MuData,
305
286
  modality_key: str = "coda",
306
- rng_key=None,
287
+ rng_key: int | None = None,
307
288
  num_prior_samples: int = 500,
308
289
  use_posterior_predictive: bool = True,
309
290
  ) -> az.InferenceData:
@@ -381,34 +362,9 @@ class Sccoda(CompositionalModel2):
381
362
  if rng_key is None:
382
363
  rng = np.random.default_rng()
383
364
  rng_key = random.key(rng.integers(0, 10000))
384
-
385
- if use_posterior_predictive:
386
- posterior_predictive = Predictive(self.model, self.mcmc.get_samples())(
387
- rng_key,
388
- counts=None,
389
- covariates=numpyro_covariates,
390
- n_total=numpyro_n_total,
391
- ref_index=ref_index,
392
- sample_adata=sample_adata,
393
- )
394
- else:
395
- posterior_predictive = None
396
-
397
- if num_prior_samples > 0:
398
- prior = Predictive(self.model, num_samples=num_prior_samples)(
399
- rng_key,
400
- counts=None,
401
- covariates=numpyro_covariates,
402
- n_total=numpyro_n_total,
403
- ref_index=ref_index,
404
- sample_adata=sample_adata,
405
- )
406
365
  else:
407
- prior = None
366
+ rng_key = random.key(rng_key)
408
367
 
409
- import arviz as az
410
-
411
- # Create arviz object
412
368
  if use_posterior_predictive:
413
369
  posterior_predictive = Predictive(self.model, self.mcmc.get_samples())(
414
370
  rng_key,
@@ -451,6 +407,9 @@ class Sccoda(CompositionalModel2):
451
407
  else:
452
408
  prior = None
453
409
 
410
+ import arviz as az
411
+
412
+ # Create arviz object
454
413
  arviz_data = az.from_numpyro(
455
414
  self.mcmc, prior=prior, posterior_predictive=posterior_predictive, dims=dims, coords=coords
456
415
  )
@@ -468,76 +427,84 @@ class Sccoda(CompositionalModel2):
468
427
  *args,
469
428
  **kwargs,
470
429
  ):
471
- """Examples:
472
- >>> import pertpy as pt
473
- >>> haber_cells = pt.dt.haber_2017_regions()
474
- >>> sccoda = pt.tl.Sccoda()
475
- >>> mdata = sccoda.load(haber_cells,
476
- >>> type="cell_level",
477
- >>> generate_sample_level=True,
478
- >>> cell_type_identifier="cell_label",
479
- >>> sample_identifier="batch",
480
- >>> covariate_obs=["condition"])
481
- >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
482
- >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42).
483
- """ # noqa: D205
430
+ """
431
+
432
+ Examples:
433
+ >>> import pertpy as pt
434
+ >>> haber_cells = pt.dt.haber_2017_regions()
435
+ >>> sccoda = pt.tl.Sccoda()
436
+ >>> mdata = sccoda.load(haber_cells,
437
+ >>> type="cell_level",
438
+ >>> generate_sample_level=True,
439
+ >>> cell_type_identifier="cell_label",
440
+ >>> sample_identifier="batch",
441
+ >>> covariate_obs=["condition"])
442
+ >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
443
+ >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42).
444
+ """ # noqa: D205, D212
484
445
  return super().run_nuts(data, modality_key, num_samples, num_warmup, rng_key, copy, *args, **kwargs)
485
446
 
486
447
  run_nuts.__doc__ = CompositionalModel2.run_nuts.__doc__ + run_nuts.__doc__
487
448
 
488
449
  def credible_effects(self, data: AnnData | MuData, modality_key: str = "coda", est_fdr: float = None) -> pd.Series:
489
- """Examples:
490
- >>> import pertpy as pt
491
- >>> haber_cells = pt.dt.haber_2017_regions()
492
- >>> sccoda = pt.tl.Sccoda()
493
- >>> mdata = sccoda.load(haber_cells,
494
- >>> type="cell_level",
495
- >>> generate_sample_level=True,
496
- >>> cell_type_identifier="cell_label",
497
- >>> sample_identifier="batch",
498
- >>> covariate_obs=["condition"])
499
- >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
500
- >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
501
- >>> credible_effects = sccoda.credible_effects(mdata).
502
- """ # noqa: D205
450
+ """
451
+
452
+ Examples:
453
+ >>> import pertpy as pt
454
+ >>> haber_cells = pt.dt.haber_2017_regions()
455
+ >>> sccoda = pt.tl.Sccoda()
456
+ >>> mdata = sccoda.load(haber_cells,
457
+ >>> type="cell_level",
458
+ >>> generate_sample_level=True,
459
+ >>> cell_type_identifier="cell_label",
460
+ >>> sample_identifier="batch",
461
+ >>> covariate_obs=["condition"])
462
+ >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
463
+ >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
464
+ >>> credible_effects = sccoda.credible_effects(mdata).
465
+ """ # noqa: D205, D212
503
466
  return super().credible_effects(data, modality_key, est_fdr)
504
467
 
505
468
  credible_effects.__doc__ = CompositionalModel2.credible_effects.__doc__ + credible_effects.__doc__
506
469
 
507
470
  def summary(self, data: AnnData | MuData, extended: bool = False, modality_key: str = "coda", *args, **kwargs):
508
- """Examples:
509
- >>> import pertpy as pt
510
- >>> haber_cells = pt.dt.haber_2017_regions()
511
- >>> sccoda = pt.tl.Sccoda()
512
- >>> mdata = sccoda.load(haber_cells,
513
- >>> type="cell_level",
514
- >>> generate_sample_level=True,
515
- >>> cell_type_identifier="cell_label",
516
- >>> sample_identifier="batch",
517
- >>> covariate_obs=["condition"])
518
- >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
519
- >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
520
- >>> sccoda.summary(mdata).
521
- """ # noqa: D205
471
+ """
472
+
473
+ Examples:
474
+ >>> import pertpy as pt
475
+ >>> haber_cells = pt.dt.haber_2017_regions()
476
+ >>> sccoda = pt.tl.Sccoda()
477
+ >>> mdata = sccoda.load(haber_cells,
478
+ >>> type="cell_level",
479
+ >>> generate_sample_level=True,
480
+ >>> cell_type_identifier="cell_label",
481
+ >>> sample_identifier="batch",
482
+ >>> covariate_obs=["condition"])
483
+ >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
484
+ >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
485
+ >>> sccoda.summary(mdata).
486
+ """ # noqa: D205, D212
522
487
  return super().summary(data, extended, modality_key, *args, **kwargs)
523
488
 
524
489
  summary.__doc__ = CompositionalModel2.summary.__doc__ + summary.__doc__
525
490
 
526
491
  def set_fdr(self, data: AnnData | MuData, est_fdr: float, modality_key: str = "coda", *args, **kwargs):
527
- """Examples:
528
- >>> import pertpy as pt
529
- >>> haber_cells = pt.dt.haber_2017_regions()
530
- >>> sccoda = pt.tl.Sccoda()
531
- >>> mdata = sccoda.load(haber_cells,
532
- >>> type="cell_level",
533
- >>> generate_sample_level=True,
534
- >>> cell_type_identifier="cell_label",
535
- >>> sample_identifier="batch",
536
- >>> covariate_obs=["condition"])
537
- >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
538
- >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
539
- >>> sccoda.set_fdr(mdata, est_fdr=0.4).
540
- """ # noqa: D205
492
+ """
493
+
494
+ Examples:
495
+ >>> import pertpy as pt
496
+ >>> haber_cells = pt.dt.haber_2017_regions()
497
+ >>> sccoda = pt.tl.Sccoda()
498
+ >>> mdata = sccoda.load(haber_cells,
499
+ >>> type="cell_level",
500
+ >>> generate_sample_level=True,
501
+ >>> cell_type_identifier="cell_label",
502
+ >>> sample_identifier="batch",
503
+ >>> covariate_obs=["condition"])
504
+ >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
505
+ >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
506
+ >>> sccoda.set_fdr(mdata, est_fdr=0.4).
507
+ """ # noqa: D205, D212
541
508
  return super().set_fdr(data, est_fdr, modality_key, *args, **kwargs)
542
509
 
543
510
  set_fdr.__doc__ = CompositionalModel2.set_fdr.__doc__ + set_fdr.__doc__