pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +4 -2
- pertpy/data/__init__.py +66 -1
- pertpy/data/_dataloader.py +28 -26
- pertpy/data/_datasets.py +261 -92
- pertpy/metadata/__init__.py +6 -0
- pertpy/metadata/_cell_line.py +795 -0
- pertpy/metadata/_compound.py +128 -0
- pertpy/metadata/_drug.py +238 -0
- pertpy/metadata/_look_up.py +569 -0
- pertpy/metadata/_metadata.py +70 -0
- pertpy/metadata/_moa.py +125 -0
- pertpy/plot/__init__.py +0 -13
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +89 -6
- pertpy/tools/__init__.py +48 -15
- pertpy/tools/_augur.py +329 -32
- pertpy/tools/_cinemaot.py +145 -6
- pertpy/tools/_coda/_base_coda.py +1237 -116
- pertpy/tools/_coda/_sccoda.py +66 -36
- pertpy/tools/_coda/_tasccoda.py +46 -39
- pertpy/tools/_dialogue.py +180 -77
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +29 -24
- pertpy/tools/_distances/_distances.py +584 -98
- pertpy/tools/_enrichment.py +460 -0
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +406 -49
- pertpy/tools/_mixscape.py +677 -55
- pertpy/tools/_perturbation_space/_clustering.py +10 -3
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
- pertpy/tools/_perturbation_space/_simple.py +52 -11
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +706 -0
- pertpy/tools/_scgen/_utils.py +3 -5
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -234
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_coda.py +0 -1001
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_guide_rna.py +0 -82
- pertpy/plot/_milopy.py +0 -284
- pertpy/plot/_mixscape.py +0 -594
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_differential_gene_expression.py +0 -99
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_coda/_base_coda.py
CHANGED
@@ -1,17 +1,24 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
|
-
from
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import TYPE_CHECKING, Literal, Optional, Union
|
5
6
|
|
6
7
|
import arviz as az
|
7
|
-
import ete3 as ete
|
8
8
|
import jax.numpy as jnp
|
9
|
+
import matplotlib.pyplot as plt
|
9
10
|
import numpy as np
|
10
11
|
import pandas as pd
|
11
12
|
import patsy as pt
|
13
|
+
import scanpy as sc
|
14
|
+
import seaborn as sns
|
15
|
+
from adjustText import adjust_text
|
12
16
|
from anndata import AnnData
|
13
|
-
from jax import random
|
14
|
-
from
|
17
|
+
from jax import config, random
|
18
|
+
from lamin_utils import logger
|
19
|
+
from matplotlib import cm, rcParams
|
20
|
+
from matplotlib import image as mpimg
|
21
|
+
from matplotlib.colors import ListedColormap
|
15
22
|
from mudata import MuData
|
16
23
|
from numpyro.infer import HMC, MCMC, NUTS, initialization
|
17
24
|
from rich import box, print
|
@@ -20,10 +27,15 @@ from rich.table import Table
|
|
20
27
|
from scipy.cluster import hierarchy as sp_hierarchy
|
21
28
|
|
22
29
|
if TYPE_CHECKING:
|
30
|
+
from collections.abc import Sequence
|
31
|
+
|
23
32
|
import numpyro as npy
|
24
33
|
import toytree as tt
|
25
|
-
from
|
34
|
+
from ete3 import Tree
|
26
35
|
from jax._src.typing import Array
|
36
|
+
from matplotlib.axes import Axes
|
37
|
+
from matplotlib.colors import Colormap
|
38
|
+
from matplotlib.figure import Figure
|
27
39
|
|
28
40
|
config.update("jax_enable_x64", True)
|
29
41
|
|
@@ -99,9 +111,9 @@ class CompositionalModel2(ABC):
|
|
99
111
|
Categorical covariates are handled automatically, with the covariate value of the first sample being used as the reference category.
|
100
112
|
To set a different level as the base category for a categorical covariate, use "C(<CovariateName>, Treatment('<ReferenceLevelName>'))"
|
101
113
|
reference_cell_type: Column name that sets the reference cell type.
|
102
|
-
Reference the name of a column. If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen.
|
114
|
+
Reference the name of a column. If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen.
|
103
115
|
automatic_reference_absence_threshold: If using reference_cell_type = "automatic", determine the maximum fraction of zero entries for a cell type
|
104
|
-
to be considered as a possible reference cell type.
|
116
|
+
to be considered as a possible reference cell type.
|
105
117
|
|
106
118
|
Returns:
|
107
119
|
AnnData object that is ready for CODA models.
|
@@ -137,7 +149,7 @@ class CompositionalModel2(ABC):
|
|
137
149
|
ref_index = np.where(cell_type_disp == min_var)[0][0]
|
138
150
|
|
139
151
|
ref_cell_type = cell_types[ref_index]
|
140
|
-
|
152
|
+
logger.info(f"Automatic reference selection! Reference cell type set to {ref_cell_type}")
|
141
153
|
|
142
154
|
# Column name as reference cell type
|
143
155
|
elif reference_cell_type in cell_types:
|
@@ -149,7 +161,7 @@ class CompositionalModel2(ABC):
|
|
149
161
|
|
150
162
|
# Add pseudocount if zeroes are present.
|
151
163
|
if np.count_nonzero(sample_adata.X) != np.size(sample_adata.X):
|
152
|
-
|
164
|
+
logger.info("Zero counts encountered in data! Added a pseudocount of 0.5.")
|
153
165
|
sample_adata.X[sample_adata.X == 0] = 0.5
|
154
166
|
|
155
167
|
sample_adata.obsm["sample_counts"] = np.sum(sample_adata.X, axis=1)
|
@@ -179,7 +191,7 @@ class CompositionalModel2(ABC):
|
|
179
191
|
self,
|
180
192
|
sample_adata: AnnData,
|
181
193
|
kernel: npy.infer.mcmc.MCMCKernel,
|
182
|
-
rng_key: Array
|
194
|
+
rng_key: Array,
|
183
195
|
copy: bool = False,
|
184
196
|
*args,
|
185
197
|
**kwargs,
|
@@ -190,7 +202,7 @@ class CompositionalModel2(ABC):
|
|
190
202
|
sample_adata: anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
|
191
203
|
kernel: A `numpyro.infer.mcmc.MCMCKernel` object
|
192
204
|
rng_key: The rng state used. If None, a random state will be selected
|
193
|
-
copy: Return a copy instead of writing to adata.
|
205
|
+
copy: Return a copy instead of writing to adata.
|
194
206
|
args: Passed to `numpyro.infer.mcmc.MCMC`
|
195
207
|
kwargs: Passed to `numpyro.infer.mcmc.MCMC`
|
196
208
|
|
@@ -226,13 +238,13 @@ class CompositionalModel2(ABC):
|
|
226
238
|
|
227
239
|
acc_rate = np.array(self.mcmc.last_state.mean_accept_prob)
|
228
240
|
if acc_rate < 0.6:
|
229
|
-
|
230
|
-
f"
|
241
|
+
logger.warning(
|
242
|
+
f"Acceptance rate unusually low ({acc_rate} < 0.5)! Results might be incorrect! "
|
231
243
|
f"Please check feasibility of results and re-run the sampling step with a different rng_key if necessary."
|
232
244
|
)
|
233
245
|
if acc_rate > 0.95:
|
234
|
-
|
235
|
-
f"
|
246
|
+
logger.warning(
|
247
|
+
f"Acceptance rate unusually high ({acc_rate} > 0.95)! Results might be incorrect! "
|
236
248
|
f"Please check feasibility of results and re-run the sampling step with a different rng_key if necessary."
|
237
249
|
)
|
238
250
|
|
@@ -275,11 +287,11 @@ class CompositionalModel2(ABC):
|
|
275
287
|
|
276
288
|
Args:
|
277
289
|
data: AnnData object or MuData object.
|
278
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
279
|
-
num_samples: Number of sampled values after burn-in.
|
280
|
-
num_warmup: Number of burn-in (warmup) samples.
|
281
|
-
rng_key: The rng state used.
|
282
|
-
copy: Return a copy instead of writing to adata.
|
290
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
291
|
+
num_samples: Number of sampled values after burn-in.
|
292
|
+
num_warmup: Number of burn-in (warmup) samples.
|
293
|
+
rng_key: The rng state used.
|
294
|
+
copy: Return a copy instead of writing to adata.
|
283
295
|
|
284
296
|
Returns:
|
285
297
|
Calls `self.__run_mcmc`
|
@@ -288,14 +300,14 @@ class CompositionalModel2(ABC):
|
|
288
300
|
try:
|
289
301
|
sample_adata = data[modality_key]
|
290
302
|
except IndexError:
|
291
|
-
|
303
|
+
logger.error("When data is a MuData object, modality_key must be specified!")
|
292
304
|
raise
|
293
305
|
if isinstance(data, AnnData):
|
294
306
|
sample_adata = data
|
295
307
|
if copy:
|
296
308
|
sample_adata = sample_adata.copy()
|
297
309
|
|
298
|
-
rng_key_array = random.
|
310
|
+
rng_key_array = random.key(rng_key)
|
299
311
|
sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = np.array(rng_key_array)
|
300
312
|
|
301
313
|
# Set up NUTS kernel
|
@@ -328,14 +340,13 @@ class CompositionalModel2(ABC):
|
|
328
340
|
|
329
341
|
Args:
|
330
342
|
data: AnnData object or MuData object.
|
331
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
332
|
-
num_samples: Number of sampled values after burn-in.
|
333
|
-
num_warmup: Number of burn-in (warmup) samples.
|
334
|
-
rng_key: The rng state used. If None, a random state will be selected.
|
335
|
-
copy: Return a copy instead of writing to adata.
|
343
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
344
|
+
num_samples: Number of sampled values after burn-in.
|
345
|
+
num_warmup: Number of burn-in (warmup) samples.
|
346
|
+
rng_key: The rng state used. If None, a random state will be selected.
|
347
|
+
copy: Return a copy instead of writing to adata.
|
336
348
|
|
337
349
|
Examples:
|
338
|
-
Example with scCODA:
|
339
350
|
>>> import pertpy as pt
|
340
351
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
341
352
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -348,7 +359,7 @@ class CompositionalModel2(ABC):
|
|
348
359
|
try:
|
349
360
|
sample_adata = data[modality_key]
|
350
361
|
except IndexError:
|
351
|
-
|
362
|
+
logger.error("When data is a MuData object, modality_key must be specified!")
|
352
363
|
raise
|
353
364
|
if isinstance(data, AnnData):
|
354
365
|
sample_adata = data
|
@@ -358,10 +369,10 @@ class CompositionalModel2(ABC):
|
|
358
369
|
# Set rng key if needed
|
359
370
|
if rng_key is None:
|
360
371
|
rng = np.random.default_rng()
|
361
|
-
rng_key = random.
|
372
|
+
rng_key = random.key(rng.integers(0, 10000))
|
362
373
|
sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = rng_key
|
363
374
|
else:
|
364
|
-
rng_key = random.
|
375
|
+
rng_key = random.key(rng_key)
|
365
376
|
|
366
377
|
# Set up HMC kernel
|
367
378
|
sample_adata = self.set_init_mcmc_states(
|
@@ -387,7 +398,7 @@ class CompositionalModel2(ABC):
|
|
387
398
|
|
388
399
|
Args:
|
389
400
|
sample_adata: Anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
|
390
|
-
est_fdr: Desired FDR value.
|
401
|
+
est_fdr: Desired FDR value.
|
391
402
|
args: Passed to ``az.summary``
|
392
403
|
kwargs: Passed to ``az.summary``
|
393
404
|
|
@@ -423,7 +434,6 @@ class CompositionalModel2(ABC):
|
|
423
434
|
- Is credible: Boolean indicator whether effect is credible
|
424
435
|
|
425
436
|
Examples:
|
426
|
-
Example with scCODA:
|
427
437
|
>>> import pertpy as pt
|
428
438
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
429
439
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -433,7 +443,6 @@ class CompositionalModel2(ABC):
|
|
433
443
|
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
434
444
|
>>> intercept_df, effect_df = sccoda.summary_prepare(mdata["coda"])
|
435
445
|
"""
|
436
|
-
# Get model and effect selection types
|
437
446
|
select_type = sample_adata.uns["scCODA_params"]["select_type"]
|
438
447
|
model_type = sample_adata.uns["scCODA_params"]["model_type"]
|
439
448
|
|
@@ -548,7 +557,11 @@ class CompositionalModel2(ABC):
|
|
548
557
|
intercept_df = intercept_df.loc[:, ["final_parameter", hdis[0], hdis[1], "sd", "expected_sample"]].copy()
|
549
558
|
intercept_df = intercept_df.rename(
|
550
559
|
columns=dict(
|
551
|
-
zip(
|
560
|
+
zip(
|
561
|
+
intercept_df.columns,
|
562
|
+
["Final Parameter", hdis_new[0], hdis_new[1], "SD", "Expected Sample"],
|
563
|
+
strict=False,
|
564
|
+
)
|
552
565
|
)
|
553
566
|
)
|
554
567
|
|
@@ -561,6 +574,7 @@ class CompositionalModel2(ABC):
|
|
561
574
|
zip(
|
562
575
|
effect_df.columns,
|
563
576
|
["Effect", "Median", hdis_new[0], hdis_new[1], "SD", "Expected Sample", "log2-fold change"],
|
577
|
+
strict=False,
|
564
578
|
)
|
565
579
|
)
|
566
580
|
)
|
@@ -581,6 +595,7 @@ class CompositionalModel2(ABC):
|
|
581
595
|
"Expected Sample",
|
582
596
|
"log2-fold change",
|
583
597
|
],
|
598
|
+
strict=False,
|
584
599
|
)
|
585
600
|
)
|
586
601
|
)
|
@@ -594,6 +609,7 @@ class CompositionalModel2(ABC):
|
|
594
609
|
zip(
|
595
610
|
node_df.columns,
|
596
611
|
["Final Parameter", "Median", hdis_new[0], hdis_new[1], "SD", "Delta", "Is credible"],
|
612
|
+
strict=False,
|
597
613
|
)
|
598
614
|
) # type: ignore
|
599
615
|
) # type: ignore
|
@@ -622,8 +638,8 @@ class CompositionalModel2(ABC):
|
|
622
638
|
effect_df: Effect summary, see ``summary_prepare``
|
623
639
|
model_type: String indicating the model type ("classic" or "tree_agg")
|
624
640
|
select_type: String indicating the type of spike_and_slab selection ("spikeslab" or "sslasso")
|
625
|
-
target_fdr: Desired FDR value.
|
626
|
-
node_df: If using tree aggregation, the node-level effect DataFrame must be passed.
|
641
|
+
target_fdr: Desired FDR value.
|
642
|
+
node_df: If using tree aggregation, the node-level effect DataFrame must be passed.
|
627
643
|
|
628
644
|
Returns:
|
629
645
|
pd.DataFrame: effect DataFrame with inclusion probability, final parameters, expected sample.
|
@@ -775,13 +791,12 @@ class CompositionalModel2(ABC):
|
|
775
791
|
|
776
792
|
Args:
|
777
793
|
data: AnnData object or MuData object.
|
778
|
-
extended: If True, return the extended summary with additional statistics.
|
779
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
794
|
+
extended: If True, return the extended summary with additional statistics.
|
795
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
780
796
|
args: Passed to az.summary
|
781
797
|
kwargs: Passed to az.summary
|
782
798
|
|
783
799
|
Examples:
|
784
|
-
Example with scCODA:
|
785
800
|
>>> import pertpy as pt
|
786
801
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
787
802
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -795,11 +810,11 @@ class CompositionalModel2(ABC):
|
|
795
810
|
try:
|
796
811
|
sample_adata = data[modality_key]
|
797
812
|
except IndexError:
|
798
|
-
|
813
|
+
logger.error("When data is a MuData object, modality_key must be specified!")
|
799
814
|
raise
|
800
815
|
if isinstance(data, AnnData):
|
801
816
|
sample_adata = data
|
802
|
-
|
817
|
+
|
803
818
|
select_type = sample_adata.uns["scCODA_params"]["select_type"]
|
804
819
|
model_type = sample_adata.uns["scCODA_params"]["model_type"]
|
805
820
|
|
@@ -834,10 +849,10 @@ class CompositionalModel2(ABC):
|
|
834
849
|
table.add_column("Name", justify="left", style="cyan")
|
835
850
|
table.add_column("Value", justify="left")
|
836
851
|
table.add_row("Data", "Data: %d samples, %d cell types" % data_dims)
|
837
|
-
table.add_row("Reference cell type", "
|
838
|
-
table.add_row("Formula", "
|
852
|
+
table.add_row("Reference cell type", "{}".format(str(sample_adata.uns["scCODA_params"]["reference_cell_type"])))
|
853
|
+
table.add_row("Formula", "{}".format(sample_adata.uns["scCODA_params"]["formula"]))
|
839
854
|
if extended:
|
840
|
-
table.add_row("Reference index", "
|
855
|
+
table.add_row("Reference index", "{}".format(str(sample_adata.uns["scCODA_params"]["reference_index"])))
|
841
856
|
if select_type == "spikeslab":
|
842
857
|
table.add_row(
|
843
858
|
"Spike-and-slab threshold",
|
@@ -920,13 +935,12 @@ class CompositionalModel2(ABC):
|
|
920
935
|
|
921
936
|
Args:
|
922
937
|
data: AnnData object or MuData object.
|
923
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
938
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
924
939
|
|
925
940
|
Returns:
|
926
941
|
pd.DataFrame: Intercept data frame.
|
927
942
|
|
928
943
|
Examples:
|
929
|
-
Example with scCODA:
|
930
944
|
>>> import pertpy as pt
|
931
945
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
932
946
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -936,12 +950,11 @@ class CompositionalModel2(ABC):
|
|
936
950
|
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
937
951
|
>>> intercepts = sccoda.get_intercept_df(mdata)
|
938
952
|
"""
|
939
|
-
|
940
953
|
if isinstance(data, MuData):
|
941
954
|
try:
|
942
955
|
sample_adata = data[modality_key]
|
943
956
|
except IndexError:
|
944
|
-
|
957
|
+
logger.error("When data is a MuData object, modality_key must be specified!")
|
945
958
|
raise
|
946
959
|
if isinstance(data, AnnData):
|
947
960
|
sample_adata = data
|
@@ -953,13 +966,12 @@ class CompositionalModel2(ABC):
|
|
953
966
|
|
954
967
|
Args:
|
955
968
|
data: AnnData object or MuData object.
|
956
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
969
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
957
970
|
|
958
971
|
Returns:
|
959
972
|
pd.DataFrame: Effect data frame.
|
960
973
|
|
961
974
|
Examples:
|
962
|
-
Example with scCODA:
|
963
975
|
>>> import pertpy as pt
|
964
976
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
965
977
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -969,12 +981,11 @@ class CompositionalModel2(ABC):
|
|
969
981
|
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
970
982
|
>>> effects = sccoda.get_effect_df(mdata)
|
971
983
|
"""
|
972
|
-
|
973
984
|
if isinstance(data, MuData):
|
974
985
|
try:
|
975
986
|
sample_adata = data[modality_key]
|
976
987
|
except IndexError:
|
977
|
-
|
988
|
+
logger.error("When data is a MuData object, modality_key must be specified!")
|
978
989
|
raise
|
979
990
|
if isinstance(data, AnnData):
|
980
991
|
sample_adata = data
|
@@ -997,15 +1008,14 @@ class CompositionalModel2(ABC):
|
|
997
1008
|
|
998
1009
|
Args:
|
999
1010
|
data: AnnData object or MuData object.
|
1000
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
1011
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1001
1012
|
|
1002
1013
|
Returns:
|
1003
1014
|
pd.DataFrame: Node effect data frame.
|
1004
1015
|
|
1005
1016
|
Examples:
|
1006
|
-
Example with tascCODA (works only for model of type tree_agg, i.e. a tascCODA model):
|
1007
1017
|
>>> import pertpy as pt
|
1008
|
-
>>> adata = pt.dt.
|
1018
|
+
>>> adata = pt.dt.tasccoda_example()
|
1009
1019
|
>>> tasccoda = pt.tl.Tasccoda()
|
1010
1020
|
>>> mdata = tasccoda.load(
|
1011
1021
|
>>> adata, type="sample_level",
|
@@ -1023,7 +1033,7 @@ class CompositionalModel2(ABC):
|
|
1023
1033
|
try:
|
1024
1034
|
sample_adata = data[modality_key]
|
1025
1035
|
except IndexError:
|
1026
|
-
|
1036
|
+
logger.error("When data is a MuData object, modality_key must be specified!")
|
1027
1037
|
raise
|
1028
1038
|
if isinstance(data, AnnData):
|
1029
1039
|
sample_adata = data
|
@@ -1037,7 +1047,7 @@ class CompositionalModel2(ABC):
|
|
1037
1047
|
Args:
|
1038
1048
|
data: AnnData object or MuData object.
|
1039
1049
|
est_fdr: Desired FDR value.
|
1040
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
1050
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1041
1051
|
args: passed to self.summary_prepare
|
1042
1052
|
kwargs: passed to self.summary_prepare
|
1043
1053
|
|
@@ -1048,7 +1058,7 @@ class CompositionalModel2(ABC):
|
|
1048
1058
|
try:
|
1049
1059
|
sample_adata = data[modality_key]
|
1050
1060
|
except IndexError:
|
1051
|
-
|
1061
|
+
logger.error("When data is a MuData object, modality_key must be specified!")
|
1052
1062
|
raise
|
1053
1063
|
if isinstance(data, AnnData):
|
1054
1064
|
sample_adata = data
|
@@ -1071,8 +1081,8 @@ class CompositionalModel2(ABC):
|
|
1071
1081
|
|
1072
1082
|
Args:
|
1073
1083
|
data: AnnData object or MuData object.
|
1074
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
1075
|
-
est_fdr: Estimated false discovery rate. Must be between 0 and 1.
|
1084
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1085
|
+
est_fdr: Estimated false discovery rate. Must be between 0 and 1.
|
1076
1086
|
|
1077
1087
|
Returns:
|
1078
1088
|
pd.Series: Credible effect decision series which includes boolean values indicate whether effects are credible under inc_prob_threshold.
|
@@ -1081,7 +1091,7 @@ class CompositionalModel2(ABC):
|
|
1081
1091
|
try:
|
1082
1092
|
sample_adata = data[modality_key]
|
1083
1093
|
except IndexError:
|
1084
|
-
|
1094
|
+
logger.error("When data is a MuData object, modality_key must be specified!")
|
1085
1095
|
raise
|
1086
1096
|
if isinstance(data, AnnData):
|
1087
1097
|
sample_adata = data
|
@@ -1113,9 +1123,1120 @@ class CompositionalModel2(ABC):
|
|
1113
1123
|
|
1114
1124
|
return out
|
1115
1125
|
|
1126
|
+
def _stackbar( # pragma: no cover
|
1127
|
+
self,
|
1128
|
+
y: np.ndarray,
|
1129
|
+
type_names: list[str],
|
1130
|
+
title: str,
|
1131
|
+
level_names: list[str],
|
1132
|
+
figsize: tuple[float, float] | None = None,
|
1133
|
+
dpi: int | None = 100,
|
1134
|
+
palette: ListedColormap | None = cm.tab20,
|
1135
|
+
show_legend: bool | None = True,
|
1136
|
+
) -> plt.Axes:
|
1137
|
+
"""Plots a stacked barplot for one (discrete) covariate.
|
1138
|
+
|
1139
|
+
Typical use (only inside stacked_barplot): plot_one_stackbar(data.X, data.var.index, "xyz", data.obs.index)
|
1140
|
+
|
1141
|
+
Args:
|
1142
|
+
y: The count data, collapsed onto the level of interest. i.e. a binary covariate has two rows,
|
1143
|
+
one for each group, containing the count mean of each cell type
|
1144
|
+
type_names: The names of all cell types
|
1145
|
+
title: Plot title, usually the covariate's name
|
1146
|
+
level_names: Names of the covariate's levels
|
1147
|
+
figsize: Figure size (matplotlib).
|
1148
|
+
dpi: Resolution in DPI (matplotlib).
|
1149
|
+
palette: The color map for the barplot.
|
1150
|
+
show_legend: If True, adds a legend.
|
1151
|
+
|
1152
|
+
Returns:
|
1153
|
+
A :class:`~matplotlib.axes.Axes` object
|
1154
|
+
"""
|
1155
|
+
n_bars, n_types = y.shape
|
1156
|
+
|
1157
|
+
figsize = rcParams["figure.figsize"] if figsize is None else figsize
|
1158
|
+
|
1159
|
+
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1160
|
+
r = np.array(range(n_bars))
|
1161
|
+
sample_sums = np.sum(y, axis=1)
|
1162
|
+
|
1163
|
+
barwidth = 0.85
|
1164
|
+
cum_bars = np.zeros(n_bars)
|
1165
|
+
|
1166
|
+
for n in range(n_types):
|
1167
|
+
bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums, strict=False)]
|
1168
|
+
plt.bar(
|
1169
|
+
r,
|
1170
|
+
bars,
|
1171
|
+
bottom=cum_bars,
|
1172
|
+
color=palette(n % palette.N),
|
1173
|
+
width=barwidth,
|
1174
|
+
label=type_names[n],
|
1175
|
+
linewidth=0,
|
1176
|
+
)
|
1177
|
+
cum_bars += bars
|
1178
|
+
|
1179
|
+
ax.set_title(title)
|
1180
|
+
if show_legend:
|
1181
|
+
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1)
|
1182
|
+
ax.set_xticks(r)
|
1183
|
+
ax.set_xticklabels(level_names, rotation=45, ha="right")
|
1184
|
+
ax.set_ylabel("Proportion")
|
1185
|
+
|
1186
|
+
return ax
|
1187
|
+
|
1188
|
+
def plot_stacked_barplot( # pragma: no cover
|
1189
|
+
self,
|
1190
|
+
data: AnnData | MuData,
|
1191
|
+
feature_name: str,
|
1192
|
+
modality_key: str = "coda",
|
1193
|
+
palette: ListedColormap | None = cm.tab20,
|
1194
|
+
show_legend: bool | None = True,
|
1195
|
+
level_order: list[str] = None,
|
1196
|
+
figsize: tuple[float, float] | None = None,
|
1197
|
+
dpi: int | None = 100,
|
1198
|
+
return_fig: bool | None = None,
|
1199
|
+
ax: plt.Axes | None = None,
|
1200
|
+
show: bool | None = None,
|
1201
|
+
save: str | bool | None = None,
|
1202
|
+
**kwargs,
|
1203
|
+
) -> plt.Axes | plt.Figure | None:
|
1204
|
+
"""Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
|
1205
|
+
|
1206
|
+
Args:
|
1207
|
+
data: AnnData object or MuData object.
|
1208
|
+
feature_name: The name of the covariate to plot. If feature_name=="samples", one bar for every sample will be plotted
|
1209
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1210
|
+
figsize: Figure size.
|
1211
|
+
dpi: Dpi setting.
|
1212
|
+
palette: The matplotlib color map for the barplot.
|
1213
|
+
show_legend: If True, adds a legend.
|
1214
|
+
level_order: Custom ordering of bars on the x-axis.
|
1215
|
+
|
1216
|
+
Returns:
|
1217
|
+
A :class:`~matplotlib.axes.Axes` object
|
1218
|
+
|
1219
|
+
Examples:
|
1220
|
+
>>> import pertpy as pt
|
1221
|
+
>>> haber_cells = pt.dt.haber_2017_regions()
|
1222
|
+
>>> sccoda = pt.tl.Sccoda()
|
1223
|
+
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
1224
|
+
sample_identifier="batch", covariate_obs=["condition"])
|
1225
|
+
>>> sccoda.plot_stacked_barplot(mdata, feature_name="samples")
|
1226
|
+
|
1227
|
+
Preview:
|
1228
|
+
.. image:: /_static/docstring_previews/sccoda_stacked_barplot.png
|
1229
|
+
"""
|
1230
|
+
if isinstance(data, MuData):
|
1231
|
+
data = data[modality_key]
|
1232
|
+
if isinstance(data, AnnData):
|
1233
|
+
data = data
|
1234
|
+
|
1235
|
+
ct_names = data.var.index
|
1236
|
+
|
1237
|
+
# option to plot one stacked barplot per sample
|
1238
|
+
if feature_name == "samples":
|
1239
|
+
if level_order:
|
1240
|
+
assert set(level_order) == set(data.obs.index), "level order is inconsistent with levels"
|
1241
|
+
data = data[level_order]
|
1242
|
+
ax = self._stackbar(
|
1243
|
+
data.X,
|
1244
|
+
type_names=data.var.index,
|
1245
|
+
title="samples",
|
1246
|
+
level_names=data.obs.index,
|
1247
|
+
figsize=figsize,
|
1248
|
+
dpi=dpi,
|
1249
|
+
palette=palette,
|
1250
|
+
show_legend=show_legend,
|
1251
|
+
)
|
1252
|
+
else:
|
1253
|
+
# Order levels
|
1254
|
+
if level_order:
|
1255
|
+
assert set(level_order) == set(data.obs[feature_name]), "level order is inconsistent with levels"
|
1256
|
+
levels = level_order
|
1257
|
+
elif hasattr(data.obs[feature_name], "cat"):
|
1258
|
+
levels = data.obs[feature_name].cat.categories.to_list()
|
1259
|
+
else:
|
1260
|
+
levels = pd.unique(data.obs[feature_name])
|
1261
|
+
n_levels = len(levels)
|
1262
|
+
feature_totals = np.zeros([n_levels, data.X.shape[1]])
|
1263
|
+
|
1264
|
+
for level in range(n_levels):
|
1265
|
+
l_indices = np.where(data.obs[feature_name] == levels[level])
|
1266
|
+
feature_totals[level] = np.sum(data.X[l_indices], axis=0)
|
1267
|
+
|
1268
|
+
ax = self._stackbar(
|
1269
|
+
feature_totals,
|
1270
|
+
type_names=ct_names,
|
1271
|
+
title=feature_name,
|
1272
|
+
level_names=levels,
|
1273
|
+
figsize=figsize,
|
1274
|
+
dpi=dpi,
|
1275
|
+
palette=palette,
|
1276
|
+
show_legend=show_legend,
|
1277
|
+
)
|
1278
|
+
|
1279
|
+
if save:
|
1280
|
+
plt.savefig(save, bbox_inches="tight")
|
1281
|
+
if show:
|
1282
|
+
plt.show()
|
1283
|
+
if return_fig:
|
1284
|
+
return plt.gcf()
|
1285
|
+
if not (show or save):
|
1286
|
+
return ax
|
1287
|
+
return None
|
1288
|
+
|
1289
|
+
def plot_effects_barplot( # pragma: no cover
|
1290
|
+
self,
|
1291
|
+
data: AnnData | MuData,
|
1292
|
+
modality_key: str = "coda",
|
1293
|
+
covariates: str | list | None = None,
|
1294
|
+
parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
|
1295
|
+
plot_facets: bool = True,
|
1296
|
+
plot_zero_covariate: bool = True,
|
1297
|
+
plot_zero_cell_type: bool = False,
|
1298
|
+
palette: str | ListedColormap | None = cm.tab20,
|
1299
|
+
level_order: list[str] = None,
|
1300
|
+
args_barplot: dict | None = None,
|
1301
|
+
figsize: tuple[float, float] | None = None,
|
1302
|
+
dpi: int | None = 100,
|
1303
|
+
return_fig: bool | None = None,
|
1304
|
+
ax: plt.Axes | None = None,
|
1305
|
+
show: bool | None = None,
|
1306
|
+
save: str | bool | None = None,
|
1307
|
+
) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
|
1308
|
+
"""Barplot visualization for effects.
|
1309
|
+
|
1310
|
+
The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
|
1311
|
+
The covariates groups can either be ordered along the x-axis of a single plot (plot_facets=False) or as plot facets (plot_facets=True).
|
1312
|
+
|
1313
|
+
Args:
|
1314
|
+
data: AnnData object or MuData object.
|
1315
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1316
|
+
covariates: The name of the covariates in data.obs to plot.
|
1317
|
+
parameter: The parameter in effect summary to plot.
|
1318
|
+
plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
|
1319
|
+
plot_zero_covariate: If True, plot covariate that have all zero effects. If False, do not plot.
|
1320
|
+
plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot.
|
1321
|
+
figsize: Figure size.
|
1322
|
+
dpi: Figure size.
|
1323
|
+
palette: The seaborn color map for the barplot.
|
1324
|
+
level_order: Custom ordering of bars on the x-axis.
|
1325
|
+
args_barplot: Arguments passed to sns.barplot.
|
1326
|
+
|
1327
|
+
Returns:
|
1328
|
+
Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
|
1329
|
+
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
|
1330
|
+
|
1331
|
+
Examples:
|
1332
|
+
>>> import pertpy as pt
|
1333
|
+
>>> haber_cells = pt.dt.haber_2017_regions()
|
1334
|
+
>>> sccoda = pt.tl.Sccoda()
|
1335
|
+
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
1336
|
+
sample_identifier="batch", covariate_obs=["condition"])
|
1337
|
+
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
1338
|
+
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
1339
|
+
>>> sccoda.plot_effects_barplot(mdata)
|
1340
|
+
|
1341
|
+
Preview:
|
1342
|
+
.. image:: /_static/docstring_previews/sccoda_effects_barplot.png
|
1343
|
+
"""
|
1344
|
+
if args_barplot is None:
|
1345
|
+
args_barplot = {}
|
1346
|
+
if isinstance(data, MuData):
|
1347
|
+
data = data[modality_key]
|
1348
|
+
if isinstance(data, AnnData):
|
1349
|
+
data = data
|
1350
|
+
# Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
|
1351
|
+
covariate_names = data.uns["scCODA_params"]["covariate_names"]
|
1352
|
+
if covariates is not None:
|
1353
|
+
if isinstance(covariates, str):
|
1354
|
+
covariates = [covariates]
|
1355
|
+
partial_covariate_names = [
|
1356
|
+
covariate_name
|
1357
|
+
for covariate_name in covariate_names
|
1358
|
+
if any(covariate in covariate_name for covariate in covariates)
|
1359
|
+
]
|
1360
|
+
covariate_names = partial_covariate_names
|
1361
|
+
covariate_names_non_zero = [
|
1362
|
+
covariate_name
|
1363
|
+
for covariate_name in covariate_names
|
1364
|
+
if data.varm[f"effect_df_{covariate_name}"][parameter].any()
|
1365
|
+
]
|
1366
|
+
covariate_names_zero = list(set(covariate_names) - set(covariate_names_non_zero))
|
1367
|
+
if not plot_zero_covariate:
|
1368
|
+
covariate_names = covariate_names_non_zero
|
1369
|
+
|
1370
|
+
# set up df for plotting
|
1371
|
+
plot_df = pd.concat(
|
1372
|
+
[data.varm[f"effect_df_{covariate_name}"][parameter] for covariate_name in covariate_names],
|
1373
|
+
axis=1,
|
1374
|
+
)
|
1375
|
+
plot_df.columns = covariate_names
|
1376
|
+
plot_df = pd.melt(plot_df, ignore_index=False, var_name="Covariate")
|
1377
|
+
|
1378
|
+
plot_df = plot_df.reset_index()
|
1379
|
+
|
1380
|
+
if len(covariate_names_zero) != 0:
|
1381
|
+
if plot_facets:
|
1382
|
+
if plot_zero_covariate and not plot_zero_cell_type:
|
1383
|
+
plot_df = plot_df[plot_df["value"] != 0]
|
1384
|
+
for covariate_name_zero in covariate_names_zero:
|
1385
|
+
new_row = {
|
1386
|
+
"Covariate": covariate_name_zero,
|
1387
|
+
"Cell Type": "zero",
|
1388
|
+
"value": 0,
|
1389
|
+
}
|
1390
|
+
plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
|
1391
|
+
plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
|
1392
|
+
plot_df = plot_df.sort_values(["covariate_"])
|
1393
|
+
if not plot_zero_cell_type:
|
1394
|
+
cell_type_names_zero = [
|
1395
|
+
name
|
1396
|
+
for name in plot_df["Cell Type"].unique()
|
1397
|
+
if (plot_df[plot_df["Cell Type"] == name]["value"] == 0).all()
|
1398
|
+
]
|
1399
|
+
plot_df = plot_df[~plot_df["Cell Type"].isin(cell_type_names_zero)]
|
1400
|
+
|
1401
|
+
# If plot as facets, create a FacetGrid and map barplot to it.
|
1402
|
+
if plot_facets:
|
1403
|
+
if isinstance(palette, ListedColormap):
|
1404
|
+
palette = np.array([palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]).tolist()
|
1405
|
+
if figsize is not None:
|
1406
|
+
height = figsize[0]
|
1407
|
+
aspect = np.round(figsize[1] / figsize[0], 2)
|
1408
|
+
else:
|
1409
|
+
height = 3
|
1410
|
+
aspect = 2
|
1411
|
+
|
1412
|
+
g = sns.FacetGrid(
|
1413
|
+
plot_df,
|
1414
|
+
col="Covariate",
|
1415
|
+
sharey=True,
|
1416
|
+
sharex=False,
|
1417
|
+
height=height,
|
1418
|
+
aspect=aspect,
|
1419
|
+
)
|
1420
|
+
|
1421
|
+
g.map(
|
1422
|
+
sns.barplot,
|
1423
|
+
"Cell Type",
|
1424
|
+
"value",
|
1425
|
+
palette=palette,
|
1426
|
+
order=level_order,
|
1427
|
+
**args_barplot,
|
1428
|
+
)
|
1429
|
+
g.set_xticklabels(rotation=90)
|
1430
|
+
g.set(ylabel=parameter)
|
1431
|
+
axes = g.axes.flatten()
|
1432
|
+
for i, ax in enumerate(axes):
|
1433
|
+
ax.set_title(covariate_names[i])
|
1434
|
+
if len(ax.get_xticklabels()) < 5:
|
1435
|
+
ax.set_aspect(10 / len(ax.get_xticklabels()))
|
1436
|
+
if len(ax.get_xticklabels()) == 1:
|
1437
|
+
if ax.get_xticklabels()[0]._text == "zero":
|
1438
|
+
ax.set_xticks([])
|
1439
|
+
|
1440
|
+
if save:
|
1441
|
+
plt.savefig(save, bbox_inches="tight")
|
1442
|
+
if show:
|
1443
|
+
plt.show()
|
1444
|
+
if return_fig:
|
1445
|
+
return plt.gcf()
|
1446
|
+
if not (show or save):
|
1447
|
+
return g
|
1448
|
+
return None
|
1449
|
+
|
1450
|
+
# If not plot as facets, call barplot to plot cell types on the x-axis.
|
1451
|
+
else:
|
1452
|
+
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1453
|
+
if len(covariate_names) == 1:
|
1454
|
+
if isinstance(palette, ListedColormap):
|
1455
|
+
palette = np.array(
|
1456
|
+
[palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]
|
1457
|
+
).tolist()
|
1458
|
+
sns.barplot(
|
1459
|
+
data=plot_df,
|
1460
|
+
x="Cell Type",
|
1461
|
+
y="value",
|
1462
|
+
hue="x",
|
1463
|
+
palette=palette,
|
1464
|
+
ax=ax,
|
1465
|
+
)
|
1466
|
+
ax.set_title(covariate_names[0])
|
1467
|
+
else:
|
1468
|
+
if isinstance(palette, ListedColormap):
|
1469
|
+
palette = np.array([palette(i % palette.N) for i in range(len(covariate_names))]).tolist()
|
1470
|
+
sns.barplot(
|
1471
|
+
data=plot_df,
|
1472
|
+
x="Cell Type",
|
1473
|
+
y="value",
|
1474
|
+
hue="Covariate",
|
1475
|
+
palette=palette,
|
1476
|
+
ax=ax,
|
1477
|
+
)
|
1478
|
+
cell_types = pd.unique(plot_df["Cell Type"])
|
1479
|
+
ax.set_xticklabels(cell_types, rotation=90)
|
1480
|
+
|
1481
|
+
if save:
|
1482
|
+
plt.savefig(save, bbox_inches="tight")
|
1483
|
+
if show:
|
1484
|
+
plt.show()
|
1485
|
+
if return_fig:
|
1486
|
+
return plt.gcf()
|
1487
|
+
if not (show or save):
|
1488
|
+
return ax
|
1489
|
+
return None
|
1490
|
+
|
1491
|
+
def plot_boxplots( # pragma: no cover
|
1492
|
+
self,
|
1493
|
+
data: AnnData | MuData,
|
1494
|
+
feature_name: str,
|
1495
|
+
modality_key: str = "coda",
|
1496
|
+
y_scale: Literal["relative", "log", "log10", "count"] = "relative",
|
1497
|
+
plot_facets: bool = False,
|
1498
|
+
add_dots: bool = False,
|
1499
|
+
cell_types: list | None = None,
|
1500
|
+
args_boxplot: dict | None = None,
|
1501
|
+
args_swarmplot: dict | None = None,
|
1502
|
+
palette: str | None = "Blues",
|
1503
|
+
show_legend: bool | None = True,
|
1504
|
+
level_order: list[str] = None,
|
1505
|
+
figsize: tuple[float, float] | None = None,
|
1506
|
+
dpi: int | None = 100,
|
1507
|
+
return_fig: bool | None = None,
|
1508
|
+
ax: plt.Axes | None = None,
|
1509
|
+
show: bool | None = None,
|
1510
|
+
save: str | bool | None = None,
|
1511
|
+
) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
|
1512
|
+
"""Grouped boxplot visualization.
|
1513
|
+
|
1514
|
+
The cell counts for each cell type are shown as a group of boxplots
|
1515
|
+
with intra--group separation by a covariate from data.obs.
|
1516
|
+
|
1517
|
+
Args:
|
1518
|
+
data: AnnData object or MuData object
|
1519
|
+
feature_name: The name of the feature in data.obs to plot
|
1520
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1521
|
+
y_scale: Transformation to of cell counts. Options: "relative" - Relative abundance, "log" - log(count),
|
1522
|
+
"log10" - log10(count), "count" - absolute abundance (cell counts).
|
1523
|
+
plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
|
1524
|
+
add_dots: If True, overlay a scatterplot with one dot for each data point.
|
1525
|
+
cell_types: Subset of cell types that should be plotted.
|
1526
|
+
args_boxplot: Arguments passed to sns.boxplot.
|
1527
|
+
args_swarmplot: Arguments passed to sns.swarmplot.
|
1528
|
+
figsize: Figure size.
|
1529
|
+
dpi: Dpi setting.
|
1530
|
+
palette: The seaborn color map for the barplot.
|
1531
|
+
show_legend: If True, adds a legend.
|
1532
|
+
level_order: Custom ordering of bars on the x-axis.
|
1533
|
+
|
1534
|
+
Returns:
|
1535
|
+
Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
|
1536
|
+
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
|
1537
|
+
|
1538
|
+
Examples:
|
1539
|
+
>>> import pertpy as pt
|
1540
|
+
>>> haber_cells = pt.dt.haber_2017_regions()
|
1541
|
+
>>> sccoda = pt.tl.Sccoda()
|
1542
|
+
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
1543
|
+
sample_identifier="batch", covariate_obs=["condition"])
|
1544
|
+
>>> sccoda.plot_boxplots(mdata, feature_name="condition", add_dots=True)
|
1545
|
+
|
1546
|
+
Preview:
|
1547
|
+
.. image:: /_static/docstring_previews/sccoda_boxplots.png
|
1548
|
+
"""
|
1549
|
+
if args_boxplot is None:
|
1550
|
+
args_boxplot = {}
|
1551
|
+
if args_swarmplot is None:
|
1552
|
+
args_swarmplot = {}
|
1553
|
+
if isinstance(data, MuData):
|
1554
|
+
data = data[modality_key]
|
1555
|
+
if isinstance(data, AnnData):
|
1556
|
+
data = data
|
1557
|
+
# y scale transformations
|
1558
|
+
if y_scale == "relative":
|
1559
|
+
sample_sums = np.sum(data.X, axis=1, keepdims=True)
|
1560
|
+
X = data.X / sample_sums
|
1561
|
+
value_name = "Proportion"
|
1562
|
+
# add pseudocount 0.5 if using log scale
|
1563
|
+
elif y_scale == "log":
|
1564
|
+
X = data.X.copy()
|
1565
|
+
X[X == 0] = 0.5
|
1566
|
+
X = np.log(X)
|
1567
|
+
value_name = "log(count)"
|
1568
|
+
elif y_scale == "log10":
|
1569
|
+
X = data.X.copy()
|
1570
|
+
X[X == 0] = 0.5
|
1571
|
+
X = np.log(X)
|
1572
|
+
value_name = "log10(count)"
|
1573
|
+
elif y_scale == "count":
|
1574
|
+
X = data.X
|
1575
|
+
value_name = "count"
|
1576
|
+
else:
|
1577
|
+
raise ValueError("Invalid y_scale transformation")
|
1578
|
+
|
1579
|
+
count_df = pd.DataFrame(X, columns=data.var.index, index=data.obs.index).merge(
|
1580
|
+
data.obs[feature_name], left_index=True, right_index=True
|
1581
|
+
)
|
1582
|
+
plot_df = pd.melt(count_df, id_vars=feature_name, var_name="Cell type", value_name=value_name)
|
1583
|
+
if cell_types is not None:
|
1584
|
+
plot_df = plot_df[plot_df["Cell type"].isin(cell_types)]
|
1585
|
+
|
1586
|
+
# Currently disabled because the latest statsannotations does not support the latest seaborn.
|
1587
|
+
# We had to drop the dependency.
|
1588
|
+
# Get credible effects results from model
|
1589
|
+
# if draw_effects:
|
1590
|
+
# if model is not None:
|
1591
|
+
# credible_effects_df = model.credible_effects(data, modality_key).to_frame().reset_index()
|
1592
|
+
# else:
|
1593
|
+
# print("[bold yellow]Specify a tasCODA model to draw effects")
|
1594
|
+
# credible_effects_df[feature_name] = credible_effects_df["Covariate"].str.removeprefix(f"{feature_name}[T.")
|
1595
|
+
# credible_effects_df[feature_name] = credible_effects_df[feature_name].str.removesuffix("]")
|
1596
|
+
# credible_effects_df = credible_effects_df[credible_effects_df["Final Parameter"]]
|
1597
|
+
|
1598
|
+
# If plot as facets, create a FacetGrid and map boxplot to it.
|
1599
|
+
if plot_facets:
|
1600
|
+
if level_order is None:
|
1601
|
+
level_order = pd.unique(plot_df[feature_name])
|
1602
|
+
|
1603
|
+
K = X.shape[1]
|
1604
|
+
|
1605
|
+
if figsize is not None:
|
1606
|
+
height = figsize[0]
|
1607
|
+
aspect = np.round(figsize[1] / figsize[0], 2)
|
1608
|
+
else:
|
1609
|
+
height = 3
|
1610
|
+
aspect = 2
|
1611
|
+
|
1612
|
+
g = sns.FacetGrid(
|
1613
|
+
plot_df,
|
1614
|
+
col="Cell type",
|
1615
|
+
sharey=False,
|
1616
|
+
col_wrap=int(np.floor(np.sqrt(K))),
|
1617
|
+
height=height,
|
1618
|
+
aspect=aspect,
|
1619
|
+
)
|
1620
|
+
g.map(
|
1621
|
+
sns.boxplot,
|
1622
|
+
feature_name,
|
1623
|
+
value_name,
|
1624
|
+
palette=palette,
|
1625
|
+
order=level_order,
|
1626
|
+
**args_boxplot,
|
1627
|
+
)
|
1628
|
+
|
1629
|
+
if add_dots:
|
1630
|
+
if "hue" in args_swarmplot:
|
1631
|
+
hue = args_swarmplot.pop("hue")
|
1632
|
+
else:
|
1633
|
+
hue = None
|
1634
|
+
|
1635
|
+
if hue is None:
|
1636
|
+
g.map(
|
1637
|
+
sns.swarmplot,
|
1638
|
+
feature_name,
|
1639
|
+
value_name,
|
1640
|
+
color="black",
|
1641
|
+
order=level_order,
|
1642
|
+
**args_swarmplot,
|
1643
|
+
).set_titles("{col_name}")
|
1644
|
+
else:
|
1645
|
+
g.map(
|
1646
|
+
sns.swarmplot,
|
1647
|
+
feature_name,
|
1648
|
+
value_name,
|
1649
|
+
hue,
|
1650
|
+
order=level_order,
|
1651
|
+
**args_swarmplot,
|
1652
|
+
).set_titles("{col_name}")
|
1653
|
+
|
1654
|
+
if save:
|
1655
|
+
plt.savefig(save, bbox_inches="tight")
|
1656
|
+
if show:
|
1657
|
+
plt.show()
|
1658
|
+
if return_fig:
|
1659
|
+
return plt.gcf()
|
1660
|
+
if not (show or save):
|
1661
|
+
return g
|
1662
|
+
return None
|
1663
|
+
|
1664
|
+
# If not plot as facets, call boxplot to plot cell types on the x-axis.
|
1665
|
+
else:
|
1666
|
+
if level_order:
|
1667
|
+
args_boxplot["hue_order"] = level_order
|
1668
|
+
args_swarmplot["hue_order"] = level_order
|
1669
|
+
|
1670
|
+
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1671
|
+
|
1672
|
+
ax = sns.boxplot(
|
1673
|
+
x="Cell type",
|
1674
|
+
y=value_name,
|
1675
|
+
hue=feature_name,
|
1676
|
+
data=plot_df,
|
1677
|
+
fliersize=1,
|
1678
|
+
palette=palette,
|
1679
|
+
ax=ax,
|
1680
|
+
**args_boxplot,
|
1681
|
+
)
|
1682
|
+
|
1683
|
+
# Currently disabled because the latest statsannotations does not support the latest seaborn.
|
1684
|
+
# We had to drop the dependency.
|
1685
|
+
# if draw_effects:
|
1686
|
+
# pairs = [
|
1687
|
+
# [(row["Cell Type"], row[feature_name]), (row["Cell Type"], "Control")]
|
1688
|
+
# for _, row in credible_effects_df.iterrows()
|
1689
|
+
# ]
|
1690
|
+
# annot = Annotator(ax, pairs, data=plot_df, x="Cell type", y=value_name, hue=feature_name)
|
1691
|
+
# annot.configure(test=None, loc="outside", color="red", line_height=0, verbose=False)
|
1692
|
+
# annot.set_custom_annotations([row[feature_name] for _, row in credible_effects_df.iterrows()])
|
1693
|
+
# annot.annotate()
|
1694
|
+
|
1695
|
+
if add_dots:
|
1696
|
+
sns.swarmplot(
|
1697
|
+
x="Cell type",
|
1698
|
+
y=value_name,
|
1699
|
+
data=plot_df,
|
1700
|
+
hue=feature_name,
|
1701
|
+
ax=ax,
|
1702
|
+
dodge=True,
|
1703
|
+
palette="dark:black",
|
1704
|
+
**args_swarmplot,
|
1705
|
+
)
|
1706
|
+
|
1707
|
+
cell_types = pd.unique(plot_df["Cell type"])
|
1708
|
+
ax.set_xticklabels(cell_types, rotation=90)
|
1709
|
+
|
1710
|
+
if show_legend:
|
1711
|
+
handles, labels = ax.get_legend_handles_labels()
|
1712
|
+
handout = []
|
1713
|
+
labelout = []
|
1714
|
+
for h, l in zip(handles, labels, strict=False):
|
1715
|
+
if l not in labelout:
|
1716
|
+
labelout.append(l)
|
1717
|
+
handout.append(h)
|
1718
|
+
ax.legend(
|
1719
|
+
handout,
|
1720
|
+
labelout,
|
1721
|
+
loc="upper left",
|
1722
|
+
bbox_to_anchor=(1, 1),
|
1723
|
+
ncol=1,
|
1724
|
+
title=feature_name,
|
1725
|
+
)
|
1726
|
+
|
1727
|
+
if save:
|
1728
|
+
plt.savefig(save, bbox_inches="tight")
|
1729
|
+
if show:
|
1730
|
+
plt.show()
|
1731
|
+
if return_fig:
|
1732
|
+
return plt.gcf()
|
1733
|
+
if not (show or save):
|
1734
|
+
return ax
|
1735
|
+
return None
|
1736
|
+
|
1737
|
+
def plot_rel_abundance_dispersion_plot( # pragma: no cover
|
1738
|
+
self,
|
1739
|
+
data: AnnData | MuData,
|
1740
|
+
modality_key: str = "coda",
|
1741
|
+
abundant_threshold: float | None = 0.9,
|
1742
|
+
default_color: str | None = "Grey",
|
1743
|
+
abundant_color: str | None = "Red",
|
1744
|
+
label_cell_types: bool = True,
|
1745
|
+
figsize: tuple[float, float] | None = None,
|
1746
|
+
dpi: int | None = 100,
|
1747
|
+
return_fig: bool | None = None,
|
1748
|
+
ax: plt.Axes | None = None,
|
1749
|
+
show: bool | None = None,
|
1750
|
+
save: str | bool | None = None,
|
1751
|
+
) -> plt.Axes | plt.Figure | None:
|
1752
|
+
"""Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
|
1753
|
+
|
1754
|
+
If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color.
|
1755
|
+
|
1756
|
+
Args:
|
1757
|
+
data: AnnData or MuData object.
|
1758
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1759
|
+
abundant_threshold: Presence threshold for abundant cell types.
|
1760
|
+
default_color: Bar color for all non-minimal cell types.
|
1761
|
+
abundant_color: Bar color for cell types with abundant percentage larger than abundant_threshold.
|
1762
|
+
label_cell_types: Label dots with cell type names.
|
1763
|
+
figsize: Figure size.
|
1764
|
+
dpi: Dpi setting.
|
1765
|
+
ax: A matplotlib axes object. Only works if plotting a single component.
|
1766
|
+
|
1767
|
+
Returns:
|
1768
|
+
A :class:`~matplotlib.axes.Axes` object
|
1769
|
+
|
1770
|
+
Examples:
|
1771
|
+
>>> import pertpy as pt
|
1772
|
+
>>> haber_cells = pt.dt.haber_2017_regions()
|
1773
|
+
>>> sccoda = pt.tl.Sccoda()
|
1774
|
+
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
1775
|
+
sample_identifier="batch", covariate_obs=["condition"])
|
1776
|
+
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
1777
|
+
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
1778
|
+
>>> sccoda.plot_rel_abundance_dispersion_plot(mdata)
|
1779
|
+
|
1780
|
+
Preview:
|
1781
|
+
.. image:: /_static/docstring_previews/sccoda_rel_abundance_dispersion_plot.png
|
1782
|
+
"""
|
1783
|
+
if isinstance(data, MuData):
|
1784
|
+
data = data[modality_key]
|
1785
|
+
if isinstance(data, AnnData):
|
1786
|
+
data = data
|
1787
|
+
if ax is None:
|
1788
|
+
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1789
|
+
|
1790
|
+
rel_abun = data.X / np.sum(data.X, axis=1, keepdims=True)
|
1791
|
+
|
1792
|
+
percent_zero = np.sum(data.X == 0, axis=0) / data.X.shape[0]
|
1793
|
+
nonrare_ct = np.where(percent_zero < 1 - abundant_threshold)[0]
|
1794
|
+
|
1795
|
+
# select reference
|
1796
|
+
cell_type_disp = np.var(rel_abun, axis=0) / np.mean(rel_abun, axis=0)
|
1797
|
+
|
1798
|
+
is_abundant = [x in nonrare_ct for x in range(data.X.shape[1])]
|
1799
|
+
|
1800
|
+
# Scatterplot
|
1801
|
+
plot_df = pd.DataFrame(
|
1802
|
+
{
|
1803
|
+
"Total dispersion": cell_type_disp,
|
1804
|
+
"Cell type": data.var.index,
|
1805
|
+
"Presence": 1 - percent_zero,
|
1806
|
+
"Is abundant": is_abundant,
|
1807
|
+
}
|
1808
|
+
)
|
1809
|
+
|
1810
|
+
if len(np.unique(plot_df["Is abundant"])) > 1:
|
1811
|
+
palette = [default_color, abundant_color]
|
1812
|
+
elif np.unique(plot_df["Is abundant"]) == [False]:
|
1813
|
+
palette = [default_color]
|
1814
|
+
else:
|
1815
|
+
palette = [abundant_color]
|
1816
|
+
|
1817
|
+
ax = sns.scatterplot(
|
1818
|
+
data=plot_df,
|
1819
|
+
x="Presence",
|
1820
|
+
y="Total dispersion",
|
1821
|
+
hue="Is abundant",
|
1822
|
+
palette=palette,
|
1823
|
+
ax=ax,
|
1824
|
+
)
|
1825
|
+
|
1826
|
+
# Text labels for abundant cell types
|
1827
|
+
|
1828
|
+
abundant_df = plot_df.loc[plot_df["Is abundant"], :]
|
1829
|
+
|
1830
|
+
def label_point(x, y, val, ax):
|
1831
|
+
a = pd.concat({"x": x, "y": y, "val": val}, axis=1)
|
1832
|
+
texts = [
|
1833
|
+
ax.text(
|
1834
|
+
point["x"],
|
1835
|
+
point["y"],
|
1836
|
+
str(point["val"]),
|
1837
|
+
)
|
1838
|
+
for i, point in a.iterrows()
|
1839
|
+
]
|
1840
|
+
adjust_text(texts)
|
1841
|
+
|
1842
|
+
if label_cell_types:
|
1843
|
+
label_point(
|
1844
|
+
abundant_df["Presence"],
|
1845
|
+
abundant_df["Total dispersion"],
|
1846
|
+
abundant_df["Cell type"],
|
1847
|
+
plt.gca(),
|
1848
|
+
)
|
1849
|
+
|
1850
|
+
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
|
1851
|
+
|
1852
|
+
if save:
|
1853
|
+
plt.savefig(save, bbox_inches="tight")
|
1854
|
+
if show:
|
1855
|
+
plt.show()
|
1856
|
+
if return_fig:
|
1857
|
+
return plt.gcf()
|
1858
|
+
if not (show or save):
|
1859
|
+
return ax
|
1860
|
+
return None
|
1861
|
+
|
1862
|
+
def plot_draw_tree( # pragma: no cover
|
1863
|
+
self,
|
1864
|
+
data: AnnData | MuData,
|
1865
|
+
modality_key: str = "coda",
|
1866
|
+
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
|
1867
|
+
tight_text: bool | None = False,
|
1868
|
+
show_scale: bool | None = False,
|
1869
|
+
units: Literal["px", "mm", "in"] | None = "px",
|
1870
|
+
figsize: tuple[float, float] | None = (None, None),
|
1871
|
+
dpi: int | None = 100,
|
1872
|
+
show: bool | None = True,
|
1873
|
+
save: str | bool | None = None,
|
1874
|
+
) -> Tree | None:
|
1875
|
+
"""Plot a tree using input ete3 tree object.
|
1876
|
+
|
1877
|
+
Args:
|
1878
|
+
data: AnnData object or MuData object.
|
1879
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1880
|
+
tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
|
1881
|
+
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
1882
|
+
producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
|
1883
|
+
show_scale: Include the scale legend in the tree image or not.
|
1884
|
+
show: If True, plot the tree inline. If false, return tree and tree_style objects.
|
1885
|
+
file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG.
|
1886
|
+
Output image can be saved whether show is True or not.
|
1887
|
+
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
|
1888
|
+
figsize: Figure size.
|
1889
|
+
dpi: Dots per inches.
|
1890
|
+
|
1891
|
+
Returns:
|
1892
|
+
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
|
1893
|
+
|
1894
|
+
Examples:
|
1895
|
+
>>> import pertpy as pt
|
1896
|
+
>>> adata = pt.dt.tasccoda_example()
|
1897
|
+
>>> tasccoda = pt.tl.Tasccoda()
|
1898
|
+
>>> mdata = tasccoda.load(
|
1899
|
+
>>> adata, type="sample_level",
|
1900
|
+
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
1901
|
+
>>> key_added="lineage", add_level_name=True
|
1902
|
+
>>> )
|
1903
|
+
>>> mdata = tasccoda.prepare(
|
1904
|
+
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
1905
|
+
>>> )
|
1906
|
+
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
1907
|
+
>>> tasccoda.plot_draw_tree(mdata, tree="lineage")
|
1908
|
+
|
1909
|
+
Preview:
|
1910
|
+
.. image:: /_static/docstring_previews/tasccoda_draw_tree.png
|
1911
|
+
"""
|
1912
|
+
try:
|
1913
|
+
from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
|
1914
|
+
except ImportError:
|
1915
|
+
raise ImportError(
|
1916
|
+
"To use tasccoda please install additional dependencies with `pip install pertpy[coda]`"
|
1917
|
+
) from None
|
1918
|
+
|
1919
|
+
if isinstance(data, MuData):
|
1920
|
+
data = data[modality_key]
|
1921
|
+
if isinstance(data, AnnData):
|
1922
|
+
data = data
|
1923
|
+
if isinstance(tree, str):
|
1924
|
+
tree = data.uns[tree]
|
1925
|
+
|
1926
|
+
def my_layout(node):
|
1927
|
+
text_face = TextFace(node.name, tight_text=tight_text)
|
1928
|
+
faces.add_face_to_node(text_face, node, column=0, position="branch-right")
|
1929
|
+
|
1930
|
+
tree_style = TreeStyle()
|
1931
|
+
tree_style.show_leaf_name = False
|
1932
|
+
tree_style.layout_fn = my_layout
|
1933
|
+
tree_style.show_scale = show_scale
|
1934
|
+
|
1935
|
+
if save is not None:
|
1936
|
+
tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
|
1937
|
+
if show:
|
1938
|
+
return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
|
1939
|
+
else:
|
1940
|
+
return tree, tree_style
|
1941
|
+
|
1942
|
+
def plot_draw_effects( # pragma: no cover
|
1943
|
+
self,
|
1944
|
+
data: AnnData | MuData,
|
1945
|
+
covariate: str,
|
1946
|
+
modality_key: str = "coda",
|
1947
|
+
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
|
1948
|
+
show_legend: bool | None = None,
|
1949
|
+
show_leaf_effects: bool | None = False,
|
1950
|
+
tight_text: bool | None = False,
|
1951
|
+
show_scale: bool | None = False,
|
1952
|
+
units: Literal["px", "mm", "in"] | None = "px",
|
1953
|
+
figsize: tuple[float, float] | None = (None, None),
|
1954
|
+
dpi: int | None = 100,
|
1955
|
+
show: bool | None = True,
|
1956
|
+
save: str | None = None,
|
1957
|
+
) -> Tree | None:
|
1958
|
+
"""Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
|
1959
|
+
|
1960
|
+
Args:
|
1961
|
+
data: AnnData object or MuData object.
|
1962
|
+
covariate: The covariate, whose effects should be plotted.
|
1963
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1964
|
+
tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
|
1965
|
+
show_legend: If show legend of nodes significant effects or not.
|
1966
|
+
Defaults to False if show_leaf_effects is True.
|
1967
|
+
show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects.
|
1968
|
+
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
1969
|
+
producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
|
1970
|
+
show_scale: Include the scale legend in the tree image or not.
|
1971
|
+
show: If True, plot the tree inline. If false, return tree and tree_style objects.
|
1972
|
+
file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
|
1973
|
+
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
|
1974
|
+
figsize: Figure size.
|
1975
|
+
dpi: Dots per inches.
|
1976
|
+
|
1977
|
+
Returns:
|
1978
|
+
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`)
|
1979
|
+
or plot the tree inline (`show = False`)
|
1980
|
+
|
1981
|
+
Examples:
|
1982
|
+
>>> import pertpy as pt
|
1983
|
+
>>> adata = pt.dt.tasccoda_example()
|
1984
|
+
>>> tasccoda = pt.tl.Tasccoda()
|
1985
|
+
>>> mdata = tasccoda.load(
|
1986
|
+
>>> adata, type="sample_level",
|
1987
|
+
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
1988
|
+
>>> key_added="lineage", add_level_name=True
|
1989
|
+
>>> )
|
1990
|
+
>>> mdata = tasccoda.prepare(
|
1991
|
+
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
1992
|
+
>>> )
|
1993
|
+
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
1994
|
+
>>> tasccoda.plot_draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
|
1995
|
+
|
1996
|
+
Preview:
|
1997
|
+
.. image:: /_static/docstring_previews/tasccoda_draw_effects.png
|
1998
|
+
"""
|
1999
|
+
try:
|
2000
|
+
from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
|
2001
|
+
except ImportError:
|
2002
|
+
raise ImportError(
|
2003
|
+
"To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
|
2004
|
+
) from None
|
2005
|
+
|
2006
|
+
if isinstance(data, MuData):
|
2007
|
+
data = data[modality_key]
|
2008
|
+
if isinstance(data, AnnData):
|
2009
|
+
data = data
|
2010
|
+
if show_legend is None:
|
2011
|
+
show_legend = not show_leaf_effects
|
2012
|
+
elif show_legend:
|
2013
|
+
logger.info("Tree leaves and leaf effect bars won't be aligned when legend is shown!")
|
2014
|
+
|
2015
|
+
if isinstance(tree, str):
|
2016
|
+
tree = data.uns[tree]
|
2017
|
+
# Collapse tree singularities
|
2018
|
+
tree2 = collapse_singularities_2(tree)
|
2019
|
+
|
2020
|
+
node_effs = data.uns["scCODA_params"]["node_df"].loc[(covariate + "_node",),].copy()
|
2021
|
+
node_effs.index = node_effs.index.get_level_values("Node")
|
2022
|
+
|
2023
|
+
covariates = data.uns["scCODA_params"]["covariate_names"]
|
2024
|
+
effect_dfs = [data.varm[f"effect_df_{cov}"] for cov in covariates]
|
2025
|
+
eff_df = pd.concat(effect_dfs)
|
2026
|
+
eff_df.index = pd.MultiIndex.from_product(
|
2027
|
+
(covariates, data.var.index.tolist()),
|
2028
|
+
names=["Covariate", "Cell Type"],
|
2029
|
+
)
|
2030
|
+
leaf_effs = eff_df.loc[(covariate,),].copy()
|
2031
|
+
leaf_effs.index = leaf_effs.index.get_level_values("Cell Type")
|
2032
|
+
|
2033
|
+
# Add effect values
|
2034
|
+
for n in tree2.traverse():
|
2035
|
+
nstyle = NodeStyle()
|
2036
|
+
nstyle["size"] = 0
|
2037
|
+
n.set_style(nstyle)
|
2038
|
+
if n.name in node_effs.index:
|
2039
|
+
e = node_effs.loc[n.name, "Final Parameter"]
|
2040
|
+
n.add_feature("node_effect", e)
|
2041
|
+
else:
|
2042
|
+
n.add_feature("node_effect", 0)
|
2043
|
+
if n.name in leaf_effs.index:
|
2044
|
+
e = leaf_effs.loc[n.name, "Effect"]
|
2045
|
+
n.add_feature("leaf_effect", e)
|
2046
|
+
else:
|
2047
|
+
n.add_feature("leaf_effect", 0)
|
2048
|
+
|
2049
|
+
# Scale effect values to get nice node sizes
|
2050
|
+
eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
|
2051
|
+
leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
|
2052
|
+
|
2053
|
+
def my_layout(node):
|
2054
|
+
text_face = TextFace(node.name, tight_text=tight_text)
|
2055
|
+
text_face.margin_left = 10
|
2056
|
+
faces.add_face_to_node(text_face, node, column=0, aligned=True)
|
2057
|
+
|
2058
|
+
# if node.is_leaf():
|
2059
|
+
size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
|
2060
|
+
if np.sign(node.node_effect) == 1:
|
2061
|
+
color = "blue"
|
2062
|
+
elif np.sign(node.node_effect) == -1:
|
2063
|
+
color = "red"
|
2064
|
+
else:
|
2065
|
+
color = "cyan"
|
2066
|
+
if size != 0:
|
2067
|
+
faces.add_face_to_node(CircleFace(radius=size, color=color), node, column=0)
|
2068
|
+
|
2069
|
+
tree_style = TreeStyle()
|
2070
|
+
tree_style.show_leaf_name = False
|
2071
|
+
tree_style.layout_fn = my_layout
|
2072
|
+
tree_style.show_scale = show_scale
|
2073
|
+
tree_style.draw_guiding_lines = True
|
2074
|
+
tree_style.legend_position = 1
|
2075
|
+
|
2076
|
+
if show_legend:
|
2077
|
+
tree_style.legend.add_face(TextFace("Effects"), column=0)
|
2078
|
+
tree_style.legend.add_face(TextFace(" "), column=1)
|
2079
|
+
for i in range(4, 0, -1):
|
2080
|
+
tree_style.legend.add_face(
|
2081
|
+
CircleFace(
|
2082
|
+
float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
|
2083
|
+
"red",
|
2084
|
+
),
|
2085
|
+
column=0,
|
2086
|
+
)
|
2087
|
+
tree_style.legend.add_face(TextFace(f"{-eff_max * i / 4:.2f} "), column=0)
|
2088
|
+
tree_style.legend.add_face(
|
2089
|
+
CircleFace(
|
2090
|
+
float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
|
2091
|
+
"blue",
|
2092
|
+
),
|
2093
|
+
column=1,
|
2094
|
+
)
|
2095
|
+
tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
|
2096
|
+
|
2097
|
+
if show_leaf_effects:
|
2098
|
+
leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
|
2099
|
+
leaf_effs = leaf_effs.loc[leaf_name].reset_index()
|
2100
|
+
palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
|
2101
|
+
|
2102
|
+
dir_path = Path.cwd()
|
2103
|
+
dir_path = Path(dir_path / "tree_effect.png")
|
2104
|
+
tree2.render(dir_path, tree_style=tree_style, units="in")
|
2105
|
+
_, ax = plt.subplots(1, 2, figsize=(10, 10))
|
2106
|
+
sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
|
2107
|
+
img = mpimg.imread(dir_path)
|
2108
|
+
ax[0].imshow(img)
|
2109
|
+
ax[0].get_xaxis().set_visible(False)
|
2110
|
+
ax[0].get_yaxis().set_visible(False)
|
2111
|
+
ax[0].set_frame_on(False)
|
2112
|
+
|
2113
|
+
ax[1].get_yaxis().set_visible(False)
|
2114
|
+
ax[1].spines["left"].set_visible(False)
|
2115
|
+
ax[1].spines["right"].set_visible(False)
|
2116
|
+
ax[1].spines["top"].set_visible(False)
|
2117
|
+
plt.xlim(-leaf_eff_max, leaf_eff_max)
|
2118
|
+
plt.subplots_adjust(wspace=0)
|
2119
|
+
|
2120
|
+
if save is not None:
|
2121
|
+
plt.savefig(save)
|
2122
|
+
|
2123
|
+
if save is not None and not show_leaf_effects:
|
2124
|
+
tree2.render(save, tree_style=tree_style, units=units)
|
2125
|
+
if show:
|
2126
|
+
if not show_leaf_effects:
|
2127
|
+
return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
|
2128
|
+
else:
|
2129
|
+
if not show_leaf_effects:
|
2130
|
+
return tree2, tree_style
|
2131
|
+
return None
|
2132
|
+
|
2133
|
+
def plot_effects_umap( # pragma: no cover
|
2134
|
+
self,
|
2135
|
+
mdata: MuData,
|
2136
|
+
effect_name: str | list | None,
|
2137
|
+
cluster_key: str,
|
2138
|
+
modality_key_1: str = "rna",
|
2139
|
+
modality_key_2: str = "coda",
|
2140
|
+
color_map: Colormap | str | None = None,
|
2141
|
+
palette: str | Sequence[str] | None = None,
|
2142
|
+
return_fig: bool | None = None,
|
2143
|
+
ax: Axes = None,
|
2144
|
+
show: bool = None,
|
2145
|
+
save: str | bool | None = None,
|
2146
|
+
**kwargs,
|
2147
|
+
) -> plt.Axes | plt.Figure | None:
|
2148
|
+
"""Plot a UMAP visualization colored by effect strength.
|
2149
|
+
|
2150
|
+
Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData
|
2151
|
+
(default is data['rna']) depending on the cluster they were assigned to.
|
2152
|
+
|
2153
|
+
Args:
|
2154
|
+
mudata: MuData object.
|
2155
|
+
effect_name: The name of the effect results in .varm of aggregated sample-level AnnData to plot
|
2156
|
+
cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']).
|
2157
|
+
To assign cell types' effects to original cells.
|
2158
|
+
modality_key_1: Key to the cell-level AnnData in the MuData object.
|
2159
|
+
modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
|
2160
|
+
show: Whether to display the figure or return axis.
|
2161
|
+
ax: A matplotlib axes object. Only works if plotting a single component.
|
2162
|
+
**kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
|
2163
|
+
|
2164
|
+
Returns:
|
2165
|
+
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
|
2166
|
+
|
2167
|
+
Examples:
|
2168
|
+
>>> import pertpy as pt
|
2169
|
+
>>> import scanpy as sc
|
2170
|
+
>>> import schist
|
2171
|
+
>>> adata = pt.dt.haber_2017_regions()
|
2172
|
+
>>> sc.pp.neighbors(adata)
|
2173
|
+
>>> schist.inference.nested_model(adata, n_init=100, random_seed=5678)
|
2174
|
+
>>> tasccoda_model = pt.tl.Tasccoda()
|
2175
|
+
>>> tasccoda_data = tasccoda_model.load(adata, type="cell_level",
|
2176
|
+
>>> cell_type_identifier="nsbm_level_1",
|
2177
|
+
>>> sample_identifier="batch", covariate_obs=["condition"],
|
2178
|
+
>>> levels_orig=["nsbm_level_4", "nsbm_level_3", "nsbm_level_2", "nsbm_level_1"],
|
2179
|
+
>>> add_level_name=True)
|
2180
|
+
>>> tasccoda_model.prepare(
|
2181
|
+
>>> tasccoda_data,
|
2182
|
+
>>> modality_key="coda",
|
2183
|
+
>>> reference_cell_type="18",
|
2184
|
+
>>> formula="condition",
|
2185
|
+
>>> pen_args={"phi": 0, "lambda_1": 3.5},
|
2186
|
+
>>> tree_key="tree"
|
2187
|
+
>>> )
|
2188
|
+
>>> tasccoda_model.run_nuts(
|
2189
|
+
... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
|
2190
|
+
... )
|
2191
|
+
>>> tasccoda_model.run_nuts(
|
2192
|
+
... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
|
2193
|
+
... )
|
2194
|
+
>>> sc.tl.umap(tasccoda_data["rna"])
|
2195
|
+
>>> tasccoda_model.plot_effects_umap(tasccoda_data,
|
2196
|
+
>>> effect_name=["effect_df_condition[T.Salmonella]",
|
2197
|
+
>>> "effect_df_condition[T.Hpoly.Day3]",
|
2198
|
+
>>> "effect_df_condition[T.Hpoly.Day10]"],
|
2199
|
+
>>> cluster_key="nsbm_level_1",
|
2200
|
+
>>> )
|
2201
|
+
|
2202
|
+
Preview:
|
2203
|
+
.. image:: /_static/docstring_previews/tasccoda_effects_umap.png
|
2204
|
+
"""
|
2205
|
+
# TODO: Add effect_name parameter and cluster_key and test the example
|
2206
|
+
data_rna = mdata[modality_key_1]
|
2207
|
+
data_coda = mdata[modality_key_2]
|
2208
|
+
if isinstance(effect_name, str):
|
2209
|
+
effect_name = [effect_name]
|
2210
|
+
for _, effect in enumerate(effect_name):
|
2211
|
+
data_rna.obs[effect] = [data_coda.varm[effect].loc[f"{c}", "Effect"] for c in data_rna.obs[cluster_key]]
|
2212
|
+
if kwargs.get("vmin"):
|
2213
|
+
vmin = kwargs["vmin"]
|
2214
|
+
kwargs.pop("vmin")
|
2215
|
+
else:
|
2216
|
+
vmin = min(data_rna.obs[effect].min() for _, effect in enumerate(effect_name))
|
2217
|
+
if kwargs.get("vmax"):
|
2218
|
+
vmax = kwargs["vmax"]
|
2219
|
+
kwargs.pop("vmax")
|
2220
|
+
else:
|
2221
|
+
vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name))
|
2222
|
+
|
2223
|
+
return sc.pl.umap(
|
2224
|
+
data_rna,
|
2225
|
+
color=effect_name,
|
2226
|
+
vmax=vmax,
|
2227
|
+
vmin=vmin,
|
2228
|
+
palette=palette,
|
2229
|
+
color_map=color_map,
|
2230
|
+
return_fig=return_fig,
|
2231
|
+
ax=ax,
|
2232
|
+
show=show,
|
2233
|
+
save=save,
|
2234
|
+
**kwargs,
|
2235
|
+
)
|
2236
|
+
|
1116
2237
|
|
1117
2238
|
def get_a(
|
1118
|
-
tree: tt.
|
2239
|
+
tree: tt.core.ToyTree,
|
1119
2240
|
) -> tuple[np.ndarray, int]:
|
1120
2241
|
"""Calculate ancestor matrix from a toytree tree
|
1121
2242
|
|
@@ -1154,7 +2275,7 @@ def get_a(
|
|
1154
2275
|
return A, n_nodes - 1
|
1155
2276
|
|
1156
2277
|
|
1157
|
-
def collapse_singularities(tree: tt.
|
2278
|
+
def collapse_singularities(tree: tt.core.ToyTree) -> tt.core.ToyTree:
|
1158
2279
|
"""Collapses (deletes) nodes in a toytree tree that are singularities (have only one child).
|
1159
2280
|
|
1160
2281
|
Args:
|
@@ -1242,7 +2363,7 @@ def df2newick(df: pd.DataFrame, levels: list[str], inner_label: bool = True) ->
|
|
1242
2363
|
|
1243
2364
|
|
1244
2365
|
def get_a_2(
|
1245
|
-
tree:
|
2366
|
+
tree: Tree,
|
1246
2367
|
leaf_order: list[str] = None,
|
1247
2368
|
node_order: list[str] = None,
|
1248
2369
|
) -> tuple[np.ndarray, int]:
|
@@ -1263,6 +2384,13 @@ def get_a_2(
|
|
1263
2384
|
T
|
1264
2385
|
number of nodes in the tree, excluding the root node
|
1265
2386
|
"""
|
2387
|
+
try:
|
2388
|
+
import ete3 as ete
|
2389
|
+
except ImportError:
|
2390
|
+
raise ImportError(
|
2391
|
+
"To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
|
2392
|
+
) from None
|
2393
|
+
|
1266
2394
|
n_tips = len(tree.get_leaves())
|
1267
2395
|
n_nodes = len(tree.get_descendants())
|
1268
2396
|
|
@@ -1292,7 +2420,7 @@ def get_a_2(
|
|
1292
2420
|
return A_, n_nodes
|
1293
2421
|
|
1294
2422
|
|
1295
|
-
def collapse_singularities_2(tree:
|
2423
|
+
def collapse_singularities_2(tree: Tree) -> Tree:
|
1296
2424
|
"""Collapses (deletes) nodes in a ete3 tree that are singularities (have only one child).
|
1297
2425
|
|
1298
2426
|
Args:
|
@@ -1327,10 +2455,10 @@ def linkage_to_newick(
|
|
1327
2455
|
|
1328
2456
|
def build_newick(node, newick, parentdist, leaf_names):
|
1329
2457
|
if node.is_leaf():
|
1330
|
-
return f"{leaf_names[node.id]}:{(parentdist - node.dist)/2}{newick}"
|
2458
|
+
return f"{leaf_names[node.id]}:{(parentdist - node.dist) / 2}{newick}"
|
1331
2459
|
else:
|
1332
2460
|
if len(newick) > 0:
|
1333
|
-
newick = f"):{(parentdist - node.dist)/2}{newick}"
|
2461
|
+
newick = f"):{(parentdist - node.dist) / 2}{newick}"
|
1334
2462
|
else:
|
1335
2463
|
newick = ");"
|
1336
2464
|
newick = build_newick(node.get_left(), newick, node.dist, leaf_names)
|
@@ -1363,14 +2491,14 @@ def import_tree(
|
|
1363
2491
|
|
1364
2492
|
Args:
|
1365
2493
|
data: A tascCODA-compatible data object.
|
1366
|
-
modality_1: If `data` is MuData,
|
1367
|
-
modality_2: If `data` is MuData,
|
1368
|
-
dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object.
|
1369
|
-
levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
|
1370
|
-
levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
|
1371
|
-
add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
|
1372
|
-
key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`.
|
1373
|
-
|
2494
|
+
modality_1: If `data` is MuData, specify the modality name to the original cell level anndata object.
|
2495
|
+
modality_2: If `data` is MuData, specify the modality name to the aggregated level anndata object.
|
2496
|
+
dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object.
|
2497
|
+
levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
|
2498
|
+
levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
|
2499
|
+
add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
|
2500
|
+
key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`.
|
2501
|
+
If `data` is MuData, save tree in data[modality_2].
|
1374
2502
|
|
1375
2503
|
Returns:
|
1376
2504
|
Updates data with the following:
|
@@ -1379,15 +2507,22 @@ def import_tree(
|
|
1379
2507
|
|
1380
2508
|
tree: A ete3 tree object.
|
1381
2509
|
"""
|
2510
|
+
try:
|
2511
|
+
import ete3 as ete
|
2512
|
+
except ImportError:
|
2513
|
+
raise ImportError(
|
2514
|
+
"To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
|
2515
|
+
) from None
|
2516
|
+
|
1382
2517
|
if isinstance(data, MuData):
|
1383
2518
|
try:
|
1384
2519
|
data_1 = data[modality_1]
|
1385
2520
|
data_2 = data[modality_2]
|
1386
2521
|
except KeyError as name:
|
1387
|
-
|
2522
|
+
logger.error(f"No {name} slot in MuData")
|
1388
2523
|
raise
|
1389
2524
|
except IndexError:
|
1390
|
-
|
2525
|
+
logger.error("Please specify modality_1 and modality_2 to indicate modalities in MuData")
|
1391
2526
|
raise
|
1392
2527
|
else:
|
1393
2528
|
data_1 = data
|
@@ -1443,68 +2578,54 @@ def from_scanpy(
|
|
1443
2578
|
|
1444
2579
|
The anndata object needs to have a column in adata.obs that contains the cell type assignment.
|
1445
2580
|
Further, it must contain one column or a set of columns (e.g. subject id, treatment, disease status) that uniquely identify each (statistical) sample.
|
1446
|
-
Further covariates (e.g. subject age) can either be specified via
|
2581
|
+
Further covariates (e.g. subject age) can either be specified via additional column names in adata.obs, a key in adata.uns, or as a separate DataFrame.
|
1447
2582
|
|
1448
|
-
NOTE: The order of samples in the returned dataset is determined by the first
|
2583
|
+
NOTE: The order of samples in the returned dataset is determined by the first occurrence of cells from each sample in `adata`
|
1449
2584
|
|
1450
2585
|
Args:
|
1451
2586
|
adata: An anndata object from scanpy
|
1452
2587
|
cell_type_identifier: column name in adata.obs that specifies the cell types
|
1453
2588
|
sample_identifier: column name or list of column names in adata.obs that uniquely identify each sample
|
1454
2589
|
covariate_uns: key for adata.uns, where covariate values are stored
|
1455
|
-
covariate_obs: list of column names in adata.obs, where covariate values are stored.
|
2590
|
+
covariate_obs: list of column names in adata.obs, where covariate values are stored.
|
2591
|
+
Note: If covariate values are not unique for a value of sample_identifier, this covariate will be skipped.
|
1456
2592
|
covariate_df: DataFrame with covariates
|
1457
2593
|
|
1458
2594
|
Returns:
|
1459
2595
|
AnnData: A data set with cells aggregated to the (sample x cell type) level
|
1460
2596
|
"""
|
1461
|
-
if isinstance(sample_identifier, str)
|
1462
|
-
|
1463
|
-
|
1464
|
-
if covariate_obs:
|
1465
|
-
covariate_obs += [i for i in sample_identifier if i not in covariate_obs]
|
1466
|
-
else:
|
1467
|
-
covariate_obs = sample_identifier # type: ignore
|
2597
|
+
sample_identifier = [sample_identifier] if isinstance(sample_identifier, str) else sample_identifier
|
2598
|
+
covariate_obs = list(set(covariate_obs or []) | set(sample_identifier))
|
1468
2599
|
|
1469
|
-
# join sample identifiers
|
1470
2600
|
if isinstance(sample_identifier, list):
|
1471
2601
|
adata.obs["scCODA_sample_id"] = adata.obs[sample_identifier].agg("-".join, axis=1)
|
1472
2602
|
sample_identifier = "scCODA_sample_id"
|
1473
2603
|
|
1474
|
-
# get cell type counts
|
1475
2604
|
groups = adata.obs.value_counts([sample_identifier, cell_type_identifier])
|
1476
|
-
|
1477
|
-
|
1478
|
-
|
1479
|
-
# get covariates from different sources
|
1480
|
-
covariate_df_ = pd.DataFrame(index=count_data.index)
|
1481
|
-
|
1482
|
-
if covariate_df is None and covariate_obs is None and covariate_uns is None:
|
1483
|
-
print("No covariate information specified!")
|
2605
|
+
ct_count_data = groups.unstack(level=cell_type_identifier).fillna(0)
|
2606
|
+
covariate_df_ = pd.DataFrame(index=ct_count_data.index)
|
1484
2607
|
|
1485
2608
|
if covariate_uns is not None:
|
1486
|
-
covariate_df_uns = pd.DataFrame(adata.uns[covariate_uns])
|
1487
|
-
covariate_df_ = pd.concat(
|
2609
|
+
covariate_df_uns = pd.DataFrame(adata.uns[covariate_uns], index=ct_count_data.index)
|
2610
|
+
covariate_df_ = pd.concat([covariate_df_, covariate_df_uns], axis=1)
|
1488
2611
|
|
1489
|
-
if covariate_obs
|
1490
|
-
|
1491
|
-
|
1492
|
-
|
2612
|
+
if covariate_obs:
|
2613
|
+
unique_check = adata.obs.groupby(sample_identifier).nunique()
|
2614
|
+
for c in covariate_obs.copy():
|
2615
|
+
if unique_check[c].max() != 1:
|
2616
|
+
logger.warning(f"Covariate {c} has non-unique values for batch! Skipping...")
|
1493
2617
|
covariate_obs.remove(c)
|
1494
|
-
|
1495
|
-
|
1496
|
-
|
2618
|
+
if covariate_obs:
|
2619
|
+
covariate_df_obs = adata.obs.groupby(sample_identifier).first()[covariate_obs]
|
2620
|
+
covariate_df_ = pd.concat([covariate_df_, covariate_df_obs], axis=1)
|
1497
2621
|
|
1498
2622
|
if covariate_df is not None:
|
1499
|
-
if set(covariate_df.index) != set(
|
1500
|
-
raise ValueError("
|
1501
|
-
|
1502
|
-
covariate_df_ = pd.concat((covariate_df_, covs_ord), axis=1)
|
2623
|
+
if set(covariate_df.index) != set(ct_count_data.index):
|
2624
|
+
raise ValueError("Mismatch between sample names in anndata and covariate_df!")
|
2625
|
+
covariate_df_ = pd.concat([covariate_df_, covariate_df.reindex(ct_count_data.index)], axis=1)
|
1503
2626
|
|
1504
|
-
|
1505
|
-
|
1506
|
-
# create var (number of cells for each type as only column)
|
1507
|
-
var_dat = count_data.sum(axis=0).rename("n_cells").to_frame()
|
2627
|
+
var_dat = ct_count_data.sum().rename("n_cells").to_frame()
|
1508
2628
|
var_dat.index = var_dat.index.astype(str)
|
2629
|
+
covariate_df_.index = covariate_df_.index.astype(str)
|
1509
2630
|
|
1510
|
-
return AnnData(X=
|
2631
|
+
return AnnData(X=ct_count_data.values, var=var_dat, obs=covariate_df_)
|