pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/plot/_coda.py DELETED
@@ -1,1001 +0,0 @@
1
- from pathlib import Path
2
- from typing import Literal, Optional, Union
3
-
4
- import matplotlib.image as mpimg
5
- import matplotlib.pyplot as plt
6
- import numpy as np
7
- import pandas as pd
8
- import scanpy as sc
9
- import seaborn as sns
10
- from adjustText import adjust_text
11
- from anndata import AnnData
12
- from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
13
- from matplotlib import cm, rcParams
14
- from matplotlib.axes import Axes
15
- from matplotlib.colors import ListedColormap
16
- from mudata import MuData
17
-
18
- from pertpy.tools._coda._base_coda import CompositionalModel2, collapse_singularities_2
19
-
20
- sns.set_style("ticks")
21
-
22
-
23
- class CodaPlot:
24
- @staticmethod
25
- def __stackbar( # pragma: no cover
26
- y: np.ndarray,
27
- type_names: list[str],
28
- title: str,
29
- level_names: list[str],
30
- figsize: Optional[tuple[float, float]] = None,
31
- dpi: Optional[int] = 100,
32
- cmap: Optional[ListedColormap] = cm.tab20,
33
- show_legend: Optional[bool] = True,
34
- ) -> plt.Axes:
35
- """Plots a stacked barplot for one (discrete) covariate.
36
-
37
- Typical use (only inside stacked_barplot): plot_one_stackbar(data.X, data.var.index, "xyz", data.obs.index)
38
-
39
- Args:
40
- y: The count data, collapsed onto the level of interest. i.e. a binary covariate has two rows,
41
- one for each group, containing the count mean of each cell type
42
- type_names: The names of all cell types
43
- title: Plot title, usually the covariate's name
44
- level_names: Names of the covariate's levels
45
- figsize: Figure size. Defaults to None.
46
- dpi: Dpi setting. Defaults to 100.
47
- cmap: The color map for the barplot. Defaults to cm.tab20.
48
- show_legend: If True, adds a legend. Defaults to True.
49
-
50
- Returns:
51
- A :class:`~matplotlib.axes.Axes` object
52
- """
53
- n_bars, n_types = y.shape
54
-
55
- figsize = rcParams["figure.figsize"] if figsize is None else figsize
56
-
57
- _, ax = plt.subplots(figsize=figsize, dpi=dpi)
58
- r = np.array(range(n_bars))
59
- sample_sums = np.sum(y, axis=1)
60
-
61
- barwidth = 0.85
62
- cum_bars = np.zeros(n_bars)
63
-
64
- for n in range(n_types):
65
- bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums)]
66
- plt.bar(
67
- r,
68
- bars,
69
- bottom=cum_bars,
70
- color=cmap(n % cmap.N),
71
- width=barwidth,
72
- label=type_names[n],
73
- linewidth=0,
74
- )
75
- cum_bars += bars
76
-
77
- ax.set_title(title)
78
- if show_legend:
79
- ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1)
80
- ax.set_xticks(r)
81
- ax.set_xticklabels(level_names, rotation=45, ha="right")
82
- ax.set_ylabel("Proportion")
83
-
84
- return ax
85
-
86
- @staticmethod
87
- def stacked_barplot( # pragma: no cover
88
- data: Union[AnnData, MuData],
89
- feature_name: str,
90
- modality_key: str = "coda",
91
- figsize: Optional[tuple[float, float]] = None,
92
- dpi: Optional[int] = 100,
93
- cmap: Optional[ListedColormap] = cm.tab20,
94
- show_legend: Optional[bool] = True,
95
- level_order: list[str] = None,
96
- ) -> plt.Axes:
97
- """Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
98
-
99
- Args:
100
- data: AnnData object or MuData object.
101
- feature_name: The name of the covariate to plot. If feature_name=="samples", one bar for every sample will be plotted
102
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
103
- figsize: Figure size. Defaults to None.
104
- dpi: Dpi setting. Defaults to 100.
105
- cmap: The matplotlib color map for the barplot. Defaults to cm.tab20.
106
- show_legend: If True, adds a legend. Defaults to True.
107
- level_order: Custom ordering of bars on the x-axis. Defaults to None.
108
-
109
- Returns:
110
- A :class:`~matplotlib.axes.Axes` object
111
-
112
- Examples:
113
- Example with scCODA:
114
- >>> import pertpy as pt
115
- >>> haber_cells = pt.dt.haber_2017_regions()
116
- >>> sccoda = pt.tl.Sccoda()
117
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
118
- sample_identifier="batch", covariate_obs=["condition"])
119
- >>> pt.pl.coda.stacked_barplot(mdata, feature_name="samples")
120
- """
121
- if isinstance(data, MuData):
122
- data = data[modality_key]
123
- if isinstance(data, AnnData):
124
- data = data
125
-
126
- ct_names = data.var.index
127
-
128
- # option to plot one stacked barplot per sample
129
- if feature_name == "samples":
130
- if level_order:
131
- assert set(level_order) == set(data.obs.index), "level order is inconsistent with levels"
132
- data = data[level_order]
133
- ax = CodaPlot.__stackbar(
134
- data.X,
135
- type_names=data.var.index,
136
- title="samples",
137
- level_names=data.obs.index,
138
- figsize=figsize,
139
- dpi=dpi,
140
- cmap=cmap,
141
- show_legend=show_legend,
142
- )
143
- else:
144
- # Order levels
145
- if level_order:
146
- assert set(level_order) == set(data.obs[feature_name]), "level order is inconsistent with levels"
147
- levels = level_order
148
- elif hasattr(data.obs[feature_name], "cat"):
149
- levels = data.obs[feature_name].cat.categories.to_list()
150
- else:
151
- levels = pd.unique(data.obs[feature_name])
152
- n_levels = len(levels)
153
- feature_totals = np.zeros([n_levels, data.X.shape[1]])
154
-
155
- for level in range(n_levels):
156
- l_indices = np.where(data.obs[feature_name] == levels[level])
157
- feature_totals[level] = np.sum(data.X[l_indices], axis=0)
158
-
159
- ax = CodaPlot.__stackbar(
160
- feature_totals,
161
- type_names=ct_names,
162
- title=feature_name,
163
- level_names=levels,
164
- figsize=figsize,
165
- dpi=dpi,
166
- cmap=cmap,
167
- show_legend=show_legend,
168
- )
169
- return ax
170
-
171
- @staticmethod
172
- def effects_barplot( # pragma: no cover
173
- data: Union[AnnData, MuData],
174
- modality_key: str = "coda",
175
- covariates: Optional[Union[str, list]] = None,
176
- parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
177
- plot_facets: bool = True,
178
- plot_zero_covariate: bool = True,
179
- plot_zero_cell_type: bool = False,
180
- figsize: Optional[tuple[float, float]] = None,
181
- dpi: Optional[int] = 100,
182
- cmap: Optional[Union[str, ListedColormap]] = cm.tab20,
183
- level_order: list[str] = None,
184
- args_barplot: Optional[dict] = None,
185
- ) -> Optional[Union[plt.Axes, sns.axisgrid.FacetGrid]]:
186
- """Barplot visualization for effects.
187
-
188
- The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
189
- The covariates groups can either be ordered along the x-axis of a single plot (plot_facets=False) or as plot facets (plot_facets=True).
190
-
191
- Args:
192
- data: AnnData object or MuData object.
193
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
194
- covariates: The name of the covariates in data.obs to plot. Defaults to None.
195
- parameter: The parameter in effect summary to plot. Defaults to "log2-fold change".
196
- plot_facets: If False, plot cell types on the x-axis. If True, plot as facets. Defaults to True.
197
- plot_zero_covariate: If True, plot covariate that have all zero effects. If False, do not plot. Defaults to True.
198
- plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot. Defaults to False.
199
- figsize: Figure size. Defaults to None.
200
- dpi: Figure size. Defaults to 100.
201
- cmap: The seaborn color map for the barplot. Defaults to cm.tab20.
202
- level_order: Custom ordering of bars on the x-axis. Defaults to None.
203
- args_barplot: Arguments passed to sns.barplot. Defaults to None.
204
-
205
- Returns:
206
- Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
207
- or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
208
-
209
- Examples:
210
- Example with scCODA:
211
- >>> import pertpy as pt
212
- >>> haber_cells = pt.dt.haber_2017_regions()
213
- >>> sccoda = pt.tl.Sccoda()
214
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
215
- sample_identifier="batch", covariate_obs=["condition"])
216
- >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
217
- >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
218
- >>> pt.pl.coda.effects_barplot(mdata)
219
- """
220
- if args_barplot is None:
221
- args_barplot = {}
222
- if isinstance(data, MuData):
223
- data = data[modality_key]
224
- if isinstance(data, AnnData):
225
- data = data
226
- # Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
227
- covariate_names = data.uns["scCODA_params"]["covariate_names"]
228
- if covariates is not None:
229
- if isinstance(covariates, str):
230
- covariates = [covariates]
231
- partial_covariate_names = [
232
- covariate_name
233
- for covariate_name in covariate_names
234
- if any(covariate in covariate_name for covariate in covariates)
235
- ]
236
- covariate_names = partial_covariate_names
237
- covariate_names_non_zero = [
238
- covariate_name
239
- for covariate_name in covariate_names
240
- if data.varm[f"effect_df_{covariate_name}"][parameter].any()
241
- ]
242
- covariate_names_zero = list(set(covariate_names) - set(covariate_names_non_zero))
243
- if not plot_zero_covariate:
244
- covariate_names = covariate_names_non_zero
245
-
246
- # set up df for plotting
247
- plot_df = pd.concat(
248
- [data.varm[f"effect_df_{covariate_name}"][parameter] for covariate_name in covariate_names],
249
- axis=1,
250
- )
251
- plot_df.columns = covariate_names
252
- plot_df = pd.melt(plot_df, ignore_index=False, var_name="Covariate")
253
-
254
- plot_df = plot_df.reset_index()
255
-
256
- if len(covariate_names_zero) != 0:
257
- if plot_facets:
258
- if plot_zero_covariate and not plot_zero_cell_type:
259
- plot_df = plot_df[plot_df["value"] != 0]
260
- for covariate_name_zero in covariate_names_zero:
261
- new_row = {
262
- "Covariate": covariate_name_zero,
263
- "Cell Type": "zero",
264
- "value": 0,
265
- }
266
- plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
267
- plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
268
- plot_df = plot_df.sort_values(["covariate_"])
269
- if not plot_zero_cell_type:
270
- cell_type_names_zero = [
271
- name
272
- for name in plot_df["Cell Type"].unique()
273
- if (plot_df[plot_df["Cell Type"] == name]["value"] == 0).all()
274
- ]
275
- plot_df = plot_df[~plot_df["Cell Type"].isin(cell_type_names_zero)]
276
-
277
- # If plot as facets, create a FacetGrid and map barplot to it.
278
- if plot_facets:
279
- if isinstance(cmap, ListedColormap):
280
- cmap = np.array([cmap(i % cmap.N) for i in range(len(plot_df["Cell Type"].unique()))])
281
- if figsize is not None:
282
- height = figsize[0]
283
- aspect = np.round(figsize[1] / figsize[0], 2)
284
- else:
285
- height = 3
286
- aspect = 2
287
-
288
- g = sns.FacetGrid(
289
- plot_df,
290
- col="Covariate",
291
- sharey=True,
292
- sharex=False,
293
- height=height,
294
- aspect=aspect,
295
- )
296
-
297
- g.map(
298
- sns.barplot,
299
- "Cell Type",
300
- "value",
301
- palette=cmap,
302
- order=level_order,
303
- **args_barplot,
304
- )
305
- g.set_xticklabels(rotation=90)
306
- g.set(ylabel=parameter)
307
- axes = g.axes.flatten()
308
- for i, ax in enumerate(axes):
309
- ax.set_title(covariate_names[i])
310
- if len(ax.get_xticklabels()) < 5:
311
- ax.set_aspect(10 / len(ax.get_xticklabels()))
312
- if len(ax.get_xticklabels()) == 1:
313
- if ax.get_xticklabels()[0]._text == "zero":
314
- ax.set_xticks([])
315
- return g
316
-
317
- # If not plot as facets, call barplot to plot cell types on the x-axis.
318
- else:
319
- _, ax = plt.subplots(figsize=figsize, dpi=dpi)
320
- if len(covariate_names) == 1:
321
- if isinstance(cmap, ListedColormap):
322
- cmap = np.array([cmap(i % cmap.N) for i in range(len(plot_df["Cell Type"].unique()))])
323
- sns.barplot(
324
- data=plot_df,
325
- x="Cell Type",
326
- y="value",
327
- palette=cmap,
328
- ax=ax,
329
- )
330
- ax.set_title(covariate_names[0])
331
- else:
332
- if isinstance(cmap, ListedColormap):
333
- cmap = np.array([cmap(i % cmap.N) for i in range(len(covariate_names))])
334
- sns.barplot(
335
- data=plot_df,
336
- x="Cell Type",
337
- y="value",
338
- hue="Covariate",
339
- palette=cmap,
340
- ax=ax,
341
- )
342
- cell_types = pd.unique(plot_df["Cell Type"])
343
- ax.set_xticklabels(cell_types, rotation=90)
344
- return ax
345
-
346
- @staticmethod
347
- def boxplots( # pragma: no cover
348
- data: Union[AnnData, MuData],
349
- feature_name: str,
350
- modality_key: str = "coda",
351
- y_scale: Literal["relative", "log", "log10", "count"] = "relative",
352
- plot_facets: bool = False,
353
- add_dots: bool = False,
354
- model: CompositionalModel2 = None,
355
- cell_types: Optional[list] = None,
356
- args_boxplot: Optional[dict] = None,
357
- args_swarmplot: Optional[dict] = None,
358
- figsize: Optional[tuple[float, float]] = None,
359
- dpi: Optional[int] = 100,
360
- cmap: Optional[str] = "Blues",
361
- show_legend: Optional[bool] = True,
362
- level_order: list[str] = None,
363
- ) -> Optional[Union[plt.Axes, sns.axisgrid.FacetGrid]]:
364
- """Grouped boxplot visualization. The cell counts for each cell type are shown as a group of boxplots,
365
- with intra--group separation by a covariate from data.obs.
366
-
367
- Args:
368
- data: AnnData object or MuData object
369
- feature_name: The name of the feature in data.obs to plot
370
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
371
- y_scale: Transformation to of cell counts. Options: "relative" - Relative abundance, "log" - log(count),
372
- "log10" - log10(count), "count" - absolute abundance (cell counts).
373
- Defaults to "relative".
374
- plot_facets: If False, plot cell types on the x-axis. If True, plot as facets. Defaults to False.
375
- add_dots: If True, overlay a scatterplot with one dot for each data point. Defaults to False.
376
- model: When draw_effects, specify a tasCODA model
377
- cell_types: Subset of cell types that should be plotted. Defaults to None.
378
- args_boxplot: Arguments passed to sns.boxplot. Defaults to {}.
379
- args_swarmplot: Arguments passed to sns.swarmplot. Defaults to {}.
380
- figsize: Figure size. Defaults to None.
381
- dpi: Dpi setting. Defaults to 100.
382
- cmap: The seaborn color map for the barplot. Defaults to "Blues".
383
- show_legend: If True, adds a legend. Defaults to True.
384
- level_order: Custom ordering of bars on the x-axis. Defaults to None.
385
-
386
- Returns:
387
- Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
388
- or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
389
-
390
- Examples:
391
- Example with scCODA:
392
- >>> import pertpy as pt
393
- >>> haber_cells = pt.dt.haber_2017_regions()
394
- >>> sccoda = pt.tl.Sccoda()
395
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
396
- sample_identifier="batch", covariate_obs=["condition"])
397
- >>> pt.pl.coda.boxplots(mdata, feature_name="condition", add_dots=True)
398
- """
399
- if args_boxplot is None:
400
- args_boxplot = {}
401
- if args_swarmplot is None:
402
- args_swarmplot = {}
403
- if isinstance(data, MuData):
404
- data = data[modality_key]
405
- if isinstance(data, AnnData):
406
- data = data
407
- # y scale transformations
408
- if y_scale == "relative":
409
- sample_sums = np.sum(data.X, axis=1, keepdims=True)
410
- X = data.X / sample_sums
411
- value_name = "Proportion"
412
- # add pseudocount 0.5 if using log scale
413
- elif y_scale == "log":
414
- X = data.X.copy()
415
- X[X == 0] = 0.5
416
- X = np.log(X)
417
- value_name = "log(count)"
418
- elif y_scale == "log10":
419
- X = data.X.copy()
420
- X[X == 0] = 0.5
421
- X = np.log(X)
422
- value_name = "log10(count)"
423
- elif y_scale == "count":
424
- X = data.X
425
- value_name = "count"
426
- else:
427
- raise ValueError("Invalid y_scale transformation")
428
-
429
- count_df = pd.DataFrame(X, columns=data.var.index, index=data.obs.index).merge(
430
- data.obs[feature_name], left_index=True, right_index=True
431
- )
432
- plot_df = pd.melt(count_df, id_vars=feature_name, var_name="Cell type", value_name=value_name)
433
- if cell_types is not None:
434
- plot_df = plot_df[plot_df["Cell type"].isin(cell_types)]
435
-
436
- # Currently disabled because the latest statsannotations does not support the latest seaborn.
437
- # We had to drop the dependency.
438
- # Get credible effects results from model
439
- # if draw_effects:
440
- # if model is not None:
441
- # credible_effects_df = model.credible_effects(data, modality_key).to_frame().reset_index()
442
- # else:
443
- # print("[bold yellow]Specify a tasCODA model to draw effects")
444
- # credible_effects_df[feature_name] = credible_effects_df["Covariate"].str.removeprefix(f"{feature_name}[T.")
445
- # credible_effects_df[feature_name] = credible_effects_df[feature_name].str.removesuffix("]")
446
- # credible_effects_df = credible_effects_df[credible_effects_df["Final Parameter"]]
447
-
448
- # If plot as facets, create a FacetGrid and map boxplot to it.
449
- if plot_facets:
450
- if level_order is None:
451
- level_order = pd.unique(plot_df[feature_name])
452
-
453
- K = X.shape[1]
454
-
455
- if figsize is not None:
456
- height = figsize[0]
457
- aspect = np.round(figsize[1] / figsize[0], 2)
458
- else:
459
- height = 3
460
- aspect = 2
461
-
462
- g = sns.FacetGrid(
463
- plot_df,
464
- col="Cell type",
465
- sharey=False,
466
- col_wrap=int(np.floor(np.sqrt(K))),
467
- height=height,
468
- aspect=aspect,
469
- )
470
- g.map(
471
- sns.boxplot,
472
- feature_name,
473
- value_name,
474
- palette=cmap,
475
- order=level_order,
476
- **args_boxplot,
477
- )
478
-
479
- if add_dots:
480
- if "hue" in args_swarmplot:
481
- hue = args_swarmplot.pop("hue")
482
- else:
483
- hue = None
484
-
485
- if hue is None:
486
- g.map(
487
- sns.swarmplot,
488
- feature_name,
489
- value_name,
490
- color="black",
491
- order=level_order,
492
- **args_swarmplot,
493
- ).set_titles("{col_name}")
494
- else:
495
- g.map(
496
- sns.swarmplot,
497
- feature_name,
498
- value_name,
499
- hue,
500
- order=level_order,
501
- **args_swarmplot,
502
- ).set_titles("{col_name}")
503
- return g
504
-
505
- # If not plot as facets, call boxplot to plot cell types on the x-axis.
506
- else:
507
- if level_order:
508
- args_boxplot["hue_order"] = level_order
509
- args_swarmplot["hue_order"] = level_order
510
-
511
- _, ax = plt.subplots(figsize=figsize, dpi=dpi)
512
-
513
- ax = sns.boxplot(
514
- x="Cell type",
515
- y=value_name,
516
- hue=feature_name,
517
- data=plot_df,
518
- fliersize=1,
519
- palette=cmap,
520
- ax=ax,
521
- **args_boxplot,
522
- )
523
-
524
- # Currently disabled because the latest statsannotations does not support the latest seaborn.
525
- # We had to drop the dependency.
526
- # if draw_effects:
527
- # pairs = [
528
- # [(row["Cell Type"], row[feature_name]), (row["Cell Type"], "Control")]
529
- # for _, row in credible_effects_df.iterrows()
530
- # ]
531
- # annot = Annotator(ax, pairs, data=plot_df, x="Cell type", y=value_name, hue=feature_name)
532
- # annot.configure(test=None, loc="outside", color="red", line_height=0, verbose=False)
533
- # annot.set_custom_annotations([row[feature_name] for _, row in credible_effects_df.iterrows()])
534
- # annot.annotate()
535
-
536
- if add_dots:
537
- sns.swarmplot(
538
- x="Cell type",
539
- y=value_name,
540
- data=plot_df,
541
- hue=feature_name,
542
- ax=ax,
543
- dodge=True,
544
- color="black",
545
- **args_swarmplot,
546
- )
547
-
548
- cell_types = pd.unique(plot_df["Cell type"])
549
- ax.set_xticklabels(cell_types, rotation=90)
550
-
551
- if show_legend:
552
- handles, labels = ax.get_legend_handles_labels()
553
- handout = []
554
- labelout = []
555
- for h, l in zip(handles, labels):
556
- if l not in labelout:
557
- labelout.append(l)
558
- handout.append(h)
559
- ax.legend(
560
- handout,
561
- labelout,
562
- loc="upper left",
563
- bbox_to_anchor=(1, 1),
564
- ncol=1,
565
- title=feature_name,
566
- )
567
-
568
- plt.tight_layout()
569
- return ax
570
-
571
- @staticmethod
572
- def rel_abundance_dispersion_plot( # pragma: no cover
573
- data: Union[AnnData, MuData],
574
- modality_key: str = "coda",
575
- abundant_threshold: Optional[float] = 0.9,
576
- default_color: Optional[str] = "Grey",
577
- abundant_color: Optional[str] = "Red",
578
- label_cell_types: bool = True,
579
- figsize: Optional[tuple[float, float]] = None,
580
- dpi: Optional[int] = 100,
581
- ax: Axes = None,
582
- ) -> plt.Axes:
583
- """Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
584
-
585
- If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color.
586
-
587
- Args:
588
- data: AnnData object or MuData object.
589
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
590
- Defaults to "coda".
591
- abundant_threshold: Presence threshold for abundant cell types. Defaults to 0.9.
592
- default_color: Bar color for all non-minimal cell types. Defaults to "Grey".
593
- abundant_color: Bar color for cell types with abundant percentage larger than abundant_threshold.
594
- Defaults to "Red".
595
- label_cell_types: Label dots with cell type names. Defaults to True.
596
- figsize: Figure size. Defaults to None.
597
- dpi: Dpi setting. Defaults to 100.
598
- ax: A matplotlib axes object. Only works if plotting a single component. Defaults to None.
599
-
600
- Returns:
601
- A :class:`~matplotlib.axes.Axes` object
602
-
603
- Examples:
604
- Example with scCODA:
605
- >>> import pertpy as pt
606
- >>> haber_cells = pt.dt.haber_2017_regions()
607
- >>> sccoda = pt.tl.Sccoda()
608
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
609
- sample_identifier="batch", covariate_obs=["condition"])
610
- >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
611
- >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
612
- >>> pt.pl.coda.rel_abundance_dispersion_plot(mdata)
613
- """
614
- if isinstance(data, MuData):
615
- data = data[modality_key]
616
- if isinstance(data, AnnData):
617
- data = data
618
- if ax is None:
619
- _, ax = plt.subplots(figsize=figsize, dpi=dpi)
620
-
621
- rel_abun = data.X / np.sum(data.X, axis=1, keepdims=True)
622
-
623
- percent_zero = np.sum(data.X == 0, axis=0) / data.X.shape[0]
624
- nonrare_ct = np.where(percent_zero < 1 - abundant_threshold)[0]
625
-
626
- # select reference
627
- cell_type_disp = np.var(rel_abun, axis=0) / np.mean(rel_abun, axis=0)
628
-
629
- is_abundant = [x in nonrare_ct for x in range(data.X.shape[1])]
630
-
631
- # Scatterplot
632
- plot_df = pd.DataFrame(
633
- {
634
- "Total dispersion": cell_type_disp,
635
- "Cell type": data.var.index,
636
- "Presence": 1 - percent_zero,
637
- "Is abundant": is_abundant,
638
- }
639
- )
640
-
641
- if len(np.unique(plot_df["Is abundant"])) > 1:
642
- palette = [default_color, abundant_color]
643
- elif np.unique(plot_df["Is abundant"]) == [False]:
644
- palette = [default_color]
645
- else:
646
- palette = [abundant_color]
647
-
648
- ax = sns.scatterplot(
649
- data=plot_df,
650
- x="Presence",
651
- y="Total dispersion",
652
- hue="Is abundant",
653
- palette=palette,
654
- ax=ax,
655
- )
656
-
657
- # Text labels for abundant cell types
658
-
659
- abundant_df = plot_df.loc[plot_df["Is abundant"], :]
660
-
661
- def label_point(x, y, val, ax):
662
- a = pd.concat({"x": x, "y": y, "val": val}, axis=1)
663
- texts = [
664
- ax.text(
665
- point["x"],
666
- point["y"],
667
- str(point["val"]),
668
- )
669
- for i, point in a.iterrows()
670
- ]
671
- adjust_text(texts)
672
-
673
- if label_cell_types:
674
- label_point(
675
- abundant_df["Presence"],
676
- abundant_df["Total dispersion"],
677
- abundant_df["Cell type"],
678
- plt.gca(),
679
- )
680
-
681
- ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
682
-
683
- plt.tight_layout()
684
- return ax
685
-
686
- @staticmethod
687
- def draw_tree( # pragma: no cover
688
- data: Union[AnnData, MuData],
689
- modality_key: str = "coda",
690
- tree: Union[Tree, str] = "tree",
691
- tight_text: Optional[bool] = False,
692
- show_scale: Optional[bool] = False,
693
- show: Optional[bool] = True,
694
- file_name: Optional[str] = None,
695
- units: Optional[Literal["px", "mm", "in"]] = "px",
696
- h: Optional[float] = None,
697
- w: Optional[float] = None,
698
- dpi: Optional[int] = 90,
699
- ):
700
- """Plot a tree using input ete3 tree object.
701
-
702
- Args:
703
- data: AnnData object or MuData object.
704
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
705
- tree: A ete3 tree object or a str to indicate the tree stored in `.uns`. Defaults to "tree".
706
- tight_text: When False, boundaries of the text are approximated according to general font metrics,
707
- producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
708
- Default to False.
709
- show_scale: Include the scale legend in the tree image or not. Default to False.
710
- show: If True, plot the tree inline. If false, return tree and tree_style objects. Defaults to True.
711
- file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
712
- Defaults to None.
713
- units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px".
714
- h: Height of the image in units. Defaults to None.
715
- w: Width of the image in units. Defaults to None.
716
- dpi: Dots per inches. Defaults to 90.
717
-
718
- Returns:
719
- Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
720
-
721
- Examples:
722
- Example with tascCODA:
723
- >>> import pertpy as pt
724
- >>> adata = pt.dt.smillie()
725
- >>> tasccoda = pt.tl.Tasccoda()
726
- >>> mdata = tasccoda.load(
727
- >>> adata, type="sample_level",
728
- >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
729
- >>> key_added="lineage", add_level_name=True
730
- >>> )
731
- >>> mdata = tasccoda.prepare(
732
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
733
- >>> )
734
- >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
735
- >>> pt.pl.coda.draw_tree(mdata, tree="lineage")
736
- """
737
- if isinstance(data, MuData):
738
- data = data[modality_key]
739
- if isinstance(data, AnnData):
740
- data = data
741
- if isinstance(tree, str):
742
- tree = data.uns[tree]
743
-
744
- def my_layout(node):
745
- text_face = TextFace(node.name, tight_text=tight_text)
746
- faces.add_face_to_node(text_face, node, column=0, position="branch-right")
747
-
748
- tree_style = TreeStyle()
749
- tree_style.show_leaf_name = False
750
- tree_style.layout_fn = my_layout
751
- tree_style.show_scale = show_scale
752
- if file_name is not None:
753
- tree.render(file_name, tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
754
- if show:
755
- return tree.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
756
- else:
757
- return tree, tree_style
758
-
759
- @staticmethod
760
- def draw_effects( # pragma: no cover
761
- data: Union[AnnData, MuData],
762
- covariate: str,
763
- modality_key: str = "coda",
764
- tree: Union[Tree, str] = "tree",
765
- show_legend: Optional[bool] = None,
766
- show_leaf_effects: Optional[bool] = False,
767
- tight_text: Optional[bool] = False,
768
- show_scale: Optional[bool] = False,
769
- show: Optional[bool] = True,
770
- file_name: Optional[str] = None,
771
- units: Optional[Literal["px", "mm", "in"]] = "in",
772
- h: Optional[float] = None,
773
- w: Optional[float] = None,
774
- dpi: Optional[int] = 90,
775
- ):
776
- """Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
777
-
778
- Args:
779
- data: AnnData object or MuData object.
780
- covariate: The covariate, whose effects should be plotted.
781
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
782
- tree: A ete3 tree object or a str to indicate the tree stored in `.uns`. Defaults to "tree".
783
- show_legend: If show legend of nodes significant effects or not. Default is False if show_leaf_effects is True.
784
- show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects. Defaults to False.
785
- tight_text: When False, boundaries of the text are approximated according to general font metrics,
786
- producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
787
- Defaults to False.
788
- show_scale: Include the scale legend in the tree image or not. Defaults to False.
789
- show: If True, plot the tree inline. If false, return tree and tree_style objects. Defaults to True.
790
- file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
791
- Defaults to None.
792
- units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Default is "in". Defaults to "in".
793
- h: Height of the image in units. Defaults to None.
794
- w: Width of the image in units. Defaults to None.
795
- dpi: Dots per inches. Defaults to 90.
796
-
797
- Returns:
798
- Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`)
799
- or plot the tree inline (`show = False`)
800
-
801
- Examples:
802
- Example with tascCODA:
803
- >>> import pertpy as pt
804
- >>> adata = pt.dt.smillie()
805
- >>> tasccoda = pt.tl.Tasccoda()
806
- >>> mdata = tasccoda.load(
807
- >>> adata, type="sample_level",
808
- >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
809
- >>> key_added="lineage", add_level_name=True
810
- >>> )
811
- >>> mdata = tasccoda.prepare(
812
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
813
- >>> )
814
- >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
815
- >>> pt.pl.coda.draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
816
- """
817
- if isinstance(data, MuData):
818
- data = data[modality_key]
819
- if isinstance(data, AnnData):
820
- data = data
821
- if show_legend is None:
822
- show_legend = not show_leaf_effects
823
- elif show_legend:
824
- print("Tree leaves and leaf effect bars won't be aligned when legend is shown!")
825
-
826
- if isinstance(tree, str):
827
- tree = data.uns[tree]
828
- # Collapse tree singularities
829
- tree2 = collapse_singularities_2(tree)
830
-
831
- node_effs = data.uns["scCODA_params"]["node_df"].loc[(covariate + "_node",),].copy()
832
- node_effs.index = node_effs.index.get_level_values("Node")
833
-
834
- covariates = data.uns["scCODA_params"]["covariate_names"]
835
- effect_dfs = [data.varm[f"effect_df_{cov}"] for cov in covariates]
836
- eff_df = pd.concat(effect_dfs)
837
- eff_df.index = pd.MultiIndex.from_product(
838
- (covariates, data.var.index.tolist()),
839
- names=["Covariate", "Cell Type"],
840
- )
841
- leaf_effs = eff_df.loc[(covariate,),].copy()
842
- leaf_effs.index = leaf_effs.index.get_level_values("Cell Type")
843
-
844
- # Add effect values
845
- for n in tree2.traverse():
846
- nstyle = NodeStyle()
847
- nstyle["size"] = 0
848
- n.set_style(nstyle)
849
- if n.name in node_effs.index:
850
- e = node_effs.loc[n.name, "Final Parameter"]
851
- n.add_feature("node_effect", e)
852
- else:
853
- n.add_feature("node_effect", 0)
854
- if n.name in leaf_effs.index:
855
- e = leaf_effs.loc[n.name, "Effect"]
856
- n.add_feature("leaf_effect", e)
857
- else:
858
- n.add_feature("leaf_effect", 0)
859
-
860
- # Scale effect values to get nice node sizes
861
- eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
862
- leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
863
-
864
- def my_layout(node):
865
- text_face = TextFace(node.name, tight_text=tight_text)
866
- text_face.margin_left = 10
867
- faces.add_face_to_node(text_face, node, column=0, aligned=True)
868
-
869
- # if node.is_leaf():
870
- size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
871
- if np.sign(node.node_effect) == 1:
872
- color = "blue"
873
- elif np.sign(node.node_effect) == -1:
874
- color = "red"
875
- else:
876
- color = "cyan"
877
- if size != 0:
878
- faces.add_face_to_node(CircleFace(radius=size, color=color), node, column=0)
879
-
880
- tree_style = TreeStyle()
881
- tree_style.show_leaf_name = False
882
- tree_style.layout_fn = my_layout
883
- tree_style.show_scale = show_scale
884
- tree_style.draw_guiding_lines = True
885
- tree_style.legend_position = 1
886
-
887
- if show_legend:
888
- tree_style.legend.add_face(TextFace("Effects"), column=0)
889
- tree_style.legend.add_face(TextFace(" "), column=1)
890
- for i in range(4, 0, -1):
891
- tree_style.legend.add_face(
892
- CircleFace(
893
- float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
894
- "red",
895
- ),
896
- column=0,
897
- )
898
- tree_style.legend.add_face(TextFace(f"{-eff_max * i / 4:.2f} "), column=0)
899
- tree_style.legend.add_face(
900
- CircleFace(
901
- float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
902
- "blue",
903
- ),
904
- column=1,
905
- )
906
- tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
907
-
908
- if show_leaf_effects:
909
- leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
910
- leaf_effs = leaf_effs.loc[leaf_name].reset_index()
911
- palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
912
-
913
- dir_path = Path.cwd()
914
- dir_path = Path(dir_path / "tree_effect.png")
915
- tree2.render(dir_path, tree_style=tree_style, units="in")
916
- _, ax = plt.subplots(1, 2, figsize=(10, 10))
917
- sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
918
- img = mpimg.imread(dir_path)
919
- ax[0].imshow(img)
920
- ax[0].get_xaxis().set_visible(False)
921
- ax[0].get_yaxis().set_visible(False)
922
- ax[0].set_frame_on(False)
923
-
924
- ax[1].get_yaxis().set_visible(False)
925
- ax[1].spines["left"].set_visible(False)
926
- ax[1].spines["right"].set_visible(False)
927
- ax[1].spines["top"].set_visible(False)
928
- plt.xlim(-leaf_eff_max, leaf_eff_max)
929
- plt.subplots_adjust(wspace=0)
930
-
931
- if file_name is not None:
932
- plt.savefig(file_name)
933
-
934
- if file_name is not None and not show_leaf_effects:
935
- tree2.render(file_name, tree_style=tree_style, units=units)
936
- if show:
937
- if not show_leaf_effects:
938
- return tree2.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
939
- else:
940
- if not show_leaf_effects:
941
- return tree2, tree_style
942
-
943
- @staticmethod
944
- def effects_umap( # pragma: no cover
945
- data: MuData,
946
- effect_name: Optional[Union[str, list]],
947
- cluster_key: str,
948
- modality_key_1: str = "rna",
949
- modality_key_2: str = "coda",
950
- show: bool = None,
951
- ax: Axes = None,
952
- **kwargs,
953
- ):
954
- """Plot a UMAP visualization colored by effect strength.
955
-
956
- Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData (default is data['rna']) depending on the cluster they were assigned to.
957
-
958
- Args:
959
- data: AnnData object or MuData object.
960
- effect_name: The name of the effect results in .varm of aggregated sample-level AnnData (default is data['coda']) to plot
961
- cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']). To assign cell types' effects to original cells.
962
- modality_key_1: Key to the cell-level AnnData in the MuData object. Defaults to "rna".
963
- modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object. Defaults to "coda".
964
- show: Whether to display the figure or return axis. Defaults to None.
965
- ax: A matplotlib axes object. Only works if plotting a single component. Defaults to None.
966
- **kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
967
-
968
- Returns:
969
- If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
970
-
971
- Examples:
972
- Example with scCODA:
973
- >>> import pertpy as pt
974
- >>> haber_cells = pt.dt.haber_2017_regions()
975
- >>> sccoda = pt.tl.Sccoda()
976
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
977
- sample_identifier="batch", covariate_obs=["condition"])
978
- >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
979
- >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
980
-
981
- >>> pt.pl.coda.effects_umap(mdata, effect_name="", cluster_key="")
982
- #TODO: Add effect_name parameter and cluster_key and test the example
983
- """
984
- data_rna = data[modality_key_1]
985
- data_coda = data[modality_key_2]
986
- if isinstance(effect_name, str):
987
- effect_name = [effect_name]
988
- for _, effect in enumerate(effect_name):
989
- data_rna.obs[effect] = [data_coda.varm[effect].loc[f"{c}", "Effect"] for c in data_rna.obs[cluster_key]]
990
- if kwargs.get("vmin"):
991
- vmin = kwargs["vmin"]
992
- kwargs.pop("vmin")
993
- else:
994
- vmin = min(data_rna.obs[effect].min() for _, effect in enumerate(effect_name))
995
- if kwargs.get("vmax"):
996
- vmax = kwargs["vmax"]
997
- kwargs.pop("vmax")
998
- else:
999
- vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name))
1000
-
1001
- return sc.pl.umap(data_rna, color=effect_name, vmax=vmax, vmin=vmin, ax=ax, show=show, **kwargs)