pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,17 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Literal, Optional, Union
5
6
 
6
7
  import arviz as az
7
- import ete3 as ete
8
8
  import jax.numpy as jnp
9
+ import matplotlib.pyplot as plt
9
10
  import numpy as np
10
11
  import pandas as pd
11
12
  import patsy as pt
13
+ import scanpy as sc
14
+ import seaborn as sns
15
+ from adjustText import adjust_text
12
16
  from anndata import AnnData
13
- from jax import random
14
- from jax.config import config
17
+ from jax import config, random
18
+ from matplotlib import cm, rcParams
19
+ from matplotlib import image as mpimg
20
+ from matplotlib.colors import ListedColormap
15
21
  from mudata import MuData
16
22
  from numpyro.infer import HMC, MCMC, NUTS, initialization
17
23
  from rich import box, print
@@ -20,10 +26,15 @@ from rich.table import Table
20
26
  from scipy.cluster import hierarchy as sp_hierarchy
21
27
 
22
28
  if TYPE_CHECKING:
29
+ from collections.abc import Sequence
30
+
23
31
  import numpyro as npy
24
32
  import toytree as tt
25
- from jax._src.prng import PRNGKeyArray
33
+ from ete3 import Tree
26
34
  from jax._src.typing import Array
35
+ from matplotlib.axes import Axes
36
+ from matplotlib.colors import Colormap
37
+ from matplotlib.figure import Figure
27
38
 
28
39
  config.update("jax_enable_x64", True)
29
40
 
@@ -179,7 +190,7 @@ class CompositionalModel2(ABC):
179
190
  self,
180
191
  sample_adata: AnnData,
181
192
  kernel: npy.infer.mcmc.MCMCKernel,
182
- rng_key: Array | PRNGKeyArray,
193
+ rng_key: Array,
183
194
  copy: bool = False,
184
195
  *args,
185
196
  **kwargs,
@@ -295,7 +306,7 @@ class CompositionalModel2(ABC):
295
306
  if copy:
296
307
  sample_adata = sample_adata.copy()
297
308
 
298
- rng_key_array = random.PRNGKey(rng_key)
309
+ rng_key_array = random.key(rng_key)
299
310
  sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = np.array(rng_key_array)
300
311
 
301
312
  # Set up NUTS kernel
@@ -335,7 +346,6 @@ class CompositionalModel2(ABC):
335
346
  copy: Return a copy instead of writing to adata. Defaults to False.
336
347
 
337
348
  Examples:
338
- Example with scCODA:
339
349
  >>> import pertpy as pt
340
350
  >>> haber_cells = pt.dt.haber_2017_regions()
341
351
  >>> sccoda = pt.tl.Sccoda()
@@ -358,10 +368,10 @@ class CompositionalModel2(ABC):
358
368
  # Set rng key if needed
359
369
  if rng_key is None:
360
370
  rng = np.random.default_rng()
361
- rng_key = random.PRNGKey(rng.integers(0, 10000))
371
+ rng_key = random.key(rng.integers(0, 10000))
362
372
  sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = rng_key
363
373
  else:
364
- rng_key = random.PRNGKey(rng_key)
374
+ rng_key = random.key(rng_key)
365
375
 
366
376
  # Set up HMC kernel
367
377
  sample_adata = self.set_init_mcmc_states(
@@ -423,7 +433,6 @@ class CompositionalModel2(ABC):
423
433
  - Is credible: Boolean indicator whether effect is credible
424
434
 
425
435
  Examples:
426
- Example with scCODA:
427
436
  >>> import pertpy as pt
428
437
  >>> haber_cells = pt.dt.haber_2017_regions()
429
438
  >>> sccoda = pt.tl.Sccoda()
@@ -433,7 +442,6 @@ class CompositionalModel2(ABC):
433
442
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
434
443
  >>> intercept_df, effect_df = sccoda.summary_prepare(mdata["coda"])
435
444
  """
436
- # Get model and effect selection types
437
445
  select_type = sample_adata.uns["scCODA_params"]["select_type"]
438
446
  model_type = sample_adata.uns["scCODA_params"]["model_type"]
439
447
 
@@ -548,7 +556,11 @@ class CompositionalModel2(ABC):
548
556
  intercept_df = intercept_df.loc[:, ["final_parameter", hdis[0], hdis[1], "sd", "expected_sample"]].copy()
549
557
  intercept_df = intercept_df.rename(
550
558
  columns=dict(
551
- zip(intercept_df.columns, ["Final Parameter", hdis_new[0], hdis_new[1], "SD", "Expected Sample"])
559
+ zip(
560
+ intercept_df.columns,
561
+ ["Final Parameter", hdis_new[0], hdis_new[1], "SD", "Expected Sample"],
562
+ strict=False,
563
+ )
552
564
  )
553
565
  )
554
566
 
@@ -561,6 +573,7 @@ class CompositionalModel2(ABC):
561
573
  zip(
562
574
  effect_df.columns,
563
575
  ["Effect", "Median", hdis_new[0], hdis_new[1], "SD", "Expected Sample", "log2-fold change"],
576
+ strict=False,
564
577
  )
565
578
  )
566
579
  )
@@ -581,6 +594,7 @@ class CompositionalModel2(ABC):
581
594
  "Expected Sample",
582
595
  "log2-fold change",
583
596
  ],
597
+ strict=False,
584
598
  )
585
599
  )
586
600
  )
@@ -594,6 +608,7 @@ class CompositionalModel2(ABC):
594
608
  zip(
595
609
  node_df.columns,
596
610
  ["Final Parameter", "Median", hdis_new[0], hdis_new[1], "SD", "Delta", "Is credible"],
611
+ strict=False,
597
612
  )
598
613
  ) # type: ignore
599
614
  ) # type: ignore
@@ -781,7 +796,6 @@ class CompositionalModel2(ABC):
781
796
  kwargs: Passed to az.summary
782
797
 
783
798
  Examples:
784
- Example with scCODA:
785
799
  >>> import pertpy as pt
786
800
  >>> haber_cells = pt.dt.haber_2017_regions()
787
801
  >>> sccoda = pt.tl.Sccoda()
@@ -799,7 +813,7 @@ class CompositionalModel2(ABC):
799
813
  raise
800
814
  if isinstance(data, AnnData):
801
815
  sample_adata = data
802
- # Get model and effect selection types
816
+
803
817
  select_type = sample_adata.uns["scCODA_params"]["select_type"]
804
818
  model_type = sample_adata.uns["scCODA_params"]["model_type"]
805
819
 
@@ -926,7 +940,6 @@ class CompositionalModel2(ABC):
926
940
  pd.DataFrame: Intercept data frame.
927
941
 
928
942
  Examples:
929
- Example with scCODA:
930
943
  >>> import pertpy as pt
931
944
  >>> haber_cells = pt.dt.haber_2017_regions()
932
945
  >>> sccoda = pt.tl.Sccoda()
@@ -936,7 +949,6 @@ class CompositionalModel2(ABC):
936
949
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
937
950
  >>> intercepts = sccoda.get_intercept_df(mdata)
938
951
  """
939
-
940
952
  if isinstance(data, MuData):
941
953
  try:
942
954
  sample_adata = data[modality_key]
@@ -959,7 +971,6 @@ class CompositionalModel2(ABC):
959
971
  pd.DataFrame: Effect data frame.
960
972
 
961
973
  Examples:
962
- Example with scCODA:
963
974
  >>> import pertpy as pt
964
975
  >>> haber_cells = pt.dt.haber_2017_regions()
965
976
  >>> sccoda = pt.tl.Sccoda()
@@ -969,7 +980,6 @@ class CompositionalModel2(ABC):
969
980
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
970
981
  >>> effects = sccoda.get_effect_df(mdata)
971
982
  """
972
-
973
983
  if isinstance(data, MuData):
974
984
  try:
975
985
  sample_adata = data[modality_key]
@@ -1003,9 +1013,8 @@ class CompositionalModel2(ABC):
1003
1013
  pd.DataFrame: Node effect data frame.
1004
1014
 
1005
1015
  Examples:
1006
- Example with tascCODA (works only for model of type tree_agg, i.e. a tascCODA model):
1007
1016
  >>> import pertpy as pt
1008
- >>> adata = pt.dt.smillie()
1017
+ >>> adata = pt.dt.tasccoda_example()
1009
1018
  >>> tasccoda = pt.tl.Tasccoda()
1010
1019
  >>> mdata = tasccoda.load(
1011
1020
  >>> adata, type="sample_level",
@@ -1113,6 +1122,1136 @@ class CompositionalModel2(ABC):
1113
1122
 
1114
1123
  return out
1115
1124
 
1125
+ def _stackbar( # pragma: no cover
1126
+ self,
1127
+ y: np.ndarray,
1128
+ type_names: list[str],
1129
+ title: str,
1130
+ level_names: list[str],
1131
+ figsize: tuple[float, float] | None = None,
1132
+ dpi: int | None = 100,
1133
+ palette: ListedColormap | None = cm.tab20,
1134
+ show_legend: bool | None = True,
1135
+ ) -> plt.Axes:
1136
+ """Plots a stacked barplot for one (discrete) covariate.
1137
+
1138
+ Typical use (only inside stacked_barplot): plot_one_stackbar(data.X, data.var.index, "xyz", data.obs.index)
1139
+
1140
+ Args:
1141
+ y: The count data, collapsed onto the level of interest. i.e. a binary covariate has two rows,
1142
+ one for each group, containing the count mean of each cell type
1143
+ type_names: The names of all cell types
1144
+ title: Plot title, usually the covariate's name
1145
+ level_names: Names of the covariate's levels
1146
+ figsize: Figure size. Defaults to None.
1147
+ dpi: Dpi setting. Defaults to 100.
1148
+ palette: The color map for the barplot. Defaults to cm.tab20.
1149
+ show_legend: If True, adds a legend. Defaults to True.
1150
+
1151
+ Returns:
1152
+ A :class:`~matplotlib.axes.Axes` object
1153
+ """
1154
+ n_bars, n_types = y.shape
1155
+
1156
+ figsize = rcParams["figure.figsize"] if figsize is None else figsize
1157
+
1158
+ _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1159
+ r = np.array(range(n_bars))
1160
+ sample_sums = np.sum(y, axis=1)
1161
+
1162
+ barwidth = 0.85
1163
+ cum_bars = np.zeros(n_bars)
1164
+
1165
+ for n in range(n_types):
1166
+ bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums, strict=False)]
1167
+ plt.bar(
1168
+ r,
1169
+ bars,
1170
+ bottom=cum_bars,
1171
+ color=palette(n % palette.N),
1172
+ width=barwidth,
1173
+ label=type_names[n],
1174
+ linewidth=0,
1175
+ )
1176
+ cum_bars += bars
1177
+
1178
+ ax.set_title(title)
1179
+ if show_legend:
1180
+ ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1)
1181
+ ax.set_xticks(r)
1182
+ ax.set_xticklabels(level_names, rotation=45, ha="right")
1183
+ ax.set_ylabel("Proportion")
1184
+
1185
+ return ax
1186
+
1187
+ def plot_stacked_barplot( # pragma: no cover
1188
+ self,
1189
+ data: AnnData | MuData,
1190
+ feature_name: str,
1191
+ modality_key: str = "coda",
1192
+ palette: ListedColormap | None = cm.tab20,
1193
+ show_legend: bool | None = True,
1194
+ level_order: list[str] = None,
1195
+ figsize: tuple[float, float] | None = None,
1196
+ dpi: int | None = 100,
1197
+ return_fig: bool | None = None,
1198
+ ax: plt.Axes | None = None,
1199
+ show: bool | None = None,
1200
+ save: str | bool | None = None,
1201
+ **kwargs,
1202
+ ) -> plt.Axes | plt.Figure | None:
1203
+ """Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
1204
+
1205
+ Args:
1206
+ data: AnnData object or MuData object.
1207
+ feature_name: The name of the covariate to plot. If feature_name=="samples", one bar for every sample will be plotted
1208
+ modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
1209
+ figsize: Figure size. Defaults to None.
1210
+ dpi: Dpi setting. Defaults to 100.
1211
+ palette: The matplotlib color map for the barplot. Defaults to cm.tab20.
1212
+ show_legend: If True, adds a legend. Defaults to True.
1213
+ level_order: Custom ordering of bars on the x-axis. Defaults to None.
1214
+
1215
+ Returns:
1216
+ A :class:`~matplotlib.axes.Axes` object
1217
+
1218
+ Examples:
1219
+ >>> import pertpy as pt
1220
+ >>> haber_cells = pt.dt.haber_2017_regions()
1221
+ >>> sccoda = pt.tl.Sccoda()
1222
+ >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
1223
+ sample_identifier="batch", covariate_obs=["condition"])
1224
+ >>> sccoda.plot_stacked_barplot(mdata, feature_name="samples")
1225
+
1226
+ Preview:
1227
+ .. image:: /_static/docstring_previews/sccoda_stacked_barplot.png
1228
+ """
1229
+ if isinstance(data, MuData):
1230
+ data = data[modality_key]
1231
+ if isinstance(data, AnnData):
1232
+ data = data
1233
+
1234
+ ct_names = data.var.index
1235
+
1236
+ # option to plot one stacked barplot per sample
1237
+ if feature_name == "samples":
1238
+ if level_order:
1239
+ assert set(level_order) == set(data.obs.index), "level order is inconsistent with levels"
1240
+ data = data[level_order]
1241
+ ax = self._stackbar(
1242
+ data.X,
1243
+ type_names=data.var.index,
1244
+ title="samples",
1245
+ level_names=data.obs.index,
1246
+ figsize=figsize,
1247
+ dpi=dpi,
1248
+ palette=palette,
1249
+ show_legend=show_legend,
1250
+ )
1251
+ else:
1252
+ # Order levels
1253
+ if level_order:
1254
+ assert set(level_order) == set(data.obs[feature_name]), "level order is inconsistent with levels"
1255
+ levels = level_order
1256
+ elif hasattr(data.obs[feature_name], "cat"):
1257
+ levels = data.obs[feature_name].cat.categories.to_list()
1258
+ else:
1259
+ levels = pd.unique(data.obs[feature_name])
1260
+ n_levels = len(levels)
1261
+ feature_totals = np.zeros([n_levels, data.X.shape[1]])
1262
+
1263
+ for level in range(n_levels):
1264
+ l_indices = np.where(data.obs[feature_name] == levels[level])
1265
+ feature_totals[level] = np.sum(data.X[l_indices], axis=0)
1266
+
1267
+ ax = self._stackbar(
1268
+ feature_totals,
1269
+ type_names=ct_names,
1270
+ title=feature_name,
1271
+ level_names=levels,
1272
+ figsize=figsize,
1273
+ dpi=dpi,
1274
+ palette=palette,
1275
+ show_legend=show_legend,
1276
+ )
1277
+
1278
+ if save:
1279
+ plt.savefig(save, bbox_inches="tight")
1280
+ if show:
1281
+ plt.show()
1282
+ if return_fig:
1283
+ return plt.gcf()
1284
+ if not (show or save):
1285
+ return ax
1286
+ return None
1287
+
1288
+ def plot_effects_barplot( # pragma: no cover
1289
+ self,
1290
+ data: AnnData | MuData,
1291
+ modality_key: str = "coda",
1292
+ covariates: str | list | None = None,
1293
+ parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
1294
+ plot_facets: bool = True,
1295
+ plot_zero_covariate: bool = True,
1296
+ plot_zero_cell_type: bool = False,
1297
+ palette: str | ListedColormap | None = cm.tab20,
1298
+ level_order: list[str] = None,
1299
+ args_barplot: dict | None = None,
1300
+ figsize: tuple[float, float] | None = None,
1301
+ dpi: int | None = 100,
1302
+ return_fig: bool | None = None,
1303
+ ax: plt.Axes | None = None,
1304
+ show: bool | None = None,
1305
+ save: str | bool | None = None,
1306
+ ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
1307
+ """Barplot visualization for effects.
1308
+
1309
+ The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
1310
+ The covariates groups can either be ordered along the x-axis of a single plot (plot_facets=False) or as plot facets (plot_facets=True).
1311
+
1312
+ Args:
1313
+ data: AnnData object or MuData object.
1314
+ modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
1315
+ covariates: The name of the covariates in data.obs to plot. Defaults to None.
1316
+ parameter: The parameter in effect summary to plot. Defaults to "log2-fold change".
1317
+ plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
1318
+ Defaults to True.
1319
+ plot_zero_covariate: If True, plot covariate that have all zero effects. If False, do not plot.
1320
+ Defaults to True.
1321
+ plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot.
1322
+ Defaults to False.
1323
+ figsize: Figure size. Defaults to None.
1324
+ dpi: Figure size. Defaults to 100.
1325
+ palette: The seaborn color map for the barplot. Defaults to cm.tab20.
1326
+ level_order: Custom ordering of bars on the x-axis. Defaults to None.
1327
+ args_barplot: Arguments passed to sns.barplot. Defaults to None.
1328
+
1329
+ Returns:
1330
+ Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
1331
+ or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
1332
+
1333
+ Examples:
1334
+ >>> import pertpy as pt
1335
+ >>> haber_cells = pt.dt.haber_2017_regions()
1336
+ >>> sccoda = pt.tl.Sccoda()
1337
+ >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
1338
+ sample_identifier="batch", covariate_obs=["condition"])
1339
+ >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
1340
+ >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
1341
+ >>> sccoda.plot_effects_barplot(mdata)
1342
+
1343
+ Preview:
1344
+ .. image:: /_static/docstring_previews/sccoda_effects_barplot.png
1345
+ """
1346
+ if args_barplot is None:
1347
+ args_barplot = {}
1348
+ if isinstance(data, MuData):
1349
+ data = data[modality_key]
1350
+ if isinstance(data, AnnData):
1351
+ data = data
1352
+ # Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
1353
+ covariate_names = data.uns["scCODA_params"]["covariate_names"]
1354
+ if covariates is not None:
1355
+ if isinstance(covariates, str):
1356
+ covariates = [covariates]
1357
+ partial_covariate_names = [
1358
+ covariate_name
1359
+ for covariate_name in covariate_names
1360
+ if any(covariate in covariate_name for covariate in covariates)
1361
+ ]
1362
+ covariate_names = partial_covariate_names
1363
+ covariate_names_non_zero = [
1364
+ covariate_name
1365
+ for covariate_name in covariate_names
1366
+ if data.varm[f"effect_df_{covariate_name}"][parameter].any()
1367
+ ]
1368
+ covariate_names_zero = list(set(covariate_names) - set(covariate_names_non_zero))
1369
+ if not plot_zero_covariate:
1370
+ covariate_names = covariate_names_non_zero
1371
+
1372
+ # set up df for plotting
1373
+ plot_df = pd.concat(
1374
+ [data.varm[f"effect_df_{covariate_name}"][parameter] for covariate_name in covariate_names],
1375
+ axis=1,
1376
+ )
1377
+ plot_df.columns = covariate_names
1378
+ plot_df = pd.melt(plot_df, ignore_index=False, var_name="Covariate")
1379
+
1380
+ plot_df = plot_df.reset_index()
1381
+
1382
+ if len(covariate_names_zero) != 0:
1383
+ if plot_facets:
1384
+ if plot_zero_covariate and not plot_zero_cell_type:
1385
+ plot_df = plot_df[plot_df["value"] != 0]
1386
+ for covariate_name_zero in covariate_names_zero:
1387
+ new_row = {
1388
+ "Covariate": covariate_name_zero,
1389
+ "Cell Type": "zero",
1390
+ "value": 0,
1391
+ }
1392
+ plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
1393
+ plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
1394
+ plot_df = plot_df.sort_values(["covariate_"])
1395
+ if not plot_zero_cell_type:
1396
+ cell_type_names_zero = [
1397
+ name
1398
+ for name in plot_df["Cell Type"].unique()
1399
+ if (plot_df[plot_df["Cell Type"] == name]["value"] == 0).all()
1400
+ ]
1401
+ plot_df = plot_df[~plot_df["Cell Type"].isin(cell_type_names_zero)]
1402
+
1403
+ # If plot as facets, create a FacetGrid and map barplot to it.
1404
+ if plot_facets:
1405
+ if isinstance(palette, ListedColormap):
1406
+ palette = np.array([palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]).tolist()
1407
+ if figsize is not None:
1408
+ height = figsize[0]
1409
+ aspect = np.round(figsize[1] / figsize[0], 2)
1410
+ else:
1411
+ height = 3
1412
+ aspect = 2
1413
+
1414
+ g = sns.FacetGrid(
1415
+ plot_df,
1416
+ col="Covariate",
1417
+ sharey=True,
1418
+ sharex=False,
1419
+ height=height,
1420
+ aspect=aspect,
1421
+ )
1422
+
1423
+ g.map(
1424
+ sns.barplot,
1425
+ "Cell Type",
1426
+ "value",
1427
+ palette=palette,
1428
+ order=level_order,
1429
+ **args_barplot,
1430
+ )
1431
+ g.set_xticklabels(rotation=90)
1432
+ g.set(ylabel=parameter)
1433
+ axes = g.axes.flatten()
1434
+ for i, ax in enumerate(axes):
1435
+ ax.set_title(covariate_names[i])
1436
+ if len(ax.get_xticklabels()) < 5:
1437
+ ax.set_aspect(10 / len(ax.get_xticklabels()))
1438
+ if len(ax.get_xticklabels()) == 1:
1439
+ if ax.get_xticklabels()[0]._text == "zero":
1440
+ ax.set_xticks([])
1441
+
1442
+ if save:
1443
+ plt.savefig(save, bbox_inches="tight")
1444
+ if show:
1445
+ plt.show()
1446
+ if return_fig:
1447
+ return plt.gcf()
1448
+ if not (show or save):
1449
+ return g
1450
+ return None
1451
+
1452
+ # If not plot as facets, call barplot to plot cell types on the x-axis.
1453
+ else:
1454
+ _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1455
+ if len(covariate_names) == 1:
1456
+ if isinstance(palette, ListedColormap):
1457
+ palette = np.array(
1458
+ [palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]
1459
+ ).tolist()
1460
+ sns.barplot(
1461
+ data=plot_df,
1462
+ x="Cell Type",
1463
+ y="value",
1464
+ hue="x",
1465
+ palette=palette,
1466
+ ax=ax,
1467
+ )
1468
+ ax.set_title(covariate_names[0])
1469
+ else:
1470
+ if isinstance(palette, ListedColormap):
1471
+ palette = np.array([palette(i % palette.N) for i in range(len(covariate_names))]).tolist()
1472
+ sns.barplot(
1473
+ data=plot_df,
1474
+ x="Cell Type",
1475
+ y="value",
1476
+ hue="Covariate",
1477
+ palette=palette,
1478
+ ax=ax,
1479
+ )
1480
+ cell_types = pd.unique(plot_df["Cell Type"])
1481
+ ax.set_xticklabels(cell_types, rotation=90)
1482
+
1483
+ if save:
1484
+ plt.savefig(save, bbox_inches="tight")
1485
+ if show:
1486
+ plt.show()
1487
+ if return_fig:
1488
+ return plt.gcf()
1489
+ if not (show or save):
1490
+ return ax
1491
+ return None
1492
+
1493
+ def plot_boxplots( # pragma: no cover
1494
+ self,
1495
+ data: AnnData | MuData,
1496
+ feature_name: str,
1497
+ modality_key: str = "coda",
1498
+ y_scale: Literal["relative", "log", "log10", "count"] = "relative",
1499
+ plot_facets: bool = False,
1500
+ add_dots: bool = False,
1501
+ cell_types: list | None = None,
1502
+ args_boxplot: dict | None = None,
1503
+ args_swarmplot: dict | None = None,
1504
+ palette: str | None = "Blues",
1505
+ show_legend: bool | None = True,
1506
+ level_order: list[str] = None,
1507
+ figsize: tuple[float, float] | None = None,
1508
+ dpi: int | None = 100,
1509
+ return_fig: bool | None = None,
1510
+ ax: plt.Axes | None = None,
1511
+ show: bool | None = None,
1512
+ save: str | bool | None = None,
1513
+ ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
1514
+ """Grouped boxplot visualization.
1515
+
1516
+ The cell counts for each cell type are shown as a group of boxplots
1517
+ with intra--group separation by a covariate from data.obs.
1518
+
1519
+ Args:
1520
+ data: AnnData object or MuData object
1521
+ feature_name: The name of the feature in data.obs to plot
1522
+ modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
1523
+ y_scale: Transformation to of cell counts. Options: "relative" - Relative abundance, "log" - log(count),
1524
+ "log10" - log10(count), "count" - absolute abundance (cell counts).
1525
+ Defaults to "relative".
1526
+ plot_facets: If False, plot cell types on the x-axis. If True, plot as facets. Defaults to False.
1527
+ add_dots: If True, overlay a scatterplot with one dot for each data point. Defaults to False.
1528
+ cell_types: Subset of cell types that should be plotted. Defaults to None.
1529
+ args_boxplot: Arguments passed to sns.boxplot. Defaults to {}.
1530
+ args_swarmplot: Arguments passed to sns.swarmplot. Defaults to {}.
1531
+ figsize: Figure size. Defaults to None.
1532
+ dpi: Dpi setting. Defaults to 100.
1533
+ palette: The seaborn color map for the barplot. Defaults to "Blues".
1534
+ show_legend: If True, adds a legend. Defaults to True.
1535
+ level_order: Custom ordering of bars on the x-axis. Defaults to None.
1536
+
1537
+ Returns:
1538
+ Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
1539
+ or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
1540
+
1541
+ Examples:
1542
+ >>> import pertpy as pt
1543
+ >>> haber_cells = pt.dt.haber_2017_regions()
1544
+ >>> sccoda = pt.tl.Sccoda()
1545
+ >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
1546
+ sample_identifier="batch", covariate_obs=["condition"])
1547
+ >>> sccoda.plot_boxplots(mdata, feature_name="condition", add_dots=True)
1548
+
1549
+ Preview:
1550
+ .. image:: /_static/docstring_previews/sccoda_boxplots.png
1551
+ """
1552
+ if args_boxplot is None:
1553
+ args_boxplot = {}
1554
+ if args_swarmplot is None:
1555
+ args_swarmplot = {}
1556
+ if isinstance(data, MuData):
1557
+ data = data[modality_key]
1558
+ if isinstance(data, AnnData):
1559
+ data = data
1560
+ # y scale transformations
1561
+ if y_scale == "relative":
1562
+ sample_sums = np.sum(data.X, axis=1, keepdims=True)
1563
+ X = data.X / sample_sums
1564
+ value_name = "Proportion"
1565
+ # add pseudocount 0.5 if using log scale
1566
+ elif y_scale == "log":
1567
+ X = data.X.copy()
1568
+ X[X == 0] = 0.5
1569
+ X = np.log(X)
1570
+ value_name = "log(count)"
1571
+ elif y_scale == "log10":
1572
+ X = data.X.copy()
1573
+ X[X == 0] = 0.5
1574
+ X = np.log(X)
1575
+ value_name = "log10(count)"
1576
+ elif y_scale == "count":
1577
+ X = data.X
1578
+ value_name = "count"
1579
+ else:
1580
+ raise ValueError("Invalid y_scale transformation")
1581
+
1582
+ count_df = pd.DataFrame(X, columns=data.var.index, index=data.obs.index).merge(
1583
+ data.obs[feature_name], left_index=True, right_index=True
1584
+ )
1585
+ plot_df = pd.melt(count_df, id_vars=feature_name, var_name="Cell type", value_name=value_name)
1586
+ if cell_types is not None:
1587
+ plot_df = plot_df[plot_df["Cell type"].isin(cell_types)]
1588
+
1589
+ # Currently disabled because the latest statsannotations does not support the latest seaborn.
1590
+ # We had to drop the dependency.
1591
+ # Get credible effects results from model
1592
+ # if draw_effects:
1593
+ # if model is not None:
1594
+ # credible_effects_df = model.credible_effects(data, modality_key).to_frame().reset_index()
1595
+ # else:
1596
+ # print("[bold yellow]Specify a tasCODA model to draw effects")
1597
+ # credible_effects_df[feature_name] = credible_effects_df["Covariate"].str.removeprefix(f"{feature_name}[T.")
1598
+ # credible_effects_df[feature_name] = credible_effects_df[feature_name].str.removesuffix("]")
1599
+ # credible_effects_df = credible_effects_df[credible_effects_df["Final Parameter"]]
1600
+
1601
+ # If plot as facets, create a FacetGrid and map boxplot to it.
1602
+ if plot_facets:
1603
+ if level_order is None:
1604
+ level_order = pd.unique(plot_df[feature_name])
1605
+
1606
+ K = X.shape[1]
1607
+
1608
+ if figsize is not None:
1609
+ height = figsize[0]
1610
+ aspect = np.round(figsize[1] / figsize[0], 2)
1611
+ else:
1612
+ height = 3
1613
+ aspect = 2
1614
+
1615
+ g = sns.FacetGrid(
1616
+ plot_df,
1617
+ col="Cell type",
1618
+ sharey=False,
1619
+ col_wrap=int(np.floor(np.sqrt(K))),
1620
+ height=height,
1621
+ aspect=aspect,
1622
+ )
1623
+ g.map(
1624
+ sns.boxplot,
1625
+ feature_name,
1626
+ value_name,
1627
+ palette=palette,
1628
+ order=level_order,
1629
+ **args_boxplot,
1630
+ )
1631
+
1632
+ if add_dots:
1633
+ if "hue" in args_swarmplot:
1634
+ hue = args_swarmplot.pop("hue")
1635
+ else:
1636
+ hue = None
1637
+
1638
+ if hue is None:
1639
+ g.map(
1640
+ sns.swarmplot,
1641
+ feature_name,
1642
+ value_name,
1643
+ color="black",
1644
+ order=level_order,
1645
+ **args_swarmplot,
1646
+ ).set_titles("{col_name}")
1647
+ else:
1648
+ g.map(
1649
+ sns.swarmplot,
1650
+ feature_name,
1651
+ value_name,
1652
+ hue,
1653
+ order=level_order,
1654
+ **args_swarmplot,
1655
+ ).set_titles("{col_name}")
1656
+
1657
+ if save:
1658
+ plt.savefig(save, bbox_inches="tight")
1659
+ if show:
1660
+ plt.show()
1661
+ if return_fig:
1662
+ return plt.gcf()
1663
+ if not (show or save):
1664
+ return g
1665
+ return None
1666
+
1667
+ # If not plot as facets, call boxplot to plot cell types on the x-axis.
1668
+ else:
1669
+ if level_order:
1670
+ args_boxplot["hue_order"] = level_order
1671
+ args_swarmplot["hue_order"] = level_order
1672
+
1673
+ _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1674
+
1675
+ ax = sns.boxplot(
1676
+ x="Cell type",
1677
+ y=value_name,
1678
+ hue=feature_name,
1679
+ data=plot_df,
1680
+ fliersize=1,
1681
+ palette=palette,
1682
+ ax=ax,
1683
+ **args_boxplot,
1684
+ )
1685
+
1686
+ # Currently disabled because the latest statsannotations does not support the latest seaborn.
1687
+ # We had to drop the dependency.
1688
+ # if draw_effects:
1689
+ # pairs = [
1690
+ # [(row["Cell Type"], row[feature_name]), (row["Cell Type"], "Control")]
1691
+ # for _, row in credible_effects_df.iterrows()
1692
+ # ]
1693
+ # annot = Annotator(ax, pairs, data=plot_df, x="Cell type", y=value_name, hue=feature_name)
1694
+ # annot.configure(test=None, loc="outside", color="red", line_height=0, verbose=False)
1695
+ # annot.set_custom_annotations([row[feature_name] for _, row in credible_effects_df.iterrows()])
1696
+ # annot.annotate()
1697
+
1698
+ if add_dots:
1699
+ sns.swarmplot(
1700
+ x="Cell type",
1701
+ y=value_name,
1702
+ data=plot_df,
1703
+ hue=feature_name,
1704
+ ax=ax,
1705
+ dodge=True,
1706
+ palette="dark:black",
1707
+ **args_swarmplot,
1708
+ )
1709
+
1710
+ cell_types = pd.unique(plot_df["Cell type"])
1711
+ ax.set_xticklabels(cell_types, rotation=90)
1712
+
1713
+ if show_legend:
1714
+ handles, labels = ax.get_legend_handles_labels()
1715
+ handout = []
1716
+ labelout = []
1717
+ for h, l in zip(handles, labels, strict=False):
1718
+ if l not in labelout:
1719
+ labelout.append(l)
1720
+ handout.append(h)
1721
+ ax.legend(
1722
+ handout,
1723
+ labelout,
1724
+ loc="upper left",
1725
+ bbox_to_anchor=(1, 1),
1726
+ ncol=1,
1727
+ title=feature_name,
1728
+ )
1729
+
1730
+ if save:
1731
+ plt.savefig(save, bbox_inches="tight")
1732
+ if show:
1733
+ plt.show()
1734
+ if return_fig:
1735
+ return plt.gcf()
1736
+ if not (show or save):
1737
+ return ax
1738
+ return None
1739
+
1740
+ def plot_rel_abundance_dispersion_plot( # pragma: no cover
1741
+ self,
1742
+ data: AnnData | MuData,
1743
+ modality_key: str = "coda",
1744
+ abundant_threshold: float | None = 0.9,
1745
+ default_color: str | None = "Grey",
1746
+ abundant_color: str | None = "Red",
1747
+ label_cell_types: bool = True,
1748
+ figsize: tuple[float, float] | None = None,
1749
+ dpi: int | None = 100,
1750
+ return_fig: bool | None = None,
1751
+ ax: plt.Axes | None = None,
1752
+ show: bool | None = None,
1753
+ save: str | bool | None = None,
1754
+ ) -> plt.Axes | plt.Figure | None:
1755
+ """Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
1756
+
1757
+ If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color.
1758
+
1759
+ Args:
1760
+ data: AnnData or MuData object.
1761
+ modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
1762
+ Defaults to "coda".
1763
+ abundant_threshold: Presence threshold for abundant cell types. Defaults to 0.9.
1764
+ default_color: Bar color for all non-minimal cell types. Defaults to "Grey".
1765
+ abundant_color: Bar color for cell types with abundant percentage larger than abundant_threshold.
1766
+ Defaults to "Red".
1767
+ label_cell_types: Label dots with cell type names. Defaults to True.
1768
+ figsize: Figure size. Defaults to None.
1769
+ dpi: Dpi setting. Defaults to 100.
1770
+ ax: A matplotlib axes object. Only works if plotting a single component. Defaults to None.
1771
+
1772
+ Returns:
1773
+ A :class:`~matplotlib.axes.Axes` object
1774
+
1775
+ Examples:
1776
+ >>> import pertpy as pt
1777
+ >>> haber_cells = pt.dt.haber_2017_regions()
1778
+ >>> sccoda = pt.tl.Sccoda()
1779
+ >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
1780
+ sample_identifier="batch", covariate_obs=["condition"])
1781
+ >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
1782
+ >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
1783
+ >>> sccoda.plot_rel_abundance_dispersion_plot(mdata)
1784
+
1785
+ Preview:
1786
+ .. image:: /_static/docstring_previews/sccoda_rel_abundance_dispersion_plot.png
1787
+ """
1788
+ if isinstance(data, MuData):
1789
+ data = data[modality_key]
1790
+ if isinstance(data, AnnData):
1791
+ data = data
1792
+ if ax is None:
1793
+ _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1794
+
1795
+ rel_abun = data.X / np.sum(data.X, axis=1, keepdims=True)
1796
+
1797
+ percent_zero = np.sum(data.X == 0, axis=0) / data.X.shape[0]
1798
+ nonrare_ct = np.where(percent_zero < 1 - abundant_threshold)[0]
1799
+
1800
+ # select reference
1801
+ cell_type_disp = np.var(rel_abun, axis=0) / np.mean(rel_abun, axis=0)
1802
+
1803
+ is_abundant = [x in nonrare_ct for x in range(data.X.shape[1])]
1804
+
1805
+ # Scatterplot
1806
+ plot_df = pd.DataFrame(
1807
+ {
1808
+ "Total dispersion": cell_type_disp,
1809
+ "Cell type": data.var.index,
1810
+ "Presence": 1 - percent_zero,
1811
+ "Is abundant": is_abundant,
1812
+ }
1813
+ )
1814
+
1815
+ if len(np.unique(plot_df["Is abundant"])) > 1:
1816
+ palette = [default_color, abundant_color]
1817
+ elif np.unique(plot_df["Is abundant"]) == [False]:
1818
+ palette = [default_color]
1819
+ else:
1820
+ palette = [abundant_color]
1821
+
1822
+ ax = sns.scatterplot(
1823
+ data=plot_df,
1824
+ x="Presence",
1825
+ y="Total dispersion",
1826
+ hue="Is abundant",
1827
+ palette=palette,
1828
+ ax=ax,
1829
+ )
1830
+
1831
+ # Text labels for abundant cell types
1832
+
1833
+ abundant_df = plot_df.loc[plot_df["Is abundant"], :]
1834
+
1835
+ def label_point(x, y, val, ax):
1836
+ a = pd.concat({"x": x, "y": y, "val": val}, axis=1)
1837
+ texts = [
1838
+ ax.text(
1839
+ point["x"],
1840
+ point["y"],
1841
+ str(point["val"]),
1842
+ )
1843
+ for i, point in a.iterrows()
1844
+ ]
1845
+ adjust_text(texts)
1846
+
1847
+ if label_cell_types:
1848
+ label_point(
1849
+ abundant_df["Presence"],
1850
+ abundant_df["Total dispersion"],
1851
+ abundant_df["Cell type"],
1852
+ plt.gca(),
1853
+ )
1854
+
1855
+ ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
1856
+
1857
+ if save:
1858
+ plt.savefig(save, bbox_inches="tight")
1859
+ if show:
1860
+ plt.show()
1861
+ if return_fig:
1862
+ return plt.gcf()
1863
+ if not (show or save):
1864
+ return ax
1865
+ return None
1866
+
1867
+ def plot_draw_tree( # pragma: no cover
1868
+ self,
1869
+ data: AnnData | MuData,
1870
+ modality_key: str = "coda",
1871
+ tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1872
+ tight_text: bool | None = False,
1873
+ show_scale: bool | None = False,
1874
+ units: Literal["px", "mm", "in"] | None = "px",
1875
+ figsize: tuple[float, float] | None = (None, None),
1876
+ dpi: int | None = 100,
1877
+ show: bool | None = True,
1878
+ save: str | bool | None = None,
1879
+ ) -> Tree | None:
1880
+ """Plot a tree using input ete3 tree object.
1881
+
1882
+ Args:
1883
+ data: AnnData object or MuData object.
1884
+ modality_key: If data is a MuData object, specify which modality to use.
1885
+ Defaults to "coda".
1886
+ tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
1887
+ Defaults to "tree".
1888
+ tight_text: When False, boundaries of the text are approximated according to general font metrics,
1889
+ producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
1890
+ Default to False.
1891
+ show_scale: Include the scale legend in the tree image or not.
1892
+ Defaults to False.
1893
+ show: If True, plot the tree inline. If false, return tree and tree_style objects.
1894
+ Defaults to True.
1895
+ file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG.
1896
+ Output image can be saved whether show is True or not.
1897
+ Defaults to None.
1898
+ units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px".
1899
+ figsize: Figure size. Defaults to None.
1900
+ dpi: Dots per inches. Defaults to 100.
1901
+
1902
+ Returns:
1903
+ Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
1904
+
1905
+ Examples:
1906
+ >>> import pertpy as pt
1907
+ >>> adata = pt.dt.tasccoda_example()
1908
+ >>> tasccoda = pt.tl.Tasccoda()
1909
+ >>> mdata = tasccoda.load(
1910
+ >>> adata, type="sample_level",
1911
+ >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
1912
+ >>> key_added="lineage", add_level_name=True
1913
+ >>> )
1914
+ >>> mdata = tasccoda.prepare(
1915
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
1916
+ >>> )
1917
+ >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
1918
+ >>> tasccoda.plot_draw_tree(mdata, tree="lineage")
1919
+
1920
+ Preview:
1921
+ .. image:: /_static/docstring_previews/tasccoda_draw_tree.png
1922
+ """
1923
+ try:
1924
+ from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
1925
+ except ImportError:
1926
+ raise ImportError(
1927
+ "To use tasccoda please install additional dependencies with `pip install pertpy[coda]`"
1928
+ ) from None
1929
+
1930
+ if isinstance(data, MuData):
1931
+ data = data[modality_key]
1932
+ if isinstance(data, AnnData):
1933
+ data = data
1934
+ if isinstance(tree, str):
1935
+ tree = data.uns[tree]
1936
+
1937
+ def my_layout(node):
1938
+ text_face = TextFace(node.name, tight_text=tight_text)
1939
+ faces.add_face_to_node(text_face, node, column=0, position="branch-right")
1940
+
1941
+ tree_style = TreeStyle()
1942
+ tree_style.show_leaf_name = False
1943
+ tree_style.layout_fn = my_layout
1944
+ tree_style.show_scale = show_scale
1945
+
1946
+ if save is not None:
1947
+ tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1948
+ if show:
1949
+ return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1950
+ else:
1951
+ return tree, tree_style
1952
+
1953
+ def plot_draw_effects( # pragma: no cover
1954
+ self,
1955
+ data: AnnData | MuData,
1956
+ covariate: str,
1957
+ modality_key: str = "coda",
1958
+ tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1959
+ show_legend: bool | None = None,
1960
+ show_leaf_effects: bool | None = False,
1961
+ tight_text: bool | None = False,
1962
+ show_scale: bool | None = False,
1963
+ units: Literal["px", "mm", "in"] | None = "px",
1964
+ figsize: tuple[float, float] | None = (None, None),
1965
+ dpi: int | None = 100,
1966
+ show: bool | None = True,
1967
+ save: str | None = None,
1968
+ ) -> Tree | None:
1969
+ """Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
1970
+
1971
+ Args:
1972
+ data: AnnData object or MuData object.
1973
+ covariate: The covariate, whose effects should be plotted.
1974
+ modality_key: If data is a MuData object, specify which modality to use.
1975
+ Defaults to "coda".
1976
+ tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
1977
+ Defaults to "tree".
1978
+ show_legend: If show legend of nodes significant effects or not.
1979
+ Defaults to False if show_leaf_effects is True.
1980
+ show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects.
1981
+ Defaults to False.
1982
+ tight_text: When False, boundaries of the text are approximated according to general font metrics,
1983
+ producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
1984
+ Defaults to False.
1985
+ show_scale: Include the scale legend in the tree image or not. Defaults to False.
1986
+ show: If True, plot the tree inline. If false, return tree and tree_style objects. Defaults to True.
1987
+ file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
1988
+ Defaults to None.
1989
+ units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px".
1990
+ figsize: Figure size. Defaults to None.
1991
+ dpi: Dots per inches. Defaults to 100.
1992
+
1993
+ Returns:
1994
+ Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`)
1995
+ or plot the tree inline (`show = False`)
1996
+
1997
+ Examples:
1998
+ >>> import pertpy as pt
1999
+ >>> adata = pt.dt.tasccoda_example()
2000
+ >>> tasccoda = pt.tl.Tasccoda()
2001
+ >>> mdata = tasccoda.load(
2002
+ >>> adata, type="sample_level",
2003
+ >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
2004
+ >>> key_added="lineage", add_level_name=True
2005
+ >>> )
2006
+ >>> mdata = tasccoda.prepare(
2007
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
2008
+ >>> )
2009
+ >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
2010
+ >>> tasccoda.plot_draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
2011
+
2012
+ Preview:
2013
+ .. image:: /_static/docstring_previews/tasccoda_draw_effects.png
2014
+ """
2015
+ try:
2016
+ from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
2017
+ except ImportError:
2018
+ raise ImportError(
2019
+ "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
2020
+ ) from None
2021
+
2022
+ if isinstance(data, MuData):
2023
+ data = data[modality_key]
2024
+ if isinstance(data, AnnData):
2025
+ data = data
2026
+ if show_legend is None:
2027
+ show_legend = not show_leaf_effects
2028
+ elif show_legend:
2029
+ print("Tree leaves and leaf effect bars won't be aligned when legend is shown!")
2030
+
2031
+ if isinstance(tree, str):
2032
+ tree = data.uns[tree]
2033
+ # Collapse tree singularities
2034
+ tree2 = collapse_singularities_2(tree)
2035
+
2036
+ node_effs = data.uns["scCODA_params"]["node_df"].loc[(covariate + "_node",),].copy()
2037
+ node_effs.index = node_effs.index.get_level_values("Node")
2038
+
2039
+ covariates = data.uns["scCODA_params"]["covariate_names"]
2040
+ effect_dfs = [data.varm[f"effect_df_{cov}"] for cov in covariates]
2041
+ eff_df = pd.concat(effect_dfs)
2042
+ eff_df.index = pd.MultiIndex.from_product(
2043
+ (covariates, data.var.index.tolist()),
2044
+ names=["Covariate", "Cell Type"],
2045
+ )
2046
+ leaf_effs = eff_df.loc[(covariate,),].copy()
2047
+ leaf_effs.index = leaf_effs.index.get_level_values("Cell Type")
2048
+
2049
+ # Add effect values
2050
+ for n in tree2.traverse():
2051
+ nstyle = NodeStyle()
2052
+ nstyle["size"] = 0
2053
+ n.set_style(nstyle)
2054
+ if n.name in node_effs.index:
2055
+ e = node_effs.loc[n.name, "Final Parameter"]
2056
+ n.add_feature("node_effect", e)
2057
+ else:
2058
+ n.add_feature("node_effect", 0)
2059
+ if n.name in leaf_effs.index:
2060
+ e = leaf_effs.loc[n.name, "Effect"]
2061
+ n.add_feature("leaf_effect", e)
2062
+ else:
2063
+ n.add_feature("leaf_effect", 0)
2064
+
2065
+ # Scale effect values to get nice node sizes
2066
+ eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
2067
+ leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
2068
+
2069
+ def my_layout(node):
2070
+ text_face = TextFace(node.name, tight_text=tight_text)
2071
+ text_face.margin_left = 10
2072
+ faces.add_face_to_node(text_face, node, column=0, aligned=True)
2073
+
2074
+ # if node.is_leaf():
2075
+ size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
2076
+ if np.sign(node.node_effect) == 1:
2077
+ color = "blue"
2078
+ elif np.sign(node.node_effect) == -1:
2079
+ color = "red"
2080
+ else:
2081
+ color = "cyan"
2082
+ if size != 0:
2083
+ faces.add_face_to_node(CircleFace(radius=size, color=color), node, column=0)
2084
+
2085
+ tree_style = TreeStyle()
2086
+ tree_style.show_leaf_name = False
2087
+ tree_style.layout_fn = my_layout
2088
+ tree_style.show_scale = show_scale
2089
+ tree_style.draw_guiding_lines = True
2090
+ tree_style.legend_position = 1
2091
+
2092
+ if show_legend:
2093
+ tree_style.legend.add_face(TextFace("Effects"), column=0)
2094
+ tree_style.legend.add_face(TextFace(" "), column=1)
2095
+ for i in range(4, 0, -1):
2096
+ tree_style.legend.add_face(
2097
+ CircleFace(
2098
+ float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
2099
+ "red",
2100
+ ),
2101
+ column=0,
2102
+ )
2103
+ tree_style.legend.add_face(TextFace(f"{-eff_max * i / 4:.2f} "), column=0)
2104
+ tree_style.legend.add_face(
2105
+ CircleFace(
2106
+ float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
2107
+ "blue",
2108
+ ),
2109
+ column=1,
2110
+ )
2111
+ tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
2112
+
2113
+ if show_leaf_effects:
2114
+ leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
2115
+ leaf_effs = leaf_effs.loc[leaf_name].reset_index()
2116
+ palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
2117
+
2118
+ dir_path = Path.cwd()
2119
+ dir_path = Path(dir_path / "tree_effect.png")
2120
+ tree2.render(dir_path, tree_style=tree_style, units="in")
2121
+ _, ax = plt.subplots(1, 2, figsize=(10, 10))
2122
+ sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
2123
+ img = mpimg.imread(dir_path)
2124
+ ax[0].imshow(img)
2125
+ ax[0].get_xaxis().set_visible(False)
2126
+ ax[0].get_yaxis().set_visible(False)
2127
+ ax[0].set_frame_on(False)
2128
+
2129
+ ax[1].get_yaxis().set_visible(False)
2130
+ ax[1].spines["left"].set_visible(False)
2131
+ ax[1].spines["right"].set_visible(False)
2132
+ ax[1].spines["top"].set_visible(False)
2133
+ plt.xlim(-leaf_eff_max, leaf_eff_max)
2134
+ plt.subplots_adjust(wspace=0)
2135
+
2136
+ if save is not None:
2137
+ plt.savefig(save)
2138
+
2139
+ if save is not None and not show_leaf_effects:
2140
+ tree2.render(save, tree_style=tree_style, units=units)
2141
+ if show:
2142
+ if not show_leaf_effects:
2143
+ return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
2144
+ else:
2145
+ if not show_leaf_effects:
2146
+ return tree2, tree_style
2147
+ return None
2148
+
2149
+ def plot_effects_umap( # pragma: no cover
2150
+ self,
2151
+ mdata: MuData,
2152
+ effect_name: str | list | None,
2153
+ cluster_key: str,
2154
+ modality_key_1: str = "rna",
2155
+ modality_key_2: str = "coda",
2156
+ color_map: Colormap | str | None = None,
2157
+ palette: str | Sequence[str] | None = None,
2158
+ return_fig: bool | None = None,
2159
+ ax: Axes = None,
2160
+ show: bool = None,
2161
+ save: str | bool | None = None,
2162
+ **kwargs,
2163
+ ) -> plt.Axes | plt.Figure | None:
2164
+ """Plot a UMAP visualization colored by effect strength.
2165
+
2166
+ Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData
2167
+ (default is data['rna']) depending on the cluster they were assigned to.
2168
+
2169
+ Args:
2170
+ mudata: MuData object.
2171
+ effect_name: The name of the effect results in .varm of aggregated sample-level AnnData to plot
2172
+ cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']).
2173
+ To assign cell types' effects to original cells.
2174
+ modality_key_1: Key to the cell-level AnnData in the MuData object. Defaults to "rna".
2175
+ modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
2176
+ Defaults to "coda".
2177
+ show: Whether to display the figure or return axis. Defaults to None.
2178
+ ax: A matplotlib axes object. Only works if plotting a single component.
2179
+ Defaults to None.
2180
+ **kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
2181
+
2182
+ Returns:
2183
+ If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
2184
+
2185
+ Examples:
2186
+ >>> import pertpy as pt
2187
+ >>> import scanpy as sc
2188
+ >>> import schist
2189
+ >>> adata = pt.dt.haber_2017_regions()
2190
+ >>> sc.pp.neighbors(adata)
2191
+ >>> schist.inference.nested_model(adata, n_init=100, random_seed=5678)
2192
+ >>> tasccoda_model = pt.tl.Tasccoda()
2193
+ >>> tasccoda_data = tasccoda_model.load(adata, type="cell_level",
2194
+ >>> cell_type_identifier="nsbm_level_1",
2195
+ >>> sample_identifier="batch", covariate_obs=["condition"],
2196
+ >>> levels_orig=["nsbm_level_4", "nsbm_level_3", "nsbm_level_2", "nsbm_level_1"],
2197
+ >>> add_level_name=True)
2198
+ >>> tasccoda_model.prepare(
2199
+ >>> tasccoda_data,
2200
+ >>> modality_key="coda",
2201
+ >>> reference_cell_type="18",
2202
+ >>> formula="condition",
2203
+ >>> pen_args={"phi": 0, "lambda_1": 3.5},
2204
+ >>> tree_key="tree"
2205
+ >>> )
2206
+ >>> tasccoda_model.run_nuts(
2207
+ ... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
2208
+ ... )
2209
+ >>> tasccoda_model.run_nuts(
2210
+ ... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
2211
+ ... )
2212
+ >>> sc.tl.umap(tasccoda_data["rna"])
2213
+ >>> tasccoda_model.plot_effects_umap(tasccoda_data,
2214
+ >>> effect_name=["effect_df_condition[T.Salmonella]",
2215
+ >>> "effect_df_condition[T.Hpoly.Day3]",
2216
+ >>> "effect_df_condition[T.Hpoly.Day10]"],
2217
+ >>> cluster_key="nsbm_level_1",
2218
+ >>> )
2219
+
2220
+ Preview:
2221
+ .. image:: /_static/docstring_previews/tasccoda_effects_umap.png
2222
+ """
2223
+ # TODO: Add effect_name parameter and cluster_key and test the example
2224
+ data_rna = mdata[modality_key_1]
2225
+ data_coda = mdata[modality_key_2]
2226
+ if isinstance(effect_name, str):
2227
+ effect_name = [effect_name]
2228
+ for _, effect in enumerate(effect_name):
2229
+ data_rna.obs[effect] = [data_coda.varm[effect].loc[f"{c}", "Effect"] for c in data_rna.obs[cluster_key]]
2230
+ if kwargs.get("vmin"):
2231
+ vmin = kwargs["vmin"]
2232
+ kwargs.pop("vmin")
2233
+ else:
2234
+ vmin = min(data_rna.obs[effect].min() for _, effect in enumerate(effect_name))
2235
+ if kwargs.get("vmax"):
2236
+ vmax = kwargs["vmax"]
2237
+ kwargs.pop("vmax")
2238
+ else:
2239
+ vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name))
2240
+
2241
+ return sc.pl.umap(
2242
+ data_rna,
2243
+ color=effect_name,
2244
+ vmax=vmax,
2245
+ vmin=vmin,
2246
+ palette=palette,
2247
+ color_map=color_map,
2248
+ return_fig=return_fig,
2249
+ ax=ax,
2250
+ show=show,
2251
+ save=save,
2252
+ **kwargs,
2253
+ )
2254
+
1116
2255
 
1117
2256
  def get_a(
1118
2257
  tree: tt.tree,
@@ -1242,7 +2381,7 @@ def df2newick(df: pd.DataFrame, levels: list[str], inner_label: bool = True) ->
1242
2381
 
1243
2382
 
1244
2383
  def get_a_2(
1245
- tree: ete.Tree,
2384
+ tree: Tree,
1246
2385
  leaf_order: list[str] = None,
1247
2386
  node_order: list[str] = None,
1248
2387
  ) -> tuple[np.ndarray, int]:
@@ -1263,6 +2402,13 @@ def get_a_2(
1263
2402
  T
1264
2403
  number of nodes in the tree, excluding the root node
1265
2404
  """
2405
+ try:
2406
+ import ete3 as ete
2407
+ except ImportError:
2408
+ raise ImportError(
2409
+ "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
2410
+ ) from None
2411
+
1266
2412
  n_tips = len(tree.get_leaves())
1267
2413
  n_nodes = len(tree.get_descendants())
1268
2414
 
@@ -1292,7 +2438,7 @@ def get_a_2(
1292
2438
  return A_, n_nodes
1293
2439
 
1294
2440
 
1295
- def collapse_singularities_2(tree: ete.Tree) -> ete.Tree:
2441
+ def collapse_singularities_2(tree: Tree) -> Tree:
1296
2442
  """Collapses (deletes) nodes in a ete3 tree that are singularities (have only one child).
1297
2443
 
1298
2444
  Args:
@@ -1327,10 +2473,10 @@ def linkage_to_newick(
1327
2473
 
1328
2474
  def build_newick(node, newick, parentdist, leaf_names):
1329
2475
  if node.is_leaf():
1330
- return f"{leaf_names[node.id]}:{(parentdist - node.dist)/2}{newick}"
2476
+ return f"{leaf_names[node.id]}:{(parentdist - node.dist) / 2}{newick}"
1331
2477
  else:
1332
2478
  if len(newick) > 0:
1333
- newick = f"):{(parentdist - node.dist)/2}{newick}"
2479
+ newick = f"):{(parentdist - node.dist) / 2}{newick}"
1334
2480
  else:
1335
2481
  newick = ");"
1336
2482
  newick = build_newick(node.get_left(), newick, node.dist, leaf_names)
@@ -1363,14 +2509,15 @@ def import_tree(
1363
2509
 
1364
2510
  Args:
1365
2511
  data: A tascCODA-compatible data object.
1366
- modality_1: If `data` is MuData, specifiy the modality name to the original cell level anndata object. Defaults to None.
1367
- modality_2: If `data` is MuData, specifiy the modality name to the aggregated level anndata object. Defaults to None.
2512
+ modality_1: If `data` is MuData, specify the modality name to the original cell level anndata object. Defaults to None.
2513
+ modality_2: If `data` is MuData, specify the modality name to the aggregated level anndata object. Defaults to None.
1368
2514
  dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object. Defaults to None.
1369
2515
  levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None.
1370
2516
  levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None.
1371
- add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}. Defaults to True.
1372
- key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`. If `data` is MuData, save tree in data[modality_2]. Defaults to "tree".
1373
- copy: Return a copy instead of writing to `data`. Defaults to False.
2517
+ add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
2518
+ Defaults to True.
2519
+ key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`.
2520
+ If `data` is MuData, save tree in data[modality_2]. Defaults to "tree".
1374
2521
 
1375
2522
  Returns:
1376
2523
  Updates data with the following:
@@ -1379,6 +2526,13 @@ def import_tree(
1379
2526
 
1380
2527
  tree: A ete3 tree object.
1381
2528
  """
2529
+ try:
2530
+ import ete3 as ete
2531
+ except ImportError:
2532
+ raise ImportError(
2533
+ "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
2534
+ ) from None
2535
+
1382
2536
  if isinstance(data, MuData):
1383
2537
  try:
1384
2538
  data_1 = data[modality_1]
@@ -1443,16 +2597,17 @@ def from_scanpy(
1443
2597
 
1444
2598
  The anndata object needs to have a column in adata.obs that contains the cell type assignment.
1445
2599
  Further, it must contain one column or a set of columns (e.g. subject id, treatment, disease status) that uniquely identify each (statistical) sample.
1446
- Further covariates (e.g. subject age) can either be specified via addidional column names in adata.obs, a key in adata.uns, or as a separate DataFrame.
2600
+ Further covariates (e.g. subject age) can either be specified via additional column names in adata.obs, a key in adata.uns, or as a separate DataFrame.
1447
2601
 
1448
- NOTE: The order of samples in the returned dataset is determined by the first occurence of cells from each sample in `adata`
2602
+ NOTE: The order of samples in the returned dataset is determined by the first occurrence of cells from each sample in `adata`
1449
2603
 
1450
2604
  Args:
1451
2605
  adata: An anndata object from scanpy
1452
2606
  cell_type_identifier: column name in adata.obs that specifies the cell types
1453
2607
  sample_identifier: column name or list of column names in adata.obs that uniquely identify each sample
1454
2608
  covariate_uns: key for adata.uns, where covariate values are stored
1455
- covariate_obs: list of column names in adata.obs, where covariate values are stored. Note: If covariate values are not unique for a value of sample_identifier, this covaariate will be skipped.
2609
+ covariate_obs: list of column names in adata.obs, where covariate values are stored.
2610
+ Note: If covariate values are not unique for a value of sample_identifier, this covariate will be skipped.
1456
2611
  covariate_df: DataFrame with covariates
1457
2612
 
1458
2613
  Returns:
@@ -1461,50 +2616,40 @@ def from_scanpy(
1461
2616
  if isinstance(sample_identifier, str):
1462
2617
  sample_identifier = [sample_identifier]
1463
2618
 
1464
- if covariate_obs:
1465
- covariate_obs += [i for i in sample_identifier if i not in covariate_obs]
1466
- else:
1467
- covariate_obs = sample_identifier # type: ignore
1468
-
1469
- # join sample identifiers
1470
- if isinstance(sample_identifier, list):
2619
+ if len(sample_identifier) > 1:
1471
2620
  adata.obs["scCODA_sample_id"] = adata.obs[sample_identifier].agg("-".join, axis=1)
1472
2621
  sample_identifier = "scCODA_sample_id"
2622
+ else:
2623
+ sample_identifier = sample_identifier[0]
1473
2624
 
1474
2625
  # get cell type counts
1475
- groups = adata.obs.value_counts([sample_identifier, cell_type_identifier])
1476
- count_data = groups.unstack(level=cell_type_identifier)
1477
- count_data = count_data.fillna(0)
2626
+ ct_count_data = pd.crosstab(adata.obs[sample_identifier], adata.obs[cell_type_identifier])
2627
+ ct_count_data = ct_count_data.fillna(0)
1478
2628
 
1479
2629
  # get covariates from different sources
1480
- covariate_df_ = pd.DataFrame(index=count_data.index)
1481
-
1482
- if covariate_df is None and covariate_obs is None and covariate_uns is None:
1483
- print("No covariate information specified!")
2630
+ covariate_df_ = pd.DataFrame(index=ct_count_data.index)
1484
2631
 
1485
2632
  if covariate_uns is not None:
1486
- covariate_df_uns = pd.DataFrame(adata.uns[covariate_uns])
1487
- covariate_df_ = pd.concat((covariate_df_, covariate_df_uns), axis=1)
2633
+ covariate_df_uns = pd.DataFrame(adata.uns[covariate_uns], index=ct_count_data.index)
2634
+ covariate_df_ = covariate_df_.join(covariate_df_uns, how="left")
1488
2635
 
1489
- if covariate_obs is not None:
1490
- for c in covariate_obs:
1491
- if any(adata.obs.groupby(sample_identifier).nunique()[c] != 1):
1492
- print(f"Covariate {c} has non-unique values! Skipping...")
1493
- covariate_obs.remove(c)
2636
+ if covariate_obs:
2637
+ is_unique = adata.obs.groupby(sample_identifier, observed=True).transform(lambda x: x.nunique() == 1)
2638
+ unique_covariates = is_unique.columns[is_unique.all()].tolist()
1494
2639
 
1495
- covariate_df_obs = adata.obs.groupby(sample_identifier).first()[covariate_obs]
1496
- covariate_df_ = pd.concat((covariate_df_, covariate_df_obs), axis=1)
2640
+ if len(unique_covariates) < len(covariate_obs):
2641
+ skipped = set(covariate_obs) - set(unique_covariates)
2642
+ print(f"[bold yellow]Covariates {skipped} have non-unique values! Skipping...")
2643
+ if unique_covariates:
2644
+ covariate_df_obs = adata.obs.groupby(sample_identifier, observed=True).first()[unique_covariates]
2645
+ covariate_df_ = covariate_df_.join(covariate_df_obs, how="left")
1497
2646
 
1498
2647
  if covariate_df is not None:
1499
- if set(covariate_df.index) != set(count_data.index):
1500
- raise ValueError("anndata sample names and covariate_df index do not have the same elements!")
1501
- covs_ord = covariate_df.reindex(count_data.index)
1502
- covariate_df_ = pd.concat((covariate_df_, covs_ord), axis=1)
1503
-
1504
- covariate_df_.index = covariate_df_.index.astype(str)
2648
+ if not covariate_df.index.equals(ct_count_data.index):
2649
+ raise ValueError("AnnData sample names and covariate_df index do not have the same elements!")
2650
+ covariate_df_ = covariate_df_.join(covariate_df, how="left")
1505
2651
 
1506
- # create var (number of cells for each type as only column)
1507
- var_dat = count_data.sum(axis=0).rename("n_cells").to_frame()
2652
+ var_dat = ct_count_data.sum(axis=0).rename("n_cells").to_frame()
1508
2653
  var_dat.index = var_dat.index.astype(str)
1509
2654
 
1510
- return AnnData(X=count_data.values, var=var_dat, obs=covariate_df_)
2655
+ return AnnData(X=ct_count_data.values, var=var_dat, obs=covariate_df_)