pertpy 0.11.4__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pertpy/__init__.py CHANGED
@@ -2,10 +2,11 @@
2
2
 
3
3
  __author__ = "Lukas Heumos"
4
4
  __email__ = "lukas.heumos@posteo.net"
5
- __version__ = "0.11.4"
5
+ __version__ = "1.0.0"
6
6
 
7
7
  import warnings
8
8
 
9
+ from anndata._core.aligned_df import ImplicitModificationWarning
9
10
  from matplotlib import MatplotlibDeprecationWarning
10
11
  from numba import NumbaDeprecationWarning
11
12
 
@@ -13,6 +14,8 @@ warnings.filterwarnings("ignore", category=NumbaDeprecationWarning)
13
14
  warnings.filterwarnings("ignore", category=MatplotlibDeprecationWarning)
14
15
  warnings.filterwarnings("ignore", category=SyntaxWarning)
15
16
  warnings.filterwarnings("ignore", category=UserWarning, module="scvi._settings")
17
+ warnings.filterwarnings("ignore", message="Environment variable.*redefined by R")
18
+ warnings.filterwarnings("ignore", message="Transforming to str index.", category=ImplicitModificationWarning)
16
19
 
17
20
  import mudata
18
21
 
@@ -1538,7 +1538,7 @@ class CompositionalModel2(ABC):
1538
1538
  if isinstance(data, MuData):
1539
1539
  data = data[modality_key]
1540
1540
  if isinstance(palette, Colormap):
1541
- palette = palette(range(2))
1541
+ palette = list(palette(range(len(data.obs[feature_name].unique()))))
1542
1542
 
1543
1543
  # y scale transformations
1544
1544
  if y_scale == "relative":
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  from abc import ABC, abstractmethod
4
4
  from typing import TYPE_CHECKING, Literal, NamedTuple
5
5
 
6
+ import jax
6
7
  import numpy as np
7
8
  import pandas as pd
8
9
  from numba import jit
@@ -685,6 +686,7 @@ class WassersteinDistance(AbstractDistance):
685
686
  def __init__(self) -> None:
686
687
  super().__init__()
687
688
  self.accepts_precomputed = False
689
+ self.solver = jax.jit(Sinkhorn())
688
690
 
689
691
  def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
690
692
  X = np.asarray(X, dtype=np.float64)
@@ -699,8 +701,7 @@ class WassersteinDistance(AbstractDistance):
699
701
 
700
702
  def solve_ot_problem(self, geom: Geometry, **kwargs):
701
703
  ot_prob = LinearProblem(geom)
702
- solver = Sinkhorn()
703
- ot = solver(ot_prob, **kwargs)
704
+ ot = self.solver(ot_prob, **kwargs)
704
705
  cost = float(ot.reg_ot_cost)
705
706
 
706
707
  # Check for NaN or invalid cost
pertpy/tools/_milo.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import random
4
4
  import re
5
+ from importlib.util import find_spec
5
6
  from typing import TYPE_CHECKING, Literal
6
7
 
7
8
  import matplotlib.pyplot as plt
@@ -29,18 +30,6 @@ from sklearn.metrics.pairwise import euclidean_distances
29
30
  class Milo:
30
31
  """Python implementation of Milo."""
31
32
 
32
- def __init__(self):
33
- try:
34
- from rpy2.robjects import conversion, numpy2ri, pandas2ri
35
- from rpy2.robjects.packages import STAP, PackageNotInstalledError, importr
36
- except ModuleNotFoundError:
37
- raise ImportError("milo requires rpy2 to be installed.") from None
38
-
39
- try:
40
- importr("edgeR")
41
- except ImportError as e:
42
- raise ImportError("milo requires a valid R installation with edger installed:\n") from e
43
-
44
33
  def load(
45
34
  self,
46
35
  input: AnnData,
@@ -266,7 +255,7 @@ class Milo:
266
255
  subset_samples: list[str] | None = None,
267
256
  add_intercept: bool = True,
268
257
  feature_key: str | None = "rna",
269
- solver: Literal["edger", "batchglm"] = "edger",
258
+ solver: Literal["edger", "pydeseq2"] = "edger",
270
259
  ):
271
260
  """Performs differential abundance testing on neighbourhoods using QLF test implementation as implemented in edgeR.
272
261
 
@@ -279,7 +268,9 @@ class Milo:
279
268
  subset_samples: subset of samples (obs in `milo_mdata['milo']`) to use for the test.
280
269
  add_intercept: whether to include an intercept in the model. If False, this is equivalent to adding + 0 in the design formula. When model_contrasts is specified, this is set to False by default.
281
270
  feature_key: If input data is MuData, specify key to cell-level AnnData object.
282
- solver: The solver to fit the model to. One of "edger" (requires R, rpy2 and edgeR to be installed) or "batchglm"
271
+ solver: The solver to fit the model to.
272
+ The "edger" solver requires R, rpy2 and edgeR to be installed and is the closest to the R implementation.
273
+ The "pydeseq2" requires pydeseq2 to be installed. It is still very comparable to the "edger" solver but might be a bit slower.
283
274
 
284
275
  Returns:
285
276
  None, modifies `milo_mdata['milo']` in place, adding the results of the DA test to `.var`:
@@ -298,7 +289,6 @@ class Milo:
298
289
  >>> milo.make_nhoods(mdata["rna"])
299
290
  >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
300
291
  >>> milo.da_nhoods(mdata, design="~label")
301
-
302
292
  """
303
293
  try:
304
294
  sample_adata = mdata["milo"]
@@ -364,19 +354,32 @@ class Milo:
364
354
  # Set up rpy2 to run edgeR
365
355
  edgeR, limma, stats, base = self._setup_rpy2()
366
356
 
357
+ import rpy2.robjects as ro
358
+ from rpy2.robjects import numpy2ri, pandas2ri
359
+ from rpy2.robjects.conversion import localconverter
360
+ from rpy2.robjects.vectors import FloatVector
361
+
367
362
  # Define model matrix
368
363
  if not add_intercept or model_contrasts is not None:
369
364
  design = design + " + 0"
370
- model = stats.model_matrix(object=stats.formula(design), data=design_df)
365
+ design_df = design_df.astype(dict.fromkeys(design_df.select_dtypes(exclude=["number"]).columns, "category"))
366
+ with localconverter(ro.default_converter + pandas2ri.converter):
367
+ design_r = pandas2ri.py2rpy(design_df)
368
+ formula_r = stats.formula(design)
369
+ model = stats.model_matrix(object=formula_r, data=design_r)
371
370
 
372
371
  # Fit NB-GLM
373
- dge = edgeR.DGEList(counts=count_mat[keep_nhoods, :][:, keep_smp], lib_size=lib_size[keep_smp])
372
+ counts_filtered = count_mat[np.ix_(keep_nhoods, keep_smp)]
373
+ lib_size_filtered = lib_size[keep_smp]
374
+ count_mat_r = numpy2ri.py2rpy(counts_filtered)
375
+ lib_size_r = FloatVector(lib_size_filtered)
376
+ dge = edgeR.DGEList(counts=count_mat_r, lib_size=lib_size_r)
374
377
  dge = edgeR.calcNormFactors(dge, method="TMM")
375
378
  dge = edgeR.estimateDisp(dge, model)
376
379
  fit = edgeR.glmQLFit(dge, model, robust=True)
377
-
378
380
  # Test
379
- n_coef = model.shape[1]
381
+ model_np = np.array(model)
382
+ n_coef = model_np.shape[1]
380
383
  if model_contrasts is not None:
381
384
  r_str = """
382
385
  get_model_cols <- function(design_df, design){
@@ -387,34 +390,90 @@ class Milo:
387
390
  from rpy2.robjects.packages import STAP
388
391
 
389
392
  get_model_cols = STAP(r_str, "get_model_cols")
390
- model_mat_cols = get_model_cols.get_model_cols(design_df, design)
391
- model_df = pd.DataFrame(model)
393
+ with localconverter(ro.default_converter + numpy2ri.converter + pandas2ri.converter):
394
+ model_mat_cols = get_model_cols.get_model_cols(design_df, design)
395
+ with localconverter(ro.default_converter + pandas2ri.converter + numpy2ri.converter):
396
+ model_df = pandas2ri.rpy2py(model)
397
+ model_df = pd.DataFrame(model_df)
392
398
  model_df.columns = model_mat_cols
393
399
  try:
394
- mod_contrast = limma.makeContrasts(contrasts=model_contrasts, levels=model_df)
400
+ with localconverter(ro.default_converter + pandas2ri.converter):
401
+ mod_contrast = limma.makeContrasts(contrasts=model_contrasts, levels=model_df)
395
402
  except ValueError:
396
403
  logger.error("Model contrasts must be in the form 'A-B' or 'A+B'")
397
404
  raise
398
- res = base.as_data_frame(
399
- edgeR.topTags(edgeR.glmQLFTest(fit, contrast=mod_contrast), sort_by="none", n=np.inf)
400
- )
405
+ with localconverter(ro.default_converter + pandas2ri.converter + numpy2ri.converter):
406
+ res = base.as_data_frame(
407
+ edgeR.topTags(edgeR.glmQLFTest(fit, contrast=mod_contrast), sort_by="none", n=np.inf)
408
+ )
401
409
  else:
402
- res = base.as_data_frame(edgeR.topTags(edgeR.glmQLFTest(fit, coef=n_coef), sort_by="none", n=np.inf))
403
-
404
- from rpy2.robjects import conversion
405
-
406
- res = conversion.rpy2py(res)
410
+ with localconverter(ro.default_converter + numpy2ri.converter + pandas2ri.converter):
411
+ res = base.as_data_frame(
412
+ edgeR.topTags(edgeR.glmQLFTest(fit, coef=n_coef), sort_by="none", n=np.inf)
413
+ )
407
414
  if not isinstance(res, pd.DataFrame):
408
415
  res = pd.DataFrame(res)
416
+ # The columns of res looks like e.g. table.A, table.B, so remove the prefix
417
+ res.columns = [col.replace("table.", "") for col in res.columns]
418
+ elif solver == "pydeseq2":
419
+ if find_spec("pydeseq2") is None:
420
+ raise ImportError("pydeseq2 is required but not installed. Install with: pip install pydeseq2")
421
+
422
+ from pydeseq2.dds import DeseqDataSet
423
+ from pydeseq2.ds import DeseqStats
424
+
425
+ counts_filtered = count_mat[np.ix_(keep_nhoods, keep_smp)]
426
+ design_df_filtered = design_df.iloc[keep_smp].copy()
427
+
428
+ design_df_filtered = design_df_filtered.astype(
429
+ dict.fromkeys(design_df_filtered.select_dtypes(exclude=["number"]).columns, "category")
430
+ )
431
+
432
+ design_clean = design if design.startswith("~") else f"~{design}"
433
+
434
+ dds = DeseqDataSet(
435
+ counts=pd.DataFrame(counts_filtered.T, index=design_df_filtered.index),
436
+ metadata=design_df_filtered,
437
+ design=design_clean,
438
+ refit_cooks=True,
439
+ )
440
+
441
+ dds.deseq2()
442
+
443
+ if model_contrasts is not None and "-" in model_contrasts:
444
+ if "(" in model_contrasts or "+" in model_contrasts.split("-")[1]:
445
+ raise ValueError(
446
+ f"Complex contrasts like '{model_contrasts}' are not supported by pydeseq2. "
447
+ "Use simple pairwise contrasts (e.g., 'GroupA-GroupB') or switch to solver='edger'."
448
+ )
449
+
450
+ parts = model_contrasts.split("-")
451
+ factor_name = design_clean.replace("~", "").split("+")[-1].strip()
452
+ group1 = parts[0].replace(factor_name, "").strip()
453
+ group2 = parts[1].replace(factor_name, "").strip()
454
+ stat_res = DeseqStats(dds, contrast=[factor_name, group1, group2])
455
+ else:
456
+ factor_name = design_clean.replace("~", "").split("+")[-1].strip()
457
+ if not isinstance(design_df_filtered[factor_name], pd.CategoricalDtype):
458
+ design_df_filtered[factor_name] = design_df_filtered[factor_name].astype("category")
459
+ categories = design_df_filtered[factor_name].cat.categories
460
+ stat_res = DeseqStats(dds, contrast=[factor_name, categories[-1], categories[0]])
461
+
462
+ stat_res.summary()
463
+ res = stat_res.results_df
464
+
465
+ res = res.rename(
466
+ columns={"baseMean": "logCPM", "log2FoldChange": "logFC", "pvalue": "PValue", "padj": "FDR"}
467
+ )
468
+
469
+ res = res[["logCPM", "logFC", "PValue", "FDR"]]
409
470
 
410
- # Save outputs
411
471
  res.index = sample_adata.var_names[keep_nhoods] # type: ignore
412
472
  if any(col in sample_adata.var.columns for col in res.columns):
413
473
  sample_adata.var = sample_adata.var.drop(res.columns, axis=1)
414
474
  sample_adata.var = pd.concat([sample_adata.var, res], axis=1)
415
475
 
416
- # Run Graph spatial FDR correction
417
- self._graph_spatial_fdr(sample_adata, neighbors_key=adata.uns["nhood_neighbors_key"])
476
+ self._graph_spatial_fdr(sample_adata)
418
477
 
419
478
  def annotate_nhoods(
420
479
  self,
@@ -657,11 +716,19 @@ class Milo:
657
716
  self,
658
717
  ):
659
718
  """Set up rpy2 to run edgeR."""
660
- from rpy2.robjects import numpy2ri, pandas2ri
719
+ try:
720
+ from rpy2.robjects import conversion, numpy2ri, pandas2ri
721
+ from rpy2.robjects.packages import STAP, PackageNotInstalledError, importr
722
+ except ModuleNotFoundError:
723
+ raise ImportError("milo requires rpy2 to be installed.") from None
724
+
725
+ try:
726
+ importr("edgeR")
727
+ except ImportError as e:
728
+ raise ImportError("milo requires a valid R installation with edger installed.") from e
729
+
661
730
  from rpy2.robjects.packages import importr
662
731
 
663
- numpy2ri.activate()
664
- pandas2ri.activate()
665
732
  edgeR = self._try_import_bioc_library("edgeR")
666
733
  limma = self._try_import_bioc_library("limma")
667
734
  stats = importr("stats")
@@ -671,26 +738,27 @@ class Milo:
671
738
 
672
739
  def _try_import_bioc_library(
673
740
  self,
674
- name: str,
741
+ r_package: str,
675
742
  ):
676
743
  """Import R packages.
677
744
 
678
745
  Args:
679
- name (str): R packages name
746
+ r_package: R packages name
680
747
  """
681
748
  from rpy2.robjects.packages import PackageNotInstalledError, importr
682
749
 
683
750
  try:
684
- _r_lib = importr(name)
751
+ _r_lib = importr(r_package)
685
752
  return _r_lib
686
753
  except PackageNotInstalledError:
687
- logger.error(f"Install Bioconductor library `{name!r}` first as `BiocManager::install({name!r}).`")
754
+ logger.error(
755
+ f"Install Bioconductor library `{r_package!r}` first as `BiocManager::install({r_package!r}).`"
756
+ )
688
757
  raise
689
758
 
690
759
  def _graph_spatial_fdr(
691
760
  self,
692
761
  sample_adata: AnnData,
693
- neighbors_key: str | None = None,
694
762
  ):
695
763
  """FDR correction weighted on inverse of connectivity of neighbourhoods.
696
764
 
@@ -698,7 +766,6 @@ class Milo:
698
766
 
699
767
  Args:
700
768
  sample_adata: Sample-level AnnData.
701
- neighbors_key: The key in `adata.obsp` to use as KNN graph.
702
769
  """
703
770
  # use 1/connectivity as the weighting for the weighted BH adjustment from Cydar
704
771
  w = 1 / sample_adata.var["kth_distance"]
@@ -1007,6 +1074,8 @@ class Milo:
1007
1074
  subset_nhoods: list[str] = None,
1008
1075
  log_counts: bool = False,
1009
1076
  return_fig: bool = False,
1077
+ ax=None,
1078
+ show: bool = True,
1010
1079
  ) -> Figure | None:
1011
1080
  """Plot boxplot of cell numbers vs condition of interest.
1012
1081
 
@@ -1036,18 +1105,36 @@ class Milo:
1036
1105
  pl_df = pd.merge(pl_df, nhood_adata.var)
1037
1106
  pl_df["log_n_cells"] = np.log1p(pl_df["n_cells"])
1038
1107
  if not log_counts:
1039
- sns.boxplot(data=pl_df, x=test_var, y="n_cells", color="lightblue")
1040
- sns.stripplot(data=pl_df, x=test_var, y="n_cells", color="black", s=3)
1041
- plt.ylabel("# cells")
1108
+ sns.boxplot(data=pl_df, x=test_var, y="n_cells", color="lightblue", ax=ax)
1109
+ sns.stripplot(data=pl_df, x=test_var, y="n_cells", color="black", s=3, ax=ax)
1110
+ if ax:
1111
+ ax.set_ylabel("# cells")
1112
+ else:
1113
+ plt.ylabel("# cells")
1114
+ else:
1115
+ sns.boxplot(data=pl_df, x=test_var, y="log_n_cells", color="lightblue", ax=ax)
1116
+ sns.stripplot(data=pl_df, x=test_var, y="log_n_cells", color="black", s=3, ax=ax)
1117
+ if ax:
1118
+ ax.set_ylabel("log(# cells + 1)")
1119
+ else:
1120
+ plt.ylabel("log(# cells + 1)")
1121
+
1122
+ if ax:
1123
+ ax.tick_params(axis="x", rotation=90)
1124
+ ax.set_xlabel(test_var)
1042
1125
  else:
1043
- sns.boxplot(data=pl_df, x=test_var, y="log_n_cells", color="lightblue")
1044
- sns.stripplot(data=pl_df, x=test_var, y="log_n_cells", color="black", s=3)
1045
- plt.ylabel("log(# cells + 1)")
1126
+ plt.xticks(rotation=90)
1127
+ plt.xlabel(test_var)
1046
1128
 
1047
- plt.xticks(rotation=90)
1048
- plt.xlabel(test_var)
1129
+ if return_fig:
1130
+ return plt.gcf()
1131
+
1132
+ if ax is None:
1133
+ plt.show()
1049
1134
 
1050
1135
  if return_fig:
1051
1136
  return plt.gcf()
1052
- plt.show()
1137
+ if show:
1138
+ plt.show()
1139
+
1053
1140
  return None
pertpy/tools/_mixscape.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import copy
4
+ import warnings
4
5
  from collections import OrderedDict
5
6
  from typing import TYPE_CHECKING, Literal
6
7
 
@@ -10,11 +11,12 @@ import pandas as pd
10
11
  import scanpy as sc
11
12
  import seaborn as sns
12
13
  from fast_array_utils.stats import mean, mean_var
14
+ from pandas.errors import PerformanceWarning
13
15
  from scanpy import get
14
16
  from scanpy._utils import _check_use_raw, sanitize_anndata
15
17
  from scanpy.plotting import _utils
16
18
  from scanpy.tools._utils import _choose_representation
17
- from scipy.sparse import csr_matrix, spmatrix
19
+ from scipy.sparse import csr_matrix, issparse, spmatrix
18
20
  from sklearn.mixture import GaussianMixture
19
21
 
20
22
  from pertpy._doc import _doc_params, doc_common_plot_args
@@ -103,6 +105,9 @@ class Mixscape:
103
105
 
104
106
  adata.layers["X_pert"] = adata.X.copy()
105
107
 
108
+ # Work with LIL for efficient indexing but don't store it in AnnData as LIL is not supported anymore
109
+ X_pert_lil = adata.layers["X_pert"].tolil() if issparse(adata.layers["X_pert"]) else adata.layers["X_pert"]
110
+
106
111
  control_mask = adata.obs[pert_key] == control
107
112
 
108
113
  if ref_selection_mode == "split_by":
@@ -110,9 +115,8 @@ class Mixscape:
110
115
  split_mask = adata.obs[split_by] == split
111
116
  control_mask_group = control_mask & split_mask
112
117
  control_mean_expr = mean(adata.X[control_mask_group], axis=0)
113
- adata.layers["X_pert"][split_mask] = (
114
- np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0)
115
- - adata.layers["X_pert"][split_mask]
118
+ X_pert_lil[split_mask] = (
119
+ np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0) - X_pert_lil[split_mask]
116
120
  )
117
121
  else:
118
122
  if split_by is None:
@@ -129,49 +133,43 @@ class Mixscape:
129
133
 
130
134
  for split_mask in split_masks:
131
135
  control_mask_split = control_mask & split_mask
132
-
133
136
  R_split = representation[split_mask]
134
137
  R_control = representation[np.asarray(control_mask_split)]
135
-
136
138
  eps = kwargs.pop("epsilon", 0.1)
137
139
  nn_index = NNDescent(R_control, **kwargs)
138
140
  indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
139
-
140
141
  X_control = np.expm1(adata.X[np.asarray(control_mask_split)])
141
-
142
142
  n_split = split_mask.sum()
143
143
  n_control = X_control.shape[0]
144
144
 
145
145
  if batch_size is None:
146
146
  col_indices = np.ravel(indices)
147
147
  row_indices = np.repeat(np.arange(n_split), n_neighbors)
148
-
149
148
  neigh_matrix = csr_matrix(
150
149
  (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)),
151
150
  shape=(n_split, n_control),
152
151
  )
153
152
  neigh_matrix /= n_neighbors
154
- adata.layers["X_pert"][np.asarray(split_mask)] = (
155
- sc.pp.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][np.asarray(split_mask)]
153
+ X_pert_lil[np.asarray(split_mask)] = (
154
+ sc.pp.log1p(neigh_matrix @ X_control) - X_pert_lil[np.asarray(split_mask)]
156
155
  )
157
156
  else:
158
157
  split_indices = np.where(split_mask)[0]
159
158
  for i in range(0, n_split, batch_size):
160
159
  size = min(i + batch_size, n_split)
161
160
  select = slice(i, size)
162
-
163
161
  batch = np.ravel(indices[select])
164
162
  split_batch = split_indices[select]
165
-
166
163
  size = size - i
167
-
168
164
  means_batch = X_control[batch]
169
165
  batch_reshaped = means_batch.reshape(size, n_neighbors, -1)
170
166
  means_batch, _ = mean_var(batch_reshaped, axis=1)
167
+ X_pert_lil[split_batch] = np.log1p(means_batch) - X_pert_lil[split_batch]
171
168
 
172
- adata.layers["X_pert"][split_batch] = (
173
- np.log1p(means_batch) - adata.layers["X_pert"][split_batch]
174
- )
169
+ if issparse(X_pert_lil):
170
+ adata.layers["X_pert"] = X_pert_lil.tocsr()
171
+ else:
172
+ adata.layers["X_pert"] = X_pert_lil
175
173
 
176
174
  if copy:
177
175
  return adata
@@ -531,26 +529,29 @@ class Mixscape:
531
529
  gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
532
530
  adata_split = adata[split_mask].copy()
533
531
  # find top DE genes between cells with targeting and non-targeting gRNAs
534
- sc.tl.rank_genes_groups(
535
- adata_split,
536
- layer=layer,
537
- groupby=labels,
538
- groups=gene_targets,
539
- reference=control,
540
- method=test_method,
541
- use_raw=False,
542
- )
543
- # get DE genes for each target gene
544
- for gene in gene_targets:
545
- logfc_threshold_mask = (
546
- np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
532
+ with warnings.catch_warnings():
533
+ warnings.simplefilter("ignore", RuntimeWarning)
534
+ warnings.simplefilter("ignore", PerformanceWarning)
535
+ sc.tl.rank_genes_groups(
536
+ adata_split,
537
+ layer=layer,
538
+ groupby=labels,
539
+ groups=gene_targets,
540
+ reference=control,
541
+ method=test_method,
542
+ use_raw=False,
547
543
  )
548
- de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask]
549
- pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask]
550
- de_genes = de_genes[pvals_adj < pval_cutoff]
551
- if len(de_genes) < min_de_genes:
552
- de_genes = np.array([])
553
- perturbation_markers[(category, gene)] = de_genes
544
+ # get DE genes for each target gene
545
+ for gene in gene_targets:
546
+ logfc_threshold_mask = (
547
+ np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
548
+ )
549
+ de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask]
550
+ pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask]
551
+ de_genes = de_genes[pvals_adj < pval_cutoff]
552
+ if len(de_genes) < min_de_genes:
553
+ de_genes = np.array([])
554
+ perturbation_markers[(category, gene)] = de_genes
554
555
 
555
556
  return perturbation_markers
556
557
 
@@ -711,7 +712,10 @@ class Mixscape:
711
712
  if "mixscape_class" not in adata.obs:
712
713
  raise ValueError("Please run `pt.tl.mixscape` first.")
713
714
  adata_subset = adata[(adata.obs[labels] == target_gene) | (adata.obs[labels] == control)].copy()
714
- sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
715
+ with warnings.catch_warnings():
716
+ warnings.simplefilter("ignore", RuntimeWarning)
717
+ warnings.simplefilter("ignore", PerformanceWarning)
718
+ sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=labels, method=method)
715
719
  sc.pp.scale(adata_subset, max_value=vmax)
716
720
  sc.pp.subsample(adata_subset, n_obs=subsample_number)
717
721
 
@@ -998,8 +1002,7 @@ class Mixscape:
998
1002
  ys = keys
999
1003
 
1000
1004
  if multi_panel and groupby is None and len(ys) == 1:
1001
- # This is a quick and dirty way for adapting scales across several
1002
- # keys if groupby is None.
1005
+ # This is a quick and dirty way for adapting scales across several keys if groupby is None.
1003
1006
  y = ys[0]
1004
1007
 
1005
1008
  g = sns.catplot(
@@ -226,7 +226,7 @@ class MLPClassifierSpace(PerturbationSpace):
226
226
  # Fix class unbalance (likely to happen in perturbation datasets)
227
227
  # Usually control cells are overrepresented such that predicting control all time would give good results
228
228
  # Cells with rare perturbations are sampled more
229
- train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels), dim=1))
229
+ train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels.to_list()), dim=1))
230
230
  train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
231
231
 
232
232
  self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
@@ -80,7 +80,7 @@ class PerturbationSpace:
80
80
  group_masks = (
81
81
  [(adata.obs[group_col] == sample) for sample in adata.obs[group_col].unique()]
82
82
  if group_col
83
- else [[True] * adata.n_obs]
83
+ else [np.array([True] * adata.n_obs)]
84
84
  )
85
85
 
86
86
  if layer_key:
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from typing import TYPE_CHECKING, Any
4
4
 
5
+ import anndata as ad
5
6
  import jax.numpy as jnp
6
7
  import matplotlib.pyplot as plt
7
8
  import numpy as np
@@ -248,7 +249,7 @@ class Scgen(JaxTrainingMixin, BaseModelClass):
248
249
  temp_cell[batch_ind[study]].X = batch_list[study].X
249
250
  shared_ct.append(temp_cell)
250
251
 
251
- all_shared_ann = AnnData.concatenate(*shared_ct, batch_key="concat_batch", index_unique=None)
252
+ all_shared_ann = ad.concat(shared_ct, label="concat_batch", index_unique=None)
252
253
  if "concat_batch" in all_shared_ann.obs.columns:
253
254
  del all_shared_ann.obs["concat_batch"]
254
255
  if len(not_shared_ct) < 1:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pertpy
3
- Version: 0.11.4
3
+ Version: 1.0.0
4
4
  Summary: Perturbation Analysis in the scverse ecosystem.
5
5
  Project-URL: Documentation, https://pertpy.readthedocs.io
6
6
  Project-URL: Source, https://github.com/scverse/pertpy
@@ -131,6 +131,12 @@ You can install _pertpy_ in less than a minute via [pip] from [PyPI]:
131
131
  pip install pertpy
132
132
  ```
133
133
 
134
+ or [conda-forge]:
135
+
136
+ ```console
137
+ conda install -c conda-forge pertpy
138
+ ```
139
+
134
140
  ### Differential gene expression
135
141
 
136
142
  If you want to use the differential gene expression interface, please install pertpy by running:
@@ -149,7 +155,13 @@ pip install 'pertpy[tcoda]'
149
155
 
150
156
  ### milo
151
157
 
152
- milo further requires edger, statmod, and rpy2 to be installed:
158
+ milo requires either the "de" extra for the "pydeseq2" solver:
159
+
160
+ ```console
161
+ pip install 'pertpy[de]'
162
+ ```
163
+
164
+ or, edger, statmod, and rpy2 for the "edger" solver:
153
165
 
154
166
  ```R
155
167
  BiocManager::install("edgeR")
@@ -179,6 +191,7 @@ pip install rpy2
179
191
  [pip]: https://pip.pypa.io/
180
192
  [pypi]: https://pypi.org/
181
193
  [api]: https://pertpy.readthedocs.io/en/latest/api.html
194
+ [conda-forge]: https://anaconda.org/conda-forge/pertpy
182
195
  [//]: # "numfocus-fiscal-sponsor-attribution"
183
196
 
184
197
  pertpy is part of the scverse® project ([website](https://scverse.org), [governance](https://scverse.org/about/roles)) and is fiscally sponsored by [NumFOCUS](https://numfocus.org/).
@@ -1,4 +1,4 @@
1
- pertpy/__init__.py,sha256=fJegZfFrqw0e5er2WVo0NzDOgeJ7DZD9M_rflPLoizQ,716
1
+ pertpy/__init__.py,sha256=cZHJ7PIOhtLkxJMlHbJ2rzei5xhLB4vg0c8AaIShfzc,972
2
2
  pertpy/_doc.py,sha256=j5TMNC-DA9yIMqIIUNpjpcVgWfRqyBBfvbRjnCM_OLs,427
3
3
  pertpy/_types.py,sha256=IcHCojCUqx8CapibNkcYf2TUqjBFP2ujeELvn_IBSBQ,154
4
4
  pertpy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -21,12 +21,12 @@ pertpy/tools/_augur.py,sha256=tc1YKyc0BwzrEGgctsfyy7DsTNKxyvy7ZvWraTWCc1A,55262
21
21
  pertpy/tools/_cinemaot.py,sha256=54-rS0AEj31dMe7iU4kEmLoAunq3jNuhsBE3IEp9hrI,38071
22
22
  pertpy/tools/_dialogue.py,sha256=mygIZm5i_bnEE37TTQtr1efl_KJq-ejzeL3V1Bmr7Pg,52354
23
23
  pertpy/tools/_enrichment.py,sha256=55mwotLH9DXQOhl85MCkxXu-MX0RysLyrPheJysAnF0,21369
24
- pertpy/tools/_milo.py,sha256=r-kZcpAcoQuhi41AnVuzh-cMIcV3HB3-RGzynHyDc1A,43712
25
- pertpy/tools/_mixscape.py,sha256=qjXGyH-oeBFte0efuHJfhVEbivnzUVWREwC40ef6Se8,57203
24
+ pertpy/tools/_milo.py,sha256=9yoB9gkBNujqYDTKOlH2v3wiWhs5PdCuB8RgZ3xVI0Y,48049
25
+ pertpy/tools/_mixscape.py,sha256=HfrpBeRlxHXaOpZkF2FmX7dg35kUB1rL0_-n2aSi2_0,57905
26
26
  pertpy/tools/decoupler_LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
27
27
  pertpy/tools/transferlearning_MMD_LICENSE,sha256=MUvDA-o_j9htRpI8fStVdCRuyLdPkQUuIH0a_EIc57w,1069
28
28
  pertpy/tools/_coda/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
- pertpy/tools/_coda/_base_coda.py,sha256=aw_aSB_NIUL0yQw2t-MUysxoXt1xdUDLK-pItRGUW3s,111703
29
+ pertpy/tools/_coda/_base_coda.py,sha256=NjKIQBtTIUENnRmeIC2O8cMdU_9DKaJ5_AHPvFnc8XQ,111744
30
30
  pertpy/tools/_coda/_sccoda.py,sha256=0Ret6O56kAfCNOdBvtxqiyuj2rUPp18SV1GVK1AvYGU,22607
31
31
  pertpy/tools/_coda/_tasccoda.py,sha256=BTaOAmL458zQ_og3x4ENlDnJHD6_F4YkdCoXWsF4i1U,30465
32
32
  pertpy/tools/_differential_gene_expression/__init__.py,sha256=SEydWg0iT3Y1pApjnCAOuHxFeI6xVUfgyBHv2s3LADU,487
@@ -39,20 +39,20 @@ pertpy/tools/_differential_gene_expression/_simple_tests.py,sha256=SfU8s_P2JzEA1
39
39
  pertpy/tools/_differential_gene_expression/_statsmodels.py,sha256=90h9EPuoCtNxAbJ1Xq4j_E4yYJJpk64zTP7GyTdmrxY,2220
40
40
  pertpy/tools/_distances/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
41
41
  pertpy/tools/_distances/_distance_tests.py,sha256=6_nqfHUfKxkI2Yhkzspq3ujMpq56zV_Ddn7bgPzgjyo,13513
42
- pertpy/tools/_distances/_distances.py,sha256=89d1zShW_9dhphup2oWx5hMOFC7RdogOY56doMuBFts,50473
42
+ pertpy/tools/_distances/_distances.py,sha256=_XbVU8dlYt_Jl2thYPUWg7HT6OXVe-Ki6qthF566sqQ,50503
43
43
  pertpy/tools/_perturbation_space/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
44
  pertpy/tools/_perturbation_space/_clustering.py,sha256=pNx_SpPkZfCbgF7vzHWqAaiiHdbxPaA-L-hTWTbzFhI,3528
45
45
  pertpy/tools/_perturbation_space/_comparison.py,sha256=-NzCPRT-IlhJ9hOz7NQLSk0riIzr2C0yZvX6zm3kon4,4291
46
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py,sha256=gDid9Z1_AAPHPWuNgAkbP7yrgcC0qjjqTuWjTzTAAZo,23373
46
+ pertpy/tools/_perturbation_space/_discriminator_classifiers.py,sha256=a53-YmUwDHQBCT7ZWe_RH7PZsGXvoSHmJaQyL0CBJng,23383
47
47
  pertpy/tools/_perturbation_space/_metrics.py,sha256=y8-baP8WRdB1iDgvP3uuQxSCDxA2lcxvEHHM2C_vWHY,3248
48
- pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=8RxVUkVEPZj5YZ-C-NP5zO4aYYVD04PzlsYuaIG-wjY,19447
48
+ pertpy/tools/_perturbation_space/_perturbation_space.py,sha256=Vyh15wWw9dcu2YUWhziQd2mA9-4IY8EC5dzkBT9HaIo,19457
49
49
  pertpy/tools/_perturbation_space/_simple.py,sha256=AJlHRaEP-vViBeMDvvMtUnXMuIKqZVc7wggnjsHMfMw,12721
50
50
  pertpy/tools/_scgen/__init__.py,sha256=uERFlFyF88TH0uLiwmsUGEfHfLVCiZMFuk8gO5f7164,45
51
51
  pertpy/tools/_scgen/_base_components.py,sha256=Qq8myRUm43q9XBrZ9gBggfa2cSV2wbz_KYoLgH7iF1A,3009
52
- pertpy/tools/_scgen/_scgen.py,sha256=31T8ez0FxABIbunJHCk8xvGulHFb8RHXSsyM_z1WsPY,30850
52
+ pertpy/tools/_scgen/_scgen.py,sha256=AQNGsDe-9HEqli3oq7UBDg68ofLCoXm-R_jnLFQ-rlc,30856
53
53
  pertpy/tools/_scgen/_scgenvae.py,sha256=bPk4v7EdJc7ROdLuDitHiX_Pvwa7Flw2qHRUwBvjLJY,3889
54
54
  pertpy/tools/_scgen/_utils.py,sha256=qz5QUn_Bvk2NGyYVzp3jgjWTFOMt1YyHwUo6HWtoThY,2871
55
- pertpy-0.11.4.dist-info/METADATA,sha256=Ox3dUh5YA5_a72GAOjCUj-l4Xc2vqz8sEZlhNlfEykY,8701
56
- pertpy-0.11.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
57
- pertpy-0.11.4.dist-info/licenses/LICENSE,sha256=XuiT2hxeRInhquEIBKMZ5M21n5syhDQ4XbABoposIAg,1100
58
- pertpy-0.11.4.dist-info/RECORD,,
55
+ pertpy-1.0.0.dist-info/METADATA,sha256=PnK9O-MyIPzSy5DNOqMN7G6zcxZ2ZTJnMFB5cEr5XJQ,8920
56
+ pertpy-1.0.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
57
+ pertpy-1.0.0.dist-info/licenses/LICENSE,sha256=XuiT2hxeRInhquEIBKMZ5M21n5syhDQ4XbABoposIAg,1100
58
+ pertpy-1.0.0.dist-info/RECORD,,