pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/plot/_coda.py DELETED
@@ -1,1001 +0,0 @@
1
- from pathlib import Path
2
- from typing import Literal, Optional, Union
3
-
4
- import matplotlib.image as mpimg
5
- import matplotlib.pyplot as plt
6
- import numpy as np
7
- import pandas as pd
8
- import scanpy as sc
9
- import seaborn as sns
10
- from adjustText import adjust_text
11
- from anndata import AnnData
12
- from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
13
- from matplotlib import cm, rcParams
14
- from matplotlib.axes import Axes
15
- from matplotlib.colors import ListedColormap
16
- from mudata import MuData
17
-
18
- from pertpy.tools._coda._base_coda import CompositionalModel2, collapse_singularities_2
19
-
20
- sns.set_style("ticks")
21
-
22
-
23
- class CodaPlot:
24
- @staticmethod
25
- def __stackbar( # pragma: no cover
26
- y: np.ndarray,
27
- type_names: list[str],
28
- title: str,
29
- level_names: list[str],
30
- figsize: Optional[tuple[float, float]] = None,
31
- dpi: Optional[int] = 100,
32
- cmap: Optional[ListedColormap] = cm.tab20,
33
- show_legend: Optional[bool] = True,
34
- ) -> plt.Axes:
35
- """Plots a stacked barplot for one (discrete) covariate.
36
-
37
- Typical use (only inside stacked_barplot): plot_one_stackbar(data.X, data.var.index, "xyz", data.obs.index)
38
-
39
- Args:
40
- y: The count data, collapsed onto the level of interest. i.e. a binary covariate has two rows,
41
- one for each group, containing the count mean of each cell type
42
- type_names: The names of all cell types
43
- title: Plot title, usually the covariate's name
44
- level_names: Names of the covariate's levels
45
- figsize: Figure size. Defaults to None.
46
- dpi: Dpi setting. Defaults to 100.
47
- cmap: The color map for the barplot. Defaults to cm.tab20.
48
- show_legend: If True, adds a legend. Defaults to True.
49
-
50
- Returns:
51
- A :class:`~matplotlib.axes.Axes` object
52
- """
53
- n_bars, n_types = y.shape
54
-
55
- figsize = rcParams["figure.figsize"] if figsize is None else figsize
56
-
57
- _, ax = plt.subplots(figsize=figsize, dpi=dpi)
58
- r = np.array(range(n_bars))
59
- sample_sums = np.sum(y, axis=1)
60
-
61
- barwidth = 0.85
62
- cum_bars = np.zeros(n_bars)
63
-
64
- for n in range(n_types):
65
- bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums)]
66
- plt.bar(
67
- r,
68
- bars,
69
- bottom=cum_bars,
70
- color=cmap(n % cmap.N),
71
- width=barwidth,
72
- label=type_names[n],
73
- linewidth=0,
74
- )
75
- cum_bars += bars
76
-
77
- ax.set_title(title)
78
- if show_legend:
79
- ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1)
80
- ax.set_xticks(r)
81
- ax.set_xticklabels(level_names, rotation=45, ha="right")
82
- ax.set_ylabel("Proportion")
83
-
84
- return ax
85
-
86
- @staticmethod
87
- def stacked_barplot( # pragma: no cover
88
- data: Union[AnnData, MuData],
89
- feature_name: str,
90
- modality_key: str = "coda",
91
- figsize: Optional[tuple[float, float]] = None,
92
- dpi: Optional[int] = 100,
93
- cmap: Optional[ListedColormap] = cm.tab20,
94
- show_legend: Optional[bool] = True,
95
- level_order: list[str] = None,
96
- ) -> plt.Axes:
97
- """Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
98
-
99
- Args:
100
- data: AnnData object or MuData object.
101
- feature_name: The name of the covariate to plot. If feature_name=="samples", one bar for every sample will be plotted
102
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
103
- figsize: Figure size. Defaults to None.
104
- dpi: Dpi setting. Defaults to 100.
105
- cmap: The matplotlib color map for the barplot. Defaults to cm.tab20.
106
- show_legend: If True, adds a legend. Defaults to True.
107
- level_order: Custom ordering of bars on the x-axis. Defaults to None.
108
-
109
- Returns:
110
- A :class:`~matplotlib.axes.Axes` object
111
-
112
- Examples:
113
- Example with scCODA:
114
- >>> import pertpy as pt
115
- >>> haber_cells = pt.dt.haber_2017_regions()
116
- >>> sccoda = pt.tl.Sccoda()
117
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
118
- sample_identifier="batch", covariate_obs=["condition"])
119
- >>> pt.pl.coda.stacked_barplot(mdata, feature_name="samples")
120
- """
121
- if isinstance(data, MuData):
122
- data = data[modality_key]
123
- if isinstance(data, AnnData):
124
- data = data
125
-
126
- ct_names = data.var.index
127
-
128
- # option to plot one stacked barplot per sample
129
- if feature_name == "samples":
130
- if level_order:
131
- assert set(level_order) == set(data.obs.index), "level order is inconsistent with levels"
132
- data = data[level_order]
133
- ax = CodaPlot.__stackbar(
134
- data.X,
135
- type_names=data.var.index,
136
- title="samples",
137
- level_names=data.obs.index,
138
- figsize=figsize,
139
- dpi=dpi,
140
- cmap=cmap,
141
- show_legend=show_legend,
142
- )
143
- else:
144
- # Order levels
145
- if level_order:
146
- assert set(level_order) == set(data.obs[feature_name]), "level order is inconsistent with levels"
147
- levels = level_order
148
- elif hasattr(data.obs[feature_name], "cat"):
149
- levels = data.obs[feature_name].cat.categories.to_list()
150
- else:
151
- levels = pd.unique(data.obs[feature_name])
152
- n_levels = len(levels)
153
- feature_totals = np.zeros([n_levels, data.X.shape[1]])
154
-
155
- for level in range(n_levels):
156
- l_indices = np.where(data.obs[feature_name] == levels[level])
157
- feature_totals[level] = np.sum(data.X[l_indices], axis=0)
158
-
159
- ax = CodaPlot.__stackbar(
160
- feature_totals,
161
- type_names=ct_names,
162
- title=feature_name,
163
- level_names=levels,
164
- figsize=figsize,
165
- dpi=dpi,
166
- cmap=cmap,
167
- show_legend=show_legend,
168
- )
169
- return ax
170
-
171
- @staticmethod
172
- def effects_barplot( # pragma: no cover
173
- data: Union[AnnData, MuData],
174
- modality_key: str = "coda",
175
- covariates: Optional[Union[str, list]] = None,
176
- parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
177
- plot_facets: bool = True,
178
- plot_zero_covariate: bool = True,
179
- plot_zero_cell_type: bool = False,
180
- figsize: Optional[tuple[float, float]] = None,
181
- dpi: Optional[int] = 100,
182
- cmap: Optional[Union[str, ListedColormap]] = cm.tab20,
183
- level_order: list[str] = None,
184
- args_barplot: Optional[dict] = None,
185
- ) -> Optional[Union[plt.Axes, sns.axisgrid.FacetGrid]]:
186
- """Barplot visualization for effects.
187
-
188
- The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
189
- The covariates groups can either be ordered along the x-axis of a single plot (plot_facets=False) or as plot facets (plot_facets=True).
190
-
191
- Args:
192
- data: AnnData object or MuData object.
193
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
194
- covariates: The name of the covariates in data.obs to plot. Defaults to None.
195
- parameter: The parameter in effect summary to plot. Defaults to "log2-fold change".
196
- plot_facets: If False, plot cell types on the x-axis. If True, plot as facets. Defaults to True.
197
- plot_zero_covariate: If True, plot covariate that have all zero effects. If False, do not plot. Defaults to True.
198
- plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot. Defaults to False.
199
- figsize: Figure size. Defaults to None.
200
- dpi: Figure size. Defaults to 100.
201
- cmap: The seaborn color map for the barplot. Defaults to cm.tab20.
202
- level_order: Custom ordering of bars on the x-axis. Defaults to None.
203
- args_barplot: Arguments passed to sns.barplot. Defaults to None.
204
-
205
- Returns:
206
- Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
207
- or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
208
-
209
- Examples:
210
- Example with scCODA:
211
- >>> import pertpy as pt
212
- >>> haber_cells = pt.dt.haber_2017_regions()
213
- >>> sccoda = pt.tl.Sccoda()
214
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
215
- sample_identifier="batch", covariate_obs=["condition"])
216
- >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
217
- >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
218
- >>> pt.pl.coda.effects_barplot(mdata)
219
- """
220
- if args_barplot is None:
221
- args_barplot = {}
222
- if isinstance(data, MuData):
223
- data = data[modality_key]
224
- if isinstance(data, AnnData):
225
- data = data
226
- # Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
227
- covariate_names = data.uns["scCODA_params"]["covariate_names"]
228
- if covariates is not None:
229
- if isinstance(covariates, str):
230
- covariates = [covariates]
231
- partial_covariate_names = [
232
- covariate_name
233
- for covariate_name in covariate_names
234
- if any(covariate in covariate_name for covariate in covariates)
235
- ]
236
- covariate_names = partial_covariate_names
237
- covariate_names_non_zero = [
238
- covariate_name
239
- for covariate_name in covariate_names
240
- if data.varm[f"effect_df_{covariate_name}"][parameter].any()
241
- ]
242
- covariate_names_zero = list(set(covariate_names) - set(covariate_names_non_zero))
243
- if not plot_zero_covariate:
244
- covariate_names = covariate_names_non_zero
245
-
246
- # set up df for plotting
247
- plot_df = pd.concat(
248
- [data.varm[f"effect_df_{covariate_name}"][parameter] for covariate_name in covariate_names],
249
- axis=1,
250
- )
251
- plot_df.columns = covariate_names
252
- plot_df = pd.melt(plot_df, ignore_index=False, var_name="Covariate")
253
-
254
- plot_df = plot_df.reset_index()
255
-
256
- if len(covariate_names_zero) != 0:
257
- if plot_facets:
258
- if plot_zero_covariate and not plot_zero_cell_type:
259
- plot_df = plot_df[plot_df["value"] != 0]
260
- for covariate_name_zero in covariate_names_zero:
261
- new_row = {
262
- "Covariate": covariate_name_zero,
263
- "Cell Type": "zero",
264
- "value": 0,
265
- }
266
- plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
267
- plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
268
- plot_df = plot_df.sort_values(["covariate_"])
269
- if not plot_zero_cell_type:
270
- cell_type_names_zero = [
271
- name
272
- for name in plot_df["Cell Type"].unique()
273
- if (plot_df[plot_df["Cell Type"] == name]["value"] == 0).all()
274
- ]
275
- plot_df = plot_df[~plot_df["Cell Type"].isin(cell_type_names_zero)]
276
-
277
- # If plot as facets, create a FacetGrid and map barplot to it.
278
- if plot_facets:
279
- if isinstance(cmap, ListedColormap):
280
- cmap = np.array([cmap(i % cmap.N) for i in range(len(plot_df["Cell Type"].unique()))])
281
- if figsize is not None:
282
- height = figsize[0]
283
- aspect = np.round(figsize[1] / figsize[0], 2)
284
- else:
285
- height = 3
286
- aspect = 2
287
-
288
- g = sns.FacetGrid(
289
- plot_df,
290
- col="Covariate",
291
- sharey=True,
292
- sharex=False,
293
- height=height,
294
- aspect=aspect,
295
- )
296
-
297
- g.map(
298
- sns.barplot,
299
- "Cell Type",
300
- "value",
301
- palette=cmap,
302
- order=level_order,
303
- **args_barplot,
304
- )
305
- g.set_xticklabels(rotation=90)
306
- g.set(ylabel=parameter)
307
- axes = g.axes.flatten()
308
- for i, ax in enumerate(axes):
309
- ax.set_title(covariate_names[i])
310
- if len(ax.get_xticklabels()) < 5:
311
- ax.set_aspect(10 / len(ax.get_xticklabels()))
312
- if len(ax.get_xticklabels()) == 1:
313
- if ax.get_xticklabels()[0]._text == "zero":
314
- ax.set_xticks([])
315
- return g
316
-
317
- # If not plot as facets, call barplot to plot cell types on the x-axis.
318
- else:
319
- _, ax = plt.subplots(figsize=figsize, dpi=dpi)
320
- if len(covariate_names) == 1:
321
- if isinstance(cmap, ListedColormap):
322
- cmap = np.array([cmap(i % cmap.N) for i in range(len(plot_df["Cell Type"].unique()))])
323
- sns.barplot(
324
- data=plot_df,
325
- x="Cell Type",
326
- y="value",
327
- palette=cmap,
328
- ax=ax,
329
- )
330
- ax.set_title(covariate_names[0])
331
- else:
332
- if isinstance(cmap, ListedColormap):
333
- cmap = np.array([cmap(i % cmap.N) for i in range(len(covariate_names))])
334
- sns.barplot(
335
- data=plot_df,
336
- x="Cell Type",
337
- y="value",
338
- hue="Covariate",
339
- palette=cmap,
340
- ax=ax,
341
- )
342
- cell_types = pd.unique(plot_df["Cell Type"])
343
- ax.set_xticklabels(cell_types, rotation=90)
344
- return ax
345
-
346
- @staticmethod
347
- def boxplots( # pragma: no cover
348
- data: Union[AnnData, MuData],
349
- feature_name: str,
350
- modality_key: str = "coda",
351
- y_scale: Literal["relative", "log", "log10", "count"] = "relative",
352
- plot_facets: bool = False,
353
- add_dots: bool = False,
354
- model: CompositionalModel2 = None,
355
- cell_types: Optional[list] = None,
356
- args_boxplot: Optional[dict] = None,
357
- args_swarmplot: Optional[dict] = None,
358
- figsize: Optional[tuple[float, float]] = None,
359
- dpi: Optional[int] = 100,
360
- cmap: Optional[str] = "Blues",
361
- show_legend: Optional[bool] = True,
362
- level_order: list[str] = None,
363
- ) -> Optional[Union[plt.Axes, sns.axisgrid.FacetGrid]]:
364
- """Grouped boxplot visualization. The cell counts for each cell type are shown as a group of boxplots,
365
- with intra--group separation by a covariate from data.obs.
366
-
367
- Args:
368
- data: AnnData object or MuData object
369
- feature_name: The name of the feature in data.obs to plot
370
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
371
- y_scale: Transformation to of cell counts. Options: "relative" - Relative abundance, "log" - log(count),
372
- "log10" - log10(count), "count" - absolute abundance (cell counts).
373
- Defaults to "relative".
374
- plot_facets: If False, plot cell types on the x-axis. If True, plot as facets. Defaults to False.
375
- add_dots: If True, overlay a scatterplot with one dot for each data point. Defaults to False.
376
- model: When draw_effects, specify a tasCODA model
377
- cell_types: Subset of cell types that should be plotted. Defaults to None.
378
- args_boxplot: Arguments passed to sns.boxplot. Defaults to {}.
379
- args_swarmplot: Arguments passed to sns.swarmplot. Defaults to {}.
380
- figsize: Figure size. Defaults to None.
381
- dpi: Dpi setting. Defaults to 100.
382
- cmap: The seaborn color map for the barplot. Defaults to "Blues".
383
- show_legend: If True, adds a legend. Defaults to True.
384
- level_order: Custom ordering of bars on the x-axis. Defaults to None.
385
-
386
- Returns:
387
- Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
388
- or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
389
-
390
- Examples:
391
- Example with scCODA:
392
- >>> import pertpy as pt
393
- >>> haber_cells = pt.dt.haber_2017_regions()
394
- >>> sccoda = pt.tl.Sccoda()
395
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
396
- sample_identifier="batch", covariate_obs=["condition"])
397
- >>> pt.pl.coda.boxplots(mdata, feature_name="condition", add_dots=True)
398
- """
399
- if args_boxplot is None:
400
- args_boxplot = {}
401
- if args_swarmplot is None:
402
- args_swarmplot = {}
403
- if isinstance(data, MuData):
404
- data = data[modality_key]
405
- if isinstance(data, AnnData):
406
- data = data
407
- # y scale transformations
408
- if y_scale == "relative":
409
- sample_sums = np.sum(data.X, axis=1, keepdims=True)
410
- X = data.X / sample_sums
411
- value_name = "Proportion"
412
- # add pseudocount 0.5 if using log scale
413
- elif y_scale == "log":
414
- X = data.X.copy()
415
- X[X == 0] = 0.5
416
- X = np.log(X)
417
- value_name = "log(count)"
418
- elif y_scale == "log10":
419
- X = data.X.copy()
420
- X[X == 0] = 0.5
421
- X = np.log(X)
422
- value_name = "log10(count)"
423
- elif y_scale == "count":
424
- X = data.X
425
- value_name = "count"
426
- else:
427
- raise ValueError("Invalid y_scale transformation")
428
-
429
- count_df = pd.DataFrame(X, columns=data.var.index, index=data.obs.index).merge(
430
- data.obs[feature_name], left_index=True, right_index=True
431
- )
432
- plot_df = pd.melt(count_df, id_vars=feature_name, var_name="Cell type", value_name=value_name)
433
- if cell_types is not None:
434
- plot_df = plot_df[plot_df["Cell type"].isin(cell_types)]
435
-
436
- # Currently disabled because the latest statsannotations does not support the latest seaborn.
437
- # We had to drop the dependency.
438
- # Get credible effects results from model
439
- # if draw_effects:
440
- # if model is not None:
441
- # credible_effects_df = model.credible_effects(data, modality_key).to_frame().reset_index()
442
- # else:
443
- # print("[bold yellow]Specify a tasCODA model to draw effects")
444
- # credible_effects_df[feature_name] = credible_effects_df["Covariate"].str.removeprefix(f"{feature_name}[T.")
445
- # credible_effects_df[feature_name] = credible_effects_df[feature_name].str.removesuffix("]")
446
- # credible_effects_df = credible_effects_df[credible_effects_df["Final Parameter"]]
447
-
448
- # If plot as facets, create a FacetGrid and map boxplot to it.
449
- if plot_facets:
450
- if level_order is None:
451
- level_order = pd.unique(plot_df[feature_name])
452
-
453
- K = X.shape[1]
454
-
455
- if figsize is not None:
456
- height = figsize[0]
457
- aspect = np.round(figsize[1] / figsize[0], 2)
458
- else:
459
- height = 3
460
- aspect = 2
461
-
462
- g = sns.FacetGrid(
463
- plot_df,
464
- col="Cell type",
465
- sharey=False,
466
- col_wrap=int(np.floor(np.sqrt(K))),
467
- height=height,
468
- aspect=aspect,
469
- )
470
- g.map(
471
- sns.boxplot,
472
- feature_name,
473
- value_name,
474
- palette=cmap,
475
- order=level_order,
476
- **args_boxplot,
477
- )
478
-
479
- if add_dots:
480
- if "hue" in args_swarmplot:
481
- hue = args_swarmplot.pop("hue")
482
- else:
483
- hue = None
484
-
485
- if hue is None:
486
- g.map(
487
- sns.swarmplot,
488
- feature_name,
489
- value_name,
490
- color="black",
491
- order=level_order,
492
- **args_swarmplot,
493
- ).set_titles("{col_name}")
494
- else:
495
- g.map(
496
- sns.swarmplot,
497
- feature_name,
498
- value_name,
499
- hue,
500
- order=level_order,
501
- **args_swarmplot,
502
- ).set_titles("{col_name}")
503
- return g
504
-
505
- # If not plot as facets, call boxplot to plot cell types on the x-axis.
506
- else:
507
- if level_order:
508
- args_boxplot["hue_order"] = level_order
509
- args_swarmplot["hue_order"] = level_order
510
-
511
- _, ax = plt.subplots(figsize=figsize, dpi=dpi)
512
-
513
- ax = sns.boxplot(
514
- x="Cell type",
515
- y=value_name,
516
- hue=feature_name,
517
- data=plot_df,
518
- fliersize=1,
519
- palette=cmap,
520
- ax=ax,
521
- **args_boxplot,
522
- )
523
-
524
- # Currently disabled because the latest statsannotations does not support the latest seaborn.
525
- # We had to drop the dependency.
526
- # if draw_effects:
527
- # pairs = [
528
- # [(row["Cell Type"], row[feature_name]), (row["Cell Type"], "Control")]
529
- # for _, row in credible_effects_df.iterrows()
530
- # ]
531
- # annot = Annotator(ax, pairs, data=plot_df, x="Cell type", y=value_name, hue=feature_name)
532
- # annot.configure(test=None, loc="outside", color="red", line_height=0, verbose=False)
533
- # annot.set_custom_annotations([row[feature_name] for _, row in credible_effects_df.iterrows()])
534
- # annot.annotate()
535
-
536
- if add_dots:
537
- sns.swarmplot(
538
- x="Cell type",
539
- y=value_name,
540
- data=plot_df,
541
- hue=feature_name,
542
- ax=ax,
543
- dodge=True,
544
- color="black",
545
- **args_swarmplot,
546
- )
547
-
548
- cell_types = pd.unique(plot_df["Cell type"])
549
- ax.set_xticklabels(cell_types, rotation=90)
550
-
551
- if show_legend:
552
- handles, labels = ax.get_legend_handles_labels()
553
- handout = []
554
- labelout = []
555
- for h, l in zip(handles, labels):
556
- if l not in labelout:
557
- labelout.append(l)
558
- handout.append(h)
559
- ax.legend(
560
- handout,
561
- labelout,
562
- loc="upper left",
563
- bbox_to_anchor=(1, 1),
564
- ncol=1,
565
- title=feature_name,
566
- )
567
-
568
- plt.tight_layout()
569
- return ax
570
-
571
- @staticmethod
572
- def rel_abundance_dispersion_plot( # pragma: no cover
573
- data: Union[AnnData, MuData],
574
- modality_key: str = "coda",
575
- abundant_threshold: Optional[float] = 0.9,
576
- default_color: Optional[str] = "Grey",
577
- abundant_color: Optional[str] = "Red",
578
- label_cell_types: bool = True,
579
- figsize: Optional[tuple[float, float]] = None,
580
- dpi: Optional[int] = 100,
581
- ax: Axes = None,
582
- ) -> plt.Axes:
583
- """Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
584
-
585
- If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color.
586
-
587
- Args:
588
- data: AnnData object or MuData object.
589
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
590
- Defaults to "coda".
591
- abundant_threshold: Presence threshold for abundant cell types. Defaults to 0.9.
592
- default_color: Bar color for all non-minimal cell types. Defaults to "Grey".
593
- abundant_color: Bar color for cell types with abundant percentage larger than abundant_threshold.
594
- Defaults to "Red".
595
- label_cell_types: Label dots with cell type names. Defaults to True.
596
- figsize: Figure size. Defaults to None.
597
- dpi: Dpi setting. Defaults to 100.
598
- ax: A matplotlib axes object. Only works if plotting a single component. Defaults to None.
599
-
600
- Returns:
601
- A :class:`~matplotlib.axes.Axes` object
602
-
603
- Examples:
604
- Example with scCODA:
605
- >>> import pertpy as pt
606
- >>> haber_cells = pt.dt.haber_2017_regions()
607
- >>> sccoda = pt.tl.Sccoda()
608
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
609
- sample_identifier="batch", covariate_obs=["condition"])
610
- >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
611
- >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
612
- >>> pt.pl.coda.rel_abundance_dispersion_plot(mdata)
613
- """
614
- if isinstance(data, MuData):
615
- data = data[modality_key]
616
- if isinstance(data, AnnData):
617
- data = data
618
- if ax is None:
619
- _, ax = plt.subplots(figsize=figsize, dpi=dpi)
620
-
621
- rel_abun = data.X / np.sum(data.X, axis=1, keepdims=True)
622
-
623
- percent_zero = np.sum(data.X == 0, axis=0) / data.X.shape[0]
624
- nonrare_ct = np.where(percent_zero < 1 - abundant_threshold)[0]
625
-
626
- # select reference
627
- cell_type_disp = np.var(rel_abun, axis=0) / np.mean(rel_abun, axis=0)
628
-
629
- is_abundant = [x in nonrare_ct for x in range(data.X.shape[1])]
630
-
631
- # Scatterplot
632
- plot_df = pd.DataFrame(
633
- {
634
- "Total dispersion": cell_type_disp,
635
- "Cell type": data.var.index,
636
- "Presence": 1 - percent_zero,
637
- "Is abundant": is_abundant,
638
- }
639
- )
640
-
641
- if len(np.unique(plot_df["Is abundant"])) > 1:
642
- palette = [default_color, abundant_color]
643
- elif np.unique(plot_df["Is abundant"]) == [False]:
644
- palette = [default_color]
645
- else:
646
- palette = [abundant_color]
647
-
648
- ax = sns.scatterplot(
649
- data=plot_df,
650
- x="Presence",
651
- y="Total dispersion",
652
- hue="Is abundant",
653
- palette=palette,
654
- ax=ax,
655
- )
656
-
657
- # Text labels for abundant cell types
658
-
659
- abundant_df = plot_df.loc[plot_df["Is abundant"], :]
660
-
661
- def label_point(x, y, val, ax):
662
- a = pd.concat({"x": x, "y": y, "val": val}, axis=1)
663
- texts = [
664
- ax.text(
665
- point["x"],
666
- point["y"],
667
- str(point["val"]),
668
- )
669
- for i, point in a.iterrows()
670
- ]
671
- adjust_text(texts)
672
-
673
- if label_cell_types:
674
- label_point(
675
- abundant_df["Presence"],
676
- abundant_df["Total dispersion"],
677
- abundant_df["Cell type"],
678
- plt.gca(),
679
- )
680
-
681
- ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
682
-
683
- plt.tight_layout()
684
- return ax
685
-
686
- @staticmethod
687
- def draw_tree( # pragma: no cover
688
- data: Union[AnnData, MuData],
689
- modality_key: str = "coda",
690
- tree: Union[Tree, str] = "tree",
691
- tight_text: Optional[bool] = False,
692
- show_scale: Optional[bool] = False,
693
- show: Optional[bool] = True,
694
- file_name: Optional[str] = None,
695
- units: Optional[Literal["px", "mm", "in"]] = "px",
696
- h: Optional[float] = None,
697
- w: Optional[float] = None,
698
- dpi: Optional[int] = 90,
699
- ):
700
- """Plot a tree using input ete3 tree object.
701
-
702
- Args:
703
- data: AnnData object or MuData object.
704
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
705
- tree: A ete3 tree object or a str to indicate the tree stored in `.uns`. Defaults to "tree".
706
- tight_text: When False, boundaries of the text are approximated according to general font metrics,
707
- producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
708
- Default to False.
709
- show_scale: Include the scale legend in the tree image or not. Default to False.
710
- show: If True, plot the tree inline. If false, return tree and tree_style objects. Defaults to True.
711
- file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
712
- Defaults to None.
713
- units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px".
714
- h: Height of the image in units. Defaults to None.
715
- w: Width of the image in units. Defaults to None.
716
- dpi: Dots per inches. Defaults to 90.
717
-
718
- Returns:
719
- Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
720
-
721
- Examples:
722
- Example with tascCODA:
723
- >>> import pertpy as pt
724
- >>> adata = pt.dt.smillie()
725
- >>> tasccoda = pt.tl.Tasccoda()
726
- >>> mdata = tasccoda.load(
727
- >>> adata, type="sample_level",
728
- >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
729
- >>> key_added="lineage", add_level_name=True
730
- >>> )
731
- >>> mdata = tasccoda.prepare(
732
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
733
- >>> )
734
- >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
735
- >>> pt.pl.coda.draw_tree(mdata, tree="lineage")
736
- """
737
- if isinstance(data, MuData):
738
- data = data[modality_key]
739
- if isinstance(data, AnnData):
740
- data = data
741
- if isinstance(tree, str):
742
- tree = data.uns[tree]
743
-
744
- def my_layout(node):
745
- text_face = TextFace(node.name, tight_text=tight_text)
746
- faces.add_face_to_node(text_face, node, column=0, position="branch-right")
747
-
748
- tree_style = TreeStyle()
749
- tree_style.show_leaf_name = False
750
- tree_style.layout_fn = my_layout
751
- tree_style.show_scale = show_scale
752
- if file_name is not None:
753
- tree.render(file_name, tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
754
- if show:
755
- return tree.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
756
- else:
757
- return tree, tree_style
758
-
759
- @staticmethod
760
- def draw_effects( # pragma: no cover
761
- data: Union[AnnData, MuData],
762
- covariate: str,
763
- modality_key: str = "coda",
764
- tree: Union[Tree, str] = "tree",
765
- show_legend: Optional[bool] = None,
766
- show_leaf_effects: Optional[bool] = False,
767
- tight_text: Optional[bool] = False,
768
- show_scale: Optional[bool] = False,
769
- show: Optional[bool] = True,
770
- file_name: Optional[str] = None,
771
- units: Optional[Literal["px", "mm", "in"]] = "in",
772
- h: Optional[float] = None,
773
- w: Optional[float] = None,
774
- dpi: Optional[int] = 90,
775
- ):
776
- """Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
777
-
778
- Args:
779
- data: AnnData object or MuData object.
780
- covariate: The covariate, whose effects should be plotted.
781
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
782
- tree: A ete3 tree object or a str to indicate the tree stored in `.uns`. Defaults to "tree".
783
- show_legend: If show legend of nodes significant effects or not. Default is False if show_leaf_effects is True.
784
- show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects. Defaults to False.
785
- tight_text: When False, boundaries of the text are approximated according to general font metrics,
786
- producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
787
- Defaults to False.
788
- show_scale: Include the scale legend in the tree image or not. Defaults to False.
789
- show: If True, plot the tree inline. If false, return tree and tree_style objects. Defaults to True.
790
- file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
791
- Defaults to None.
792
- units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Default is "in". Defaults to "in".
793
- h: Height of the image in units. Defaults to None.
794
- w: Width of the image in units. Defaults to None.
795
- dpi: Dots per inches. Defaults to 90.
796
-
797
- Returns:
798
- Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`)
799
- or plot the tree inline (`show = False`)
800
-
801
- Examples:
802
- Example with tascCODA:
803
- >>> import pertpy as pt
804
- >>> adata = pt.dt.smillie()
805
- >>> tasccoda = pt.tl.Tasccoda()
806
- >>> mdata = tasccoda.load(
807
- >>> adata, type="sample_level",
808
- >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
809
- >>> key_added="lineage", add_level_name=True
810
- >>> )
811
- >>> mdata = tasccoda.prepare(
812
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
813
- >>> )
814
- >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
815
- >>> pt.pl.coda.draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
816
- """
817
- if isinstance(data, MuData):
818
- data = data[modality_key]
819
- if isinstance(data, AnnData):
820
- data = data
821
- if show_legend is None:
822
- show_legend = not show_leaf_effects
823
- elif show_legend:
824
- print("Tree leaves and leaf effect bars won't be aligned when legend is shown!")
825
-
826
- if isinstance(tree, str):
827
- tree = data.uns[tree]
828
- # Collapse tree singularities
829
- tree2 = collapse_singularities_2(tree)
830
-
831
- node_effs = data.uns["scCODA_params"]["node_df"].loc[(covariate + "_node",),].copy()
832
- node_effs.index = node_effs.index.get_level_values("Node")
833
-
834
- covariates = data.uns["scCODA_params"]["covariate_names"]
835
- effect_dfs = [data.varm[f"effect_df_{cov}"] for cov in covariates]
836
- eff_df = pd.concat(effect_dfs)
837
- eff_df.index = pd.MultiIndex.from_product(
838
- (covariates, data.var.index.tolist()),
839
- names=["Covariate", "Cell Type"],
840
- )
841
- leaf_effs = eff_df.loc[(covariate,),].copy()
842
- leaf_effs.index = leaf_effs.index.get_level_values("Cell Type")
843
-
844
- # Add effect values
845
- for n in tree2.traverse():
846
- nstyle = NodeStyle()
847
- nstyle["size"] = 0
848
- n.set_style(nstyle)
849
- if n.name in node_effs.index:
850
- e = node_effs.loc[n.name, "Final Parameter"]
851
- n.add_feature("node_effect", e)
852
- else:
853
- n.add_feature("node_effect", 0)
854
- if n.name in leaf_effs.index:
855
- e = leaf_effs.loc[n.name, "Effect"]
856
- n.add_feature("leaf_effect", e)
857
- else:
858
- n.add_feature("leaf_effect", 0)
859
-
860
- # Scale effect values to get nice node sizes
861
- eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
862
- leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
863
-
864
- def my_layout(node):
865
- text_face = TextFace(node.name, tight_text=tight_text)
866
- text_face.margin_left = 10
867
- faces.add_face_to_node(text_face, node, column=0, aligned=True)
868
-
869
- # if node.is_leaf():
870
- size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
871
- if np.sign(node.node_effect) == 1:
872
- color = "blue"
873
- elif np.sign(node.node_effect) == -1:
874
- color = "red"
875
- else:
876
- color = "cyan"
877
- if size != 0:
878
- faces.add_face_to_node(CircleFace(radius=size, color=color), node, column=0)
879
-
880
- tree_style = TreeStyle()
881
- tree_style.show_leaf_name = False
882
- tree_style.layout_fn = my_layout
883
- tree_style.show_scale = show_scale
884
- tree_style.draw_guiding_lines = True
885
- tree_style.legend_position = 1
886
-
887
- if show_legend:
888
- tree_style.legend.add_face(TextFace("Effects"), column=0)
889
- tree_style.legend.add_face(TextFace(" "), column=1)
890
- for i in range(4, 0, -1):
891
- tree_style.legend.add_face(
892
- CircleFace(
893
- float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
894
- "red",
895
- ),
896
- column=0,
897
- )
898
- tree_style.legend.add_face(TextFace(f"{-eff_max * i / 4:.2f} "), column=0)
899
- tree_style.legend.add_face(
900
- CircleFace(
901
- float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
902
- "blue",
903
- ),
904
- column=1,
905
- )
906
- tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
907
-
908
- if show_leaf_effects:
909
- leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
910
- leaf_effs = leaf_effs.loc[leaf_name].reset_index()
911
- palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
912
-
913
- dir_path = Path.cwd()
914
- dir_path = Path(dir_path / "tree_effect.png")
915
- tree2.render(dir_path, tree_style=tree_style, units="in")
916
- _, ax = plt.subplots(1, 2, figsize=(10, 10))
917
- sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
918
- img = mpimg.imread(dir_path)
919
- ax[0].imshow(img)
920
- ax[0].get_xaxis().set_visible(False)
921
- ax[0].get_yaxis().set_visible(False)
922
- ax[0].set_frame_on(False)
923
-
924
- ax[1].get_yaxis().set_visible(False)
925
- ax[1].spines["left"].set_visible(False)
926
- ax[1].spines["right"].set_visible(False)
927
- ax[1].spines["top"].set_visible(False)
928
- plt.xlim(-leaf_eff_max, leaf_eff_max)
929
- plt.subplots_adjust(wspace=0)
930
-
931
- if file_name is not None:
932
- plt.savefig(file_name)
933
-
934
- if file_name is not None and not show_leaf_effects:
935
- tree2.render(file_name, tree_style=tree_style, units=units)
936
- if show:
937
- if not show_leaf_effects:
938
- return tree2.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
939
- else:
940
- if not show_leaf_effects:
941
- return tree2, tree_style
942
-
943
- @staticmethod
944
- def effects_umap( # pragma: no cover
945
- data: MuData,
946
- effect_name: Optional[Union[str, list]],
947
- cluster_key: str,
948
- modality_key_1: str = "rna",
949
- modality_key_2: str = "coda",
950
- show: bool = None,
951
- ax: Axes = None,
952
- **kwargs,
953
- ):
954
- """Plot a UMAP visualization colored by effect strength.
955
-
956
- Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData (default is data['rna']) depending on the cluster they were assigned to.
957
-
958
- Args:
959
- data: AnnData object or MuData object.
960
- effect_name: The name of the effect results in .varm of aggregated sample-level AnnData (default is data['coda']) to plot
961
- cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']). To assign cell types' effects to original cells.
962
- modality_key_1: Key to the cell-level AnnData in the MuData object. Defaults to "rna".
963
- modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object. Defaults to "coda".
964
- show: Whether to display the figure or return axis. Defaults to None.
965
- ax: A matplotlib axes object. Only works if plotting a single component. Defaults to None.
966
- **kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
967
-
968
- Returns:
969
- If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
970
-
971
- Examples:
972
- Example with scCODA:
973
- >>> import pertpy as pt
974
- >>> haber_cells = pt.dt.haber_2017_regions()
975
- >>> sccoda = pt.tl.Sccoda()
976
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
977
- sample_identifier="batch", covariate_obs=["condition"])
978
- >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
979
- >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
980
-
981
- >>> pt.pl.coda.effects_umap(mdata, effect_name="", cluster_key="")
982
- #TODO: Add effect_name parameter and cluster_key and test the example
983
- """
984
- data_rna = data[modality_key_1]
985
- data_coda = data[modality_key_2]
986
- if isinstance(effect_name, str):
987
- effect_name = [effect_name]
988
- for _, effect in enumerate(effect_name):
989
- data_rna.obs[effect] = [data_coda.varm[effect].loc[f"{c}", "Effect"] for c in data_rna.obs[cluster_key]]
990
- if kwargs.get("vmin"):
991
- vmin = kwargs["vmin"]
992
- kwargs.pop("vmin")
993
- else:
994
- vmin = min(data_rna.obs[effect].min() for _, effect in enumerate(effect_name))
995
- if kwargs.get("vmax"):
996
- vmax = kwargs["vmax"]
997
- kwargs.pop("vmax")
998
- else:
999
- vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name))
1000
-
1001
- return sc.pl.umap(data_rna, color=effect_name, vmax=vmax, vmin=vmin, ax=ax, show=show, **kwargs)