pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
5
5
  import numpy as np
6
6
  import pandas as pd
7
7
  from anndata import AnnData
8
+ from pynndescent import NNDescent
8
9
  from rich import print
9
10
 
10
11
  if TYPE_CHECKING:
@@ -15,7 +16,7 @@ class PerturbationSpace:
15
16
  """Implements various ways of interacting with PerturbationSpaces.
16
17
 
17
18
  We differentiate between a cell space and a perturbation space.
18
- Visually speaking, in cell spaces single dota points in an embeddings summarize a cell,
19
+ Visually speaking, in cell spaces single data points in an embeddings summarize a cell,
19
20
  whereas in a perturbation space, data points summarize whole perturbations.
20
21
  """
21
22
 
@@ -25,7 +26,8 @@ class PerturbationSpace:
25
26
  def compute_control_diff( # type: ignore
26
27
  self,
27
28
  adata: AnnData,
28
- target_col: str = "perturbations",
29
+ target_col: str = "perturbation",
30
+ group_col: str = None,
29
31
  reference_key: str = "control",
30
32
  layer_key: str = None,
31
33
  new_layer_key: str = "control_diff",
@@ -33,26 +35,31 @@ class PerturbationSpace:
33
35
  new_embedding_key: str = "control_diff",
34
36
  all_data: bool = False,
35
37
  copy: bool = False,
36
- ):
38
+ ) -> AnnData:
37
39
  """Subtract mean of the control from the perturbation.
38
40
 
39
41
  Args:
40
42
  adata: Anndata object of size cells x genes.
41
43
  target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
44
+ group_col: .obs column name that stores the label of the group of eah cell. If None, ignore groups. Defaults to 'perturbations'.
42
45
  reference_key: The key of the control values. Defaults to 'control'.
43
46
  layer_key: Key of the AnnData layer to use for computation. Defaults to the `X` matrix otherwise.
44
- new_layer_key: the results are stored in the given layer. Defaults to 'differential diff'.
47
+ new_layer_key: the results are stored in the given layer. Defaults to 'control_diff'.
45
48
  embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
46
- new_embedding_key: Results are stored in a new embedding in `obsm` with this key. Defaults to 'control diff'.
49
+ new_embedding_key: Results are stored in a new embedding in `obsm` with this key. Defaults to 'control_diff'.
47
50
  all_data: if True, do the computation in all data representations (X, all layers and all embeddings)
48
51
  copy: If True returns a new Anndata of same size with the new column; otherwise it updates the initial AnnData object.
49
52
 
53
+ Returns:
54
+ Updated AnnData object.
55
+
50
56
  Examples:
51
57
  Example usage with PseudobulkSpace:
58
+
52
59
  >>> import pertpy as pt
53
60
  >>> mdata = pt.dt.papalexi_2021()
54
61
  >>> ps = pt.tl.PseudobulkSpace()
55
- >>> diff_adata = ps.compute_control_diff(mdata["rna"], target_col="gene_target", reference_key='NT')
62
+ >>> diff_adata = ps.compute_control_diff(mdata["rna"], target_col="gene_target", reference_key="NT")
56
63
  """
57
64
  if reference_key not in adata.obs[target_col].unique():
58
65
  raise ValueError(
@@ -69,48 +76,67 @@ class PerturbationSpace:
69
76
  adata = adata.copy()
70
77
 
71
78
  control_mask = adata.obs[target_col] == reference_key
72
- num_control = control_mask.sum()
79
+ group_masks = (
80
+ [(adata.obs[group_col] == sample) for sample in adata.obs[group_col].unique()]
81
+ if group_col
82
+ else [[True] * adata.n_obs]
83
+ )
73
84
 
74
85
  if layer_key:
75
- if num_control == 1:
76
- control_expression = adata.layers[layer_key][control_mask, :]
77
- else:
78
- control_expression = np.mean(adata.layers[layer_key][control_mask, :], axis=0)
79
- diff_matrix = adata.layers[layer_key] - control_expression
80
- adata.layers[new_layer_key] = diff_matrix
86
+ adata.layers[new_layer_key] = np.zeros((adata.n_obs, adata.n_vars))
87
+ for mask in group_masks:
88
+ num_control = (control_mask & mask).sum()
89
+ if num_control == 1:
90
+ control_expression = adata.layers[layer_key][(control_mask & mask), :]
91
+ elif num_control > 1:
92
+ control_expression = np.mean(adata.layers[layer_key][(control_mask & mask), :], axis=0)
93
+ else:
94
+ control_expression = np.zeros((1, adata.n_vars))
95
+ adata.layers[new_layer_key][mask, :] = adata.layers[layer_key][mask, :] - control_expression
81
96
 
82
97
  if embedding_key:
83
- if num_control == 1:
84
- control_expression = adata.obsm[embedding_key][control_mask, :]
85
- else:
86
- control_expression = np.mean(adata.obsm[embedding_key][control_mask, :], axis=0)
87
- diff_matrix = adata.obsm[embedding_key] - control_expression
88
- adata.obsm[new_embedding_key] = diff_matrix
98
+ adata.obsm[new_embedding_key] = np.zeros(adata.obsm[embedding_key].shape)
99
+ for mask in group_masks:
100
+ num_control = (control_mask & mask).sum()
101
+ if num_control == 1:
102
+ control_expression = adata.obsm[embedding_key][(control_mask & mask), :]
103
+ elif num_control > 1:
104
+ control_expression = np.mean(adata.obsm[embedding_key][(control_mask & mask), :], axis=0)
105
+ else:
106
+ control_expression = np.zeros((1, adata.n_vars))
107
+ adata.obsm[new_embedding_key][mask, :] = adata.obsm[embedding_key][mask, :] - control_expression
89
108
 
90
109
  if (not layer_key and not embedding_key) or all_data:
91
- if num_control == 1:
92
- control_expression = adata.X[control_mask, :]
93
- else:
94
- control_expression = np.mean(adata.X[control_mask, :], axis=0)
95
- diff_matrix = adata.X - control_expression
96
- adata.X = diff_matrix
110
+ adata_x = np.zeros((adata.n_obs, adata.n_vars))
111
+ for mask in group_masks:
112
+ num_control = (control_mask & mask).sum()
113
+ if num_control == 1:
114
+ control_expression = adata.X[(control_mask & mask), :]
115
+ elif num_control > 1:
116
+ control_expression = np.mean(adata.X[(control_mask & mask), :], axis=0)
117
+ else:
118
+ control_expression = np.zeros((1, adata.n_vars))
119
+ adata_x[mask, :] = adata.X[mask, :] - control_expression
120
+ adata.X = adata_x
97
121
 
98
122
  if all_data:
99
123
  layers_keys = list(adata.layers.keys())
100
124
  for local_layer_key in layers_keys:
101
125
  if local_layer_key != layer_key and local_layer_key != new_layer_key:
102
- diff_matrix = adata.layers[local_layer_key] - np.mean(
103
- adata.layers[local_layer_key][control_mask, :], axis=0
104
- )
105
- adata.layers[local_layer_key + "_control_diff"] = diff_matrix
126
+ adata.layers[local_layer_key + "_control_diff"] = np.zeros((adata.n_obs, adata.n_vars))
127
+ for mask in group_masks:
128
+ adata.layers[local_layer_key + "_control_diff"][mask, :] = adata.layers[local_layer_key][
129
+ mask, :
130
+ ] - np.mean(adata.layers[local_layer_key][(control_mask & mask), :], axis=0)
106
131
 
107
132
  embedding_keys = list(adata.obsm_keys())
108
133
  for local_embedding_key in embedding_keys:
109
134
  if local_embedding_key != embedding_key and local_embedding_key != new_embedding_key:
110
- diff_matrix = adata.obsm[local_embedding_key] - np.mean(
111
- adata.obsm[local_embedding_key][control_mask, :], axis=0
112
- )
113
- adata.obsm[local_embedding_key + "_control_diff"] = diff_matrix
135
+ adata.obsm[local_embedding_key + "_control_diff"] = np.zeros(adata.obsm[local_embedding_key].shape)
136
+ for mask in group_masks:
137
+ adata.obsm[local_embedding_key + "_control_diff"][mask, :] = adata.obsm[local_embedding_key][
138
+ mask, :
139
+ ] - np.mean(adata.obsm[local_embedding_key][(control_mask & mask), :], axis=0)
114
140
 
115
141
  self.control_diff_computed = True
116
142
 
@@ -122,24 +148,30 @@ class PerturbationSpace:
122
148
  perturbations: Iterable[str],
123
149
  reference_key: str = "control",
124
150
  ensure_consistency: bool = False,
125
- target_col: str = "perturbations",
126
- ):
151
+ target_col: str = "perturbation",
152
+ ) -> tuple[AnnData, AnnData] | AnnData:
127
153
  """Add perturbations linearly. Assumes input of size n_perts x dimensionality
128
154
 
129
155
  Args:
130
156
  adata: Anndata object of size n_perts x dim.
131
157
  perturbations: Perturbations to add.
132
- reference_key: perturbation source from which the perturbation summation starts.
158
+ reference_key: perturbation source from which the perturbation summation starts. Defaults to 'control'.
133
159
  ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
134
- target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
160
+ target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbation'.
161
+
162
+ Returns:
163
+ Anndata object of size (n_perts+1) x dim, where the last row is the addition of the specified perturbations.
164
+ If ensure_consistency is True, returns a tuple of (new_perturbation, adata) where adata is the AnnData object
165
+ provided as input but updated using compute_control_diff.
135
166
 
136
167
  Examples:
137
168
  Example usage with PseudobulkSpace:
169
+
138
170
  >>> import pertpy as pt
139
171
  >>> mdata = pt.dt.papalexi_2021()
140
172
  >>> ps = pt.tl.PseudobulkSpace()
141
173
  >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target")
142
- >>> new_perturbation = ps.add(ps_adata, perturbations=["ATF2", "CD86"], reference_key='NT')
174
+ >>> new_perturbation = ps.add(ps_adata, perturbations=["ATF2", "CD86"], reference_key="NT")
143
175
  """
144
176
  new_pert_name = ""
145
177
  for perturbation in perturbations:
@@ -212,6 +244,8 @@ class PerturbationSpace:
212
244
  key_name = key.removesuffix("_control_diff")
213
245
  new_perturbation.obsm[key_name] = data["embeddings"][key]
214
246
 
247
+ new_perturbation.obs[target_col] = new_perturbation.obs_names.astype("category")
248
+
215
249
  if ensure_consistency:
216
250
  return new_perturbation, adata
217
251
 
@@ -223,24 +257,30 @@ class PerturbationSpace:
223
257
  perturbations: Iterable[str],
224
258
  reference_key: str = "control",
225
259
  ensure_consistency: bool = False,
226
- target_col: str = "perturbations",
227
- ):
260
+ target_col: str = "perturbation",
261
+ ) -> tuple[AnnData, AnnData] | AnnData:
228
262
  """Subtract perturbations linearly. Assumes input of size n_perts x dimensionality
229
263
 
230
264
  Args:
231
265
  adata: Anndata object of size n_perts x dim.
232
- perturbations: Perturbations to subtract,
233
- reference_key: Perturbation source from which the perturbation subtraction starts
266
+ perturbations: Perturbations to subtract.
267
+ reference_key: Perturbation source from which the perturbation subtraction starts. Defaults to 'control'.
234
268
  ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
235
269
  target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
236
270
 
271
+ Returns:
272
+ Anndata object of size (n_perts+1) x dim, where the last row is the subtraction of the specified perturbations.
273
+ If ensure_consistency is True, returns a tuple of (new_perturbation, adata) where adata is the AnnData object
274
+ provided as input but updated using compute_control_diff.
275
+
237
276
  Examples:
238
277
  Example usage with PseudobulkSpace:
278
+
239
279
  >>> import pertpy as pt
240
280
  >>> mdata = pt.dt.papalexi_2021()
241
281
  >>> ps = pt.tl.PseudobulkSpace()
242
282
  >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target")
243
- >>> new_perturbation = ps.add(ps_adata, reference_key="ATF2", perturbations=["BRD4", "CUL3"])
283
+ >>> new_perturbation = ps.subtract(ps_adata, reference_key="ATF2", perturbations=["BRD4", "CUL3"])
244
284
  """
245
285
  new_pert_name = reference_key + "-"
246
286
  for perturbation in perturbations:
@@ -313,7 +353,59 @@ class PerturbationSpace:
313
353
  key_name = key.removesuffix("_control_diff")
314
354
  new_perturbation.obsm[key_name] = data["embeddings"][key]
315
355
 
356
+ new_perturbation.obs[target_col] = new_perturbation.obs_names.astype("category")
357
+
316
358
  if ensure_consistency:
317
359
  return new_perturbation, adata
318
360
 
319
361
  return new_perturbation
362
+
363
+ def label_transfer(
364
+ self,
365
+ adata: AnnData,
366
+ column: str = "perturbation",
367
+ target_val: str = "unknown",
368
+ n_neighbors: int = 5,
369
+ use_rep: str = "X_umap",
370
+ ) -> None:
371
+ """Impute missing values in the specified column using KNN imputation in the space defined by `use_rep`.
372
+
373
+ Args:
374
+ adata: The AnnData object containing single-cell data.
375
+ column: The column name in AnnData object to perform imputation on. Defaults to "perturbation".
376
+ target_val: The target value to impute. Defaults to "unknown".
377
+ n_neighbors: Number of neighbors to use for imputation. Defaults to 5.
378
+ use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored. Defaults to 'X_umap'.
379
+
380
+ Examples:
381
+ >>> import pertpy as pt
382
+ >>> import scanpy as sc
383
+ >>> import numpy as np
384
+ >>> adata = sc.datasets.pbmc68k_reduced()
385
+ >>> rng = np.random.default_rng()
386
+ >>> adata.obs["perturbation"] = rng.choice(
387
+ ... ["A", "B", "C", "unknown"], size=adata.n_obs, p=[0.33, 0.33, 0.33, 0.01]
388
+ ... )
389
+ >>> sc.pp.neighbors(adata)
390
+ >>> sc.tl.umap(adata)
391
+ >>> ps = pt.tl.PseudobulkSpace()
392
+ >>> ps.label_transfer(adata, n_neighbors=5, use_rep="X_umap")
393
+ """
394
+ if use_rep not in adata.obsm:
395
+ raise ValueError(f"Representation {use_rep} not found in the AnnData object.")
396
+
397
+ embedding = adata.obsm[use_rep]
398
+
399
+ nnd = NNDescent(embedding, n_neighbors=n_neighbors)
400
+ indices, _ = nnd.query(embedding, k=n_neighbors)
401
+
402
+ perturbations = np.array(adata.obs[column])
403
+ missing_mask = perturbations == target_val
404
+
405
+ for idx in np.where(missing_mask)[0]:
406
+ neighbor_indices = indices[idx]
407
+ neighbor_categories = perturbations[neighbor_indices]
408
+ most_common = pd.Series(neighbor_categories).mode()[0]
409
+ perturbations[idx] = most_common
410
+
411
+ adata.obs[column] = perturbations
@@ -15,9 +15,10 @@ class CentroidSpace(PerturbationSpace):
15
15
  def compute(
16
16
  self,
17
17
  adata: AnnData,
18
- target_col: str = "perturbations",
18
+ target_col: str = "perturbation",
19
19
  layer_key: str = None,
20
20
  embedding_key: str = "X_umap",
21
+ keep_obs: bool = True,
21
22
  ) -> AnnData: # type: ignore
22
23
  """Computes the centroids of a pre-computed embedding such as UMAP.
23
24
 
@@ -26,6 +27,12 @@ class CentroidSpace(PerturbationSpace):
26
27
  target_col: .obs column that stores the label of the perturbation applied to each cell.
27
28
  layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
28
29
  embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
30
+ keep_obs: Whether .obs columns in the input AnnData should be kept in the output pseudobulk AnnData. Only .obs columns with the same value for
31
+ each cell of one perturbation are kept. Defaults to True.
32
+
33
+ Returns:
34
+ AnnData object with one observation per perturbation, storing the embedding data of the
35
+ centroid of the respective perturbation.
29
36
 
30
37
  Examples:
31
38
  Compute the centroids of a UMAP embedding of the papalexi_2021 dataset:
@@ -34,7 +41,7 @@ class CentroidSpace(PerturbationSpace):
34
41
  >>> import scanpy as sc
35
42
  >>> mdata = pt.dt.papalexi_2021()
36
43
  >>> sc.pp.pca(mdata["rna"])
37
- >>> sc.pp.neighbors(mdata['rna'])
44
+ >>> sc.pp.neighbors(mdata["rna"])
38
45
  >>> sc.tl.umap(mdata["rna"])
39
46
  >>> cs = pt.tl.CentroidSpace()
40
47
  >>> cs_adata = cs.compute(mdata["rna"], target_col="gene_target")
@@ -84,6 +91,22 @@ class CentroidSpace(PerturbationSpace):
84
91
 
85
92
  ps_adata = AnnData(X=X)
86
93
  ps_adata.obs_names = index
94
+ ps_adata.obs[target_col] = index
95
+
96
+ if embedding_key is not None:
97
+ ps_adata.obsm[embedding_key] = X
98
+
99
+ if keep_obs: # Save the values of the obs columns of interest in the ps_adata object
100
+ obs_df = adata.obs
101
+ obs_df = obs_df.groupby(target_col).agg(
102
+ lambda pert_group: np.nan if len(set(pert_group)) != 1 else list(set(pert_group))[0]
103
+ )
104
+ for obs_name in obs_df.columns:
105
+ if not obs_df[obs_name].isnull().values.any():
106
+ mapping = {pert: obs_df.loc[pert][obs_name] for pert in index}
107
+ ps_adata.obs[obs_name] = ps_adata.obs[target_col].map(mapping)
108
+
109
+ ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")
87
110
 
88
111
  return ps_adata
89
112
 
@@ -94,7 +117,8 @@ class PseudobulkSpace(PerturbationSpace):
94
117
  def compute(
95
118
  self,
96
119
  adata: AnnData,
97
- target_col: str = "perturbations",
120
+ target_col: str = "perturbation",
121
+ groups_col: str = None,
98
122
  layer_key: str = None,
99
123
  embedding_key: str = None,
100
124
  **kwargs,
@@ -104,19 +128,21 @@ class PseudobulkSpace(PerturbationSpace):
104
128
  Args:
105
129
  adata: Anndata object of size cells x genes
106
130
  target_col: .obs column that stores the label of the perturbation applied to each cell.
131
+ groups_col: Optional .obs column that stores a grouping label to consider for pseudobulk computation.
132
+ The summarized expression per perturbation (target_col) and group (groups_col) is computed. Defaults to None.
107
133
  layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
108
134
  embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
109
135
  **kwargs: Are passed to decoupler's get_pseuobulk.
110
136
 
137
+ Returns:
138
+ AnnData object with one observation per perturbation.
139
+
111
140
  Examples:
112
- >>> import pertpy as pp
141
+ >>> import pertpy as pt
113
142
  >>> mdata = pt.dt.papalexi_2021()
114
143
  >>> ps = pt.tl.PseudobulkSpace()
115
- >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target")
144
+ >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target")
116
145
  """
117
- if "groups_col" not in kwargs:
118
- kwargs["groups_col"] = "perturbations"
119
-
120
146
  if layer_key is not None and embedding_key is not None:
121
147
  raise ValueError("Please, select just either layer or embedding for computation.")
122
148
 
@@ -135,7 +161,10 @@ class PseudobulkSpace(PerturbationSpace):
135
161
  adata_emb.obs = adata.obs
136
162
  adata = adata_emb
137
163
 
138
- ps_adata = dc.get_pseudobulk(adata, sample_col=target_col, layer=layer_key, **kwargs) # type: ignore
164
+ adata.obs[target_col] = adata.obs[target_col].astype("category")
165
+ ps_adata = dc.get_pseudobulk(adata, sample_col=target_col, layer=layer_key, groups_col=groups_col, **kwargs) # type: ignore
166
+
167
+ ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")
139
168
 
140
169
  return ps_adata
141
170
 
@@ -164,6 +193,11 @@ class KMeansSpace(ClusteringSpace):
164
193
  return_object: if True returns the clustering object
165
194
  **kwargs: Are passed to sklearn's KMeans.
166
195
 
196
+ Returns:
197
+ If return_object is True, the adata and the clustering object is returned.
198
+ Otherwise, only the adata is returned. The adata is updated with a new .obs column as specified in cluster_key,
199
+ that stores the cluster labels.
200
+
167
201
  Examples:
168
202
  >>> import pertpy as pt
169
203
  >>> mdata = pt.dt.papalexi_2021()
@@ -193,6 +227,7 @@ class KMeansSpace(ClusteringSpace):
193
227
 
194
228
  clustering = KMeans(**kwargs).fit(self.X)
195
229
  adata.obs[cluster_key] = clustering.labels_
230
+ adata.obs[cluster_key] = adata.obs[cluster_key].astype("category")
196
231
 
197
232
  if return_object:
198
233
  return adata, clustering
@@ -212,7 +247,7 @@ class DBSCANSpace(ClusteringSpace):
212
247
  copy: bool = True,
213
248
  return_object: bool = False,
214
249
  **kwargs,
215
- ) -> tuple[AnnData, object | AnnData]:
250
+ ) -> tuple[AnnData, object] | AnnData:
216
251
  """Computes a clustering using Density-based spatial clustering of applications (DBSCAN).
217
252
 
218
253
  Args:
@@ -224,6 +259,11 @@ class DBSCANSpace(ClusteringSpace):
224
259
  return_object: if True returns the clustering object
225
260
  **kwargs: Are passed to sklearn's DBSCAN.
226
261
 
262
+ Returns:
263
+ If return_object is True, the adata and the clustering object is returned.
264
+ Otherwise, only the adata is returned. The adata is updated with a new .obs column as specified in cluster_key,
265
+ that stores the cluster labels.
266
+
227
267
  Examples:
228
268
  >>> import pertpy as pt
229
269
  >>> mdata = pt.dt.papalexi_2021()
@@ -250,6 +290,7 @@ class DBSCANSpace(ClusteringSpace):
250
290
 
251
291
  clustering = DBSCAN(**kwargs).fit(self.X)
252
292
  adata.obs[cluster_key] = clustering.labels_
293
+ adata.obs[cluster_key] = adata.obs[cluster_key].astype("category")
253
294
 
254
295
  if return_object:
255
296
  return adata, clustering
@@ -1 +1 @@
1
- from pertpy.tools._scgen._jax_scgen import SCGEN
1
+ from pertpy.tools._scgen._scgen import SCGEN