pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/plot/_guide_rna.py CHANGED
@@ -1,13 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import warnings
3
4
  from typing import TYPE_CHECKING
4
5
 
5
- import numpy as np
6
- import scanpy as sc
7
-
8
6
  if TYPE_CHECKING:
7
+ import numpy as np
9
8
  from anndata import AnnData
10
- from matplotlib.axes import Axes
11
9
 
12
10
 
13
11
  class GuideRnaPlot:
@@ -17,8 +15,8 @@ class GuideRnaPlot:
17
15
  layer: str | None = None,
18
16
  order_by: np.ndarray | str | None = None,
19
17
  key_to_save_order: str = None,
20
- **kwds,
21
- ) -> list[Axes]:
18
+ **kwargs,
19
+ ):
22
20
  """Heatmap plotting of guide RNA expression matrix.
23
21
 
24
22
  Assuming guides have sparse expression, this function reorders cells
@@ -36,7 +34,7 @@ class GuideRnaPlot:
36
34
  If a string is provided, adata.obs[order_by] will be used as the order.
37
35
  If a numpy array is provided, the array will be used for ordering.
38
36
  key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None.
39
- kwds: Are passed to sc.pl.heatmap.
37
+ kwargs: Are passed to sc.pl.heatmap.
40
38
 
41
39
  Returns:
42
40
  List of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap.
@@ -47,36 +45,20 @@ class GuideRnaPlot:
47
45
  visualized using a heatmap.
48
46
 
49
47
  >>> import pertpy as pt
50
- >>> mdata = pt.data.papalexi_2021()
51
- >>> gdo = mdata.mod['gdo']
48
+ >>> mdata = pt.dt.papalexi_2021()
49
+ >>> gdo = mdata.mod["gdo"]
52
50
  >>> ga = pt.pp.GuideAssignment()
53
51
  >>> ga.assign_by_threshold(gdo, assignment_threshold=5)
54
- >>> pt.pl.guide.heatmap(gdo)
52
+ >>> ga.plot_heatmap(gdo)
55
53
  """
56
- data = adata.X if layer is None else adata.layers[layer]
54
+ warnings.warn(
55
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
56
+ " Please use the corresponding 'pt.tl' object",
57
+ FutureWarning,
58
+ stacklevel=2,
59
+ )
57
60
 
58
- if order_by is None:
59
- max_guide_index = np.where(
60
- np.array(data.max(axis=1)).squeeze() != data.min(), np.array(data.argmax(axis=1)).squeeze(), -1
61
- )
62
- order = np.argsort(max_guide_index)
63
- elif isinstance(order_by, str):
64
- order = adata.obs[order_by]
65
- else:
66
- order = order_by
61
+ from pertpy.preprocessing import GuideAssignment
67
62
 
68
- adata.obs["_tmp_pertpy_grna_plot_dummy_group"] = ""
69
- if key_to_save_order is not None:
70
- adata.obs[key_to_save_order] = order
71
- axis_group = sc.pl.heatmap(
72
- adata[order],
73
- adata.var.index.tolist(),
74
- groupby="_tmp_pertpy_grna_plot_dummy_group",
75
- cmap="viridis",
76
- use_raw=False,
77
- dendrogram=False,
78
- layer=layer,
79
- **kwds,
80
- )
81
- del adata.obs["_tmp_pertpy_grna_plot_dummy_group"]
82
- return axis_group
63
+ ga = GuideAssignment()
64
+ ga.plot_heatmap(adata=adata, layer=layer, order_by=order_by, key_to_save_order=key_to_save_order, kwargs=kwargs)
pertpy/plot/_milopy.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import warnings
3
4
  from typing import TYPE_CHECKING
4
5
 
5
6
  import matplotlib.pyplot as plt
@@ -52,44 +53,29 @@ class MilopyPlot:
52
53
  >>> sc.tl.umap(mdata["rna"])
53
54
  >>> milo.make_nhoods(mdata["rna"])
54
55
  >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
55
- >>> milo.da_nhoods(mdata, design="~label")
56
+ >>> milo.da_nhoods(mdata,
57
+ >>> design='~label',
58
+ >>> model_contrasts='labelwithdraw_15d_Cocaine-labelwithdraw_48h_Cocaine')
56
59
  >>> milo.build_nhood_graph(mdata)
57
- >>> pt.pl.milo.nhood_graph(mdata)
58
- # TODO: If necessary adjust after fixing StopIteration error, which is currently thrown
60
+ >>> milo.plot_nhood_graph(mdata)
59
61
  """
60
- nhood_adata = mdata["milo"].T.copy()
61
-
62
- if "Nhood_size" not in nhood_adata.obs.columns:
63
- raise KeyError(
64
- 'Cannot find "Nhood_size" column in adata.uns["nhood_adata"].obs -- \
65
- please run milopy.utils.build_nhood_graph(adata)'
66
- )
67
-
68
- nhood_adata.obs["graph_color"] = nhood_adata.obs["logFC"]
69
- nhood_adata.obs.loc[nhood_adata.obs["SpatialFDR"] > alpha, "graph_color"] = np.nan
70
- nhood_adata.obs["abs_logFC"] = abs(nhood_adata.obs["logFC"])
71
- nhood_adata.obs.loc[nhood_adata.obs["abs_logFC"] < min_logFC, "graph_color"] = np.nan
62
+ warnings.warn(
63
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
64
+ " Please use the corresponding 'pt.tl' object",
65
+ FutureWarning,
66
+ stacklevel=2,
67
+ )
72
68
 
73
- # Plotting order - extreme logFC on top
74
- nhood_adata.obs.loc[nhood_adata.obs["graph_color"].isna(), "abs_logFC"] = np.nan
75
- ordered = nhood_adata.obs.sort_values("abs_logFC", na_position="first").index
76
- nhood_adata = nhood_adata[ordered]
69
+ from pertpy.tools import Milo
77
70
 
78
- vmax = np.max([nhood_adata.obs["graph_color"].max(), abs(nhood_adata.obs["graph_color"].min())])
79
- vmin = -vmax
71
+ milo = Milo()
80
72
 
81
- sc.pl.embedding(
82
- nhood_adata,
83
- "X_milo_graph",
84
- color="graph_color",
85
- cmap="RdBu_r",
86
- size=nhood_adata.obs["Nhood_size"] * min_size,
87
- edges=plot_edges,
88
- neighbors_key="nhood",
89
- sort_order=False,
90
- frameon=False,
91
- vmax=vmax,
92
- vmin=vmin,
73
+ return milo.plot_nhood_graph(
74
+ madata=mdata,
75
+ alpha=alpha,
76
+ min_logFC=min_logFC,
77
+ min_size=min_size,
78
+ plot_edges=plot_edges,
93
79
  title=title,
94
80
  show=show,
95
81
  save=save,
@@ -127,12 +113,19 @@ class MilopyPlot:
127
113
  >>> milo.make_nhoods(mdata["rna"])
128
114
  >>> pt.pl.milo.nhood(mdata, ix=0)
129
115
  """
130
-
131
- mdata[feature_key].obs["Nhood"] = mdata[feature_key].obsm["nhoods"][:, ix].toarray().ravel()
132
- sc.pl.embedding(
133
- mdata[feature_key], basis, color="Nhood", size=30, title="Nhood" + str(ix), show=show, save=save, **kwargs
116
+ warnings.warn(
117
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
118
+ " Please use the corresponding 'pt.tl' object",
119
+ FutureWarning,
120
+ stacklevel=2,
134
121
  )
135
122
 
123
+ from pertpy.tools import Milo
124
+
125
+ milo = Milo()
126
+
127
+ milo.plot_nhood(mdata=mdata, ix=ix, feature_key=feature_key, basis=basis, show=show, save=save, **kwargs)
128
+
136
129
  @staticmethod
137
130
  def da_beeswarm(
138
131
  mdata: MuData,
@@ -146,7 +139,7 @@ class MilopyPlot:
146
139
 
147
140
  Args:
148
141
  mdata: MuData object
149
- anno_col: Column in adata.uns['nhood_adata'].obs to use as annotation. (default: 'nhood_annotation'.)
142
+ anno_col: Column in mdata['milo'].var to use as annotation. (default: 'nhood_annotation'.)
150
143
  alpha: Significance threshold. (default: 0.1)
151
144
  subset_nhoods: List of nhoods to plot. If None, plot all nhoods. (default: None)
152
145
  palette: Name of Seaborn color palette for violinplots.
@@ -162,84 +155,28 @@ class MilopyPlot:
162
155
  >>> milo.make_nhoods(mdata["rna"])
163
156
  >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
164
157
  >>> milo.da_nhoods(mdata, design="~label")
165
- >>> milo.annotate_nhoods(mdata, anno_col='cell_type')
166
- >>> pt.pl.milo.da_beeswarm(mdata)
158
+ >>> milo.annotate_nhoods(mdata, anno_col="cell_type")
159
+ >>> milo.plot_da_beeswarm(mdata)
167
160
  """
168
- try:
169
- nhood_adata = mdata["milo"].T.copy()
170
- except KeyError:
171
- raise RuntimeError(
172
- "mdata should be a MuData object with two slots: feature_key and 'milo'. Run 'milopy.count_nhoods(adata)' first."
173
- ) from None
174
-
175
- if subset_nhoods is not None:
176
- nhood_adata = nhood_adata[subset_nhoods]
177
-
178
- try:
179
- nhood_adata.obs[anno_col]
180
- except KeyError:
181
- raise RuntimeError(
182
- f"Unable to find {anno_col} in mdata.uns['nhood_adata']. Run 'milopy.utils.annotate_nhoods(adata, anno_col)' first"
183
- ) from None
184
-
185
- try:
186
- nhood_adata.obs["logFC"]
187
- except KeyError:
188
- raise RuntimeError(
189
- "Unable to find 'logFC' in mdata.uns['nhood_adata'].obs. Run 'core.da_nhoods(adata)' first."
190
- ) from None
191
-
192
- sorted_annos = (
193
- nhood_adata.obs[[anno_col, "logFC"]].groupby(anno_col).median().sort_values("logFC", ascending=True).index
161
+ warnings.warn(
162
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
163
+ " Please use the corresponding 'pt.tl' object",
164
+ FutureWarning,
165
+ stacklevel=2,
194
166
  )
195
167
 
196
- anno_df = nhood_adata.obs[[anno_col, "logFC", "SpatialFDR"]].copy()
197
- anno_df["is_signif"] = anno_df["SpatialFDR"] < alpha
198
- anno_df = anno_df[anno_df[anno_col] != "nan"]
168
+ from pertpy.tools import Milo
199
169
 
200
- try:
201
- obs_col = nhood_adata.uns["annotation_obs"]
202
- if palette is None:
203
- palette = dict(
204
- zip(mdata[feature_key].obs[obs_col].cat.categories, mdata[feature_key].uns[f"{obs_col}_colors"])
205
- )
206
- sns.violinplot(
207
- data=anno_df,
208
- y=anno_col,
209
- x="logFC",
210
- order=sorted_annos,
211
- size=190,
212
- inner=None,
213
- orient="h",
214
- palette=palette,
215
- linewidth=0,
216
- scale="width",
217
- )
218
- except BaseException: # noqa: BLE001
219
- sns.violinplot(
220
- data=anno_df,
221
- y=anno_col,
222
- x="logFC",
223
- order=sorted_annos,
224
- size=190,
225
- inner=None,
226
- orient="h",
227
- linewidth=0,
228
- scale="width",
229
- )
230
- sns.stripplot(
231
- data=anno_df,
232
- y=anno_col,
233
- x="logFC",
234
- order=sorted_annos,
235
- size=2,
236
- hue="is_signif",
237
- palette=["grey", "black"],
238
- orient="h",
239
- alpha=0.5,
170
+ milo = Milo()
171
+
172
+ milo.plot_da_beeswarm(
173
+ mdata=mdata,
174
+ feature_key=feature_key,
175
+ anno_col=anno_col,
176
+ alpha=alpha,
177
+ subset_nhoods=subset_nhoods,
178
+ palette=palette,
240
179
  )
241
- plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False)
242
- plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--")
243
180
 
244
181
  @staticmethod
245
182
  def nhood_counts_by_cond(
@@ -256,29 +193,17 @@ class MilopyPlot:
256
193
  subset_nhoods: List of obs_names for neighbourhoods to include in plot. If None, plot all nhoods. (default: None)
257
194
  log_counts: Whether to plot log1p of cell counts. (default: False)
258
195
  """
259
- try:
260
- nhood_adata = mdata["milo"].T.copy()
261
- except KeyError:
262
- raise RuntimeError(
263
- "mdata should be a MuData object with two slots: feature_key and 'milo'. Run milopy.count_nhoods(mdata) first"
264
- ) from None
196
+ warnings.warn(
197
+ "This function is deprecated and will be removed in pertpy 0.8.0!"
198
+ " Please use the corresponding 'pt.tl' object",
199
+ FutureWarning,
200
+ stacklevel=2,
201
+ )
265
202
 
266
- if subset_nhoods is None:
267
- subset_nhoods = nhood_adata.obs_names
203
+ from pertpy.tools import Milo
268
204
 
269
- pl_df = pd.DataFrame(nhood_adata[subset_nhoods].X.A, columns=nhood_adata.var_names).melt(
270
- var_name=nhood_adata.uns["sample_col"], value_name="n_cells"
271
- )
272
- pl_df = pd.merge(pl_df, nhood_adata.var)
273
- pl_df["log_n_cells"] = np.log1p(pl_df["n_cells"])
274
- if not log_counts:
275
- sns.boxplot(data=pl_df, x=test_var, y="n_cells", color="lightblue")
276
- sns.stripplot(data=pl_df, x=test_var, y="n_cells", color="black", s=3)
277
- plt.ylabel("# cells")
278
- else:
279
- sns.boxplot(data=pl_df, x=test_var, y="log_n_cells", color="lightblue")
280
- sns.stripplot(data=pl_df, x=test_var, y="log_n_cells", color="black", s=3)
281
- plt.ylabel("log(# cells + 1)")
205
+ milo = Milo()
282
206
 
283
- plt.xticks(rotation=90)
284
- plt.xlabel(test_var)
207
+ milo.plot_nhood_counts_by_cond(
208
+ mdata=mdata, test_var=test_var, subset_nhoods=subset_nhoods, log_counts=log_counts
209
+ )