pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
5
5
  import numpy as np
6
6
  import pandas as pd
7
7
  from anndata import AnnData
8
+ from pynndescent import NNDescent
8
9
  from rich import print
9
10
 
10
11
  if TYPE_CHECKING:
@@ -15,7 +16,7 @@ class PerturbationSpace:
15
16
  """Implements various ways of interacting with PerturbationSpaces.
16
17
 
17
18
  We differentiate between a cell space and a perturbation space.
18
- Visually speaking, in cell spaces single dota points in an embeddings summarize a cell,
19
+ Visually speaking, in cell spaces single data points in an embeddings summarize a cell,
19
20
  whereas in a perturbation space, data points summarize whole perturbations.
20
21
  """
21
22
 
@@ -25,7 +26,8 @@ class PerturbationSpace:
25
26
  def compute_control_diff( # type: ignore
26
27
  self,
27
28
  adata: AnnData,
28
- target_col: str = "perturbations",
29
+ target_col: str = "perturbation",
30
+ group_col: str = None,
29
31
  reference_key: str = "control",
30
32
  layer_key: str = None,
31
33
  new_layer_key: str = "control_diff",
@@ -33,26 +35,31 @@ class PerturbationSpace:
33
35
  new_embedding_key: str = "control_diff",
34
36
  all_data: bool = False,
35
37
  copy: bool = False,
36
- ):
38
+ ) -> AnnData:
37
39
  """Subtract mean of the control from the perturbation.
38
40
 
39
41
  Args:
40
42
  adata: Anndata object of size cells x genes.
41
43
  target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
44
+ group_col: .obs column name that stores the label of the group of eah cell. If None, ignore groups. Defaults to 'perturbations'.
42
45
  reference_key: The key of the control values. Defaults to 'control'.
43
46
  layer_key: Key of the AnnData layer to use for computation. Defaults to the `X` matrix otherwise.
44
- new_layer_key: the results are stored in the given layer. Defaults to 'differential diff'.
47
+ new_layer_key: the results are stored in the given layer. Defaults to 'control_diff'.
45
48
  embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
46
- new_embedding_key: Results are stored in a new embedding in `obsm` with this key. Defaults to 'control diff'.
49
+ new_embedding_key: Results are stored in a new embedding in `obsm` with this key. Defaults to 'control_diff'.
47
50
  all_data: if True, do the computation in all data representations (X, all layers and all embeddings)
48
51
  copy: If True returns a new Anndata of same size with the new column; otherwise it updates the initial AnnData object.
49
52
 
53
+ Returns:
54
+ Updated AnnData object.
55
+
50
56
  Examples:
51
57
  Example usage with PseudobulkSpace:
58
+
52
59
  >>> import pertpy as pt
53
60
  >>> mdata = pt.dt.papalexi_2021()
54
61
  >>> ps = pt.tl.PseudobulkSpace()
55
- >>> diff_adata = ps.compute_control_diff(mdata["rna"], target_col="gene_target", reference_key='NT')
62
+ >>> diff_adata = ps.compute_control_diff(mdata["rna"], target_col="gene_target", reference_key="NT")
56
63
  """
57
64
  if reference_key not in adata.obs[target_col].unique():
58
65
  raise ValueError(
@@ -69,48 +76,67 @@ class PerturbationSpace:
69
76
  adata = adata.copy()
70
77
 
71
78
  control_mask = adata.obs[target_col] == reference_key
72
- num_control = control_mask.sum()
79
+ group_masks = (
80
+ [(adata.obs[group_col] == sample) for sample in adata.obs[group_col].unique()]
81
+ if group_col
82
+ else [[True] * adata.n_obs]
83
+ )
73
84
 
74
85
  if layer_key:
75
- if num_control == 1:
76
- control_expression = adata.layers[layer_key][control_mask, :]
77
- else:
78
- control_expression = np.mean(adata.layers[layer_key][control_mask, :], axis=0)
79
- diff_matrix = adata.layers[layer_key] - control_expression
80
- adata.layers[new_layer_key] = diff_matrix
86
+ adata.layers[new_layer_key] = np.zeros((adata.n_obs, adata.n_vars))
87
+ for mask in group_masks:
88
+ num_control = (control_mask & mask).sum()
89
+ if num_control == 1:
90
+ control_expression = adata.layers[layer_key][(control_mask & mask), :]
91
+ elif num_control > 1:
92
+ control_expression = np.mean(adata.layers[layer_key][(control_mask & mask), :], axis=0)
93
+ else:
94
+ control_expression = np.zeros((1, adata.n_vars))
95
+ adata.layers[new_layer_key][mask, :] = adata.layers[layer_key][mask, :] - control_expression
81
96
 
82
97
  if embedding_key:
83
- if num_control == 1:
84
- control_expression = adata.obsm[embedding_key][control_mask, :]
85
- else:
86
- control_expression = np.mean(adata.obsm[embedding_key][control_mask, :], axis=0)
87
- diff_matrix = adata.obsm[embedding_key] - control_expression
88
- adata.obsm[new_embedding_key] = diff_matrix
98
+ adata.obsm[new_embedding_key] = np.zeros(adata.obsm[embedding_key].shape)
99
+ for mask in group_masks:
100
+ num_control = (control_mask & mask).sum()
101
+ if num_control == 1:
102
+ control_expression = adata.obsm[embedding_key][(control_mask & mask), :]
103
+ elif num_control > 1:
104
+ control_expression = np.mean(adata.obsm[embedding_key][(control_mask & mask), :], axis=0)
105
+ else:
106
+ control_expression = np.zeros((1, adata.n_vars))
107
+ adata.obsm[new_embedding_key][mask, :] = adata.obsm[embedding_key][mask, :] - control_expression
89
108
 
90
109
  if (not layer_key and not embedding_key) or all_data:
91
- if num_control == 1:
92
- control_expression = adata.X[control_mask, :]
93
- else:
94
- control_expression = np.mean(adata.X[control_mask, :], axis=0)
95
- diff_matrix = adata.X - control_expression
96
- adata.X = diff_matrix
110
+ adata_x = np.zeros((adata.n_obs, adata.n_vars))
111
+ for mask in group_masks:
112
+ num_control = (control_mask & mask).sum()
113
+ if num_control == 1:
114
+ control_expression = adata.X[(control_mask & mask), :]
115
+ elif num_control > 1:
116
+ control_expression = np.mean(adata.X[(control_mask & mask), :], axis=0)
117
+ else:
118
+ control_expression = np.zeros((1, adata.n_vars))
119
+ adata_x[mask, :] = adata.X[mask, :] - control_expression
120
+ adata.X = adata_x
97
121
 
98
122
  if all_data:
99
123
  layers_keys = list(adata.layers.keys())
100
124
  for local_layer_key in layers_keys:
101
125
  if local_layer_key != layer_key and local_layer_key != new_layer_key:
102
- diff_matrix = adata.layers[local_layer_key] - np.mean(
103
- adata.layers[local_layer_key][control_mask, :], axis=0
104
- )
105
- adata.layers[local_layer_key + "_control_diff"] = diff_matrix
126
+ adata.layers[local_layer_key + "_control_diff"] = np.zeros((adata.n_obs, adata.n_vars))
127
+ for mask in group_masks:
128
+ adata.layers[local_layer_key + "_control_diff"][mask, :] = adata.layers[local_layer_key][
129
+ mask, :
130
+ ] - np.mean(adata.layers[local_layer_key][(control_mask & mask), :], axis=0)
106
131
 
107
132
  embedding_keys = list(adata.obsm_keys())
108
133
  for local_embedding_key in embedding_keys:
109
134
  if local_embedding_key != embedding_key and local_embedding_key != new_embedding_key:
110
- diff_matrix = adata.obsm[local_embedding_key] - np.mean(
111
- adata.obsm[local_embedding_key][control_mask, :], axis=0
112
- )
113
- adata.obsm[local_embedding_key + "_control_diff"] = diff_matrix
135
+ adata.obsm[local_embedding_key + "_control_diff"] = np.zeros(adata.obsm[local_embedding_key].shape)
136
+ for mask in group_masks:
137
+ adata.obsm[local_embedding_key + "_control_diff"][mask, :] = adata.obsm[local_embedding_key][
138
+ mask, :
139
+ ] - np.mean(adata.obsm[local_embedding_key][(control_mask & mask), :], axis=0)
114
140
 
115
141
  self.control_diff_computed = True
116
142
 
@@ -122,24 +148,30 @@ class PerturbationSpace:
122
148
  perturbations: Iterable[str],
123
149
  reference_key: str = "control",
124
150
  ensure_consistency: bool = False,
125
- target_col: str = "perturbations",
126
- ):
151
+ target_col: str = "perturbation",
152
+ ) -> tuple[AnnData, AnnData] | AnnData:
127
153
  """Add perturbations linearly. Assumes input of size n_perts x dimensionality
128
154
 
129
155
  Args:
130
156
  adata: Anndata object of size n_perts x dim.
131
157
  perturbations: Perturbations to add.
132
- reference_key: perturbation source from which the perturbation summation starts.
158
+ reference_key: perturbation source from which the perturbation summation starts. Defaults to 'control'.
133
159
  ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
134
- target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
160
+ target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbation'.
161
+
162
+ Returns:
163
+ Anndata object of size (n_perts+1) x dim, where the last row is the addition of the specified perturbations.
164
+ If ensure_consistency is True, returns a tuple of (new_perturbation, adata) where adata is the AnnData object
165
+ provided as input but updated using compute_control_diff.
135
166
 
136
167
  Examples:
137
168
  Example usage with PseudobulkSpace:
169
+
138
170
  >>> import pertpy as pt
139
171
  >>> mdata = pt.dt.papalexi_2021()
140
172
  >>> ps = pt.tl.PseudobulkSpace()
141
173
  >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target")
142
- >>> new_perturbation = ps.add(ps_adata, perturbations=["ATF2", "CD86"], reference_key='NT')
174
+ >>> new_perturbation = ps.add(ps_adata, perturbations=["ATF2", "CD86"], reference_key="NT")
143
175
  """
144
176
  new_pert_name = ""
145
177
  for perturbation in perturbations:
@@ -212,6 +244,8 @@ class PerturbationSpace:
212
244
  key_name = key.removesuffix("_control_diff")
213
245
  new_perturbation.obsm[key_name] = data["embeddings"][key]
214
246
 
247
+ new_perturbation.obs[target_col] = new_perturbation.obs_names.astype("category")
248
+
215
249
  if ensure_consistency:
216
250
  return new_perturbation, adata
217
251
 
@@ -223,24 +257,30 @@ class PerturbationSpace:
223
257
  perturbations: Iterable[str],
224
258
  reference_key: str = "control",
225
259
  ensure_consistency: bool = False,
226
- target_col: str = "perturbations",
227
- ):
260
+ target_col: str = "perturbation",
261
+ ) -> tuple[AnnData, AnnData] | AnnData:
228
262
  """Subtract perturbations linearly. Assumes input of size n_perts x dimensionality
229
263
 
230
264
  Args:
231
265
  adata: Anndata object of size n_perts x dim.
232
- perturbations: Perturbations to subtract,
233
- reference_key: Perturbation source from which the perturbation subtraction starts
266
+ perturbations: Perturbations to subtract.
267
+ reference_key: Perturbation source from which the perturbation subtraction starts. Defaults to 'control'.
234
268
  ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
235
269
  target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
236
270
 
271
+ Returns:
272
+ Anndata object of size (n_perts+1) x dim, where the last row is the subtraction of the specified perturbations.
273
+ If ensure_consistency is True, returns a tuple of (new_perturbation, adata) where adata is the AnnData object
274
+ provided as input but updated using compute_control_diff.
275
+
237
276
  Examples:
238
277
  Example usage with PseudobulkSpace:
278
+
239
279
  >>> import pertpy as pt
240
280
  >>> mdata = pt.dt.papalexi_2021()
241
281
  >>> ps = pt.tl.PseudobulkSpace()
242
282
  >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target")
243
- >>> new_perturbation = ps.add(ps_adata, reference_key="ATF2", perturbations=["BRD4", "CUL3"])
283
+ >>> new_perturbation = ps.subtract(ps_adata, reference_key="ATF2", perturbations=["BRD4", "CUL3"])
244
284
  """
245
285
  new_pert_name = reference_key + "-"
246
286
  for perturbation in perturbations:
@@ -313,7 +353,59 @@ class PerturbationSpace:
313
353
  key_name = key.removesuffix("_control_diff")
314
354
  new_perturbation.obsm[key_name] = data["embeddings"][key]
315
355
 
356
+ new_perturbation.obs[target_col] = new_perturbation.obs_names.astype("category")
357
+
316
358
  if ensure_consistency:
317
359
  return new_perturbation, adata
318
360
 
319
361
  return new_perturbation
362
+
363
+ def label_transfer(
364
+ self,
365
+ adata: AnnData,
366
+ column: str = "perturbation",
367
+ target_val: str = "unknown",
368
+ n_neighbors: int = 5,
369
+ use_rep: str = "X_umap",
370
+ ) -> None:
371
+ """Impute missing values in the specified column using KNN imputation in the space defined by `use_rep`.
372
+
373
+ Args:
374
+ adata: The AnnData object containing single-cell data.
375
+ column: The column name in AnnData object to perform imputation on. Defaults to "perturbation".
376
+ target_val: The target value to impute. Defaults to "unknown".
377
+ n_neighbors: Number of neighbors to use for imputation. Defaults to 5.
378
+ use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored. Defaults to 'X_umap'.
379
+
380
+ Examples:
381
+ >>> import pertpy as pt
382
+ >>> import scanpy as sc
383
+ >>> import numpy as np
384
+ >>> adata = sc.datasets.pbmc68k_reduced()
385
+ >>> rng = np.random.default_rng()
386
+ >>> adata.obs["perturbation"] = rng.choice(
387
+ ... ["A", "B", "C", "unknown"], size=adata.n_obs, p=[0.33, 0.33, 0.33, 0.01]
388
+ ... )
389
+ >>> sc.pp.neighbors(adata)
390
+ >>> sc.tl.umap(adata)
391
+ >>> ps = pt.tl.PseudobulkSpace()
392
+ >>> ps.label_transfer(adata, n_neighbors=5, use_rep="X_umap")
393
+ """
394
+ if use_rep not in adata.obsm:
395
+ raise ValueError(f"Representation {use_rep} not found in the AnnData object.")
396
+
397
+ embedding = adata.obsm[use_rep]
398
+
399
+ nnd = NNDescent(embedding, n_neighbors=n_neighbors)
400
+ indices, _ = nnd.query(embedding, k=n_neighbors)
401
+
402
+ perturbations = np.array(adata.obs[column])
403
+ missing_mask = perturbations == target_val
404
+
405
+ for idx in np.where(missing_mask)[0]:
406
+ neighbor_indices = indices[idx]
407
+ neighbor_categories = perturbations[neighbor_indices]
408
+ most_common = pd.Series(neighbor_categories).mode()[0]
409
+ perturbations[idx] = most_common
410
+
411
+ adata.obs[column] = perturbations
@@ -15,9 +15,10 @@ class CentroidSpace(PerturbationSpace):
15
15
  def compute(
16
16
  self,
17
17
  adata: AnnData,
18
- target_col: str = "perturbations",
18
+ target_col: str = "perturbation",
19
19
  layer_key: str = None,
20
20
  embedding_key: str = "X_umap",
21
+ keep_obs: bool = True,
21
22
  ) -> AnnData: # type: ignore
22
23
  """Computes the centroids of a pre-computed embedding such as UMAP.
23
24
 
@@ -26,6 +27,12 @@ class CentroidSpace(PerturbationSpace):
26
27
  target_col: .obs column that stores the label of the perturbation applied to each cell.
27
28
  layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
28
29
  embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
30
+ keep_obs: Whether .obs columns in the input AnnData should be kept in the output pseudobulk AnnData. Only .obs columns with the same value for
31
+ each cell of one perturbation are kept. Defaults to True.
32
+
33
+ Returns:
34
+ AnnData object with one observation per perturbation, storing the embedding data of the
35
+ centroid of the respective perturbation.
29
36
 
30
37
  Examples:
31
38
  Compute the centroids of a UMAP embedding of the papalexi_2021 dataset:
@@ -34,7 +41,7 @@ class CentroidSpace(PerturbationSpace):
34
41
  >>> import scanpy as sc
35
42
  >>> mdata = pt.dt.papalexi_2021()
36
43
  >>> sc.pp.pca(mdata["rna"])
37
- >>> sc.pp.neighbors(mdata['rna'])
44
+ >>> sc.pp.neighbors(mdata["rna"])
38
45
  >>> sc.tl.umap(mdata["rna"])
39
46
  >>> cs = pt.tl.CentroidSpace()
40
47
  >>> cs_adata = cs.compute(mdata["rna"], target_col="gene_target")
@@ -84,6 +91,22 @@ class CentroidSpace(PerturbationSpace):
84
91
 
85
92
  ps_adata = AnnData(X=X)
86
93
  ps_adata.obs_names = index
94
+ ps_adata.obs[target_col] = index
95
+
96
+ if embedding_key is not None:
97
+ ps_adata.obsm[embedding_key] = X
98
+
99
+ if keep_obs: # Save the values of the obs columns of interest in the ps_adata object
100
+ obs_df = adata.obs
101
+ obs_df = obs_df.groupby(target_col).agg(
102
+ lambda pert_group: np.nan if len(set(pert_group)) != 1 else list(set(pert_group))[0]
103
+ )
104
+ for obs_name in obs_df.columns:
105
+ if not obs_df[obs_name].isnull().values.any():
106
+ mapping = {pert: obs_df.loc[pert][obs_name] for pert in index}
107
+ ps_adata.obs[obs_name] = ps_adata.obs[target_col].map(mapping)
108
+
109
+ ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")
87
110
 
88
111
  return ps_adata
89
112
 
@@ -94,7 +117,8 @@ class PseudobulkSpace(PerturbationSpace):
94
117
  def compute(
95
118
  self,
96
119
  adata: AnnData,
97
- target_col: str = "perturbations",
120
+ target_col: str = "perturbation",
121
+ groups_col: str = None,
98
122
  layer_key: str = None,
99
123
  embedding_key: str = None,
100
124
  **kwargs,
@@ -104,19 +128,21 @@ class PseudobulkSpace(PerturbationSpace):
104
128
  Args:
105
129
  adata: Anndata object of size cells x genes
106
130
  target_col: .obs column that stores the label of the perturbation applied to each cell.
131
+ groups_col: Optional .obs column that stores a grouping label to consider for pseudobulk computation.
132
+ The summarized expression per perturbation (target_col) and group (groups_col) is computed. Defaults to None.
107
133
  layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
108
134
  embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
109
135
  **kwargs: Are passed to decoupler's get_pseuobulk.
110
136
 
137
+ Returns:
138
+ AnnData object with one observation per perturbation.
139
+
111
140
  Examples:
112
- >>> import pertpy as pp
141
+ >>> import pertpy as pt
113
142
  >>> mdata = pt.dt.papalexi_2021()
114
143
  >>> ps = pt.tl.PseudobulkSpace()
115
- >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target")
144
+ >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target")
116
145
  """
117
- if "groups_col" not in kwargs:
118
- kwargs["groups_col"] = "perturbations"
119
-
120
146
  if layer_key is not None and embedding_key is not None:
121
147
  raise ValueError("Please, select just either layer or embedding for computation.")
122
148
 
@@ -135,7 +161,10 @@ class PseudobulkSpace(PerturbationSpace):
135
161
  adata_emb.obs = adata.obs
136
162
  adata = adata_emb
137
163
 
138
- ps_adata = dc.get_pseudobulk(adata, sample_col=target_col, layer=layer_key, **kwargs) # type: ignore
164
+ adata.obs[target_col] = adata.obs[target_col].astype("category")
165
+ ps_adata = dc.get_pseudobulk(adata, sample_col=target_col, layer=layer_key, groups_col=groups_col, **kwargs) # type: ignore
166
+
167
+ ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")
139
168
 
140
169
  return ps_adata
141
170
 
@@ -164,6 +193,11 @@ class KMeansSpace(ClusteringSpace):
164
193
  return_object: if True returns the clustering object
165
194
  **kwargs: Are passed to sklearn's KMeans.
166
195
 
196
+ Returns:
197
+ If return_object is True, the adata and the clustering object is returned.
198
+ Otherwise, only the adata is returned. The adata is updated with a new .obs column as specified in cluster_key,
199
+ that stores the cluster labels.
200
+
167
201
  Examples:
168
202
  >>> import pertpy as pt
169
203
  >>> mdata = pt.dt.papalexi_2021()
@@ -193,6 +227,7 @@ class KMeansSpace(ClusteringSpace):
193
227
 
194
228
  clustering = KMeans(**kwargs).fit(self.X)
195
229
  adata.obs[cluster_key] = clustering.labels_
230
+ adata.obs[cluster_key] = adata.obs[cluster_key].astype("category")
196
231
 
197
232
  if return_object:
198
233
  return adata, clustering
@@ -212,7 +247,7 @@ class DBSCANSpace(ClusteringSpace):
212
247
  copy: bool = True,
213
248
  return_object: bool = False,
214
249
  **kwargs,
215
- ) -> tuple[AnnData, object | AnnData]:
250
+ ) -> tuple[AnnData, object] | AnnData:
216
251
  """Computes a clustering using Density-based spatial clustering of applications (DBSCAN).
217
252
 
218
253
  Args:
@@ -224,6 +259,11 @@ class DBSCANSpace(ClusteringSpace):
224
259
  return_object: if True returns the clustering object
225
260
  **kwargs: Are passed to sklearn's DBSCAN.
226
261
 
262
+ Returns:
263
+ If return_object is True, the adata and the clustering object is returned.
264
+ Otherwise, only the adata is returned. The adata is updated with a new .obs column as specified in cluster_key,
265
+ that stores the cluster labels.
266
+
227
267
  Examples:
228
268
  >>> import pertpy as pt
229
269
  >>> mdata = pt.dt.papalexi_2021()
@@ -250,6 +290,7 @@ class DBSCANSpace(ClusteringSpace):
250
290
 
251
291
  clustering = DBSCAN(**kwargs).fit(self.X)
252
292
  adata.obs[cluster_key] = clustering.labels_
293
+ adata.obs[cluster_key] = adata.obs[cluster_key].astype("category")
253
294
 
254
295
  if return_object:
255
296
  return adata, clustering
@@ -1 +1 @@
1
- from pertpy.tools._scgen._jax_scgen import SCGEN
1
+ from pertpy.tools._scgen._scgen import SCGEN