pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,17 +1,24 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Literal, Optional, Union
5
6
 
6
7
  import arviz as az
7
- import ete3 as ete
8
8
  import jax.numpy as jnp
9
+ import matplotlib.pyplot as plt
9
10
  import numpy as np
10
11
  import pandas as pd
11
12
  import patsy as pt
13
+ import scanpy as sc
14
+ import seaborn as sns
15
+ from adjustText import adjust_text
12
16
  from anndata import AnnData
13
- from jax import random
14
- from jax.config import config
17
+ from jax import config, random
18
+ from lamin_utils import logger
19
+ from matplotlib import cm, rcParams
20
+ from matplotlib import image as mpimg
21
+ from matplotlib.colors import ListedColormap
15
22
  from mudata import MuData
16
23
  from numpyro.infer import HMC, MCMC, NUTS, initialization
17
24
  from rich import box, print
@@ -20,10 +27,15 @@ from rich.table import Table
20
27
  from scipy.cluster import hierarchy as sp_hierarchy
21
28
 
22
29
  if TYPE_CHECKING:
30
+ from collections.abc import Sequence
31
+
23
32
  import numpyro as npy
24
33
  import toytree as tt
25
- from jax._src.prng import PRNGKeyArray
34
+ from ete3 import Tree
26
35
  from jax._src.typing import Array
36
+ from matplotlib.axes import Axes
37
+ from matplotlib.colors import Colormap
38
+ from matplotlib.figure import Figure
27
39
 
28
40
  config.update("jax_enable_x64", True)
29
41
 
@@ -99,9 +111,9 @@ class CompositionalModel2(ABC):
99
111
  Categorical covariates are handled automatically, with the covariate value of the first sample being used as the reference category.
100
112
  To set a different level as the base category for a categorical covariate, use "C(<CovariateName>, Treatment('<ReferenceLevelName>'))"
101
113
  reference_cell_type: Column name that sets the reference cell type.
102
- Reference the name of a column. If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen. Defaults to "automatic".
114
+ Reference the name of a column. If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen.
103
115
  automatic_reference_absence_threshold: If using reference_cell_type = "automatic", determine the maximum fraction of zero entries for a cell type
104
- to be considered as a possible reference cell type. Defaults to 0.05.
116
+ to be considered as a possible reference cell type.
105
117
 
106
118
  Returns:
107
119
  AnnData object that is ready for CODA models.
@@ -137,7 +149,7 @@ class CompositionalModel2(ABC):
137
149
  ref_index = np.where(cell_type_disp == min_var)[0][0]
138
150
 
139
151
  ref_cell_type = cell_types[ref_index]
140
- print(f"[bold blue]Automatic reference selection! Reference cell type set to {ref_cell_type}")
152
+ logger.info(f"Automatic reference selection! Reference cell type set to {ref_cell_type}")
141
153
 
142
154
  # Column name as reference cell type
143
155
  elif reference_cell_type in cell_types:
@@ -149,7 +161,7 @@ class CompositionalModel2(ABC):
149
161
 
150
162
  # Add pseudocount if zeroes are present.
151
163
  if np.count_nonzero(sample_adata.X) != np.size(sample_adata.X):
152
- print("Zero counts encountered in data! Added a pseudocount of 0.5.")
164
+ logger.info("Zero counts encountered in data! Added a pseudocount of 0.5.")
153
165
  sample_adata.X[sample_adata.X == 0] = 0.5
154
166
 
155
167
  sample_adata.obsm["sample_counts"] = np.sum(sample_adata.X, axis=1)
@@ -179,7 +191,7 @@ class CompositionalModel2(ABC):
179
191
  self,
180
192
  sample_adata: AnnData,
181
193
  kernel: npy.infer.mcmc.MCMCKernel,
182
- rng_key: Array | PRNGKeyArray,
194
+ rng_key: Array,
183
195
  copy: bool = False,
184
196
  *args,
185
197
  **kwargs,
@@ -190,7 +202,7 @@ class CompositionalModel2(ABC):
190
202
  sample_adata: anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
191
203
  kernel: A `numpyro.infer.mcmc.MCMCKernel` object
192
204
  rng_key: The rng state used. If None, a random state will be selected
193
- copy: Return a copy instead of writing to adata. Defaults to False.
205
+ copy: Return a copy instead of writing to adata.
194
206
  args: Passed to `numpyro.infer.mcmc.MCMC`
195
207
  kwargs: Passed to `numpyro.infer.mcmc.MCMC`
196
208
 
@@ -226,13 +238,13 @@ class CompositionalModel2(ABC):
226
238
 
227
239
  acc_rate = np.array(self.mcmc.last_state.mean_accept_prob)
228
240
  if acc_rate < 0.6:
229
- print(
230
- f"[bold red]Acceptance rate unusually low ({acc_rate} < 0.5)! Results might be incorrect! "
241
+ logger.warning(
242
+ f"Acceptance rate unusually low ({acc_rate} < 0.5)! Results might be incorrect! "
231
243
  f"Please check feasibility of results and re-run the sampling step with a different rng_key if necessary."
232
244
  )
233
245
  if acc_rate > 0.95:
234
- print(
235
- f"[bold red]Acceptance rate unusually high ({acc_rate} > 0.95)! Results might be incorrect! "
246
+ logger.warning(
247
+ f"Acceptance rate unusually high ({acc_rate} > 0.95)! Results might be incorrect! "
236
248
  f"Please check feasibility of results and re-run the sampling step with a different rng_key if necessary."
237
249
  )
238
250
 
@@ -275,11 +287,11 @@ class CompositionalModel2(ABC):
275
287
 
276
288
  Args:
277
289
  data: AnnData object or MuData object.
278
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
279
- num_samples: Number of sampled values after burn-in. Defaults to 10000.
280
- num_warmup: Number of burn-in (warmup) samples. Defaults to 1000.
281
- rng_key: The rng state used. Defaults to 0.
282
- copy: Return a copy instead of writing to adata. Defaults to False.
290
+ modality_key: If data is a MuData object, specify which modality to use.
291
+ num_samples: Number of sampled values after burn-in.
292
+ num_warmup: Number of burn-in (warmup) samples.
293
+ rng_key: The rng state used.
294
+ copy: Return a copy instead of writing to adata.
283
295
 
284
296
  Returns:
285
297
  Calls `self.__run_mcmc`
@@ -288,14 +300,14 @@ class CompositionalModel2(ABC):
288
300
  try:
289
301
  sample_adata = data[modality_key]
290
302
  except IndexError:
291
- print("When data is a MuData object, modality_key must be specified!")
303
+ logger.error("When data is a MuData object, modality_key must be specified!")
292
304
  raise
293
305
  if isinstance(data, AnnData):
294
306
  sample_adata = data
295
307
  if copy:
296
308
  sample_adata = sample_adata.copy()
297
309
 
298
- rng_key_array = random.PRNGKey(rng_key)
310
+ rng_key_array = random.key(rng_key)
299
311
  sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = np.array(rng_key_array)
300
312
 
301
313
  # Set up NUTS kernel
@@ -328,14 +340,13 @@ class CompositionalModel2(ABC):
328
340
 
329
341
  Args:
330
342
  data: AnnData object or MuData object.
331
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
332
- num_samples: Number of sampled values after burn-in. Defaults to 20000.
333
- num_warmup: Number of burn-in (warmup) samples. Defaults to 5000.
334
- rng_key: The rng state used. If None, a random state will be selected. Defaults to None.
335
- copy: Return a copy instead of writing to adata. Defaults to False.
343
+ modality_key: If data is a MuData object, specify which modality to use.
344
+ num_samples: Number of sampled values after burn-in.
345
+ num_warmup: Number of burn-in (warmup) samples.
346
+ rng_key: The rng state used. If None, a random state will be selected.
347
+ copy: Return a copy instead of writing to adata.
336
348
 
337
349
  Examples:
338
- Example with scCODA:
339
350
  >>> import pertpy as pt
340
351
  >>> haber_cells = pt.dt.haber_2017_regions()
341
352
  >>> sccoda = pt.tl.Sccoda()
@@ -348,7 +359,7 @@ class CompositionalModel2(ABC):
348
359
  try:
349
360
  sample_adata = data[modality_key]
350
361
  except IndexError:
351
- print("When data is a MuData object, modality_key must be specified!")
362
+ logger.error("When data is a MuData object, modality_key must be specified!")
352
363
  raise
353
364
  if isinstance(data, AnnData):
354
365
  sample_adata = data
@@ -358,10 +369,10 @@ class CompositionalModel2(ABC):
358
369
  # Set rng key if needed
359
370
  if rng_key is None:
360
371
  rng = np.random.default_rng()
361
- rng_key = random.PRNGKey(rng.integers(0, 10000))
372
+ rng_key = random.key(rng.integers(0, 10000))
362
373
  sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = rng_key
363
374
  else:
364
- rng_key = random.PRNGKey(rng_key)
375
+ rng_key = random.key(rng_key)
365
376
 
366
377
  # Set up HMC kernel
367
378
  sample_adata = self.set_init_mcmc_states(
@@ -387,7 +398,7 @@ class CompositionalModel2(ABC):
387
398
 
388
399
  Args:
389
400
  sample_adata: Anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
390
- est_fdr: Desired FDR value. Defaults to 0.05.
401
+ est_fdr: Desired FDR value.
391
402
  args: Passed to ``az.summary``
392
403
  kwargs: Passed to ``az.summary``
393
404
 
@@ -423,7 +434,6 @@ class CompositionalModel2(ABC):
423
434
  - Is credible: Boolean indicator whether effect is credible
424
435
 
425
436
  Examples:
426
- Example with scCODA:
427
437
  >>> import pertpy as pt
428
438
  >>> haber_cells = pt.dt.haber_2017_regions()
429
439
  >>> sccoda = pt.tl.Sccoda()
@@ -433,7 +443,6 @@ class CompositionalModel2(ABC):
433
443
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
434
444
  >>> intercept_df, effect_df = sccoda.summary_prepare(mdata["coda"])
435
445
  """
436
- # Get model and effect selection types
437
446
  select_type = sample_adata.uns["scCODA_params"]["select_type"]
438
447
  model_type = sample_adata.uns["scCODA_params"]["model_type"]
439
448
 
@@ -548,7 +557,11 @@ class CompositionalModel2(ABC):
548
557
  intercept_df = intercept_df.loc[:, ["final_parameter", hdis[0], hdis[1], "sd", "expected_sample"]].copy()
549
558
  intercept_df = intercept_df.rename(
550
559
  columns=dict(
551
- zip(intercept_df.columns, ["Final Parameter", hdis_new[0], hdis_new[1], "SD", "Expected Sample"])
560
+ zip(
561
+ intercept_df.columns,
562
+ ["Final Parameter", hdis_new[0], hdis_new[1], "SD", "Expected Sample"],
563
+ strict=False,
564
+ )
552
565
  )
553
566
  )
554
567
 
@@ -561,6 +574,7 @@ class CompositionalModel2(ABC):
561
574
  zip(
562
575
  effect_df.columns,
563
576
  ["Effect", "Median", hdis_new[0], hdis_new[1], "SD", "Expected Sample", "log2-fold change"],
577
+ strict=False,
564
578
  )
565
579
  )
566
580
  )
@@ -581,6 +595,7 @@ class CompositionalModel2(ABC):
581
595
  "Expected Sample",
582
596
  "log2-fold change",
583
597
  ],
598
+ strict=False,
584
599
  )
585
600
  )
586
601
  )
@@ -594,6 +609,7 @@ class CompositionalModel2(ABC):
594
609
  zip(
595
610
  node_df.columns,
596
611
  ["Final Parameter", "Median", hdis_new[0], hdis_new[1], "SD", "Delta", "Is credible"],
612
+ strict=False,
597
613
  )
598
614
  ) # type: ignore
599
615
  ) # type: ignore
@@ -622,8 +638,8 @@ class CompositionalModel2(ABC):
622
638
  effect_df: Effect summary, see ``summary_prepare``
623
639
  model_type: String indicating the model type ("classic" or "tree_agg")
624
640
  select_type: String indicating the type of spike_and_slab selection ("spikeslab" or "sslasso")
625
- target_fdr: Desired FDR value. Defaults to 0.05.
626
- node_df: If using tree aggregation, the node-level effect DataFrame must be passed. Defaults to None.
641
+ target_fdr: Desired FDR value.
642
+ node_df: If using tree aggregation, the node-level effect DataFrame must be passed.
627
643
 
628
644
  Returns:
629
645
  pd.DataFrame: effect DataFrame with inclusion probability, final parameters, expected sample.
@@ -775,13 +791,12 @@ class CompositionalModel2(ABC):
775
791
 
776
792
  Args:
777
793
  data: AnnData object or MuData object.
778
- extended: If True, return the extended summary with additional statistics. Defaults to False.
779
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
794
+ extended: If True, return the extended summary with additional statistics.
795
+ modality_key: If data is a MuData object, specify which modality to use.
780
796
  args: Passed to az.summary
781
797
  kwargs: Passed to az.summary
782
798
 
783
799
  Examples:
784
- Example with scCODA:
785
800
  >>> import pertpy as pt
786
801
  >>> haber_cells = pt.dt.haber_2017_regions()
787
802
  >>> sccoda = pt.tl.Sccoda()
@@ -795,11 +810,11 @@ class CompositionalModel2(ABC):
795
810
  try:
796
811
  sample_adata = data[modality_key]
797
812
  except IndexError:
798
- print("[bold red]When data is a MuData object, modality_key must be specified!")
813
+ logger.error("When data is a MuData object, modality_key must be specified!")
799
814
  raise
800
815
  if isinstance(data, AnnData):
801
816
  sample_adata = data
802
- # Get model and effect selection types
817
+
803
818
  select_type = sample_adata.uns["scCODA_params"]["select_type"]
804
819
  model_type = sample_adata.uns["scCODA_params"]["model_type"]
805
820
 
@@ -834,10 +849,10 @@ class CompositionalModel2(ABC):
834
849
  table.add_column("Name", justify="left", style="cyan")
835
850
  table.add_column("Value", justify="left")
836
851
  table.add_row("Data", "Data: %d samples, %d cell types" % data_dims)
837
- table.add_row("Reference cell type", "%s" % str(sample_adata.uns["scCODA_params"]["reference_cell_type"]))
838
- table.add_row("Formula", "%s" % sample_adata.uns["scCODA_params"]["formula"])
852
+ table.add_row("Reference cell type", "{}".format(str(sample_adata.uns["scCODA_params"]["reference_cell_type"])))
853
+ table.add_row("Formula", "{}".format(sample_adata.uns["scCODA_params"]["formula"]))
839
854
  if extended:
840
- table.add_row("Reference index", "%s" % str(sample_adata.uns["scCODA_params"]["reference_index"]))
855
+ table.add_row("Reference index", "{}".format(str(sample_adata.uns["scCODA_params"]["reference_index"])))
841
856
  if select_type == "spikeslab":
842
857
  table.add_row(
843
858
  "Spike-and-slab threshold",
@@ -920,13 +935,12 @@ class CompositionalModel2(ABC):
920
935
 
921
936
  Args:
922
937
  data: AnnData object or MuData object.
923
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
938
+ modality_key: If data is a MuData object, specify which modality to use.
924
939
 
925
940
  Returns:
926
941
  pd.DataFrame: Intercept data frame.
927
942
 
928
943
  Examples:
929
- Example with scCODA:
930
944
  >>> import pertpy as pt
931
945
  >>> haber_cells = pt.dt.haber_2017_regions()
932
946
  >>> sccoda = pt.tl.Sccoda()
@@ -936,12 +950,11 @@ class CompositionalModel2(ABC):
936
950
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
937
951
  >>> intercepts = sccoda.get_intercept_df(mdata)
938
952
  """
939
-
940
953
  if isinstance(data, MuData):
941
954
  try:
942
955
  sample_adata = data[modality_key]
943
956
  except IndexError:
944
- print("When data is a MuData object, modality_key must be specified!")
957
+ logger.error("When data is a MuData object, modality_key must be specified!")
945
958
  raise
946
959
  if isinstance(data, AnnData):
947
960
  sample_adata = data
@@ -953,13 +966,12 @@ class CompositionalModel2(ABC):
953
966
 
954
967
  Args:
955
968
  data: AnnData object or MuData object.
956
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
969
+ modality_key: If data is a MuData object, specify which modality to use.
957
970
 
958
971
  Returns:
959
972
  pd.DataFrame: Effect data frame.
960
973
 
961
974
  Examples:
962
- Example with scCODA:
963
975
  >>> import pertpy as pt
964
976
  >>> haber_cells = pt.dt.haber_2017_regions()
965
977
  >>> sccoda = pt.tl.Sccoda()
@@ -969,12 +981,11 @@ class CompositionalModel2(ABC):
969
981
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
970
982
  >>> effects = sccoda.get_effect_df(mdata)
971
983
  """
972
-
973
984
  if isinstance(data, MuData):
974
985
  try:
975
986
  sample_adata = data[modality_key]
976
987
  except IndexError:
977
- print("When data is a MuData object, modality_key must be specified!")
988
+ logger.error("When data is a MuData object, modality_key must be specified!")
978
989
  raise
979
990
  if isinstance(data, AnnData):
980
991
  sample_adata = data
@@ -997,15 +1008,14 @@ class CompositionalModel2(ABC):
997
1008
 
998
1009
  Args:
999
1010
  data: AnnData object or MuData object.
1000
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
1011
+ modality_key: If data is a MuData object, specify which modality to use.
1001
1012
 
1002
1013
  Returns:
1003
1014
  pd.DataFrame: Node effect data frame.
1004
1015
 
1005
1016
  Examples:
1006
- Example with tascCODA (works only for model of type tree_agg, i.e. a tascCODA model):
1007
1017
  >>> import pertpy as pt
1008
- >>> adata = pt.dt.smillie()
1018
+ >>> adata = pt.dt.tasccoda_example()
1009
1019
  >>> tasccoda = pt.tl.Tasccoda()
1010
1020
  >>> mdata = tasccoda.load(
1011
1021
  >>> adata, type="sample_level",
@@ -1023,7 +1033,7 @@ class CompositionalModel2(ABC):
1023
1033
  try:
1024
1034
  sample_adata = data[modality_key]
1025
1035
  except IndexError:
1026
- print("When data is a MuData object, modality_key must be specified!")
1036
+ logger.error("When data is a MuData object, modality_key must be specified!")
1027
1037
  raise
1028
1038
  if isinstance(data, AnnData):
1029
1039
  sample_adata = data
@@ -1037,7 +1047,7 @@ class CompositionalModel2(ABC):
1037
1047
  Args:
1038
1048
  data: AnnData object or MuData object.
1039
1049
  est_fdr: Desired FDR value.
1040
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
1050
+ modality_key: If data is a MuData object, specify which modality to use.
1041
1051
  args: passed to self.summary_prepare
1042
1052
  kwargs: passed to self.summary_prepare
1043
1053
 
@@ -1048,7 +1058,7 @@ class CompositionalModel2(ABC):
1048
1058
  try:
1049
1059
  sample_adata = data[modality_key]
1050
1060
  except IndexError:
1051
- print("When data is a MuData object, modality_key must be specified!")
1061
+ logger.error("When data is a MuData object, modality_key must be specified!")
1052
1062
  raise
1053
1063
  if isinstance(data, AnnData):
1054
1064
  sample_adata = data
@@ -1071,8 +1081,8 @@ class CompositionalModel2(ABC):
1071
1081
 
1072
1082
  Args:
1073
1083
  data: AnnData object or MuData object.
1074
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
1075
- est_fdr: Estimated false discovery rate. Must be between 0 and 1. Defaults to None.
1084
+ modality_key: If data is a MuData object, specify which modality to use.
1085
+ est_fdr: Estimated false discovery rate. Must be between 0 and 1.
1076
1086
 
1077
1087
  Returns:
1078
1088
  pd.Series: Credible effect decision series which includes boolean values indicate whether effects are credible under inc_prob_threshold.
@@ -1081,7 +1091,7 @@ class CompositionalModel2(ABC):
1081
1091
  try:
1082
1092
  sample_adata = data[modality_key]
1083
1093
  except IndexError:
1084
- print("When data is a MuData object, modality_key must be specified!")
1094
+ logger.error("When data is a MuData object, modality_key must be specified!")
1085
1095
  raise
1086
1096
  if isinstance(data, AnnData):
1087
1097
  sample_adata = data
@@ -1113,9 +1123,1120 @@ class CompositionalModel2(ABC):
1113
1123
 
1114
1124
  return out
1115
1125
 
1126
+ def _stackbar( # pragma: no cover
1127
+ self,
1128
+ y: np.ndarray,
1129
+ type_names: list[str],
1130
+ title: str,
1131
+ level_names: list[str],
1132
+ figsize: tuple[float, float] | None = None,
1133
+ dpi: int | None = 100,
1134
+ palette: ListedColormap | None = cm.tab20,
1135
+ show_legend: bool | None = True,
1136
+ ) -> plt.Axes:
1137
+ """Plots a stacked barplot for one (discrete) covariate.
1138
+
1139
+ Typical use (only inside stacked_barplot): plot_one_stackbar(data.X, data.var.index, "xyz", data.obs.index)
1140
+
1141
+ Args:
1142
+ y: The count data, collapsed onto the level of interest. i.e. a binary covariate has two rows,
1143
+ one for each group, containing the count mean of each cell type
1144
+ type_names: The names of all cell types
1145
+ title: Plot title, usually the covariate's name
1146
+ level_names: Names of the covariate's levels
1147
+ figsize: Figure size (matplotlib).
1148
+ dpi: Resolution in DPI (matplotlib).
1149
+ palette: The color map for the barplot.
1150
+ show_legend: If True, adds a legend.
1151
+
1152
+ Returns:
1153
+ A :class:`~matplotlib.axes.Axes` object
1154
+ """
1155
+ n_bars, n_types = y.shape
1156
+
1157
+ figsize = rcParams["figure.figsize"] if figsize is None else figsize
1158
+
1159
+ _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1160
+ r = np.array(range(n_bars))
1161
+ sample_sums = np.sum(y, axis=1)
1162
+
1163
+ barwidth = 0.85
1164
+ cum_bars = np.zeros(n_bars)
1165
+
1166
+ for n in range(n_types):
1167
+ bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums, strict=False)]
1168
+ plt.bar(
1169
+ r,
1170
+ bars,
1171
+ bottom=cum_bars,
1172
+ color=palette(n % palette.N),
1173
+ width=barwidth,
1174
+ label=type_names[n],
1175
+ linewidth=0,
1176
+ )
1177
+ cum_bars += bars
1178
+
1179
+ ax.set_title(title)
1180
+ if show_legend:
1181
+ ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1)
1182
+ ax.set_xticks(r)
1183
+ ax.set_xticklabels(level_names, rotation=45, ha="right")
1184
+ ax.set_ylabel("Proportion")
1185
+
1186
+ return ax
1187
+
1188
+ def plot_stacked_barplot( # pragma: no cover
1189
+ self,
1190
+ data: AnnData | MuData,
1191
+ feature_name: str,
1192
+ modality_key: str = "coda",
1193
+ palette: ListedColormap | None = cm.tab20,
1194
+ show_legend: bool | None = True,
1195
+ level_order: list[str] = None,
1196
+ figsize: tuple[float, float] | None = None,
1197
+ dpi: int | None = 100,
1198
+ return_fig: bool | None = None,
1199
+ ax: plt.Axes | None = None,
1200
+ show: bool | None = None,
1201
+ save: str | bool | None = None,
1202
+ **kwargs,
1203
+ ) -> plt.Axes | plt.Figure | None:
1204
+ """Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
1205
+
1206
+ Args:
1207
+ data: AnnData object or MuData object.
1208
+ feature_name: The name of the covariate to plot. If feature_name=="samples", one bar for every sample will be plotted
1209
+ modality_key: If data is a MuData object, specify which modality to use.
1210
+ figsize: Figure size.
1211
+ dpi: Dpi setting.
1212
+ palette: The matplotlib color map for the barplot.
1213
+ show_legend: If True, adds a legend.
1214
+ level_order: Custom ordering of bars on the x-axis.
1215
+
1216
+ Returns:
1217
+ A :class:`~matplotlib.axes.Axes` object
1218
+
1219
+ Examples:
1220
+ >>> import pertpy as pt
1221
+ >>> haber_cells = pt.dt.haber_2017_regions()
1222
+ >>> sccoda = pt.tl.Sccoda()
1223
+ >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
1224
+ sample_identifier="batch", covariate_obs=["condition"])
1225
+ >>> sccoda.plot_stacked_barplot(mdata, feature_name="samples")
1226
+
1227
+ Preview:
1228
+ .. image:: /_static/docstring_previews/sccoda_stacked_barplot.png
1229
+ """
1230
+ if isinstance(data, MuData):
1231
+ data = data[modality_key]
1232
+ if isinstance(data, AnnData):
1233
+ data = data
1234
+
1235
+ ct_names = data.var.index
1236
+
1237
+ # option to plot one stacked barplot per sample
1238
+ if feature_name == "samples":
1239
+ if level_order:
1240
+ assert set(level_order) == set(data.obs.index), "level order is inconsistent with levels"
1241
+ data = data[level_order]
1242
+ ax = self._stackbar(
1243
+ data.X,
1244
+ type_names=data.var.index,
1245
+ title="samples",
1246
+ level_names=data.obs.index,
1247
+ figsize=figsize,
1248
+ dpi=dpi,
1249
+ palette=palette,
1250
+ show_legend=show_legend,
1251
+ )
1252
+ else:
1253
+ # Order levels
1254
+ if level_order:
1255
+ assert set(level_order) == set(data.obs[feature_name]), "level order is inconsistent with levels"
1256
+ levels = level_order
1257
+ elif hasattr(data.obs[feature_name], "cat"):
1258
+ levels = data.obs[feature_name].cat.categories.to_list()
1259
+ else:
1260
+ levels = pd.unique(data.obs[feature_name])
1261
+ n_levels = len(levels)
1262
+ feature_totals = np.zeros([n_levels, data.X.shape[1]])
1263
+
1264
+ for level in range(n_levels):
1265
+ l_indices = np.where(data.obs[feature_name] == levels[level])
1266
+ feature_totals[level] = np.sum(data.X[l_indices], axis=0)
1267
+
1268
+ ax = self._stackbar(
1269
+ feature_totals,
1270
+ type_names=ct_names,
1271
+ title=feature_name,
1272
+ level_names=levels,
1273
+ figsize=figsize,
1274
+ dpi=dpi,
1275
+ palette=palette,
1276
+ show_legend=show_legend,
1277
+ )
1278
+
1279
+ if save:
1280
+ plt.savefig(save, bbox_inches="tight")
1281
+ if show:
1282
+ plt.show()
1283
+ if return_fig:
1284
+ return plt.gcf()
1285
+ if not (show or save):
1286
+ return ax
1287
+ return None
1288
+
1289
+ def plot_effects_barplot( # pragma: no cover
1290
+ self,
1291
+ data: AnnData | MuData,
1292
+ modality_key: str = "coda",
1293
+ covariates: str | list | None = None,
1294
+ parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
1295
+ plot_facets: bool = True,
1296
+ plot_zero_covariate: bool = True,
1297
+ plot_zero_cell_type: bool = False,
1298
+ palette: str | ListedColormap | None = cm.tab20,
1299
+ level_order: list[str] = None,
1300
+ args_barplot: dict | None = None,
1301
+ figsize: tuple[float, float] | None = None,
1302
+ dpi: int | None = 100,
1303
+ return_fig: bool | None = None,
1304
+ ax: plt.Axes | None = None,
1305
+ show: bool | None = None,
1306
+ save: str | bool | None = None,
1307
+ ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
1308
+ """Barplot visualization for effects.
1309
+
1310
+ The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
1311
+ The covariates groups can either be ordered along the x-axis of a single plot (plot_facets=False) or as plot facets (plot_facets=True).
1312
+
1313
+ Args:
1314
+ data: AnnData object or MuData object.
1315
+ modality_key: If data is a MuData object, specify which modality to use.
1316
+ covariates: The name of the covariates in data.obs to plot.
1317
+ parameter: The parameter in effect summary to plot.
1318
+ plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
1319
+ plot_zero_covariate: If True, plot covariate that have all zero effects. If False, do not plot.
1320
+ plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot.
1321
+ figsize: Figure size.
1322
+ dpi: Figure size.
1323
+ palette: The seaborn color map for the barplot.
1324
+ level_order: Custom ordering of bars on the x-axis.
1325
+ args_barplot: Arguments passed to sns.barplot.
1326
+
1327
+ Returns:
1328
+ Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
1329
+ or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
1330
+
1331
+ Examples:
1332
+ >>> import pertpy as pt
1333
+ >>> haber_cells = pt.dt.haber_2017_regions()
1334
+ >>> sccoda = pt.tl.Sccoda()
1335
+ >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
1336
+ sample_identifier="batch", covariate_obs=["condition"])
1337
+ >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
1338
+ >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
1339
+ >>> sccoda.plot_effects_barplot(mdata)
1340
+
1341
+ Preview:
1342
+ .. image:: /_static/docstring_previews/sccoda_effects_barplot.png
1343
+ """
1344
+ if args_barplot is None:
1345
+ args_barplot = {}
1346
+ if isinstance(data, MuData):
1347
+ data = data[modality_key]
1348
+ if isinstance(data, AnnData):
1349
+ data = data
1350
+ # Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
1351
+ covariate_names = data.uns["scCODA_params"]["covariate_names"]
1352
+ if covariates is not None:
1353
+ if isinstance(covariates, str):
1354
+ covariates = [covariates]
1355
+ partial_covariate_names = [
1356
+ covariate_name
1357
+ for covariate_name in covariate_names
1358
+ if any(covariate in covariate_name for covariate in covariates)
1359
+ ]
1360
+ covariate_names = partial_covariate_names
1361
+ covariate_names_non_zero = [
1362
+ covariate_name
1363
+ for covariate_name in covariate_names
1364
+ if data.varm[f"effect_df_{covariate_name}"][parameter].any()
1365
+ ]
1366
+ covariate_names_zero = list(set(covariate_names) - set(covariate_names_non_zero))
1367
+ if not plot_zero_covariate:
1368
+ covariate_names = covariate_names_non_zero
1369
+
1370
+ # set up df for plotting
1371
+ plot_df = pd.concat(
1372
+ [data.varm[f"effect_df_{covariate_name}"][parameter] for covariate_name in covariate_names],
1373
+ axis=1,
1374
+ )
1375
+ plot_df.columns = covariate_names
1376
+ plot_df = pd.melt(plot_df, ignore_index=False, var_name="Covariate")
1377
+
1378
+ plot_df = plot_df.reset_index()
1379
+
1380
+ if len(covariate_names_zero) != 0:
1381
+ if plot_facets:
1382
+ if plot_zero_covariate and not plot_zero_cell_type:
1383
+ plot_df = plot_df[plot_df["value"] != 0]
1384
+ for covariate_name_zero in covariate_names_zero:
1385
+ new_row = {
1386
+ "Covariate": covariate_name_zero,
1387
+ "Cell Type": "zero",
1388
+ "value": 0,
1389
+ }
1390
+ plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
1391
+ plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
1392
+ plot_df = plot_df.sort_values(["covariate_"])
1393
+ if not plot_zero_cell_type:
1394
+ cell_type_names_zero = [
1395
+ name
1396
+ for name in plot_df["Cell Type"].unique()
1397
+ if (plot_df[plot_df["Cell Type"] == name]["value"] == 0).all()
1398
+ ]
1399
+ plot_df = plot_df[~plot_df["Cell Type"].isin(cell_type_names_zero)]
1400
+
1401
+ # If plot as facets, create a FacetGrid and map barplot to it.
1402
+ if plot_facets:
1403
+ if isinstance(palette, ListedColormap):
1404
+ palette = np.array([palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]).tolist()
1405
+ if figsize is not None:
1406
+ height = figsize[0]
1407
+ aspect = np.round(figsize[1] / figsize[0], 2)
1408
+ else:
1409
+ height = 3
1410
+ aspect = 2
1411
+
1412
+ g = sns.FacetGrid(
1413
+ plot_df,
1414
+ col="Covariate",
1415
+ sharey=True,
1416
+ sharex=False,
1417
+ height=height,
1418
+ aspect=aspect,
1419
+ )
1420
+
1421
+ g.map(
1422
+ sns.barplot,
1423
+ "Cell Type",
1424
+ "value",
1425
+ palette=palette,
1426
+ order=level_order,
1427
+ **args_barplot,
1428
+ )
1429
+ g.set_xticklabels(rotation=90)
1430
+ g.set(ylabel=parameter)
1431
+ axes = g.axes.flatten()
1432
+ for i, ax in enumerate(axes):
1433
+ ax.set_title(covariate_names[i])
1434
+ if len(ax.get_xticklabels()) < 5:
1435
+ ax.set_aspect(10 / len(ax.get_xticklabels()))
1436
+ if len(ax.get_xticklabels()) == 1:
1437
+ if ax.get_xticklabels()[0]._text == "zero":
1438
+ ax.set_xticks([])
1439
+
1440
+ if save:
1441
+ plt.savefig(save, bbox_inches="tight")
1442
+ if show:
1443
+ plt.show()
1444
+ if return_fig:
1445
+ return plt.gcf()
1446
+ if not (show or save):
1447
+ return g
1448
+ return None
1449
+
1450
+ # If not plot as facets, call barplot to plot cell types on the x-axis.
1451
+ else:
1452
+ _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1453
+ if len(covariate_names) == 1:
1454
+ if isinstance(palette, ListedColormap):
1455
+ palette = np.array(
1456
+ [palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]
1457
+ ).tolist()
1458
+ sns.barplot(
1459
+ data=plot_df,
1460
+ x="Cell Type",
1461
+ y="value",
1462
+ hue="x",
1463
+ palette=palette,
1464
+ ax=ax,
1465
+ )
1466
+ ax.set_title(covariate_names[0])
1467
+ else:
1468
+ if isinstance(palette, ListedColormap):
1469
+ palette = np.array([palette(i % palette.N) for i in range(len(covariate_names))]).tolist()
1470
+ sns.barplot(
1471
+ data=plot_df,
1472
+ x="Cell Type",
1473
+ y="value",
1474
+ hue="Covariate",
1475
+ palette=palette,
1476
+ ax=ax,
1477
+ )
1478
+ cell_types = pd.unique(plot_df["Cell Type"])
1479
+ ax.set_xticklabels(cell_types, rotation=90)
1480
+
1481
+ if save:
1482
+ plt.savefig(save, bbox_inches="tight")
1483
+ if show:
1484
+ plt.show()
1485
+ if return_fig:
1486
+ return plt.gcf()
1487
+ if not (show or save):
1488
+ return ax
1489
+ return None
1490
+
1491
+ def plot_boxplots( # pragma: no cover
1492
+ self,
1493
+ data: AnnData | MuData,
1494
+ feature_name: str,
1495
+ modality_key: str = "coda",
1496
+ y_scale: Literal["relative", "log", "log10", "count"] = "relative",
1497
+ plot_facets: bool = False,
1498
+ add_dots: bool = False,
1499
+ cell_types: list | None = None,
1500
+ args_boxplot: dict | None = None,
1501
+ args_swarmplot: dict | None = None,
1502
+ palette: str | None = "Blues",
1503
+ show_legend: bool | None = True,
1504
+ level_order: list[str] = None,
1505
+ figsize: tuple[float, float] | None = None,
1506
+ dpi: int | None = 100,
1507
+ return_fig: bool | None = None,
1508
+ ax: plt.Axes | None = None,
1509
+ show: bool | None = None,
1510
+ save: str | bool | None = None,
1511
+ ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
1512
+ """Grouped boxplot visualization.
1513
+
1514
+ The cell counts for each cell type are shown as a group of boxplots
1515
+ with intra--group separation by a covariate from data.obs.
1516
+
1517
+ Args:
1518
+ data: AnnData object or MuData object
1519
+ feature_name: The name of the feature in data.obs to plot
1520
+ modality_key: If data is a MuData object, specify which modality to use.
1521
+ y_scale: Transformation to of cell counts. Options: "relative" - Relative abundance, "log" - log(count),
1522
+ "log10" - log10(count), "count" - absolute abundance (cell counts).
1523
+ plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
1524
+ add_dots: If True, overlay a scatterplot with one dot for each data point.
1525
+ cell_types: Subset of cell types that should be plotted.
1526
+ args_boxplot: Arguments passed to sns.boxplot.
1527
+ args_swarmplot: Arguments passed to sns.swarmplot.
1528
+ figsize: Figure size.
1529
+ dpi: Dpi setting.
1530
+ palette: The seaborn color map for the barplot.
1531
+ show_legend: If True, adds a legend.
1532
+ level_order: Custom ordering of bars on the x-axis.
1533
+
1534
+ Returns:
1535
+ Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
1536
+ or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
1537
+
1538
+ Examples:
1539
+ >>> import pertpy as pt
1540
+ >>> haber_cells = pt.dt.haber_2017_regions()
1541
+ >>> sccoda = pt.tl.Sccoda()
1542
+ >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
1543
+ sample_identifier="batch", covariate_obs=["condition"])
1544
+ >>> sccoda.plot_boxplots(mdata, feature_name="condition", add_dots=True)
1545
+
1546
+ Preview:
1547
+ .. image:: /_static/docstring_previews/sccoda_boxplots.png
1548
+ """
1549
+ if args_boxplot is None:
1550
+ args_boxplot = {}
1551
+ if args_swarmplot is None:
1552
+ args_swarmplot = {}
1553
+ if isinstance(data, MuData):
1554
+ data = data[modality_key]
1555
+ if isinstance(data, AnnData):
1556
+ data = data
1557
+ # y scale transformations
1558
+ if y_scale == "relative":
1559
+ sample_sums = np.sum(data.X, axis=1, keepdims=True)
1560
+ X = data.X / sample_sums
1561
+ value_name = "Proportion"
1562
+ # add pseudocount 0.5 if using log scale
1563
+ elif y_scale == "log":
1564
+ X = data.X.copy()
1565
+ X[X == 0] = 0.5
1566
+ X = np.log(X)
1567
+ value_name = "log(count)"
1568
+ elif y_scale == "log10":
1569
+ X = data.X.copy()
1570
+ X[X == 0] = 0.5
1571
+ X = np.log(X)
1572
+ value_name = "log10(count)"
1573
+ elif y_scale == "count":
1574
+ X = data.X
1575
+ value_name = "count"
1576
+ else:
1577
+ raise ValueError("Invalid y_scale transformation")
1578
+
1579
+ count_df = pd.DataFrame(X, columns=data.var.index, index=data.obs.index).merge(
1580
+ data.obs[feature_name], left_index=True, right_index=True
1581
+ )
1582
+ plot_df = pd.melt(count_df, id_vars=feature_name, var_name="Cell type", value_name=value_name)
1583
+ if cell_types is not None:
1584
+ plot_df = plot_df[plot_df["Cell type"].isin(cell_types)]
1585
+
1586
+ # Currently disabled because the latest statsannotations does not support the latest seaborn.
1587
+ # We had to drop the dependency.
1588
+ # Get credible effects results from model
1589
+ # if draw_effects:
1590
+ # if model is not None:
1591
+ # credible_effects_df = model.credible_effects(data, modality_key).to_frame().reset_index()
1592
+ # else:
1593
+ # print("[bold yellow]Specify a tasCODA model to draw effects")
1594
+ # credible_effects_df[feature_name] = credible_effects_df["Covariate"].str.removeprefix(f"{feature_name}[T.")
1595
+ # credible_effects_df[feature_name] = credible_effects_df[feature_name].str.removesuffix("]")
1596
+ # credible_effects_df = credible_effects_df[credible_effects_df["Final Parameter"]]
1597
+
1598
+ # If plot as facets, create a FacetGrid and map boxplot to it.
1599
+ if plot_facets:
1600
+ if level_order is None:
1601
+ level_order = pd.unique(plot_df[feature_name])
1602
+
1603
+ K = X.shape[1]
1604
+
1605
+ if figsize is not None:
1606
+ height = figsize[0]
1607
+ aspect = np.round(figsize[1] / figsize[0], 2)
1608
+ else:
1609
+ height = 3
1610
+ aspect = 2
1611
+
1612
+ g = sns.FacetGrid(
1613
+ plot_df,
1614
+ col="Cell type",
1615
+ sharey=False,
1616
+ col_wrap=int(np.floor(np.sqrt(K))),
1617
+ height=height,
1618
+ aspect=aspect,
1619
+ )
1620
+ g.map(
1621
+ sns.boxplot,
1622
+ feature_name,
1623
+ value_name,
1624
+ palette=palette,
1625
+ order=level_order,
1626
+ **args_boxplot,
1627
+ )
1628
+
1629
+ if add_dots:
1630
+ if "hue" in args_swarmplot:
1631
+ hue = args_swarmplot.pop("hue")
1632
+ else:
1633
+ hue = None
1634
+
1635
+ if hue is None:
1636
+ g.map(
1637
+ sns.swarmplot,
1638
+ feature_name,
1639
+ value_name,
1640
+ color="black",
1641
+ order=level_order,
1642
+ **args_swarmplot,
1643
+ ).set_titles("{col_name}")
1644
+ else:
1645
+ g.map(
1646
+ sns.swarmplot,
1647
+ feature_name,
1648
+ value_name,
1649
+ hue,
1650
+ order=level_order,
1651
+ **args_swarmplot,
1652
+ ).set_titles("{col_name}")
1653
+
1654
+ if save:
1655
+ plt.savefig(save, bbox_inches="tight")
1656
+ if show:
1657
+ plt.show()
1658
+ if return_fig:
1659
+ return plt.gcf()
1660
+ if not (show or save):
1661
+ return g
1662
+ return None
1663
+
1664
+ # If not plot as facets, call boxplot to plot cell types on the x-axis.
1665
+ else:
1666
+ if level_order:
1667
+ args_boxplot["hue_order"] = level_order
1668
+ args_swarmplot["hue_order"] = level_order
1669
+
1670
+ _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1671
+
1672
+ ax = sns.boxplot(
1673
+ x="Cell type",
1674
+ y=value_name,
1675
+ hue=feature_name,
1676
+ data=plot_df,
1677
+ fliersize=1,
1678
+ palette=palette,
1679
+ ax=ax,
1680
+ **args_boxplot,
1681
+ )
1682
+
1683
+ # Currently disabled because the latest statsannotations does not support the latest seaborn.
1684
+ # We had to drop the dependency.
1685
+ # if draw_effects:
1686
+ # pairs = [
1687
+ # [(row["Cell Type"], row[feature_name]), (row["Cell Type"], "Control")]
1688
+ # for _, row in credible_effects_df.iterrows()
1689
+ # ]
1690
+ # annot = Annotator(ax, pairs, data=plot_df, x="Cell type", y=value_name, hue=feature_name)
1691
+ # annot.configure(test=None, loc="outside", color="red", line_height=0, verbose=False)
1692
+ # annot.set_custom_annotations([row[feature_name] for _, row in credible_effects_df.iterrows()])
1693
+ # annot.annotate()
1694
+
1695
+ if add_dots:
1696
+ sns.swarmplot(
1697
+ x="Cell type",
1698
+ y=value_name,
1699
+ data=plot_df,
1700
+ hue=feature_name,
1701
+ ax=ax,
1702
+ dodge=True,
1703
+ palette="dark:black",
1704
+ **args_swarmplot,
1705
+ )
1706
+
1707
+ cell_types = pd.unique(plot_df["Cell type"])
1708
+ ax.set_xticklabels(cell_types, rotation=90)
1709
+
1710
+ if show_legend:
1711
+ handles, labels = ax.get_legend_handles_labels()
1712
+ handout = []
1713
+ labelout = []
1714
+ for h, l in zip(handles, labels, strict=False):
1715
+ if l not in labelout:
1716
+ labelout.append(l)
1717
+ handout.append(h)
1718
+ ax.legend(
1719
+ handout,
1720
+ labelout,
1721
+ loc="upper left",
1722
+ bbox_to_anchor=(1, 1),
1723
+ ncol=1,
1724
+ title=feature_name,
1725
+ )
1726
+
1727
+ if save:
1728
+ plt.savefig(save, bbox_inches="tight")
1729
+ if show:
1730
+ plt.show()
1731
+ if return_fig:
1732
+ return plt.gcf()
1733
+ if not (show or save):
1734
+ return ax
1735
+ return None
1736
+
1737
+ def plot_rel_abundance_dispersion_plot( # pragma: no cover
1738
+ self,
1739
+ data: AnnData | MuData,
1740
+ modality_key: str = "coda",
1741
+ abundant_threshold: float | None = 0.9,
1742
+ default_color: str | None = "Grey",
1743
+ abundant_color: str | None = "Red",
1744
+ label_cell_types: bool = True,
1745
+ figsize: tuple[float, float] | None = None,
1746
+ dpi: int | None = 100,
1747
+ return_fig: bool | None = None,
1748
+ ax: plt.Axes | None = None,
1749
+ show: bool | None = None,
1750
+ save: str | bool | None = None,
1751
+ ) -> plt.Axes | plt.Figure | None:
1752
+ """Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
1753
+
1754
+ If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color.
1755
+
1756
+ Args:
1757
+ data: AnnData or MuData object.
1758
+ modality_key: If data is a MuData object, specify which modality to use.
1759
+ abundant_threshold: Presence threshold for abundant cell types.
1760
+ default_color: Bar color for all non-minimal cell types.
1761
+ abundant_color: Bar color for cell types with abundant percentage larger than abundant_threshold.
1762
+ label_cell_types: Label dots with cell type names.
1763
+ figsize: Figure size.
1764
+ dpi: Dpi setting.
1765
+ ax: A matplotlib axes object. Only works if plotting a single component.
1766
+
1767
+ Returns:
1768
+ A :class:`~matplotlib.axes.Axes` object
1769
+
1770
+ Examples:
1771
+ >>> import pertpy as pt
1772
+ >>> haber_cells = pt.dt.haber_2017_regions()
1773
+ >>> sccoda = pt.tl.Sccoda()
1774
+ >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
1775
+ sample_identifier="batch", covariate_obs=["condition"])
1776
+ >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
1777
+ >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
1778
+ >>> sccoda.plot_rel_abundance_dispersion_plot(mdata)
1779
+
1780
+ Preview:
1781
+ .. image:: /_static/docstring_previews/sccoda_rel_abundance_dispersion_plot.png
1782
+ """
1783
+ if isinstance(data, MuData):
1784
+ data = data[modality_key]
1785
+ if isinstance(data, AnnData):
1786
+ data = data
1787
+ if ax is None:
1788
+ _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1789
+
1790
+ rel_abun = data.X / np.sum(data.X, axis=1, keepdims=True)
1791
+
1792
+ percent_zero = np.sum(data.X == 0, axis=0) / data.X.shape[0]
1793
+ nonrare_ct = np.where(percent_zero < 1 - abundant_threshold)[0]
1794
+
1795
+ # select reference
1796
+ cell_type_disp = np.var(rel_abun, axis=0) / np.mean(rel_abun, axis=0)
1797
+
1798
+ is_abundant = [x in nonrare_ct for x in range(data.X.shape[1])]
1799
+
1800
+ # Scatterplot
1801
+ plot_df = pd.DataFrame(
1802
+ {
1803
+ "Total dispersion": cell_type_disp,
1804
+ "Cell type": data.var.index,
1805
+ "Presence": 1 - percent_zero,
1806
+ "Is abundant": is_abundant,
1807
+ }
1808
+ )
1809
+
1810
+ if len(np.unique(plot_df["Is abundant"])) > 1:
1811
+ palette = [default_color, abundant_color]
1812
+ elif np.unique(plot_df["Is abundant"]) == [False]:
1813
+ palette = [default_color]
1814
+ else:
1815
+ palette = [abundant_color]
1816
+
1817
+ ax = sns.scatterplot(
1818
+ data=plot_df,
1819
+ x="Presence",
1820
+ y="Total dispersion",
1821
+ hue="Is abundant",
1822
+ palette=palette,
1823
+ ax=ax,
1824
+ )
1825
+
1826
+ # Text labels for abundant cell types
1827
+
1828
+ abundant_df = plot_df.loc[plot_df["Is abundant"], :]
1829
+
1830
+ def label_point(x, y, val, ax):
1831
+ a = pd.concat({"x": x, "y": y, "val": val}, axis=1)
1832
+ texts = [
1833
+ ax.text(
1834
+ point["x"],
1835
+ point["y"],
1836
+ str(point["val"]),
1837
+ )
1838
+ for i, point in a.iterrows()
1839
+ ]
1840
+ adjust_text(texts)
1841
+
1842
+ if label_cell_types:
1843
+ label_point(
1844
+ abundant_df["Presence"],
1845
+ abundant_df["Total dispersion"],
1846
+ abundant_df["Cell type"],
1847
+ plt.gca(),
1848
+ )
1849
+
1850
+ ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
1851
+
1852
+ if save:
1853
+ plt.savefig(save, bbox_inches="tight")
1854
+ if show:
1855
+ plt.show()
1856
+ if return_fig:
1857
+ return plt.gcf()
1858
+ if not (show or save):
1859
+ return ax
1860
+ return None
1861
+
1862
+ def plot_draw_tree( # pragma: no cover
1863
+ self,
1864
+ data: AnnData | MuData,
1865
+ modality_key: str = "coda",
1866
+ tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1867
+ tight_text: bool | None = False,
1868
+ show_scale: bool | None = False,
1869
+ units: Literal["px", "mm", "in"] | None = "px",
1870
+ figsize: tuple[float, float] | None = (None, None),
1871
+ dpi: int | None = 100,
1872
+ show: bool | None = True,
1873
+ save: str | bool | None = None,
1874
+ ) -> Tree | None:
1875
+ """Plot a tree using input ete3 tree object.
1876
+
1877
+ Args:
1878
+ data: AnnData object or MuData object.
1879
+ modality_key: If data is a MuData object, specify which modality to use.
1880
+ tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
1881
+ tight_text: When False, boundaries of the text are approximated according to general font metrics,
1882
+ producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
1883
+ show_scale: Include the scale legend in the tree image or not.
1884
+ show: If True, plot the tree inline. If false, return tree and tree_style objects.
1885
+ file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG.
1886
+ Output image can be saved whether show is True or not.
1887
+ units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
1888
+ figsize: Figure size.
1889
+ dpi: Dots per inches.
1890
+
1891
+ Returns:
1892
+ Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
1893
+
1894
+ Examples:
1895
+ >>> import pertpy as pt
1896
+ >>> adata = pt.dt.tasccoda_example()
1897
+ >>> tasccoda = pt.tl.Tasccoda()
1898
+ >>> mdata = tasccoda.load(
1899
+ >>> adata, type="sample_level",
1900
+ >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
1901
+ >>> key_added="lineage", add_level_name=True
1902
+ >>> )
1903
+ >>> mdata = tasccoda.prepare(
1904
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
1905
+ >>> )
1906
+ >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
1907
+ >>> tasccoda.plot_draw_tree(mdata, tree="lineage")
1908
+
1909
+ Preview:
1910
+ .. image:: /_static/docstring_previews/tasccoda_draw_tree.png
1911
+ """
1912
+ try:
1913
+ from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
1914
+ except ImportError:
1915
+ raise ImportError(
1916
+ "To use tasccoda please install additional dependencies with `pip install pertpy[coda]`"
1917
+ ) from None
1918
+
1919
+ if isinstance(data, MuData):
1920
+ data = data[modality_key]
1921
+ if isinstance(data, AnnData):
1922
+ data = data
1923
+ if isinstance(tree, str):
1924
+ tree = data.uns[tree]
1925
+
1926
+ def my_layout(node):
1927
+ text_face = TextFace(node.name, tight_text=tight_text)
1928
+ faces.add_face_to_node(text_face, node, column=0, position="branch-right")
1929
+
1930
+ tree_style = TreeStyle()
1931
+ tree_style.show_leaf_name = False
1932
+ tree_style.layout_fn = my_layout
1933
+ tree_style.show_scale = show_scale
1934
+
1935
+ if save is not None:
1936
+ tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1937
+ if show:
1938
+ return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1939
+ else:
1940
+ return tree, tree_style
1941
+
1942
+ def plot_draw_effects( # pragma: no cover
1943
+ self,
1944
+ data: AnnData | MuData,
1945
+ covariate: str,
1946
+ modality_key: str = "coda",
1947
+ tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1948
+ show_legend: bool | None = None,
1949
+ show_leaf_effects: bool | None = False,
1950
+ tight_text: bool | None = False,
1951
+ show_scale: bool | None = False,
1952
+ units: Literal["px", "mm", "in"] | None = "px",
1953
+ figsize: tuple[float, float] | None = (None, None),
1954
+ dpi: int | None = 100,
1955
+ show: bool | None = True,
1956
+ save: str | None = None,
1957
+ ) -> Tree | None:
1958
+ """Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
1959
+
1960
+ Args:
1961
+ data: AnnData object or MuData object.
1962
+ covariate: The covariate, whose effects should be plotted.
1963
+ modality_key: If data is a MuData object, specify which modality to use.
1964
+ tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
1965
+ show_legend: If show legend of nodes significant effects or not.
1966
+ Defaults to False if show_leaf_effects is True.
1967
+ show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects.
1968
+ tight_text: When False, boundaries of the text are approximated according to general font metrics,
1969
+ producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
1970
+ show_scale: Include the scale legend in the tree image or not.
1971
+ show: If True, plot the tree inline. If false, return tree and tree_style objects.
1972
+ file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
1973
+ units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
1974
+ figsize: Figure size.
1975
+ dpi: Dots per inches.
1976
+
1977
+ Returns:
1978
+ Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`)
1979
+ or plot the tree inline (`show = False`)
1980
+
1981
+ Examples:
1982
+ >>> import pertpy as pt
1983
+ >>> adata = pt.dt.tasccoda_example()
1984
+ >>> tasccoda = pt.tl.Tasccoda()
1985
+ >>> mdata = tasccoda.load(
1986
+ >>> adata, type="sample_level",
1987
+ >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
1988
+ >>> key_added="lineage", add_level_name=True
1989
+ >>> )
1990
+ >>> mdata = tasccoda.prepare(
1991
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
1992
+ >>> )
1993
+ >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
1994
+ >>> tasccoda.plot_draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
1995
+
1996
+ Preview:
1997
+ .. image:: /_static/docstring_previews/tasccoda_draw_effects.png
1998
+ """
1999
+ try:
2000
+ from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
2001
+ except ImportError:
2002
+ raise ImportError(
2003
+ "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
2004
+ ) from None
2005
+
2006
+ if isinstance(data, MuData):
2007
+ data = data[modality_key]
2008
+ if isinstance(data, AnnData):
2009
+ data = data
2010
+ if show_legend is None:
2011
+ show_legend = not show_leaf_effects
2012
+ elif show_legend:
2013
+ logger.info("Tree leaves and leaf effect bars won't be aligned when legend is shown!")
2014
+
2015
+ if isinstance(tree, str):
2016
+ tree = data.uns[tree]
2017
+ # Collapse tree singularities
2018
+ tree2 = collapse_singularities_2(tree)
2019
+
2020
+ node_effs = data.uns["scCODA_params"]["node_df"].loc[(covariate + "_node",),].copy()
2021
+ node_effs.index = node_effs.index.get_level_values("Node")
2022
+
2023
+ covariates = data.uns["scCODA_params"]["covariate_names"]
2024
+ effect_dfs = [data.varm[f"effect_df_{cov}"] for cov in covariates]
2025
+ eff_df = pd.concat(effect_dfs)
2026
+ eff_df.index = pd.MultiIndex.from_product(
2027
+ (covariates, data.var.index.tolist()),
2028
+ names=["Covariate", "Cell Type"],
2029
+ )
2030
+ leaf_effs = eff_df.loc[(covariate,),].copy()
2031
+ leaf_effs.index = leaf_effs.index.get_level_values("Cell Type")
2032
+
2033
+ # Add effect values
2034
+ for n in tree2.traverse():
2035
+ nstyle = NodeStyle()
2036
+ nstyle["size"] = 0
2037
+ n.set_style(nstyle)
2038
+ if n.name in node_effs.index:
2039
+ e = node_effs.loc[n.name, "Final Parameter"]
2040
+ n.add_feature("node_effect", e)
2041
+ else:
2042
+ n.add_feature("node_effect", 0)
2043
+ if n.name in leaf_effs.index:
2044
+ e = leaf_effs.loc[n.name, "Effect"]
2045
+ n.add_feature("leaf_effect", e)
2046
+ else:
2047
+ n.add_feature("leaf_effect", 0)
2048
+
2049
+ # Scale effect values to get nice node sizes
2050
+ eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
2051
+ leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
2052
+
2053
+ def my_layout(node):
2054
+ text_face = TextFace(node.name, tight_text=tight_text)
2055
+ text_face.margin_left = 10
2056
+ faces.add_face_to_node(text_face, node, column=0, aligned=True)
2057
+
2058
+ # if node.is_leaf():
2059
+ size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
2060
+ if np.sign(node.node_effect) == 1:
2061
+ color = "blue"
2062
+ elif np.sign(node.node_effect) == -1:
2063
+ color = "red"
2064
+ else:
2065
+ color = "cyan"
2066
+ if size != 0:
2067
+ faces.add_face_to_node(CircleFace(radius=size, color=color), node, column=0)
2068
+
2069
+ tree_style = TreeStyle()
2070
+ tree_style.show_leaf_name = False
2071
+ tree_style.layout_fn = my_layout
2072
+ tree_style.show_scale = show_scale
2073
+ tree_style.draw_guiding_lines = True
2074
+ tree_style.legend_position = 1
2075
+
2076
+ if show_legend:
2077
+ tree_style.legend.add_face(TextFace("Effects"), column=0)
2078
+ tree_style.legend.add_face(TextFace(" "), column=1)
2079
+ for i in range(4, 0, -1):
2080
+ tree_style.legend.add_face(
2081
+ CircleFace(
2082
+ float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
2083
+ "red",
2084
+ ),
2085
+ column=0,
2086
+ )
2087
+ tree_style.legend.add_face(TextFace(f"{-eff_max * i / 4:.2f} "), column=0)
2088
+ tree_style.legend.add_face(
2089
+ CircleFace(
2090
+ float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
2091
+ "blue",
2092
+ ),
2093
+ column=1,
2094
+ )
2095
+ tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
2096
+
2097
+ if show_leaf_effects:
2098
+ leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
2099
+ leaf_effs = leaf_effs.loc[leaf_name].reset_index()
2100
+ palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
2101
+
2102
+ dir_path = Path.cwd()
2103
+ dir_path = Path(dir_path / "tree_effect.png")
2104
+ tree2.render(dir_path, tree_style=tree_style, units="in")
2105
+ _, ax = plt.subplots(1, 2, figsize=(10, 10))
2106
+ sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
2107
+ img = mpimg.imread(dir_path)
2108
+ ax[0].imshow(img)
2109
+ ax[0].get_xaxis().set_visible(False)
2110
+ ax[0].get_yaxis().set_visible(False)
2111
+ ax[0].set_frame_on(False)
2112
+
2113
+ ax[1].get_yaxis().set_visible(False)
2114
+ ax[1].spines["left"].set_visible(False)
2115
+ ax[1].spines["right"].set_visible(False)
2116
+ ax[1].spines["top"].set_visible(False)
2117
+ plt.xlim(-leaf_eff_max, leaf_eff_max)
2118
+ plt.subplots_adjust(wspace=0)
2119
+
2120
+ if save is not None:
2121
+ plt.savefig(save)
2122
+
2123
+ if save is not None and not show_leaf_effects:
2124
+ tree2.render(save, tree_style=tree_style, units=units)
2125
+ if show:
2126
+ if not show_leaf_effects:
2127
+ return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
2128
+ else:
2129
+ if not show_leaf_effects:
2130
+ return tree2, tree_style
2131
+ return None
2132
+
2133
+ def plot_effects_umap( # pragma: no cover
2134
+ self,
2135
+ mdata: MuData,
2136
+ effect_name: str | list | None,
2137
+ cluster_key: str,
2138
+ modality_key_1: str = "rna",
2139
+ modality_key_2: str = "coda",
2140
+ color_map: Colormap | str | None = None,
2141
+ palette: str | Sequence[str] | None = None,
2142
+ return_fig: bool | None = None,
2143
+ ax: Axes = None,
2144
+ show: bool = None,
2145
+ save: str | bool | None = None,
2146
+ **kwargs,
2147
+ ) -> plt.Axes | plt.Figure | None:
2148
+ """Plot a UMAP visualization colored by effect strength.
2149
+
2150
+ Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData
2151
+ (default is data['rna']) depending on the cluster they were assigned to.
2152
+
2153
+ Args:
2154
+ mudata: MuData object.
2155
+ effect_name: The name of the effect results in .varm of aggregated sample-level AnnData to plot
2156
+ cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']).
2157
+ To assign cell types' effects to original cells.
2158
+ modality_key_1: Key to the cell-level AnnData in the MuData object.
2159
+ modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
2160
+ show: Whether to display the figure or return axis.
2161
+ ax: A matplotlib axes object. Only works if plotting a single component.
2162
+ **kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
2163
+
2164
+ Returns:
2165
+ If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
2166
+
2167
+ Examples:
2168
+ >>> import pertpy as pt
2169
+ >>> import scanpy as sc
2170
+ >>> import schist
2171
+ >>> adata = pt.dt.haber_2017_regions()
2172
+ >>> sc.pp.neighbors(adata)
2173
+ >>> schist.inference.nested_model(adata, n_init=100, random_seed=5678)
2174
+ >>> tasccoda_model = pt.tl.Tasccoda()
2175
+ >>> tasccoda_data = tasccoda_model.load(adata, type="cell_level",
2176
+ >>> cell_type_identifier="nsbm_level_1",
2177
+ >>> sample_identifier="batch", covariate_obs=["condition"],
2178
+ >>> levels_orig=["nsbm_level_4", "nsbm_level_3", "nsbm_level_2", "nsbm_level_1"],
2179
+ >>> add_level_name=True)
2180
+ >>> tasccoda_model.prepare(
2181
+ >>> tasccoda_data,
2182
+ >>> modality_key="coda",
2183
+ >>> reference_cell_type="18",
2184
+ >>> formula="condition",
2185
+ >>> pen_args={"phi": 0, "lambda_1": 3.5},
2186
+ >>> tree_key="tree"
2187
+ >>> )
2188
+ >>> tasccoda_model.run_nuts(
2189
+ ... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
2190
+ ... )
2191
+ >>> tasccoda_model.run_nuts(
2192
+ ... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
2193
+ ... )
2194
+ >>> sc.tl.umap(tasccoda_data["rna"])
2195
+ >>> tasccoda_model.plot_effects_umap(tasccoda_data,
2196
+ >>> effect_name=["effect_df_condition[T.Salmonella]",
2197
+ >>> "effect_df_condition[T.Hpoly.Day3]",
2198
+ >>> "effect_df_condition[T.Hpoly.Day10]"],
2199
+ >>> cluster_key="nsbm_level_1",
2200
+ >>> )
2201
+
2202
+ Preview:
2203
+ .. image:: /_static/docstring_previews/tasccoda_effects_umap.png
2204
+ """
2205
+ # TODO: Add effect_name parameter and cluster_key and test the example
2206
+ data_rna = mdata[modality_key_1]
2207
+ data_coda = mdata[modality_key_2]
2208
+ if isinstance(effect_name, str):
2209
+ effect_name = [effect_name]
2210
+ for _, effect in enumerate(effect_name):
2211
+ data_rna.obs[effect] = [data_coda.varm[effect].loc[f"{c}", "Effect"] for c in data_rna.obs[cluster_key]]
2212
+ if kwargs.get("vmin"):
2213
+ vmin = kwargs["vmin"]
2214
+ kwargs.pop("vmin")
2215
+ else:
2216
+ vmin = min(data_rna.obs[effect].min() for _, effect in enumerate(effect_name))
2217
+ if kwargs.get("vmax"):
2218
+ vmax = kwargs["vmax"]
2219
+ kwargs.pop("vmax")
2220
+ else:
2221
+ vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name))
2222
+
2223
+ return sc.pl.umap(
2224
+ data_rna,
2225
+ color=effect_name,
2226
+ vmax=vmax,
2227
+ vmin=vmin,
2228
+ palette=palette,
2229
+ color_map=color_map,
2230
+ return_fig=return_fig,
2231
+ ax=ax,
2232
+ show=show,
2233
+ save=save,
2234
+ **kwargs,
2235
+ )
2236
+
1116
2237
 
1117
2238
  def get_a(
1118
- tree: tt.tree,
2239
+ tree: tt.core.ToyTree,
1119
2240
  ) -> tuple[np.ndarray, int]:
1120
2241
  """Calculate ancestor matrix from a toytree tree
1121
2242
 
@@ -1154,7 +2275,7 @@ def get_a(
1154
2275
  return A, n_nodes - 1
1155
2276
 
1156
2277
 
1157
- def collapse_singularities(tree: tt.tree) -> tt.tree:
2278
+ def collapse_singularities(tree: tt.core.ToyTree) -> tt.core.ToyTree:
1158
2279
  """Collapses (deletes) nodes in a toytree tree that are singularities (have only one child).
1159
2280
 
1160
2281
  Args:
@@ -1242,7 +2363,7 @@ def df2newick(df: pd.DataFrame, levels: list[str], inner_label: bool = True) ->
1242
2363
 
1243
2364
 
1244
2365
  def get_a_2(
1245
- tree: ete.Tree,
2366
+ tree: Tree,
1246
2367
  leaf_order: list[str] = None,
1247
2368
  node_order: list[str] = None,
1248
2369
  ) -> tuple[np.ndarray, int]:
@@ -1263,6 +2384,13 @@ def get_a_2(
1263
2384
  T
1264
2385
  number of nodes in the tree, excluding the root node
1265
2386
  """
2387
+ try:
2388
+ import ete3 as ete
2389
+ except ImportError:
2390
+ raise ImportError(
2391
+ "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
2392
+ ) from None
2393
+
1266
2394
  n_tips = len(tree.get_leaves())
1267
2395
  n_nodes = len(tree.get_descendants())
1268
2396
 
@@ -1292,7 +2420,7 @@ def get_a_2(
1292
2420
  return A_, n_nodes
1293
2421
 
1294
2422
 
1295
- def collapse_singularities_2(tree: ete.Tree) -> ete.Tree:
2423
+ def collapse_singularities_2(tree: Tree) -> Tree:
1296
2424
  """Collapses (deletes) nodes in a ete3 tree that are singularities (have only one child).
1297
2425
 
1298
2426
  Args:
@@ -1327,10 +2455,10 @@ def linkage_to_newick(
1327
2455
 
1328
2456
  def build_newick(node, newick, parentdist, leaf_names):
1329
2457
  if node.is_leaf():
1330
- return f"{leaf_names[node.id]}:{(parentdist - node.dist)/2}{newick}"
2458
+ return f"{leaf_names[node.id]}:{(parentdist - node.dist) / 2}{newick}"
1331
2459
  else:
1332
2460
  if len(newick) > 0:
1333
- newick = f"):{(parentdist - node.dist)/2}{newick}"
2461
+ newick = f"):{(parentdist - node.dist) / 2}{newick}"
1334
2462
  else:
1335
2463
  newick = ");"
1336
2464
  newick = build_newick(node.get_left(), newick, node.dist, leaf_names)
@@ -1363,14 +2491,14 @@ def import_tree(
1363
2491
 
1364
2492
  Args:
1365
2493
  data: A tascCODA-compatible data object.
1366
- modality_1: If `data` is MuData, specifiy the modality name to the original cell level anndata object. Defaults to None.
1367
- modality_2: If `data` is MuData, specifiy the modality name to the aggregated level anndata object. Defaults to None.
1368
- dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object. Defaults to None.
1369
- levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None.
1370
- levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None.
1371
- add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}. Defaults to True.
1372
- key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`. If `data` is MuData, save tree in data[modality_2]. Defaults to "tree".
1373
- copy: Return a copy instead of writing to `data`. Defaults to False.
2494
+ modality_1: If `data` is MuData, specify the modality name to the original cell level anndata object.
2495
+ modality_2: If `data` is MuData, specify the modality name to the aggregated level anndata object.
2496
+ dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object.
2497
+ levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
2498
+ levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
2499
+ add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
2500
+ key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`.
2501
+ If `data` is MuData, save tree in data[modality_2].
1374
2502
 
1375
2503
  Returns:
1376
2504
  Updates data with the following:
@@ -1379,15 +2507,22 @@ def import_tree(
1379
2507
 
1380
2508
  tree: A ete3 tree object.
1381
2509
  """
2510
+ try:
2511
+ import ete3 as ete
2512
+ except ImportError:
2513
+ raise ImportError(
2514
+ "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
2515
+ ) from None
2516
+
1382
2517
  if isinstance(data, MuData):
1383
2518
  try:
1384
2519
  data_1 = data[modality_1]
1385
2520
  data_2 = data[modality_2]
1386
2521
  except KeyError as name:
1387
- print(f"No {name} slot in MuData")
2522
+ logger.error(f"No {name} slot in MuData")
1388
2523
  raise
1389
2524
  except IndexError:
1390
- print("Please specify modality_1 and modality_2 to indicate modalities in MuData")
2525
+ logger.error("Please specify modality_1 and modality_2 to indicate modalities in MuData")
1391
2526
  raise
1392
2527
  else:
1393
2528
  data_1 = data
@@ -1443,68 +2578,54 @@ def from_scanpy(
1443
2578
 
1444
2579
  The anndata object needs to have a column in adata.obs that contains the cell type assignment.
1445
2580
  Further, it must contain one column or a set of columns (e.g. subject id, treatment, disease status) that uniquely identify each (statistical) sample.
1446
- Further covariates (e.g. subject age) can either be specified via addidional column names in adata.obs, a key in adata.uns, or as a separate DataFrame.
2581
+ Further covariates (e.g. subject age) can either be specified via additional column names in adata.obs, a key in adata.uns, or as a separate DataFrame.
1447
2582
 
1448
- NOTE: The order of samples in the returned dataset is determined by the first occurence of cells from each sample in `adata`
2583
+ NOTE: The order of samples in the returned dataset is determined by the first occurrence of cells from each sample in `adata`
1449
2584
 
1450
2585
  Args:
1451
2586
  adata: An anndata object from scanpy
1452
2587
  cell_type_identifier: column name in adata.obs that specifies the cell types
1453
2588
  sample_identifier: column name or list of column names in adata.obs that uniquely identify each sample
1454
2589
  covariate_uns: key for adata.uns, where covariate values are stored
1455
- covariate_obs: list of column names in adata.obs, where covariate values are stored. Note: If covariate values are not unique for a value of sample_identifier, this covaariate will be skipped.
2590
+ covariate_obs: list of column names in adata.obs, where covariate values are stored.
2591
+ Note: If covariate values are not unique for a value of sample_identifier, this covariate will be skipped.
1456
2592
  covariate_df: DataFrame with covariates
1457
2593
 
1458
2594
  Returns:
1459
2595
  AnnData: A data set with cells aggregated to the (sample x cell type) level
1460
2596
  """
1461
- if isinstance(sample_identifier, str):
1462
- sample_identifier = [sample_identifier]
1463
-
1464
- if covariate_obs:
1465
- covariate_obs += [i for i in sample_identifier if i not in covariate_obs]
1466
- else:
1467
- covariate_obs = sample_identifier # type: ignore
2597
+ sample_identifier = [sample_identifier] if isinstance(sample_identifier, str) else sample_identifier
2598
+ covariate_obs = list(set(covariate_obs or []) | set(sample_identifier))
1468
2599
 
1469
- # join sample identifiers
1470
2600
  if isinstance(sample_identifier, list):
1471
2601
  adata.obs["scCODA_sample_id"] = adata.obs[sample_identifier].agg("-".join, axis=1)
1472
2602
  sample_identifier = "scCODA_sample_id"
1473
2603
 
1474
- # get cell type counts
1475
2604
  groups = adata.obs.value_counts([sample_identifier, cell_type_identifier])
1476
- count_data = groups.unstack(level=cell_type_identifier)
1477
- count_data = count_data.fillna(0)
1478
-
1479
- # get covariates from different sources
1480
- covariate_df_ = pd.DataFrame(index=count_data.index)
1481
-
1482
- if covariate_df is None and covariate_obs is None and covariate_uns is None:
1483
- print("No covariate information specified!")
2605
+ ct_count_data = groups.unstack(level=cell_type_identifier).fillna(0)
2606
+ covariate_df_ = pd.DataFrame(index=ct_count_data.index)
1484
2607
 
1485
2608
  if covariate_uns is not None:
1486
- covariate_df_uns = pd.DataFrame(adata.uns[covariate_uns])
1487
- covariate_df_ = pd.concat((covariate_df_, covariate_df_uns), axis=1)
2609
+ covariate_df_uns = pd.DataFrame(adata.uns[covariate_uns], index=ct_count_data.index)
2610
+ covariate_df_ = pd.concat([covariate_df_, covariate_df_uns], axis=1)
1488
2611
 
1489
- if covariate_obs is not None:
1490
- for c in covariate_obs:
1491
- if any(adata.obs.groupby(sample_identifier).nunique()[c] != 1):
1492
- print(f"Covariate {c} has non-unique values! Skipping...")
2612
+ if covariate_obs:
2613
+ unique_check = adata.obs.groupby(sample_identifier).nunique()
2614
+ for c in covariate_obs.copy():
2615
+ if unique_check[c].max() != 1:
2616
+ logger.warning(f"Covariate {c} has non-unique values for batch! Skipping...")
1493
2617
  covariate_obs.remove(c)
1494
-
1495
- covariate_df_obs = adata.obs.groupby(sample_identifier).first()[covariate_obs]
1496
- covariate_df_ = pd.concat((covariate_df_, covariate_df_obs), axis=1)
2618
+ if covariate_obs:
2619
+ covariate_df_obs = adata.obs.groupby(sample_identifier).first()[covariate_obs]
2620
+ covariate_df_ = pd.concat([covariate_df_, covariate_df_obs], axis=1)
1497
2621
 
1498
2622
  if covariate_df is not None:
1499
- if set(covariate_df.index) != set(count_data.index):
1500
- raise ValueError("anndata sample names and covariate_df index do not have the same elements!")
1501
- covs_ord = covariate_df.reindex(count_data.index)
1502
- covariate_df_ = pd.concat((covariate_df_, covs_ord), axis=1)
2623
+ if set(covariate_df.index) != set(ct_count_data.index):
2624
+ raise ValueError("Mismatch between sample names in anndata and covariate_df!")
2625
+ covariate_df_ = pd.concat([covariate_df_, covariate_df.reindex(ct_count_data.index)], axis=1)
1503
2626
 
1504
- covariate_df_.index = covariate_df_.index.astype(str)
1505
-
1506
- # create var (number of cells for each type as only column)
1507
- var_dat = count_data.sum(axis=0).rename("n_cells").to_frame()
2627
+ var_dat = ct_count_data.sum().rename("n_cells").to_frame()
1508
2628
  var_dat.index = var_dat.index.astype(str)
2629
+ covariate_df_.index = covariate_df_.index.astype(str)
1509
2630
 
1510
- return AnnData(X=count_data.values, var=var_dat, obs=covariate_df_)
2631
+ return AnnData(X=ct_count_data.values, var=var_dat, obs=covariate_df_)