pertpy 0.11.3__py3-none-any.whl → 0.11.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pertpy/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = "Lukas Heumos"
4
4
  __email__ = "lukas.heumos@posteo.net"
5
- __version__ = "0.11.3"
5
+ __version__ = "0.11.5"
6
6
 
7
7
  import warnings
8
8
 
@@ -16,7 +16,7 @@ from jax import config, random
16
16
  from lamin_utils import logger
17
17
  from matplotlib import cm, rcParams
18
18
  from matplotlib import image as mpimg
19
- from matplotlib.colors import ListedColormap
19
+ from matplotlib.colors import Colormap
20
20
  from mudata import MuData
21
21
  from numpyro.infer import HMC, MCMC, NUTS, initialization
22
22
  from rich import box, print
@@ -34,7 +34,6 @@ if TYPE_CHECKING:
34
34
  from ete4 import Tree
35
35
  from jax._src.typing import Array
36
36
  from matplotlib.axes import Axes
37
- from matplotlib.colors import Colormap
38
37
  from matplotlib.figure import Figure
39
38
 
40
39
  config.update("jax_enable_x64", True)
@@ -1141,7 +1140,7 @@ class CompositionalModel2(ABC):
1141
1140
  level_names: list[str],
1142
1141
  figsize: tuple[float, float] | None = None,
1143
1142
  dpi: int | None = 100,
1144
- palette: ListedColormap | None = cm.tab20,
1143
+ palette: str | Colormap | None = cm.tab20,
1145
1144
  show_legend: bool | None = True,
1146
1145
  ) -> plt.Axes:
1147
1146
  """Plots a stacked barplot for one (discrete) covariate.
@@ -1156,12 +1155,15 @@ class CompositionalModel2(ABC):
1156
1155
  level_names: Names of the covariate's levels
1157
1156
  figsize: Figure size (matplotlib).
1158
1157
  dpi: Resolution in DPI (matplotlib).
1159
- palette: The color map for the barplot.
1158
+ palette: The color map (name) for the barplot.
1160
1159
  show_legend: If True, adds a legend.
1161
1160
 
1162
1161
  Returns:
1163
1162
  A :class:`~matplotlib.axes.Axes` object
1164
1163
  """
1164
+ if isinstance(palette, str):
1165
+ palette = getattr(cm, palette)
1166
+
1165
1167
  n_bars, n_types = y.shape
1166
1168
 
1167
1169
  figsize = rcParams["figure.figsize"] if figsize is None else figsize
@@ -1202,7 +1204,7 @@ class CompositionalModel2(ABC):
1202
1204
  feature_name: str,
1203
1205
  *,
1204
1206
  modality_key: str = "coda",
1205
- palette: ListedColormap | None = cm.tab20,
1207
+ palette: str | Colormap | None = cm.tab20,
1206
1208
  show_legend: bool | None = True,
1207
1209
  level_order: list[str] = None,
1208
1210
  figsize: tuple[float, float] | None = None,
@@ -1217,7 +1219,7 @@ class CompositionalModel2(ABC):
1217
1219
  modality_key: If data is a MuData object, specify which modality to use.
1218
1220
  figsize: Figure size.
1219
1221
  dpi: Dpi setting.
1220
- palette: The matplotlib color map for the barplot.
1222
+ palette: The matplotlib color map (name) for the barplot.
1221
1223
  show_legend: If True, adds a legend.
1222
1224
  level_order: Custom ordering of bars on the x-axis.
1223
1225
  {common_plot_args}
@@ -1299,7 +1301,7 @@ class CompositionalModel2(ABC):
1299
1301
  plot_facets: bool = True,
1300
1302
  plot_zero_covariate: bool = True,
1301
1303
  plot_zero_cell_type: bool = False,
1302
- palette: str | ListedColormap | None = cm.tab20,
1304
+ palette: str | Colormap | None = cm.tab20,
1303
1305
  level_order: list[str] = None,
1304
1306
  args_barplot: dict | None = None,
1305
1307
  figsize: tuple[float, float] | None = None,
@@ -1321,7 +1323,7 @@ class CompositionalModel2(ABC):
1321
1323
  plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot.
1322
1324
  figsize: Figure size.
1323
1325
  dpi: Figure size.
1324
- palette: The seaborn color map for the barplot.
1326
+ palette: The seaborn color map (name) for the barplot.
1325
1327
  level_order: Custom ordering of bars on the x-axis.
1326
1328
  args_barplot: Arguments passed to sns.barplot.
1327
1329
  {common_plot_args}
@@ -1397,7 +1399,7 @@ class CompositionalModel2(ABC):
1397
1399
 
1398
1400
  # If plot as facets, create a FacetGrid and map barplot to it.
1399
1401
  if plot_facets:
1400
- if isinstance(palette, ListedColormap):
1402
+ if isinstance(palette, Colormap):
1401
1403
  palette = np.array([palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]).tolist()
1402
1404
  if figsize is not None:
1403
1405
  height = figsize[0]
@@ -1437,7 +1439,7 @@ class CompositionalModel2(ABC):
1437
1439
  else:
1438
1440
  _, ax = plt.subplots(figsize=figsize, dpi=dpi)
1439
1441
  if len(covariate_names) == 1:
1440
- if isinstance(palette, ListedColormap):
1442
+ if isinstance(palette, Colormap):
1441
1443
  palette = np.array(
1442
1444
  [palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]
1443
1445
  ).tolist()
@@ -1451,7 +1453,7 @@ class CompositionalModel2(ABC):
1451
1453
  )
1452
1454
  ax.set_title(covariate_names[0])
1453
1455
  else:
1454
- if isinstance(palette, ListedColormap):
1456
+ if isinstance(palette, Colormap):
1455
1457
  palette = np.array([palette(i % palette.N) for i in range(len(covariate_names))]).tolist()
1456
1458
  sns.barplot(
1457
1459
  data=plot_df,
@@ -1485,7 +1487,7 @@ class CompositionalModel2(ABC):
1485
1487
  cell_types: list | None = None,
1486
1488
  args_boxplot: dict | None = None,
1487
1489
  args_swarmplot: dict | None = None,
1488
- palette: str | None = "Blues",
1490
+ palette: str | Colormap | None = "Blues",
1489
1491
  show_legend: bool | None = True,
1490
1492
  level_order: list[str] = None,
1491
1493
  figsize: tuple[float, float] | None = None,
@@ -1510,7 +1512,7 @@ class CompositionalModel2(ABC):
1510
1512
  args_swarmplot: Arguments passed to sns.swarmplot.
1511
1513
  figsize: Figure size.
1512
1514
  dpi: Dpi setting.
1513
- palette: The seaborn color map for the barplot.
1515
+ palette: The seaborn color map (name) for the barplot.
1514
1516
  show_legend: If True, adds a legend.
1515
1517
  level_order: Custom ordering of bars on the x-axis.
1516
1518
  {common_plot_args}
@@ -1535,6 +1537,8 @@ class CompositionalModel2(ABC):
1535
1537
  args_swarmplot = {}
1536
1538
  if isinstance(data, MuData):
1537
1539
  data = data[modality_key]
1540
+ if isinstance(palette, Colormap):
1541
+ palette = list(palette(range(len(data.obs[feature_name].unique()))))
1538
1542
 
1539
1543
  # y scale transformations
1540
1544
  if y_scale == "relative":
@@ -2104,7 +2108,7 @@ class CompositionalModel2(ABC):
2104
2108
  modality_key_1: str = "rna",
2105
2109
  modality_key_2: str = "coda",
2106
2110
  color_map: Colormap | str | None = None,
2107
- palette: str | Sequence[str] | None = None,
2111
+ palette: str | Sequence[str] | Colormap | None = None,
2108
2112
  ax: Axes = None,
2109
2113
  return_fig: bool = False,
2110
2114
  **kwargs,
@@ -2122,7 +2126,7 @@ class CompositionalModel2(ABC):
2122
2126
  modality_key_1: Key to the cell-level AnnData in the MuData object.
2123
2127
  modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
2124
2128
  color_map: The color map to use for plotting.
2125
- palette: The color palette to use for plotting.
2129
+ palette: The color palette (name) to use for plotting.
2126
2130
  ax: A matplotlib axes object. Only works if plotting a single component.
2127
2131
  {common_plot_args}
2128
2132
  **kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
@@ -2154,9 +2158,6 @@ class CompositionalModel2(ABC):
2154
2158
  >>> tasccoda_model.run_nuts(
2155
2159
  ... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
2156
2160
  ... )
2157
- >>> tasccoda_model.run_nuts(
2158
- ... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
2159
- ... )
2160
2161
  >>> sc.tl.umap(tasccoda_data["rna"])
2161
2162
  >>> tasccoda_model.plot_effects_umap(tasccoda_data,
2162
2163
  >>> effect_name=["effect_df_condition[T.Salmonella]",
@@ -2173,6 +2174,10 @@ class CompositionalModel2(ABC):
2173
2174
  data_coda = mdata[modality_key_2]
2174
2175
  if isinstance(effect_name, str):
2175
2176
  effect_name = [effect_name]
2177
+ if isinstance(palette, Colormap):
2178
+ palette = {
2179
+ cluster: palette(i % palette.N) for i, cluster in enumerate(data_rna.obs[cluster_key].unique().tolist())
2180
+ }
2176
2181
  for _, effect in enumerate(effect_name):
2177
2182
  data_rna.obs[effect] = [data_coda.varm[effect].loc[f"{c}", "Effect"] for c in data_rna.obs[cluster_key]]
2178
2183
  if kwargs.get("vmin"):
pertpy/tools/_dialogue.py CHANGED
@@ -80,7 +80,7 @@ class Dialogue:
80
80
  Returns:
81
81
  A Pandas DataFrame of pseudobulk counts
82
82
  """
83
- # TODO: Replace with decoupler's implementation
83
+ # TODO: Replace with scanpy get implementation
84
84
  pseudobulk = {"Genes": adata.var_names.values}
85
85
 
86
86
  for category in adata.obs.loc[:, groupby].cat.categories:
@@ -572,9 +572,7 @@ class MethodBase(ABC):
572
572
  if any(adata.obs[[groupby, pairedby]].value_counts() > 1):
573
573
  logger.info("Performing pseudobulk for paired samples")
574
574
  ps = PseudobulkSpace()
575
- adata = ps.compute(
576
- adata, target_col=groupby, groups_col=pairedby, layer_key=layer, mode="sum", min_cells=1, min_counts=1
577
- )
575
+ adata = ps.compute(adata, target_col=groupby, groups_col=pairedby, layer_key=layer, mode="sum")
578
576
 
579
577
  X = adata.layers[layer] if layer is not None else adata.X
580
578
  with contextlib.suppress(AttributeError):
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  from abc import ABC, abstractmethod
4
4
  from typing import TYPE_CHECKING, Literal, NamedTuple
5
5
 
6
+ import jax
6
7
  import numpy as np
7
8
  import pandas as pd
8
9
  from numba import jit
@@ -685,6 +686,7 @@ class WassersteinDistance(AbstractDistance):
685
686
  def __init__(self) -> None:
686
687
  super().__init__()
687
688
  self.accepts_precomputed = False
689
+ self.solver = jax.jit(Sinkhorn())
688
690
 
689
691
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
690
692
  X = np.asarray(X, dtype=np.float64)
@@ -699,8 +701,7 @@ class WassersteinDistance(AbstractDistance):
699
701
 
700
702
  def solve_ot_problem(self, geom: Geometry, **kwargs):
701
703
  ot_prob = LinearProblem(geom)
702
- solver = Sinkhorn()
703
- ot = solver(ot_prob, **kwargs)
704
+ ot = self.solver(ot_prob, **kwargs)
704
705
  cost = float(ot.reg_ot_cost)
705
706
 
706
707
  # Check for NaN or invalid cost
pertpy/tools/_milo.py CHANGED
@@ -364,19 +364,32 @@ class Milo:
364
364
  # Set up rpy2 to run edgeR
365
365
  edgeR, limma, stats, base = self._setup_rpy2()
366
366
 
367
+ import rpy2.robjects as ro
368
+ from rpy2.robjects import numpy2ri, pandas2ri
369
+ from rpy2.robjects.conversion import localconverter
370
+ from rpy2.robjects.vectors import FloatVector
371
+
367
372
  # Define model matrix
368
373
  if not add_intercept or model_contrasts is not None:
369
374
  design = design + " + 0"
370
- model = stats.model_matrix(object=stats.formula(design), data=design_df)
375
+ design_df = design_df.astype(dict.fromkeys(design_df.select_dtypes(exclude=["number"]).columns, "category"))
376
+ with localconverter(ro.default_converter + pandas2ri.converter):
377
+ design_r = pandas2ri.py2rpy(design_df)
378
+ formula_r = stats.formula(design)
379
+ model = stats.model_matrix(object=formula_r, data=design_r)
371
380
 
372
381
  # Fit NB-GLM
373
- dge = edgeR.DGEList(counts=count_mat[keep_nhoods, :][:, keep_smp], lib_size=lib_size[keep_smp])
382
+ counts_filtered = count_mat[np.ix_(keep_nhoods, keep_smp)]
383
+ lib_size_filtered = lib_size[keep_smp]
384
+ count_mat_r = numpy2ri.py2rpy(counts_filtered)
385
+ lib_size_r = FloatVector(lib_size_filtered)
386
+ dge = edgeR.DGEList(counts=count_mat_r, lib_size=lib_size_r)
374
387
  dge = edgeR.calcNormFactors(dge, method="TMM")
375
388
  dge = edgeR.estimateDisp(dge, model)
376
389
  fit = edgeR.glmQLFit(dge, model, robust=True)
377
-
378
390
  # Test
379
- n_coef = model.shape[1]
391
+ model_np = np.array(model)
392
+ n_coef = model_np.shape[1]
380
393
  if model_contrasts is not None:
381
394
  r_str = """
382
395
  get_model_cols <- function(design_df, design){
@@ -387,32 +400,36 @@ class Milo:
387
400
  from rpy2.robjects.packages import STAP
388
401
 
389
402
  get_model_cols = STAP(r_str, "get_model_cols")
390
- model_mat_cols = get_model_cols.get_model_cols(design_df, design)
391
- model_df = pd.DataFrame(model)
403
+ with localconverter(ro.default_converter + numpy2ri.converter + pandas2ri.converter):
404
+ model_mat_cols = get_model_cols.get_model_cols(design_df, design)
405
+ with localconverter(ro.default_converter + pandas2ri.converter + numpy2ri.converter):
406
+ model_df = pandas2ri.rpy2py(model)
407
+ model_df = pd.DataFrame(model_df)
392
408
  model_df.columns = model_mat_cols
393
409
  try:
394
- mod_contrast = limma.makeContrasts(contrasts=model_contrasts, levels=model_df)
410
+ with localconverter(ro.default_converter + pandas2ri.converter):
411
+ mod_contrast = limma.makeContrasts(contrasts=model_contrasts, levels=model_df)
395
412
  except ValueError:
396
413
  logger.error("Model contrasts must be in the form 'A-B' or 'A+B'")
397
414
  raise
398
- res = base.as_data_frame(
399
- edgeR.topTags(edgeR.glmQLFTest(fit, contrast=mod_contrast), sort_by="none", n=np.inf)
400
- )
415
+ with localconverter(ro.default_converter + pandas2ri.converter + numpy2ri.converter):
416
+ res = base.as_data_frame(
417
+ edgeR.topTags(edgeR.glmQLFTest(fit, contrast=mod_contrast), sort_by="none", n=np.inf)
418
+ )
401
419
  else:
402
- res = base.as_data_frame(edgeR.topTags(edgeR.glmQLFTest(fit, coef=n_coef), sort_by="none", n=np.inf))
403
-
404
- from rpy2.robjects import conversion
405
-
406
- res = conversion.rpy2py(res)
420
+ with localconverter(ro.default_converter + numpy2ri.converter + pandas2ri.converter):
421
+ res = base.as_data_frame(
422
+ edgeR.topTags(edgeR.glmQLFTest(fit, coef=n_coef), sort_by="none", n=np.inf)
423
+ )
407
424
  if not isinstance(res, pd.DataFrame):
408
425
  res = pd.DataFrame(res)
409
-
426
+ # The columns of res looks like e.g. table.A, table.B, so remove the prefix
427
+ res.columns = [col.replace("table.", "") for col in res.columns]
410
428
  # Save outputs
411
429
  res.index = sample_adata.var_names[keep_nhoods] # type: ignore
412
430
  if any(col in sample_adata.var.columns for col in res.columns):
413
431
  sample_adata.var = sample_adata.var.drop(res.columns, axis=1)
414
432
  sample_adata.var = pd.concat([sample_adata.var, res], axis=1)
415
-
416
433
  # Run Graph spatial FDR correction
417
434
  self._graph_spatial_fdr(sample_adata, neighbors_key=adata.uns["nhood_neighbors_key"])
418
435
 
@@ -657,11 +674,8 @@ class Milo:
657
674
  self,
658
675
  ):
659
676
  """Set up rpy2 to run edgeR."""
660
- from rpy2.robjects import numpy2ri, pandas2ri
661
677
  from rpy2.robjects.packages import importr
662
678
 
663
- numpy2ri.activate()
664
- pandas2ri.activate()
665
679
  edgeR = self._try_import_bioc_library("edgeR")
666
680
  limma = self._try_import_bioc_library("limma")
667
681
  stats = importr("stats")
@@ -1007,6 +1021,8 @@ class Milo:
1007
1021
  subset_nhoods: list[str] = None,
1008
1022
  log_counts: bool = False,
1009
1023
  return_fig: bool = False,
1024
+ ax=None,
1025
+ show: bool = True,
1010
1026
  ) -> Figure | None:
1011
1027
  """Plot boxplot of cell numbers vs condition of interest.
1012
1028
 
@@ -1036,18 +1052,36 @@ class Milo:
1036
1052
  pl_df = pd.merge(pl_df, nhood_adata.var)
1037
1053
  pl_df["log_n_cells"] = np.log1p(pl_df["n_cells"])
1038
1054
  if not log_counts:
1039
- sns.boxplot(data=pl_df, x=test_var, y="n_cells", color="lightblue")
1040
- sns.stripplot(data=pl_df, x=test_var, y="n_cells", color="black", s=3)
1041
- plt.ylabel("# cells")
1055
+ sns.boxplot(data=pl_df, x=test_var, y="n_cells", color="lightblue", ax=ax)
1056
+ sns.stripplot(data=pl_df, x=test_var, y="n_cells", color="black", s=3, ax=ax)
1057
+ if ax:
1058
+ ax.set_ylabel("# cells")
1059
+ else:
1060
+ plt.ylabel("# cells")
1042
1061
  else:
1043
- sns.boxplot(data=pl_df, x=test_var, y="log_n_cells", color="lightblue")
1044
- sns.stripplot(data=pl_df, x=test_var, y="log_n_cells", color="black", s=3)
1045
- plt.ylabel("log(# cells + 1)")
1062
+ sns.boxplot(data=pl_df, x=test_var, y="log_n_cells", color="lightblue", ax=ax)
1063
+ sns.stripplot(data=pl_df, x=test_var, y="log_n_cells", color="black", s=3, ax=ax)
1064
+ if ax:
1065
+ ax.set_ylabel("log(# cells + 1)")
1066
+ else:
1067
+ plt.ylabel("log(# cells + 1)")
1046
1068
 
1047
- plt.xticks(rotation=90)
1048
- plt.xlabel(test_var)
1069
+ if ax:
1070
+ ax.tick_params(axis="x", rotation=90)
1071
+ ax.set_xlabel(test_var)
1072
+ else:
1073
+ plt.xticks(rotation=90)
1074
+ plt.xlabel(test_var)
1049
1075
 
1050
1076
  if return_fig:
1051
1077
  return plt.gcf()
1052
- plt.show()
1078
+
1079
+ if ax is None:
1080
+ plt.show()
1081
+
1082
+ if return_fig:
1083
+ return plt.gcf()
1084
+ if show:
1085
+ plt.show()
1086
+
1053
1087
  return None
pertpy/tools/_mixscape.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import copy
4
+ import warnings
4
5
  from collections import OrderedDict
5
6
  from typing import TYPE_CHECKING, Literal
6
7
 
@@ -10,11 +11,12 @@ import pandas as pd
10
11
  import scanpy as sc
11
12
  import seaborn as sns
12
13
  from fast_array_utils.stats import mean, mean_var
14
+ from pandas.errors import PerformanceWarning
13
15
  from scanpy import get
14
16
  from scanpy._utils import _check_use_raw, sanitize_anndata
15
17
  from scanpy.plotting import _utils
16
18
  from scanpy.tools._utils import _choose_representation
17
- from scipy.sparse import csr_matrix, spmatrix
19
+ from scipy.sparse import csr_matrix, issparse, spmatrix
18
20
  from sklearn.mixture import GaussianMixture
19
21
 
20
22
  from pertpy._doc import _doc_params, doc_common_plot_args
@@ -103,6 +105,9 @@ class Mixscape:
103
105
 
104
106
  adata.layers["X_pert"] = adata.X.copy()
105
107
 
108
+ # Work with LIL for efficient indexing but don't store it in AnnData as LIL is not supported anymore
109
+ X_pert_lil = adata.layers["X_pert"].tolil() if issparse(adata.layers["X_pert"]) else adata.layers["X_pert"]
110
+
106
111
  control_mask = adata.obs[pert_key] == control
107
112
 
108
113
  if ref_selection_mode == "split_by":
@@ -110,9 +115,8 @@ class Mixscape:
110
115
  split_mask = adata.obs[split_by] == split
111
116
  control_mask_group = control_mask & split_mask
112
117
  control_mean_expr = mean(adata.X[control_mask_group], axis=0)
113
- adata.layers["X_pert"][split_mask] = (
114
- np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0)
115
- - adata.layers["X_pert"][split_mask]
118
+ X_pert_lil[split_mask] = (
119
+ np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0) - X_pert_lil[split_mask]
116
120
  )
117
121
  else:
118
122
  if split_by is None:
@@ -129,49 +133,43 @@ class Mixscape:
129
133
 
130
134
  for split_mask in split_masks:
131
135
  control_mask_split = control_mask & split_mask
132
-
133
136
  R_split = representation[split_mask]
134
137
  R_control = representation[np.asarray(control_mask_split)]
135
-
136
138
  eps = kwargs.pop("epsilon", 0.1)
137
139
  nn_index = NNDescent(R_control, **kwargs)
138
140
  indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
139
-
140
141
  X_control = np.expm1(adata.X[np.asarray(control_mask_split)])
141
-
142
142
  n_split = split_mask.sum()
143
143
  n_control = X_control.shape[0]
144
144
 
145
145
  if batch_size is None:
146
146
  col_indices = np.ravel(indices)
147
147
  row_indices = np.repeat(np.arange(n_split), n_neighbors)
148
-
149
148
  neigh_matrix = csr_matrix(
150
149
  (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)),
151
150
  shape=(n_split, n_control),
152
151
  )
153
152
  neigh_matrix /= n_neighbors
154
- adata.layers["X_pert"][np.asarray(split_mask)] = (
155
- sc.pp.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][np.asarray(split_mask)]
153
+ X_pert_lil[np.asarray(split_mask)] = (
154
+ sc.pp.log1p(neigh_matrix @ X_control) - X_pert_lil[np.asarray(split_mask)]
156
155
  )
157
156
  else:
158
157
  split_indices = np.where(split_mask)[0]
159
158
  for i in range(0, n_split, batch_size):
160
159
  size = min(i + batch_size, n_split)
161
160
  select = slice(i, size)
162
-
163
161
  batch = np.ravel(indices[select])
164
162
  split_batch = split_indices[select]
165
-
166
163
  size = size - i
167
-
168
164
  means_batch = X_control[batch]
169
165
  batch_reshaped = means_batch.reshape(size, n_neighbors, -1)
170
166
  means_batch, _ = mean_var(batch_reshaped, axis=1)
167
+ X_pert_lil[split_batch] = np.log1p(means_batch) - X_pert_lil[split_batch]
171
168
 
172
- adata.layers["X_pert"][split_batch] = (
173
- np.log1p(means_batch) - adata.layers["X_pert"][split_batch]
174
- )
169
+ if issparse(X_pert_lil):
170
+ adata.layers["X_pert"] = X_pert_lil.tocsr()
171
+ else:
172
+ adata.layers["X_pert"] = X_pert_lil
175
173
 
176
174
  if copy:
177
175
  return adata
@@ -531,26 +529,29 @@ class Mixscape:
531
529
  gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
532
530
  adata_split = adata[split_mask].copy()
533
531
  # find top DE genes between cells with targeting and non-targeting gRNAs
534
- sc.tl.rank_genes_groups(
535
- adata_split,
536
- layer=layer,
537
- groupby=labels,
538
- groups=gene_targets,
539
- reference=control,
540
- method=test_method,
541
- use_raw=False,
542
- )
543
- # get DE genes for each target gene
544
- for gene in gene_targets:
545
- logfc_threshold_mask = (
546
- np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
532
+ with warnings.catch_warnings():
533
+ warnings.simplefilter("ignore", RuntimeWarning)
534
+ warnings.simplefilter("ignore", PerformanceWarning)
535
+ sc.tl.rank_genes_groups(
536
+ adata_split,
537
+ layer=layer,
538
+ groupby=labels,
539
+ groups=gene_targets,
540
+ reference=control,
541
+ method=test_method,
542
+ use_raw=False,
547
543
  )
548
- de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask]
549
- pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask]
550
- de_genes = de_genes[pvals_adj < pval_cutoff]
551
- if len(de_genes) < min_de_genes:
552
- de_genes = np.array([])
553
- perturbation_markers[(category, gene)] = de_genes
544
+ # get DE genes for each target gene
545
+ for gene in gene_targets:
546
+ logfc_threshold_mask = (
547
+ np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
548
+ )
549
+ de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask]
550
+ pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask]
551
+ de_genes = de_genes[pvals_adj < pval_cutoff]
552
+ if len(de_genes) < min_de_genes:
553
+ de_genes = np.array([])
554
+ perturbation_markers[(category, gene)] = de_genes
554
555
 
555
556
  return perturbation_markers
556
557
 
@@ -711,7 +712,10 @@ class Mixscape:
711
712
  if "mixscape_class" not in adata.obs:
712
713
  raise ValueError("Please run `pt.tl.mixscape` first.")
713
714
  adata_subset = adata[(adata.obs[labels] == target_gene) | (adata.obs[labels] == control)].copy()
714
- sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
715
+ with warnings.catch_warnings():
716
+ warnings.simplefilter("ignore", RuntimeWarning)
717
+ warnings.simplefilter("ignore", PerformanceWarning)
718
+ sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
715
719
  sc.pp.scale(adata_subset, max_value=vmax)
716
720
  sc.pp.subsample(adata_subset, n_obs=subsample_number)
717
721
 
@@ -998,8 +1002,7 @@ class Mixscape:
998
1002
  ys = keys
999
1003
 
1000
1004
  if multi_panel and groupby is None and len(ys) == 1:
1001
- # This is a quick and dirty way for adapting scales across several
1002
- # keys if groupby is None.
1005
+ # This is a quick and dirty way for adapting scales across several keys if groupby is None.
1003
1006
  y = ys[0]
1004
1007
 
1005
1008
  g = sns.catplot(
@@ -226,7 +226,7 @@ class MLPClassifierSpace(PerturbationSpace):
226
226
  # Fix class unbalance (likely to happen in perturbation datasets)
227
227
  # Usually control cells are overrepresented such that predicting control all time would give good results
228
228
  # Cells with rare perturbations are sampled more
229
- train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels), dim=1))
229
+ train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels.to_list()), dim=1))
230
230
  train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
231
231
 
232
232
  self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
@@ -1,21 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING
3
+ from typing import Literal
4
4
 
5
- import matplotlib.pyplot as plt
6
5
  import numpy as np
6
+ import scanpy as sc
7
7
  from anndata import AnnData
8
- from decoupler import get_pseudobulk as dc_get_pseudobulk
9
- from decoupler import plot_psbulk_samples as dc_plot_psbulk_samples
10
8
  from sklearn.cluster import DBSCAN, KMeans
11
9
 
12
- from pertpy._doc import _doc_params, doc_common_plot_args
13
10
  from pertpy.tools._perturbation_space._clustering import ClusteringSpace
14
11
  from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
15
12
 
16
- if TYPE_CHECKING:
17
- from matplotlib.pyplot import Figure
18
-
19
13
 
20
14
  class CentroidSpace(PerturbationSpace):
21
15
  """Computes the centroids per perturbation of a pre-computed embedding."""
@@ -126,9 +120,9 @@ class PseudobulkSpace(PerturbationSpace):
126
120
  groups_col: str = None,
127
121
  layer_key: str = None,
128
122
  embedding_key: str = None,
129
- **kwargs,
123
+ mode: Literal["count_nonzero", "mean", "sum", "var", "median"] = "sum",
130
124
  ) -> AnnData: # type: ignore
131
- """Determines pseudobulks of an AnnData object. It uses Decoupler implementation.
125
+ """Determines pseudobulks of an AnnData object.
132
126
 
133
127
  Args:
134
128
  adata: Anndata object of size cells x genes
@@ -137,7 +131,7 @@ class PseudobulkSpace(PerturbationSpace):
137
131
  The summarized expression per perturbation (target_col) and group (groups_col) is computed.
138
132
  layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
139
133
  embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
140
- **kwargs: Are passed to decoupler's get_pseuobulk.
134
+ mode: Pseudobulk aggregation function
141
135
 
142
136
  Returns:
143
137
  AnnData object with one observation per perturbation.
@@ -167,53 +161,16 @@ class PseudobulkSpace(PerturbationSpace):
167
161
  adata = adata_emb
168
162
 
169
163
  adata.obs[target_col] = adata.obs[target_col].astype("category")
170
- ps_adata = dc_get_pseudobulk(adata, sample_col=target_col, layer=layer_key, groups_col=groups_col, **kwargs) # type: ignore
164
+ ps_adata = sc.get.aggregate(
165
+ adata, by=[target_col] if groups_col is None else [target_col, groups_col], func=mode, layer=layer_key
166
+ )
167
+ if mode in ps_adata.layers:
168
+ ps_adata.X = ps_adata.layers[mode]
171
169
 
172
170
  ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")
173
171
 
174
172
  return ps_adata
175
173
 
176
- @_doc_params(common_plot_args=doc_common_plot_args)
177
- def plot_psbulk_samples( # pragma: no cover # noqa: D417
178
- self,
179
- adata: AnnData,
180
- groupby: str,
181
- *,
182
- return_fig: bool = False,
183
- **kwargs,
184
- ) -> Figure | None:
185
- """Plot the pseudobulk samples of an AnnData object.
186
-
187
- Plot the count number vs. the number of cells per pseudobulk sample.
188
-
189
- Args:
190
- adata: Anndata containing pseudobulk samples.
191
- groupby: `.obs` column to color the samples by.
192
- {common_plot_args}
193
- **kwargs: Are passed to decoupler's plot_psbulk_samples.
194
-
195
- Returns:
196
- If `return_fig` is `True`, returns the figure, otherwise `None`.
197
-
198
- Examples:
199
- >>> import pertpy as pt
200
- >>> adata = pt.dt.zhang_2021()
201
- >>> ps = pt.tl.PseudobulkSpace()
202
- >>> pdata = ps.compute(
203
- ... adata, target_col="Patient", groups_col="Cluster", mode="sum", min_cells=10, min_counts=1000
204
- ... )
205
- >>> ps.plot_psbulk_samples(pdata, groupby=["Patient", "Major celltype"], figsize=(12, 4))
206
-
207
- Preview:
208
- .. image:: /_static/docstring_previews/pseudobulk_samples.png
209
- """
210
- fig = dc_plot_psbulk_samples(adata, groupby, return_fig=True, **kwargs)
211
-
212
- if return_fig:
213
- return fig
214
- plt.show()
215
- return None
216
-
217
174
 
218
175
  class KMeansSpace(ClusteringSpace):
219
176
  """Computes K-Means clustering of the expression values."""
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from typing import TYPE_CHECKING, Any
4
4
 
5
+ import anndata as ad
5
6
  import jax.numpy as jnp
6
7
  import matplotlib.pyplot as plt
7
8
  import numpy as np
@@ -248,7 +249,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
248
249
  temp_cell[batch_ind[study]].X = batch_list[study].X
249
250
  shared_ct.append(temp_cell)
250
251
 
251
- all_shared_ann = AnnData.concatenate(*shared_ct, batch_key="concat_batch", index_unique=None)
252
+ all_shared_ann = ad.concat(shared_ct, label="concat_batch", index_unique=None)
252
253
  if "concat_batch" in all_shared_ann.obs.columns:
253
254
  del all_shared_ann.obs["concat_batch"]
254
255
  if len(not_shared_ct) < 1:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pertpy
3
- Version: 0.11.3
3
+ Version: 0.11.5
4
4
  Summary: Perturbation Analysis in the scverse ecosystem.
5
5
  Project-URL: Documentation, https://pertpy.readthedocs.io
6
6
  Project-URL: Source, https://github.com/scverse/pertpy
@@ -49,7 +49,6 @@ Requires-Python: <3.14,>=3.11
49
49
  Requires-Dist: adjusttext
50
50
  Requires-Dist: arviz
51
51
  Requires-Dist: blitzgsea
52
- Requires-Dist: decoupler
53
52
  Requires-Dist: fast-array-utils
54
53
  Requires-Dist: lamin-utils
55
54
  Requires-Dist: mudata
@@ -132,6 +131,12 @@ You can install _pertpy_ in less than a minute via [pip] from [PyPI]:
132
131
  pip install pertpy
133
132
  ```
134
133
 
134
+ or [conda-forge]:
135
+
136
+ ```console
137
+ conda install -c conda-forge pertpy
138
+ ```
139
+
135
140
  ### Differential gene expression
136
141
 
137
142
  If you want to use the differential gene expression interface, please install pertpy by running:
@@ -180,6 +185,7 @@ pip install rpy2
180
185
  [pip]: https://pip.pypa.io/
181
186
  [pypi]: https://pypi.org/
182
187
  [api]: https://pertpy.readthedocs.io/en/latest/api.html
188
+ [conda-forge]: https://anaconda.org/conda-forge/pertpy
183
189
  [//]: # "numfocus-fiscal-sponsor-attribution"
184
190
 
185
191
  pertpy is part of the scverse® project ([website](https://scverse.org), [governance](https://scverse.org/about/roles)) and is fiscally sponsored by [NumFOCUS](https://numfocus.org/).
@@ -1,4 +1,4 @@
1
- pertpy/__init__.py,sha256=nuIzUydNMtNnhu1E4ffxU9gCgIdLi3liUv7dylR-2_I,716
1
+ pertpy/__init__.py,sha256=KIxMlqyHlppcGM5Uc2HpTwCEtGFavXRPW50dM5dFB7U,716
2
2
  pertpy/_doc.py,sha256=j5TMNC-DA9yIMqIIUNpjpcVgWfRqyBBfvbRjnCM_OLs,427
3
3
  pertpy/_types.py,sha256=IcHCojCUqx8CapibNkcYf2TUqjBFP2ujeELvn_IBSBQ,154
4
4
  pertpy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -19,18 +19,18 @@ pertpy/preprocessing/_guide_rna_mixture.py,sha256=pT_YkjmN4iEJ-THBROu_dpbr8E6u8G
19
19
  pertpy/tools/__init__.py,sha256=xjfw3Dd_KGytjjCID0uEs6Fz7DalH46fCjVL2Zf2kOo,2629
20
20
  pertpy/tools/_augur.py,sha256=tc1YKyc0BwzrEGgctsfyy7DsTNKxyvy7ZvWraTWCc1A,55262
21
21
  pertpy/tools/_cinemaot.py,sha256=54-rS0AEj31dMe7iU4kEmLoAunq3jNuhsBE3IEp9hrI,38071
22
- pertpy/tools/_dialogue.py,sha256=cCSwo9ge1pOLoA7QHTPb3b865juCFWUaKX5aD7UoSjo,52355
22
+ pertpy/tools/_dialogue.py,sha256=mygIZm5i_bnEE37TTQtr1efl_KJq-ejzeL3V1Bmr7Pg,52354
23
23
  pertpy/tools/_enrichment.py,sha256=55mwotLH9DXQOhl85MCkxXu-MX0RysLyrPheJysAnF0,21369
24
- pertpy/tools/_milo.py,sha256=r-kZcpAcoQuhi41AnVuzh-cMIcV3HB3-RGzynHyDc1A,43712
25
- pertpy/tools/_mixscape.py,sha256=qjXGyH-oeBFte0efuHJfhVEbivnzUVWREwC40ef6Se8,57203
24
+ pertpy/tools/_milo.py,sha256=zIYG0aP8B39_eiNgpZONhTKmDvcRwCzOLo5FMOTMUms,45530
25
+ pertpy/tools/_mixscape.py,sha256=HfrpBeRlxHXaOpZkF2FmX7dg35kUB1rL0_-n2aSi2_0,57905
26
26
  pertpy/tools/decoupler_LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
27
27
  pertpy/tools/transferlearning_MMD_LICENSE,sha256=MUvDA-o_j9htRpI8fStVdCRuyLdPkQUuIH0a_EIc57w,1069
28
28
  pertpy/tools/_coda/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
- pertpy/tools/_coda/_base_coda.py,sha256=-rpdipPLFd4cFXyLMN7uFgv-pFJseKaqDmyWRBrGfws,111519
29
+ pertpy/tools/_coda/_base_coda.py,sha256=NjKIQBtTIUENnRmeIC2O8cMdU_9DKaJ5_AHPvFnc8XQ,111744
30
30
  pertpy/tools/_coda/_sccoda.py,sha256=0Ret6O56kAfCNOdBvtxqiyuj2rUPp18SV1GVK1AvYGU,22607
31
31
  pertpy/tools/_coda/_tasccoda.py,sha256=BTaOAmL458zQ_og3x4ENlDnJHD6_F4YkdCoXWsF4i1U,30465
32
32
  pertpy/tools/_differential_gene_expression/__init__.py,sha256=SEydWg0iT3Y1pApjnCAOuHxFeI6xVUfgyBHv2s3LADU,487
33
- pertpy/tools/_differential_gene_expression/_base.py,sha256=PpfH_RZXsN79Bu0yBFPE9TXEqNsZ4bSzSbhM0wZou2I,38322
33
+ pertpy/tools/_differential_gene_expression/_base.py,sha256=ELx0e9DChJGO3yRpCLUOykt3oNOyDAPOQZZGSwzBSR0,38265
34
34
  pertpy/tools/_differential_gene_expression/_checks.py,sha256=hH_GP0lWGO-5zrCFX4YiIVCZBCuK0ZJ0jFmdlx2Qm4k,1639
35
35
  pertpy/tools/_differential_gene_expression/_dge_comparison.py,sha256=LXhp5djKKCAk9VI7OqxOuja849G5lnd8Ehcs9Epk8rg,4159
36
36
  pertpy/tools/_differential_gene_expression/_edger.py,sha256=nSHMDA4drGq_sJwUXs5I2EbMHwqjiS08GqOU_1_dXPc,4798
@@ -39,20 +39,20 @@ pertpy/tools/_differential_gene_expression/_simple_tests.py,sha256=SfU8s_P2JzEA1
39
39
  pertpy/tools/_differential_gene_expression/_statsmodels.py,sha256=90h9EPuoCtNxAbJ1Xq4j_E4yYJJpk64zTP7GyTdmrxY,2220
40
40
  pertpy/tools/_distances/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
41
41
  pertpy/tools/_distances/_distance_tests.py,sha256=6_nqfHUfKxkI2Yhkzspq3ujMpq56zV_Ddn7bgPzgjyo,13513
42
- pertpy/tools/_distances/_distances.py,sha256=89d1zShW_9dhphup2oWx5hMOFC7RdogOY56doMuBFts,50473
42
+ pertpy/tools/_distances/_distances.py,sha256=_XbVU8dlYt_Jl2thYPUWg7HT6OXVe-Ki6qthF566sqQ,50503
43
43
  pertpy/tools/_perturbation_space/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
44
  pertpy/tools/_perturbation_space/_clustering.py,sha256=pNx_SpPkZfCbgF7vzHWqAaiiHdbxPaA-L-hTWTbzFhI,3528
45
45
  pertpy/tools/_perturbation_space/_comparison.py,sha256=-NzCPRT-IlhJ9hOz7NQLSk0riIzr2C0yZvX6zm3kon4,4291
46
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py,sha256=gDid9Z1_AAPHPWuNgAkbP7yrgcC0qjjqTuWjTzTAAZo,23373
46
+ pertpy/tools/_perturbation_space/_discriminator_classifiers.py,sha256=a53-YmUwDHQBCT7ZWe_RH7PZsGXvoSHmJaQyL0CBJng,23383
47
47
  pertpy/tools/_perturbation_space/_metrics.py,sha256=y8-baP8WRdB1iDgvP3uuQxSCDxA2lcxvEHHM2C_vWHY,3248
48
48
  pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=8RxVUkVEPZj5YZ-C-NP5zO4aYYVD04PzlsYuaIG-wjY,19447
49
- pertpy/tools/_perturbation_space/_simple.py,sha256=nnagHJ_aPv4ZCqfnVLdVUT_JShtIXg7iEP_sCMD3JLY,14271
49
+ pertpy/tools/_perturbation_space/_simple.py,sha256=AJlHRaEP-vViBeMDvvMtUnXMuIKqZVc7wggnjsHMfMw,12721
50
50
  pertpy/tools/_scgen/__init__.py,sha256=uERFlFyF88TH0uLiwmsUGEfHfLVCiZMFuk8gO5f7164,45
51
51
  pertpy/tools/_scgen/_base_components.py,sha256=Qq8myRUm43q9XBrZ9gBggfa2cSV2wbz_KYoLgH7iF1A,3009
52
- pertpy/tools/_scgen/_scgen.py,sha256=31T8ez0FxABIbunJHCk8xvGulHFb8RHXSsyM_z1WsPY,30850
52
+ pertpy/tools/_scgen/_scgen.py,sha256=AQNGsDe-9HEqli3oq7UBDg68ofLCoXm-R_jnLFQ-rlc,30856
53
53
  pertpy/tools/_scgen/_scgenvae.py,sha256=bPk4v7EdJc7ROdLuDitHiX_Pvwa7Flw2qHRUwBvjLJY,3889
54
54
  pertpy/tools/_scgen/_utils.py,sha256=qz5QUn_Bvk2NGyYVzp3jgjWTFOMt1YyHwUo6HWtoThY,2871
55
- pertpy-0.11.3.dist-info/METADATA,sha256=PM1yN_AADeouMbFT9X2m4Qv4VfPJ2PFgAbC2FNScIXs,8726
56
- pertpy-0.11.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
57
- pertpy-0.11.3.dist-info/licenses/LICENSE,sha256=XuiT2hxeRInhquEIBKMZ5M21n5syhDQ4XbABoposIAg,1100
58
- pertpy-0.11.3.dist-info/RECORD,,
55
+ pertpy-0.11.5.dist-info/METADATA,sha256=YEYgYTHkjmyWyboRL3RhBaSxOw86O5vr0wpXdvaLTGk,8827
56
+ pertpy-0.11.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
57
+ pertpy-0.11.5.dist-info/licenses/LICENSE,sha256=XuiT2hxeRInhquEIBKMZ5M21n5syhDQ4XbABoposIAg,1100
58
+ pertpy-0.11.5.dist-info/RECORD,,