pertpy 1.0.2__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/data/_dataloader.py +2 -2
- pertpy/data/_datasets.py +62 -62
- pertpy/metadata/_drug.py +4 -2
- pertpy/preprocessing/_guide_rna.py +17 -10
- pertpy/preprocessing/_guide_rna_mixture.py +9 -3
- pertpy/tools/__init__.py +12 -2
- pertpy/tools/_augur.py +37 -14
- pertpy/tools/_coda/_sccoda.py +0 -19
- pertpy/tools/_coda/_tasccoda.py +12 -24
- pertpy/tools/_mixscape.py +48 -39
- pertpy/tools/_perturbation_space/_comparison.py +3 -3
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +261 -353
- pertpy/tools/_perturbation_space/_perturbation_space.py +22 -14
- pertpy/tools/_perturbation_space/_simple.py +12 -6
- pertpy/tools/_scgen/_scgenvae.py +2 -1
- pertpy/tools/core.py +18 -0
- {pertpy-1.0.2.dist-info → pertpy-1.0.3.dist-info}/METADATA +84 -51
- {pertpy-1.0.2.dist-info → pertpy-1.0.3.dist-info}/RECORD +21 -20
- {pertpy-1.0.2.dist-info → pertpy-1.0.3.dist-info}/WHEEL +1 -1
- {pertpy-1.0.2.dist-info → pertpy-1.0.3.dist-info}/licenses/LICENSE +0 -0
@@ -8,7 +8,7 @@ import numpy as np
|
|
8
8
|
from jax.random import PRNGKey
|
9
9
|
from jax.scipy.special import logsumexp
|
10
10
|
from numpyro import factor, plate, sample
|
11
|
-
from numpyro.distributions import Dirichlet, Exponential, HalfNormal, Normal, Poisson
|
11
|
+
from numpyro.distributions import Categorical, Dirichlet, Exponential, HalfNormal, Normal, Poisson
|
12
12
|
from numpyro.infer import MCMC, NUTS
|
13
13
|
|
14
14
|
ParamsDict = Mapping[str, jnp.ndarray]
|
@@ -102,8 +102,14 @@ class MixtureModel(ABC):
|
|
102
102
|
|
103
103
|
with plate("data", data.shape[0]):
|
104
104
|
log_likelihoods = self.log_likelihood(data, params)
|
105
|
-
|
106
|
-
sample("
|
105
|
+
mixture_probs = jnp.exp(log_likelihoods - logsumexp(log_likelihoods, axis=-1, keepdims=True))
|
106
|
+
z = sample("z", Categorical(mixture_probs), infer={"enumerate": "parallel"})
|
107
|
+
|
108
|
+
# Observe under selected component
|
109
|
+
poisson_ll = Poisson(params["poisson_rate"]).log_prob(data)
|
110
|
+
gaussian_ll = Normal(params["gaussian_mean"], params["gaussian_std"]).log_prob(data)
|
111
|
+
obs_ll = jnp.where(z == 0, poisson_ll, gaussian_ll)
|
112
|
+
factor("obs", obs_ll)
|
107
113
|
|
108
114
|
def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray:
|
109
115
|
"""Assign data points to mixture components.
|
pertpy/tools/__init__.py
CHANGED
@@ -21,7 +21,6 @@ from pertpy.tools._perturbation_space._simple import (
|
|
21
21
|
KMeansSpace,
|
22
22
|
PseudobulkSpace,
|
23
23
|
)
|
24
|
-
from pertpy.tools._scgen import Scgen
|
25
24
|
|
26
25
|
|
27
26
|
def __getattr__(name: str):
|
@@ -35,14 +34,25 @@ def __getattr__(name: str):
|
|
35
34
|
raise ImportError(
|
36
35
|
"Extra dependencies required: toytree, ete4. Please install with: pip install toytree ete4"
|
37
36
|
) from None
|
38
|
-
|
39
37
|
elif name in ["EdgeR", "PyDESeq2", "Statsmodels", "TTest", "WilcoxonTest"]:
|
40
38
|
module = import_module("pertpy.tools._differential_gene_expression")
|
41
39
|
return getattr(module, name)
|
40
|
+
elif name == "Scgen":
|
41
|
+
try:
|
42
|
+
module = import_module("pertpy.tools._scgen")
|
43
|
+
return module.Scgen
|
44
|
+
except ImportError:
|
45
|
+
raise ImportError(
|
46
|
+
"Scgen requires scvi-tools to be installed. Please install with: pip install scvi-tools"
|
47
|
+
) from None
|
42
48
|
|
43
49
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
44
50
|
|
45
51
|
|
52
|
+
def __dir__():
|
53
|
+
return __all__
|
54
|
+
|
55
|
+
|
46
56
|
__all__ = [
|
47
57
|
"Augur",
|
48
58
|
"Cinemaot",
|
pertpy/tools/_augur.py
CHANGED
@@ -36,6 +36,7 @@ from statsmodels.api import OLS
|
|
36
36
|
from statsmodels.stats.multitest import fdrcorrection
|
37
37
|
|
38
38
|
from pertpy._doc import _doc_params, doc_common_plot_args
|
39
|
+
from pertpy.tools.core import _is_raw_counts
|
39
40
|
|
40
41
|
if TYPE_CHECKING:
|
41
42
|
from matplotlib.axes import Axes
|
@@ -87,6 +88,7 @@ class Augur:
|
|
87
88
|
self,
|
88
89
|
input: AnnData | pd.DataFrame,
|
89
90
|
*,
|
91
|
+
layer: str | None = None,
|
90
92
|
meta: pd.DataFrame | None = None,
|
91
93
|
label_col: str = "label_col",
|
92
94
|
cell_type_col: str = "cell_type_col",
|
@@ -98,6 +100,7 @@ class Augur:
|
|
98
100
|
Args:
|
99
101
|
input: Anndata or matrix containing gene expression values (genes in rows, cells in columns)
|
100
102
|
and optionally meta data about each cell.
|
103
|
+
layer: Layer in AnnData to use for expression data. If None, uses .X
|
101
104
|
meta: Optional Pandas DataFrame containing meta data about each cell.
|
102
105
|
label_col: column of the meta DataFrame or the Anndata or matrix containing the condition labels for each cell
|
103
106
|
in the cell-by-gene expression matrix
|
@@ -114,11 +117,11 @@ class Augur:
|
|
114
117
|
>>> import pertpy as pt
|
115
118
|
>>> adata = pt.dt.sc_sim_augur()
|
116
119
|
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
|
117
|
-
>>>
|
120
|
+
>>> augur_adata = ag_rfc.load(adata)
|
118
121
|
"""
|
119
122
|
if isinstance(input, AnnData):
|
120
|
-
input.obs = input.obs.rename(columns={cell_type_col: "cell_type", label_col: "label"})
|
121
123
|
adata = input
|
124
|
+
obs_renamed = adata.obs.rename(columns={cell_type_col: "cell_type", label_col: "label"})
|
122
125
|
|
123
126
|
elif isinstance(input, pd.DataFrame):
|
124
127
|
if meta is None:
|
@@ -130,27 +133,47 @@ class Augur:
|
|
130
133
|
|
131
134
|
label = input[label_col] if meta is None else meta[label_col]
|
132
135
|
cell_type = input[cell_type_col] if meta is None else meta[cell_type_col]
|
133
|
-
|
134
|
-
adata = AnnData(X=
|
136
|
+
X = input.drop([label_col, cell_type_col], axis=1) if meta is None else input
|
137
|
+
adata = AnnData(X=X, obs=pd.DataFrame({"cell_type": cell_type, "label": label}))
|
138
|
+
obs_renamed = adata.obs
|
135
139
|
|
136
|
-
if len(
|
140
|
+
if len(obs_renamed["label"].unique()) < 2:
|
137
141
|
raise ValueError("Less than two unique labels in dataset. At least two are needed for the analysis.")
|
142
|
+
|
143
|
+
if isinstance(input, AnnData):
|
144
|
+
final_adata = AnnData(X=adata.X, obs=obs_renamed, var=adata.var, layers=adata.layers)
|
145
|
+
else:
|
146
|
+
final_adata = adata
|
147
|
+
|
138
148
|
# dummy variables for categorical data
|
139
|
-
if
|
140
|
-
|
149
|
+
if final_adata.obs["label"].dtype.name == "category":
|
150
|
+
label_encoder = LabelEncoder()
|
151
|
+
final_adata.obs["y_"] = label_encoder.fit_transform(final_adata.obs["label"])
|
152
|
+
|
141
153
|
if condition_label is not None and treatment_label is not None:
|
142
154
|
logger.info(f"Filtering samples with {condition_label} and {treatment_label} labels.")
|
143
|
-
|
144
|
-
[
|
155
|
+
final_adata = ad.concat(
|
156
|
+
[
|
157
|
+
final_adata[final_adata.obs["label"] == condition_label],
|
158
|
+
final_adata[final_adata.obs["label"] == treatment_label],
|
159
|
+
]
|
145
160
|
)
|
146
|
-
label_encoder = LabelEncoder()
|
147
|
-
adata.obs["y_"] = label_encoder.fit_transform(adata.obs["label"])
|
148
161
|
else:
|
149
|
-
y =
|
162
|
+
y = final_adata.obs["label"].to_frame()
|
150
163
|
y = y.rename(columns={"label": "y_"})
|
151
|
-
|
164
|
+
final_adata.obs = pd.concat([final_adata.obs, y], axis=1)
|
152
165
|
|
153
|
-
|
166
|
+
if layer is not None:
|
167
|
+
if layer not in final_adata.layers:
|
168
|
+
raise ValueError(f"Layer '{layer}' not found in AnnData object")
|
169
|
+
X = final_adata.layers[layer]
|
170
|
+
else:
|
171
|
+
X = final_adata.X
|
172
|
+
|
173
|
+
if not _is_raw_counts(X):
|
174
|
+
logger.warning("Data does not appear to be raw counts. Augur developers recommend using raw counts.")
|
175
|
+
|
176
|
+
return final_adata
|
154
177
|
|
155
178
|
def create_estimator(
|
156
179
|
self,
|
pertpy/tools/_coda/_sccoda.py
CHANGED
@@ -11,7 +11,6 @@ from jax import config, random
|
|
11
11
|
from lamin_utils import logger
|
12
12
|
from mudata import MuData
|
13
13
|
from numpyro.infer import Predictive
|
14
|
-
from rich import print
|
15
14
|
|
16
15
|
from pertpy.tools._coda._base_coda import CompositionalModel2, from_scanpy
|
17
16
|
|
@@ -25,24 +24,6 @@ config.update("jax_enable_x64", True)
|
|
25
24
|
class Sccoda(CompositionalModel2):
|
26
25
|
r"""Statistical model for single-cell differential composition analysis with specification of a reference cell type.
|
27
26
|
|
28
|
-
This is the standard scCODA model and recommended for all uses.
|
29
|
-
|
30
|
-
The hierarchical formulation of the model for one sample is:
|
31
|
-
|
32
|
-
.. math::
|
33
|
-
y|x &\\sim DirMult(\\phi, \\bar{y}) \\\\
|
34
|
-
\\log(\\phi) &= \\alpha + x \\beta \\\\
|
35
|
-
\\alpha_k &\\sim N(0, 5) \\quad &\\forall k \\in [K] \\\\
|
36
|
-
\\beta_{m, \\hat{k}} &= 0 &\\forall m \\in [M]\\\\
|
37
|
-
\\beta_{m, k} &= \\tau_{m, k} \\tilde{\\beta}_{m, k} \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
|
38
|
-
\\tau_{m, k} &= \\frac{\\exp(t_{m, k})}{1+ \\exp(t_{m, k})} \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
|
39
|
-
\\frac{t_{m, k}}{50} &\\sim N(0, 1) \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
|
40
|
-
\\tilde{\\beta}_{m, k} &= \\sigma_m^2 \\cdot \\gamma_{m, k} \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
|
41
|
-
\\sigma_m^2 &\\sim HC(0, 1) \\quad &\\forall m \\in [M] \\\\
|
42
|
-
\\gamma_{m, k} &\\sim N(0,1) \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
|
43
|
-
|
44
|
-
with y being the cell counts and x the covariates.
|
45
|
-
|
46
27
|
For further information, see `scCODA is a Bayesian model for compositional single-cell data analysis`
|
47
28
|
(Büttner, Ostner et al., NatComms, 2021)
|
48
29
|
"""
|
pertpy/tools/_coda/_tasccoda.py
CHANGED
@@ -33,24 +33,6 @@ config.update("jax_enable_x64", True)
|
|
33
33
|
class Tasccoda(CompositionalModel2):
|
34
34
|
r"""Statistical model for tree-aggregated differential composition analysis (tascCODA, Ostner et al., 2021).
|
35
35
|
|
36
|
-
The hierarchical formulation of the model for one sample is:
|
37
|
-
|
38
|
-
.. math::
|
39
|
-
\\begin{align*}
|
40
|
-
Y_i &\\sim \\textrm{DirMult}(\\bar{Y}_i, \\textbf{a}(\\textbf{x})_i)\\\\
|
41
|
-
\\log(\\textbf{a}(X))_i &= \\alpha + X_{i, \\cdot} \\beta\\\\
|
42
|
-
\\alpha_j &\\sim \\mathcal{N}(0, 10) & \\forall j\\in[p]\\\\
|
43
|
-
\\beta &= \\hat{\\beta} A^T \\\\
|
44
|
-
\\hat{\\beta}_{l, k} &= 0 & \\forall k \\in \\hat{v}, l \\in [d]\\\\
|
45
|
-
\\hat{\\beta}_{l, k} &= \\theta \\tilde{\\beta}_{1, l, k} + (1- \\theta) \\tilde{\\beta}_{0, l, k} \\quad & \\forall k\\in\\{[v] \\smallsetminus \\hat{v}\\}, l \\in [d]\\\\
|
46
|
-
\\tilde{\\beta}_{m, l, k} &= \\sigma_{m, l, k} * b_{m, l, k} \\quad & \\forall k\\in\\{[v] \\smallsetminus \\hat{v}\\}, m \\in \\{0, 1\\}, l \\in [d]\\\\
|
47
|
-
\\sigma_{m, l, k} &\\sim \\textrm{Exp}(\\lambda_{m, l, k}^2/2) \\quad & \\forall k\\in\\{[v] \\smallsetminus \\hat{v}\\}, l \\in \\{0, 1\\}, l \\in [d]\\\\
|
48
|
-
b_{m, l, k} &\\sim N(0,1) \\quad & \\forall k\\in\\{[v] \\smallsetminus \\hat{v}\\}, l \\in \\{0, 1\\}, l \\in [d]\\\\
|
49
|
-
\\theta &\\sim \\textrm{Beta}(1, \\frac{1}{|\\{[v] \\smallsetminus \\hat{v}\\}|})
|
50
|
-
\\end{align*}
|
51
|
-
|
52
|
-
with Y being the cell counts, X the covariates, and v the set of nodes of the underlying tree structure.
|
53
|
-
|
54
36
|
For further information, see `tascCODA: Bayesian Tree-Aggregated Analysis of Compositional Amplicon and Single-Cell Data`
|
55
37
|
(Ostner et al., 2021)
|
56
38
|
"""
|
@@ -75,11 +57,14 @@ class Tasccoda(CompositionalModel2):
|
|
75
57
|
modality_key_1: str = "rna",
|
76
58
|
modality_key_2: str = "coda",
|
77
59
|
) -> MuData:
|
78
|
-
"""Prepare a MuData object for subsequent processing.
|
60
|
+
"""Prepare a MuData object for subsequent processing.
|
61
|
+
|
62
|
+
If type is "cell_level", then create a compositional analysis dataset from the input adata.
|
63
|
+
If type is "sample_level", generate ete tree for tascCODA models from dendrogram information or cell-level observations.
|
79
64
|
|
80
|
-
When using
|
65
|
+
When using `type="cell_level"`, `adata` needs to have a column in `adata.obs` that contains the cell type assignment.
|
81
66
|
Further, it must contain one column or a set of columns (e.g. subject id, treatment, disease status) that uniquely identify each (statistical) sample.
|
82
|
-
Further covariates (e.g. subject age) can either be specified via addidional column names in
|
67
|
+
Further covariates (e.g. subject age) can either be specified via addidional column names in `adata.obs`, a key in `adata.uns`, or as a separate DataFrame.
|
83
68
|
|
84
69
|
Args:
|
85
70
|
adata: AnnData object.
|
@@ -90,10 +75,13 @@ class Tasccoda(CompositionalModel2):
|
|
90
75
|
covariate_obs: If type is "cell_level", specify list of keys for adata.obs, where covariate values are stored.
|
91
76
|
covariate_df: If type is "cell_level", specify dataFrame with covariates.
|
92
77
|
dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object.
|
93
|
-
levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels.
|
94
|
-
|
78
|
+
levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels.
|
79
|
+
The list must begin with the root level, and end with the leaf level.
|
80
|
+
levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels.
|
81
|
+
The list must begin with the root level, and end with the leaf level.
|
95
82
|
add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
|
96
|
-
key_added: If not specified, the tree is stored in
|
83
|
+
key_added: If not specified, the tree is stored in `.uns['tree']`.
|
84
|
+
If `data` is AnnData, save tree in `data`. If `data` is MuData, save tree in data[modality_2].
|
97
85
|
modality_key_1: Key to the cell-level AnnData in the MuData object.
|
98
86
|
modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
|
99
87
|
|
pertpy/tools/_mixscape.py
CHANGED
@@ -177,7 +177,7 @@ class Mixscape:
|
|
177
177
|
def mixscape(
|
178
178
|
self,
|
179
179
|
adata: AnnData,
|
180
|
-
|
180
|
+
pert_key: str,
|
181
181
|
control: str,
|
182
182
|
*,
|
183
183
|
new_class_name: str | None = "mixscape_class",
|
@@ -201,12 +201,12 @@ class Mixscape:
|
|
201
201
|
|
202
202
|
Args:
|
203
203
|
adata: The annotated data object.
|
204
|
-
|
204
|
+
pert_key: The column of `.obs` with target gene labels.
|
205
205
|
control: Control category from the `labels` column.
|
206
206
|
new_class_name: Name of mixscape classification to be stored in `.obs`.
|
207
207
|
layer: Key from adata.layers whose value will be used to perform tests on. Default is using `.layers["X_pert"]`.
|
208
208
|
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
209
|
-
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells
|
209
|
+
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
|
210
210
|
de_layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
|
211
211
|
test_method: Method to use for differential expression testing.
|
212
212
|
iter_num: Number of normalmixEM iterations to run if convergence does not occur.
|
@@ -256,7 +256,7 @@ class Mixscape:
|
|
256
256
|
adata=adata,
|
257
257
|
split_masks=split_masks,
|
258
258
|
categories=categories,
|
259
|
-
|
259
|
+
pert_key=pert_key,
|
260
260
|
control=control,
|
261
261
|
layer=de_layer,
|
262
262
|
pval_cutoff=pval_cutoff,
|
@@ -278,7 +278,7 @@ class Mixscape:
|
|
278
278
|
|
279
279
|
# initialize return variables
|
280
280
|
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
|
281
|
-
adata.obs[new_class_name] = adata.obs[
|
281
|
+
adata.obs[new_class_name] = adata.obs[pert_key].astype(str)
|
282
282
|
adata.obs[f"{new_class_name}_global"] = np.empty(
|
283
283
|
[
|
284
284
|
adata.n_obs,
|
@@ -290,12 +290,12 @@ class Mixscape:
|
|
290
290
|
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0.0
|
291
291
|
for split, split_mask in enumerate(split_masks):
|
292
292
|
category = categories[split]
|
293
|
-
gene_targets = list(set(adata[split_mask].obs[
|
293
|
+
gene_targets = list(set(adata[split_mask].obs[pert_key]).difference([control]))
|
294
294
|
for gene in gene_targets:
|
295
295
|
post_prob = 0
|
296
|
-
orig_guide_cells = (adata.obs[
|
296
|
+
orig_guide_cells = (adata.obs[pert_key] == gene) & split_mask
|
297
297
|
orig_guide_cells_index = list(orig_guide_cells.index[orig_guide_cells])
|
298
|
-
nt_cells = (adata.obs[
|
298
|
+
nt_cells = (adata.obs[pert_key] == control) & split_mask
|
299
299
|
all_cells = orig_guide_cells | nt_cells
|
300
300
|
|
301
301
|
if len(perturbation_markers[(category, gene)]) == 0:
|
@@ -307,7 +307,11 @@ class Mixscape:
|
|
307
307
|
|
308
308
|
dat = X[np.asarray(all_cells)][:, de_genes_indices]
|
309
309
|
if scale:
|
310
|
-
|
310
|
+
with warnings.catch_warnings():
|
311
|
+
warnings.filterwarnings(
|
312
|
+
"ignore", message="zero-centering a sparse array/matrix densifies it."
|
313
|
+
)
|
314
|
+
dat = sc.pp.scale(dat)
|
311
315
|
|
312
316
|
converged = False
|
313
317
|
n_iter = 0
|
@@ -335,10 +339,10 @@ class Mixscape:
|
|
335
339
|
pvec = pd.Series(np.asarray(pvec).flatten(), index=list(all_cells.index[all_cells]))
|
336
340
|
|
337
341
|
if n_iter == 0:
|
338
|
-
gv = pd.DataFrame(columns=["pvec",
|
342
|
+
gv = pd.DataFrame(columns=["pvec", pert_key])
|
339
343
|
gv["pvec"] = pvec
|
340
|
-
gv[
|
341
|
-
gv.loc[guide_cells,
|
344
|
+
gv[pert_key] = control
|
345
|
+
gv.loc[guide_cells, pert_key] = gene
|
342
346
|
if gene not in gv_list:
|
343
347
|
gv_list[gene] = {}
|
344
348
|
gv_list[gene][category] = gv
|
@@ -389,7 +393,7 @@ class Mixscape:
|
|
389
393
|
def lda(
|
390
394
|
self,
|
391
395
|
adata: AnnData,
|
392
|
-
|
396
|
+
pert_key: str,
|
393
397
|
control: str,
|
394
398
|
*,
|
395
399
|
mixscape_class_global: str | None = "mixscape_class_global",
|
@@ -407,7 +411,7 @@ class Mixscape:
|
|
407
411
|
|
408
412
|
Args:
|
409
413
|
adata: The annotated data object.
|
410
|
-
|
414
|
+
pert_key: The column of `.obs` with target gene labels.
|
411
415
|
control: Control category from the `pert_key` column.
|
412
416
|
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
|
413
417
|
layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
|
@@ -456,7 +460,7 @@ class Mixscape:
|
|
456
460
|
adata=adata,
|
457
461
|
split_masks=split_masks,
|
458
462
|
categories=categories,
|
459
|
-
|
463
|
+
pert_key=pert_key,
|
460
464
|
control=control,
|
461
465
|
layer=layer,
|
462
466
|
pval_cutoff=pval_cutoff,
|
@@ -475,17 +479,19 @@ class Mixscape:
|
|
475
479
|
continue
|
476
480
|
else:
|
477
481
|
gene_subset = adata_subset[
|
478
|
-
(adata_subset.obs[
|
482
|
+
(adata_subset.obs[pert_key] == key[1]) | (adata_subset.obs[pert_key] == control)
|
479
483
|
].copy()
|
480
|
-
|
484
|
+
with warnings.catch_warnings():
|
485
|
+
warnings.simplefilter("ignore", UserWarning)
|
486
|
+
sc.pp.scale(gene_subset)
|
481
487
|
sc.tl.pca(gene_subset, n_comps=n_comps)
|
482
488
|
# project cells into PCA space of gene_subset
|
483
489
|
projected_pcs[key[1]] = np.asarray(np.dot(X, gene_subset.varm["PCs"]))
|
484
490
|
# concatenate all pcs into a single matrix.
|
485
491
|
projected_pcs_array = np.concatenate(list(projected_pcs.values()), axis=1)
|
486
492
|
|
487
|
-
clf = LinearDiscriminantAnalysis(n_components=len(np.unique(adata_subset.obs[
|
488
|
-
clf.fit(projected_pcs_array, adata_subset.obs[
|
493
|
+
clf = LinearDiscriminantAnalysis(n_components=len(np.unique(adata_subset.obs[pert_key])) - 1)
|
494
|
+
clf.fit(projected_pcs_array, adata_subset.obs[pert_key])
|
489
495
|
cell_embeddings = clf.transform(projected_pcs_array)
|
490
496
|
adata.uns["mixscape_lda"] = cell_embeddings
|
491
497
|
|
@@ -495,9 +501,10 @@ class Mixscape:
|
|
495
501
|
def _get_perturbation_markers(
|
496
502
|
self,
|
497
503
|
adata: AnnData,
|
504
|
+
*,
|
498
505
|
split_masks: list[np.ndarray],
|
499
506
|
categories: list[str],
|
500
|
-
|
507
|
+
pert_key: str,
|
501
508
|
control: str,
|
502
509
|
layer: str,
|
503
510
|
pval_cutoff: float,
|
@@ -511,7 +518,7 @@ class Mixscape:
|
|
511
518
|
adata: :class:`~anndata.AnnData` object
|
512
519
|
split_masks: List of boolean masks for each split/group.
|
513
520
|
categories: List of split/group names.
|
514
|
-
|
521
|
+
pert_key: The column of `.obs` with target gene labels.
|
515
522
|
control: Control category from the `labels` column.
|
516
523
|
layer: Key from adata.layers whose value will be used to compare gene expression.
|
517
524
|
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
@@ -526,7 +533,7 @@ class Mixscape:
|
|
526
533
|
for split, split_mask in enumerate(split_masks):
|
527
534
|
category = categories[split]
|
528
535
|
# get gene sets for each split
|
529
|
-
gene_targets = list(set(adata[split_mask].obs[
|
536
|
+
gene_targets = list(set(adata[split_mask].obs[pert_key]).difference([control]))
|
530
537
|
adata_split = adata[split_mask].copy()
|
531
538
|
# find top DE genes between cells with targeting and non-targeting gRNAs
|
532
539
|
with warnings.catch_warnings():
|
@@ -535,7 +542,7 @@ class Mixscape:
|
|
535
542
|
sc.tl.rank_genes_groups(
|
536
543
|
adata_split,
|
537
544
|
layer=layer,
|
538
|
-
groupby=
|
545
|
+
groupby=pert_key,
|
539
546
|
groups=gene_targets,
|
540
547
|
reference=control,
|
541
548
|
method=test_method,
|
@@ -666,7 +673,7 @@ class Mixscape:
|
|
666
673
|
def plot_heatmap( # pragma: no cover # noqa: D417
|
667
674
|
self,
|
668
675
|
adata: AnnData,
|
669
|
-
|
676
|
+
pert_key: str,
|
670
677
|
target_gene: str,
|
671
678
|
control: str,
|
672
679
|
*,
|
@@ -682,7 +689,7 @@ class Mixscape:
|
|
682
689
|
|
683
690
|
Args:
|
684
691
|
adata: The annotated data object.
|
685
|
-
|
692
|
+
pert_key: The column of `.obs` with target gene labels.
|
686
693
|
target_gene: Target gene name to visualize heatmap for.
|
687
694
|
control: Control category from the `pert_key` column.
|
688
695
|
layer: Key from `adata.layers` whose value will be used to perform tests on.
|
@@ -711,12 +718,13 @@ class Mixscape:
|
|
711
718
|
"""
|
712
719
|
if "mixscape_class" not in adata.obs:
|
713
720
|
raise ValueError("Please run `pt.tl.mixscape` first.")
|
714
|
-
adata_subset = adata[(adata.obs[
|
721
|
+
adata_subset = adata[(adata.obs[pert_key] == target_gene) | (adata.obs[pert_key] == control)].copy()
|
715
722
|
with warnings.catch_warnings():
|
716
723
|
warnings.simplefilter("ignore", RuntimeWarning)
|
717
724
|
warnings.simplefilter("ignore", PerformanceWarning)
|
718
|
-
|
719
|
-
|
725
|
+
warnings.simplefilter("ignore", UserWarning)
|
726
|
+
sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=pert_key, method=method)
|
727
|
+
sc.pp.scale(adata_subset, max_value=vmax)
|
720
728
|
sc.pp.subsample(adata_subset, n_obs=subsample_number)
|
721
729
|
|
722
730
|
fig = sc.pl.rank_genes_groups_heatmap(
|
@@ -739,7 +747,7 @@ class Mixscape:
|
|
739
747
|
def plot_perturbscore( # pragma: no cover # noqa: D417
|
740
748
|
self,
|
741
749
|
adata: AnnData,
|
742
|
-
|
750
|
+
pert_key: str,
|
743
751
|
target_gene: str,
|
744
752
|
*,
|
745
753
|
mixscape_class: str = "mixscape_class",
|
@@ -758,7 +766,7 @@ class Mixscape:
|
|
758
766
|
|
759
767
|
Args:
|
760
768
|
adata: The annotated data object.
|
761
|
-
|
769
|
+
pert_key: The column of `.obs` with target gene labels.
|
762
770
|
target_gene: Target gene name to visualize perturbation scores for.
|
763
771
|
mixscape_class: The column of `.obs` with mixscape classifications.
|
764
772
|
color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey.
|
@@ -797,21 +805,21 @@ class Mixscape:
|
|
797
805
|
else:
|
798
806
|
perturbation_score = pd.concat([perturbation_score, perturbation_score_temp])
|
799
807
|
perturbation_score["mix"] = adata.obs[mixscape_class][perturbation_score.index]
|
800
|
-
gd = list(set(perturbation_score[
|
808
|
+
gd = list(set(perturbation_score[pert_key]).difference({target_gene}))[0]
|
801
809
|
|
802
810
|
# If before_mixscape is True, split densities based on original target gene classification
|
803
811
|
if before_mixscape is True:
|
804
812
|
palette = {gd: "#7d7d7d", target_gene: color}
|
805
|
-
plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=
|
813
|
+
plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=pert_key, fill=False, common_norm=False)
|
806
814
|
top_r = max(plot_dens.get_lines()[cond].get_data()[1].max() for cond in range(len(plot_dens.get_lines())))
|
807
815
|
plt.close()
|
808
816
|
perturbation_score["y_jitter"] = perturbation_score["pvec"]
|
809
817
|
rng = np.random.default_rng()
|
810
|
-
perturbation_score.loc[perturbation_score[
|
811
|
-
low=0.001, high=top_r / 10, size=sum(perturbation_score[
|
818
|
+
perturbation_score.loc[perturbation_score[pert_key] == gd, "y_jitter"] = rng.uniform(
|
819
|
+
low=0.001, high=top_r / 10, size=sum(perturbation_score[pert_key] == gd)
|
812
820
|
)
|
813
|
-
perturbation_score.loc[perturbation_score[
|
814
|
-
low=-top_r / 10, high=0, size=sum(perturbation_score[
|
821
|
+
perturbation_score.loc[perturbation_score[pert_key] == target_gene, "y_jitter"] = rng.uniform(
|
822
|
+
low=-top_r / 10, high=0, size=sum(perturbation_score[pert_key] == target_gene)
|
815
823
|
)
|
816
824
|
# If split_by is provided, split densities based on the split_by
|
817
825
|
if split_by is not None:
|
@@ -844,7 +852,7 @@ class Mixscape:
|
|
844
852
|
else:
|
845
853
|
if palette is None:
|
846
854
|
palette = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color}
|
847
|
-
plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=
|
855
|
+
plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=pert_key, fill=False, common_norm=False)
|
848
856
|
top_r = max(plot_dens.get_lines()[i].get_data()[1].max() for i in range(len(plot_dens.get_lines())))
|
849
857
|
plt.close()
|
850
858
|
perturbation_score["y_jitter"] = perturbation_score["pvec"]
|
@@ -899,6 +907,7 @@ class Mixscape:
|
|
899
907
|
if return_fig:
|
900
908
|
return plt.gcf()
|
901
909
|
plt.show()
|
910
|
+
|
902
911
|
return None
|
903
912
|
|
904
913
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1058,7 +1067,7 @@ class Mixscape:
|
|
1058
1067
|
data=obs_tidy,
|
1059
1068
|
order=order,
|
1060
1069
|
orient="vertical",
|
1061
|
-
|
1070
|
+
density_norm=scale,
|
1062
1071
|
ax=ax,
|
1063
1072
|
hue=hue,
|
1064
1073
|
**kwargs,
|
@@ -1072,7 +1081,7 @@ class Mixscape:
|
|
1072
1081
|
data=obs_tidy,
|
1073
1082
|
order=order,
|
1074
1083
|
jitter=jitter,
|
1075
|
-
|
1084
|
+
palette="dark:black",
|
1076
1085
|
size=size,
|
1077
1086
|
ax=ax,
|
1078
1087
|
hue=hue,
|
@@ -22,7 +22,7 @@ class PerturbationComparison:
|
|
22
22
|
) -> float:
|
23
23
|
"""Compare classification accuracy between real and simulated perturbations.
|
24
24
|
|
25
|
-
Trains a classifier on the real perturbation data
|
25
|
+
Trains a classifier on the real perturbation data & the control data and reports a normalized
|
26
26
|
classification accuracy on the simulated perturbation.
|
27
27
|
|
28
28
|
Args:
|
@@ -64,8 +64,8 @@ class PerturbationComparison:
|
|
64
64
|
real: Real perturbed data.
|
65
65
|
simulated: Simulated perturbed data.
|
66
66
|
control: Control data
|
67
|
-
use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph.
|
68
|
-
control (`control`) is provided.
|
67
|
+
use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph.
|
68
|
+
Only valid when control (`control`) is provided.
|
69
69
|
n_neighbors: Number of neighbors to use in k-neighbor graph.
|
70
70
|
random_state: Random state used for k-neighbor graph construction.
|
71
71
|
n_jobs: Number of cores to use. Defaults to -1 (all).
|