pertpy 1.0.2__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,7 @@ import numpy as np
8
8
  from jax.random import PRNGKey
9
9
  from jax.scipy.special import logsumexp
10
10
  from numpyro import factor, plate, sample
11
- from numpyro.distributions import Dirichlet, Exponential, HalfNormal, Normal, Poisson
11
+ from numpyro.distributions import Categorical, Dirichlet, Exponential, HalfNormal, Normal, Poisson
12
12
  from numpyro.infer import MCMC, NUTS
13
13
 
14
14
  ParamsDict = Mapping[str, jnp.ndarray]
@@ -102,8 +102,14 @@ class MixtureModel(ABC):
102
102
 
103
103
  with plate("data", data.shape[0]):
104
104
  log_likelihoods = self.log_likelihood(data, params)
105
- log_mixture_likelihood = logsumexp(log_likelihoods, axis=-1)
106
- sample("obs", Normal(log_mixture_likelihood, 1.0), obs=data)
105
+ mixture_probs = jnp.exp(log_likelihoods - logsumexp(log_likelihoods, axis=-1, keepdims=True))
106
+ z = sample("z", Categorical(mixture_probs), infer={"enumerate": "parallel"})
107
+
108
+ # Observe under selected component
109
+ poisson_ll = Poisson(params["poisson_rate"]).log_prob(data)
110
+ gaussian_ll = Normal(params["gaussian_mean"], params["gaussian_std"]).log_prob(data)
111
+ obs_ll = jnp.where(z == 0, poisson_ll, gaussian_ll)
112
+ factor("obs", obs_ll)
107
113
 
108
114
  def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray:
109
115
  """Assign data points to mixture components.
pertpy/tools/__init__.py CHANGED
@@ -21,7 +21,6 @@ from pertpy.tools._perturbation_space._simple import (
21
21
  KMeansSpace,
22
22
  PseudobulkSpace,
23
23
  )
24
- from pertpy.tools._scgen import Scgen
25
24
 
26
25
 
27
26
  def __getattr__(name: str):
@@ -35,14 +34,25 @@ def __getattr__(name: str):
35
34
  raise ImportError(
36
35
  "Extra dependencies required: toytree, ete4. Please install with: pip install toytree ete4"
37
36
  ) from None
38
-
39
37
  elif name in ["EdgeR", "PyDESeq2", "Statsmodels", "TTest", "WilcoxonTest"]:
40
38
  module = import_module("pertpy.tools._differential_gene_expression")
41
39
  return getattr(module, name)
40
+ elif name == "Scgen":
41
+ try:
42
+ module = import_module("pertpy.tools._scgen")
43
+ return module.Scgen
44
+ except ImportError:
45
+ raise ImportError(
46
+ "Scgen requires scvi-tools to be installed. Please install with: pip install scvi-tools"
47
+ ) from None
42
48
 
43
49
  raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
44
50
 
45
51
 
52
+ def __dir__():
53
+ return __all__
54
+
55
+
46
56
  __all__ = [
47
57
  "Augur",
48
58
  "Cinemaot",
pertpy/tools/_augur.py CHANGED
@@ -36,6 +36,7 @@ from statsmodels.api import OLS
36
36
  from statsmodels.stats.multitest import fdrcorrection
37
37
 
38
38
  from pertpy._doc import _doc_params, doc_common_plot_args
39
+ from pertpy.tools.core import _is_raw_counts
39
40
 
40
41
  if TYPE_CHECKING:
41
42
  from matplotlib.axes import Axes
@@ -87,6 +88,7 @@ class Augur:
87
88
  self,
88
89
  input: AnnData | pd.DataFrame,
89
90
  *,
91
+ layer: str | None = None,
90
92
  meta: pd.DataFrame | None = None,
91
93
  label_col: str = "label_col",
92
94
  cell_type_col: str = "cell_type_col",
@@ -98,6 +100,7 @@ class Augur:
98
100
  Args:
99
101
  input: Anndata or matrix containing gene expression values (genes in rows, cells in columns)
100
102
  and optionally meta data about each cell.
103
+ layer: Layer in AnnData to use for expression data. If None, uses .X
101
104
  meta: Optional Pandas DataFrame containing meta data about each cell.
102
105
  label_col: column of the meta DataFrame or the Anndata or matrix containing the condition labels for each cell
103
106
  in the cell-by-gene expression matrix
@@ -114,11 +117,11 @@ class Augur:
114
117
  >>> import pertpy as pt
115
118
  >>> adata = pt.dt.sc_sim_augur()
116
119
  >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
117
- >>> loaded_data = ag_rfc.load(adata)
120
+ >>> augur_adata = ag_rfc.load(adata)
118
121
  """
119
122
  if isinstance(input, AnnData):
120
- input.obs = input.obs.rename(columns={cell_type_col: "cell_type", label_col: "label"})
121
123
  adata = input
124
+ obs_renamed = adata.obs.rename(columns={cell_type_col: "cell_type", label_col: "label"})
122
125
 
123
126
  elif isinstance(input, pd.DataFrame):
124
127
  if meta is None:
@@ -130,27 +133,47 @@ class Augur:
130
133
 
131
134
  label = input[label_col] if meta is None else meta[label_col]
132
135
  cell_type = input[cell_type_col] if meta is None else meta[cell_type_col]
133
- x = input.drop([label_col, cell_type_col], axis=1) if meta is None else input
134
- adata = AnnData(X=x, obs=pd.DataFrame({"cell_type": cell_type, "label": label}))
136
+ X = input.drop([label_col, cell_type_col], axis=1) if meta is None else input
137
+ adata = AnnData(X=X, obs=pd.DataFrame({"cell_type": cell_type, "label": label}))
138
+ obs_renamed = adata.obs
135
139
 
136
- if len(adata.obs["label"].unique()) < 2:
140
+ if len(obs_renamed["label"].unique()) < 2:
137
141
  raise ValueError("Less than two unique labels in dataset. At least two are needed for the analysis.")
142
+
143
+ if isinstance(input, AnnData):
144
+ final_adata = AnnData(X=adata.X, obs=obs_renamed, var=adata.var, layers=adata.layers)
145
+ else:
146
+ final_adata = adata
147
+
138
148
  # dummy variables for categorical data
139
- if adata.obs["label"].dtype.name == "category":
140
- # filter samples according to label
149
+ if final_adata.obs["label"].dtype.name == "category":
150
+ label_encoder = LabelEncoder()
151
+ final_adata.obs["y_"] = label_encoder.fit_transform(final_adata.obs["label"])
152
+
141
153
  if condition_label is not None and treatment_label is not None:
142
154
  logger.info(f"Filtering samples with {condition_label} and {treatment_label} labels.")
143
- adata = ad.concat(
144
- [adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]]
155
+ final_adata = ad.concat(
156
+ [
157
+ final_adata[final_adata.obs["label"] == condition_label],
158
+ final_adata[final_adata.obs["label"] == treatment_label],
159
+ ]
145
160
  )
146
- label_encoder = LabelEncoder()
147
- adata.obs["y_"] = label_encoder.fit_transform(adata.obs["label"])
148
161
  else:
149
- y = adata.obs["label"].to_frame()
162
+ y = final_adata.obs["label"].to_frame()
150
163
  y = y.rename(columns={"label": "y_"})
151
- adata.obs = pd.concat([adata.obs, y], axis=1)
164
+ final_adata.obs = pd.concat([final_adata.obs, y], axis=1)
152
165
 
153
- return adata
166
+ if layer is not None:
167
+ if layer not in final_adata.layers:
168
+ raise ValueError(f"Layer '{layer}' not found in AnnData object")
169
+ X = final_adata.layers[layer]
170
+ else:
171
+ X = final_adata.X
172
+
173
+ if not _is_raw_counts(X):
174
+ logger.warning("Data does not appear to be raw counts. Augur developers recommend using raw counts.")
175
+
176
+ return final_adata
154
177
 
155
178
  def create_estimator(
156
179
  self,
@@ -11,7 +11,6 @@ from jax import config, random
11
11
  from lamin_utils import logger
12
12
  from mudata import MuData
13
13
  from numpyro.infer import Predictive
14
- from rich import print
15
14
 
16
15
  from pertpy.tools._coda._base_coda import CompositionalModel2, from_scanpy
17
16
 
@@ -25,24 +24,6 @@ config.update("jax_enable_x64", True)
25
24
  class Sccoda(CompositionalModel2):
26
25
  r"""Statistical model for single-cell differential composition analysis with specification of a reference cell type.
27
26
 
28
- This is the standard scCODA model and recommended for all uses.
29
-
30
- The hierarchical formulation of the model for one sample is:
31
-
32
- .. math::
33
- y|x &\\sim DirMult(\\phi, \\bar{y}) \\\\
34
- \\log(\\phi) &= \\alpha + x \\beta \\\\
35
- \\alpha_k &\\sim N(0, 5) \\quad &\\forall k \\in [K] \\\\
36
- \\beta_{m, \\hat{k}} &= 0 &\\forall m \\in [M]\\\\
37
- \\beta_{m, k} &= \\tau_{m, k} \\tilde{\\beta}_{m, k} \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
38
- \\tau_{m, k} &= \\frac{\\exp(t_{m, k})}{1+ \\exp(t_{m, k})} \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
39
- \\frac{t_{m, k}}{50} &\\sim N(0, 1) \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
40
- \\tilde{\\beta}_{m, k} &= \\sigma_m^2 \\cdot \\gamma_{m, k} \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
41
- \\sigma_m^2 &\\sim HC(0, 1) \\quad &\\forall m \\in [M] \\\\
42
- \\gamma_{m, k} &\\sim N(0,1) \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
43
-
44
- with y being the cell counts and x the covariates.
45
-
46
27
  For further information, see `scCODA is a Bayesian model for compositional single-cell data analysis`
47
28
  (Büttner, Ostner et al., NatComms, 2021)
48
29
  """
@@ -33,24 +33,6 @@ config.update("jax_enable_x64", True)
33
33
  class Tasccoda(CompositionalModel2):
34
34
  r"""Statistical model for tree-aggregated differential composition analysis (tascCODA, Ostner et al., 2021).
35
35
 
36
- The hierarchical formulation of the model for one sample is:
37
-
38
- .. math::
39
- \\begin{align*}
40
- Y_i &\\sim \\textrm{DirMult}(\\bar{Y}_i, \\textbf{a}(\\textbf{x})_i)\\\\
41
- \\log(\\textbf{a}(X))_i &= \\alpha + X_{i, \\cdot} \\beta\\\\
42
- \\alpha_j &\\sim \\mathcal{N}(0, 10) & \\forall j\\in[p]\\\\
43
- \\beta &= \\hat{\\beta} A^T \\\\
44
- \\hat{\\beta}_{l, k} &= 0 & \\forall k \\in \\hat{v}, l \\in [d]\\\\
45
- \\hat{\\beta}_{l, k} &= \\theta \\tilde{\\beta}_{1, l, k} + (1- \\theta) \\tilde{\\beta}_{0, l, k} \\quad & \\forall k\\in\\{[v] \\smallsetminus \\hat{v}\\}, l \\in [d]\\\\
46
- \\tilde{\\beta}_{m, l, k} &= \\sigma_{m, l, k} * b_{m, l, k} \\quad & \\forall k\\in\\{[v] \\smallsetminus \\hat{v}\\}, m \\in \\{0, 1\\}, l \\in [d]\\\\
47
- \\sigma_{m, l, k} &\\sim \\textrm{Exp}(\\lambda_{m, l, k}^2/2) \\quad & \\forall k\\in\\{[v] \\smallsetminus \\hat{v}\\}, l \\in \\{0, 1\\}, l \\in [d]\\\\
48
- b_{m, l, k} &\\sim N(0,1) \\quad & \\forall k\\in\\{[v] \\smallsetminus \\hat{v}\\}, l \\in \\{0, 1\\}, l \\in [d]\\\\
49
- \\theta &\\sim \\textrm{Beta}(1, \\frac{1}{|\\{[v] \\smallsetminus \\hat{v}\\}|})
50
- \\end{align*}
51
-
52
- with Y being the cell counts, X the covariates, and v the set of nodes of the underlying tree structure.
53
-
54
36
  For further information, see `tascCODA: Bayesian Tree-Aggregated Analysis of Compositional Amplicon and Single-Cell Data`
55
37
  (Ostner et al., 2021)
56
38
  """
@@ -75,11 +57,14 @@ class Tasccoda(CompositionalModel2):
75
57
  modality_key_1: str = "rna",
76
58
  modality_key_2: str = "coda",
77
59
  ) -> MuData:
78
- """Prepare a MuData object for subsequent processing. If type is "cell_level", then create a compositional analysis dataset from the input adata. If type is "sample_level", generate ete tree for tascCODA models from dendrogram information or cell-level observations.
60
+ """Prepare a MuData object for subsequent processing.
61
+
62
+ If type is "cell_level", then create a compositional analysis dataset from the input adata.
63
+ If type is "sample_level", generate ete tree for tascCODA models from dendrogram information or cell-level observations.
79
64
 
80
- When using ``type="cell_level"``, ``adata`` needs to have a column in ``adata.obs`` that contains the cell type assignment.
65
+ When using `type="cell_level"`, `adata` needs to have a column in `adata.obs` that contains the cell type assignment.
81
66
  Further, it must contain one column or a set of columns (e.g. subject id, treatment, disease status) that uniquely identify each (statistical) sample.
82
- Further covariates (e.g. subject age) can either be specified via addidional column names in ``adata.obs``, a key in ``adata.uns``, or as a separate DataFrame.
67
+ Further covariates (e.g. subject age) can either be specified via addidional column names in `adata.obs`, a key in `adata.uns`, or as a separate DataFrame.
83
68
 
84
69
  Args:
85
70
  adata: AnnData object.
@@ -90,10 +75,13 @@ class Tasccoda(CompositionalModel2):
90
75
  covariate_obs: If type is "cell_level", specify list of keys for adata.obs, where covariate values are stored.
91
76
  covariate_df: If type is "cell_level", specify dataFrame with covariates.
92
77
  dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object.
93
- levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
94
- levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
78
+ levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels.
79
+ The list must begin with the root level, and end with the leaf level.
80
+ levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels.
81
+ The list must begin with the root level, and end with the leaf level.
95
82
  add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
96
- key_added: If not specified, the tree is stored in .uns[tree]. If `data` is AnnData, save tree in `data`. If `data` is MuData, save tree in data[modality_2].
83
+ key_added: If not specified, the tree is stored in `.uns['tree']`.
84
+ If `data` is AnnData, save tree in `data`. If `data` is MuData, save tree in data[modality_2].
97
85
  modality_key_1: Key to the cell-level AnnData in the MuData object.
98
86
  modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
99
87
 
pertpy/tools/_mixscape.py CHANGED
@@ -177,7 +177,7 @@ class Mixscape:
177
177
  def mixscape(
178
178
  self,
179
179
  adata: AnnData,
180
- labels: str,
180
+ pert_key: str,
181
181
  control: str,
182
182
  *,
183
183
  new_class_name: str | None = "mixscape_class",
@@ -201,12 +201,12 @@ class Mixscape:
201
201
 
202
202
  Args:
203
203
  adata: The annotated data object.
204
- labels: The column of `.obs` with target gene labels.
204
+ pert_key: The column of `.obs` with target gene labels.
205
205
  control: Control category from the `labels` column.
206
206
  new_class_name: Name of mixscape classification to be stored in `.obs`.
207
207
  layer: Key from adata.layers whose value will be used to perform tests on. Default is using `.layers["X_pert"]`.
208
208
  min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
209
- logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells (default: 0.25).
209
+ logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
210
210
  de_layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
211
211
  test_method: Method to use for differential expression testing.
212
212
  iter_num: Number of normalmixEM iterations to run if convergence does not occur.
@@ -256,7 +256,7 @@ class Mixscape:
256
256
  adata=adata,
257
257
  split_masks=split_masks,
258
258
  categories=categories,
259
- labels=labels,
259
+ pert_key=pert_key,
260
260
  control=control,
261
261
  layer=de_layer,
262
262
  pval_cutoff=pval_cutoff,
@@ -278,7 +278,7 @@ class Mixscape:
278
278
 
279
279
  # initialize return variables
280
280
  adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
281
- adata.obs[new_class_name] = adata.obs[labels].astype(str)
281
+ adata.obs[new_class_name] = adata.obs[pert_key].astype(str)
282
282
  adata.obs[f"{new_class_name}_global"] = np.empty(
283
283
  [
284
284
  adata.n_obs,
@@ -290,12 +290,12 @@ class Mixscape:
290
290
  adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0.0
291
291
  for split, split_mask in enumerate(split_masks):
292
292
  category = categories[split]
293
- gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
293
+ gene_targets = list(set(adata[split_mask].obs[pert_key]).difference([control]))
294
294
  for gene in gene_targets:
295
295
  post_prob = 0
296
- orig_guide_cells = (adata.obs[labels] == gene) & split_mask
296
+ orig_guide_cells = (adata.obs[pert_key] == gene) & split_mask
297
297
  orig_guide_cells_index = list(orig_guide_cells.index[orig_guide_cells])
298
- nt_cells = (adata.obs[labels] == control) & split_mask
298
+ nt_cells = (adata.obs[pert_key] == control) & split_mask
299
299
  all_cells = orig_guide_cells | nt_cells
300
300
 
301
301
  if len(perturbation_markers[(category, gene)]) == 0:
@@ -307,7 +307,11 @@ class Mixscape:
307
307
 
308
308
  dat = X[np.asarray(all_cells)][:, de_genes_indices]
309
309
  if scale:
310
- dat = sc.pp.scale(dat)
310
+ with warnings.catch_warnings():
311
+ warnings.filterwarnings(
312
+ "ignore", message="zero-centering a sparse array/matrix densifies it."
313
+ )
314
+ dat = sc.pp.scale(dat)
311
315
 
312
316
  converged = False
313
317
  n_iter = 0
@@ -335,10 +339,10 @@ class Mixscape:
335
339
  pvec = pd.Series(np.asarray(pvec).flatten(), index=list(all_cells.index[all_cells]))
336
340
 
337
341
  if n_iter == 0:
338
- gv = pd.DataFrame(columns=["pvec", labels])
342
+ gv = pd.DataFrame(columns=["pvec", pert_key])
339
343
  gv["pvec"] = pvec
340
- gv[labels] = control
341
- gv.loc[guide_cells, labels] = gene
344
+ gv[pert_key] = control
345
+ gv.loc[guide_cells, pert_key] = gene
342
346
  if gene not in gv_list:
343
347
  gv_list[gene] = {}
344
348
  gv_list[gene][category] = gv
@@ -389,7 +393,7 @@ class Mixscape:
389
393
  def lda(
390
394
  self,
391
395
  adata: AnnData,
392
- labels: str,
396
+ pert_key: str,
393
397
  control: str,
394
398
  *,
395
399
  mixscape_class_global: str | None = "mixscape_class_global",
@@ -407,7 +411,7 @@ class Mixscape:
407
411
 
408
412
  Args:
409
413
  adata: The annotated data object.
410
- labels: The column of `.obs` with target gene labels.
414
+ pert_key: The column of `.obs` with target gene labels.
411
415
  control: Control category from the `pert_key` column.
412
416
  mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
413
417
  layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
@@ -456,7 +460,7 @@ class Mixscape:
456
460
  adata=adata,
457
461
  split_masks=split_masks,
458
462
  categories=categories,
459
- labels=labels,
463
+ pert_key=pert_key,
460
464
  control=control,
461
465
  layer=layer,
462
466
  pval_cutoff=pval_cutoff,
@@ -475,17 +479,19 @@ class Mixscape:
475
479
  continue
476
480
  else:
477
481
  gene_subset = adata_subset[
478
- (adata_subset.obs[labels] == key[1]) | (adata_subset.obs[labels] == control)
482
+ (adata_subset.obs[pert_key] == key[1]) | (adata_subset.obs[pert_key] == control)
479
483
  ].copy()
480
- sc.pp.scale(gene_subset)
484
+ with warnings.catch_warnings():
485
+ warnings.simplefilter("ignore", UserWarning)
486
+ sc.pp.scale(gene_subset)
481
487
  sc.tl.pca(gene_subset, n_comps=n_comps)
482
488
  # project cells into PCA space of gene_subset
483
489
  projected_pcs[key[1]] = np.asarray(np.dot(X, gene_subset.varm["PCs"]))
484
490
  # concatenate all pcs into a single matrix.
485
491
  projected_pcs_array = np.concatenate(list(projected_pcs.values()), axis=1)
486
492
 
487
- clf = LinearDiscriminantAnalysis(n_components=len(np.unique(adata_subset.obs[labels])) - 1)
488
- clf.fit(projected_pcs_array, adata_subset.obs[labels])
493
+ clf = LinearDiscriminantAnalysis(n_components=len(np.unique(adata_subset.obs[pert_key])) - 1)
494
+ clf.fit(projected_pcs_array, adata_subset.obs[pert_key])
489
495
  cell_embeddings = clf.transform(projected_pcs_array)
490
496
  adata.uns["mixscape_lda"] = cell_embeddings
491
497
 
@@ -495,9 +501,10 @@ class Mixscape:
495
501
  def _get_perturbation_markers(
496
502
  self,
497
503
  adata: AnnData,
504
+ *,
498
505
  split_masks: list[np.ndarray],
499
506
  categories: list[str],
500
- labels: str,
507
+ pert_key: str,
501
508
  control: str,
502
509
  layer: str,
503
510
  pval_cutoff: float,
@@ -511,7 +518,7 @@ class Mixscape:
511
518
  adata: :class:`~anndata.AnnData` object
512
519
  split_masks: List of boolean masks for each split/group.
513
520
  categories: List of split/group names.
514
- labels: The column of `.obs` with target gene labels.
521
+ pert_key: The column of `.obs` with target gene labels.
515
522
  control: Control category from the `labels` column.
516
523
  layer: Key from adata.layers whose value will be used to compare gene expression.
517
524
  pval_cutoff: P-value cut-off for selection of significantly DE genes.
@@ -526,7 +533,7 @@ class Mixscape:
526
533
  for split, split_mask in enumerate(split_masks):
527
534
  category = categories[split]
528
535
  # get gene sets for each split
529
- gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
536
+ gene_targets = list(set(adata[split_mask].obs[pert_key]).difference([control]))
530
537
  adata_split = adata[split_mask].copy()
531
538
  # find top DE genes between cells with targeting and non-targeting gRNAs
532
539
  with warnings.catch_warnings():
@@ -535,7 +542,7 @@ class Mixscape:
535
542
  sc.tl.rank_genes_groups(
536
543
  adata_split,
537
544
  layer=layer,
538
- groupby=labels,
545
+ groupby=pert_key,
539
546
  groups=gene_targets,
540
547
  reference=control,
541
548
  method=test_method,
@@ -666,7 +673,7 @@ class Mixscape:
666
673
  def plot_heatmap( # pragma: no cover # noqa: D417
667
674
  self,
668
675
  adata: AnnData,
669
- labels: str,
676
+ pert_key: str,
670
677
  target_gene: str,
671
678
  control: str,
672
679
  *,
@@ -682,7 +689,7 @@ class Mixscape:
682
689
 
683
690
  Args:
684
691
  adata: The annotated data object.
685
- labels: The column of `.obs` with target gene labels.
692
+ pert_key: The column of `.obs` with target gene labels.
686
693
  target_gene: Target gene name to visualize heatmap for.
687
694
  control: Control category from the `pert_key` column.
688
695
  layer: Key from `adata.layers` whose value will be used to perform tests on.
@@ -711,12 +718,13 @@ class Mixscape:
711
718
  """
712
719
  if "mixscape_class" not in adata.obs:
713
720
  raise ValueError("Please run `pt.tl.mixscape` first.")
714
- adata_subset = adata[(adata.obs[labels] == target_gene) | (adata.obs[labels] == control)].copy()
721
+ adata_subset = adata[(adata.obs[pert_key] == target_gene) | (adata.obs[pert_key] == control)].copy()
715
722
  with warnings.catch_warnings():
716
723
  warnings.simplefilter("ignore", RuntimeWarning)
717
724
  warnings.simplefilter("ignore", PerformanceWarning)
718
- sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
719
- sc.pp.scale(adata_subset, max_value=vmax)
725
+ warnings.simplefilter("ignore", UserWarning)
726
+ sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=pert_key, method=method)
727
+ sc.pp.scale(adata_subset, max_value=vmax)
720
728
  sc.pp.subsample(adata_subset, n_obs=subsample_number)
721
729
 
722
730
  fig = sc.pl.rank_genes_groups_heatmap(
@@ -739,7 +747,7 @@ class Mixscape:
739
747
  def plot_perturbscore( # pragma: no cover # noqa: D417
740
748
  self,
741
749
  adata: AnnData,
742
- labels: str,
750
+ pert_key: str,
743
751
  target_gene: str,
744
752
  *,
745
753
  mixscape_class: str = "mixscape_class",
@@ -758,7 +766,7 @@ class Mixscape:
758
766
 
759
767
  Args:
760
768
  adata: The annotated data object.
761
- labels: The column of `.obs` with target gene labels.
769
+ pert_key: The column of `.obs` with target gene labels.
762
770
  target_gene: Target gene name to visualize perturbation scores for.
763
771
  mixscape_class: The column of `.obs` with mixscape classifications.
764
772
  color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey.
@@ -797,21 +805,21 @@ class Mixscape:
797
805
  else:
798
806
  perturbation_score = pd.concat([perturbation_score, perturbation_score_temp])
799
807
  perturbation_score["mix"] = adata.obs[mixscape_class][perturbation_score.index]
800
- gd = list(set(perturbation_score[labels]).difference({target_gene}))[0]
808
+ gd = list(set(perturbation_score[pert_key]).difference({target_gene}))[0]
801
809
 
802
810
  # If before_mixscape is True, split densities based on original target gene classification
803
811
  if before_mixscape is True:
804
812
  palette = {gd: "#7d7d7d", target_gene: color}
805
- plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False)
813
+ plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=pert_key, fill=False, common_norm=False)
806
814
  top_r = max(plot_dens.get_lines()[cond].get_data()[1].max() for cond in range(len(plot_dens.get_lines())))
807
815
  plt.close()
808
816
  perturbation_score["y_jitter"] = perturbation_score["pvec"]
809
817
  rng = np.random.default_rng()
810
- perturbation_score.loc[perturbation_score[labels] == gd, "y_jitter"] = rng.uniform(
811
- low=0.001, high=top_r / 10, size=sum(perturbation_score[labels] == gd)
818
+ perturbation_score.loc[perturbation_score[pert_key] == gd, "y_jitter"] = rng.uniform(
819
+ low=0.001, high=top_r / 10, size=sum(perturbation_score[pert_key] == gd)
812
820
  )
813
- perturbation_score.loc[perturbation_score[labels] == target_gene, "y_jitter"] = rng.uniform(
814
- low=-top_r / 10, high=0, size=sum(perturbation_score[labels] == target_gene)
821
+ perturbation_score.loc[perturbation_score[pert_key] == target_gene, "y_jitter"] = rng.uniform(
822
+ low=-top_r / 10, high=0, size=sum(perturbation_score[pert_key] == target_gene)
815
823
  )
816
824
  # If split_by is provided, split densities based on the split_by
817
825
  if split_by is not None:
@@ -844,7 +852,7 @@ class Mixscape:
844
852
  else:
845
853
  if palette is None:
846
854
  palette = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color}
847
- plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False)
855
+ plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=pert_key, fill=False, common_norm=False)
848
856
  top_r = max(plot_dens.get_lines()[i].get_data()[1].max() for i in range(len(plot_dens.get_lines())))
849
857
  plt.close()
850
858
  perturbation_score["y_jitter"] = perturbation_score["pvec"]
@@ -899,6 +907,7 @@ class Mixscape:
899
907
  if return_fig:
900
908
  return plt.gcf()
901
909
  plt.show()
910
+
902
911
  return None
903
912
 
904
913
  @_doc_params(common_plot_args=doc_common_plot_args)
@@ -1058,7 +1067,7 @@ class Mixscape:
1058
1067
  data=obs_tidy,
1059
1068
  order=order,
1060
1069
  orient="vertical",
1061
- scale=scale,
1070
+ density_norm=scale,
1062
1071
  ax=ax,
1063
1072
  hue=hue,
1064
1073
  **kwargs,
@@ -1072,7 +1081,7 @@ class Mixscape:
1072
1081
  data=obs_tidy,
1073
1082
  order=order,
1074
1083
  jitter=jitter,
1075
- color="black",
1084
+ palette="dark:black",
1076
1085
  size=size,
1077
1086
  ax=ax,
1078
1087
  hue=hue,
@@ -22,7 +22,7 @@ class PerturbationComparison:
22
22
  ) -> float:
23
23
  """Compare classification accuracy between real and simulated perturbations.
24
24
 
25
- Trains a classifier on the real perturbation data + the control data and reports a normalized
25
+ Trains a classifier on the real perturbation data & the control data and reports a normalized
26
26
  classification accuracy on the simulated perturbation.
27
27
 
28
28
  Args:
@@ -64,8 +64,8 @@ class PerturbationComparison:
64
64
  real: Real perturbed data.
65
65
  simulated: Simulated perturbed data.
66
66
  control: Control data
67
- use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph. Only valid when
68
- control (`control`) is provided.
67
+ use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph.
68
+ Only valid when control (`control`) is provided.
69
69
  n_neighbors: Number of neighbors to use in k-neighbor graph.
70
70
  random_state: Random state used for k-neighbor graph construction.
71
71
  n_jobs: Number of cores to use. Defaults to -1 (all).