pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,17 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Literal, Optional, Union
5
6
 
6
7
  import arviz as az
7
- import ete3 as ete
8
8
  import jax.numpy as jnp
9
+ import matplotlib.pyplot as plt
9
10
  import numpy as np
10
11
  import pandas as pd
11
12
  import patsy as pt
13
+ import scanpy as sc
14
+ import seaborn as sns
15
+ from adjustText import adjust_text
12
16
  from anndata import AnnData
13
- from jax import random
14
- from jax.config import config
17
+ from jax import config, random
18
+ from matplotlib import cm, rcParams
19
+ from matplotlib import image as mpimg
20
+ from matplotlib.colors import ListedColormap
15
21
  from mudata import MuData
16
22
  from numpyro.infer import HMC, MCMC, NUTS, initialization
17
23
  from rich import box, print
@@ -20,10 +26,15 @@ from rich.table import Table
20
26
  from scipy.cluster import hierarchy as sp_hierarchy
21
27
 
22
28
  if TYPE_CHECKING:
29
+ from collections.abc import Sequence
30
+
23
31
  import numpyro as npy
24
32
  import toytree as tt
25
- from jax._src.prng import PRNGKeyArray
33
+ from ete3 import Tree
26
34
  from jax._src.typing import Array
35
+ from matplotlib.axes import Axes
36
+ from matplotlib.colors import Colormap
37
+ from matplotlib.figure import Figure
27
38
 
28
39
  config.update("jax_enable_x64", True)
29
40
 
@@ -179,7 +190,7 @@ class CompositionalModel2(ABC):
179
190
  self,
180
191
  sample_adata: AnnData,
181
192
  kernel: npy.infer.mcmc.MCMCKernel,
182
- rng_key: Array | PRNGKeyArray,
193
+ rng_key: Array,
183
194
  copy: bool = False,
184
195
  *args,
185
196
  **kwargs,
@@ -295,7 +306,7 @@ class CompositionalModel2(ABC):
295
306
  if copy:
296
307
  sample_adata = sample_adata.copy()
297
308
 
298
- rng_key_array = random.PRNGKey(rng_key)
309
+ rng_key_array = random.key(rng_key)
299
310
  sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = np.array(rng_key_array)
300
311
 
301
312
  # Set up NUTS kernel
@@ -335,7 +346,6 @@ class CompositionalModel2(ABC):
335
346
  copy: Return a copy instead of writing to adata. Defaults to False.
336
347
 
337
348
  Examples:
338
- Example with scCODA:
339
349
  >>> import pertpy as pt
340
350
  >>> haber_cells = pt.dt.haber_2017_regions()
341
351
  >>> sccoda = pt.tl.Sccoda()
@@ -358,10 +368,10 @@ class CompositionalModel2(ABC):
358
368
  # Set rng key if needed
359
369
  if rng_key is None:
360
370
  rng = np.random.default_rng()
361
- rng_key = random.PRNGKey(rng.integers(0, 10000))
371
+ rng_key = random.key(rng.integers(0, 10000))
362
372
  sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = rng_key
363
373
  else:
364
- rng_key = random.PRNGKey(rng_key)
374
+ rng_key = random.key(rng_key)
365
375
 
366
376
  # Set up HMC kernel
367
377
  sample_adata = self.set_init_mcmc_states(
@@ -423,7 +433,6 @@ class CompositionalModel2(ABC):
423
433
  - Is credible: Boolean indicator whether effect is credible
424
434
 
425
435
  Examples:
426
- Example with scCODA:
427
436
  >>> import pertpy as pt
428
437
  >>> haber_cells = pt.dt.haber_2017_regions()
429
438
  >>> sccoda = pt.tl.Sccoda()
@@ -433,7 +442,6 @@ class CompositionalModel2(ABC):
433
442
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
434
443
  >>> intercept_df, effect_df = sccoda.summary_prepare(mdata["coda"])
435
444
  """
436
- # Get model and effect selection types
437
445
  select_type = sample_adata.uns["scCODA_params"]["select_type"]
438
446
  model_type = sample_adata.uns["scCODA_params"]["model_type"]
439
447
 
@@ -548,7 +556,11 @@ class CompositionalModel2(ABC):
548
556
  intercept_df = intercept_df.loc[:, ["final_parameter", hdis[0], hdis[1], "sd", "expected_sample"]].copy()
549
557
  intercept_df = intercept_df.rename(
550
558
  columns=dict(
551
- zip(intercept_df.columns, ["Final Parameter", hdis_new[0], hdis_new[1], "SD", "Expected Sample"])
559
+ zip(
560
+ intercept_df.columns,
561
+ ["Final Parameter", hdis_new[0], hdis_new[1], "SD", "Expected Sample"],
562
+ strict=False,
563
+ )
552
564
  )
553
565
  )
554
566
 
@@ -561,6 +573,7 @@ class CompositionalModel2(ABC):
561
573
  zip(
562
574
  effect_df.columns,
563
575
  ["Effect", "Median", hdis_new[0], hdis_new[1], "SD", "Expected Sample", "log2-fold change"],
576
+ strict=False,
564
577
  )
565
578
  )
566
579
  )
@@ -581,6 +594,7 @@ class CompositionalModel2(ABC):
581
594
  "Expected Sample",
582
595
  "log2-fold change",
583
596
  ],
597
+ strict=False,
584
598
  )
585
599
  )
586
600
  )
@@ -594,6 +608,7 @@ class CompositionalModel2(ABC):
594
608
  zip(
595
609
  node_df.columns,
596
610
  ["Final Parameter", "Median", hdis_new[0], hdis_new[1], "SD", "Delta", "Is credible"],
611
+ strict=False,
597
612
  )
598
613
  ) # type: ignore
599
614
  ) # type: ignore
@@ -781,7 +796,6 @@ class CompositionalModel2(ABC):
781
796
  kwargs: Passed to az.summary
782
797
 
783
798
  Examples:
784
- Example with scCODA:
785
799
  >>> import pertpy as pt
786
800
  >>> haber_cells = pt.dt.haber_2017_regions()
787
801
  >>> sccoda = pt.tl.Sccoda()
@@ -799,7 +813,7 @@ class CompositionalModel2(ABC):
799
813
  raise
800
814
  if isinstance(data, AnnData):
801
815
  sample_adata = data
802
- # Get model and effect selection types
816
+
803
817
  select_type = sample_adata.uns["scCODA_params"]["select_type"]
804
818
  model_type = sample_adata.uns["scCODA_params"]["model_type"]
805
819
 
@@ -926,7 +940,6 @@ class CompositionalModel2(ABC):
926
940
  pd.DataFrame: Intercept data frame.
927
941
 
928
942
  Examples:
929
- Example with scCODA:
930
943
  >>> import pertpy as pt
931
944
  >>> haber_cells = pt.dt.haber_2017_regions()
932
945
  >>> sccoda = pt.tl.Sccoda()
@@ -936,7 +949,6 @@ class CompositionalModel2(ABC):
936
949
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
937
950
  >>> intercepts = sccoda.get_intercept_df(mdata)
938
951
  """
939
-
940
952
  if isinstance(data, MuData):
941
953
  try:
942
954
  sample_adata = data[modality_key]
@@ -959,7 +971,6 @@ class CompositionalModel2(ABC):
959
971
  pd.DataFrame: Effect data frame.
960
972
 
961
973
  Examples:
962
- Example with scCODA:
963
974
  >>> import pertpy as pt
964
975
  >>> haber_cells = pt.dt.haber_2017_regions()
965
976
  >>> sccoda = pt.tl.Sccoda()
@@ -969,7 +980,6 @@ class CompositionalModel2(ABC):
969
980
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
970
981
  >>> effects = sccoda.get_effect_df(mdata)
971
982
  """
972
-
973
983
  if isinstance(data, MuData):
974
984
  try:
975
985
  sample_adata = data[modality_key]
@@ -1003,9 +1013,8 @@ class CompositionalModel2(ABC):
1003
1013
  pd.DataFrame: Node effect data frame.
1004
1014
 
1005
1015
  Examples:
1006
- Example with tascCODA (works only for model of type tree_agg, i.e. a tascCODA model):
1007
1016
  >>> import pertpy as pt
1008
- >>> adata = pt.dt.smillie()
1017
+ >>> adata = pt.dt.tasccoda_example()
1009
1018
  >>> tasccoda = pt.tl.Tasccoda()
1010
1019
  >>> mdata = tasccoda.load(
1011
1020
  >>> adata, type="sample_level",
@@ -1113,6 +1122,1136 @@ class CompositionalModel2(ABC):
1113
1122
 
1114
1123
  return out
1115
1124
 
1125
+ def _stackbar( # pragma: no cover
1126
+ self,
1127
+ y: np.ndarray,
1128
+ type_names: list[str],
1129
+ title: str,
1130
+ level_names: list[str],
1131
+ figsize: tuple[float, float] | None = None,
1132
+ dpi: int | None = 100,
1133
+ palette: ListedColormap | None = cm.tab20,
1134
+ show_legend: bool | None = True,
1135
+ ) -> plt.Axes:
1136
+ """Plots a stacked barplot for one (discrete) covariate.
1137
+
1138
+ Typical use (only inside stacked_barplot): plot_one_stackbar(data.X, data.var.index, "xyz", data.obs.index)
1139
+
1140
+ Args:
1141
+ y: The count data, collapsed onto the level of interest. i.e. a binary covariate has two rows,
1142
+ one for each group, containing the count mean of each cell type
1143
+ type_names: The names of all cell types
1144
+ title: Plot title, usually the covariate's name
1145
+ level_names: Names of the covariate's levels
1146
+ figsize: Figure size. Defaults to None.
1147
+ dpi: Dpi setting. Defaults to 100.
1148
+ palette: The color map for the barplot. Defaults to cm.tab20.
1149
+ show_legend: If True, adds a legend. Defaults to True.
1150
+
1151
+ Returns:
1152
+ A :class:`~matplotlib.axes.Axes` object
1153
+ """
1154
+ n_bars, n_types = y.shape
1155
+
1156
+ figsize = rcParams["figure.figsize"] if figsize is None else figsize
1157
+
1158
+ _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1159
+ r = np.array(range(n_bars))
1160
+ sample_sums = np.sum(y, axis=1)
1161
+
1162
+ barwidth = 0.85
1163
+ cum_bars = np.zeros(n_bars)
1164
+
1165
+ for n in range(n_types):
1166
+ bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums, strict=False)]
1167
+ plt.bar(
1168
+ r,
1169
+ bars,
1170
+ bottom=cum_bars,
1171
+ color=palette(n % palette.N),
1172
+ width=barwidth,
1173
+ label=type_names[n],
1174
+ linewidth=0,
1175
+ )
1176
+ cum_bars += bars
1177
+
1178
+ ax.set_title(title)
1179
+ if show_legend:
1180
+ ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1)
1181
+ ax.set_xticks(r)
1182
+ ax.set_xticklabels(level_names, rotation=45, ha="right")
1183
+ ax.set_ylabel("Proportion")
1184
+
1185
+ return ax
1186
+
1187
+ def plot_stacked_barplot( # pragma: no cover
1188
+ self,
1189
+ data: AnnData | MuData,
1190
+ feature_name: str,
1191
+ modality_key: str = "coda",
1192
+ palette: ListedColormap | None = cm.tab20,
1193
+ show_legend: bool | None = True,
1194
+ level_order: list[str] = None,
1195
+ figsize: tuple[float, float] | None = None,
1196
+ dpi: int | None = 100,
1197
+ return_fig: bool | None = None,
1198
+ ax: plt.Axes | None = None,
1199
+ show: bool | None = None,
1200
+ save: str | bool | None = None,
1201
+ **kwargs,
1202
+ ) -> plt.Axes | plt.Figure | None:
1203
+ """Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
1204
+
1205
+ Args:
1206
+ data: AnnData object or MuData object.
1207
+ feature_name: The name of the covariate to plot. If feature_name=="samples", one bar for every sample will be plotted
1208
+ modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
1209
+ figsize: Figure size. Defaults to None.
1210
+ dpi: Dpi setting. Defaults to 100.
1211
+ palette: The matplotlib color map for the barplot. Defaults to cm.tab20.
1212
+ show_legend: If True, adds a legend. Defaults to True.
1213
+ level_order: Custom ordering of bars on the x-axis. Defaults to None.
1214
+
1215
+ Returns:
1216
+ A :class:`~matplotlib.axes.Axes` object
1217
+
1218
+ Examples:
1219
+ >>> import pertpy as pt
1220
+ >>> haber_cells = pt.dt.haber_2017_regions()
1221
+ >>> sccoda = pt.tl.Sccoda()
1222
+ >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
1223
+ sample_identifier="batch", covariate_obs=["condition"])
1224
+ >>> sccoda.plot_stacked_barplot(mdata, feature_name="samples")
1225
+
1226
+ Preview:
1227
+ .. image:: /_static/docstring_previews/sccoda_stacked_barplot.png
1228
+ """
1229
+ if isinstance(data, MuData):
1230
+ data = data[modality_key]
1231
+ if isinstance(data, AnnData):
1232
+ data = data
1233
+
1234
+ ct_names = data.var.index
1235
+
1236
+ # option to plot one stacked barplot per sample
1237
+ if feature_name == "samples":
1238
+ if level_order:
1239
+ assert set(level_order) == set(data.obs.index), "level order is inconsistent with levels"
1240
+ data = data[level_order]
1241
+ ax = self._stackbar(
1242
+ data.X,
1243
+ type_names=data.var.index,
1244
+ title="samples",
1245
+ level_names=data.obs.index,
1246
+ figsize=figsize,
1247
+ dpi=dpi,
1248
+ palette=palette,
1249
+ show_legend=show_legend,
1250
+ )
1251
+ else:
1252
+ # Order levels
1253
+ if level_order:
1254
+ assert set(level_order) == set(data.obs[feature_name]), "level order is inconsistent with levels"
1255
+ levels = level_order
1256
+ elif hasattr(data.obs[feature_name], "cat"):
1257
+ levels = data.obs[feature_name].cat.categories.to_list()
1258
+ else:
1259
+ levels = pd.unique(data.obs[feature_name])
1260
+ n_levels = len(levels)
1261
+ feature_totals = np.zeros([n_levels, data.X.shape[1]])
1262
+
1263
+ for level in range(n_levels):
1264
+ l_indices = np.where(data.obs[feature_name] == levels[level])
1265
+ feature_totals[level] = np.sum(data.X[l_indices], axis=0)
1266
+
1267
+ ax = self._stackbar(
1268
+ feature_totals,
1269
+ type_names=ct_names,
1270
+ title=feature_name,
1271
+ level_names=levels,
1272
+ figsize=figsize,
1273
+ dpi=dpi,
1274
+ palette=palette,
1275
+ show_legend=show_legend,
1276
+ )
1277
+
1278
+ if save:
1279
+ plt.savefig(save, bbox_inches="tight")
1280
+ if show:
1281
+ plt.show()
1282
+ if return_fig:
1283
+ return plt.gcf()
1284
+ if not (show or save):
1285
+ return ax
1286
+ return None
1287
+
1288
+ def plot_effects_barplot( # pragma: no cover
1289
+ self,
1290
+ data: AnnData | MuData,
1291
+ modality_key: str = "coda",
1292
+ covariates: str | list | None = None,
1293
+ parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
1294
+ plot_facets: bool = True,
1295
+ plot_zero_covariate: bool = True,
1296
+ plot_zero_cell_type: bool = False,
1297
+ palette: str | ListedColormap | None = cm.tab20,
1298
+ level_order: list[str] = None,
1299
+ args_barplot: dict | None = None,
1300
+ figsize: tuple[float, float] | None = None,
1301
+ dpi: int | None = 100,
1302
+ return_fig: bool | None = None,
1303
+ ax: plt.Axes | None = None,
1304
+ show: bool | None = None,
1305
+ save: str | bool | None = None,
1306
+ ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
1307
+ """Barplot visualization for effects.
1308
+
1309
+ The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
1310
+ The covariates groups can either be ordered along the x-axis of a single plot (plot_facets=False) or as plot facets (plot_facets=True).
1311
+
1312
+ Args:
1313
+ data: AnnData object or MuData object.
1314
+ modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
1315
+ covariates: The name of the covariates in data.obs to plot. Defaults to None.
1316
+ parameter: The parameter in effect summary to plot. Defaults to "log2-fold change".
1317
+ plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
1318
+ Defaults to True.
1319
+ plot_zero_covariate: If True, plot covariate that have all zero effects. If False, do not plot.
1320
+ Defaults to True.
1321
+ plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot.
1322
+ Defaults to False.
1323
+ figsize: Figure size. Defaults to None.
1324
+ dpi: Figure size. Defaults to 100.
1325
+ palette: The seaborn color map for the barplot. Defaults to cm.tab20.
1326
+ level_order: Custom ordering of bars on the x-axis. Defaults to None.
1327
+ args_barplot: Arguments passed to sns.barplot. Defaults to None.
1328
+
1329
+ Returns:
1330
+ Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
1331
+ or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
1332
+
1333
+ Examples:
1334
+ >>> import pertpy as pt
1335
+ >>> haber_cells = pt.dt.haber_2017_regions()
1336
+ >>> sccoda = pt.tl.Sccoda()
1337
+ >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
1338
+ sample_identifier="batch", covariate_obs=["condition"])
1339
+ >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
1340
+ >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
1341
+ >>> sccoda.plot_effects_barplot(mdata)
1342
+
1343
+ Preview:
1344
+ .. image:: /_static/docstring_previews/sccoda_effects_barplot.png
1345
+ """
1346
+ if args_barplot is None:
1347
+ args_barplot = {}
1348
+ if isinstance(data, MuData):
1349
+ data = data[modality_key]
1350
+ if isinstance(data, AnnData):
1351
+ data = data
1352
+ # Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
1353
+ covariate_names = data.uns["scCODA_params"]["covariate_names"]
1354
+ if covariates is not None:
1355
+ if isinstance(covariates, str):
1356
+ covariates = [covariates]
1357
+ partial_covariate_names = [
1358
+ covariate_name
1359
+ for covariate_name in covariate_names
1360
+ if any(covariate in covariate_name for covariate in covariates)
1361
+ ]
1362
+ covariate_names = partial_covariate_names
1363
+ covariate_names_non_zero = [
1364
+ covariate_name
1365
+ for covariate_name in covariate_names
1366
+ if data.varm[f"effect_df_{covariate_name}"][parameter].any()
1367
+ ]
1368
+ covariate_names_zero = list(set(covariate_names) - set(covariate_names_non_zero))
1369
+ if not plot_zero_covariate:
1370
+ covariate_names = covariate_names_non_zero
1371
+
1372
+ # set up df for plotting
1373
+ plot_df = pd.concat(
1374
+ [data.varm[f"effect_df_{covariate_name}"][parameter] for covariate_name in covariate_names],
1375
+ axis=1,
1376
+ )
1377
+ plot_df.columns = covariate_names
1378
+ plot_df = pd.melt(plot_df, ignore_index=False, var_name="Covariate")
1379
+
1380
+ plot_df = plot_df.reset_index()
1381
+
1382
+ if len(covariate_names_zero) != 0:
1383
+ if plot_facets:
1384
+ if plot_zero_covariate and not plot_zero_cell_type:
1385
+ plot_df = plot_df[plot_df["value"] != 0]
1386
+ for covariate_name_zero in covariate_names_zero:
1387
+ new_row = {
1388
+ "Covariate": covariate_name_zero,
1389
+ "Cell Type": "zero",
1390
+ "value": 0,
1391
+ }
1392
+ plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
1393
+ plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
1394
+ plot_df = plot_df.sort_values(["covariate_"])
1395
+ if not plot_zero_cell_type:
1396
+ cell_type_names_zero = [
1397
+ name
1398
+ for name in plot_df["Cell Type"].unique()
1399
+ if (plot_df[plot_df["Cell Type"] == name]["value"] == 0).all()
1400
+ ]
1401
+ plot_df = plot_df[~plot_df["Cell Type"].isin(cell_type_names_zero)]
1402
+
1403
+ # If plot as facets, create a FacetGrid and map barplot to it.
1404
+ if plot_facets:
1405
+ if isinstance(palette, ListedColormap):
1406
+ palette = np.array([palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]).tolist()
1407
+ if figsize is not None:
1408
+ height = figsize[0]
1409
+ aspect = np.round(figsize[1] / figsize[0], 2)
1410
+ else:
1411
+ height = 3
1412
+ aspect = 2
1413
+
1414
+ g = sns.FacetGrid(
1415
+ plot_df,
1416
+ col="Covariate",
1417
+ sharey=True,
1418
+ sharex=False,
1419
+ height=height,
1420
+ aspect=aspect,
1421
+ )
1422
+
1423
+ g.map(
1424
+ sns.barplot,
1425
+ "Cell Type",
1426
+ "value",
1427
+ palette=palette,
1428
+ order=level_order,
1429
+ **args_barplot,
1430
+ )
1431
+ g.set_xticklabels(rotation=90)
1432
+ g.set(ylabel=parameter)
1433
+ axes = g.axes.flatten()
1434
+ for i, ax in enumerate(axes):
1435
+ ax.set_title(covariate_names[i])
1436
+ if len(ax.get_xticklabels()) < 5:
1437
+ ax.set_aspect(10 / len(ax.get_xticklabels()))
1438
+ if len(ax.get_xticklabels()) == 1:
1439
+ if ax.get_xticklabels()[0]._text == "zero":
1440
+ ax.set_xticks([])
1441
+
1442
+ if save:
1443
+ plt.savefig(save, bbox_inches="tight")
1444
+ if show:
1445
+ plt.show()
1446
+ if return_fig:
1447
+ return plt.gcf()
1448
+ if not (show or save):
1449
+ return g
1450
+ return None
1451
+
1452
+ # If not plot as facets, call barplot to plot cell types on the x-axis.
1453
+ else:
1454
+ _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1455
+ if len(covariate_names) == 1:
1456
+ if isinstance(palette, ListedColormap):
1457
+ palette = np.array(
1458
+ [palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]
1459
+ ).tolist()
1460
+ sns.barplot(
1461
+ data=plot_df,
1462
+ x="Cell Type",
1463
+ y="value",
1464
+ hue="x",
1465
+ palette=palette,
1466
+ ax=ax,
1467
+ )
1468
+ ax.set_title(covariate_names[0])
1469
+ else:
1470
+ if isinstance(palette, ListedColormap):
1471
+ palette = np.array([palette(i % palette.N) for i in range(len(covariate_names))]).tolist()
1472
+ sns.barplot(
1473
+ data=plot_df,
1474
+ x="Cell Type",
1475
+ y="value",
1476
+ hue="Covariate",
1477
+ palette=palette,
1478
+ ax=ax,
1479
+ )
1480
+ cell_types = pd.unique(plot_df["Cell Type"])
1481
+ ax.set_xticklabels(cell_types, rotation=90)
1482
+
1483
+ if save:
1484
+ plt.savefig(save, bbox_inches="tight")
1485
+ if show:
1486
+ plt.show()
1487
+ if return_fig:
1488
+ return plt.gcf()
1489
+ if not (show or save):
1490
+ return ax
1491
+ return None
1492
+
1493
+ def plot_boxplots( # pragma: no cover
1494
+ self,
1495
+ data: AnnData | MuData,
1496
+ feature_name: str,
1497
+ modality_key: str = "coda",
1498
+ y_scale: Literal["relative", "log", "log10", "count"] = "relative",
1499
+ plot_facets: bool = False,
1500
+ add_dots: bool = False,
1501
+ cell_types: list | None = None,
1502
+ args_boxplot: dict | None = None,
1503
+ args_swarmplot: dict | None = None,
1504
+ palette: str | None = "Blues",
1505
+ show_legend: bool | None = True,
1506
+ level_order: list[str] = None,
1507
+ figsize: tuple[float, float] | None = None,
1508
+ dpi: int | None = 100,
1509
+ return_fig: bool | None = None,
1510
+ ax: plt.Axes | None = None,
1511
+ show: bool | None = None,
1512
+ save: str | bool | None = None,
1513
+ ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
1514
+ """Grouped boxplot visualization.
1515
+
1516
+ The cell counts for each cell type are shown as a group of boxplots
1517
+ with intra--group separation by a covariate from data.obs.
1518
+
1519
+ Args:
1520
+ data: AnnData object or MuData object
1521
+ feature_name: The name of the feature in data.obs to plot
1522
+ modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
1523
+ y_scale: Transformation to of cell counts. Options: "relative" - Relative abundance, "log" - log(count),
1524
+ "log10" - log10(count), "count" - absolute abundance (cell counts).
1525
+ Defaults to "relative".
1526
+ plot_facets: If False, plot cell types on the x-axis. If True, plot as facets. Defaults to False.
1527
+ add_dots: If True, overlay a scatterplot with one dot for each data point. Defaults to False.
1528
+ cell_types: Subset of cell types that should be plotted. Defaults to None.
1529
+ args_boxplot: Arguments passed to sns.boxplot. Defaults to {}.
1530
+ args_swarmplot: Arguments passed to sns.swarmplot. Defaults to {}.
1531
+ figsize: Figure size. Defaults to None.
1532
+ dpi: Dpi setting. Defaults to 100.
1533
+ palette: The seaborn color map for the barplot. Defaults to "Blues".
1534
+ show_legend: If True, adds a legend. Defaults to True.
1535
+ level_order: Custom ordering of bars on the x-axis. Defaults to None.
1536
+
1537
+ Returns:
1538
+ Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
1539
+ or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
1540
+
1541
+ Examples:
1542
+ >>> import pertpy as pt
1543
+ >>> haber_cells = pt.dt.haber_2017_regions()
1544
+ >>> sccoda = pt.tl.Sccoda()
1545
+ >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
1546
+ sample_identifier="batch", covariate_obs=["condition"])
1547
+ >>> sccoda.plot_boxplots(mdata, feature_name="condition", add_dots=True)
1548
+
1549
+ Preview:
1550
+ .. image:: /_static/docstring_previews/sccoda_boxplots.png
1551
+ """
1552
+ if args_boxplot is None:
1553
+ args_boxplot = {}
1554
+ if args_swarmplot is None:
1555
+ args_swarmplot = {}
1556
+ if isinstance(data, MuData):
1557
+ data = data[modality_key]
1558
+ if isinstance(data, AnnData):
1559
+ data = data
1560
+ # y scale transformations
1561
+ if y_scale == "relative":
1562
+ sample_sums = np.sum(data.X, axis=1, keepdims=True)
1563
+ X = data.X / sample_sums
1564
+ value_name = "Proportion"
1565
+ # add pseudocount 0.5 if using log scale
1566
+ elif y_scale == "log":
1567
+ X = data.X.copy()
1568
+ X[X == 0] = 0.5
1569
+ X = np.log(X)
1570
+ value_name = "log(count)"
1571
+ elif y_scale == "log10":
1572
+ X = data.X.copy()
1573
+ X[X == 0] = 0.5
1574
+ X = np.log(X)
1575
+ value_name = "log10(count)"
1576
+ elif y_scale == "count":
1577
+ X = data.X
1578
+ value_name = "count"
1579
+ else:
1580
+ raise ValueError("Invalid y_scale transformation")
1581
+
1582
+ count_df = pd.DataFrame(X, columns=data.var.index, index=data.obs.index).merge(
1583
+ data.obs[feature_name], left_index=True, right_index=True
1584
+ )
1585
+ plot_df = pd.melt(count_df, id_vars=feature_name, var_name="Cell type", value_name=value_name)
1586
+ if cell_types is not None:
1587
+ plot_df = plot_df[plot_df["Cell type"].isin(cell_types)]
1588
+
1589
+ # Currently disabled because the latest statsannotations does not support the latest seaborn.
1590
+ # We had to drop the dependency.
1591
+ # Get credible effects results from model
1592
+ # if draw_effects:
1593
+ # if model is not None:
1594
+ # credible_effects_df = model.credible_effects(data, modality_key).to_frame().reset_index()
1595
+ # else:
1596
+ # print("[bold yellow]Specify a tasCODA model to draw effects")
1597
+ # credible_effects_df[feature_name] = credible_effects_df["Covariate"].str.removeprefix(f"{feature_name}[T.")
1598
+ # credible_effects_df[feature_name] = credible_effects_df[feature_name].str.removesuffix("]")
1599
+ # credible_effects_df = credible_effects_df[credible_effects_df["Final Parameter"]]
1600
+
1601
+ # If plot as facets, create a FacetGrid and map boxplot to it.
1602
+ if plot_facets:
1603
+ if level_order is None:
1604
+ level_order = pd.unique(plot_df[feature_name])
1605
+
1606
+ K = X.shape[1]
1607
+
1608
+ if figsize is not None:
1609
+ height = figsize[0]
1610
+ aspect = np.round(figsize[1] / figsize[0], 2)
1611
+ else:
1612
+ height = 3
1613
+ aspect = 2
1614
+
1615
+ g = sns.FacetGrid(
1616
+ plot_df,
1617
+ col="Cell type",
1618
+ sharey=False,
1619
+ col_wrap=int(np.floor(np.sqrt(K))),
1620
+ height=height,
1621
+ aspect=aspect,
1622
+ )
1623
+ g.map(
1624
+ sns.boxplot,
1625
+ feature_name,
1626
+ value_name,
1627
+ palette=palette,
1628
+ order=level_order,
1629
+ **args_boxplot,
1630
+ )
1631
+
1632
+ if add_dots:
1633
+ if "hue" in args_swarmplot:
1634
+ hue = args_swarmplot.pop("hue")
1635
+ else:
1636
+ hue = None
1637
+
1638
+ if hue is None:
1639
+ g.map(
1640
+ sns.swarmplot,
1641
+ feature_name,
1642
+ value_name,
1643
+ color="black",
1644
+ order=level_order,
1645
+ **args_swarmplot,
1646
+ ).set_titles("{col_name}")
1647
+ else:
1648
+ g.map(
1649
+ sns.swarmplot,
1650
+ feature_name,
1651
+ value_name,
1652
+ hue,
1653
+ order=level_order,
1654
+ **args_swarmplot,
1655
+ ).set_titles("{col_name}")
1656
+
1657
+ if save:
1658
+ plt.savefig(save, bbox_inches="tight")
1659
+ if show:
1660
+ plt.show()
1661
+ if return_fig:
1662
+ return plt.gcf()
1663
+ if not (show or save):
1664
+ return g
1665
+ return None
1666
+
1667
+ # If not plot as facets, call boxplot to plot cell types on the x-axis.
1668
+ else:
1669
+ if level_order:
1670
+ args_boxplot["hue_order"] = level_order
1671
+ args_swarmplot["hue_order"] = level_order
1672
+
1673
+ _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1674
+
1675
+ ax = sns.boxplot(
1676
+ x="Cell type",
1677
+ y=value_name,
1678
+ hue=feature_name,
1679
+ data=plot_df,
1680
+ fliersize=1,
1681
+ palette=palette,
1682
+ ax=ax,
1683
+ **args_boxplot,
1684
+ )
1685
+
1686
+ # Currently disabled because the latest statsannotations does not support the latest seaborn.
1687
+ # We had to drop the dependency.
1688
+ # if draw_effects:
1689
+ # pairs = [
1690
+ # [(row["Cell Type"], row[feature_name]), (row["Cell Type"], "Control")]
1691
+ # for _, row in credible_effects_df.iterrows()
1692
+ # ]
1693
+ # annot = Annotator(ax, pairs, data=plot_df, x="Cell type", y=value_name, hue=feature_name)
1694
+ # annot.configure(test=None, loc="outside", color="red", line_height=0, verbose=False)
1695
+ # annot.set_custom_annotations([row[feature_name] for _, row in credible_effects_df.iterrows()])
1696
+ # annot.annotate()
1697
+
1698
+ if add_dots:
1699
+ sns.swarmplot(
1700
+ x="Cell type",
1701
+ y=value_name,
1702
+ data=plot_df,
1703
+ hue=feature_name,
1704
+ ax=ax,
1705
+ dodge=True,
1706
+ palette="dark:black",
1707
+ **args_swarmplot,
1708
+ )
1709
+
1710
+ cell_types = pd.unique(plot_df["Cell type"])
1711
+ ax.set_xticklabels(cell_types, rotation=90)
1712
+
1713
+ if show_legend:
1714
+ handles, labels = ax.get_legend_handles_labels()
1715
+ handout = []
1716
+ labelout = []
1717
+ for h, l in zip(handles, labels, strict=False):
1718
+ if l not in labelout:
1719
+ labelout.append(l)
1720
+ handout.append(h)
1721
+ ax.legend(
1722
+ handout,
1723
+ labelout,
1724
+ loc="upper left",
1725
+ bbox_to_anchor=(1, 1),
1726
+ ncol=1,
1727
+ title=feature_name,
1728
+ )
1729
+
1730
+ if save:
1731
+ plt.savefig(save, bbox_inches="tight")
1732
+ if show:
1733
+ plt.show()
1734
+ if return_fig:
1735
+ return plt.gcf()
1736
+ if not (show or save):
1737
+ return ax
1738
+ return None
1739
+
1740
+ def plot_rel_abundance_dispersion_plot( # pragma: no cover
1741
+ self,
1742
+ data: AnnData | MuData,
1743
+ modality_key: str = "coda",
1744
+ abundant_threshold: float | None = 0.9,
1745
+ default_color: str | None = "Grey",
1746
+ abundant_color: str | None = "Red",
1747
+ label_cell_types: bool = True,
1748
+ figsize: tuple[float, float] | None = None,
1749
+ dpi: int | None = 100,
1750
+ return_fig: bool | None = None,
1751
+ ax: plt.Axes | None = None,
1752
+ show: bool | None = None,
1753
+ save: str | bool | None = None,
1754
+ ) -> plt.Axes | plt.Figure | None:
1755
+ """Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
1756
+
1757
+ If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color.
1758
+
1759
+ Args:
1760
+ data: AnnData or MuData object.
1761
+ modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
1762
+ Defaults to "coda".
1763
+ abundant_threshold: Presence threshold for abundant cell types. Defaults to 0.9.
1764
+ default_color: Bar color for all non-minimal cell types. Defaults to "Grey".
1765
+ abundant_color: Bar color for cell types with abundant percentage larger than abundant_threshold.
1766
+ Defaults to "Red".
1767
+ label_cell_types: Label dots with cell type names. Defaults to True.
1768
+ figsize: Figure size. Defaults to None.
1769
+ dpi: Dpi setting. Defaults to 100.
1770
+ ax: A matplotlib axes object. Only works if plotting a single component. Defaults to None.
1771
+
1772
+ Returns:
1773
+ A :class:`~matplotlib.axes.Axes` object
1774
+
1775
+ Examples:
1776
+ >>> import pertpy as pt
1777
+ >>> haber_cells = pt.dt.haber_2017_regions()
1778
+ >>> sccoda = pt.tl.Sccoda()
1779
+ >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
1780
+ sample_identifier="batch", covariate_obs=["condition"])
1781
+ >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
1782
+ >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
1783
+ >>> sccoda.plot_rel_abundance_dispersion_plot(mdata)
1784
+
1785
+ Preview:
1786
+ .. image:: /_static/docstring_previews/sccoda_rel_abundance_dispersion_plot.png
1787
+ """
1788
+ if isinstance(data, MuData):
1789
+ data = data[modality_key]
1790
+ if isinstance(data, AnnData):
1791
+ data = data
1792
+ if ax is None:
1793
+ _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1794
+
1795
+ rel_abun = data.X / np.sum(data.X, axis=1, keepdims=True)
1796
+
1797
+ percent_zero = np.sum(data.X == 0, axis=0) / data.X.shape[0]
1798
+ nonrare_ct = np.where(percent_zero < 1 - abundant_threshold)[0]
1799
+
1800
+ # select reference
1801
+ cell_type_disp = np.var(rel_abun, axis=0) / np.mean(rel_abun, axis=0)
1802
+
1803
+ is_abundant = [x in nonrare_ct for x in range(data.X.shape[1])]
1804
+
1805
+ # Scatterplot
1806
+ plot_df = pd.DataFrame(
1807
+ {
1808
+ "Total dispersion": cell_type_disp,
1809
+ "Cell type": data.var.index,
1810
+ "Presence": 1 - percent_zero,
1811
+ "Is abundant": is_abundant,
1812
+ }
1813
+ )
1814
+
1815
+ if len(np.unique(plot_df["Is abundant"])) > 1:
1816
+ palette = [default_color, abundant_color]
1817
+ elif np.unique(plot_df["Is abundant"]) == [False]:
1818
+ palette = [default_color]
1819
+ else:
1820
+ palette = [abundant_color]
1821
+
1822
+ ax = sns.scatterplot(
1823
+ data=plot_df,
1824
+ x="Presence",
1825
+ y="Total dispersion",
1826
+ hue="Is abundant",
1827
+ palette=palette,
1828
+ ax=ax,
1829
+ )
1830
+
1831
+ # Text labels for abundant cell types
1832
+
1833
+ abundant_df = plot_df.loc[plot_df["Is abundant"], :]
1834
+
1835
+ def label_point(x, y, val, ax):
1836
+ a = pd.concat({"x": x, "y": y, "val": val}, axis=1)
1837
+ texts = [
1838
+ ax.text(
1839
+ point["x"],
1840
+ point["y"],
1841
+ str(point["val"]),
1842
+ )
1843
+ for i, point in a.iterrows()
1844
+ ]
1845
+ adjust_text(texts)
1846
+
1847
+ if label_cell_types:
1848
+ label_point(
1849
+ abundant_df["Presence"],
1850
+ abundant_df["Total dispersion"],
1851
+ abundant_df["Cell type"],
1852
+ plt.gca(),
1853
+ )
1854
+
1855
+ ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
1856
+
1857
+ if save:
1858
+ plt.savefig(save, bbox_inches="tight")
1859
+ if show:
1860
+ plt.show()
1861
+ if return_fig:
1862
+ return plt.gcf()
1863
+ if not (show or save):
1864
+ return ax
1865
+ return None
1866
+
1867
+ def plot_draw_tree( # pragma: no cover
1868
+ self,
1869
+ data: AnnData | MuData,
1870
+ modality_key: str = "coda",
1871
+ tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1872
+ tight_text: bool | None = False,
1873
+ show_scale: bool | None = False,
1874
+ units: Literal["px", "mm", "in"] | None = "px",
1875
+ figsize: tuple[float, float] | None = (None, None),
1876
+ dpi: int | None = 100,
1877
+ show: bool | None = True,
1878
+ save: str | bool | None = None,
1879
+ ) -> Tree | None:
1880
+ """Plot a tree using input ete3 tree object.
1881
+
1882
+ Args:
1883
+ data: AnnData object or MuData object.
1884
+ modality_key: If data is a MuData object, specify which modality to use.
1885
+ Defaults to "coda".
1886
+ tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
1887
+ Defaults to "tree".
1888
+ tight_text: When False, boundaries of the text are approximated according to general font metrics,
1889
+ producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
1890
+ Default to False.
1891
+ show_scale: Include the scale legend in the tree image or not.
1892
+ Defaults to False.
1893
+ show: If True, plot the tree inline. If false, return tree and tree_style objects.
1894
+ Defaults to True.
1895
+ file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG.
1896
+ Output image can be saved whether show is True or not.
1897
+ Defaults to None.
1898
+ units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px".
1899
+ figsize: Figure size. Defaults to None.
1900
+ dpi: Dots per inches. Defaults to 100.
1901
+
1902
+ Returns:
1903
+ Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
1904
+
1905
+ Examples:
1906
+ >>> import pertpy as pt
1907
+ >>> adata = pt.dt.tasccoda_example()
1908
+ >>> tasccoda = pt.tl.Tasccoda()
1909
+ >>> mdata = tasccoda.load(
1910
+ >>> adata, type="sample_level",
1911
+ >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
1912
+ >>> key_added="lineage", add_level_name=True
1913
+ >>> )
1914
+ >>> mdata = tasccoda.prepare(
1915
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
1916
+ >>> )
1917
+ >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
1918
+ >>> tasccoda.plot_draw_tree(mdata, tree="lineage")
1919
+
1920
+ Preview:
1921
+ .. image:: /_static/docstring_previews/tasccoda_draw_tree.png
1922
+ """
1923
+ try:
1924
+ from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
1925
+ except ImportError:
1926
+ raise ImportError(
1927
+ "To use tasccoda please install additional dependencies with `pip install pertpy[coda]`"
1928
+ ) from None
1929
+
1930
+ if isinstance(data, MuData):
1931
+ data = data[modality_key]
1932
+ if isinstance(data, AnnData):
1933
+ data = data
1934
+ if isinstance(tree, str):
1935
+ tree = data.uns[tree]
1936
+
1937
+ def my_layout(node):
1938
+ text_face = TextFace(node.name, tight_text=tight_text)
1939
+ faces.add_face_to_node(text_face, node, column=0, position="branch-right")
1940
+
1941
+ tree_style = TreeStyle()
1942
+ tree_style.show_leaf_name = False
1943
+ tree_style.layout_fn = my_layout
1944
+ tree_style.show_scale = show_scale
1945
+
1946
+ if save is not None:
1947
+ tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1948
+ if show:
1949
+ return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
1950
+ else:
1951
+ return tree, tree_style
1952
+
1953
+ def plot_draw_effects( # pragma: no cover
1954
+ self,
1955
+ data: AnnData | MuData,
1956
+ covariate: str,
1957
+ modality_key: str = "coda",
1958
+ tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
1959
+ show_legend: bool | None = None,
1960
+ show_leaf_effects: bool | None = False,
1961
+ tight_text: bool | None = False,
1962
+ show_scale: bool | None = False,
1963
+ units: Literal["px", "mm", "in"] | None = "px",
1964
+ figsize: tuple[float, float] | None = (None, None),
1965
+ dpi: int | None = 100,
1966
+ show: bool | None = True,
1967
+ save: str | None = None,
1968
+ ) -> Tree | None:
1969
+ """Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
1970
+
1971
+ Args:
1972
+ data: AnnData object or MuData object.
1973
+ covariate: The covariate, whose effects should be plotted.
1974
+ modality_key: If data is a MuData object, specify which modality to use.
1975
+ Defaults to "coda".
1976
+ tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
1977
+ Defaults to "tree".
1978
+ show_legend: If show legend of nodes significant effects or not.
1979
+ Defaults to False if show_leaf_effects is True.
1980
+ show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects.
1981
+ Defaults to False.
1982
+ tight_text: When False, boundaries of the text are approximated according to general font metrics,
1983
+ producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
1984
+ Defaults to False.
1985
+ show_scale: Include the scale legend in the tree image or not. Defaults to False.
1986
+ show: If True, plot the tree inline. If false, return tree and tree_style objects. Defaults to True.
1987
+ file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
1988
+ Defaults to None.
1989
+ units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px".
1990
+ figsize: Figure size. Defaults to None.
1991
+ dpi: Dots per inches. Defaults to 100.
1992
+
1993
+ Returns:
1994
+ Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`)
1995
+ or plot the tree inline (`show = False`)
1996
+
1997
+ Examples:
1998
+ >>> import pertpy as pt
1999
+ >>> adata = pt.dt.tasccoda_example()
2000
+ >>> tasccoda = pt.tl.Tasccoda()
2001
+ >>> mdata = tasccoda.load(
2002
+ >>> adata, type="sample_level",
2003
+ >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
2004
+ >>> key_added="lineage", add_level_name=True
2005
+ >>> )
2006
+ >>> mdata = tasccoda.prepare(
2007
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
2008
+ >>> )
2009
+ >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
2010
+ >>> tasccoda.plot_draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
2011
+
2012
+ Preview:
2013
+ .. image:: /_static/docstring_previews/tasccoda_draw_effects.png
2014
+ """
2015
+ try:
2016
+ from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
2017
+ except ImportError:
2018
+ raise ImportError(
2019
+ "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
2020
+ ) from None
2021
+
2022
+ if isinstance(data, MuData):
2023
+ data = data[modality_key]
2024
+ if isinstance(data, AnnData):
2025
+ data = data
2026
+ if show_legend is None:
2027
+ show_legend = not show_leaf_effects
2028
+ elif show_legend:
2029
+ print("Tree leaves and leaf effect bars won't be aligned when legend is shown!")
2030
+
2031
+ if isinstance(tree, str):
2032
+ tree = data.uns[tree]
2033
+ # Collapse tree singularities
2034
+ tree2 = collapse_singularities_2(tree)
2035
+
2036
+ node_effs = data.uns["scCODA_params"]["node_df"].loc[(covariate + "_node",),].copy()
2037
+ node_effs.index = node_effs.index.get_level_values("Node")
2038
+
2039
+ covariates = data.uns["scCODA_params"]["covariate_names"]
2040
+ effect_dfs = [data.varm[f"effect_df_{cov}"] for cov in covariates]
2041
+ eff_df = pd.concat(effect_dfs)
2042
+ eff_df.index = pd.MultiIndex.from_product(
2043
+ (covariates, data.var.index.tolist()),
2044
+ names=["Covariate", "Cell Type"],
2045
+ )
2046
+ leaf_effs = eff_df.loc[(covariate,),].copy()
2047
+ leaf_effs.index = leaf_effs.index.get_level_values("Cell Type")
2048
+
2049
+ # Add effect values
2050
+ for n in tree2.traverse():
2051
+ nstyle = NodeStyle()
2052
+ nstyle["size"] = 0
2053
+ n.set_style(nstyle)
2054
+ if n.name in node_effs.index:
2055
+ e = node_effs.loc[n.name, "Final Parameter"]
2056
+ n.add_feature("node_effect", e)
2057
+ else:
2058
+ n.add_feature("node_effect", 0)
2059
+ if n.name in leaf_effs.index:
2060
+ e = leaf_effs.loc[n.name, "Effect"]
2061
+ n.add_feature("leaf_effect", e)
2062
+ else:
2063
+ n.add_feature("leaf_effect", 0)
2064
+
2065
+ # Scale effect values to get nice node sizes
2066
+ eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
2067
+ leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
2068
+
2069
+ def my_layout(node):
2070
+ text_face = TextFace(node.name, tight_text=tight_text)
2071
+ text_face.margin_left = 10
2072
+ faces.add_face_to_node(text_face, node, column=0, aligned=True)
2073
+
2074
+ # if node.is_leaf():
2075
+ size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
2076
+ if np.sign(node.node_effect) == 1:
2077
+ color = "blue"
2078
+ elif np.sign(node.node_effect) == -1:
2079
+ color = "red"
2080
+ else:
2081
+ color = "cyan"
2082
+ if size != 0:
2083
+ faces.add_face_to_node(CircleFace(radius=size, color=color), node, column=0)
2084
+
2085
+ tree_style = TreeStyle()
2086
+ tree_style.show_leaf_name = False
2087
+ tree_style.layout_fn = my_layout
2088
+ tree_style.show_scale = show_scale
2089
+ tree_style.draw_guiding_lines = True
2090
+ tree_style.legend_position = 1
2091
+
2092
+ if show_legend:
2093
+ tree_style.legend.add_face(TextFace("Effects"), column=0)
2094
+ tree_style.legend.add_face(TextFace(" "), column=1)
2095
+ for i in range(4, 0, -1):
2096
+ tree_style.legend.add_face(
2097
+ CircleFace(
2098
+ float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
2099
+ "red",
2100
+ ),
2101
+ column=0,
2102
+ )
2103
+ tree_style.legend.add_face(TextFace(f"{-eff_max * i / 4:.2f} "), column=0)
2104
+ tree_style.legend.add_face(
2105
+ CircleFace(
2106
+ float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
2107
+ "blue",
2108
+ ),
2109
+ column=1,
2110
+ )
2111
+ tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
2112
+
2113
+ if show_leaf_effects:
2114
+ leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
2115
+ leaf_effs = leaf_effs.loc[leaf_name].reset_index()
2116
+ palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
2117
+
2118
+ dir_path = Path.cwd()
2119
+ dir_path = Path(dir_path / "tree_effect.png")
2120
+ tree2.render(dir_path, tree_style=tree_style, units="in")
2121
+ _, ax = plt.subplots(1, 2, figsize=(10, 10))
2122
+ sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
2123
+ img = mpimg.imread(dir_path)
2124
+ ax[0].imshow(img)
2125
+ ax[0].get_xaxis().set_visible(False)
2126
+ ax[0].get_yaxis().set_visible(False)
2127
+ ax[0].set_frame_on(False)
2128
+
2129
+ ax[1].get_yaxis().set_visible(False)
2130
+ ax[1].spines["left"].set_visible(False)
2131
+ ax[1].spines["right"].set_visible(False)
2132
+ ax[1].spines["top"].set_visible(False)
2133
+ plt.xlim(-leaf_eff_max, leaf_eff_max)
2134
+ plt.subplots_adjust(wspace=0)
2135
+
2136
+ if save is not None:
2137
+ plt.savefig(save)
2138
+
2139
+ if save is not None and not show_leaf_effects:
2140
+ tree2.render(save, tree_style=tree_style, units=units)
2141
+ if show:
2142
+ if not show_leaf_effects:
2143
+ return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
2144
+ else:
2145
+ if not show_leaf_effects:
2146
+ return tree2, tree_style
2147
+ return None
2148
+
2149
+ def plot_effects_umap( # pragma: no cover
2150
+ self,
2151
+ mdata: MuData,
2152
+ effect_name: str | list | None,
2153
+ cluster_key: str,
2154
+ modality_key_1: str = "rna",
2155
+ modality_key_2: str = "coda",
2156
+ color_map: Colormap | str | None = None,
2157
+ palette: str | Sequence[str] | None = None,
2158
+ return_fig: bool | None = None,
2159
+ ax: Axes = None,
2160
+ show: bool = None,
2161
+ save: str | bool | None = None,
2162
+ **kwargs,
2163
+ ) -> plt.Axes | plt.Figure | None:
2164
+ """Plot a UMAP visualization colored by effect strength.
2165
+
2166
+ Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData
2167
+ (default is data['rna']) depending on the cluster they were assigned to.
2168
+
2169
+ Args:
2170
+ mudata: MuData object.
2171
+ effect_name: The name of the effect results in .varm of aggregated sample-level AnnData to plot
2172
+ cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']).
2173
+ To assign cell types' effects to original cells.
2174
+ modality_key_1: Key to the cell-level AnnData in the MuData object. Defaults to "rna".
2175
+ modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
2176
+ Defaults to "coda".
2177
+ show: Whether to display the figure or return axis. Defaults to None.
2178
+ ax: A matplotlib axes object. Only works if plotting a single component.
2179
+ Defaults to None.
2180
+ **kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
2181
+
2182
+ Returns:
2183
+ If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
2184
+
2185
+ Examples:
2186
+ >>> import pertpy as pt
2187
+ >>> import scanpy as sc
2188
+ >>> import schist
2189
+ >>> adata = pt.dt.haber_2017_regions()
2190
+ >>> sc.pp.neighbors(adata)
2191
+ >>> schist.inference.nested_model(adata, n_init=100, random_seed=5678)
2192
+ >>> tasccoda_model = pt.tl.Tasccoda()
2193
+ >>> tasccoda_data = tasccoda_model.load(adata, type="cell_level",
2194
+ >>> cell_type_identifier="nsbm_level_1",
2195
+ >>> sample_identifier="batch", covariate_obs=["condition"],
2196
+ >>> levels_orig=["nsbm_level_4", "nsbm_level_3", "nsbm_level_2", "nsbm_level_1"],
2197
+ >>> add_level_name=True)
2198
+ >>> tasccoda_model.prepare(
2199
+ >>> tasccoda_data,
2200
+ >>> modality_key="coda",
2201
+ >>> reference_cell_type="18",
2202
+ >>> formula="condition",
2203
+ >>> pen_args={"phi": 0, "lambda_1": 3.5},
2204
+ >>> tree_key="tree"
2205
+ >>> )
2206
+ >>> tasccoda_model.run_nuts(
2207
+ ... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
2208
+ ... )
2209
+ >>> tasccoda_model.run_nuts(
2210
+ ... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
2211
+ ... )
2212
+ >>> sc.tl.umap(tasccoda_data["rna"])
2213
+ >>> tasccoda_model.plot_effects_umap(tasccoda_data,
2214
+ >>> effect_name=["effect_df_condition[T.Salmonella]",
2215
+ >>> "effect_df_condition[T.Hpoly.Day3]",
2216
+ >>> "effect_df_condition[T.Hpoly.Day10]"],
2217
+ >>> cluster_key="nsbm_level_1",
2218
+ >>> )
2219
+
2220
+ Preview:
2221
+ .. image:: /_static/docstring_previews/tasccoda_effects_umap.png
2222
+ """
2223
+ # TODO: Add effect_name parameter and cluster_key and test the example
2224
+ data_rna = mdata[modality_key_1]
2225
+ data_coda = mdata[modality_key_2]
2226
+ if isinstance(effect_name, str):
2227
+ effect_name = [effect_name]
2228
+ for _, effect in enumerate(effect_name):
2229
+ data_rna.obs[effect] = [data_coda.varm[effect].loc[f"{c}", "Effect"] for c in data_rna.obs[cluster_key]]
2230
+ if kwargs.get("vmin"):
2231
+ vmin = kwargs["vmin"]
2232
+ kwargs.pop("vmin")
2233
+ else:
2234
+ vmin = min(data_rna.obs[effect].min() for _, effect in enumerate(effect_name))
2235
+ if kwargs.get("vmax"):
2236
+ vmax = kwargs["vmax"]
2237
+ kwargs.pop("vmax")
2238
+ else:
2239
+ vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name))
2240
+
2241
+ return sc.pl.umap(
2242
+ data_rna,
2243
+ color=effect_name,
2244
+ vmax=vmax,
2245
+ vmin=vmin,
2246
+ palette=palette,
2247
+ color_map=color_map,
2248
+ return_fig=return_fig,
2249
+ ax=ax,
2250
+ show=show,
2251
+ save=save,
2252
+ **kwargs,
2253
+ )
2254
+
1116
2255
 
1117
2256
  def get_a(
1118
2257
  tree: tt.tree,
@@ -1242,7 +2381,7 @@ def df2newick(df: pd.DataFrame, levels: list[str], inner_label: bool = True) ->
1242
2381
 
1243
2382
 
1244
2383
  def get_a_2(
1245
- tree: ete.Tree,
2384
+ tree: Tree,
1246
2385
  leaf_order: list[str] = None,
1247
2386
  node_order: list[str] = None,
1248
2387
  ) -> tuple[np.ndarray, int]:
@@ -1263,6 +2402,13 @@ def get_a_2(
1263
2402
  T
1264
2403
  number of nodes in the tree, excluding the root node
1265
2404
  """
2405
+ try:
2406
+ import ete3 as ete
2407
+ except ImportError:
2408
+ raise ImportError(
2409
+ "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
2410
+ ) from None
2411
+
1266
2412
  n_tips = len(tree.get_leaves())
1267
2413
  n_nodes = len(tree.get_descendants())
1268
2414
 
@@ -1292,7 +2438,7 @@ def get_a_2(
1292
2438
  return A_, n_nodes
1293
2439
 
1294
2440
 
1295
- def collapse_singularities_2(tree: ete.Tree) -> ete.Tree:
2441
+ def collapse_singularities_2(tree: Tree) -> Tree:
1296
2442
  """Collapses (deletes) nodes in a ete3 tree that are singularities (have only one child).
1297
2443
 
1298
2444
  Args:
@@ -1327,10 +2473,10 @@ def linkage_to_newick(
1327
2473
 
1328
2474
  def build_newick(node, newick, parentdist, leaf_names):
1329
2475
  if node.is_leaf():
1330
- return f"{leaf_names[node.id]}:{(parentdist - node.dist)/2}{newick}"
2476
+ return f"{leaf_names[node.id]}:{(parentdist - node.dist) / 2}{newick}"
1331
2477
  else:
1332
2478
  if len(newick) > 0:
1333
- newick = f"):{(parentdist - node.dist)/2}{newick}"
2479
+ newick = f"):{(parentdist - node.dist) / 2}{newick}"
1334
2480
  else:
1335
2481
  newick = ");"
1336
2482
  newick = build_newick(node.get_left(), newick, node.dist, leaf_names)
@@ -1363,14 +2509,15 @@ def import_tree(
1363
2509
 
1364
2510
  Args:
1365
2511
  data: A tascCODA-compatible data object.
1366
- modality_1: If `data` is MuData, specifiy the modality name to the original cell level anndata object. Defaults to None.
1367
- modality_2: If `data` is MuData, specifiy the modality name to the aggregated level anndata object. Defaults to None.
2512
+ modality_1: If `data` is MuData, specify the modality name to the original cell level anndata object. Defaults to None.
2513
+ modality_2: If `data` is MuData, specify the modality name to the aggregated level anndata object. Defaults to None.
1368
2514
  dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object. Defaults to None.
1369
2515
  levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None.
1370
2516
  levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None.
1371
- add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}. Defaults to True.
1372
- key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`. If `data` is MuData, save tree in data[modality_2]. Defaults to "tree".
1373
- copy: Return a copy instead of writing to `data`. Defaults to False.
2517
+ add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
2518
+ Defaults to True.
2519
+ key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`.
2520
+ If `data` is MuData, save tree in data[modality_2]. Defaults to "tree".
1374
2521
 
1375
2522
  Returns:
1376
2523
  Updates data with the following:
@@ -1379,6 +2526,13 @@ def import_tree(
1379
2526
 
1380
2527
  tree: A ete3 tree object.
1381
2528
  """
2529
+ try:
2530
+ import ete3 as ete
2531
+ except ImportError:
2532
+ raise ImportError(
2533
+ "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
2534
+ ) from None
2535
+
1382
2536
  if isinstance(data, MuData):
1383
2537
  try:
1384
2538
  data_1 = data[modality_1]
@@ -1443,16 +2597,17 @@ def from_scanpy(
1443
2597
 
1444
2598
  The anndata object needs to have a column in adata.obs that contains the cell type assignment.
1445
2599
  Further, it must contain one column or a set of columns (e.g. subject id, treatment, disease status) that uniquely identify each (statistical) sample.
1446
- Further covariates (e.g. subject age) can either be specified via addidional column names in adata.obs, a key in adata.uns, or as a separate DataFrame.
2600
+ Further covariates (e.g. subject age) can either be specified via additional column names in adata.obs, a key in adata.uns, or as a separate DataFrame.
1447
2601
 
1448
- NOTE: The order of samples in the returned dataset is determined by the first occurence of cells from each sample in `adata`
2602
+ NOTE: The order of samples in the returned dataset is determined by the first occurrence of cells from each sample in `adata`
1449
2603
 
1450
2604
  Args:
1451
2605
  adata: An anndata object from scanpy
1452
2606
  cell_type_identifier: column name in adata.obs that specifies the cell types
1453
2607
  sample_identifier: column name or list of column names in adata.obs that uniquely identify each sample
1454
2608
  covariate_uns: key for adata.uns, where covariate values are stored
1455
- covariate_obs: list of column names in adata.obs, where covariate values are stored. Note: If covariate values are not unique for a value of sample_identifier, this covaariate will be skipped.
2609
+ covariate_obs: list of column names in adata.obs, where covariate values are stored.
2610
+ Note: If covariate values are not unique for a value of sample_identifier, this covariate will be skipped.
1456
2611
  covariate_df: DataFrame with covariates
1457
2612
 
1458
2613
  Returns:
@@ -1461,50 +2616,40 @@ def from_scanpy(
1461
2616
  if isinstance(sample_identifier, str):
1462
2617
  sample_identifier = [sample_identifier]
1463
2618
 
1464
- if covariate_obs:
1465
- covariate_obs += [i for i in sample_identifier if i not in covariate_obs]
1466
- else:
1467
- covariate_obs = sample_identifier # type: ignore
1468
-
1469
- # join sample identifiers
1470
- if isinstance(sample_identifier, list):
2619
+ if len(sample_identifier) > 1:
1471
2620
  adata.obs["scCODA_sample_id"] = adata.obs[sample_identifier].agg("-".join, axis=1)
1472
2621
  sample_identifier = "scCODA_sample_id"
2622
+ else:
2623
+ sample_identifier = sample_identifier[0]
1473
2624
 
1474
2625
  # get cell type counts
1475
- groups = adata.obs.value_counts([sample_identifier, cell_type_identifier])
1476
- count_data = groups.unstack(level=cell_type_identifier)
1477
- count_data = count_data.fillna(0)
2626
+ ct_count_data = pd.crosstab(adata.obs[sample_identifier], adata.obs[cell_type_identifier])
2627
+ ct_count_data = ct_count_data.fillna(0)
1478
2628
 
1479
2629
  # get covariates from different sources
1480
- covariate_df_ = pd.DataFrame(index=count_data.index)
1481
-
1482
- if covariate_df is None and covariate_obs is None and covariate_uns is None:
1483
- print("No covariate information specified!")
2630
+ covariate_df_ = pd.DataFrame(index=ct_count_data.index)
1484
2631
 
1485
2632
  if covariate_uns is not None:
1486
- covariate_df_uns = pd.DataFrame(adata.uns[covariate_uns])
1487
- covariate_df_ = pd.concat((covariate_df_, covariate_df_uns), axis=1)
2633
+ covariate_df_uns = pd.DataFrame(adata.uns[covariate_uns], index=ct_count_data.index)
2634
+ covariate_df_ = covariate_df_.join(covariate_df_uns, how="left")
1488
2635
 
1489
- if covariate_obs is not None:
1490
- for c in covariate_obs:
1491
- if any(adata.obs.groupby(sample_identifier).nunique()[c] != 1):
1492
- print(f"Covariate {c} has non-unique values! Skipping...")
1493
- covariate_obs.remove(c)
2636
+ if covariate_obs:
2637
+ is_unique = adata.obs.groupby(sample_identifier, observed=True).transform(lambda x: x.nunique() == 1)
2638
+ unique_covariates = is_unique.columns[is_unique.all()].tolist()
1494
2639
 
1495
- covariate_df_obs = adata.obs.groupby(sample_identifier).first()[covariate_obs]
1496
- covariate_df_ = pd.concat((covariate_df_, covariate_df_obs), axis=1)
2640
+ if len(unique_covariates) < len(covariate_obs):
2641
+ skipped = set(covariate_obs) - set(unique_covariates)
2642
+ print(f"[bold yellow]Covariates {skipped} have non-unique values! Skipping...")
2643
+ if unique_covariates:
2644
+ covariate_df_obs = adata.obs.groupby(sample_identifier, observed=True).first()[unique_covariates]
2645
+ covariate_df_ = covariate_df_.join(covariate_df_obs, how="left")
1497
2646
 
1498
2647
  if covariate_df is not None:
1499
- if set(covariate_df.index) != set(count_data.index):
1500
- raise ValueError("anndata sample names and covariate_df index do not have the same elements!")
1501
- covs_ord = covariate_df.reindex(count_data.index)
1502
- covariate_df_ = pd.concat((covariate_df_, covs_ord), axis=1)
1503
-
1504
- covariate_df_.index = covariate_df_.index.astype(str)
2648
+ if not covariate_df.index.equals(ct_count_data.index):
2649
+ raise ValueError("AnnData sample names and covariate_df index do not have the same elements!")
2650
+ covariate_df_ = covariate_df_.join(covariate_df, how="left")
1505
2651
 
1506
- # create var (number of cells for each type as only column)
1507
- var_dat = count_data.sum(axis=0).rename("n_cells").to_frame()
2652
+ var_dat = ct_count_data.sum(axis=0).rename("n_cells").to_frame()
1508
2653
  var_dat.index = var_dat.index.astype(str)
1509
2654
 
1510
- return AnnData(X=count_data.values, var=var_dat, obs=covariate_df_)
2655
+ return AnnData(X=ct_count_data.values, var=var_dat, obs=covariate_df_)