pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,17 +1,24 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Literal, Optional, Union
5
6
 
6
7
  import arviz as az
7
- import ete3 as ete
8
8
  import jax.numpy as jnp
9
+ import matplotlib.pyplot as plt
9
10
  import numpy as np
10
11
  import pandas as pd
11
12
  import patsy as pt
13
+ import scanpy as sc
14
+ import seaborn as sns
15
+ from adjustText import adjust_text
12
16
  from anndata import AnnData
13
- from jax import random
14
- from jax.config import config
17
+ from jax import config, random
18
+ from lamin_utils import logger
19
+ from matplotlib import cm, rcParams
20
+ from matplotlib import image as mpimg
21
+ from matplotlib.colors import ListedColormap
15
22
  from mudata import MuData
16
23
  from numpyro.infer import HMC, MCMC, NUTS, initialization
17
24
  from rich import box, print
@@ -20,10 +27,15 @@ from rich.table import Table
20
27
  from scipy.cluster import hierarchy as sp_hierarchy
21
28
 
22
29
  if TYPE_CHECKING:
30
+ from collections.abc import Sequence
31
+
23
32
  import numpyro as npy
24
33
  import toytree as tt
25
- from jax._src.prng import PRNGKeyArray
34
+ from ete3 import Tree
26
35
  from jax._src.typing import Array
36
+ from matplotlib.axes import Axes
37
+ from matplotlib.colors import Colormap
38
+ from matplotlib.figure import Figure
27
39
 
28
40
  config.update("jax_enable_x64", True)
29
41
 
@@ -99,9 +111,9 @@ class CompositionalModel2(ABC):
99
111
  Categorical covariates are handled automatically, with the covariate value of the first sample being used as the reference category.
100
112
  To set a different level as the base category for a categorical covariate, use "C(<CovariateName>, Treatment('<ReferenceLevelName>'))"
101
113
  reference_cell_type: Column name that sets the reference cell type.
102
- Reference the name of a column. If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen. Defaults to "automatic".
114
+ Reference the name of a column. If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen.
103
115
  automatic_reference_absence_threshold: If using reference_cell_type = "automatic", determine the maximum fraction of zero entries for a cell type
104
- to be considered as a possible reference cell type. Defaults to 0.05.
116
+ to be considered as a possible reference cell type.
105
117
 
106
118
  Returns:
107
119
  AnnData object that is ready for CODA models.
@@ -137,7 +149,7 @@ class CompositionalModel2(ABC):
137
149
  ref_index = np.where(cell_type_disp == min_var)[0][0]
138
150
 
139
151
  ref_cell_type = cell_types[ref_index]
140
- print(f"[bold blue]Automatic reference selection! Reference cell type set to {ref_cell_type}")
152
+ logger.info(f"Automatic reference selection! Reference cell type set to {ref_cell_type}")
141
153
 
142
154
  # Column name as reference cell type
143
155
  elif reference_cell_type in cell_types:
@@ -149,7 +161,7 @@ class CompositionalModel2(ABC):
149
161
 
150
162
  # Add pseudocount if zeroes are present.
151
163
  if np.count_nonzero(sample_adata.X) != np.size(sample_adata.X):
152
- print("Zero counts encountered in data! Added a pseudocount of 0.5.")
164
+ logger.info("Zero counts encountered in data! Added a pseudocount of 0.5.")
153
165
  sample_adata.X[sample_adata.X == 0] = 0.5
154
166
 
155
167
  sample_adata.obsm["sample_counts"] = np.sum(sample_adata.X, axis=1)
@@ -179,7 +191,7 @@ class CompositionalModel2(ABC):
179
191
  self,
180
192
  sample_adata: AnnData,
181
193
  kernel: npy.infer.mcmc.MCMCKernel,
182
- rng_key: Array | PRNGKeyArray,
194
+ rng_key: Array,
183
195
  copy: bool = False,
184
196
  *args,
185
197
  **kwargs,
@@ -190,7 +202,7 @@ class CompositionalModel2(ABC):
190
202
  sample_adata: anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
191
203
  kernel: A `numpyro.infer.mcmc.MCMCKernel` object
192
204
  rng_key: The rng state used. If None, a random state will be selected
193
- copy: Return a copy instead of writing to adata. Defaults to False.
205
+ copy: Return a copy instead of writing to adata.
194
206
  args: Passed to `numpyro.infer.mcmc.MCMC`
195
207
  kwargs: Passed to `numpyro.infer.mcmc.MCMC`
196
208
 
@@ -226,13 +238,13 @@ class CompositionalModel2(ABC):
226
238
 
227
239
  acc_rate = np.array(self.mcmc.last_state.mean_accept_prob)
228
240
  if acc_rate < 0.6:
229
- print(
230
- f"[bold red]Acceptance rate unusually low ({acc_rate} < 0.5)! Results might be incorrect! "
241
+ logger.warning(
242
+ f"Acceptance rate unusually low ({acc_rate} < 0.5)! Results might be incorrect! "
231
243
  f"Please check feasibility of results and re-run the sampling step with a different rng_key if necessary."
232
244
  )
233
245
  if acc_rate > 0.95:
234
- print(
235
- f"[bold red]Acceptance rate unusually high ({acc_rate} > 0.95)! Results might be incorrect! "
246
+ logger.warning(
247
+ f"Acceptance rate unusually high ({acc_rate} > 0.95)! Results might be incorrect! "
236
248
  f"Please check feasibility of results and re-run the sampling step with a different rng_key if necessary."
237
249
  )
238
250
 
@@ -275,11 +287,11 @@ class CompositionalModel2(ABC):
275
287
 
276
288
  Args:
277
289
  data: AnnData object or MuData object.
278
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
279
- num_samples: Number of sampled values after burn-in. Defaults to 10000.
280
- num_warmup: Number of burn-in (warmup) samples. Defaults to 1000.
281
- rng_key: The rng state used. Defaults to 0.
282
- copy: Return a copy instead of writing to adata. Defaults to False.
290
+ modality_key: If data is a MuData object, specify which modality to use.
291
+ num_samples: Number of sampled values after burn-in.
292
+ num_warmup: Number of burn-in (warmup) samples.
293
+ rng_key: The rng state used.
294
+ copy: Return a copy instead of writing to adata.
283
295
 
284
296
  Returns:
285
297
  Calls `self.__run_mcmc`
@@ -288,14 +300,14 @@ class CompositionalModel2(ABC):
288
300
  try:
289
301
  sample_adata = data[modality_key]
290
302
  except IndexError:
291
- print("When data is a MuData object, modality_key must be specified!")
303
+ logger.error("When data is a MuData object, modality_key must be specified!")
292
304
  raise
293
305
  if isinstance(data, AnnData):
294
306
  sample_adata = data
295
307
  if copy:
296
308
  sample_adata = sample_adata.copy()
297
309
 
298
- rng_key_array = random.PRNGKey(rng_key)
310
+ rng_key_array = random.key(rng_key)
299
311
  sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = np.array(rng_key_array)
300
312
 
301
313
  # Set up NUTS kernel
@@ -328,14 +340,13 @@ class CompositionalModel2(ABC):
328
340
 
329
341
  Args:
330
342
  data: AnnData object or MuData object.
331
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
332
- num_samples: Number of sampled values after burn-in. Defaults to 20000.
333
- num_warmup: Number of burn-in (warmup) samples. Defaults to 5000.
334
- rng_key: The rng state used. If None, a random state will be selected. Defaults to None.
335
- copy: Return a copy instead of writing to adata. Defaults to False.
343
+ modality_key: If data is a MuData object, specify which modality to use.
344
+ num_samples: Number of sampled values after burn-in.
345
+ num_warmup: Number of burn-in (warmup) samples.
346
+ rng_key: The rng state used. If None, a random state will be selected.
347
+ copy: Return a copy instead of writing to adata.
336
348
 
337
349
  Examples:
338
- Example with scCODA:
339
350
  >>> import pertpy as pt
340
351
  >>> haber_cells = pt.dt.haber_2017_regions()
341
352
  >>> sccoda = pt.tl.Sccoda()
@@ -348,7 +359,7 @@ class CompositionalModel2(ABC):
348
359
  try:
349
360
  sample_adata = data[modality_key]
350
361
  except IndexError:
351
- print("When data is a MuData object, modality_key must be specified!")
362
+ logger.error("When data is a MuData object, modality_key must be specified!")
352
363
  raise
353
364
  if isinstance(data, AnnData):
354
365
  sample_adata = data
@@ -358,10 +369,10 @@ class CompositionalModel2(ABC):
358
369
  # Set rng key if needed
359
370
  if rng_key is None:
360
371
  rng = np.random.default_rng()
361
- rng_key = random.PRNGKey(rng.integers(0, 10000))
372
+ rng_key = random.key(rng.integers(0, 10000))
362
373
  sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = rng_key
363
374
  else:
364
- rng_key = random.PRNGKey(rng_key)
375
+ rng_key = random.key(rng_key)
365
376
 
366
377
  # Set up HMC kernel
367
378
  sample_adata = self.set_init_mcmc_states(
@@ -387,7 +398,7 @@ class CompositionalModel2(ABC):
387
398
 
388
399
  Args:
389
400
  sample_adata: Anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
390
- est_fdr: Desired FDR value. Defaults to 0.05.
401
+ est_fdr: Desired FDR value.
391
402
  args: Passed to ``az.summary``
392
403
  kwargs: Passed to ``az.summary``
393
404
 
@@ -423,7 +434,6 @@ class CompositionalModel2(ABC):
423
434
  - Is credible: Boolean indicator whether effect is credible
424
435
 
425
436
  Examples:
426
- Example with scCODA:
427
437
  >>> import pertpy as pt
428
438
  >>> haber_cells = pt.dt.haber_2017_regions()
429
439
  >>> sccoda = pt.tl.Sccoda()
@@ -433,7 +443,6 @@ class CompositionalModel2(ABC):
433
443
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
434
444
  >>> intercept_df, effect_df = sccoda.summary_prepare(mdata["coda"])
435
445
  """
436
- # Get model and effect selection types
437
446
  select_type = sample_adata.uns["scCODA_params"]["select_type"]
438
447
  model_type = sample_adata.uns["scCODA_params"]["model_type"]
439
448
 
@@ -548,7 +557,11 @@ class CompositionalModel2(ABC):
548
557
  intercept_df = intercept_df.loc[:, ["final_parameter", hdis[0], hdis[1], "sd", "expected_sample"]].copy()
549
558
  intercept_df = intercept_df.rename(
550
559
  columns=dict(
551
- zip(intercept_df.columns, ["Final Parameter", hdis_new[0], hdis_new[1], "SD", "Expected Sample"])
560
+ zip(
561
+ intercept_df.columns,
562
+ ["Final Parameter", hdis_new[0], hdis_new[1], "SD", "Expected Sample"],
563
+ strict=False,
564
+ )
552
565
  )
553
566
  )
554
567
 
@@ -561,6 +574,7 @@ class CompositionalModel2(ABC):
561
574
  zip(
562
575
  effect_df.columns,
563
576
  ["Effect", "Median", hdis_new[0], hdis_new[1], "SD", "Expected Sample", "log2-fold change"],
577
+ strict=False,
564
578
  )
565
579
  )
566
580
  )
@@ -581,6 +595,7 @@ class CompositionalModel2(ABC):
581
595
  "Expected Sample",
582
596
  "log2-fold change",
583
597
  ],
598
+ strict=False,
584
599
  )
585
600
  )
586
601
  )
@@ -594,6 +609,7 @@ class CompositionalModel2(ABC):
594
609
  zip(
595
610
  node_df.columns,
596
611
  ["Final Parameter", "Median", hdis_new[0], hdis_new[1], "SD", "Delta", "Is credible"],
612
+ strict=False,
597
613
  )
598
614
  ) # type: ignore
599
615
  ) # type: ignore
@@ -622,8 +638,8 @@ class CompositionalModel2(ABC):
622
638
  effect_df: Effect summary, see ``summary_prepare``
623
639
  model_type: String indicating the model type ("classic" or "tree_agg")
624
640
  select_type: String indicating the type of spike_and_slab selection ("spikeslab" or "sslasso")
625
- target_fdr: Desired FDR value. Defaults to 0.05.
626
- node_df: If using tree aggregation, the node-level effect DataFrame must be passed. Defaults to None.
641
+ target_fdr: Desired FDR value.
642
+ node_df: If using tree aggregation, the node-level effect DataFrame must be passed.
627
643
 
628
644
  Returns:
629
645
  pd.DataFrame: effect DataFrame with inclusion probability, final parameters, expected sample.
@@ -775,13 +791,12 @@ class CompositionalModel2(ABC):
775
791
 
776
792
  Args:
777
793
  data: AnnData object or MuData object.
778
- extended: If True, return the extended summary with additional statistics. Defaults to False.
779
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
794
+ extended: If True, return the extended summary with additional statistics.
795
+ modality_key: If data is a MuData object, specify which modality to use.
780
796
  args: Passed to az.summary
781
797
  kwargs: Passed to az.summary
782
798
 
783
799
  Examples:
784
- Example with scCODA:
785
800
  >>> import pertpy as pt
786
801
  >>> haber_cells = pt.dt.haber_2017_regions()
787
802
  >>> sccoda = pt.tl.Sccoda()
@@ -795,11 +810,11 @@ class CompositionalModel2(ABC):
795
810
  try:
796
811
  sample_adata = data[modality_key]
797
812
  except IndexError:
798
- print("[bold red]When data is a MuData object, modality_key must be specified!")
813
+ logger.error("When data is a MuData object, modality_key must be specified!")
799
814
  raise
800
815
  if isinstance(data, AnnData):
801
816
  sample_adata = data
802
- # Get model and effect selection types
817
+
803
818
  select_type = sample_adata.uns["scCODA_params"]["select_type"]
804
819
  model_type = sample_adata.uns["scCODA_params"]["model_type"]
805
820
 
@@ -834,10 +849,10 @@ class CompositionalModel2(ABC):
834
849
  table.add_column("Name", justify="left", style="cyan")
835
850
  table.add_column("Value", justify="left")
836
851
  table.add_row("Data", "Data: %d samples, %d cell types" % data_dims)
837
- table.add_row("Reference cell type", "%s" % str(sample_adata.uns["scCODA_params"]["reference_cell_type"]))
838
- table.add_row("Formula", "%s" % sample_adata.uns["scCODA_params"]["formula"])
852
+ table.add_row("Reference cell type", "{}".format(str(sample_adata.uns["scCODA_params"]["reference_cell_type"])))
853
+ table.add_row("Formula", "{}".format(sample_adata.uns["scCODA_params"]["formula"]))
839
854
  if extended:
840
- table.add_row("Reference index", "%s" % str(sample_adata.uns["scCODA_params"]["reference_index"]))
855
+ table.add_row("Reference index", "{}".format(str(sample_adata.uns["scCODA_params"]["reference_index"])))
841
856
  if select_type == "spikeslab":
842
857
  table.add_row(
843
858
  "Spike-and-slab threshold",
@@ -920,13 +935,12 @@ class CompositionalModel2(ABC):
920
935
 
921
936
  Args:
922
937
  data: AnnData object or MuData object.
923
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
938
+ modality_key: If data is a MuData object, specify which modality to use.
924
939
 
925
940
  Returns:
926
941
  pd.DataFrame: Intercept data frame.
927
942
 
928
943
  Examples:
929
- Example with scCODA:
930
944
  >>> import pertpy as pt
931
945
  >>> haber_cells = pt.dt.haber_2017_regions()
932
946
  >>> sccoda = pt.tl.Sccoda()
@@ -936,12 +950,11 @@ class CompositionalModel2(ABC):
936
950
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
937
951
  >>> intercepts = sccoda.get_intercept_df(mdata)
938
952
  """
939
-
940
953
  if isinstance(data, MuData):
941
954
  try:
942
955
  sample_adata = data[modality_key]
943
956
  except IndexError:
944
- print("When data is a MuData object, modality_key must be specified!")
957
+ logger.error("When data is a MuData object, modality_key must be specified!")
945
958
  raise
946
959
  if isinstance(data, AnnData):
947
960
  sample_adata = data
@@ -953,13 +966,12 @@ class CompositionalModel2(ABC):
953
966
 
954
967
  Args:
955
968
  data: AnnData object or MuData object.
956
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
969
+ modality_key: If data is a MuData object, specify which modality to use.
957
970
 
958
971
  Returns:
959
972
  pd.DataFrame: Effect data frame.
960
973
 
961
974
  Examples:
962
- Example with scCODA:
963
975
  >>> import pertpy as pt
964
976
  >>> haber_cells = pt.dt.haber_2017_regions()
965
977
  >>> sccoda = pt.tl.Sccoda()
@@ -969,12 +981,11 @@ class CompositionalModel2(ABC):
969
981
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
970
982
  >>> effects = sccoda.get_effect_df(mdata)
971
983
  """
972
-
973
984
  if isinstance(data, MuData):
974
985
  try:
975
986
  sample_adata = data[modality_key]
976
987
  except IndexError:
977
- print("When data is a MuData object, modality_key must be specified!")
988
+ logger.error("When data is a MuData object, modality_key must be specified!")
978
989
  raise
979
990
  if isinstance(data, AnnData):
980
991
  sample_adata = data
@@ -997,15 +1008,14 @@ class CompositionalModel2(ABC):
997
1008
 
998
1009
  Args:
999
1010
  data: AnnData object or MuData object.
1000
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
1011
+ modality_key: If data is a MuData object, specify which modality to use.
1001
1012
 
1002
1013
  Returns:
1003
1014
  pd.DataFrame: Node effect data frame.
1004
1015
 
1005
1016
  Examples:
1006
- Example with tascCODA (works only for model of type tree_agg, i.e. a tascCODA model):
1007
1017
  >>> import pertpy as pt
1008
- >>> adata = pt.dt.smillie()
1018
+ >>> adata = pt.dt.tasccoda_example()
1009
1019
  >>> tasccoda = pt.tl.Tasccoda()
1010
1020
  >>> mdata = tasccoda.load(
1011
1021
  >>> adata, type="sample_level",
@@ -1023,7 +1033,7 @@ class CompositionalModel2(ABC):
1023
1033
  try:
1024
1034
  sample_adata = data[modality_key]
1025
1035
  except IndexError:
1026
- print("When data is a MuData object, modality_key must be specified!")
1036
+ logger.error("When data is a MuData object, modality_key must be specified!")
1027
1037
  raise
1028
1038
  if isinstance(data, AnnData):
1029
1039
  sample_adata = data
@@ -1037,7 +1047,7 @@ class CompositionalModel2(ABC):
1037
1047
  Args:
1038
1048
  data: AnnData object or MuData object.
1039
1049
  est_fdr: Desired FDR value.
1040
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
1050
+ modality_key: If data is a MuData object, specify which modality to use.
1041
1051
  args: passed to self.summary_prepare
1042
1052
  kwargs: passed to self.summary_prepare
1043
1053
 
@@ -1048,7 +1058,7 @@ class CompositionalModel2(ABC):
1048
1058
  try:
1049
1059
  sample_adata = data[modality_key]
1050
1060
  except IndexError:
1051
- print("When data is a MuData object, modality_key must be specified!")
1061
+ logger.error("When data is a MuData object, modality_key must be specified!")
1052
1062
  raise
1053
1063
  if isinstance(data, AnnData):
1054
1064
  sample_adata = data
@@ -1071,8 +1081,8 @@ class CompositionalModel2(ABC):
1071
1081
 
1072
1082
  Args:
1073
1083
  data: AnnData object or MuData object.
1074
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
1075
- est_fdr: Estimated false discovery rate. Must be between 0 and 1. Defaults to None.
1084
+ modality_key: If data is a MuData object, specify which modality to use.
1085
+ est_fdr: Estimated false discovery rate. Must be between 0 and 1.
1076
1086
 
1077
1087
  Returns:
1078
1088
  pd.Series: Credible effect decision series which includes boolean values indicate whether effects are credible under inc_prob_threshold.
@@ -1081,7 +1091,7 @@ class CompositionalModel2(ABC):
1081
1091
  try:
1082
1092
  sample_adata = data[modality_key]
1083
1093
  except IndexError:
1084
- print("When data is a MuData object, modality_key must be specified!")
1094
+ logger.error("When data is a MuData object, modality_key must be specified!")
1085
1095
  raise
1086
1096
  if isinstance(data, AnnData):
1087
1097
  sample_adata = data
@@ -1113,9 +1123,1120 @@ class CompositionalModel2(ABC):
1113
1123
 
1114
1124
  return out
1115
1125
 
1126
+ def _stackbar( # pragma: no cover
1127
+ self,
1128
+ y: np.ndarray,
1129
+ type_names: list[str],
1130
+ title: str,
1131
+ level_names: list[str],
1132
+ figsize: tuple[float, float] | None = None,
1133
+ dpi: int | None = 100,
1134
+ palette: ListedColormap | None = cm.tab20,
1135
+ show_legend: bool | None = True,
1136
+ ) -> plt.Axes:
1137
+ """Plots a stacked barplot for one (discrete) covariate.
1138
+
1139
+ Typical use (only inside stacked_barplot): plot_one_stackbar(data.X, data.var.index, "xyz", data.obs.index)
1140
+
1141
+ Args:
1142
+ y: The count data, collapsed onto the level of interest. i.e. a binary covariate has two rows,
1143
+ one for each group, containing the count mean of each cell type
1144
+ type_names: The names of all cell types
1145
+ title: Plot title, usually the covariate's name
1146
+ level_names: Names of the covariate's levels
1147
+ figsize: Figure size (matplotlib).
1148
+ dpi: Resolution in DPI (matplotlib).
1149
+ palette: The color map for the barplot.
1150
+ show_legend: If True, adds a legend.
1151
+
1152
+ Returns:
1153
+ A :class:`~matplotlib.axes.Axes` object
1154
+ """
1155
+ n_bars, n_types = y.shape
1156
+
1157
+ figsize = rcParams["figure.figsize"] if figsize is None else figsize
1158
+
1159
+ _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1160
+ r = np.array(range(n_bars))
1161
+ sample_sums = np.sum(y, axis=1)
1162
+
1163
+ barwidth = 0.85
1164
+ cum_bars = np.zeros(n_bars)
1165
+
1166
+ for n in range(n_types):
1167
+ bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums, strict=False)]
1168
+ plt.bar(
1169
+ r,
1170
+ bars,
1171
+ bottom=cum_bars,
1172
+ color=palette(n % palette.N),
1173
+ width=barwidth,
1174
+ label=type_names[n],
1175
+ linewidth=0,
1176
+ )
1177
+ cum_bars += bars
1178
+
1179
+ ax.set_title(title)
1180
+ if show_legend:
1181
+ ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1)
1182
+ ax.set_xticks(r)
1183
+ ax.set_xticklabels(level_names, rotation=45, ha="right")
1184
+ ax.set_ylabel("Proportion")
1185
+
1186
+ return ax
1187
+
1188
+ def plot_stacked_barplot( # pragma: no cover
1189
+ self,
1190
+ data: AnnData | MuData,
1191
+ feature_name: str,
1192
+ modality_key: str = "coda",
1193
+ palette: ListedColormap | None = cm.tab20,
1194
+ show_legend: bool | None = True,
1195
+ level_order: list[str] = None,
1196
+ figsize: tuple[float, float] | None = None,
1197
+ dpi: int | None = 100,
1198
+ return_fig: bool | None = None,
1199
+ ax: plt.Axes | None = None,
1200
+ show: bool | None = None,
1201
+ save: str | bool | None = None,
1202
+ **kwargs,
1203
+ ) -> plt.Axes | plt.Figure | None:
1204
+ """Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
1205
+
1206
+ Args:
1207
+ data: AnnData object or MuData object.
1208
+ feature_name: The name of the covariate to plot. If feature_name=="samples", one bar for every sample will be plotted
1209
+ modality_key: If data is a MuData object, specify which modality to use.
1210
+ figsize: Figure size.
1211
+ dpi: Dpi setting.
1212
+ palette: The matplotlib color map for the barplot.
1213
+ show_legend: If True, adds a legend.
1214
+ level_order: Custom ordering of bars on the x-axis.
1215
+
1216
+ Returns:
1217
+ A :class:`~matplotlib.axes.Axes` object
1218
+
1219
+ Examples:
1220
+ >>> import pertpy as pt
1221
+ >>> haber_cells = pt.dt.haber_2017_regions()
1222
+ >>> sccoda = pt.tl.Sccoda()
1223
+ >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
1224
+ sample_identifier="batch", covariate_obs=["condition"])
1225
+ >>> sccoda.plot_stacked_barplot(mdata, feature_name="samples")
1226
+
1227
+ Preview:
1228
+ .. image:: /_static/docstring_previews/sccoda_stacked_barplot.png
1229
+ """
1230
+ if isinstance(data, MuData):
1231
+ data = data[modality_key]
1232
+ if isinstance(data, AnnData):
1233
+ data = data
1234
+
1235
+ ct_names = data.var.index
1236
+
1237
+ # option to plot one stacked barplot per sample
1238
+ if feature_name == "samples":
1239
+ if level_order:
1240
+ assert set(level_order) == set(data.obs.index), "level order is inconsistent with levels"
1241
+ data = data[level_order]
1242
+ ax = self._stackbar(
1243
+ data.X,
1244
+ type_names=data.var.index,
1245
+ title="samples",
1246
+ level_names=data.obs.index,
1247
+ figsize=figsize,
1248
+ dpi=dpi,
1249
+ palette=palette,
1250
+ show_legend=show_legend,
1251
+ )
1252
+ else:
1253
+ # Order levels
1254
+ if level_order:
1255
+ assert set(level_order) == set(data.obs[feature_name]), "level order is inconsistent with levels"
1256
+ levels = level_order
1257
+ elif hasattr(data.obs[feature_name], "cat"):
1258
+ levels = data.obs[feature_name].cat.categories.to_list()
1259
+ else:
1260
+ levels = pd.unique(data.obs[feature_name])
1261
+ n_levels = len(levels)
1262
+ feature_totals = np.zeros([n_levels, data.X.shape[1]])
1263
+
1264
+ for level in range(n_levels):
1265
+ l_indices = np.where(data.obs[feature_name] == levels[level])
1266
+ feature_totals[level] = np.sum(data.X[l_indices], axis=0)
1267
+
1268
+ ax = self._stackbar(
1269
+ feature_totals,
1270
+ type_names=ct_names,
1271
+ title=feature_name,
1272
+ level_names=levels,
1273
+ figsize=figsize,
1274
+ dpi=dpi,
1275
+ palette=palette,
1276
+ show_legend=show_legend,
1277
+ )
1278
+
1279
+ if save:
1280
+ plt.savefig(save, bbox_inches="tight")
1281
+ if show:
1282
+ plt.show()
1283
+ if return_fig:
1284
+ return plt.gcf()
1285
+ if not (show or save):
1286
+ return ax
1287
+ return None
1288
+
1289
+ def plot_effects_barplot( # pragma: no cover
1290
+ self,
1291
+ data: AnnData | MuData,
1292
+ modality_key: str = "coda",
1293
+ covariates: str | list | None = None,
1294
+ parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
1295
+ plot_facets: bool = True,
1296
+ plot_zero_covariate: bool = True,
1297
+ plot_zero_cell_type: bool = False,
1298
+ palette: str | ListedColormap | None = cm.tab20,
1299
+ level_order: list[str] = None,
1300
+ args_barplot: dict | None = None,
1301
+ figsize: tuple[float, float] | None = None,
1302
+ dpi: int | None = 100,
1303
+ return_fig: bool | None = None,
1304
+ ax: plt.Axes | None = None,
1305
+ show: bool | None = None,
1306
+ save: str | bool | None = None,
1307
+ ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
1308
+ """Barplot visualization for effects.
1309
+
1310
+ The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
1311
+ The covariates groups can either be ordered along the x-axis of a single plot (plot_facets=False) or as plot facets (plot_facets=True).
1312
+
1313
+ Args:
1314
+ data: AnnData object or MuData object.
1315
+ modality_key: If data is a MuData object, specify which modality to use.
1316
+ covariates: The name of the covariates in data.obs to plot.
1317
+ parameter: The parameter in effect summary to plot.
1318
+ plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
1319
+ plot_zero_covariate: If True, plot covariate that have all zero effects. If False, do not plot.
1320
+ plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot.
1321
+ figsize: Figure size.
1322
+ dpi: Figure size.
1323
+ palette: The seaborn color map for the barplot.
1324
+ level_order: Custom ordering of bars on the x-axis.
1325
+ args_barplot: Arguments passed to sns.barplot.
1326
+
1327
+ Returns:
1328
+ Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
1329
+ or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
1330
+
1331
+ Examples:
1332
+ >>> import pertpy as pt
1333
+ >>> haber_cells = pt.dt.haber_2017_regions()
1334
+ >>> sccoda = pt.tl.Sccoda()
1335
+ >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
1336
+ sample_identifier="batch", covariate_obs=["condition"])
1337
+ >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
1338
+ >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
1339
+ >>> sccoda.plot_effects_barplot(mdata)
1340
+
1341
+ Preview:
1342
+ .. image:: /_static/docstring_previews/sccoda_effects_barplot.png
1343
+ """
1344
+ if args_barplot is None:
1345
+ args_barplot = {}
1346
+ if isinstance(data, MuData):
1347
+ data = data[modality_key]
1348
+ if isinstance(data, AnnData):
1349
+ data = data
1350
+ # Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
1351
+ covariate_names = data.uns["scCODA_params"]["covariate_names"]
1352
+ if covariates is not None:
1353
+ if isinstance(covariates, str):
1354
+ covariates = [covariates]
1355
+ partial_covariate_names = [
1356
+ covariate_name
1357
+ for covariate_name in covariate_names
1358
+ if any(covariate in covariate_name for covariate in covariates)
1359
+ ]
1360
+ covariate_names = partial_covariate_names
1361
+ covariate_names_non_zero = [
1362
+ covariate_name
1363
+ for covariate_name in covariate_names
1364
+ if data.varm[f"effect_df_{covariate_name}"][parameter].any()
1365
+ ]
1366
+ covariate_names_zero = list(set(covariate_names) - set(covariate_names_non_zero))
1367
+ if not plot_zero_covariate:
1368
+ covariate_names = covariate_names_non_zero
1369
+
1370
+ # set up df for plotting
1371
+ plot_df = pd.concat(
1372
+ [data.varm[f"effect_df_{covariate_name}"][parameter] for covariate_name in covariate_names],
1373
+ axis=1,
1374
+ )
1375
+ plot_df.columns = covariate_names
1376
+ plot_df = pd.melt(plot_df, ignore_index=False, var_name="Covariate")
1377
+
1378
+ plot_df = plot_df.reset_index()
1379
+
1380
+ if len(covariate_names_zero) != 0:
1381
+ if plot_facets:
1382
+ if plot_zero_covariate and not plot_zero_cell_type:
1383
+ plot_df = plot_df[plot_df["value"] != 0]
1384
+ for covariate_name_zero in covariate_names_zero:
1385
+ new_row = {
1386
+ "Covariate": covariate_name_zero,
1387
+ "Cell Type": "zero",
1388
+ "value": 0,
1389
+ }
1390
+ plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
1391
+ plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
1392
+ plot_df = plot_df.sort_values(["covariate_"])
1393
+ if not plot_zero_cell_type:
1394
+ cell_type_names_zero = [
1395
+ name
1396
+ for name in plot_df["Cell Type"].unique()
1397
+ if (plot_df[plot_df["Cell Type"] == name]["value"] == 0).all()
1398
+ ]
1399
+ plot_df = plot_df[~plot_df["Cell Type"].isin(cell_type_names_zero)]
1400
+
1401
+ # If plot as facets, create a FacetGrid and map barplot to it.
1402
+ if plot_facets:
1403
+ if isinstance(palette, ListedColormap):
1404
+ palette = np.array([palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]).tolist()
1405
+ if figsize is not None:
1406
+ height = figsize[0]
1407
+ aspect = np.round(figsize[1] / figsize[0], 2)
1408
+ else:
1409
+ height = 3
1410
+ aspect = 2
1411
+
1412
+ g = sns.FacetGrid(
1413
+ plot_df,
1414
+ col="Covariate",
1415
+ sharey=True,
1416
+ sharex=False,
1417
+ height=height,
1418
+ aspect=aspect,
1419
+ )
1420
+
1421
+ g.map(
1422
+ sns.barplot,
1423
+ "Cell Type",
1424
+ "value",
1425
+ palette=palette,
1426
+ order=level_order,
1427
+ **args_barplot,
1428
+ )
1429
+ g.set_xticklabels(rotation=90)
1430
+ g.set(ylabel=parameter)
1431
+ axes = g.axes.flatten()
1432
+ for i, ax in enumerate(axes):
1433
+ ax.set_title(covariate_names[i])
1434
+ if len(ax.get_xticklabels()) < 5:
1435
+ ax.set_aspect(10 / len(ax.get_xticklabels()))
1436
+ if len(ax.get_xticklabels()) == 1:
1437
+ if ax.get_xticklabels()[0]._text == "zero":
1438
+ ax.set_xticks([])
1439
+
1440
+ if save:
1441
+ plt.savefig(save, bbox_inches="tight")
1442
+ if show:
1443
+ plt.show()
1444
+ if return_fig:
1445
+ return plt.gcf()
1446
+ if not (show or save):
1447
+ return g
1448
+ return None
1449
+
1450
+ # If not plot as facets, call barplot to plot cell types on the x-axis.
1451
+ else:
1452
+ _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1453
+ if len(covariate_names) == 1:
1454
+ if isinstance(palette, ListedColormap):
1455
+ palette = np.array(
1456
+ [palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]
1457
+ ).tolist()
1458
+ sns.barplot(
1459
+ data=plot_df,
1460
+ x="Cell Type",
1461
+ y="value",
1462
+ hue="x",
1463
+ palette=palette,
1464
+ ax=ax,
1465
+ )
1466
+ ax.set_title(covariate_names[0])
1467
+ else:
1468
+ if isinstance(palette, ListedColormap):
1469
+ palette = np.array([palette(i % palette.N) for i in range(len(covariate_names))]).tolist()
1470
+ sns.barplot(
1471
+ data=plot_df,
1472
+ x="Cell Type",
1473
+ y="value",
1474
+ hue="Covariate",
1475
+ palette=palette,
1476
+ ax=ax,
1477
+ )
1478
+ cell_types = pd.unique(plot_df["Cell Type"])
1479
+ ax.set_xticklabels(cell_types, rotation=90)
1480
+
1481
+ if save:
1482
+ plt.savefig(save, bbox_inches="tight")
1483
+ if show:
1484
+ plt.show()
1485
+ if return_fig:
1486
+ return plt.gcf()
1487
+ if not (show or save):
1488
+ return ax
1489
+ return None
1490
+
1491
+ def plot_boxplots( # pragma: no cover
1492
+ self,
1493
+ data: AnnData | MuData,
1494
+ feature_name: str,
1495
+ modality_key: str = "coda",
1496
+ y_scale: Literal["relative", "log", "log10", "count"] = "relative",
1497
+ plot_facets: bool = False,
1498
+ add_dots: bool = False,
1499
+ cell_types: list | None = None,
1500
+ args_boxplot: dict | None = None,
1501
+ args_swarmplot: dict | None = None,
1502
+ palette: str | None = "Blues",
1503
+ show_legend: bool | None = True,
1504
+ level_order: list[str] = None,
1505
+ figsize: tuple[float, float] | None = None,
1506
+ dpi: int | None = 100,
1507
+ return_fig: bool | None = None,
1508
+ ax: plt.Axes | None = None,
1509
+ show: bool | None = None,
1510
+ save: str | bool | None = None,
1511
+ ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
1512
+ """Grouped boxplot visualization.
1513
+
1514
+ The cell counts for each cell type are shown as a group of boxplots
1515
+ with intra--group separation by a covariate from data.obs.
1516
+
1517
+ Args:
1518
+ data: AnnData object or MuData object
1519
+ feature_name: The name of the feature in data.obs to plot
1520
+ modality_key: If data is a MuData object, specify which modality to use.
1521
+ y_scale: Transformation to of cell counts. Options: "relative" - Relative abundance, "log" - log(count),
1522
+ "log10" - log10(count), "count" - absolute abundance (cell counts).
1523
+ plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
1524
+ add_dots: If True, overlay a scatterplot with one dot for each data point.
1525
+ cell_types: Subset of cell types that should be plotted.
1526
+ args_boxplot: Arguments passed to sns.boxplot.
1527
+ args_swarmplot: Arguments passed to sns.swarmplot.
1528
+ figsize: Figure size.
1529
+ dpi: Dpi setting.
1530
+ palette: The seaborn color map for the barplot.
1531
+ show_legend: If True, adds a legend.
1532
+ level_order: Custom ordering of bars on the x-axis.
1533
+
1534
+ Returns:
1535
+ Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
1536
+ or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
1537
+
1538
+ Examples:
1539
+ >>> import pertpy as pt
1540
+ >>> haber_cells = pt.dt.haber_2017_regions()
1541
+ >>> sccoda = pt.tl.Sccoda()
1542
+ >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
1543
+ sample_identifier="batch", covariate_obs=["condition"])
1544
+ >>> sccoda.plot_boxplots(mdata, feature_name="condition", add_dots=True)
1545
+
1546
+ Preview:
1547
+ .. image:: /_static/docstring_previews/sccoda_boxplots.png
1548
+ """
1549
+ if args_boxplot is None:
1550
+ args_boxplot = {}
1551
+ if args_swarmplot is None:
1552
+ args_swarmplot = {}
1553
+ if isinstance(data, MuData):
1554
+ data = data[modality_key]
1555
+ if isinstance(data, AnnData):
1556
+ data = data
1557
+ # y scale transformations
1558
+ if y_scale == "relative":
1559
+ sample_sums = np.sum(data.X, axis=1, keepdims=True)
1560
+ X = data.X / sample_sums
1561
+ value_name = "Proportion"
1562
+ # add pseudocount 0.5 if using log scale
1563
+ elif y_scale == "log":
1564
+ X = data.X.copy()
1565
+ X[X == 0] = 0.5
1566
+ X = np.log(X)
1567
+ value_name = "log(count)"
1568
+ elif y_scale == "log10":
1569
+ X = data.X.copy()
1570
+ X[X == 0] = 0.5
1571
+ X = np.log(X)
1572
+ value_name = "log10(count)"
1573
+ elif y_scale == "count":
1574
+ X = data.X
1575
+ value_name = "count"
1576
+ else:
1577
+ raise ValueError("Invalid y_scale transformation")
1578
+
1579
+ count_df = pd.DataFrame(X, columns=data.var.index, index=data.obs.index).merge(
1580
+ data.obs[feature_name], left_index=True, right_index=True
1581
+ )
1582
+ plot_df = pd.melt(count_df, id_vars=feature_name, var_name="Cell type", value_name=value_name)
1583
+ if cell_types is not None:
1584
+ plot_df = plot_df[plot_df["Cell type"].isin(cell_types)]
1585
+
1586
+ # Currently disabled because the latest statsannotations does not support the latest seaborn.
1587
+ # We had to drop the dependency.
1588
+ # Get credible effects results from model
1589
+ # if draw_effects:
1590
+ # if model is not None:
1591
+ # credible_effects_df = model.credible_effects(data, modality_key).to_frame().reset_index()
1592
+ # else:
1593
+ # print("[bold yellow]Specify a tasCODA model to draw effects")
1594
+ # credible_effects_df[feature_name] = credible_effects_df["Covariate"].str.removeprefix(f"{feature_name}[T.")
1595
+ # credible_effects_df[feature_name] = credible_effects_df[feature_name].str.removesuffix("]")
1596
+ # credible_effects_df = credible_effects_df[credible_effects_df["Final Parameter"]]
1597
+
1598
+ # If plot as facets, create a FacetGrid and map boxplot to it.
1599
+ if plot_facets:
1600
+ if level_order is None:
1601
+ level_order = pd.unique(plot_df[feature_name])
1602
+
1603
+ K = X.shape[1]
1604
+
1605
+ if figsize is not None:
1606
+ height = figsize[0]
1607
+ aspect = np.round(figsize[1] / figsize[0], 2)
1608
+ else:
1609
+ height = 3
1610
+ aspect = 2
1611
+
1612
+ g = sns.FacetGrid(
1613
+ plot_df,
1614
+ col="Cell type",
1615
+ sharey=False,
1616
+ col_wrap=int(np.floor(np.sqrt(K))),
1617
+ height=height,
1618
+ aspect=aspect,
1619
+ )
1620
+ g.map(
1621
+ sns.boxplot,
1622
+ feature_name,
1623
+ value_name,
1624
+ palette=palette,
1625
+ order=level_order,
1626
+ **args_boxplot,
1627
+ )
1628
+
1629
+ if add_dots:
1630
+ if "hue" in args_swarmplot:
1631
+ hue = args_swarmplot.pop("hue")
1632
+ else:
1633
+ hue = None
1634
+
1635
+ if hue is None:
1636
+ g.map(
1637
+ sns.swarmplot,
1638
+ feature_name,
1639
+ value_name,
1640
+ color="black",
1641
+ order=level_order,
1642
+ **args_swarmplot,
1643
+ ).set_titles("{col_name}")
1644
+ else:
1645
+ g.map(
1646
+ sns.swarmplot,
1647
+ feature_name,
1648
+ value_name,
1649
+ hue,
1650
+ order=level_order,
1651
+ **args_swarmplot,
1652
+ ).set_titles("{col_name}")
1653
+
1654
+ if save:
1655
+ plt.savefig(save, bbox_inches="tight")
1656
+ if show:
1657
+ plt.show()
1658
+ if return_fig:
1659
+ return plt.gcf()
1660
+ if not (show or save):
1661
+ return g
1662
+ return None
1663
+
1664
+ # If not plot as facets, call boxplot to plot cell types on the x-axis.
1665
+ else:
1666
+ if level_order:
1667
+ args_boxplot["hue_order"] = level_order
1668
+ args_swarmplot["hue_order"] = level_order
1669
+
1670
+ _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1671
+
1672
+ ax = sns.boxplot(
1673
+ x="Cell type",
1674
+ y=value_name,
1675
+ hue=feature_name,
1676
+ data=plot_df,
1677
+ fliersize=1,
1678
+ palette=palette,
1679
+ ax=ax,
1680
+ **args_boxplot,
1681
+ )
1682
+
1683
+ # Currently disabled because the latest statsannotations does not support the latest seaborn.
1684
+ # We had to drop the dependency.
1685
+ # if draw_effects:
1686
+ # pairs = [
1687
+ # [(row["Cell Type"], row[feature_name]), (row["Cell Type"], "Control")]
1688
+ # for _, row in credible_effects_df.iterrows()
1689
+ # ]
1690
+ # annot = Annotator(ax, pairs, data=plot_df, x="Cell type", y=value_name, hue=feature_name)
1691
+ # annot.configure(test=None, loc="outside", color="red", line_height=0, verbose=False)
1692
+ # annot.set_custom_annotations([row[feature_name] for _, row in credible_effects_df.iterrows()])
1693
+ # annot.annotate()
1694
+
1695
+ if add_dots:
1696
+ sns.swarmplot(
1697
+ x="Cell type",
1698
+ y=value_name,
1699
+ data=plot_df,
1700
+ hue=feature_name,
1701
+ ax=ax,
1702
+ dodge=True,
1703
+ palette="dark:black",
1704
+ **args_swarmplot,
1705
+ )
1706
+
1707
+ cell_types = pd.unique(plot_df["Cell type"])
1708
+ ax.set_xticklabels(cell_types, rotation=90)
1709
+
1710
+ if show_legend:
1711
+ handles, labels = ax.get_legend_handles_labels()
1712
+ handout = []
1713
+ labelout = []
1714
+ for h, l in zip(handles, labels, strict=False):
1715
+ if l not in labelout:
1716
+ labelout.append(l)
1717
+ handout.append(h)
1718
+ ax.legend(
1719
+ handout,
1720
+ labelout,
1721
+ loc="upper left",
1722
+ bbox_to_anchor=(1, 1),
1723
+ ncol=1,
1724
+ title=feature_name,
1725
+ )
1726
+
1727
+ if save:
1728
+ plt.savefig(save, bbox_inches="tight")
1729
+ if show:
1730
+ plt.show()
1731
+ if return_fig:
1732
+ return plt.gcf()
1733
+ if not (show or save):
1734
+ return ax
1735
+ return None
1736
+
1737
+ def plot_rel_abundance_dispersion_plot( # pragma: no cover
1738
+ self,
1739
+ data: AnnData | MuData,
1740
+ modality_key: str = "coda",
1741
+ abundant_threshold: float | None = 0.9,
1742
+ default_color: str | None = "Grey",
1743
+ abundant_color: str | None = "Red",
1744
+ label_cell_types: bool = True,
1745
+ figsize: tuple[float, float] | None = None,
1746
+ dpi: int | None = 100,
1747
+ return_fig: bool | None = None,
1748
+ ax: plt.Axes | None = None,
1749
+ show: bool | None = None,
1750
+ save: str | bool | None = None,
1751
+ ) -> plt.Axes | plt.Figure | None:
1752
+ """Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
1753
+
1754
+ If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color.
1755
+
1756
+ Args:
1757
+ data: AnnData or MuData object.
1758
+ modality_key: If data is a MuData object, specify which modality to use.
1759
+ abundant_threshold: Presence threshold for abundant cell types.
1760
+ default_color: Bar color for all non-minimal cell types.
1761
+ abundant_color: Bar color for cell types with abundant percentage larger than abundant_threshold.
1762
+ label_cell_types: Label dots with cell type names.
1763
+ figsize: Figure size.
1764
+ dpi: Dpi setting.
1765
+ ax: A matplotlib axes object. Only works if plotting a single component.
1766
+
1767
+ Returns:
1768
+ A :class:`~matplotlib.axes.Axes` object
1769
+
1770
+ Examples:
1771
+ >>> import pertpy as pt
1772
+ >>> haber_cells = pt.dt.haber_2017_regions()
1773
+ >>> sccoda = pt.tl.Sccoda()
1774
+ >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
1775
+ sample_identifier="batch", covariate_obs=["condition"])
1776
+ >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
1777
+ >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
1778
+ >>> sccoda.plot_rel_abundance_dispersion_plot(mdata)
1779
+
1780
+ Preview:
1781
+ .. image:: /_static/docstring_previews/sccoda_rel_abundance_dispersion_plot.png
1782
+ """
1783
+ if isinstance(data, MuData):
1784
+ data = data[modality_key]
1785
+ if isinstance(data, AnnData):
1786
+ data = data
1787
+ if ax is None:
1788
+ _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1789
+
1790
+ rel_abun = data.X / np.sum(data.X, axis=1, keepdims=True)
1791
+
1792
+ percent_zero = np.sum(data.X == 0, axis=0) / data.X.shape[0]
1793
+ nonrare_ct = np.where(percent_zero < 1 - abundant_threshold)[0]
1794
+
1795
+ # select reference
1796
+ cell_type_disp = np.var(rel_abun, axis=0) / np.mean(rel_abun, axis=0)
1797
+
1798
+ is_abundant = [x in nonrare_ct for x in range(data.X.shape[1])]
1799
+
1800
+ # Scatterplot
1801
+ plot_df = pd.DataFrame(
1802
+ {
1803
+ "Total dispersion": cell_type_disp,
1804
+ "Cell type": data.var.index,
1805
+ "Presence": 1 - percent_zero,
1806
+ "Is abundant": is_abundant,
1807
+ }
1808
+ )
1809
+
1810
+ if len(np.unique(plot_df["Is abundant"])) > 1:
1811
+ palette = [default_color, abundant_color]
1812
+ elif np.unique(plot_df["Is abundant"]) == [False]:
1813
+ palette = [default_color]
1814
+ else:
1815
+ palette = [abundant_color]
1816
+
1817
+ ax = sns.scatterplot(
1818
+ data=plot_df,
1819
+ x="Presence",
1820
+ y="Total dispersion",
1821
+ hue="Is abundant",
1822
+ palette=palette,
1823
+ ax=ax,
1824
+ )
1825
+
1826
+ # Text labels for abundant cell types
1827
+
1828
+ abundant_df = plot_df.loc[plot_df["Is abundant"], :]
1829
+
1830
+ def label_point(x, y, val, ax):
1831
+ a = pd.concat({"x": x, "y": y, "val": val}, axis=1)
1832
+ texts = [
1833
+ ax.text(
1834
+ point["x"],
1835
+ point["y"],
1836
+ str(point["val"]),
1837
+ )
1838
+ for i, point in a.iterrows()
1839
+ ]
1840
+ adjust_text(texts)
1841
+
1842
+ if label_cell_types:
1843
+ label_point(
1844
+ abundant_df["Presence"],
1845
+ abundant_df["Total dispersion"],
1846
+ abundant_df["Cell type"],
1847
+ plt.gca(),
1848
+ )
1849
+
1850
+ ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
1851
+
1852
+ if save:
1853
+ plt.savefig(save, bbox_inches="tight")
1854
+ if show:
1855
+ plt.show()
1856
+ if return_fig:
1857
+ return plt.gcf()
1858
+ if not (show or save):
1859
+ return ax
1860
+ return None
1861
+
1862
+ def plot_draw_tree( # pragma: no cover
1863
+ self,
1864
+ data: AnnData | MuData,
1865
+ modality_key: str = "coda",
1866
+ tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1867
+ tight_text: bool | None = False,
1868
+ show_scale: bool | None = False,
1869
+ units: Literal["px", "mm", "in"] | None = "px",
1870
+ figsize: tuple[float, float] | None = (None, None),
1871
+ dpi: int | None = 100,
1872
+ show: bool | None = True,
1873
+ save: str | bool | None = None,
1874
+ ) -> Tree | None:
1875
+ """Plot a tree using input ete3 tree object.
1876
+
1877
+ Args:
1878
+ data: AnnData object or MuData object.
1879
+ modality_key: If data is a MuData object, specify which modality to use.
1880
+ tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
1881
+ tight_text: When False, boundaries of the text are approximated according to general font metrics,
1882
+ producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
1883
+ show_scale: Include the scale legend in the tree image or not.
1884
+ show: If True, plot the tree inline. If false, return tree and tree_style objects.
1885
+ file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG.
1886
+ Output image can be saved whether show is True or not.
1887
+ units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
1888
+ figsize: Figure size.
1889
+ dpi: Dots per inches.
1890
+
1891
+ Returns:
1892
+ Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
1893
+
1894
+ Examples:
1895
+ >>> import pertpy as pt
1896
+ >>> adata = pt.dt.tasccoda_example()
1897
+ >>> tasccoda = pt.tl.Tasccoda()
1898
+ >>> mdata = tasccoda.load(
1899
+ >>> adata, type="sample_level",
1900
+ >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
1901
+ >>> key_added="lineage", add_level_name=True
1902
+ >>> )
1903
+ >>> mdata = tasccoda.prepare(
1904
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
1905
+ >>> )
1906
+ >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
1907
+ >>> tasccoda.plot_draw_tree(mdata, tree="lineage")
1908
+
1909
+ Preview:
1910
+ .. image:: /_static/docstring_previews/tasccoda_draw_tree.png
1911
+ """
1912
+ try:
1913
+ from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
1914
+ except ImportError:
1915
+ raise ImportError(
1916
+ "To use tasccoda please install additional dependencies with `pip install pertpy[coda]`"
1917
+ ) from None
1918
+
1919
+ if isinstance(data, MuData):
1920
+ data = data[modality_key]
1921
+ if isinstance(data, AnnData):
1922
+ data = data
1923
+ if isinstance(tree, str):
1924
+ tree = data.uns[tree]
1925
+
1926
+ def my_layout(node):
1927
+ text_face = TextFace(node.name, tight_text=tight_text)
1928
+ faces.add_face_to_node(text_face, node, column=0, position="branch-right")
1929
+
1930
+ tree_style = TreeStyle()
1931
+ tree_style.show_leaf_name = False
1932
+ tree_style.layout_fn = my_layout
1933
+ tree_style.show_scale = show_scale
1934
+
1935
+ if save is not None:
1936
+ tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1937
+ if show:
1938
+ return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1939
+ else:
1940
+ return tree, tree_style
1941
+
1942
+ def plot_draw_effects( # pragma: no cover
1943
+ self,
1944
+ data: AnnData | MuData,
1945
+ covariate: str,
1946
+ modality_key: str = "coda",
1947
+ tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1948
+ show_legend: bool | None = None,
1949
+ show_leaf_effects: bool | None = False,
1950
+ tight_text: bool | None = False,
1951
+ show_scale: bool | None = False,
1952
+ units: Literal["px", "mm", "in"] | None = "px",
1953
+ figsize: tuple[float, float] | None = (None, None),
1954
+ dpi: int | None = 100,
1955
+ show: bool | None = True,
1956
+ save: str | None = None,
1957
+ ) -> Tree | None:
1958
+ """Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
1959
+
1960
+ Args:
1961
+ data: AnnData object or MuData object.
1962
+ covariate: The covariate, whose effects should be plotted.
1963
+ modality_key: If data is a MuData object, specify which modality to use.
1964
+ tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
1965
+ show_legend: If show legend of nodes significant effects or not.
1966
+ Defaults to False if show_leaf_effects is True.
1967
+ show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects.
1968
+ tight_text: When False, boundaries of the text are approximated according to general font metrics,
1969
+ producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
1970
+ show_scale: Include the scale legend in the tree image or not.
1971
+ show: If True, plot the tree inline. If false, return tree and tree_style objects.
1972
+ file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
1973
+ units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
1974
+ figsize: Figure size.
1975
+ dpi: Dots per inches.
1976
+
1977
+ Returns:
1978
+ Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`)
1979
+ or plot the tree inline (`show = False`)
1980
+
1981
+ Examples:
1982
+ >>> import pertpy as pt
1983
+ >>> adata = pt.dt.tasccoda_example()
1984
+ >>> tasccoda = pt.tl.Tasccoda()
1985
+ >>> mdata = tasccoda.load(
1986
+ >>> adata, type="sample_level",
1987
+ >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
1988
+ >>> key_added="lineage", add_level_name=True
1989
+ >>> )
1990
+ >>> mdata = tasccoda.prepare(
1991
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
1992
+ >>> )
1993
+ >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
1994
+ >>> tasccoda.plot_draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
1995
+
1996
+ Preview:
1997
+ .. image:: /_static/docstring_previews/tasccoda_draw_effects.png
1998
+ """
1999
+ try:
2000
+ from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
2001
+ except ImportError:
2002
+ raise ImportError(
2003
+ "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
2004
+ ) from None
2005
+
2006
+ if isinstance(data, MuData):
2007
+ data = data[modality_key]
2008
+ if isinstance(data, AnnData):
2009
+ data = data
2010
+ if show_legend is None:
2011
+ show_legend = not show_leaf_effects
2012
+ elif show_legend:
2013
+ logger.info("Tree leaves and leaf effect bars won't be aligned when legend is shown!")
2014
+
2015
+ if isinstance(tree, str):
2016
+ tree = data.uns[tree]
2017
+ # Collapse tree singularities
2018
+ tree2 = collapse_singularities_2(tree)
2019
+
2020
+ node_effs = data.uns["scCODA_params"]["node_df"].loc[(covariate + "_node",),].copy()
2021
+ node_effs.index = node_effs.index.get_level_values("Node")
2022
+
2023
+ covariates = data.uns["scCODA_params"]["covariate_names"]
2024
+ effect_dfs = [data.varm[f"effect_df_{cov}"] for cov in covariates]
2025
+ eff_df = pd.concat(effect_dfs)
2026
+ eff_df.index = pd.MultiIndex.from_product(
2027
+ (covariates, data.var.index.tolist()),
2028
+ names=["Covariate", "Cell Type"],
2029
+ )
2030
+ leaf_effs = eff_df.loc[(covariate,),].copy()
2031
+ leaf_effs.index = leaf_effs.index.get_level_values("Cell Type")
2032
+
2033
+ # Add effect values
2034
+ for n in tree2.traverse():
2035
+ nstyle = NodeStyle()
2036
+ nstyle["size"] = 0
2037
+ n.set_style(nstyle)
2038
+ if n.name in node_effs.index:
2039
+ e = node_effs.loc[n.name, "Final Parameter"]
2040
+ n.add_feature("node_effect", e)
2041
+ else:
2042
+ n.add_feature("node_effect", 0)
2043
+ if n.name in leaf_effs.index:
2044
+ e = leaf_effs.loc[n.name, "Effect"]
2045
+ n.add_feature("leaf_effect", e)
2046
+ else:
2047
+ n.add_feature("leaf_effect", 0)
2048
+
2049
+ # Scale effect values to get nice node sizes
2050
+ eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
2051
+ leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
2052
+
2053
+ def my_layout(node):
2054
+ text_face = TextFace(node.name, tight_text=tight_text)
2055
+ text_face.margin_left = 10
2056
+ faces.add_face_to_node(text_face, node, column=0, aligned=True)
2057
+
2058
+ # if node.is_leaf():
2059
+ size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
2060
+ if np.sign(node.node_effect) == 1:
2061
+ color = "blue"
2062
+ elif np.sign(node.node_effect) == -1:
2063
+ color = "red"
2064
+ else:
2065
+ color = "cyan"
2066
+ if size != 0:
2067
+ faces.add_face_to_node(CircleFace(radius=size, color=color), node, column=0)
2068
+
2069
+ tree_style = TreeStyle()
2070
+ tree_style.show_leaf_name = False
2071
+ tree_style.layout_fn = my_layout
2072
+ tree_style.show_scale = show_scale
2073
+ tree_style.draw_guiding_lines = True
2074
+ tree_style.legend_position = 1
2075
+
2076
+ if show_legend:
2077
+ tree_style.legend.add_face(TextFace("Effects"), column=0)
2078
+ tree_style.legend.add_face(TextFace(" "), column=1)
2079
+ for i in range(4, 0, -1):
2080
+ tree_style.legend.add_face(
2081
+ CircleFace(
2082
+ float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
2083
+ "red",
2084
+ ),
2085
+ column=0,
2086
+ )
2087
+ tree_style.legend.add_face(TextFace(f"{-eff_max * i / 4:.2f} "), column=0)
2088
+ tree_style.legend.add_face(
2089
+ CircleFace(
2090
+ float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
2091
+ "blue",
2092
+ ),
2093
+ column=1,
2094
+ )
2095
+ tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
2096
+
2097
+ if show_leaf_effects:
2098
+ leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
2099
+ leaf_effs = leaf_effs.loc[leaf_name].reset_index()
2100
+ palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
2101
+
2102
+ dir_path = Path.cwd()
2103
+ dir_path = Path(dir_path / "tree_effect.png")
2104
+ tree2.render(dir_path, tree_style=tree_style, units="in")
2105
+ _, ax = plt.subplots(1, 2, figsize=(10, 10))
2106
+ sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
2107
+ img = mpimg.imread(dir_path)
2108
+ ax[0].imshow(img)
2109
+ ax[0].get_xaxis().set_visible(False)
2110
+ ax[0].get_yaxis().set_visible(False)
2111
+ ax[0].set_frame_on(False)
2112
+
2113
+ ax[1].get_yaxis().set_visible(False)
2114
+ ax[1].spines["left"].set_visible(False)
2115
+ ax[1].spines["right"].set_visible(False)
2116
+ ax[1].spines["top"].set_visible(False)
2117
+ plt.xlim(-leaf_eff_max, leaf_eff_max)
2118
+ plt.subplots_adjust(wspace=0)
2119
+
2120
+ if save is not None:
2121
+ plt.savefig(save)
2122
+
2123
+ if save is not None and not show_leaf_effects:
2124
+ tree2.render(save, tree_style=tree_style, units=units)
2125
+ if show:
2126
+ if not show_leaf_effects:
2127
+ return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
2128
+ else:
2129
+ if not show_leaf_effects:
2130
+ return tree2, tree_style
2131
+ return None
2132
+
2133
+ def plot_effects_umap( # pragma: no cover
2134
+ self,
2135
+ mdata: MuData,
2136
+ effect_name: str | list | None,
2137
+ cluster_key: str,
2138
+ modality_key_1: str = "rna",
2139
+ modality_key_2: str = "coda",
2140
+ color_map: Colormap | str | None = None,
2141
+ palette: str | Sequence[str] | None = None,
2142
+ return_fig: bool | None = None,
2143
+ ax: Axes = None,
2144
+ show: bool = None,
2145
+ save: str | bool | None = None,
2146
+ **kwargs,
2147
+ ) -> plt.Axes | plt.Figure | None:
2148
+ """Plot a UMAP visualization colored by effect strength.
2149
+
2150
+ Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData
2151
+ (default is data['rna']) depending on the cluster they were assigned to.
2152
+
2153
+ Args:
2154
+ mudata: MuData object.
2155
+ effect_name: The name of the effect results in .varm of aggregated sample-level AnnData to plot
2156
+ cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']).
2157
+ To assign cell types' effects to original cells.
2158
+ modality_key_1: Key to the cell-level AnnData in the MuData object.
2159
+ modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
2160
+ show: Whether to display the figure or return axis.
2161
+ ax: A matplotlib axes object. Only works if plotting a single component.
2162
+ **kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
2163
+
2164
+ Returns:
2165
+ If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
2166
+
2167
+ Examples:
2168
+ >>> import pertpy as pt
2169
+ >>> import scanpy as sc
2170
+ >>> import schist
2171
+ >>> adata = pt.dt.haber_2017_regions()
2172
+ >>> sc.pp.neighbors(adata)
2173
+ >>> schist.inference.nested_model(adata, n_init=100, random_seed=5678)
2174
+ >>> tasccoda_model = pt.tl.Tasccoda()
2175
+ >>> tasccoda_data = tasccoda_model.load(adata, type="cell_level",
2176
+ >>> cell_type_identifier="nsbm_level_1",
2177
+ >>> sample_identifier="batch", covariate_obs=["condition"],
2178
+ >>> levels_orig=["nsbm_level_4", "nsbm_level_3", "nsbm_level_2", "nsbm_level_1"],
2179
+ >>> add_level_name=True)
2180
+ >>> tasccoda_model.prepare(
2181
+ >>> tasccoda_data,
2182
+ >>> modality_key="coda",
2183
+ >>> reference_cell_type="18",
2184
+ >>> formula="condition",
2185
+ >>> pen_args={"phi": 0, "lambda_1": 3.5},
2186
+ >>> tree_key="tree"
2187
+ >>> )
2188
+ >>> tasccoda_model.run_nuts(
2189
+ ... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
2190
+ ... )
2191
+ >>> tasccoda_model.run_nuts(
2192
+ ... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
2193
+ ... )
2194
+ >>> sc.tl.umap(tasccoda_data["rna"])
2195
+ >>> tasccoda_model.plot_effects_umap(tasccoda_data,
2196
+ >>> effect_name=["effect_df_condition[T.Salmonella]",
2197
+ >>> "effect_df_condition[T.Hpoly.Day3]",
2198
+ >>> "effect_df_condition[T.Hpoly.Day10]"],
2199
+ >>> cluster_key="nsbm_level_1",
2200
+ >>> )
2201
+
2202
+ Preview:
2203
+ .. image:: /_static/docstring_previews/tasccoda_effects_umap.png
2204
+ """
2205
+ # TODO: Add effect_name parameter and cluster_key and test the example
2206
+ data_rna = mdata[modality_key_1]
2207
+ data_coda = mdata[modality_key_2]
2208
+ if isinstance(effect_name, str):
2209
+ effect_name = [effect_name]
2210
+ for _, effect in enumerate(effect_name):
2211
+ data_rna.obs[effect] = [data_coda.varm[effect].loc[f"{c}", "Effect"] for c in data_rna.obs[cluster_key]]
2212
+ if kwargs.get("vmin"):
2213
+ vmin = kwargs["vmin"]
2214
+ kwargs.pop("vmin")
2215
+ else:
2216
+ vmin = min(data_rna.obs[effect].min() for _, effect in enumerate(effect_name))
2217
+ if kwargs.get("vmax"):
2218
+ vmax = kwargs["vmax"]
2219
+ kwargs.pop("vmax")
2220
+ else:
2221
+ vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name))
2222
+
2223
+ return sc.pl.umap(
2224
+ data_rna,
2225
+ color=effect_name,
2226
+ vmax=vmax,
2227
+ vmin=vmin,
2228
+ palette=palette,
2229
+ color_map=color_map,
2230
+ return_fig=return_fig,
2231
+ ax=ax,
2232
+ show=show,
2233
+ save=save,
2234
+ **kwargs,
2235
+ )
2236
+
1116
2237
 
1117
2238
  def get_a(
1118
- tree: tt.tree,
2239
+ tree: tt.core.ToyTree,
1119
2240
  ) -> tuple[np.ndarray, int]:
1120
2241
  """Calculate ancestor matrix from a toytree tree
1121
2242
 
@@ -1154,7 +2275,7 @@ def get_a(
1154
2275
  return A, n_nodes - 1
1155
2276
 
1156
2277
 
1157
- def collapse_singularities(tree: tt.tree) -> tt.tree:
2278
+ def collapse_singularities(tree: tt.core.ToyTree) -> tt.core.ToyTree:
1158
2279
  """Collapses (deletes) nodes in a toytree tree that are singularities (have only one child).
1159
2280
 
1160
2281
  Args:
@@ -1242,7 +2363,7 @@ def df2newick(df: pd.DataFrame, levels: list[str], inner_label: bool = True) ->
1242
2363
 
1243
2364
 
1244
2365
  def get_a_2(
1245
- tree: ete.Tree,
2366
+ tree: Tree,
1246
2367
  leaf_order: list[str] = None,
1247
2368
  node_order: list[str] = None,
1248
2369
  ) -> tuple[np.ndarray, int]:
@@ -1263,6 +2384,13 @@ def get_a_2(
1263
2384
  T
1264
2385
  number of nodes in the tree, excluding the root node
1265
2386
  """
2387
+ try:
2388
+ import ete3 as ete
2389
+ except ImportError:
2390
+ raise ImportError(
2391
+ "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
2392
+ ) from None
2393
+
1266
2394
  n_tips = len(tree.get_leaves())
1267
2395
  n_nodes = len(tree.get_descendants())
1268
2396
 
@@ -1292,7 +2420,7 @@ def get_a_2(
1292
2420
  return A_, n_nodes
1293
2421
 
1294
2422
 
1295
- def collapse_singularities_2(tree: ete.Tree) -> ete.Tree:
2423
+ def collapse_singularities_2(tree: Tree) -> Tree:
1296
2424
  """Collapses (deletes) nodes in a ete3 tree that are singularities (have only one child).
1297
2425
 
1298
2426
  Args:
@@ -1327,10 +2455,10 @@ def linkage_to_newick(
1327
2455
 
1328
2456
  def build_newick(node, newick, parentdist, leaf_names):
1329
2457
  if node.is_leaf():
1330
- return f"{leaf_names[node.id]}:{(parentdist - node.dist)/2}{newick}"
2458
+ return f"{leaf_names[node.id]}:{(parentdist - node.dist) / 2}{newick}"
1331
2459
  else:
1332
2460
  if len(newick) > 0:
1333
- newick = f"):{(parentdist - node.dist)/2}{newick}"
2461
+ newick = f"):{(parentdist - node.dist) / 2}{newick}"
1334
2462
  else:
1335
2463
  newick = ");"
1336
2464
  newick = build_newick(node.get_left(), newick, node.dist, leaf_names)
@@ -1363,14 +2491,14 @@ def import_tree(
1363
2491
 
1364
2492
  Args:
1365
2493
  data: A tascCODA-compatible data object.
1366
- modality_1: If `data` is MuData, specifiy the modality name to the original cell level anndata object. Defaults to None.
1367
- modality_2: If `data` is MuData, specifiy the modality name to the aggregated level anndata object. Defaults to None.
1368
- dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object. Defaults to None.
1369
- levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None.
1370
- levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None.
1371
- add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}. Defaults to True.
1372
- key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`. If `data` is MuData, save tree in data[modality_2]. Defaults to "tree".
1373
- copy: Return a copy instead of writing to `data`. Defaults to False.
2494
+ modality_1: If `data` is MuData, specify the modality name to the original cell level anndata object.
2495
+ modality_2: If `data` is MuData, specify the modality name to the aggregated level anndata object.
2496
+ dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object.
2497
+ levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
2498
+ levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
2499
+ add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
2500
+ key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`.
2501
+ If `data` is MuData, save tree in data[modality_2].
1374
2502
 
1375
2503
  Returns:
1376
2504
  Updates data with the following:
@@ -1379,15 +2507,22 @@ def import_tree(
1379
2507
 
1380
2508
  tree: A ete3 tree object.
1381
2509
  """
2510
+ try:
2511
+ import ete3 as ete
2512
+ except ImportError:
2513
+ raise ImportError(
2514
+ "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
2515
+ ) from None
2516
+
1382
2517
  if isinstance(data, MuData):
1383
2518
  try:
1384
2519
  data_1 = data[modality_1]
1385
2520
  data_2 = data[modality_2]
1386
2521
  except KeyError as name:
1387
- print(f"No {name} slot in MuData")
2522
+ logger.error(f"No {name} slot in MuData")
1388
2523
  raise
1389
2524
  except IndexError:
1390
- print("Please specify modality_1 and modality_2 to indicate modalities in MuData")
2525
+ logger.error("Please specify modality_1 and modality_2 to indicate modalities in MuData")
1391
2526
  raise
1392
2527
  else:
1393
2528
  data_1 = data
@@ -1443,68 +2578,54 @@ def from_scanpy(
1443
2578
 
1444
2579
  The anndata object needs to have a column in adata.obs that contains the cell type assignment.
1445
2580
  Further, it must contain one column or a set of columns (e.g. subject id, treatment, disease status) that uniquely identify each (statistical) sample.
1446
- Further covariates (e.g. subject age) can either be specified via addidional column names in adata.obs, a key in adata.uns, or as a separate DataFrame.
2581
+ Further covariates (e.g. subject age) can either be specified via additional column names in adata.obs, a key in adata.uns, or as a separate DataFrame.
1447
2582
 
1448
- NOTE: The order of samples in the returned dataset is determined by the first occurence of cells from each sample in `adata`
2583
+ NOTE: The order of samples in the returned dataset is determined by the first occurrence of cells from each sample in `adata`
1449
2584
 
1450
2585
  Args:
1451
2586
  adata: An anndata object from scanpy
1452
2587
  cell_type_identifier: column name in adata.obs that specifies the cell types
1453
2588
  sample_identifier: column name or list of column names in adata.obs that uniquely identify each sample
1454
2589
  covariate_uns: key for adata.uns, where covariate values are stored
1455
- covariate_obs: list of column names in adata.obs, where covariate values are stored. Note: If covariate values are not unique for a value of sample_identifier, this covaariate will be skipped.
2590
+ covariate_obs: list of column names in adata.obs, where covariate values are stored.
2591
+ Note: If covariate values are not unique for a value of sample_identifier, this covariate will be skipped.
1456
2592
  covariate_df: DataFrame with covariates
1457
2593
 
1458
2594
  Returns:
1459
2595
  AnnData: A data set with cells aggregated to the (sample x cell type) level
1460
2596
  """
1461
- if isinstance(sample_identifier, str):
1462
- sample_identifier = [sample_identifier]
1463
-
1464
- if covariate_obs:
1465
- covariate_obs += [i for i in sample_identifier if i not in covariate_obs]
1466
- else:
1467
- covariate_obs = sample_identifier # type: ignore
2597
+ sample_identifier = [sample_identifier] if isinstance(sample_identifier, str) else sample_identifier
2598
+ covariate_obs = list(set(covariate_obs or []) | set(sample_identifier))
1468
2599
 
1469
- # join sample identifiers
1470
2600
  if isinstance(sample_identifier, list):
1471
2601
  adata.obs["scCODA_sample_id"] = adata.obs[sample_identifier].agg("-".join, axis=1)
1472
2602
  sample_identifier = "scCODA_sample_id"
1473
2603
 
1474
- # get cell type counts
1475
2604
  groups = adata.obs.value_counts([sample_identifier, cell_type_identifier])
1476
- count_data = groups.unstack(level=cell_type_identifier)
1477
- count_data = count_data.fillna(0)
1478
-
1479
- # get covariates from different sources
1480
- covariate_df_ = pd.DataFrame(index=count_data.index)
1481
-
1482
- if covariate_df is None and covariate_obs is None and covariate_uns is None:
1483
- print("No covariate information specified!")
2605
+ ct_count_data = groups.unstack(level=cell_type_identifier).fillna(0)
2606
+ covariate_df_ = pd.DataFrame(index=ct_count_data.index)
1484
2607
 
1485
2608
  if covariate_uns is not None:
1486
- covariate_df_uns = pd.DataFrame(adata.uns[covariate_uns])
1487
- covariate_df_ = pd.concat((covariate_df_, covariate_df_uns), axis=1)
2609
+ covariate_df_uns = pd.DataFrame(adata.uns[covariate_uns], index=ct_count_data.index)
2610
+ covariate_df_ = pd.concat([covariate_df_, covariate_df_uns], axis=1)
1488
2611
 
1489
- if covariate_obs is not None:
1490
- for c in covariate_obs:
1491
- if any(adata.obs.groupby(sample_identifier).nunique()[c] != 1):
1492
- print(f"Covariate {c} has non-unique values! Skipping...")
2612
+ if covariate_obs:
2613
+ unique_check = adata.obs.groupby(sample_identifier).nunique()
2614
+ for c in covariate_obs.copy():
2615
+ if unique_check[c].max() != 1:
2616
+ logger.warning(f"Covariate {c} has non-unique values for batch! Skipping...")
1493
2617
  covariate_obs.remove(c)
1494
-
1495
- covariate_df_obs = adata.obs.groupby(sample_identifier).first()[covariate_obs]
1496
- covariate_df_ = pd.concat((covariate_df_, covariate_df_obs), axis=1)
2618
+ if covariate_obs:
2619
+ covariate_df_obs = adata.obs.groupby(sample_identifier).first()[covariate_obs]
2620
+ covariate_df_ = pd.concat([covariate_df_, covariate_df_obs], axis=1)
1497
2621
 
1498
2622
  if covariate_df is not None:
1499
- if set(covariate_df.index) != set(count_data.index):
1500
- raise ValueError("anndata sample names and covariate_df index do not have the same elements!")
1501
- covs_ord = covariate_df.reindex(count_data.index)
1502
- covariate_df_ = pd.concat((covariate_df_, covs_ord), axis=1)
2623
+ if set(covariate_df.index) != set(ct_count_data.index):
2624
+ raise ValueError("Mismatch between sample names in anndata and covariate_df!")
2625
+ covariate_df_ = pd.concat([covariate_df_, covariate_df.reindex(ct_count_data.index)], axis=1)
1503
2626
 
1504
- covariate_df_.index = covariate_df_.index.astype(str)
1505
-
1506
- # create var (number of cells for each type as only column)
1507
- var_dat = count_data.sum(axis=0).rename("n_cells").to_frame()
2627
+ var_dat = ct_count_data.sum().rename("n_cells").to_frame()
1508
2628
  var_dat.index = var_dat.index.astype(str)
2629
+ covariate_df_.index = covariate_df_.index.astype(str)
1509
2630
 
1510
- return AnnData(X=count_data.values, var=var_dat, obs=covariate_df_)
2631
+ return AnnData(X=ct_count_data.values, var=var_dat, obs=covariate_df_)