pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +4 -2
- pertpy/data/__init__.py +66 -1
- pertpy/data/_dataloader.py +28 -26
- pertpy/data/_datasets.py +261 -92
- pertpy/metadata/__init__.py +6 -0
- pertpy/metadata/_cell_line.py +795 -0
- pertpy/metadata/_compound.py +128 -0
- pertpy/metadata/_drug.py +238 -0
- pertpy/metadata/_look_up.py +569 -0
- pertpy/metadata/_metadata.py +70 -0
- pertpy/metadata/_moa.py +125 -0
- pertpy/plot/__init__.py +0 -13
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +89 -6
- pertpy/tools/__init__.py +48 -15
- pertpy/tools/_augur.py +329 -32
- pertpy/tools/_cinemaot.py +145 -6
- pertpy/tools/_coda/_base_coda.py +1237 -116
- pertpy/tools/_coda/_sccoda.py +66 -36
- pertpy/tools/_coda/_tasccoda.py +46 -39
- pertpy/tools/_dialogue.py +180 -77
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +29 -24
- pertpy/tools/_distances/_distances.py +584 -98
- pertpy/tools/_enrichment.py +460 -0
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +406 -49
- pertpy/tools/_mixscape.py +677 -55
- pertpy/tools/_perturbation_space/_clustering.py +10 -3
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
- pertpy/tools/_perturbation_space/_simple.py +52 -11
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +706 -0
- pertpy/tools/_scgen/_utils.py +3 -5
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -234
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_coda.py +0 -1001
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_guide_rna.py +0 -82
- pertpy/plot/_milopy.py +0 -284
- pertpy/plot/_mixscape.py +0 -594
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_differential_gene_expression.py +0 -99
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_coda/_base_coda.py
CHANGED
@@ -1,17 +1,24 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
|
-
from
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import TYPE_CHECKING, Literal, Optional, Union
|
5
6
|
|
6
7
|
import arviz as az
|
7
|
-
import ete3 as ete
|
8
8
|
import jax.numpy as jnp
|
9
|
+
import matplotlib.pyplot as plt
|
9
10
|
import numpy as np
|
10
11
|
import pandas as pd
|
11
12
|
import patsy as pt
|
13
|
+
import scanpy as sc
|
14
|
+
import seaborn as sns
|
15
|
+
from adjustText import adjust_text
|
12
16
|
from anndata import AnnData
|
13
|
-
from jax import random
|
14
|
-
from
|
17
|
+
from jax import config, random
|
18
|
+
from lamin_utils import logger
|
19
|
+
from matplotlib import cm, rcParams
|
20
|
+
from matplotlib import image as mpimg
|
21
|
+
from matplotlib.colors import ListedColormap
|
15
22
|
from mudata import MuData
|
16
23
|
from numpyro.infer import HMC, MCMC, NUTS, initialization
|
17
24
|
from rich import box, print
|
@@ -20,10 +27,15 @@ from rich.table import Table
|
|
20
27
|
from scipy.cluster import hierarchy as sp_hierarchy
|
21
28
|
|
22
29
|
if TYPE_CHECKING:
|
30
|
+
from collections.abc import Sequence
|
31
|
+
|
23
32
|
import numpyro as npy
|
24
33
|
import toytree as tt
|
25
|
-
from
|
34
|
+
from ete3 import Tree
|
26
35
|
from jax._src.typing import Array
|
36
|
+
from matplotlib.axes import Axes
|
37
|
+
from matplotlib.colors import Colormap
|
38
|
+
from matplotlib.figure import Figure
|
27
39
|
|
28
40
|
config.update("jax_enable_x64", True)
|
29
41
|
|
@@ -99,9 +111,9 @@ class CompositionalModel2(ABC):
|
|
99
111
|
Categorical covariates are handled automatically, with the covariate value of the first sample being used as the reference category.
|
100
112
|
To set a different level as the base category for a categorical covariate, use "C(<CovariateName>, Treatment('<ReferenceLevelName>'))"
|
101
113
|
reference_cell_type: Column name that sets the reference cell type.
|
102
|
-
Reference the name of a column. If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen.
|
114
|
+
Reference the name of a column. If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen.
|
103
115
|
automatic_reference_absence_threshold: If using reference_cell_type = "automatic", determine the maximum fraction of zero entries for a cell type
|
104
|
-
to be considered as a possible reference cell type.
|
116
|
+
to be considered as a possible reference cell type.
|
105
117
|
|
106
118
|
Returns:
|
107
119
|
AnnData object that is ready for CODA models.
|
@@ -137,7 +149,7 @@ class CompositionalModel2(ABC):
|
|
137
149
|
ref_index = np.where(cell_type_disp == min_var)[0][0]
|
138
150
|
|
139
151
|
ref_cell_type = cell_types[ref_index]
|
140
|
-
|
152
|
+
logger.info(f"Automatic reference selection! Reference cell type set to {ref_cell_type}")
|
141
153
|
|
142
154
|
# Column name as reference cell type
|
143
155
|
elif reference_cell_type in cell_types:
|
@@ -149,7 +161,7 @@ class CompositionalModel2(ABC):
|
|
149
161
|
|
150
162
|
# Add pseudocount if zeroes are present.
|
151
163
|
if np.count_nonzero(sample_adata.X) != np.size(sample_adata.X):
|
152
|
-
|
164
|
+
logger.info("Zero counts encountered in data! Added a pseudocount of 0.5.")
|
153
165
|
sample_adata.X[sample_adata.X == 0] = 0.5
|
154
166
|
|
155
167
|
sample_adata.obsm["sample_counts"] = np.sum(sample_adata.X, axis=1)
|
@@ -179,7 +191,7 @@ class CompositionalModel2(ABC):
|
|
179
191
|
self,
|
180
192
|
sample_adata: AnnData,
|
181
193
|
kernel: npy.infer.mcmc.MCMCKernel,
|
182
|
-
rng_key: Array
|
194
|
+
rng_key: Array,
|
183
195
|
copy: bool = False,
|
184
196
|
*args,
|
185
197
|
**kwargs,
|
@@ -190,7 +202,7 @@ class CompositionalModel2(ABC):
|
|
190
202
|
sample_adata: anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
|
191
203
|
kernel: A `numpyro.infer.mcmc.MCMCKernel` object
|
192
204
|
rng_key: The rng state used. If None, a random state will be selected
|
193
|
-
copy: Return a copy instead of writing to adata.
|
205
|
+
copy: Return a copy instead of writing to adata.
|
194
206
|
args: Passed to `numpyro.infer.mcmc.MCMC`
|
195
207
|
kwargs: Passed to `numpyro.infer.mcmc.MCMC`
|
196
208
|
|
@@ -226,13 +238,13 @@ class CompositionalModel2(ABC):
|
|
226
238
|
|
227
239
|
acc_rate = np.array(self.mcmc.last_state.mean_accept_prob)
|
228
240
|
if acc_rate < 0.6:
|
229
|
-
|
230
|
-
f"
|
241
|
+
logger.warning(
|
242
|
+
f"Acceptance rate unusually low ({acc_rate} < 0.5)! Results might be incorrect! "
|
231
243
|
f"Please check feasibility of results and re-run the sampling step with a different rng_key if necessary."
|
232
244
|
)
|
233
245
|
if acc_rate > 0.95:
|
234
|
-
|
235
|
-
f"
|
246
|
+
logger.warning(
|
247
|
+
f"Acceptance rate unusually high ({acc_rate} > 0.95)! Results might be incorrect! "
|
236
248
|
f"Please check feasibility of results and re-run the sampling step with a different rng_key if necessary."
|
237
249
|
)
|
238
250
|
|
@@ -275,11 +287,11 @@ class CompositionalModel2(ABC):
|
|
275
287
|
|
276
288
|
Args:
|
277
289
|
data: AnnData object or MuData object.
|
278
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
279
|
-
num_samples: Number of sampled values after burn-in.
|
280
|
-
num_warmup: Number of burn-in (warmup) samples.
|
281
|
-
rng_key: The rng state used.
|
282
|
-
copy: Return a copy instead of writing to adata.
|
290
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
291
|
+
num_samples: Number of sampled values after burn-in.
|
292
|
+
num_warmup: Number of burn-in (warmup) samples.
|
293
|
+
rng_key: The rng state used.
|
294
|
+
copy: Return a copy instead of writing to adata.
|
283
295
|
|
284
296
|
Returns:
|
285
297
|
Calls `self.__run_mcmc`
|
@@ -288,14 +300,14 @@ class CompositionalModel2(ABC):
|
|
288
300
|
try:
|
289
301
|
sample_adata = data[modality_key]
|
290
302
|
except IndexError:
|
291
|
-
|
303
|
+
logger.error("When data is a MuData object, modality_key must be specified!")
|
292
304
|
raise
|
293
305
|
if isinstance(data, AnnData):
|
294
306
|
sample_adata = data
|
295
307
|
if copy:
|
296
308
|
sample_adata = sample_adata.copy()
|
297
309
|
|
298
|
-
rng_key_array = random.
|
310
|
+
rng_key_array = random.key(rng_key)
|
299
311
|
sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = np.array(rng_key_array)
|
300
312
|
|
301
313
|
# Set up NUTS kernel
|
@@ -328,14 +340,13 @@ class CompositionalModel2(ABC):
|
|
328
340
|
|
329
341
|
Args:
|
330
342
|
data: AnnData object or MuData object.
|
331
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
332
|
-
num_samples: Number of sampled values after burn-in.
|
333
|
-
num_warmup: Number of burn-in (warmup) samples.
|
334
|
-
rng_key: The rng state used. If None, a random state will be selected.
|
335
|
-
copy: Return a copy instead of writing to adata.
|
343
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
344
|
+
num_samples: Number of sampled values after burn-in.
|
345
|
+
num_warmup: Number of burn-in (warmup) samples.
|
346
|
+
rng_key: The rng state used. If None, a random state will be selected.
|
347
|
+
copy: Return a copy instead of writing to adata.
|
336
348
|
|
337
349
|
Examples:
|
338
|
-
Example with scCODA:
|
339
350
|
>>> import pertpy as pt
|
340
351
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
341
352
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -348,7 +359,7 @@ class CompositionalModel2(ABC):
|
|
348
359
|
try:
|
349
360
|
sample_adata = data[modality_key]
|
350
361
|
except IndexError:
|
351
|
-
|
362
|
+
logger.error("When data is a MuData object, modality_key must be specified!")
|
352
363
|
raise
|
353
364
|
if isinstance(data, AnnData):
|
354
365
|
sample_adata = data
|
@@ -358,10 +369,10 @@ class CompositionalModel2(ABC):
|
|
358
369
|
# Set rng key if needed
|
359
370
|
if rng_key is None:
|
360
371
|
rng = np.random.default_rng()
|
361
|
-
rng_key = random.
|
372
|
+
rng_key = random.key(rng.integers(0, 10000))
|
362
373
|
sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = rng_key
|
363
374
|
else:
|
364
|
-
rng_key = random.
|
375
|
+
rng_key = random.key(rng_key)
|
365
376
|
|
366
377
|
# Set up HMC kernel
|
367
378
|
sample_adata = self.set_init_mcmc_states(
|
@@ -387,7 +398,7 @@ class CompositionalModel2(ABC):
|
|
387
398
|
|
388
399
|
Args:
|
389
400
|
sample_adata: Anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
|
390
|
-
est_fdr: Desired FDR value.
|
401
|
+
est_fdr: Desired FDR value.
|
391
402
|
args: Passed to ``az.summary``
|
392
403
|
kwargs: Passed to ``az.summary``
|
393
404
|
|
@@ -423,7 +434,6 @@ class CompositionalModel2(ABC):
|
|
423
434
|
- Is credible: Boolean indicator whether effect is credible
|
424
435
|
|
425
436
|
Examples:
|
426
|
-
Example with scCODA:
|
427
437
|
>>> import pertpy as pt
|
428
438
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
429
439
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -433,7 +443,6 @@ class CompositionalModel2(ABC):
|
|
433
443
|
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
434
444
|
>>> intercept_df, effect_df = sccoda.summary_prepare(mdata["coda"])
|
435
445
|
"""
|
436
|
-
# Get model and effect selection types
|
437
446
|
select_type = sample_adata.uns["scCODA_params"]["select_type"]
|
438
447
|
model_type = sample_adata.uns["scCODA_params"]["model_type"]
|
439
448
|
|
@@ -548,7 +557,11 @@ class CompositionalModel2(ABC):
|
|
548
557
|
intercept_df = intercept_df.loc[:, ["final_parameter", hdis[0], hdis[1], "sd", "expected_sample"]].copy()
|
549
558
|
intercept_df = intercept_df.rename(
|
550
559
|
columns=dict(
|
551
|
-
zip(
|
560
|
+
zip(
|
561
|
+
intercept_df.columns,
|
562
|
+
["Final Parameter", hdis_new[0], hdis_new[1], "SD", "Expected Sample"],
|
563
|
+
strict=False,
|
564
|
+
)
|
552
565
|
)
|
553
566
|
)
|
554
567
|
|
@@ -561,6 +574,7 @@ class CompositionalModel2(ABC):
|
|
561
574
|
zip(
|
562
575
|
effect_df.columns,
|
563
576
|
["Effect", "Median", hdis_new[0], hdis_new[1], "SD", "Expected Sample", "log2-fold change"],
|
577
|
+
strict=False,
|
564
578
|
)
|
565
579
|
)
|
566
580
|
)
|
@@ -581,6 +595,7 @@ class CompositionalModel2(ABC):
|
|
581
595
|
"Expected Sample",
|
582
596
|
"log2-fold change",
|
583
597
|
],
|
598
|
+
strict=False,
|
584
599
|
)
|
585
600
|
)
|
586
601
|
)
|
@@ -594,6 +609,7 @@ class CompositionalModel2(ABC):
|
|
594
609
|
zip(
|
595
610
|
node_df.columns,
|
596
611
|
["Final Parameter", "Median", hdis_new[0], hdis_new[1], "SD", "Delta", "Is credible"],
|
612
|
+
strict=False,
|
597
613
|
)
|
598
614
|
) # type: ignore
|
599
615
|
) # type: ignore
|
@@ -622,8 +638,8 @@ class CompositionalModel2(ABC):
|
|
622
638
|
effect_df: Effect summary, see ``summary_prepare``
|
623
639
|
model_type: String indicating the model type ("classic" or "tree_agg")
|
624
640
|
select_type: String indicating the type of spike_and_slab selection ("spikeslab" or "sslasso")
|
625
|
-
target_fdr: Desired FDR value.
|
626
|
-
node_df: If using tree aggregation, the node-level effect DataFrame must be passed.
|
641
|
+
target_fdr: Desired FDR value.
|
642
|
+
node_df: If using tree aggregation, the node-level effect DataFrame must be passed.
|
627
643
|
|
628
644
|
Returns:
|
629
645
|
pd.DataFrame: effect DataFrame with inclusion probability, final parameters, expected sample.
|
@@ -775,13 +791,12 @@ class CompositionalModel2(ABC):
|
|
775
791
|
|
776
792
|
Args:
|
777
793
|
data: AnnData object or MuData object.
|
778
|
-
extended: If True, return the extended summary with additional statistics.
|
779
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
794
|
+
extended: If True, return the extended summary with additional statistics.
|
795
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
780
796
|
args: Passed to az.summary
|
781
797
|
kwargs: Passed to az.summary
|
782
798
|
|
783
799
|
Examples:
|
784
|
-
Example with scCODA:
|
785
800
|
>>> import pertpy as pt
|
786
801
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
787
802
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -795,11 +810,11 @@ class CompositionalModel2(ABC):
|
|
795
810
|
try:
|
796
811
|
sample_adata = data[modality_key]
|
797
812
|
except IndexError:
|
798
|
-
|
813
|
+
logger.error("When data is a MuData object, modality_key must be specified!")
|
799
814
|
raise
|
800
815
|
if isinstance(data, AnnData):
|
801
816
|
sample_adata = data
|
802
|
-
|
817
|
+
|
803
818
|
select_type = sample_adata.uns["scCODA_params"]["select_type"]
|
804
819
|
model_type = sample_adata.uns["scCODA_params"]["model_type"]
|
805
820
|
|
@@ -834,10 +849,10 @@ class CompositionalModel2(ABC):
|
|
834
849
|
table.add_column("Name", justify="left", style="cyan")
|
835
850
|
table.add_column("Value", justify="left")
|
836
851
|
table.add_row("Data", "Data: %d samples, %d cell types" % data_dims)
|
837
|
-
table.add_row("Reference cell type", "
|
838
|
-
table.add_row("Formula", "
|
852
|
+
table.add_row("Reference cell type", "{}".format(str(sample_adata.uns["scCODA_params"]["reference_cell_type"])))
|
853
|
+
table.add_row("Formula", "{}".format(sample_adata.uns["scCODA_params"]["formula"]))
|
839
854
|
if extended:
|
840
|
-
table.add_row("Reference index", "
|
855
|
+
table.add_row("Reference index", "{}".format(str(sample_adata.uns["scCODA_params"]["reference_index"])))
|
841
856
|
if select_type == "spikeslab":
|
842
857
|
table.add_row(
|
843
858
|
"Spike-and-slab threshold",
|
@@ -920,13 +935,12 @@ class CompositionalModel2(ABC):
|
|
920
935
|
|
921
936
|
Args:
|
922
937
|
data: AnnData object or MuData object.
|
923
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
938
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
924
939
|
|
925
940
|
Returns:
|
926
941
|
pd.DataFrame: Intercept data frame.
|
927
942
|
|
928
943
|
Examples:
|
929
|
-
Example with scCODA:
|
930
944
|
>>> import pertpy as pt
|
931
945
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
932
946
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -936,12 +950,11 @@ class CompositionalModel2(ABC):
|
|
936
950
|
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
937
951
|
>>> intercepts = sccoda.get_intercept_df(mdata)
|
938
952
|
"""
|
939
|
-
|
940
953
|
if isinstance(data, MuData):
|
941
954
|
try:
|
942
955
|
sample_adata = data[modality_key]
|
943
956
|
except IndexError:
|
944
|
-
|
957
|
+
logger.error("When data is a MuData object, modality_key must be specified!")
|
945
958
|
raise
|
946
959
|
if isinstance(data, AnnData):
|
947
960
|
sample_adata = data
|
@@ -953,13 +966,12 @@ class CompositionalModel2(ABC):
|
|
953
966
|
|
954
967
|
Args:
|
955
968
|
data: AnnData object or MuData object.
|
956
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
969
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
957
970
|
|
958
971
|
Returns:
|
959
972
|
pd.DataFrame: Effect data frame.
|
960
973
|
|
961
974
|
Examples:
|
962
|
-
Example with scCODA:
|
963
975
|
>>> import pertpy as pt
|
964
976
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
965
977
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -969,12 +981,11 @@ class CompositionalModel2(ABC):
|
|
969
981
|
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
970
982
|
>>> effects = sccoda.get_effect_df(mdata)
|
971
983
|
"""
|
972
|
-
|
973
984
|
if isinstance(data, MuData):
|
974
985
|
try:
|
975
986
|
sample_adata = data[modality_key]
|
976
987
|
except IndexError:
|
977
|
-
|
988
|
+
logger.error("When data is a MuData object, modality_key must be specified!")
|
978
989
|
raise
|
979
990
|
if isinstance(data, AnnData):
|
980
991
|
sample_adata = data
|
@@ -997,15 +1008,14 @@ class CompositionalModel2(ABC):
|
|
997
1008
|
|
998
1009
|
Args:
|
999
1010
|
data: AnnData object or MuData object.
|
1000
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
1011
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1001
1012
|
|
1002
1013
|
Returns:
|
1003
1014
|
pd.DataFrame: Node effect data frame.
|
1004
1015
|
|
1005
1016
|
Examples:
|
1006
|
-
Example with tascCODA (works only for model of type tree_agg, i.e. a tascCODA model):
|
1007
1017
|
>>> import pertpy as pt
|
1008
|
-
>>> adata = pt.dt.
|
1018
|
+
>>> adata = pt.dt.tasccoda_example()
|
1009
1019
|
>>> tasccoda = pt.tl.Tasccoda()
|
1010
1020
|
>>> mdata = tasccoda.load(
|
1011
1021
|
>>> adata, type="sample_level",
|
@@ -1023,7 +1033,7 @@ class CompositionalModel2(ABC):
|
|
1023
1033
|
try:
|
1024
1034
|
sample_adata = data[modality_key]
|
1025
1035
|
except IndexError:
|
1026
|
-
|
1036
|
+
logger.error("When data is a MuData object, modality_key must be specified!")
|
1027
1037
|
raise
|
1028
1038
|
if isinstance(data, AnnData):
|
1029
1039
|
sample_adata = data
|
@@ -1037,7 +1047,7 @@ class CompositionalModel2(ABC):
|
|
1037
1047
|
Args:
|
1038
1048
|
data: AnnData object or MuData object.
|
1039
1049
|
est_fdr: Desired FDR value.
|
1040
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
1050
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1041
1051
|
args: passed to self.summary_prepare
|
1042
1052
|
kwargs: passed to self.summary_prepare
|
1043
1053
|
|
@@ -1048,7 +1058,7 @@ class CompositionalModel2(ABC):
|
|
1048
1058
|
try:
|
1049
1059
|
sample_adata = data[modality_key]
|
1050
1060
|
except IndexError:
|
1051
|
-
|
1061
|
+
logger.error("When data is a MuData object, modality_key must be specified!")
|
1052
1062
|
raise
|
1053
1063
|
if isinstance(data, AnnData):
|
1054
1064
|
sample_adata = data
|
@@ -1071,8 +1081,8 @@ class CompositionalModel2(ABC):
|
|
1071
1081
|
|
1072
1082
|
Args:
|
1073
1083
|
data: AnnData object or MuData object.
|
1074
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
1075
|
-
est_fdr: Estimated false discovery rate. Must be between 0 and 1.
|
1084
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1085
|
+
est_fdr: Estimated false discovery rate. Must be between 0 and 1.
|
1076
1086
|
|
1077
1087
|
Returns:
|
1078
1088
|
pd.Series: Credible effect decision series which includes boolean values indicate whether effects are credible under inc_prob_threshold.
|
@@ -1081,7 +1091,7 @@ class CompositionalModel2(ABC):
|
|
1081
1091
|
try:
|
1082
1092
|
sample_adata = data[modality_key]
|
1083
1093
|
except IndexError:
|
1084
|
-
|
1094
|
+
logger.error("When data is a MuData object, modality_key must be specified!")
|
1085
1095
|
raise
|
1086
1096
|
if isinstance(data, AnnData):
|
1087
1097
|
sample_adata = data
|
@@ -1113,9 +1123,1120 @@ class CompositionalModel2(ABC):
|
|
1113
1123
|
|
1114
1124
|
return out
|
1115
1125
|
|
1126
|
+
def _stackbar( # pragma: no cover
|
1127
|
+
self,
|
1128
|
+
y: np.ndarray,
|
1129
|
+
type_names: list[str],
|
1130
|
+
title: str,
|
1131
|
+
level_names: list[str],
|
1132
|
+
figsize: tuple[float, float] | None = None,
|
1133
|
+
dpi: int | None = 100,
|
1134
|
+
palette: ListedColormap | None = cm.tab20,
|
1135
|
+
show_legend: bool | None = True,
|
1136
|
+
) -> plt.Axes:
|
1137
|
+
"""Plots a stacked barplot for one (discrete) covariate.
|
1138
|
+
|
1139
|
+
Typical use (only inside stacked_barplot): plot_one_stackbar(data.X, data.var.index, "xyz", data.obs.index)
|
1140
|
+
|
1141
|
+
Args:
|
1142
|
+
y: The count data, collapsed onto the level of interest. i.e. a binary covariate has two rows,
|
1143
|
+
one for each group, containing the count mean of each cell type
|
1144
|
+
type_names: The names of all cell types
|
1145
|
+
title: Plot title, usually the covariate's name
|
1146
|
+
level_names: Names of the covariate's levels
|
1147
|
+
figsize: Figure size (matplotlib).
|
1148
|
+
dpi: Resolution in DPI (matplotlib).
|
1149
|
+
palette: The color map for the barplot.
|
1150
|
+
show_legend: If True, adds a legend.
|
1151
|
+
|
1152
|
+
Returns:
|
1153
|
+
A :class:`~matplotlib.axes.Axes` object
|
1154
|
+
"""
|
1155
|
+
n_bars, n_types = y.shape
|
1156
|
+
|
1157
|
+
figsize = rcParams["figure.figsize"] if figsize is None else figsize
|
1158
|
+
|
1159
|
+
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1160
|
+
r = np.array(range(n_bars))
|
1161
|
+
sample_sums = np.sum(y, axis=1)
|
1162
|
+
|
1163
|
+
barwidth = 0.85
|
1164
|
+
cum_bars = np.zeros(n_bars)
|
1165
|
+
|
1166
|
+
for n in range(n_types):
|
1167
|
+
bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums, strict=False)]
|
1168
|
+
plt.bar(
|
1169
|
+
r,
|
1170
|
+
bars,
|
1171
|
+
bottom=cum_bars,
|
1172
|
+
color=palette(n % palette.N),
|
1173
|
+
width=barwidth,
|
1174
|
+
label=type_names[n],
|
1175
|
+
linewidth=0,
|
1176
|
+
)
|
1177
|
+
cum_bars += bars
|
1178
|
+
|
1179
|
+
ax.set_title(title)
|
1180
|
+
if show_legend:
|
1181
|
+
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1)
|
1182
|
+
ax.set_xticks(r)
|
1183
|
+
ax.set_xticklabels(level_names, rotation=45, ha="right")
|
1184
|
+
ax.set_ylabel("Proportion")
|
1185
|
+
|
1186
|
+
return ax
|
1187
|
+
|
1188
|
+
def plot_stacked_barplot( # pragma: no cover
|
1189
|
+
self,
|
1190
|
+
data: AnnData | MuData,
|
1191
|
+
feature_name: str,
|
1192
|
+
modality_key: str = "coda",
|
1193
|
+
palette: ListedColormap | None = cm.tab20,
|
1194
|
+
show_legend: bool | None = True,
|
1195
|
+
level_order: list[str] = None,
|
1196
|
+
figsize: tuple[float, float] | None = None,
|
1197
|
+
dpi: int | None = 100,
|
1198
|
+
return_fig: bool | None = None,
|
1199
|
+
ax: plt.Axes | None = None,
|
1200
|
+
show: bool | None = None,
|
1201
|
+
save: str | bool | None = None,
|
1202
|
+
**kwargs,
|
1203
|
+
) -> plt.Axes | plt.Figure | None:
|
1204
|
+
"""Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
|
1205
|
+
|
1206
|
+
Args:
|
1207
|
+
data: AnnData object or MuData object.
|
1208
|
+
feature_name: The name of the covariate to plot. If feature_name=="samples", one bar for every sample will be plotted
|
1209
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1210
|
+
figsize: Figure size.
|
1211
|
+
dpi: Dpi setting.
|
1212
|
+
palette: The matplotlib color map for the barplot.
|
1213
|
+
show_legend: If True, adds a legend.
|
1214
|
+
level_order: Custom ordering of bars on the x-axis.
|
1215
|
+
|
1216
|
+
Returns:
|
1217
|
+
A :class:`~matplotlib.axes.Axes` object
|
1218
|
+
|
1219
|
+
Examples:
|
1220
|
+
>>> import pertpy as pt
|
1221
|
+
>>> haber_cells = pt.dt.haber_2017_regions()
|
1222
|
+
>>> sccoda = pt.tl.Sccoda()
|
1223
|
+
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
1224
|
+
sample_identifier="batch", covariate_obs=["condition"])
|
1225
|
+
>>> sccoda.plot_stacked_barplot(mdata, feature_name="samples")
|
1226
|
+
|
1227
|
+
Preview:
|
1228
|
+
.. image:: /_static/docstring_previews/sccoda_stacked_barplot.png
|
1229
|
+
"""
|
1230
|
+
if isinstance(data, MuData):
|
1231
|
+
data = data[modality_key]
|
1232
|
+
if isinstance(data, AnnData):
|
1233
|
+
data = data
|
1234
|
+
|
1235
|
+
ct_names = data.var.index
|
1236
|
+
|
1237
|
+
# option to plot one stacked barplot per sample
|
1238
|
+
if feature_name == "samples":
|
1239
|
+
if level_order:
|
1240
|
+
assert set(level_order) == set(data.obs.index), "level order is inconsistent with levels"
|
1241
|
+
data = data[level_order]
|
1242
|
+
ax = self._stackbar(
|
1243
|
+
data.X,
|
1244
|
+
type_names=data.var.index,
|
1245
|
+
title="samples",
|
1246
|
+
level_names=data.obs.index,
|
1247
|
+
figsize=figsize,
|
1248
|
+
dpi=dpi,
|
1249
|
+
palette=palette,
|
1250
|
+
show_legend=show_legend,
|
1251
|
+
)
|
1252
|
+
else:
|
1253
|
+
# Order levels
|
1254
|
+
if level_order:
|
1255
|
+
assert set(level_order) == set(data.obs[feature_name]), "level order is inconsistent with levels"
|
1256
|
+
levels = level_order
|
1257
|
+
elif hasattr(data.obs[feature_name], "cat"):
|
1258
|
+
levels = data.obs[feature_name].cat.categories.to_list()
|
1259
|
+
else:
|
1260
|
+
levels = pd.unique(data.obs[feature_name])
|
1261
|
+
n_levels = len(levels)
|
1262
|
+
feature_totals = np.zeros([n_levels, data.X.shape[1]])
|
1263
|
+
|
1264
|
+
for level in range(n_levels):
|
1265
|
+
l_indices = np.where(data.obs[feature_name] == levels[level])
|
1266
|
+
feature_totals[level] = np.sum(data.X[l_indices], axis=0)
|
1267
|
+
|
1268
|
+
ax = self._stackbar(
|
1269
|
+
feature_totals,
|
1270
|
+
type_names=ct_names,
|
1271
|
+
title=feature_name,
|
1272
|
+
level_names=levels,
|
1273
|
+
figsize=figsize,
|
1274
|
+
dpi=dpi,
|
1275
|
+
palette=palette,
|
1276
|
+
show_legend=show_legend,
|
1277
|
+
)
|
1278
|
+
|
1279
|
+
if save:
|
1280
|
+
plt.savefig(save, bbox_inches="tight")
|
1281
|
+
if show:
|
1282
|
+
plt.show()
|
1283
|
+
if return_fig:
|
1284
|
+
return plt.gcf()
|
1285
|
+
if not (show or save):
|
1286
|
+
return ax
|
1287
|
+
return None
|
1288
|
+
|
1289
|
+
def plot_effects_barplot( # pragma: no cover
|
1290
|
+
self,
|
1291
|
+
data: AnnData | MuData,
|
1292
|
+
modality_key: str = "coda",
|
1293
|
+
covariates: str | list | None = None,
|
1294
|
+
parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
|
1295
|
+
plot_facets: bool = True,
|
1296
|
+
plot_zero_covariate: bool = True,
|
1297
|
+
plot_zero_cell_type: bool = False,
|
1298
|
+
palette: str | ListedColormap | None = cm.tab20,
|
1299
|
+
level_order: list[str] = None,
|
1300
|
+
args_barplot: dict | None = None,
|
1301
|
+
figsize: tuple[float, float] | None = None,
|
1302
|
+
dpi: int | None = 100,
|
1303
|
+
return_fig: bool | None = None,
|
1304
|
+
ax: plt.Axes | None = None,
|
1305
|
+
show: bool | None = None,
|
1306
|
+
save: str | bool | None = None,
|
1307
|
+
) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
|
1308
|
+
"""Barplot visualization for effects.
|
1309
|
+
|
1310
|
+
The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
|
1311
|
+
The covariates groups can either be ordered along the x-axis of a single plot (plot_facets=False) or as plot facets (plot_facets=True).
|
1312
|
+
|
1313
|
+
Args:
|
1314
|
+
data: AnnData object or MuData object.
|
1315
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1316
|
+
covariates: The name of the covariates in data.obs to plot.
|
1317
|
+
parameter: The parameter in effect summary to plot.
|
1318
|
+
plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
|
1319
|
+
plot_zero_covariate: If True, plot covariate that have all zero effects. If False, do not plot.
|
1320
|
+
plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot.
|
1321
|
+
figsize: Figure size.
|
1322
|
+
dpi: Figure size.
|
1323
|
+
palette: The seaborn color map for the barplot.
|
1324
|
+
level_order: Custom ordering of bars on the x-axis.
|
1325
|
+
args_barplot: Arguments passed to sns.barplot.
|
1326
|
+
|
1327
|
+
Returns:
|
1328
|
+
Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
|
1329
|
+
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
|
1330
|
+
|
1331
|
+
Examples:
|
1332
|
+
>>> import pertpy as pt
|
1333
|
+
>>> haber_cells = pt.dt.haber_2017_regions()
|
1334
|
+
>>> sccoda = pt.tl.Sccoda()
|
1335
|
+
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
1336
|
+
sample_identifier="batch", covariate_obs=["condition"])
|
1337
|
+
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
1338
|
+
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
1339
|
+
>>> sccoda.plot_effects_barplot(mdata)
|
1340
|
+
|
1341
|
+
Preview:
|
1342
|
+
.. image:: /_static/docstring_previews/sccoda_effects_barplot.png
|
1343
|
+
"""
|
1344
|
+
if args_barplot is None:
|
1345
|
+
args_barplot = {}
|
1346
|
+
if isinstance(data, MuData):
|
1347
|
+
data = data[modality_key]
|
1348
|
+
if isinstance(data, AnnData):
|
1349
|
+
data = data
|
1350
|
+
# Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
|
1351
|
+
covariate_names = data.uns["scCODA_params"]["covariate_names"]
|
1352
|
+
if covariates is not None:
|
1353
|
+
if isinstance(covariates, str):
|
1354
|
+
covariates = [covariates]
|
1355
|
+
partial_covariate_names = [
|
1356
|
+
covariate_name
|
1357
|
+
for covariate_name in covariate_names
|
1358
|
+
if any(covariate in covariate_name for covariate in covariates)
|
1359
|
+
]
|
1360
|
+
covariate_names = partial_covariate_names
|
1361
|
+
covariate_names_non_zero = [
|
1362
|
+
covariate_name
|
1363
|
+
for covariate_name in covariate_names
|
1364
|
+
if data.varm[f"effect_df_{covariate_name}"][parameter].any()
|
1365
|
+
]
|
1366
|
+
covariate_names_zero = list(set(covariate_names) - set(covariate_names_non_zero))
|
1367
|
+
if not plot_zero_covariate:
|
1368
|
+
covariate_names = covariate_names_non_zero
|
1369
|
+
|
1370
|
+
# set up df for plotting
|
1371
|
+
plot_df = pd.concat(
|
1372
|
+
[data.varm[f"effect_df_{covariate_name}"][parameter] for covariate_name in covariate_names],
|
1373
|
+
axis=1,
|
1374
|
+
)
|
1375
|
+
plot_df.columns = covariate_names
|
1376
|
+
plot_df = pd.melt(plot_df, ignore_index=False, var_name="Covariate")
|
1377
|
+
|
1378
|
+
plot_df = plot_df.reset_index()
|
1379
|
+
|
1380
|
+
if len(covariate_names_zero) != 0:
|
1381
|
+
if plot_facets:
|
1382
|
+
if plot_zero_covariate and not plot_zero_cell_type:
|
1383
|
+
plot_df = plot_df[plot_df["value"] != 0]
|
1384
|
+
for covariate_name_zero in covariate_names_zero:
|
1385
|
+
new_row = {
|
1386
|
+
"Covariate": covariate_name_zero,
|
1387
|
+
"Cell Type": "zero",
|
1388
|
+
"value": 0,
|
1389
|
+
}
|
1390
|
+
plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
|
1391
|
+
plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
|
1392
|
+
plot_df = plot_df.sort_values(["covariate_"])
|
1393
|
+
if not plot_zero_cell_type:
|
1394
|
+
cell_type_names_zero = [
|
1395
|
+
name
|
1396
|
+
for name in plot_df["Cell Type"].unique()
|
1397
|
+
if (plot_df[plot_df["Cell Type"] == name]["value"] == 0).all()
|
1398
|
+
]
|
1399
|
+
plot_df = plot_df[~plot_df["Cell Type"].isin(cell_type_names_zero)]
|
1400
|
+
|
1401
|
+
# If plot as facets, create a FacetGrid and map barplot to it.
|
1402
|
+
if plot_facets:
|
1403
|
+
if isinstance(palette, ListedColormap):
|
1404
|
+
palette = np.array([palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]).tolist()
|
1405
|
+
if figsize is not None:
|
1406
|
+
height = figsize[0]
|
1407
|
+
aspect = np.round(figsize[1] / figsize[0], 2)
|
1408
|
+
else:
|
1409
|
+
height = 3
|
1410
|
+
aspect = 2
|
1411
|
+
|
1412
|
+
g = sns.FacetGrid(
|
1413
|
+
plot_df,
|
1414
|
+
col="Covariate",
|
1415
|
+
sharey=True,
|
1416
|
+
sharex=False,
|
1417
|
+
height=height,
|
1418
|
+
aspect=aspect,
|
1419
|
+
)
|
1420
|
+
|
1421
|
+
g.map(
|
1422
|
+
sns.barplot,
|
1423
|
+
"Cell Type",
|
1424
|
+
"value",
|
1425
|
+
palette=palette,
|
1426
|
+
order=level_order,
|
1427
|
+
**args_barplot,
|
1428
|
+
)
|
1429
|
+
g.set_xticklabels(rotation=90)
|
1430
|
+
g.set(ylabel=parameter)
|
1431
|
+
axes = g.axes.flatten()
|
1432
|
+
for i, ax in enumerate(axes):
|
1433
|
+
ax.set_title(covariate_names[i])
|
1434
|
+
if len(ax.get_xticklabels()) < 5:
|
1435
|
+
ax.set_aspect(10 / len(ax.get_xticklabels()))
|
1436
|
+
if len(ax.get_xticklabels()) == 1:
|
1437
|
+
if ax.get_xticklabels()[0]._text == "zero":
|
1438
|
+
ax.set_xticks([])
|
1439
|
+
|
1440
|
+
if save:
|
1441
|
+
plt.savefig(save, bbox_inches="tight")
|
1442
|
+
if show:
|
1443
|
+
plt.show()
|
1444
|
+
if return_fig:
|
1445
|
+
return plt.gcf()
|
1446
|
+
if not (show or save):
|
1447
|
+
return g
|
1448
|
+
return None
|
1449
|
+
|
1450
|
+
# If not plot as facets, call barplot to plot cell types on the x-axis.
|
1451
|
+
else:
|
1452
|
+
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1453
|
+
if len(covariate_names) == 1:
|
1454
|
+
if isinstance(palette, ListedColormap):
|
1455
|
+
palette = np.array(
|
1456
|
+
[palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]
|
1457
|
+
).tolist()
|
1458
|
+
sns.barplot(
|
1459
|
+
data=plot_df,
|
1460
|
+
x="Cell Type",
|
1461
|
+
y="value",
|
1462
|
+
hue="x",
|
1463
|
+
palette=palette,
|
1464
|
+
ax=ax,
|
1465
|
+
)
|
1466
|
+
ax.set_title(covariate_names[0])
|
1467
|
+
else:
|
1468
|
+
if isinstance(palette, ListedColormap):
|
1469
|
+
palette = np.array([palette(i % palette.N) for i in range(len(covariate_names))]).tolist()
|
1470
|
+
sns.barplot(
|
1471
|
+
data=plot_df,
|
1472
|
+
x="Cell Type",
|
1473
|
+
y="value",
|
1474
|
+
hue="Covariate",
|
1475
|
+
palette=palette,
|
1476
|
+
ax=ax,
|
1477
|
+
)
|
1478
|
+
cell_types = pd.unique(plot_df["Cell Type"])
|
1479
|
+
ax.set_xticklabels(cell_types, rotation=90)
|
1480
|
+
|
1481
|
+
if save:
|
1482
|
+
plt.savefig(save, bbox_inches="tight")
|
1483
|
+
if show:
|
1484
|
+
plt.show()
|
1485
|
+
if return_fig:
|
1486
|
+
return plt.gcf()
|
1487
|
+
if not (show or save):
|
1488
|
+
return ax
|
1489
|
+
return None
|
1490
|
+
|
1491
|
+
def plot_boxplots( # pragma: no cover
|
1492
|
+
self,
|
1493
|
+
data: AnnData | MuData,
|
1494
|
+
feature_name: str,
|
1495
|
+
modality_key: str = "coda",
|
1496
|
+
y_scale: Literal["relative", "log", "log10", "count"] = "relative",
|
1497
|
+
plot_facets: bool = False,
|
1498
|
+
add_dots: bool = False,
|
1499
|
+
cell_types: list | None = None,
|
1500
|
+
args_boxplot: dict | None = None,
|
1501
|
+
args_swarmplot: dict | None = None,
|
1502
|
+
palette: str | None = "Blues",
|
1503
|
+
show_legend: bool | None = True,
|
1504
|
+
level_order: list[str] = None,
|
1505
|
+
figsize: tuple[float, float] | None = None,
|
1506
|
+
dpi: int | None = 100,
|
1507
|
+
return_fig: bool | None = None,
|
1508
|
+
ax: plt.Axes | None = None,
|
1509
|
+
show: bool | None = None,
|
1510
|
+
save: str | bool | None = None,
|
1511
|
+
) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
|
1512
|
+
"""Grouped boxplot visualization.
|
1513
|
+
|
1514
|
+
The cell counts for each cell type are shown as a group of boxplots
|
1515
|
+
with intra--group separation by a covariate from data.obs.
|
1516
|
+
|
1517
|
+
Args:
|
1518
|
+
data: AnnData object or MuData object
|
1519
|
+
feature_name: The name of the feature in data.obs to plot
|
1520
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1521
|
+
y_scale: Transformation to of cell counts. Options: "relative" - Relative abundance, "log" - log(count),
|
1522
|
+
"log10" - log10(count), "count" - absolute abundance (cell counts).
|
1523
|
+
plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
|
1524
|
+
add_dots: If True, overlay a scatterplot with one dot for each data point.
|
1525
|
+
cell_types: Subset of cell types that should be plotted.
|
1526
|
+
args_boxplot: Arguments passed to sns.boxplot.
|
1527
|
+
args_swarmplot: Arguments passed to sns.swarmplot.
|
1528
|
+
figsize: Figure size.
|
1529
|
+
dpi: Dpi setting.
|
1530
|
+
palette: The seaborn color map for the barplot.
|
1531
|
+
show_legend: If True, adds a legend.
|
1532
|
+
level_order: Custom ordering of bars on the x-axis.
|
1533
|
+
|
1534
|
+
Returns:
|
1535
|
+
Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
|
1536
|
+
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
|
1537
|
+
|
1538
|
+
Examples:
|
1539
|
+
>>> import pertpy as pt
|
1540
|
+
>>> haber_cells = pt.dt.haber_2017_regions()
|
1541
|
+
>>> sccoda = pt.tl.Sccoda()
|
1542
|
+
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
1543
|
+
sample_identifier="batch", covariate_obs=["condition"])
|
1544
|
+
>>> sccoda.plot_boxplots(mdata, feature_name="condition", add_dots=True)
|
1545
|
+
|
1546
|
+
Preview:
|
1547
|
+
.. image:: /_static/docstring_previews/sccoda_boxplots.png
|
1548
|
+
"""
|
1549
|
+
if args_boxplot is None:
|
1550
|
+
args_boxplot = {}
|
1551
|
+
if args_swarmplot is None:
|
1552
|
+
args_swarmplot = {}
|
1553
|
+
if isinstance(data, MuData):
|
1554
|
+
data = data[modality_key]
|
1555
|
+
if isinstance(data, AnnData):
|
1556
|
+
data = data
|
1557
|
+
# y scale transformations
|
1558
|
+
if y_scale == "relative":
|
1559
|
+
sample_sums = np.sum(data.X, axis=1, keepdims=True)
|
1560
|
+
X = data.X / sample_sums
|
1561
|
+
value_name = "Proportion"
|
1562
|
+
# add pseudocount 0.5 if using log scale
|
1563
|
+
elif y_scale == "log":
|
1564
|
+
X = data.X.copy()
|
1565
|
+
X[X == 0] = 0.5
|
1566
|
+
X = np.log(X)
|
1567
|
+
value_name = "log(count)"
|
1568
|
+
elif y_scale == "log10":
|
1569
|
+
X = data.X.copy()
|
1570
|
+
X[X == 0] = 0.5
|
1571
|
+
X = np.log(X)
|
1572
|
+
value_name = "log10(count)"
|
1573
|
+
elif y_scale == "count":
|
1574
|
+
X = data.X
|
1575
|
+
value_name = "count"
|
1576
|
+
else:
|
1577
|
+
raise ValueError("Invalid y_scale transformation")
|
1578
|
+
|
1579
|
+
count_df = pd.DataFrame(X, columns=data.var.index, index=data.obs.index).merge(
|
1580
|
+
data.obs[feature_name], left_index=True, right_index=True
|
1581
|
+
)
|
1582
|
+
plot_df = pd.melt(count_df, id_vars=feature_name, var_name="Cell type", value_name=value_name)
|
1583
|
+
if cell_types is not None:
|
1584
|
+
plot_df = plot_df[plot_df["Cell type"].isin(cell_types)]
|
1585
|
+
|
1586
|
+
# Currently disabled because the latest statsannotations does not support the latest seaborn.
|
1587
|
+
# We had to drop the dependency.
|
1588
|
+
# Get credible effects results from model
|
1589
|
+
# if draw_effects:
|
1590
|
+
# if model is not None:
|
1591
|
+
# credible_effects_df = model.credible_effects(data, modality_key).to_frame().reset_index()
|
1592
|
+
# else:
|
1593
|
+
# print("[bold yellow]Specify a tasCODA model to draw effects")
|
1594
|
+
# credible_effects_df[feature_name] = credible_effects_df["Covariate"].str.removeprefix(f"{feature_name}[T.")
|
1595
|
+
# credible_effects_df[feature_name] = credible_effects_df[feature_name].str.removesuffix("]")
|
1596
|
+
# credible_effects_df = credible_effects_df[credible_effects_df["Final Parameter"]]
|
1597
|
+
|
1598
|
+
# If plot as facets, create a FacetGrid and map boxplot to it.
|
1599
|
+
if plot_facets:
|
1600
|
+
if level_order is None:
|
1601
|
+
level_order = pd.unique(plot_df[feature_name])
|
1602
|
+
|
1603
|
+
K = X.shape[1]
|
1604
|
+
|
1605
|
+
if figsize is not None:
|
1606
|
+
height = figsize[0]
|
1607
|
+
aspect = np.round(figsize[1] / figsize[0], 2)
|
1608
|
+
else:
|
1609
|
+
height = 3
|
1610
|
+
aspect = 2
|
1611
|
+
|
1612
|
+
g = sns.FacetGrid(
|
1613
|
+
plot_df,
|
1614
|
+
col="Cell type",
|
1615
|
+
sharey=False,
|
1616
|
+
col_wrap=int(np.floor(np.sqrt(K))),
|
1617
|
+
height=height,
|
1618
|
+
aspect=aspect,
|
1619
|
+
)
|
1620
|
+
g.map(
|
1621
|
+
sns.boxplot,
|
1622
|
+
feature_name,
|
1623
|
+
value_name,
|
1624
|
+
palette=palette,
|
1625
|
+
order=level_order,
|
1626
|
+
**args_boxplot,
|
1627
|
+
)
|
1628
|
+
|
1629
|
+
if add_dots:
|
1630
|
+
if "hue" in args_swarmplot:
|
1631
|
+
hue = args_swarmplot.pop("hue")
|
1632
|
+
else:
|
1633
|
+
hue = None
|
1634
|
+
|
1635
|
+
if hue is None:
|
1636
|
+
g.map(
|
1637
|
+
sns.swarmplot,
|
1638
|
+
feature_name,
|
1639
|
+
value_name,
|
1640
|
+
color="black",
|
1641
|
+
order=level_order,
|
1642
|
+
**args_swarmplot,
|
1643
|
+
).set_titles("{col_name}")
|
1644
|
+
else:
|
1645
|
+
g.map(
|
1646
|
+
sns.swarmplot,
|
1647
|
+
feature_name,
|
1648
|
+
value_name,
|
1649
|
+
hue,
|
1650
|
+
order=level_order,
|
1651
|
+
**args_swarmplot,
|
1652
|
+
).set_titles("{col_name}")
|
1653
|
+
|
1654
|
+
if save:
|
1655
|
+
plt.savefig(save, bbox_inches="tight")
|
1656
|
+
if show:
|
1657
|
+
plt.show()
|
1658
|
+
if return_fig:
|
1659
|
+
return plt.gcf()
|
1660
|
+
if not (show or save):
|
1661
|
+
return g
|
1662
|
+
return None
|
1663
|
+
|
1664
|
+
# If not plot as facets, call boxplot to plot cell types on the x-axis.
|
1665
|
+
else:
|
1666
|
+
if level_order:
|
1667
|
+
args_boxplot["hue_order"] = level_order
|
1668
|
+
args_swarmplot["hue_order"] = level_order
|
1669
|
+
|
1670
|
+
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1671
|
+
|
1672
|
+
ax = sns.boxplot(
|
1673
|
+
x="Cell type",
|
1674
|
+
y=value_name,
|
1675
|
+
hue=feature_name,
|
1676
|
+
data=plot_df,
|
1677
|
+
fliersize=1,
|
1678
|
+
palette=palette,
|
1679
|
+
ax=ax,
|
1680
|
+
**args_boxplot,
|
1681
|
+
)
|
1682
|
+
|
1683
|
+
# Currently disabled because the latest statsannotations does not support the latest seaborn.
|
1684
|
+
# We had to drop the dependency.
|
1685
|
+
# if draw_effects:
|
1686
|
+
# pairs = [
|
1687
|
+
# [(row["Cell Type"], row[feature_name]), (row["Cell Type"], "Control")]
|
1688
|
+
# for _, row in credible_effects_df.iterrows()
|
1689
|
+
# ]
|
1690
|
+
# annot = Annotator(ax, pairs, data=plot_df, x="Cell type", y=value_name, hue=feature_name)
|
1691
|
+
# annot.configure(test=None, loc="outside", color="red", line_height=0, verbose=False)
|
1692
|
+
# annot.set_custom_annotations([row[feature_name] for _, row in credible_effects_df.iterrows()])
|
1693
|
+
# annot.annotate()
|
1694
|
+
|
1695
|
+
if add_dots:
|
1696
|
+
sns.swarmplot(
|
1697
|
+
x="Cell type",
|
1698
|
+
y=value_name,
|
1699
|
+
data=plot_df,
|
1700
|
+
hue=feature_name,
|
1701
|
+
ax=ax,
|
1702
|
+
dodge=True,
|
1703
|
+
palette="dark:black",
|
1704
|
+
**args_swarmplot,
|
1705
|
+
)
|
1706
|
+
|
1707
|
+
cell_types = pd.unique(plot_df["Cell type"])
|
1708
|
+
ax.set_xticklabels(cell_types, rotation=90)
|
1709
|
+
|
1710
|
+
if show_legend:
|
1711
|
+
handles, labels = ax.get_legend_handles_labels()
|
1712
|
+
handout = []
|
1713
|
+
labelout = []
|
1714
|
+
for h, l in zip(handles, labels, strict=False):
|
1715
|
+
if l not in labelout:
|
1716
|
+
labelout.append(l)
|
1717
|
+
handout.append(h)
|
1718
|
+
ax.legend(
|
1719
|
+
handout,
|
1720
|
+
labelout,
|
1721
|
+
loc="upper left",
|
1722
|
+
bbox_to_anchor=(1, 1),
|
1723
|
+
ncol=1,
|
1724
|
+
title=feature_name,
|
1725
|
+
)
|
1726
|
+
|
1727
|
+
if save:
|
1728
|
+
plt.savefig(save, bbox_inches="tight")
|
1729
|
+
if show:
|
1730
|
+
plt.show()
|
1731
|
+
if return_fig:
|
1732
|
+
return plt.gcf()
|
1733
|
+
if not (show or save):
|
1734
|
+
return ax
|
1735
|
+
return None
|
1736
|
+
|
1737
|
+
def plot_rel_abundance_dispersion_plot( # pragma: no cover
|
1738
|
+
self,
|
1739
|
+
data: AnnData | MuData,
|
1740
|
+
modality_key: str = "coda",
|
1741
|
+
abundant_threshold: float | None = 0.9,
|
1742
|
+
default_color: str | None = "Grey",
|
1743
|
+
abundant_color: str | None = "Red",
|
1744
|
+
label_cell_types: bool = True,
|
1745
|
+
figsize: tuple[float, float] | None = None,
|
1746
|
+
dpi: int | None = 100,
|
1747
|
+
return_fig: bool | None = None,
|
1748
|
+
ax: plt.Axes | None = None,
|
1749
|
+
show: bool | None = None,
|
1750
|
+
save: str | bool | None = None,
|
1751
|
+
) -> plt.Axes | plt.Figure | None:
|
1752
|
+
"""Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
|
1753
|
+
|
1754
|
+
If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color.
|
1755
|
+
|
1756
|
+
Args:
|
1757
|
+
data: AnnData or MuData object.
|
1758
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1759
|
+
abundant_threshold: Presence threshold for abundant cell types.
|
1760
|
+
default_color: Bar color for all non-minimal cell types.
|
1761
|
+
abundant_color: Bar color for cell types with abundant percentage larger than abundant_threshold.
|
1762
|
+
label_cell_types: Label dots with cell type names.
|
1763
|
+
figsize: Figure size.
|
1764
|
+
dpi: Dpi setting.
|
1765
|
+
ax: A matplotlib axes object. Only works if plotting a single component.
|
1766
|
+
|
1767
|
+
Returns:
|
1768
|
+
A :class:`~matplotlib.axes.Axes` object
|
1769
|
+
|
1770
|
+
Examples:
|
1771
|
+
>>> import pertpy as pt
|
1772
|
+
>>> haber_cells = pt.dt.haber_2017_regions()
|
1773
|
+
>>> sccoda = pt.tl.Sccoda()
|
1774
|
+
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
1775
|
+
sample_identifier="batch", covariate_obs=["condition"])
|
1776
|
+
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
1777
|
+
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
1778
|
+
>>> sccoda.plot_rel_abundance_dispersion_plot(mdata)
|
1779
|
+
|
1780
|
+
Preview:
|
1781
|
+
.. image:: /_static/docstring_previews/sccoda_rel_abundance_dispersion_plot.png
|
1782
|
+
"""
|
1783
|
+
if isinstance(data, MuData):
|
1784
|
+
data = data[modality_key]
|
1785
|
+
if isinstance(data, AnnData):
|
1786
|
+
data = data
|
1787
|
+
if ax is None:
|
1788
|
+
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
1789
|
+
|
1790
|
+
rel_abun = data.X / np.sum(data.X, axis=1, keepdims=True)
|
1791
|
+
|
1792
|
+
percent_zero = np.sum(data.X == 0, axis=0) / data.X.shape[0]
|
1793
|
+
nonrare_ct = np.where(percent_zero < 1 - abundant_threshold)[0]
|
1794
|
+
|
1795
|
+
# select reference
|
1796
|
+
cell_type_disp = np.var(rel_abun, axis=0) / np.mean(rel_abun, axis=0)
|
1797
|
+
|
1798
|
+
is_abundant = [x in nonrare_ct for x in range(data.X.shape[1])]
|
1799
|
+
|
1800
|
+
# Scatterplot
|
1801
|
+
plot_df = pd.DataFrame(
|
1802
|
+
{
|
1803
|
+
"Total dispersion": cell_type_disp,
|
1804
|
+
"Cell type": data.var.index,
|
1805
|
+
"Presence": 1 - percent_zero,
|
1806
|
+
"Is abundant": is_abundant,
|
1807
|
+
}
|
1808
|
+
)
|
1809
|
+
|
1810
|
+
if len(np.unique(plot_df["Is abundant"])) > 1:
|
1811
|
+
palette = [default_color, abundant_color]
|
1812
|
+
elif np.unique(plot_df["Is abundant"]) == [False]:
|
1813
|
+
palette = [default_color]
|
1814
|
+
else:
|
1815
|
+
palette = [abundant_color]
|
1816
|
+
|
1817
|
+
ax = sns.scatterplot(
|
1818
|
+
data=plot_df,
|
1819
|
+
x="Presence",
|
1820
|
+
y="Total dispersion",
|
1821
|
+
hue="Is abundant",
|
1822
|
+
palette=palette,
|
1823
|
+
ax=ax,
|
1824
|
+
)
|
1825
|
+
|
1826
|
+
# Text labels for abundant cell types
|
1827
|
+
|
1828
|
+
abundant_df = plot_df.loc[plot_df["Is abundant"], :]
|
1829
|
+
|
1830
|
+
def label_point(x, y, val, ax):
|
1831
|
+
a = pd.concat({"x": x, "y": y, "val": val}, axis=1)
|
1832
|
+
texts = [
|
1833
|
+
ax.text(
|
1834
|
+
point["x"],
|
1835
|
+
point["y"],
|
1836
|
+
str(point["val"]),
|
1837
|
+
)
|
1838
|
+
for i, point in a.iterrows()
|
1839
|
+
]
|
1840
|
+
adjust_text(texts)
|
1841
|
+
|
1842
|
+
if label_cell_types:
|
1843
|
+
label_point(
|
1844
|
+
abundant_df["Presence"],
|
1845
|
+
abundant_df["Total dispersion"],
|
1846
|
+
abundant_df["Cell type"],
|
1847
|
+
plt.gca(),
|
1848
|
+
)
|
1849
|
+
|
1850
|
+
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
|
1851
|
+
|
1852
|
+
if save:
|
1853
|
+
plt.savefig(save, bbox_inches="tight")
|
1854
|
+
if show:
|
1855
|
+
plt.show()
|
1856
|
+
if return_fig:
|
1857
|
+
return plt.gcf()
|
1858
|
+
if not (show or save):
|
1859
|
+
return ax
|
1860
|
+
return None
|
1861
|
+
|
1862
|
+
def plot_draw_tree( # pragma: no cover
|
1863
|
+
self,
|
1864
|
+
data: AnnData | MuData,
|
1865
|
+
modality_key: str = "coda",
|
1866
|
+
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
|
1867
|
+
tight_text: bool | None = False,
|
1868
|
+
show_scale: bool | None = False,
|
1869
|
+
units: Literal["px", "mm", "in"] | None = "px",
|
1870
|
+
figsize: tuple[float, float] | None = (None, None),
|
1871
|
+
dpi: int | None = 100,
|
1872
|
+
show: bool | None = True,
|
1873
|
+
save: str | bool | None = None,
|
1874
|
+
) -> Tree | None:
|
1875
|
+
"""Plot a tree using input ete3 tree object.
|
1876
|
+
|
1877
|
+
Args:
|
1878
|
+
data: AnnData object or MuData object.
|
1879
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1880
|
+
tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
|
1881
|
+
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
1882
|
+
producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
|
1883
|
+
show_scale: Include the scale legend in the tree image or not.
|
1884
|
+
show: If True, plot the tree inline. If false, return tree and tree_style objects.
|
1885
|
+
file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG.
|
1886
|
+
Output image can be saved whether show is True or not.
|
1887
|
+
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
|
1888
|
+
figsize: Figure size.
|
1889
|
+
dpi: Dots per inches.
|
1890
|
+
|
1891
|
+
Returns:
|
1892
|
+
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
|
1893
|
+
|
1894
|
+
Examples:
|
1895
|
+
>>> import pertpy as pt
|
1896
|
+
>>> adata = pt.dt.tasccoda_example()
|
1897
|
+
>>> tasccoda = pt.tl.Tasccoda()
|
1898
|
+
>>> mdata = tasccoda.load(
|
1899
|
+
>>> adata, type="sample_level",
|
1900
|
+
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
1901
|
+
>>> key_added="lineage", add_level_name=True
|
1902
|
+
>>> )
|
1903
|
+
>>> mdata = tasccoda.prepare(
|
1904
|
+
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
1905
|
+
>>> )
|
1906
|
+
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
1907
|
+
>>> tasccoda.plot_draw_tree(mdata, tree="lineage")
|
1908
|
+
|
1909
|
+
Preview:
|
1910
|
+
.. image:: /_static/docstring_previews/tasccoda_draw_tree.png
|
1911
|
+
"""
|
1912
|
+
try:
|
1913
|
+
from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
|
1914
|
+
except ImportError:
|
1915
|
+
raise ImportError(
|
1916
|
+
"To use tasccoda please install additional dependencies with `pip install pertpy[coda]`"
|
1917
|
+
) from None
|
1918
|
+
|
1919
|
+
if isinstance(data, MuData):
|
1920
|
+
data = data[modality_key]
|
1921
|
+
if isinstance(data, AnnData):
|
1922
|
+
data = data
|
1923
|
+
if isinstance(tree, str):
|
1924
|
+
tree = data.uns[tree]
|
1925
|
+
|
1926
|
+
def my_layout(node):
|
1927
|
+
text_face = TextFace(node.name, tight_text=tight_text)
|
1928
|
+
faces.add_face_to_node(text_face, node, column=0, position="branch-right")
|
1929
|
+
|
1930
|
+
tree_style = TreeStyle()
|
1931
|
+
tree_style.show_leaf_name = False
|
1932
|
+
tree_style.layout_fn = my_layout
|
1933
|
+
tree_style.show_scale = show_scale
|
1934
|
+
|
1935
|
+
if save is not None:
|
1936
|
+
tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
|
1937
|
+
if show:
|
1938
|
+
return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
|
1939
|
+
else:
|
1940
|
+
return tree, tree_style
|
1941
|
+
|
1942
|
+
def plot_draw_effects( # pragma: no cover
|
1943
|
+
self,
|
1944
|
+
data: AnnData | MuData,
|
1945
|
+
covariate: str,
|
1946
|
+
modality_key: str = "coda",
|
1947
|
+
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
|
1948
|
+
show_legend: bool | None = None,
|
1949
|
+
show_leaf_effects: bool | None = False,
|
1950
|
+
tight_text: bool | None = False,
|
1951
|
+
show_scale: bool | None = False,
|
1952
|
+
units: Literal["px", "mm", "in"] | None = "px",
|
1953
|
+
figsize: tuple[float, float] | None = (None, None),
|
1954
|
+
dpi: int | None = 100,
|
1955
|
+
show: bool | None = True,
|
1956
|
+
save: str | None = None,
|
1957
|
+
) -> Tree | None:
|
1958
|
+
"""Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
|
1959
|
+
|
1960
|
+
Args:
|
1961
|
+
data: AnnData object or MuData object.
|
1962
|
+
covariate: The covariate, whose effects should be plotted.
|
1963
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
1964
|
+
tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
|
1965
|
+
show_legend: If show legend of nodes significant effects or not.
|
1966
|
+
Defaults to False if show_leaf_effects is True.
|
1967
|
+
show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects.
|
1968
|
+
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
1969
|
+
producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
|
1970
|
+
show_scale: Include the scale legend in the tree image or not.
|
1971
|
+
show: If True, plot the tree inline. If false, return tree and tree_style objects.
|
1972
|
+
file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
|
1973
|
+
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
|
1974
|
+
figsize: Figure size.
|
1975
|
+
dpi: Dots per inches.
|
1976
|
+
|
1977
|
+
Returns:
|
1978
|
+
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`)
|
1979
|
+
or plot the tree inline (`show = False`)
|
1980
|
+
|
1981
|
+
Examples:
|
1982
|
+
>>> import pertpy as pt
|
1983
|
+
>>> adata = pt.dt.tasccoda_example()
|
1984
|
+
>>> tasccoda = pt.tl.Tasccoda()
|
1985
|
+
>>> mdata = tasccoda.load(
|
1986
|
+
>>> adata, type="sample_level",
|
1987
|
+
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
1988
|
+
>>> key_added="lineage", add_level_name=True
|
1989
|
+
>>> )
|
1990
|
+
>>> mdata = tasccoda.prepare(
|
1991
|
+
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
1992
|
+
>>> )
|
1993
|
+
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
1994
|
+
>>> tasccoda.plot_draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
|
1995
|
+
|
1996
|
+
Preview:
|
1997
|
+
.. image:: /_static/docstring_previews/tasccoda_draw_effects.png
|
1998
|
+
"""
|
1999
|
+
try:
|
2000
|
+
from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
|
2001
|
+
except ImportError:
|
2002
|
+
raise ImportError(
|
2003
|
+
"To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
|
2004
|
+
) from None
|
2005
|
+
|
2006
|
+
if isinstance(data, MuData):
|
2007
|
+
data = data[modality_key]
|
2008
|
+
if isinstance(data, AnnData):
|
2009
|
+
data = data
|
2010
|
+
if show_legend is None:
|
2011
|
+
show_legend = not show_leaf_effects
|
2012
|
+
elif show_legend:
|
2013
|
+
logger.info("Tree leaves and leaf effect bars won't be aligned when legend is shown!")
|
2014
|
+
|
2015
|
+
if isinstance(tree, str):
|
2016
|
+
tree = data.uns[tree]
|
2017
|
+
# Collapse tree singularities
|
2018
|
+
tree2 = collapse_singularities_2(tree)
|
2019
|
+
|
2020
|
+
node_effs = data.uns["scCODA_params"]["node_df"].loc[(covariate + "_node",),].copy()
|
2021
|
+
node_effs.index = node_effs.index.get_level_values("Node")
|
2022
|
+
|
2023
|
+
covariates = data.uns["scCODA_params"]["covariate_names"]
|
2024
|
+
effect_dfs = [data.varm[f"effect_df_{cov}"] for cov in covariates]
|
2025
|
+
eff_df = pd.concat(effect_dfs)
|
2026
|
+
eff_df.index = pd.MultiIndex.from_product(
|
2027
|
+
(covariates, data.var.index.tolist()),
|
2028
|
+
names=["Covariate", "Cell Type"],
|
2029
|
+
)
|
2030
|
+
leaf_effs = eff_df.loc[(covariate,),].copy()
|
2031
|
+
leaf_effs.index = leaf_effs.index.get_level_values("Cell Type")
|
2032
|
+
|
2033
|
+
# Add effect values
|
2034
|
+
for n in tree2.traverse():
|
2035
|
+
nstyle = NodeStyle()
|
2036
|
+
nstyle["size"] = 0
|
2037
|
+
n.set_style(nstyle)
|
2038
|
+
if n.name in node_effs.index:
|
2039
|
+
e = node_effs.loc[n.name, "Final Parameter"]
|
2040
|
+
n.add_feature("node_effect", e)
|
2041
|
+
else:
|
2042
|
+
n.add_feature("node_effect", 0)
|
2043
|
+
if n.name in leaf_effs.index:
|
2044
|
+
e = leaf_effs.loc[n.name, "Effect"]
|
2045
|
+
n.add_feature("leaf_effect", e)
|
2046
|
+
else:
|
2047
|
+
n.add_feature("leaf_effect", 0)
|
2048
|
+
|
2049
|
+
# Scale effect values to get nice node sizes
|
2050
|
+
eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
|
2051
|
+
leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
|
2052
|
+
|
2053
|
+
def my_layout(node):
|
2054
|
+
text_face = TextFace(node.name, tight_text=tight_text)
|
2055
|
+
text_face.margin_left = 10
|
2056
|
+
faces.add_face_to_node(text_face, node, column=0, aligned=True)
|
2057
|
+
|
2058
|
+
# if node.is_leaf():
|
2059
|
+
size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
|
2060
|
+
if np.sign(node.node_effect) == 1:
|
2061
|
+
color = "blue"
|
2062
|
+
elif np.sign(node.node_effect) == -1:
|
2063
|
+
color = "red"
|
2064
|
+
else:
|
2065
|
+
color = "cyan"
|
2066
|
+
if size != 0:
|
2067
|
+
faces.add_face_to_node(CircleFace(radius=size, color=color), node, column=0)
|
2068
|
+
|
2069
|
+
tree_style = TreeStyle()
|
2070
|
+
tree_style.show_leaf_name = False
|
2071
|
+
tree_style.layout_fn = my_layout
|
2072
|
+
tree_style.show_scale = show_scale
|
2073
|
+
tree_style.draw_guiding_lines = True
|
2074
|
+
tree_style.legend_position = 1
|
2075
|
+
|
2076
|
+
if show_legend:
|
2077
|
+
tree_style.legend.add_face(TextFace("Effects"), column=0)
|
2078
|
+
tree_style.legend.add_face(TextFace(" "), column=1)
|
2079
|
+
for i in range(4, 0, -1):
|
2080
|
+
tree_style.legend.add_face(
|
2081
|
+
CircleFace(
|
2082
|
+
float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
|
2083
|
+
"red",
|
2084
|
+
),
|
2085
|
+
column=0,
|
2086
|
+
)
|
2087
|
+
tree_style.legend.add_face(TextFace(f"{-eff_max * i / 4:.2f} "), column=0)
|
2088
|
+
tree_style.legend.add_face(
|
2089
|
+
CircleFace(
|
2090
|
+
float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
|
2091
|
+
"blue",
|
2092
|
+
),
|
2093
|
+
column=1,
|
2094
|
+
)
|
2095
|
+
tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
|
2096
|
+
|
2097
|
+
if show_leaf_effects:
|
2098
|
+
leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
|
2099
|
+
leaf_effs = leaf_effs.loc[leaf_name].reset_index()
|
2100
|
+
palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
|
2101
|
+
|
2102
|
+
dir_path = Path.cwd()
|
2103
|
+
dir_path = Path(dir_path / "tree_effect.png")
|
2104
|
+
tree2.render(dir_path, tree_style=tree_style, units="in")
|
2105
|
+
_, ax = plt.subplots(1, 2, figsize=(10, 10))
|
2106
|
+
sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
|
2107
|
+
img = mpimg.imread(dir_path)
|
2108
|
+
ax[0].imshow(img)
|
2109
|
+
ax[0].get_xaxis().set_visible(False)
|
2110
|
+
ax[0].get_yaxis().set_visible(False)
|
2111
|
+
ax[0].set_frame_on(False)
|
2112
|
+
|
2113
|
+
ax[1].get_yaxis().set_visible(False)
|
2114
|
+
ax[1].spines["left"].set_visible(False)
|
2115
|
+
ax[1].spines["right"].set_visible(False)
|
2116
|
+
ax[1].spines["top"].set_visible(False)
|
2117
|
+
plt.xlim(-leaf_eff_max, leaf_eff_max)
|
2118
|
+
plt.subplots_adjust(wspace=0)
|
2119
|
+
|
2120
|
+
if save is not None:
|
2121
|
+
plt.savefig(save)
|
2122
|
+
|
2123
|
+
if save is not None and not show_leaf_effects:
|
2124
|
+
tree2.render(save, tree_style=tree_style, units=units)
|
2125
|
+
if show:
|
2126
|
+
if not show_leaf_effects:
|
2127
|
+
return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
|
2128
|
+
else:
|
2129
|
+
if not show_leaf_effects:
|
2130
|
+
return tree2, tree_style
|
2131
|
+
return None
|
2132
|
+
|
2133
|
+
def plot_effects_umap( # pragma: no cover
|
2134
|
+
self,
|
2135
|
+
mdata: MuData,
|
2136
|
+
effect_name: str | list | None,
|
2137
|
+
cluster_key: str,
|
2138
|
+
modality_key_1: str = "rna",
|
2139
|
+
modality_key_2: str = "coda",
|
2140
|
+
color_map: Colormap | str | None = None,
|
2141
|
+
palette: str | Sequence[str] | None = None,
|
2142
|
+
return_fig: bool | None = None,
|
2143
|
+
ax: Axes = None,
|
2144
|
+
show: bool = None,
|
2145
|
+
save: str | bool | None = None,
|
2146
|
+
**kwargs,
|
2147
|
+
) -> plt.Axes | plt.Figure | None:
|
2148
|
+
"""Plot a UMAP visualization colored by effect strength.
|
2149
|
+
|
2150
|
+
Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData
|
2151
|
+
(default is data['rna']) depending on the cluster they were assigned to.
|
2152
|
+
|
2153
|
+
Args:
|
2154
|
+
mudata: MuData object.
|
2155
|
+
effect_name: The name of the effect results in .varm of aggregated sample-level AnnData to plot
|
2156
|
+
cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']).
|
2157
|
+
To assign cell types' effects to original cells.
|
2158
|
+
modality_key_1: Key to the cell-level AnnData in the MuData object.
|
2159
|
+
modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
|
2160
|
+
show: Whether to display the figure or return axis.
|
2161
|
+
ax: A matplotlib axes object. Only works if plotting a single component.
|
2162
|
+
**kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
|
2163
|
+
|
2164
|
+
Returns:
|
2165
|
+
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
|
2166
|
+
|
2167
|
+
Examples:
|
2168
|
+
>>> import pertpy as pt
|
2169
|
+
>>> import scanpy as sc
|
2170
|
+
>>> import schist
|
2171
|
+
>>> adata = pt.dt.haber_2017_regions()
|
2172
|
+
>>> sc.pp.neighbors(adata)
|
2173
|
+
>>> schist.inference.nested_model(adata, n_init=100, random_seed=5678)
|
2174
|
+
>>> tasccoda_model = pt.tl.Tasccoda()
|
2175
|
+
>>> tasccoda_data = tasccoda_model.load(adata, type="cell_level",
|
2176
|
+
>>> cell_type_identifier="nsbm_level_1",
|
2177
|
+
>>> sample_identifier="batch", covariate_obs=["condition"],
|
2178
|
+
>>> levels_orig=["nsbm_level_4", "nsbm_level_3", "nsbm_level_2", "nsbm_level_1"],
|
2179
|
+
>>> add_level_name=True)
|
2180
|
+
>>> tasccoda_model.prepare(
|
2181
|
+
>>> tasccoda_data,
|
2182
|
+
>>> modality_key="coda",
|
2183
|
+
>>> reference_cell_type="18",
|
2184
|
+
>>> formula="condition",
|
2185
|
+
>>> pen_args={"phi": 0, "lambda_1": 3.5},
|
2186
|
+
>>> tree_key="tree"
|
2187
|
+
>>> )
|
2188
|
+
>>> tasccoda_model.run_nuts(
|
2189
|
+
... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
|
2190
|
+
... )
|
2191
|
+
>>> tasccoda_model.run_nuts(
|
2192
|
+
... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
|
2193
|
+
... )
|
2194
|
+
>>> sc.tl.umap(tasccoda_data["rna"])
|
2195
|
+
>>> tasccoda_model.plot_effects_umap(tasccoda_data,
|
2196
|
+
>>> effect_name=["effect_df_condition[T.Salmonella]",
|
2197
|
+
>>> "effect_df_condition[T.Hpoly.Day3]",
|
2198
|
+
>>> "effect_df_condition[T.Hpoly.Day10]"],
|
2199
|
+
>>> cluster_key="nsbm_level_1",
|
2200
|
+
>>> )
|
2201
|
+
|
2202
|
+
Preview:
|
2203
|
+
.. image:: /_static/docstring_previews/tasccoda_effects_umap.png
|
2204
|
+
"""
|
2205
|
+
# TODO: Add effect_name parameter and cluster_key and test the example
|
2206
|
+
data_rna = mdata[modality_key_1]
|
2207
|
+
data_coda = mdata[modality_key_2]
|
2208
|
+
if isinstance(effect_name, str):
|
2209
|
+
effect_name = [effect_name]
|
2210
|
+
for _, effect in enumerate(effect_name):
|
2211
|
+
data_rna.obs[effect] = [data_coda.varm[effect].loc[f"{c}", "Effect"] for c in data_rna.obs[cluster_key]]
|
2212
|
+
if kwargs.get("vmin"):
|
2213
|
+
vmin = kwargs["vmin"]
|
2214
|
+
kwargs.pop("vmin")
|
2215
|
+
else:
|
2216
|
+
vmin = min(data_rna.obs[effect].min() for _, effect in enumerate(effect_name))
|
2217
|
+
if kwargs.get("vmax"):
|
2218
|
+
vmax = kwargs["vmax"]
|
2219
|
+
kwargs.pop("vmax")
|
2220
|
+
else:
|
2221
|
+
vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name))
|
2222
|
+
|
2223
|
+
return sc.pl.umap(
|
2224
|
+
data_rna,
|
2225
|
+
color=effect_name,
|
2226
|
+
vmax=vmax,
|
2227
|
+
vmin=vmin,
|
2228
|
+
palette=palette,
|
2229
|
+
color_map=color_map,
|
2230
|
+
return_fig=return_fig,
|
2231
|
+
ax=ax,
|
2232
|
+
show=show,
|
2233
|
+
save=save,
|
2234
|
+
**kwargs,
|
2235
|
+
)
|
2236
|
+
|
1116
2237
|
|
1117
2238
|
def get_a(
|
1118
|
-
tree: tt.
|
2239
|
+
tree: tt.core.ToyTree,
|
1119
2240
|
) -> tuple[np.ndarray, int]:
|
1120
2241
|
"""Calculate ancestor matrix from a toytree tree
|
1121
2242
|
|
@@ -1154,7 +2275,7 @@ def get_a(
|
|
1154
2275
|
return A, n_nodes - 1
|
1155
2276
|
|
1156
2277
|
|
1157
|
-
def collapse_singularities(tree: tt.
|
2278
|
+
def collapse_singularities(tree: tt.core.ToyTree) -> tt.core.ToyTree:
|
1158
2279
|
"""Collapses (deletes) nodes in a toytree tree that are singularities (have only one child).
|
1159
2280
|
|
1160
2281
|
Args:
|
@@ -1242,7 +2363,7 @@ def df2newick(df: pd.DataFrame, levels: list[str], inner_label: bool = True) ->
|
|
1242
2363
|
|
1243
2364
|
|
1244
2365
|
def get_a_2(
|
1245
|
-
tree:
|
2366
|
+
tree: Tree,
|
1246
2367
|
leaf_order: list[str] = None,
|
1247
2368
|
node_order: list[str] = None,
|
1248
2369
|
) -> tuple[np.ndarray, int]:
|
@@ -1263,6 +2384,13 @@ def get_a_2(
|
|
1263
2384
|
T
|
1264
2385
|
number of nodes in the tree, excluding the root node
|
1265
2386
|
"""
|
2387
|
+
try:
|
2388
|
+
import ete3 as ete
|
2389
|
+
except ImportError:
|
2390
|
+
raise ImportError(
|
2391
|
+
"To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
|
2392
|
+
) from None
|
2393
|
+
|
1266
2394
|
n_tips = len(tree.get_leaves())
|
1267
2395
|
n_nodes = len(tree.get_descendants())
|
1268
2396
|
|
@@ -1292,7 +2420,7 @@ def get_a_2(
|
|
1292
2420
|
return A_, n_nodes
|
1293
2421
|
|
1294
2422
|
|
1295
|
-
def collapse_singularities_2(tree:
|
2423
|
+
def collapse_singularities_2(tree: Tree) -> Tree:
|
1296
2424
|
"""Collapses (deletes) nodes in a ete3 tree that are singularities (have only one child).
|
1297
2425
|
|
1298
2426
|
Args:
|
@@ -1327,10 +2455,10 @@ def linkage_to_newick(
|
|
1327
2455
|
|
1328
2456
|
def build_newick(node, newick, parentdist, leaf_names):
|
1329
2457
|
if node.is_leaf():
|
1330
|
-
return f"{leaf_names[node.id]}:{(parentdist - node.dist)/2}{newick}"
|
2458
|
+
return f"{leaf_names[node.id]}:{(parentdist - node.dist) / 2}{newick}"
|
1331
2459
|
else:
|
1332
2460
|
if len(newick) > 0:
|
1333
|
-
newick = f"):{(parentdist - node.dist)/2}{newick}"
|
2461
|
+
newick = f"):{(parentdist - node.dist) / 2}{newick}"
|
1334
2462
|
else:
|
1335
2463
|
newick = ");"
|
1336
2464
|
newick = build_newick(node.get_left(), newick, node.dist, leaf_names)
|
@@ -1363,14 +2491,14 @@ def import_tree(
|
|
1363
2491
|
|
1364
2492
|
Args:
|
1365
2493
|
data: A tascCODA-compatible data object.
|
1366
|
-
modality_1: If `data` is MuData,
|
1367
|
-
modality_2: If `data` is MuData,
|
1368
|
-
dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object.
|
1369
|
-
levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
|
1370
|
-
levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
|
1371
|
-
add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
|
1372
|
-
key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`.
|
1373
|
-
|
2494
|
+
modality_1: If `data` is MuData, specify the modality name to the original cell level anndata object.
|
2495
|
+
modality_2: If `data` is MuData, specify the modality name to the aggregated level anndata object.
|
2496
|
+
dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object.
|
2497
|
+
levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
|
2498
|
+
levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
|
2499
|
+
add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
|
2500
|
+
key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`.
|
2501
|
+
If `data` is MuData, save tree in data[modality_2].
|
1374
2502
|
|
1375
2503
|
Returns:
|
1376
2504
|
Updates data with the following:
|
@@ -1379,15 +2507,22 @@ def import_tree(
|
|
1379
2507
|
|
1380
2508
|
tree: A ete3 tree object.
|
1381
2509
|
"""
|
2510
|
+
try:
|
2511
|
+
import ete3 as ete
|
2512
|
+
except ImportError:
|
2513
|
+
raise ImportError(
|
2514
|
+
"To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
|
2515
|
+
) from None
|
2516
|
+
|
1382
2517
|
if isinstance(data, MuData):
|
1383
2518
|
try:
|
1384
2519
|
data_1 = data[modality_1]
|
1385
2520
|
data_2 = data[modality_2]
|
1386
2521
|
except KeyError as name:
|
1387
|
-
|
2522
|
+
logger.error(f"No {name} slot in MuData")
|
1388
2523
|
raise
|
1389
2524
|
except IndexError:
|
1390
|
-
|
2525
|
+
logger.error("Please specify modality_1 and modality_2 to indicate modalities in MuData")
|
1391
2526
|
raise
|
1392
2527
|
else:
|
1393
2528
|
data_1 = data
|
@@ -1443,68 +2578,54 @@ def from_scanpy(
|
|
1443
2578
|
|
1444
2579
|
The anndata object needs to have a column in adata.obs that contains the cell type assignment.
|
1445
2580
|
Further, it must contain one column or a set of columns (e.g. subject id, treatment, disease status) that uniquely identify each (statistical) sample.
|
1446
|
-
Further covariates (e.g. subject age) can either be specified via
|
2581
|
+
Further covariates (e.g. subject age) can either be specified via additional column names in adata.obs, a key in adata.uns, or as a separate DataFrame.
|
1447
2582
|
|
1448
|
-
NOTE: The order of samples in the returned dataset is determined by the first
|
2583
|
+
NOTE: The order of samples in the returned dataset is determined by the first occurrence of cells from each sample in `adata`
|
1449
2584
|
|
1450
2585
|
Args:
|
1451
2586
|
adata: An anndata object from scanpy
|
1452
2587
|
cell_type_identifier: column name in adata.obs that specifies the cell types
|
1453
2588
|
sample_identifier: column name or list of column names in adata.obs that uniquely identify each sample
|
1454
2589
|
covariate_uns: key for adata.uns, where covariate values are stored
|
1455
|
-
covariate_obs: list of column names in adata.obs, where covariate values are stored.
|
2590
|
+
covariate_obs: list of column names in adata.obs, where covariate values are stored.
|
2591
|
+
Note: If covariate values are not unique for a value of sample_identifier, this covariate will be skipped.
|
1456
2592
|
covariate_df: DataFrame with covariates
|
1457
2593
|
|
1458
2594
|
Returns:
|
1459
2595
|
AnnData: A data set with cells aggregated to the (sample x cell type) level
|
1460
2596
|
"""
|
1461
|
-
if isinstance(sample_identifier, str)
|
1462
|
-
|
1463
|
-
|
1464
|
-
if covariate_obs:
|
1465
|
-
covariate_obs += [i for i in sample_identifier if i not in covariate_obs]
|
1466
|
-
else:
|
1467
|
-
covariate_obs = sample_identifier # type: ignore
|
2597
|
+
sample_identifier = [sample_identifier] if isinstance(sample_identifier, str) else sample_identifier
|
2598
|
+
covariate_obs = list(set(covariate_obs or []) | set(sample_identifier))
|
1468
2599
|
|
1469
|
-
# join sample identifiers
|
1470
2600
|
if isinstance(sample_identifier, list):
|
1471
2601
|
adata.obs["scCODA_sample_id"] = adata.obs[sample_identifier].agg("-".join, axis=1)
|
1472
2602
|
sample_identifier = "scCODA_sample_id"
|
1473
2603
|
|
1474
|
-
# get cell type counts
|
1475
2604
|
groups = adata.obs.value_counts([sample_identifier, cell_type_identifier])
|
1476
|
-
|
1477
|
-
|
1478
|
-
|
1479
|
-
# get covariates from different sources
|
1480
|
-
covariate_df_ = pd.DataFrame(index=count_data.index)
|
1481
|
-
|
1482
|
-
if covariate_df is None and covariate_obs is None and covariate_uns is None:
|
1483
|
-
print("No covariate information specified!")
|
2605
|
+
ct_count_data = groups.unstack(level=cell_type_identifier).fillna(0)
|
2606
|
+
covariate_df_ = pd.DataFrame(index=ct_count_data.index)
|
1484
2607
|
|
1485
2608
|
if covariate_uns is not None:
|
1486
|
-
covariate_df_uns = pd.DataFrame(adata.uns[covariate_uns])
|
1487
|
-
covariate_df_ = pd.concat(
|
2609
|
+
covariate_df_uns = pd.DataFrame(adata.uns[covariate_uns], index=ct_count_data.index)
|
2610
|
+
covariate_df_ = pd.concat([covariate_df_, covariate_df_uns], axis=1)
|
1488
2611
|
|
1489
|
-
if covariate_obs
|
1490
|
-
|
1491
|
-
|
1492
|
-
|
2612
|
+
if covariate_obs:
|
2613
|
+
unique_check = adata.obs.groupby(sample_identifier).nunique()
|
2614
|
+
for c in covariate_obs.copy():
|
2615
|
+
if unique_check[c].max() != 1:
|
2616
|
+
logger.warning(f"Covariate {c} has non-unique values for batch! Skipping...")
|
1493
2617
|
covariate_obs.remove(c)
|
1494
|
-
|
1495
|
-
|
1496
|
-
|
2618
|
+
if covariate_obs:
|
2619
|
+
covariate_df_obs = adata.obs.groupby(sample_identifier).first()[covariate_obs]
|
2620
|
+
covariate_df_ = pd.concat([covariate_df_, covariate_df_obs], axis=1)
|
1497
2621
|
|
1498
2622
|
if covariate_df is not None:
|
1499
|
-
if set(covariate_df.index) != set(
|
1500
|
-
raise ValueError("
|
1501
|
-
|
1502
|
-
covariate_df_ = pd.concat((covariate_df_, covs_ord), axis=1)
|
2623
|
+
if set(covariate_df.index) != set(ct_count_data.index):
|
2624
|
+
raise ValueError("Mismatch between sample names in anndata and covariate_df!")
|
2625
|
+
covariate_df_ = pd.concat([covariate_df_, covariate_df.reindex(ct_count_data.index)], axis=1)
|
1503
2626
|
|
1504
|
-
|
1505
|
-
|
1506
|
-
# create var (number of cells for each type as only column)
|
1507
|
-
var_dat = count_data.sum(axis=0).rename("n_cells").to_frame()
|
2627
|
+
var_dat = ct_count_data.sum().rename("n_cells").to_frame()
|
1508
2628
|
var_dat.index = var_dat.index.astype(str)
|
2629
|
+
covariate_df_.index = covariate_df_.index.astype(str)
|
1509
2630
|
|
1510
|
-
return AnnData(X=
|
2631
|
+
return AnnData(X=ct_count_data.values, var=var_dat, obs=covariate_df_)
|