pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/plot/_coda.py CHANGED
@@ -1,21 +1,16 @@
1
- from pathlib import Path
1
+ import warnings
2
2
  from typing import Literal, Optional, Union
3
3
 
4
- import matplotlib.image as mpimg
5
4
  import matplotlib.pyplot as plt
6
5
  import numpy as np
7
- import pandas as pd
8
- import scanpy as sc
9
6
  import seaborn as sns
10
- from adjustText import adjust_text
11
7
  from anndata import AnnData
12
- from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
13
8
  from matplotlib import cm, rcParams
14
9
  from matplotlib.axes import Axes
15
10
  from matplotlib.colors import ListedColormap
16
11
  from mudata import MuData
17
12
 
18
- from pertpy.tools._coda._base_coda import CompositionalModel2, collapse_singularities_2
13
+ from pertpy.tools._coda._base_coda import CompositionalModel2
19
14
 
20
15
  sns.set_style("ticks")
21
16
 
@@ -27,10 +22,10 @@ class CodaPlot:
27
22
  type_names: list[str],
28
23
  title: str,
29
24
  level_names: list[str],
30
- figsize: Optional[tuple[float, float]] = None,
31
- dpi: Optional[int] = 100,
32
- cmap: Optional[ListedColormap] = cm.tab20,
33
- show_legend: Optional[bool] = True,
25
+ figsize: tuple[float, float] | None = None,
26
+ dpi: int | None = 100,
27
+ cmap: ListedColormap | None = cm.tab20,
28
+ show_legend: bool | None = True,
34
29
  ) -> plt.Axes:
35
30
  """Plots a stacked barplot for one (discrete) covariate.
36
31
 
@@ -62,7 +57,7 @@ class CodaPlot:
62
57
  cum_bars = np.zeros(n_bars)
63
58
 
64
59
  for n in range(n_types):
65
- bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums)]
60
+ bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums, strict=False)]
66
61
  plt.bar(
67
62
  r,
68
63
  bars,
@@ -85,13 +80,13 @@ class CodaPlot:
85
80
 
86
81
  @staticmethod
87
82
  def stacked_barplot( # pragma: no cover
88
- data: Union[AnnData, MuData],
83
+ data: AnnData | MuData,
89
84
  feature_name: str,
90
85
  modality_key: str = "coda",
91
- figsize: Optional[tuple[float, float]] = None,
92
- dpi: Optional[int] = 100,
93
- cmap: Optional[ListedColormap] = cm.tab20,
94
- show_legend: Optional[bool] = True,
86
+ figsize: tuple[float, float] | None = None,
87
+ dpi: int | None = 100,
88
+ cmap: ListedColormap | None = cm.tab20,
89
+ show_legend: bool | None = True,
95
90
  level_order: list[str] = None,
96
91
  ) -> plt.Axes:
97
92
  """Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
@@ -110,79 +105,49 @@ class CodaPlot:
110
105
  A :class:`~matplotlib.axes.Axes` object
111
106
 
112
107
  Examples:
113
- Example with scCODA:
114
108
  >>> import pertpy as pt
115
109
  >>> haber_cells = pt.dt.haber_2017_regions()
116
110
  >>> sccoda = pt.tl.Sccoda()
117
111
  >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
118
112
  sample_identifier="batch", covariate_obs=["condition"])
119
- >>> pt.pl.coda.stacked_barplot(mdata, feature_name="samples")
113
+ >>> sccoda.plot_stacked_barplot(mdata, feature_name="samples")
120
114
  """
121
- if isinstance(data, MuData):
122
- data = data[modality_key]
123
- if isinstance(data, AnnData):
124
- data = data
125
-
126
- ct_names = data.var.index
127
-
128
- # option to plot one stacked barplot per sample
129
- if feature_name == "samples":
130
- if level_order:
131
- assert set(level_order) == set(data.obs.index), "level order is inconsistent with levels"
132
- data = data[level_order]
133
- ax = CodaPlot.__stackbar(
134
- data.X,
135
- type_names=data.var.index,
136
- title="samples",
137
- level_names=data.obs.index,
138
- figsize=figsize,
139
- dpi=dpi,
140
- cmap=cmap,
141
- show_legend=show_legend,
142
- )
143
- else:
144
- # Order levels
145
- if level_order:
146
- assert set(level_order) == set(data.obs[feature_name]), "level order is inconsistent with levels"
147
- levels = level_order
148
- elif hasattr(data.obs[feature_name], "cat"):
149
- levels = data.obs[feature_name].cat.categories.to_list()
150
- else:
151
- levels = pd.unique(data.obs[feature_name])
152
- n_levels = len(levels)
153
- feature_totals = np.zeros([n_levels, data.X.shape[1]])
154
-
155
- for level in range(n_levels):
156
- l_indices = np.where(data.obs[feature_name] == levels[level])
157
- feature_totals[level] = np.sum(data.X[l_indices], axis=0)
158
-
159
- ax = CodaPlot.__stackbar(
160
- feature_totals,
161
- type_names=ct_names,
162
- title=feature_name,
163
- level_names=levels,
164
- figsize=figsize,
165
- dpi=dpi,
166
- cmap=cmap,
167
- show_legend=show_legend,
168
- )
169
- return ax
115
+ warnings.warn(
116
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
117
+ " Please use the corresponding 'pt.tl' object for plotting function directly.",
118
+ FutureWarning,
119
+ stacklevel=2,
120
+ )
121
+
122
+ from pertpy.tools import Sccoda
123
+
124
+ coda = Sccoda()
125
+ return coda.plot_stacked_barplot(
126
+ data=data,
127
+ feature_name=feature_name,
128
+ modality_key=modality_key,
129
+ figsize=figsize,
130
+ dpi=dpi,
131
+ palette=cmap,
132
+ show_legend=show_legend,
133
+ level_order=level_order,
134
+ )
170
135
 
171
136
  @staticmethod
172
137
  def effects_barplot( # pragma: no cover
173
- data: Union[AnnData, MuData],
138
+ data: AnnData | MuData,
174
139
  modality_key: str = "coda",
175
- covariates: Optional[Union[str, list]] = None,
140
+ covariates: str | list | None = None,
176
141
  parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
177
142
  plot_facets: bool = True,
178
143
  plot_zero_covariate: bool = True,
179
144
  plot_zero_cell_type: bool = False,
180
- figsize: Optional[tuple[float, float]] = None,
181
- dpi: Optional[int] = 100,
182
- cmap: Optional[Union[str, ListedColormap]] = cm.tab20,
145
+ figsize: tuple[float, float] | None = None,
146
+ dpi: int | None = 100,
147
+ cmap: str | ListedColormap | None = cm.tab20,
183
148
  level_order: list[str] = None,
184
- args_barplot: Optional[dict] = None,
185
- ) -> Optional[Union[plt.Axes, sns.axisgrid.FacetGrid]]:
149
+ args_barplot: dict | None = None,
150
+ ) -> plt.Axes | sns.axisgrid.FacetGrid | None:
186
151
  """Barplot visualization for effects.
187
152
 
188
153
  The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
@@ -193,9 +158,12 @@ class CodaPlot:
193
158
  modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
194
159
  covariates: The name of the covariates in data.obs to plot. Defaults to None.
195
160
  parameter: The parameter in effect summary to plot. Defaults to "log2-fold change".
196
- plot_facets: If False, plot cell types on the x-axis. If True, plot as facets. Defaults to True.
197
- plot_zero_covariate: If True, plot covariate that have all zero effects. If False, do not plot. Defaults to True.
198
- plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot. Defaults to False.
161
+ plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
162
+ Defaults to True.
163
+ plot_zero_covariate: If True, plot covariate that have all zero effects. If False, do not plot.
164
+ Defaults to True.
165
+ plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot.
166
+ Defaults to False.
199
167
  figsize: Figure size. Defaults to None.
200
168
  dpi: Figure size. Defaults to 100.
201
169
  cmap: The seaborn color map for the barplot. Defaults to cm.tab20.
@@ -207,7 +175,6 @@ class CodaPlot:
207
175
  or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
208
176
 
209
177
  Examples:
210
- Example with scCODA:
211
178
  >>> import pertpy as pt
212
179
  >>> haber_cells = pt.dt.haber_2017_regions()
213
180
  >>> sccoda = pt.tl.Sccoda()
@@ -215,152 +182,50 @@ class CodaPlot:
215
182
  sample_identifier="batch", covariate_obs=["condition"])
216
183
  >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
217
184
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
218
- >>> pt.pl.coda.effects_barplot(mdata)
185
+ >>> sccoda.plot_effects_barplot(mdata)
219
186
  """
220
- if args_barplot is None:
221
- args_barplot = {}
222
- if isinstance(data, MuData):
223
- data = data[modality_key]
224
- if isinstance(data, AnnData):
225
- data = data
226
- # Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
227
- covariate_names = data.uns["scCODA_params"]["covariate_names"]
228
- if covariates is not None:
229
- if isinstance(covariates, str):
230
- covariates = [covariates]
231
- partial_covariate_names = [
232
- covariate_name
233
- for covariate_name in covariate_names
234
- if any(covariate in covariate_name for covariate in covariates)
235
- ]
236
- covariate_names = partial_covariate_names
237
- covariate_names_non_zero = [
238
- covariate_name
239
- for covariate_name in covariate_names
240
- if data.varm[f"effect_df_{covariate_name}"][parameter].any()
241
- ]
242
- covariate_names_zero = list(set(covariate_names) - set(covariate_names_non_zero))
243
- if not plot_zero_covariate:
244
- covariate_names = covariate_names_non_zero
245
-
246
- # set up df for plotting
247
- plot_df = pd.concat(
248
- [data.varm[f"effect_df_{covariate_name}"][parameter] for covariate_name in covariate_names],
249
- axis=1,
187
+ warnings.warn(
188
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
189
+ " Please use the corresponding 'pt.tl' object for plotting function directly.",
190
+ FutureWarning,
191
+ stacklevel=2,
250
192
  )
251
- plot_df.columns = covariate_names
252
- plot_df = pd.melt(plot_df, ignore_index=False, var_name="Covariate")
253
-
254
- plot_df = plot_df.reset_index()
255
-
256
- if len(covariate_names_zero) != 0:
257
- if plot_facets:
258
- if plot_zero_covariate and not plot_zero_cell_type:
259
- plot_df = plot_df[plot_df["value"] != 0]
260
- for covariate_name_zero in covariate_names_zero:
261
- new_row = {
262
- "Covariate": covariate_name_zero,
263
- "Cell Type": "zero",
264
- "value": 0,
265
- }
266
- plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
267
- plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
268
- plot_df = plot_df.sort_values(["covariate_"])
269
- if not plot_zero_cell_type:
270
- cell_type_names_zero = [
271
- name
272
- for name in plot_df["Cell Type"].unique()
273
- if (plot_df[plot_df["Cell Type"] == name]["value"] == 0).all()
274
- ]
275
- plot_df = plot_df[~plot_df["Cell Type"].isin(cell_type_names_zero)]
276
-
277
- # If plot as facets, create a FacetGrid and map barplot to it.
278
- if plot_facets:
279
- if isinstance(cmap, ListedColormap):
280
- cmap = np.array([cmap(i % cmap.N) for i in range(len(plot_df["Cell Type"].unique()))])
281
- if figsize is not None:
282
- height = figsize[0]
283
- aspect = np.round(figsize[1] / figsize[0], 2)
284
- else:
285
- height = 3
286
- aspect = 2
287
-
288
- g = sns.FacetGrid(
289
- plot_df,
290
- col="Covariate",
291
- sharey=True,
292
- sharex=False,
293
- height=height,
294
- aspect=aspect,
295
- )
296
193
 
297
- g.map(
298
- sns.barplot,
299
- "Cell Type",
300
- "value",
301
- palette=cmap,
302
- order=level_order,
303
- **args_barplot,
304
- )
305
- g.set_xticklabels(rotation=90)
306
- g.set(ylabel=parameter)
307
- axes = g.axes.flatten()
308
- for i, ax in enumerate(axes):
309
- ax.set_title(covariate_names[i])
310
- if len(ax.get_xticklabels()) < 5:
311
- ax.set_aspect(10 / len(ax.get_xticklabels()))
312
- if len(ax.get_xticklabels()) == 1:
313
- if ax.get_xticklabels()[0]._text == "zero":
314
- ax.set_xticks([])
315
- return g
316
-
317
- # If not plot as facets, call barplot to plot cell types on the x-axis.
318
- else:
319
- _, ax = plt.subplots(figsize=figsize, dpi=dpi)
320
- if len(covariate_names) == 1:
321
- if isinstance(cmap, ListedColormap):
322
- cmap = np.array([cmap(i % cmap.N) for i in range(len(plot_df["Cell Type"].unique()))])
323
- sns.barplot(
324
- data=plot_df,
325
- x="Cell Type",
326
- y="value",
327
- palette=cmap,
328
- ax=ax,
329
- )
330
- ax.set_title(covariate_names[0])
331
- else:
332
- if isinstance(cmap, ListedColormap):
333
- cmap = np.array([cmap(i % cmap.N) for i in range(len(covariate_names))])
334
- sns.barplot(
335
- data=plot_df,
336
- x="Cell Type",
337
- y="value",
338
- hue="Covariate",
339
- palette=cmap,
340
- ax=ax,
341
- )
342
- cell_types = pd.unique(plot_df["Cell Type"])
343
- ax.set_xticklabels(cell_types, rotation=90)
344
- return ax
194
+ from pertpy.tools import Sccoda
195
+
196
+ coda = Sccoda()
197
+ return coda.plot_effects_barplot(
198
+ data=data,
199
+ modality_key=modality_key,
200
+ covariates=covariates,
201
+ parameter=parameter,
202
+ plot_facets=plot_facets,
203
+ plot_zero_covariate=plot_zero_covariate,
204
+ plot_zero_cell_type=plot_zero_cell_type,
205
+ figsize=figsize,
206
+ dpi=dpi,
207
+ palette=cmap,
208
+ level_order=level_order,
209
+ args_barplot=args_barplot,
210
+ )
345
211
 
346
212
  @staticmethod
347
213
  def boxplots( # pragma: no cover
348
- data: Union[AnnData, MuData],
214
+ data: AnnData | MuData,
349
215
  feature_name: str,
350
216
  modality_key: str = "coda",
351
217
  y_scale: Literal["relative", "log", "log10", "count"] = "relative",
352
218
  plot_facets: bool = False,
353
219
  add_dots: bool = False,
354
- model: CompositionalModel2 = None,
355
- cell_types: Optional[list] = None,
356
- args_boxplot: Optional[dict] = None,
357
- args_swarmplot: Optional[dict] = None,
358
- figsize: Optional[tuple[float, float]] = None,
359
- dpi: Optional[int] = 100,
360
- cmap: Optional[str] = "Blues",
361
- show_legend: Optional[bool] = True,
220
+ cell_types: list | None = None,
221
+ args_boxplot: dict | None = None,
222
+ args_swarmplot: dict | None = None,
223
+ figsize: tuple[float, float] | None = None,
224
+ dpi: int | None = 100,
225
+ cmap: str | None = "Blues",
226
+ show_legend: bool | None = True,
362
227
  level_order: list[str] = None,
363
- ) -> Optional[Union[plt.Axes, sns.axisgrid.FacetGrid]]:
228
+ ) -> plt.Axes | sns.axisgrid.FacetGrid | None:
364
229
  """Grouped boxplot visualization. The cell counts for each cell type are shown as a group of boxplots,
365
230
  with intra--group separation by a covariate from data.obs.
366
231
 
@@ -388,196 +253,50 @@ class CodaPlot:
388
253
  or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
389
254
 
390
255
  Examples:
391
- Example with scCODA:
392
256
  >>> import pertpy as pt
393
257
  >>> haber_cells = pt.dt.haber_2017_regions()
394
258
  >>> sccoda = pt.tl.Sccoda()
395
259
  >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
396
260
  sample_identifier="batch", covariate_obs=["condition"])
397
- >>> pt.pl.coda.boxplots(mdata, feature_name="condition", add_dots=True)
261
+ >>> sccoda.plot_boxplots(mdata, feature_name="condition", add_dots=True)
398
262
  """
399
- if args_boxplot is None:
400
- args_boxplot = {}
401
- if args_swarmplot is None:
402
- args_swarmplot = {}
403
- if isinstance(data, MuData):
404
- data = data[modality_key]
405
- if isinstance(data, AnnData):
406
- data = data
407
- # y scale transformations
408
- if y_scale == "relative":
409
- sample_sums = np.sum(data.X, axis=1, keepdims=True)
410
- X = data.X / sample_sums
411
- value_name = "Proportion"
412
- # add pseudocount 0.5 if using log scale
413
- elif y_scale == "log":
414
- X = data.X.copy()
415
- X[X == 0] = 0.5
416
- X = np.log(X)
417
- value_name = "log(count)"
418
- elif y_scale == "log10":
419
- X = data.X.copy()
420
- X[X == 0] = 0.5
421
- X = np.log(X)
422
- value_name = "log10(count)"
423
- elif y_scale == "count":
424
- X = data.X
425
- value_name = "count"
426
- else:
427
- raise ValueError("Invalid y_scale transformation")
428
-
429
- count_df = pd.DataFrame(X, columns=data.var.index, index=data.obs.index).merge(
430
- data.obs[feature_name], left_index=True, right_index=True
263
+ warnings.warn(
264
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
265
+ " Please use the corresponding 'pt.tl' object for plotting function directly.",
266
+ FutureWarning,
267
+ stacklevel=2,
431
268
  )
432
- plot_df = pd.melt(count_df, id_vars=feature_name, var_name="Cell type", value_name=value_name)
433
- if cell_types is not None:
434
- plot_df = plot_df[plot_df["Cell type"].isin(cell_types)]
435
-
436
- # Currently disabled because the latest statsannotations does not support the latest seaborn.
437
- # We had to drop the dependency.
438
- # Get credible effects results from model
439
- # if draw_effects:
440
- # if model is not None:
441
- # credible_effects_df = model.credible_effects(data, modality_key).to_frame().reset_index()
442
- # else:
443
- # print("[bold yellow]Specify a tasCODA model to draw effects")
444
- # credible_effects_df[feature_name] = credible_effects_df["Covariate"].str.removeprefix(f"{feature_name}[T.")
445
- # credible_effects_df[feature_name] = credible_effects_df[feature_name].str.removesuffix("]")
446
- # credible_effects_df = credible_effects_df[credible_effects_df["Final Parameter"]]
447
-
448
- # If plot as facets, create a FacetGrid and map boxplot to it.
449
- if plot_facets:
450
- if level_order is None:
451
- level_order = pd.unique(plot_df[feature_name])
452
-
453
- K = X.shape[1]
454
-
455
- if figsize is not None:
456
- height = figsize[0]
457
- aspect = np.round(figsize[1] / figsize[0], 2)
458
- else:
459
- height = 3
460
- aspect = 2
461
-
462
- g = sns.FacetGrid(
463
- plot_df,
464
- col="Cell type",
465
- sharey=False,
466
- col_wrap=int(np.floor(np.sqrt(K))),
467
- height=height,
468
- aspect=aspect,
469
- )
470
- g.map(
471
- sns.boxplot,
472
- feature_name,
473
- value_name,
474
- palette=cmap,
475
- order=level_order,
476
- **args_boxplot,
477
- )
478
269
 
479
- if add_dots:
480
- if "hue" in args_swarmplot:
481
- hue = args_swarmplot.pop("hue")
482
- else:
483
- hue = None
484
-
485
- if hue is None:
486
- g.map(
487
- sns.swarmplot,
488
- feature_name,
489
- value_name,
490
- color="black",
491
- order=level_order,
492
- **args_swarmplot,
493
- ).set_titles("{col_name}")
494
- else:
495
- g.map(
496
- sns.swarmplot,
497
- feature_name,
498
- value_name,
499
- hue,
500
- order=level_order,
501
- **args_swarmplot,
502
- ).set_titles("{col_name}")
503
- return g
504
-
505
- # If not plot as facets, call boxplot to plot cell types on the x-axis.
506
- else:
507
- if level_order:
508
- args_boxplot["hue_order"] = level_order
509
- args_swarmplot["hue_order"] = level_order
510
-
511
- _, ax = plt.subplots(figsize=figsize, dpi=dpi)
512
-
513
- ax = sns.boxplot(
514
- x="Cell type",
515
- y=value_name,
516
- hue=feature_name,
517
- data=plot_df,
518
- fliersize=1,
519
- palette=cmap,
520
- ax=ax,
521
- **args_boxplot,
522
- )
523
-
524
- # Currently disabled because the latest statsannotations does not support the latest seaborn.
525
- # We had to drop the dependency.
526
- # if draw_effects:
527
- # pairs = [
528
- # [(row["Cell Type"], row[feature_name]), (row["Cell Type"], "Control")]
529
- # for _, row in credible_effects_df.iterrows()
530
- # ]
531
- # annot = Annotator(ax, pairs, data=plot_df, x="Cell type", y=value_name, hue=feature_name)
532
- # annot.configure(test=None, loc="outside", color="red", line_height=0, verbose=False)
533
- # annot.set_custom_annotations([row[feature_name] for _, row in credible_effects_df.iterrows()])
534
- # annot.annotate()
535
-
536
- if add_dots:
537
- sns.swarmplot(
538
- x="Cell type",
539
- y=value_name,
540
- data=plot_df,
541
- hue=feature_name,
542
- ax=ax,
543
- dodge=True,
544
- color="black",
545
- **args_swarmplot,
546
- )
547
-
548
- cell_types = pd.unique(plot_df["Cell type"])
549
- ax.set_xticklabels(cell_types, rotation=90)
550
-
551
- if show_legend:
552
- handles, labels = ax.get_legend_handles_labels()
553
- handout = []
554
- labelout = []
555
- for h, l in zip(handles, labels):
556
- if l not in labelout:
557
- labelout.append(l)
558
- handout.append(h)
559
- ax.legend(
560
- handout,
561
- labelout,
562
- loc="upper left",
563
- bbox_to_anchor=(1, 1),
564
- ncol=1,
565
- title=feature_name,
566
- )
567
-
568
- plt.tight_layout()
569
- return ax
270
+ from pertpy.tools import Sccoda
271
+
272
+ coda = Sccoda()
273
+ return coda.plot_boxplots(
274
+ data=data,
275
+ feature_name=feature_name,
276
+ modality_key=modality_key,
277
+ y_scale=y_scale,
278
+ plot_facets=plot_facets,
279
+ add_dots=add_dots,
280
+ cell_types=cell_types,
281
+ args_boxplot=args_boxplot,
282
+ args_swarmplot=args_swarmplot,
283
+ figsize=figsize,
284
+ dpi=dpi,
285
+ palette=cmap,
286
+ show_legend=show_legend,
287
+ level_order=level_order,
288
+ )
570
289
 
571
290
  @staticmethod
572
291
  def rel_abundance_dispersion_plot( # pragma: no cover
573
- data: Union[AnnData, MuData],
292
+ data: AnnData | MuData,
574
293
  modality_key: str = "coda",
575
- abundant_threshold: Optional[float] = 0.9,
576
- default_color: Optional[str] = "Grey",
577
- abundant_color: Optional[str] = "Red",
294
+ abundant_threshold: float | None = 0.9,
295
+ default_color: str | None = "Grey",
296
+ abundant_color: str | None = "Red",
578
297
  label_cell_types: bool = True,
579
- figsize: Optional[tuple[float, float]] = None,
580
- dpi: Optional[int] = 100,
298
+ figsize: tuple[float, float] | None = None,
299
+ dpi: int | None = 100,
581
300
  ax: Axes = None,
582
301
  ) -> plt.Axes:
583
302
  """Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
@@ -601,7 +320,6 @@ class CodaPlot:
601
320
  A :class:`~matplotlib.axes.Axes` object
602
321
 
603
322
  Examples:
604
- Example with scCODA:
605
323
  >>> import pertpy as pt
606
324
  >>> haber_cells = pt.dt.haber_2017_regions()
607
325
  >>> sccoda = pt.tl.Sccoda()
@@ -609,119 +327,74 @@ class CodaPlot:
609
327
  sample_identifier="batch", covariate_obs=["condition"])
610
328
  >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
611
329
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
612
- >>> pt.pl.coda.rel_abundance_dispersion_plot(mdata)
330
+ >>> sccoda.plot_rel_abundance_dispersion_plot(mdata)
613
331
  """
614
- if isinstance(data, MuData):
615
- data = data[modality_key]
616
- if isinstance(data, AnnData):
617
- data = data
618
- if ax is None:
619
- _, ax = plt.subplots(figsize=figsize, dpi=dpi)
620
-
621
- rel_abun = data.X / np.sum(data.X, axis=1, keepdims=True)
622
-
623
- percent_zero = np.sum(data.X == 0, axis=0) / data.X.shape[0]
624
- nonrare_ct = np.where(percent_zero < 1 - abundant_threshold)[0]
625
-
626
- # select reference
627
- cell_type_disp = np.var(rel_abun, axis=0) / np.mean(rel_abun, axis=0)
628
-
629
- is_abundant = [x in nonrare_ct for x in range(data.X.shape[1])]
630
-
631
- # Scatterplot
632
- plot_df = pd.DataFrame(
633
- {
634
- "Total dispersion": cell_type_disp,
635
- "Cell type": data.var.index,
636
- "Presence": 1 - percent_zero,
637
- "Is abundant": is_abundant,
638
- }
332
+ warnings.warn(
333
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
334
+ " Please use the corresponding 'pt.tl' object for plotting function directly.",
335
+ FutureWarning,
336
+ stacklevel=2,
639
337
  )
640
338
 
641
- if len(np.unique(plot_df["Is abundant"])) > 1:
642
- palette = [default_color, abundant_color]
643
- elif np.unique(plot_df["Is abundant"]) == [False]:
644
- palette = [default_color]
645
- else:
646
- palette = [abundant_color]
647
-
648
- ax = sns.scatterplot(
649
- data=plot_df,
650
- x="Presence",
651
- y="Total dispersion",
652
- hue="Is abundant",
653
- palette=palette,
339
+ from pertpy.tools import Sccoda
340
+
341
+ coda = Sccoda()
342
+ return coda.plot_rel_abundance_dispersion_plot(
343
+ data=data,
344
+ modality_key=modality_key,
345
+ abundant_threshold=abundant_threshold,
346
+ default_color=default_color,
347
+ abundant_color=abundant_color,
348
+ label_cell_types=label_cell_types,
349
+ figsize=figsize,
350
+ dpi=dpi,
654
351
  ax=ax,
655
352
  )
656
353
 
657
- # Text labels for abundant cell types
658
-
659
- abundant_df = plot_df.loc[plot_df["Is abundant"], :]
660
-
661
- def label_point(x, y, val, ax):
662
- a = pd.concat({"x": x, "y": y, "val": val}, axis=1)
663
- texts = [
664
- ax.text(
665
- point["x"],
666
- point["y"],
667
- str(point["val"]),
668
- )
669
- for i, point in a.iterrows()
670
- ]
671
- adjust_text(texts)
672
-
673
- if label_cell_types:
674
- label_point(
675
- abundant_df["Presence"],
676
- abundant_df["Total dispersion"],
677
- abundant_df["Cell type"],
678
- plt.gca(),
679
- )
680
-
681
- ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
682
-
683
- plt.tight_layout()
684
- return ax
685
-
686
354
  @staticmethod
687
355
  def draw_tree( # pragma: no cover
688
- data: Union[AnnData, MuData],
356
+ data: AnnData | MuData,
689
357
  modality_key: str = "coda",
690
- tree: Union[Tree, str] = "tree",
691
- tight_text: Optional[bool] = False,
692
- show_scale: Optional[bool] = False,
693
- show: Optional[bool] = True,
694
- file_name: Optional[str] = None,
695
- units: Optional[Literal["px", "mm", "in"]] = "px",
696
- h: Optional[float] = None,
697
- w: Optional[float] = None,
698
- dpi: Optional[int] = 90,
358
+ tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
359
+ tight_text: bool | None = False,
360
+ show_scale: bool | None = False,
361
+ show: bool | None = True,
362
+ save: str | None = None,
363
+ units: Literal["px", "mm", "in"] | None = "px",
364
+ figsize: tuple[float, float] | None = (None, None),
365
+ dpi: int | None = 90,
699
366
  ):
700
367
  """Plot a tree using input ete3 tree object.
701
368
 
702
369
  Args:
703
370
  data: AnnData object or MuData object.
704
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
705
- tree: A ete3 tree object or a str to indicate the tree stored in `.uns`. Defaults to "tree".
371
+ modality_key: If data is a MuData object, specify which modality to use.
372
+ Defaults to "coda".
373
+ tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
374
+ Defaults to "tree".
706
375
  tight_text: When False, boundaries of the text are approximated according to general font metrics,
707
- producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
376
+ producing slightly worse aligned text faces but improving
377
+ the performance of tree visualization in scenes with a lot of text faces.
708
378
  Default to False.
709
- show_scale: Include the scale legend in the tree image or not. Default to False.
710
- show: If True, plot the tree inline. If false, return tree and tree_style objects. Defaults to True.
711
- file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
379
+ show_scale: Include the scale legend in the tree image or not.
380
+ Defaults to False.
381
+ show: If True, plot the tree inline. If false, return tree and tree_style objects.
382
+ Defaults to True.
383
+ file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG.
384
+ Output image can be saved whether show is True or not.
712
385
  Defaults to None.
713
- units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px".
386
+ units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
387
+ Defaults to "px".
714
388
  h: Height of the image in units. Defaults to None.
715
389
  w: Width of the image in units. Defaults to None.
716
390
  dpi: Dots per inches. Defaults to 90.
717
391
 
718
392
  Returns:
719
- Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
393
+ Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
720
394
 
721
395
  Examples:
722
- Example with tascCODA:
723
396
  >>> import pertpy as pt
724
- >>> adata = pt.dt.smillie()
397
+ >>> adata = pt.dt.tasccoda_example()
725
398
  >>> tasccoda = pt.tl.Tasccoda()
726
399
  >>> mdata = tasccoda.load(
727
400
  >>> adata, type="sample_level",
@@ -732,56 +405,62 @@ class CodaPlot:
732
405
  >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
733
406
  >>> )
734
407
  >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
735
- >>> pt.pl.coda.draw_tree(mdata, tree="lineage")
408
+ >>> tasccoda.plot_draw_tree(mdata, tree="lineage")
409
+
410
+ Preview: #TODO: Add preview
736
411
  """
737
- if isinstance(data, MuData):
738
- data = data[modality_key]
739
- if isinstance(data, AnnData):
740
- data = data
741
- if isinstance(tree, str):
742
- tree = data.uns[tree]
743
-
744
- def my_layout(node):
745
- text_face = TextFace(node.name, tight_text=tight_text)
746
- faces.add_face_to_node(text_face, node, column=0, position="branch-right")
747
-
748
- tree_style = TreeStyle()
749
- tree_style.show_leaf_name = False
750
- tree_style.layout_fn = my_layout
751
- tree_style.show_scale = show_scale
752
- if file_name is not None:
753
- tree.render(file_name, tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
754
- if show:
755
- return tree.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
756
- else:
757
- return tree, tree_style
412
+ warnings.warn(
413
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
414
+ " Please use the corresponding 'pt.tl' object for plotting function directly.",
415
+ FutureWarning,
416
+ stacklevel=2,
417
+ )
418
+
419
+ from pertpy.tools import Tasccoda
420
+
421
+ coda = Tasccoda()
422
+ return coda.plot_draw_tree(
423
+ data=data,
424
+ modality_key=modality_key,
425
+ tree=tree,
426
+ tight_text=tight_text,
427
+ show_scale=show_scale,
428
+ show=show,
429
+ save=save,
430
+ units=units,
431
+ figsize=figsize,
432
+ dpi=dpi,
433
+ )
758
434
 
759
435
  @staticmethod
760
436
  def draw_effects( # pragma: no cover
761
- data: Union[AnnData, MuData],
437
+ data: AnnData | MuData,
762
438
  covariate: str,
763
439
  modality_key: str = "coda",
764
- tree: Union[Tree, str] = "tree",
765
- show_legend: Optional[bool] = None,
766
- show_leaf_effects: Optional[bool] = False,
767
- tight_text: Optional[bool] = False,
768
- show_scale: Optional[bool] = False,
769
- show: Optional[bool] = True,
770
- file_name: Optional[str] = None,
771
- units: Optional[Literal["px", "mm", "in"]] = "in",
772
- h: Optional[float] = None,
773
- w: Optional[float] = None,
774
- dpi: Optional[int] = 90,
440
+ tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
441
+ show_legend: bool | None = None,
442
+ show_leaf_effects: bool | None = False,
443
+ tight_text: bool | None = False,
444
+ show_scale: bool | None = False,
445
+ show: bool | None = True,
446
+ save: str | None = None,
447
+ units: Literal["px", "mm", "in"] | None = "in",
448
+ figsize: tuple[float, float] | None = (None, None),
449
+ dpi: int | None = 90,
775
450
  ):
776
451
  """Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
777
452
 
778
453
  Args:
779
454
  data: AnnData object or MuData object.
780
455
  covariate: The covariate, whose effects should be plotted.
781
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
782
- tree: A ete3 tree object or a str to indicate the tree stored in `.uns`. Defaults to "tree".
783
- show_legend: If show legend of nodes significant effects or not. Default is False if show_leaf_effects is True.
784
- show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects. Defaults to False.
456
+ modality_key: If data is a MuData object, specify which modality to use.
457
+ Defaults to "coda".
458
+ tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
459
+ Defaults to "tree".
460
+ show_legend: If show legend of nodes significant effects or not.
461
+ Defaults to False if show_leaf_effects is True.
462
+ show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects.
463
+ Defaults to False.
785
464
  tight_text: When False, boundaries of the text are approximated according to general font metrics,
786
465
  producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
787
466
  Defaults to False.
@@ -799,9 +478,8 @@ class CodaPlot:
799
478
  or plot the tree inline (`show = False`)
800
479
 
801
480
  Examples:
802
- Example with tascCODA:
803
481
  >>> import pertpy as pt
804
- >>> adata = pt.dt.smillie()
482
+ >>> adata = pt.dt.tasccoda_example()
805
483
  >>> tasccoda = pt.tl.Tasccoda()
806
484
  >>> mdata = tasccoda.load(
807
485
  >>> adata, type="sample_level",
@@ -812,138 +490,38 @@ class CodaPlot:
812
490
  >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
813
491
  >>> )
814
492
  >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
815
- >>> pt.pl.coda.draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
493
+ >>> tasccoda.plot_draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
816
494
  """
817
- if isinstance(data, MuData):
818
- data = data[modality_key]
819
- if isinstance(data, AnnData):
820
- data = data
821
- if show_legend is None:
822
- show_legend = not show_leaf_effects
823
- elif show_legend:
824
- print("Tree leaves and leaf effect bars won't be aligned when legend is shown!")
825
-
826
- if isinstance(tree, str):
827
- tree = data.uns[tree]
828
- # Collapse tree singularities
829
- tree2 = collapse_singularities_2(tree)
830
-
831
- node_effs = data.uns["scCODA_params"]["node_df"].loc[(covariate + "_node",),].copy()
832
- node_effs.index = node_effs.index.get_level_values("Node")
833
-
834
- covariates = data.uns["scCODA_params"]["covariate_names"]
835
- effect_dfs = [data.varm[f"effect_df_{cov}"] for cov in covariates]
836
- eff_df = pd.concat(effect_dfs)
837
- eff_df.index = pd.MultiIndex.from_product(
838
- (covariates, data.var.index.tolist()),
839
- names=["Covariate", "Cell Type"],
495
+ warnings.warn(
496
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
497
+ " Please use the corresponding 'pt.tl' object",
498
+ FutureWarning,
499
+ stacklevel=2,
840
500
  )
841
- leaf_effs = eff_df.loc[(covariate,),].copy()
842
- leaf_effs.index = leaf_effs.index.get_level_values("Cell Type")
843
-
844
- # Add effect values
845
- for n in tree2.traverse():
846
- nstyle = NodeStyle()
847
- nstyle["size"] = 0
848
- n.set_style(nstyle)
849
- if n.name in node_effs.index:
850
- e = node_effs.loc[n.name, "Final Parameter"]
851
- n.add_feature("node_effect", e)
852
- else:
853
- n.add_feature("node_effect", 0)
854
- if n.name in leaf_effs.index:
855
- e = leaf_effs.loc[n.name, "Effect"]
856
- n.add_feature("leaf_effect", e)
857
- else:
858
- n.add_feature("leaf_effect", 0)
859
-
860
- # Scale effect values to get nice node sizes
861
- eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
862
- leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
863
-
864
- def my_layout(node):
865
- text_face = TextFace(node.name, tight_text=tight_text)
866
- text_face.margin_left = 10
867
- faces.add_face_to_node(text_face, node, column=0, aligned=True)
868
-
869
- # if node.is_leaf():
870
- size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
871
- if np.sign(node.node_effect) == 1:
872
- color = "blue"
873
- elif np.sign(node.node_effect) == -1:
874
- color = "red"
875
- else:
876
- color = "cyan"
877
- if size != 0:
878
- faces.add_face_to_node(CircleFace(radius=size, color=color), node, column=0)
879
-
880
- tree_style = TreeStyle()
881
- tree_style.show_leaf_name = False
882
- tree_style.layout_fn = my_layout
883
- tree_style.show_scale = show_scale
884
- tree_style.draw_guiding_lines = True
885
- tree_style.legend_position = 1
886
501
 
887
- if show_legend:
888
- tree_style.legend.add_face(TextFace("Effects"), column=0)
889
- tree_style.legend.add_face(TextFace(" "), column=1)
890
- for i in range(4, 0, -1):
891
- tree_style.legend.add_face(
892
- CircleFace(
893
- float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
894
- "red",
895
- ),
896
- column=0,
897
- )
898
- tree_style.legend.add_face(TextFace(f"{-eff_max * i / 4:.2f} "), column=0)
899
- tree_style.legend.add_face(
900
- CircleFace(
901
- float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
902
- "blue",
903
- ),
904
- column=1,
905
- )
906
- tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
907
-
908
- if show_leaf_effects:
909
- leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
910
- leaf_effs = leaf_effs.loc[leaf_name].reset_index()
911
- palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
912
-
913
- dir_path = Path.cwd()
914
- dir_path = Path(dir_path / "tree_effect.png")
915
- tree2.render(dir_path, tree_style=tree_style, units="in")
916
- _, ax = plt.subplots(1, 2, figsize=(10, 10))
917
- sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
918
- img = mpimg.imread(dir_path)
919
- ax[0].imshow(img)
920
- ax[0].get_xaxis().set_visible(False)
921
- ax[0].get_yaxis().set_visible(False)
922
- ax[0].set_frame_on(False)
923
-
924
- ax[1].get_yaxis().set_visible(False)
925
- ax[1].spines["left"].set_visible(False)
926
- ax[1].spines["right"].set_visible(False)
927
- ax[1].spines["top"].set_visible(False)
928
- plt.xlim(-leaf_eff_max, leaf_eff_max)
929
- plt.subplots_adjust(wspace=0)
930
-
931
- if file_name is not None:
932
- plt.savefig(file_name)
933
-
934
- if file_name is not None and not show_leaf_effects:
935
- tree2.render(file_name, tree_style=tree_style, units=units)
936
- if show:
937
- if not show_leaf_effects:
938
- return tree2.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
939
- else:
940
- if not show_leaf_effects:
941
- return tree2, tree_style
502
+ from pertpy.tools import Tasccoda
503
+
504
+ coda = Tasccoda()
505
+ return coda.plot_draw_effects(
506
+ data=data,
507
+ modality_key=modality_key,
508
+ covariate=covariate,
509
+ tree=tree,
510
+ show_legend=show_legend,
511
+ show_leaf_effects=show_leaf_effects,
512
+ tight_text=tight_text,
513
+ show_scale=show_scale,
514
+ show=show,
515
+ save=save,
516
+ units=units,
517
+ figsize=figsize,
518
+ dpi=dpi,
519
+ )
942
520
 
943
521
  @staticmethod
944
522
  def effects_umap( # pragma: no cover
945
523
  data: MuData,
946
- effect_name: Optional[Union[str, list]],
524
+ effect_name: str | list | None,
947
525
  cluster_key: str,
948
526
  modality_key_1: str = "rna",
949
527
  modality_key_2: str = "coda",
@@ -953,49 +531,71 @@ class CodaPlot:
953
531
  ):
954
532
  """Plot a UMAP visualization colored by effect strength.
955
533
 
956
- Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData (default is data['rna']) depending on the cluster they were assigned to.
534
+ Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData
535
+ (default is data['rna']) depending on the cluster they were assigned to.
957
536
 
958
537
  Args:
959
538
  data: AnnData object or MuData object.
960
- effect_name: The name of the effect results in .varm of aggregated sample-level AnnData (default is data['coda']) to plot
961
- cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']). To assign cell types' effects to original cells.
539
+ effect_name: The name of the effect results in .varm of aggregated sample-level AnnData to plot
540
+ cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']).
541
+ To assign cell types' effects to original cells.
962
542
  modality_key_1: Key to the cell-level AnnData in the MuData object. Defaults to "rna".
963
- modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object. Defaults to "coda".
543
+ modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
544
+ Defaults to "coda".
964
545
  show: Whether to display the figure or return axis. Defaults to None.
965
- ax: A matplotlib axes object. Only works if plotting a single component. Defaults to None.
546
+ ax: A matplotlib axes object. Only works if plotting a single component.
547
+ Defaults to None.
966
548
  **kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
967
549
 
968
550
  Returns:
969
551
  If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
970
552
 
971
553
  Examples:
972
- Example with scCODA:
973
554
  >>> import pertpy as pt
974
- >>> haber_cells = pt.dt.haber_2017_regions()
975
- >>> sccoda = pt.tl.Sccoda()
976
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
977
- sample_identifier="batch", covariate_obs=["condition"])
978
- >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
979
- >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
980
-
981
- >>> pt.pl.coda.effects_umap(mdata, effect_name="", cluster_key="")
982
- #TODO: Add effect_name parameter and cluster_key and test the example
555
+ >>> import schist
556
+ >>> adata = pt.dt.haber_2017_regions()
557
+ >>> schist.inference.nested_model(adata, samples=100, random_seed=5678)
558
+ >>> tasccoda_model = pt.tl.Tasccoda()
559
+ >>> tasccoda_data = tasccoda_model.load(adata, type="cell_level",
560
+ >>> cell_type_identifier="nsbm_level_1",
561
+ >>> sample_identifier="batch", covariate_obs=["condition"],
562
+ >>> levels_orig=["nsbm_level_4", "nsbm_level_3", "nsbm_level_2", "nsbm_level_1"],
563
+ >>> add_level_name=True)sccoda = pt.tl.Sccoda()
564
+ >>> tasccoda_model.prepare(
565
+ >>> tasccoda_data,
566
+ >>> modality_key="coda",
567
+ >>> reference_cell_type="18",
568
+ >>> formula="condition",
569
+ >>> pen_args={"phi": 0, "lambda_1": 3.5},
570
+ >>> tree_key="tree"
571
+ >>> )
572
+ >>> tasccoda_model.run_nuts(
573
+ ... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
574
+ ... )
575
+ >>> tasccoda_model.plot_effects_umap(tasccoda_data,
576
+ >>> effect_name=["effect_df_condition[T.Salmonella]",
577
+ >>> "effect_df_condition[T.Hpoly.Day3]",
578
+ >>> "effect_df_condition[T.Hpoly.Day10]"],
579
+ >>> cluster_key="nsbm_level_1",
580
+ >>> )
983
581
  """
984
- data_rna = data[modality_key_1]
985
- data_coda = data[modality_key_2]
986
- if isinstance(effect_name, str):
987
- effect_name = [effect_name]
988
- for _, effect in enumerate(effect_name):
989
- data_rna.obs[effect] = [data_coda.varm[effect].loc[f"{c}", "Effect"] for c in data_rna.obs[cluster_key]]
990
- if kwargs.get("vmin"):
991
- vmin = kwargs["vmin"]
992
- kwargs.pop("vmin")
993
- else:
994
- vmin = min(data_rna.obs[effect].min() for _, effect in enumerate(effect_name))
995
- if kwargs.get("vmax"):
996
- vmax = kwargs["vmax"]
997
- kwargs.pop("vmax")
998
- else:
999
- vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name))
1000
-
1001
- return sc.pl.umap(data_rna, color=effect_name, vmax=vmax, vmin=vmin, ax=ax, show=show, **kwargs)
582
+ warnings.warn(
583
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
584
+ " Please use the corresponding 'pt.tl' object for plotting function directly.",
585
+ FutureWarning,
586
+ stacklevel=2,
587
+ )
588
+
589
+ from pertpy.tools import Tasccoda
590
+
591
+ coda = Tasccoda()
592
+ coda.plot_effects_umap(
593
+ data=data,
594
+ effect_name=effect_name,
595
+ cluster_key=cluster_key,
596
+ modality_key_1=modality_key_1,
597
+ modality_key_2=modality_key_2,
598
+ show=show,
599
+ ax=ax,
600
+ **kwargs,
601
+ )