pertpy 0.11.4__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +4 -1
- pertpy/tools/_coda/_base_coda.py +1 -1
- pertpy/tools/_distances/_distances.py +3 -2
- pertpy/tools/_milo.py +138 -51
- pertpy/tools/_mixscape.py +42 -39
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +1 -1
- pertpy/tools/_perturbation_space/_perturbation_space.py +1 -1
- pertpy/tools/_scgen/_scgen.py +2 -1
- {pertpy-0.11.4.dist-info → pertpy-1.0.0.dist-info}/METADATA +15 -2
- {pertpy-0.11.4.dist-info → pertpy-1.0.0.dist-info}/RECORD +12 -12
- {pertpy-0.11.4.dist-info → pertpy-1.0.0.dist-info}/WHEEL +0 -0
- {pertpy-0.11.4.dist-info → pertpy-1.0.0.dist-info}/licenses/LICENSE +0 -0
pertpy/__init__.py
CHANGED
@@ -2,10 +2,11 @@
|
|
2
2
|
|
3
3
|
__author__ = "Lukas Heumos"
|
4
4
|
__email__ = "lukas.heumos@posteo.net"
|
5
|
-
__version__ = "0.
|
5
|
+
__version__ = "1.0.0"
|
6
6
|
|
7
7
|
import warnings
|
8
8
|
|
9
|
+
from anndata._core.aligned_df import ImplicitModificationWarning
|
9
10
|
from matplotlib import MatplotlibDeprecationWarning
|
10
11
|
from numba import NumbaDeprecationWarning
|
11
12
|
|
@@ -13,6 +14,8 @@ warnings.filterwarnings("ignore", category=NumbaDeprecationWarning)
|
|
13
14
|
warnings.filterwarnings("ignore", category=MatplotlibDeprecationWarning)
|
14
15
|
warnings.filterwarnings("ignore", category=SyntaxWarning)
|
15
16
|
warnings.filterwarnings("ignore", category=UserWarning, module="scvi._settings")
|
17
|
+
warnings.filterwarnings("ignore", message="Environment variable.*redefined by R")
|
18
|
+
warnings.filterwarnings("ignore", message="Transforming to str index.", category=ImplicitModificationWarning)
|
16
19
|
|
17
20
|
import mudata
|
18
21
|
|
pertpy/tools/_coda/_base_coda.py
CHANGED
@@ -1538,7 +1538,7 @@ class CompositionalModel2(ABC):
|
|
1538
1538
|
if isinstance(data, MuData):
|
1539
1539
|
data = data[modality_key]
|
1540
1540
|
if isinstance(palette, Colormap):
|
1541
|
-
palette = palette(range(
|
1541
|
+
palette = list(palette(range(len(data.obs[feature_name].unique()))))
|
1542
1542
|
|
1543
1543
|
# y scale transformations
|
1544
1544
|
if y_scale == "relative":
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from typing import TYPE_CHECKING, Literal, NamedTuple
|
5
5
|
|
6
|
+
import jax
|
6
7
|
import numpy as np
|
7
8
|
import pandas as pd
|
8
9
|
from numba import jit
|
@@ -685,6 +686,7 @@ class WassersteinDistance(AbstractDistance):
|
|
685
686
|
def __init__(self) -> None:
|
686
687
|
super().__init__()
|
687
688
|
self.accepts_precomputed = False
|
689
|
+
self.solver = jax.jit(Sinkhorn())
|
688
690
|
|
689
691
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
690
692
|
X = np.asarray(X, dtype=np.float64)
|
@@ -699,8 +701,7 @@ class WassersteinDistance(AbstractDistance):
|
|
699
701
|
|
700
702
|
def solve_ot_problem(self, geom: Geometry, **kwargs):
|
701
703
|
ot_prob = LinearProblem(geom)
|
702
|
-
|
703
|
-
ot = solver(ot_prob, **kwargs)
|
704
|
+
ot = self.solver(ot_prob, **kwargs)
|
704
705
|
cost = float(ot.reg_ot_cost)
|
705
706
|
|
706
707
|
# Check for NaN or invalid cost
|
pertpy/tools/_milo.py
CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import random
|
4
4
|
import re
|
5
|
+
from importlib.util import find_spec
|
5
6
|
from typing import TYPE_CHECKING, Literal
|
6
7
|
|
7
8
|
import matplotlib.pyplot as plt
|
@@ -29,18 +30,6 @@ from sklearn.metrics.pairwise import euclidean_distances
|
|
29
30
|
class Milo:
|
30
31
|
"""Python implementation of Milo."""
|
31
32
|
|
32
|
-
def __init__(self):
|
33
|
-
try:
|
34
|
-
from rpy2.robjects import conversion, numpy2ri, pandas2ri
|
35
|
-
from rpy2.robjects.packages import STAP, PackageNotInstalledError, importr
|
36
|
-
except ModuleNotFoundError:
|
37
|
-
raise ImportError("milo requires rpy2 to be installed.") from None
|
38
|
-
|
39
|
-
try:
|
40
|
-
importr("edgeR")
|
41
|
-
except ImportError as e:
|
42
|
-
raise ImportError("milo requires a valid R installation with edger installed:\n") from e
|
43
|
-
|
44
33
|
def load(
|
45
34
|
self,
|
46
35
|
input: AnnData,
|
@@ -266,7 +255,7 @@ class Milo:
|
|
266
255
|
subset_samples: list[str] | None = None,
|
267
256
|
add_intercept: bool = True,
|
268
257
|
feature_key: str | None = "rna",
|
269
|
-
solver: Literal["edger", "
|
258
|
+
solver: Literal["edger", "pydeseq2"] = "edger",
|
270
259
|
):
|
271
260
|
"""Performs differential abundance testing on neighbourhoods using QLF test implementation as implemented in edgeR.
|
272
261
|
|
@@ -279,7 +268,9 @@ class Milo:
|
|
279
268
|
subset_samples: subset of samples (obs in `milo_mdata['milo']`) to use for the test.
|
280
269
|
add_intercept: whether to include an intercept in the model. If False, this is equivalent to adding + 0 in the design formula. When model_contrasts is specified, this is set to False by default.
|
281
270
|
feature_key: If input data is MuData, specify key to cell-level AnnData object.
|
282
|
-
solver: The solver to fit the model to.
|
271
|
+
solver: The solver to fit the model to.
|
272
|
+
The "edger" solver requires R, rpy2 and edgeR to be installed and is the closest to the R implementation.
|
273
|
+
The "pydeseq2" requires pydeseq2 to be installed. It is still very comparable to the "edger" solver but might be a bit slower.
|
283
274
|
|
284
275
|
Returns:
|
285
276
|
None, modifies `milo_mdata['milo']` in place, adding the results of the DA test to `.var`:
|
@@ -298,7 +289,6 @@ class Milo:
|
|
298
289
|
>>> milo.make_nhoods(mdata["rna"])
|
299
290
|
>>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
|
300
291
|
>>> milo.da_nhoods(mdata, design="~label")
|
301
|
-
|
302
292
|
"""
|
303
293
|
try:
|
304
294
|
sample_adata = mdata["milo"]
|
@@ -364,19 +354,32 @@ class Milo:
|
|
364
354
|
# Set up rpy2 to run edgeR
|
365
355
|
edgeR, limma, stats, base = self._setup_rpy2()
|
366
356
|
|
357
|
+
import rpy2.robjects as ro
|
358
|
+
from rpy2.robjects import numpy2ri, pandas2ri
|
359
|
+
from rpy2.robjects.conversion import localconverter
|
360
|
+
from rpy2.robjects.vectors import FloatVector
|
361
|
+
|
367
362
|
# Define model matrix
|
368
363
|
if not add_intercept or model_contrasts is not None:
|
369
364
|
design = design + " + 0"
|
370
|
-
|
365
|
+
design_df = design_df.astype(dict.fromkeys(design_df.select_dtypes(exclude=["number"]).columns, "category"))
|
366
|
+
with localconverter(ro.default_converter + pandas2ri.converter):
|
367
|
+
design_r = pandas2ri.py2rpy(design_df)
|
368
|
+
formula_r = stats.formula(design)
|
369
|
+
model = stats.model_matrix(object=formula_r, data=design_r)
|
371
370
|
|
372
371
|
# Fit NB-GLM
|
373
|
-
|
372
|
+
counts_filtered = count_mat[np.ix_(keep_nhoods, keep_smp)]
|
373
|
+
lib_size_filtered = lib_size[keep_smp]
|
374
|
+
count_mat_r = numpy2ri.py2rpy(counts_filtered)
|
375
|
+
lib_size_r = FloatVector(lib_size_filtered)
|
376
|
+
dge = edgeR.DGEList(counts=count_mat_r, lib_size=lib_size_r)
|
374
377
|
dge = edgeR.calcNormFactors(dge, method="TMM")
|
375
378
|
dge = edgeR.estimateDisp(dge, model)
|
376
379
|
fit = edgeR.glmQLFit(dge, model, robust=True)
|
377
|
-
|
378
380
|
# Test
|
379
|
-
|
381
|
+
model_np = np.array(model)
|
382
|
+
n_coef = model_np.shape[1]
|
380
383
|
if model_contrasts is not None:
|
381
384
|
r_str = """
|
382
385
|
get_model_cols <- function(design_df, design){
|
@@ -387,34 +390,90 @@ class Milo:
|
|
387
390
|
from rpy2.robjects.packages import STAP
|
388
391
|
|
389
392
|
get_model_cols = STAP(r_str, "get_model_cols")
|
390
|
-
|
391
|
-
|
393
|
+
with localconverter(ro.default_converter + numpy2ri.converter + pandas2ri.converter):
|
394
|
+
model_mat_cols = get_model_cols.get_model_cols(design_df, design)
|
395
|
+
with localconverter(ro.default_converter + pandas2ri.converter + numpy2ri.converter):
|
396
|
+
model_df = pandas2ri.rpy2py(model)
|
397
|
+
model_df = pd.DataFrame(model_df)
|
392
398
|
model_df.columns = model_mat_cols
|
393
399
|
try:
|
394
|
-
|
400
|
+
with localconverter(ro.default_converter + pandas2ri.converter):
|
401
|
+
mod_contrast = limma.makeContrasts(contrasts=model_contrasts, levels=model_df)
|
395
402
|
except ValueError:
|
396
403
|
logger.error("Model contrasts must be in the form 'A-B' or 'A+B'")
|
397
404
|
raise
|
398
|
-
|
399
|
-
|
400
|
-
|
405
|
+
with localconverter(ro.default_converter + pandas2ri.converter + numpy2ri.converter):
|
406
|
+
res = base.as_data_frame(
|
407
|
+
edgeR.topTags(edgeR.glmQLFTest(fit, contrast=mod_contrast), sort_by="none", n=np.inf)
|
408
|
+
)
|
401
409
|
else:
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
res = conversion.rpy2py(res)
|
410
|
+
with localconverter(ro.default_converter + numpy2ri.converter + pandas2ri.converter):
|
411
|
+
res = base.as_data_frame(
|
412
|
+
edgeR.topTags(edgeR.glmQLFTest(fit, coef=n_coef), sort_by="none", n=np.inf)
|
413
|
+
)
|
407
414
|
if not isinstance(res, pd.DataFrame):
|
408
415
|
res = pd.DataFrame(res)
|
416
|
+
# The columns of res looks like e.g. table.A, table.B, so remove the prefix
|
417
|
+
res.columns = [col.replace("table.", "") for col in res.columns]
|
418
|
+
elif solver == "pydeseq2":
|
419
|
+
if find_spec("pydeseq2") is None:
|
420
|
+
raise ImportError("pydeseq2 is required but not installed. Install with: pip install pydeseq2")
|
421
|
+
|
422
|
+
from pydeseq2.dds import DeseqDataSet
|
423
|
+
from pydeseq2.ds import DeseqStats
|
424
|
+
|
425
|
+
counts_filtered = count_mat[np.ix_(keep_nhoods, keep_smp)]
|
426
|
+
design_df_filtered = design_df.iloc[keep_smp].copy()
|
427
|
+
|
428
|
+
design_df_filtered = design_df_filtered.astype(
|
429
|
+
dict.fromkeys(design_df_filtered.select_dtypes(exclude=["number"]).columns, "category")
|
430
|
+
)
|
431
|
+
|
432
|
+
design_clean = design if design.startswith("~") else f"~{design}"
|
433
|
+
|
434
|
+
dds = DeseqDataSet(
|
435
|
+
counts=pd.DataFrame(counts_filtered.T, index=design_df_filtered.index),
|
436
|
+
metadata=design_df_filtered,
|
437
|
+
design=design_clean,
|
438
|
+
refit_cooks=True,
|
439
|
+
)
|
440
|
+
|
441
|
+
dds.deseq2()
|
442
|
+
|
443
|
+
if model_contrasts is not None and "-" in model_contrasts:
|
444
|
+
if "(" in model_contrasts or "+" in model_contrasts.split("-")[1]:
|
445
|
+
raise ValueError(
|
446
|
+
f"Complex contrasts like '{model_contrasts}' are not supported by pydeseq2. "
|
447
|
+
"Use simple pairwise contrasts (e.g., 'GroupA-GroupB') or switch to solver='edger'."
|
448
|
+
)
|
449
|
+
|
450
|
+
parts = model_contrasts.split("-")
|
451
|
+
factor_name = design_clean.replace("~", "").split("+")[-1].strip()
|
452
|
+
group1 = parts[0].replace(factor_name, "").strip()
|
453
|
+
group2 = parts[1].replace(factor_name, "").strip()
|
454
|
+
stat_res = DeseqStats(dds, contrast=[factor_name, group1, group2])
|
455
|
+
else:
|
456
|
+
factor_name = design_clean.replace("~", "").split("+")[-1].strip()
|
457
|
+
if not isinstance(design_df_filtered[factor_name], pd.CategoricalDtype):
|
458
|
+
design_df_filtered[factor_name] = design_df_filtered[factor_name].astype("category")
|
459
|
+
categories = design_df_filtered[factor_name].cat.categories
|
460
|
+
stat_res = DeseqStats(dds, contrast=[factor_name, categories[-1], categories[0]])
|
461
|
+
|
462
|
+
stat_res.summary()
|
463
|
+
res = stat_res.results_df
|
464
|
+
|
465
|
+
res = res.rename(
|
466
|
+
columns={"baseMean": "logCPM", "log2FoldChange": "logFC", "pvalue": "PValue", "padj": "FDR"}
|
467
|
+
)
|
468
|
+
|
469
|
+
res = res[["logCPM", "logFC", "PValue", "FDR"]]
|
409
470
|
|
410
|
-
# Save outputs
|
411
471
|
res.index = sample_adata.var_names[keep_nhoods] # type: ignore
|
412
472
|
if any(col in sample_adata.var.columns for col in res.columns):
|
413
473
|
sample_adata.var = sample_adata.var.drop(res.columns, axis=1)
|
414
474
|
sample_adata.var = pd.concat([sample_adata.var, res], axis=1)
|
415
475
|
|
416
|
-
|
417
|
-
self._graph_spatial_fdr(sample_adata, neighbors_key=adata.uns["nhood_neighbors_key"])
|
476
|
+
self._graph_spatial_fdr(sample_adata)
|
418
477
|
|
419
478
|
def annotate_nhoods(
|
420
479
|
self,
|
@@ -657,11 +716,19 @@ class Milo:
|
|
657
716
|
self,
|
658
717
|
):
|
659
718
|
"""Set up rpy2 to run edgeR."""
|
660
|
-
|
719
|
+
try:
|
720
|
+
from rpy2.robjects import conversion, numpy2ri, pandas2ri
|
721
|
+
from rpy2.robjects.packages import STAP, PackageNotInstalledError, importr
|
722
|
+
except ModuleNotFoundError:
|
723
|
+
raise ImportError("milo requires rpy2 to be installed.") from None
|
724
|
+
|
725
|
+
try:
|
726
|
+
importr("edgeR")
|
727
|
+
except ImportError as e:
|
728
|
+
raise ImportError("milo requires a valid R installation with edger installed.") from e
|
729
|
+
|
661
730
|
from rpy2.robjects.packages import importr
|
662
731
|
|
663
|
-
numpy2ri.activate()
|
664
|
-
pandas2ri.activate()
|
665
732
|
edgeR = self._try_import_bioc_library("edgeR")
|
666
733
|
limma = self._try_import_bioc_library("limma")
|
667
734
|
stats = importr("stats")
|
@@ -671,26 +738,27 @@ class Milo:
|
|
671
738
|
|
672
739
|
def _try_import_bioc_library(
|
673
740
|
self,
|
674
|
-
|
741
|
+
r_package: str,
|
675
742
|
):
|
676
743
|
"""Import R packages.
|
677
744
|
|
678
745
|
Args:
|
679
|
-
|
746
|
+
r_package: R packages name
|
680
747
|
"""
|
681
748
|
from rpy2.robjects.packages import PackageNotInstalledError, importr
|
682
749
|
|
683
750
|
try:
|
684
|
-
_r_lib = importr(
|
751
|
+
_r_lib = importr(r_package)
|
685
752
|
return _r_lib
|
686
753
|
except PackageNotInstalledError:
|
687
|
-
logger.error(
|
754
|
+
logger.error(
|
755
|
+
f"Install Bioconductor library `{r_package!r}` first as `BiocManager::install({r_package!r}).`"
|
756
|
+
)
|
688
757
|
raise
|
689
758
|
|
690
759
|
def _graph_spatial_fdr(
|
691
760
|
self,
|
692
761
|
sample_adata: AnnData,
|
693
|
-
neighbors_key: str | None = None,
|
694
762
|
):
|
695
763
|
"""FDR correction weighted on inverse of connectivity of neighbourhoods.
|
696
764
|
|
@@ -698,7 +766,6 @@ class Milo:
|
|
698
766
|
|
699
767
|
Args:
|
700
768
|
sample_adata: Sample-level AnnData.
|
701
|
-
neighbors_key: The key in `adata.obsp` to use as KNN graph.
|
702
769
|
"""
|
703
770
|
# use 1/connectivity as the weighting for the weighted BH adjustment from Cydar
|
704
771
|
w = 1 / sample_adata.var["kth_distance"]
|
@@ -1007,6 +1074,8 @@ class Milo:
|
|
1007
1074
|
subset_nhoods: list[str] = None,
|
1008
1075
|
log_counts: bool = False,
|
1009
1076
|
return_fig: bool = False,
|
1077
|
+
ax=None,
|
1078
|
+
show: bool = True,
|
1010
1079
|
) -> Figure | None:
|
1011
1080
|
"""Plot boxplot of cell numbers vs condition of interest.
|
1012
1081
|
|
@@ -1036,18 +1105,36 @@ class Milo:
|
|
1036
1105
|
pl_df = pd.merge(pl_df, nhood_adata.var)
|
1037
1106
|
pl_df["log_n_cells"] = np.log1p(pl_df["n_cells"])
|
1038
1107
|
if not log_counts:
|
1039
|
-
sns.boxplot(data=pl_df, x=test_var, y="n_cells", color="lightblue")
|
1040
|
-
sns.stripplot(data=pl_df, x=test_var, y="n_cells", color="black", s=3)
|
1041
|
-
|
1108
|
+
sns.boxplot(data=pl_df, x=test_var, y="n_cells", color="lightblue", ax=ax)
|
1109
|
+
sns.stripplot(data=pl_df, x=test_var, y="n_cells", color="black", s=3, ax=ax)
|
1110
|
+
if ax:
|
1111
|
+
ax.set_ylabel("# cells")
|
1112
|
+
else:
|
1113
|
+
plt.ylabel("# cells")
|
1114
|
+
else:
|
1115
|
+
sns.boxplot(data=pl_df, x=test_var, y="log_n_cells", color="lightblue", ax=ax)
|
1116
|
+
sns.stripplot(data=pl_df, x=test_var, y="log_n_cells", color="black", s=3, ax=ax)
|
1117
|
+
if ax:
|
1118
|
+
ax.set_ylabel("log(# cells + 1)")
|
1119
|
+
else:
|
1120
|
+
plt.ylabel("log(# cells + 1)")
|
1121
|
+
|
1122
|
+
if ax:
|
1123
|
+
ax.tick_params(axis="x", rotation=90)
|
1124
|
+
ax.set_xlabel(test_var)
|
1042
1125
|
else:
|
1043
|
-
|
1044
|
-
|
1045
|
-
plt.ylabel("log(# cells + 1)")
|
1126
|
+
plt.xticks(rotation=90)
|
1127
|
+
plt.xlabel(test_var)
|
1046
1128
|
|
1047
|
-
|
1048
|
-
|
1129
|
+
if return_fig:
|
1130
|
+
return plt.gcf()
|
1131
|
+
|
1132
|
+
if ax is None:
|
1133
|
+
plt.show()
|
1049
1134
|
|
1050
1135
|
if return_fig:
|
1051
1136
|
return plt.gcf()
|
1052
|
-
|
1137
|
+
if show:
|
1138
|
+
plt.show()
|
1139
|
+
|
1053
1140
|
return None
|
pertpy/tools/_mixscape.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import copy
|
4
|
+
import warnings
|
4
5
|
from collections import OrderedDict
|
5
6
|
from typing import TYPE_CHECKING, Literal
|
6
7
|
|
@@ -10,11 +11,12 @@ import pandas as pd
|
|
10
11
|
import scanpy as sc
|
11
12
|
import seaborn as sns
|
12
13
|
from fast_array_utils.stats import mean, mean_var
|
14
|
+
from pandas.errors import PerformanceWarning
|
13
15
|
from scanpy import get
|
14
16
|
from scanpy._utils import _check_use_raw, sanitize_anndata
|
15
17
|
from scanpy.plotting import _utils
|
16
18
|
from scanpy.tools._utils import _choose_representation
|
17
|
-
from scipy.sparse import csr_matrix, spmatrix
|
19
|
+
from scipy.sparse import csr_matrix, issparse, spmatrix
|
18
20
|
from sklearn.mixture import GaussianMixture
|
19
21
|
|
20
22
|
from pertpy._doc import _doc_params, doc_common_plot_args
|
@@ -103,6 +105,9 @@ class Mixscape:
|
|
103
105
|
|
104
106
|
adata.layers["X_pert"] = adata.X.copy()
|
105
107
|
|
108
|
+
# Work with LIL for efficient indexing but don't store it in AnnData as LIL is not supported anymore
|
109
|
+
X_pert_lil = adata.layers["X_pert"].tolil() if issparse(adata.layers["X_pert"]) else adata.layers["X_pert"]
|
110
|
+
|
106
111
|
control_mask = adata.obs[pert_key] == control
|
107
112
|
|
108
113
|
if ref_selection_mode == "split_by":
|
@@ -110,9 +115,8 @@ class Mixscape:
|
|
110
115
|
split_mask = adata.obs[split_by] == split
|
111
116
|
control_mask_group = control_mask & split_mask
|
112
117
|
control_mean_expr = mean(adata.X[control_mask_group], axis=0)
|
113
|
-
|
114
|
-
np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0)
|
115
|
-
- adata.layers["X_pert"][split_mask]
|
118
|
+
X_pert_lil[split_mask] = (
|
119
|
+
np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0) - X_pert_lil[split_mask]
|
116
120
|
)
|
117
121
|
else:
|
118
122
|
if split_by is None:
|
@@ -129,49 +133,43 @@ class Mixscape:
|
|
129
133
|
|
130
134
|
for split_mask in split_masks:
|
131
135
|
control_mask_split = control_mask & split_mask
|
132
|
-
|
133
136
|
R_split = representation[split_mask]
|
134
137
|
R_control = representation[np.asarray(control_mask_split)]
|
135
|
-
|
136
138
|
eps = kwargs.pop("epsilon", 0.1)
|
137
139
|
nn_index = NNDescent(R_control, **kwargs)
|
138
140
|
indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
|
139
|
-
|
140
141
|
X_control = np.expm1(adata.X[np.asarray(control_mask_split)])
|
141
|
-
|
142
142
|
n_split = split_mask.sum()
|
143
143
|
n_control = X_control.shape[0]
|
144
144
|
|
145
145
|
if batch_size is None:
|
146
146
|
col_indices = np.ravel(indices)
|
147
147
|
row_indices = np.repeat(np.arange(n_split), n_neighbors)
|
148
|
-
|
149
148
|
neigh_matrix = csr_matrix(
|
150
149
|
(np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)),
|
151
150
|
shape=(n_split, n_control),
|
152
151
|
)
|
153
152
|
neigh_matrix /= n_neighbors
|
154
|
-
|
155
|
-
sc.pp.log1p(neigh_matrix @ X_control) -
|
153
|
+
X_pert_lil[np.asarray(split_mask)] = (
|
154
|
+
sc.pp.log1p(neigh_matrix @ X_control) - X_pert_lil[np.asarray(split_mask)]
|
156
155
|
)
|
157
156
|
else:
|
158
157
|
split_indices = np.where(split_mask)[0]
|
159
158
|
for i in range(0, n_split, batch_size):
|
160
159
|
size = min(i + batch_size, n_split)
|
161
160
|
select = slice(i, size)
|
162
|
-
|
163
161
|
batch = np.ravel(indices[select])
|
164
162
|
split_batch = split_indices[select]
|
165
|
-
|
166
163
|
size = size - i
|
167
|
-
|
168
164
|
means_batch = X_control[batch]
|
169
165
|
batch_reshaped = means_batch.reshape(size, n_neighbors, -1)
|
170
166
|
means_batch, _ = mean_var(batch_reshaped, axis=1)
|
167
|
+
X_pert_lil[split_batch] = np.log1p(means_batch) - X_pert_lil[split_batch]
|
171
168
|
|
172
|
-
|
173
|
-
|
174
|
-
|
169
|
+
if issparse(X_pert_lil):
|
170
|
+
adata.layers["X_pert"] = X_pert_lil.tocsr()
|
171
|
+
else:
|
172
|
+
adata.layers["X_pert"] = X_pert_lil
|
175
173
|
|
176
174
|
if copy:
|
177
175
|
return adata
|
@@ -531,26 +529,29 @@ class Mixscape:
|
|
531
529
|
gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
|
532
530
|
adata_split = adata[split_mask].copy()
|
533
531
|
# find top DE genes between cells with targeting and non-targeting gRNAs
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
logfc_threshold_mask = (
|
546
|
-
np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
|
532
|
+
with warnings.catch_warnings():
|
533
|
+
warnings.simplefilter("ignore", RuntimeWarning)
|
534
|
+
warnings.simplefilter("ignore", PerformanceWarning)
|
535
|
+
sc.tl.rank_genes_groups(
|
536
|
+
adata_split,
|
537
|
+
layer=layer,
|
538
|
+
groupby=labels,
|
539
|
+
groups=gene_targets,
|
540
|
+
reference=control,
|
541
|
+
method=test_method,
|
542
|
+
use_raw=False,
|
547
543
|
)
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
544
|
+
# get DE genes for each target gene
|
545
|
+
for gene in gene_targets:
|
546
|
+
logfc_threshold_mask = (
|
547
|
+
np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
|
548
|
+
)
|
549
|
+
de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask]
|
550
|
+
pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask]
|
551
|
+
de_genes = de_genes[pvals_adj < pval_cutoff]
|
552
|
+
if len(de_genes) < min_de_genes:
|
553
|
+
de_genes = np.array([])
|
554
|
+
perturbation_markers[(category, gene)] = de_genes
|
554
555
|
|
555
556
|
return perturbation_markers
|
556
557
|
|
@@ -711,7 +712,10 @@ class Mixscape:
|
|
711
712
|
if "mixscape_class" not in adata.obs:
|
712
713
|
raise ValueError("Please run `pt.tl.mixscape` first.")
|
713
714
|
adata_subset = adata[(adata.obs[labels] == target_gene) | (adata.obs[labels] == control)].copy()
|
714
|
-
|
715
|
+
with warnings.catch_warnings():
|
716
|
+
warnings.simplefilter("ignore", RuntimeWarning)
|
717
|
+
warnings.simplefilter("ignore", PerformanceWarning)
|
718
|
+
sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
|
715
719
|
sc.pp.scale(adata_subset, max_value=vmax)
|
716
720
|
sc.pp.subsample(adata_subset, n_obs=subsample_number)
|
717
721
|
|
@@ -998,8 +1002,7 @@ class Mixscape:
|
|
998
1002
|
ys = keys
|
999
1003
|
|
1000
1004
|
if multi_panel and groupby is None and len(ys) == 1:
|
1001
|
-
# This is a quick and dirty way for adapting scales across several
|
1002
|
-
# keys if groupby is None.
|
1005
|
+
# This is a quick and dirty way for adapting scales across several keys if groupby is None.
|
1003
1006
|
y = ys[0]
|
1004
1007
|
|
1005
1008
|
g = sns.catplot(
|
@@ -226,7 +226,7 @@ class MLPClassifierSpace(PerturbationSpace):
|
|
226
226
|
# Fix class unbalance (likely to happen in perturbation datasets)
|
227
227
|
# Usually control cells are overrepresented such that predicting control all time would give good results
|
228
228
|
# Cells with rare perturbations are sampled more
|
229
|
-
train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels), dim=1))
|
229
|
+
train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels.to_list()), dim=1))
|
230
230
|
train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
|
231
231
|
|
232
232
|
self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
|
pertpy/tools/_scgen/_scgen.py
CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from typing import TYPE_CHECKING, Any
|
4
4
|
|
5
|
+
import anndata as ad
|
5
6
|
import jax.numpy as jnp
|
6
7
|
import matplotlib.pyplot as plt
|
7
8
|
import numpy as np
|
@@ -248,7 +249,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
|
|
248
249
|
temp_cell[batch_ind[study]].X = batch_list[study].X
|
249
250
|
shared_ct.append(temp_cell)
|
250
251
|
|
251
|
-
all_shared_ann =
|
252
|
+
all_shared_ann = ad.concat(shared_ct, label="concat_batch", index_unique=None)
|
252
253
|
if "concat_batch" in all_shared_ann.obs.columns:
|
253
254
|
del all_shared_ann.obs["concat_batch"]
|
254
255
|
if len(not_shared_ct) < 1:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: pertpy
|
3
|
-
Version: 0.
|
3
|
+
Version: 1.0.0
|
4
4
|
Summary: Perturbation Analysis in the scverse ecosystem.
|
5
5
|
Project-URL: Documentation, https://pertpy.readthedocs.io
|
6
6
|
Project-URL: Source, https://github.com/scverse/pertpy
|
@@ -131,6 +131,12 @@ You can install _pertpy_ in less than a minute via [pip] from [PyPI]:
|
|
131
131
|
pip install pertpy
|
132
132
|
```
|
133
133
|
|
134
|
+
or [conda-forge]:
|
135
|
+
|
136
|
+
```console
|
137
|
+
conda install -c conda-forge pertpy
|
138
|
+
```
|
139
|
+
|
134
140
|
### Differential gene expression
|
135
141
|
|
136
142
|
If you want to use the differential gene expression interface, please install pertpy by running:
|
@@ -149,7 +155,13 @@ pip install 'pertpy[tcoda]'
|
|
149
155
|
|
150
156
|
### milo
|
151
157
|
|
152
|
-
milo
|
158
|
+
milo requires either the "de" extra for the "pydeseq2" solver:
|
159
|
+
|
160
|
+
```console
|
161
|
+
pip install 'pertpy[de]'
|
162
|
+
```
|
163
|
+
|
164
|
+
or, edger, statmod, and rpy2 for the "edger" solver:
|
153
165
|
|
154
166
|
```R
|
155
167
|
BiocManager::install("edgeR")
|
@@ -179,6 +191,7 @@ pip install rpy2
|
|
179
191
|
[pip]: https://pip.pypa.io/
|
180
192
|
[pypi]: https://pypi.org/
|
181
193
|
[api]: https://pertpy.readthedocs.io/en/latest/api.html
|
194
|
+
[conda-forge]: https://anaconda.org/conda-forge/pertpy
|
182
195
|
[//]: # "numfocus-fiscal-sponsor-attribution"
|
183
196
|
|
184
197
|
pertpy is part of the scverse® project ([website](https://scverse.org), [governance](https://scverse.org/about/roles)) and is fiscally sponsored by [NumFOCUS](https://numfocus.org/).
|
@@ -1,4 +1,4 @@
|
|
1
|
-
pertpy/__init__.py,sha256=
|
1
|
+
pertpy/__init__.py,sha256=cZHJ7PIOhtLkxJMlHbJ2rzei5xhLB4vg0c8AaIShfzc,972
|
2
2
|
pertpy/_doc.py,sha256=j5TMNC-DA9yIMqIIUNpjpcVgWfRqyBBfvbRjnCM_OLs,427
|
3
3
|
pertpy/_types.py,sha256=IcHCojCUqx8CapibNkcYf2TUqjBFP2ujeELvn_IBSBQ,154
|
4
4
|
pertpy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -21,12 +21,12 @@ pertpy/tools/_augur.py,sha256=tc1YKyc0BwzrEGgctsfyy7DsTNKxyvy7ZvWraTWCc1A,55262
|
|
21
21
|
pertpy/tools/_cinemaot.py,sha256=54-rS0AEj31dMe7iU4kEmLoAunq3jNuhsBE3IEp9hrI,38071
|
22
22
|
pertpy/tools/_dialogue.py,sha256=mygIZm5i_bnEE37TTQtr1efl_KJq-ejzeL3V1Bmr7Pg,52354
|
23
23
|
pertpy/tools/_enrichment.py,sha256=55mwotLH9DXQOhl85MCkxXu-MX0RysLyrPheJysAnF0,21369
|
24
|
-
pertpy/tools/_milo.py,sha256=
|
25
|
-
pertpy/tools/_mixscape.py,sha256=
|
24
|
+
pertpy/tools/_milo.py,sha256=9yoB9gkBNujqYDTKOlH2v3wiWhs5PdCuB8RgZ3xVI0Y,48049
|
25
|
+
pertpy/tools/_mixscape.py,sha256=HfrpBeRlxHXaOpZkF2FmX7dg35kUB1rL0_-n2aSi2_0,57905
|
26
26
|
pertpy/tools/decoupler_LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
27
27
|
pertpy/tools/transferlearning_MMD_LICENSE,sha256=MUvDA-o_j9htRpI8fStVdCRuyLdPkQUuIH0a_EIc57w,1069
|
28
28
|
pertpy/tools/_coda/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
29
|
-
pertpy/tools/_coda/_base_coda.py,sha256=
|
29
|
+
pertpy/tools/_coda/_base_coda.py,sha256=NjKIQBtTIUENnRmeIC2O8cMdU_9DKaJ5_AHPvFnc8XQ,111744
|
30
30
|
pertpy/tools/_coda/_sccoda.py,sha256=0Ret6O56kAfCNOdBvtxqiyuj2rUPp18SV1GVK1AvYGU,22607
|
31
31
|
pertpy/tools/_coda/_tasccoda.py,sha256=BTaOAmL458zQ_og3x4ENlDnJHD6_F4YkdCoXWsF4i1U,30465
|
32
32
|
pertpy/tools/_differential_gene_expression/__init__.py,sha256=SEydWg0iT3Y1pApjnCAOuHxFeI6xVUfgyBHv2s3LADU,487
|
@@ -39,20 +39,20 @@ pertpy/tools/_differential_gene_expression/_simple_tests.py,sha256=SfU8s_P2JzEA1
|
|
39
39
|
pertpy/tools/_differential_gene_expression/_statsmodels.py,sha256=90h9EPuoCtNxAbJ1Xq4j_E4yYJJpk64zTP7GyTdmrxY,2220
|
40
40
|
pertpy/tools/_distances/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
41
41
|
pertpy/tools/_distances/_distance_tests.py,sha256=6_nqfHUfKxkI2Yhkzspq3ujMpq56zV_Ddn7bgPzgjyo,13513
|
42
|
-
pertpy/tools/_distances/_distances.py,sha256=
|
42
|
+
pertpy/tools/_distances/_distances.py,sha256=_XbVU8dlYt_Jl2thYPUWg7HT6OXVe-Ki6qthF566sqQ,50503
|
43
43
|
pertpy/tools/_perturbation_space/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
44
44
|
pertpy/tools/_perturbation_space/_clustering.py,sha256=pNx_SpPkZfCbgF7vzHWqAaiiHdbxPaA-L-hTWTbzFhI,3528
|
45
45
|
pertpy/tools/_perturbation_space/_comparison.py,sha256=-NzCPRT-IlhJ9hOz7NQLSk0riIzr2C0yZvX6zm3kon4,4291
|
46
|
-
pertpy/tools/_perturbation_space/_discriminator_classifiers.py,sha256=
|
46
|
+
pertpy/tools/_perturbation_space/_discriminator_classifiers.py,sha256=a53-YmUwDHQBCT7ZWe_RH7PZsGXvoSHmJaQyL0CBJng,23383
|
47
47
|
pertpy/tools/_perturbation_space/_metrics.py,sha256=y8-baP8WRdB1iDgvP3uuQxSCDxA2lcxvEHHM2C_vWHY,3248
|
48
|
-
pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=
|
48
|
+
pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=Vyh15wWw9dcu2YUWhziQd2mA9-4IY8EC5dzkBT9HaIo,19457
|
49
49
|
pertpy/tools/_perturbation_space/_simple.py,sha256=AJlHRaEP-vViBeMDvvMtUnXMuIKqZVc7wggnjsHMfMw,12721
|
50
50
|
pertpy/tools/_scgen/__init__.py,sha256=uERFlFyF88TH0uLiwmsUGEfHfLVCiZMFuk8gO5f7164,45
|
51
51
|
pertpy/tools/_scgen/_base_components.py,sha256=Qq8myRUm43q9XBrZ9gBggfa2cSV2wbz_KYoLgH7iF1A,3009
|
52
|
-
pertpy/tools/_scgen/_scgen.py,sha256=
|
52
|
+
pertpy/tools/_scgen/_scgen.py,sha256=AQNGsDe-9HEqli3oq7UBDg68ofLCoXm-R_jnLFQ-rlc,30856
|
53
53
|
pertpy/tools/_scgen/_scgenvae.py,sha256=bPk4v7EdJc7ROdLuDitHiX_Pvwa7Flw2qHRUwBvjLJY,3889
|
54
54
|
pertpy/tools/_scgen/_utils.py,sha256=qz5QUn_Bvk2NGyYVzp3jgjWTFOMt1YyHwUo6HWtoThY,2871
|
55
|
-
pertpy-0.
|
56
|
-
pertpy-0.
|
57
|
-
pertpy-0.
|
58
|
-
pertpy-0.
|
55
|
+
pertpy-1.0.0.dist-info/METADATA,sha256=PnK9O-MyIPzSy5DNOqMN7G6zcxZ2ZTJnMFB5cEr5XJQ,8920
|
56
|
+
pertpy-1.0.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
57
|
+
pertpy-1.0.0.dist-info/licenses/LICENSE,sha256=XuiT2hxeRInhquEIBKMZ5M21n5syhDQ4XbABoposIAg,1100
|
58
|
+
pertpy-1.0.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|