pertpy 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. pertpy/__init__.py +5 -1
  2. pertpy/_doc.py +1 -3
  3. pertpy/_types.py +6 -0
  4. pertpy/data/_dataloader.py +68 -24
  5. pertpy/data/_datasets.py +9 -9
  6. pertpy/metadata/__init__.py +2 -1
  7. pertpy/metadata/_cell_line.py +133 -25
  8. pertpy/metadata/_look_up.py +13 -19
  9. pertpy/metadata/_moa.py +1 -1
  10. pertpy/preprocessing/_guide_rna.py +138 -44
  11. pertpy/preprocessing/_guide_rna_mixture.py +17 -19
  12. pertpy/tools/__init__.py +1 -1
  13. pertpy/tools/_augur.py +106 -98
  14. pertpy/tools/_cinemaot.py +74 -114
  15. pertpy/tools/_coda/_base_coda.py +129 -145
  16. pertpy/tools/_coda/_sccoda.py +66 -69
  17. pertpy/tools/_coda/_tasccoda.py +71 -79
  18. pertpy/tools/_dialogue.py +48 -40
  19. pertpy/tools/_differential_gene_expression/_base.py +21 -31
  20. pertpy/tools/_differential_gene_expression/_checks.py +4 -6
  21. pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
  22. pertpy/tools/_differential_gene_expression/_edger.py +6 -10
  23. pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
  24. pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
  25. pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
  26. pertpy/tools/_distances/_distance_tests.py +1 -2
  27. pertpy/tools/_distances/_distances.py +31 -45
  28. pertpy/tools/_enrichment.py +7 -22
  29. pertpy/tools/_milo.py +19 -15
  30. pertpy/tools/_mixscape.py +73 -75
  31. pertpy/tools/_perturbation_space/_clustering.py +4 -4
  32. pertpy/tools/_perturbation_space/_comparison.py +4 -4
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
  35. pertpy/tools/_perturbation_space/_simple.py +12 -14
  36. pertpy/tools/_scgen/_scgen.py +16 -17
  37. pertpy/tools/_scgen/_scgenvae.py +2 -2
  38. pertpy/tools/_scgen/_utils.py +3 -1
  39. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/METADATA +36 -20
  40. pertpy-0.11.0.dist-info/RECORD +58 -0
  41. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
  42. pertpy/tools/_kernel_pca.py +0 -50
  43. pertpy-0.10.0.dist-info/RECORD +0 -58
  44. {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
pertpy/tools/_mixscape.py CHANGED
@@ -9,15 +9,14 @@ import numpy as np
9
9
  import pandas as pd
10
10
  import scanpy as sc
11
11
  import seaborn as sns
12
+ from fast_array_utils.stats import mean, mean_var
12
13
  from scanpy import get
13
- from scanpy._settings import settings
14
14
  from scanpy._utils import _check_use_raw, sanitize_anndata
15
15
  from scanpy.plotting import _utils
16
16
  from scanpy.tools._utils import _choose_representation
17
- from scipy.sparse import csr_matrix, issparse, spmatrix
17
+ from scipy.sparse import csr_matrix, spmatrix
18
18
  from sklearn.mixture import GaussianMixture
19
19
 
20
- import pertpy as pt
21
20
  from pertpy._doc import _doc_params, doc_common_plot_args
22
21
 
23
22
  if TYPE_CHECKING:
@@ -111,7 +110,7 @@ class Mixscape:
111
110
  for split in adata.obs[split_by].unique():
112
111
  split_mask = adata.obs[split_by] == split
113
112
  control_mask_group = control_mask & split_mask
114
- control_mean_expr = adata.X[control_mask_group].mean(0)
113
+ control_mean_expr = mean(adata.X[control_mask_group], axis=0)
115
114
  adata.layers["X_pert"][split_mask] = (
116
115
  np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0)
117
116
  - adata.layers["X_pert"][split_mask]
@@ -127,14 +126,14 @@ class Mixscape:
127
126
  if n_dims is not None and n_dims < representation.shape[1]:
128
127
  representation = representation[:, :n_dims]
129
128
 
129
+ from pynndescent import NNDescent
130
+
130
131
  for split_mask in split_masks:
131
132
  control_mask_split = control_mask & split_mask
132
133
 
133
134
  R_split = representation[split_mask]
134
135
  R_control = representation[np.asarray(control_mask_split)]
135
136
 
136
- from pynndescent import NNDescent
137
-
138
137
  eps = kwargs.pop("epsilon", 0.1)
139
138
  nn_index = NNDescent(R_control, **kwargs)
140
139
  indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
@@ -153,11 +152,10 @@ class Mixscape:
153
152
  shape=(n_split, n_control),
154
153
  )
155
154
  neigh_matrix /= n_neighbors
156
- adata.layers["X_pert"][split_mask] = (
157
- np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask]
155
+ adata.layers["X_pert"][np.asarray(split_mask)] = (
156
+ sc.pp.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][np.asarray(split_mask)]
158
157
  )
159
158
  else:
160
- is_sparse = issparse(X_control)
161
159
  split_indices = np.where(split_mask)[0]
162
160
  for i in range(0, n_split, batch_size):
163
161
  size = min(i + batch_size, n_split)
@@ -168,10 +166,9 @@ class Mixscape:
168
166
 
169
167
  size = size - i
170
168
 
171
- # sparse is very slow
172
169
  means_batch = X_control[batch]
173
- means_batch = means_batch.toarray() if is_sparse else means_batch
174
- means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1)
170
+ batch_reshaped = means_batch.reshape(size, n_neighbors, -1)
171
+ means_batch, _ = mean_var(batch_reshaped, axis=1)
175
172
 
176
173
  adata.layers["X_pert"][split_batch] = (
177
174
  np.log1p(means_batch) - adata.layers["X_pert"][split_batch]
@@ -199,6 +196,7 @@ class Mixscape:
199
196
  perturbation_type: str | None = "KO",
200
197
  random_state: int | None = 0,
201
198
  copy: bool | None = False,
199
+ **gmmkwargs,
202
200
  ):
203
201
  """Identify perturbed and non-perturbed gRNA expressing cells that accounts for multiple treatments/conditions/chemical perturbations.
204
202
 
@@ -221,6 +219,7 @@ class Mixscape:
221
219
  perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications.
222
220
  random_state: Random seed for the GaussianMixture model.
223
221
  copy: Determines whether a copy of the `adata` is returned.
222
+ **gmmkwargs: Passed to custom implementation of scikit-learn Gaussian Mixture Model.
224
223
 
225
224
  Returns:
226
225
  If `copy=True`, returns the copy of `adata` with the classification result in `.obs`.
@@ -307,10 +306,9 @@ class Mixscape:
307
306
 
308
307
  else:
309
308
  de_genes = perturbation_markers[(category, gene)]
310
- de_genes_indices = self._get_column_indices(adata, list(de_genes))
309
+ de_genes_indices = np.where(np.isin(adata.var_names, list(de_genes)))[0]
311
310
 
312
311
  dat = X[np.asarray(all_cells)][:, de_genes_indices]
313
- dat_cells = all_cells[all_cells].index
314
312
  if scale:
315
313
  dat = sc.pp.scale(dat)
316
314
 
@@ -318,6 +316,9 @@ class Mixscape:
318
316
  n_iter = 0
319
317
  old_classes = adata.obs[new_class_name][all_cells]
320
318
 
319
+ nt_cells_dat_idx = all_cells[all_cells].index.get_indexer(nt_cells[nt_cells].index)
320
+ nt_cells_mean = np.mean(dat[nt_cells_dat_idx], axis=0)
321
+
321
322
  while not converged and n_iter < iter_num:
322
323
  # Get all cells in current split&Gene
323
324
  guide_cells = (adata.obs[new_class_name] == gene) & split_mask
@@ -326,12 +327,12 @@ class Mixscape:
326
327
  # all cells in current split&Gene minus all NT cells in current split
327
328
  # Each row is for each cell, each column is for each gene, get mean for each column
328
329
  guide_cells_dat_idx = all_cells[all_cells].index.get_indexer(guide_cells[guide_cells].index)
329
- nt_cells_dat_idx = all_cells[all_cells].index.get_indexer(nt_cells[nt_cells].index)
330
- vec = np.mean(dat[guide_cells_dat_idx], axis=0) - np.mean(dat[nt_cells_dat_idx], axis=0)
330
+ guide_cells_mean = np.mean(dat[guide_cells_dat_idx], axis=0)
331
+ vec = guide_cells_mean - nt_cells_mean
331
332
 
332
333
  # project cells onto the perturbation vector
333
334
  if isinstance(dat, spmatrix):
334
- pvec = np.dot(dat.toarray(), vec) / np.dot(vec, vec)
335
+ pvec = dat.dot(vec) / np.dot(vec, vec)
335
336
  else:
336
337
  pvec = np.dot(dat, vec) / np.dot(vec, vec)
337
338
  pvec = pd.Series(np.asarray(pvec).flatten(), index=list(all_cells.index[all_cells]))
@@ -341,7 +342,7 @@ class Mixscape:
341
342
  gv["pvec"] = pvec
342
343
  gv[labels] = control
343
344
  gv.loc[guide_cells, labels] = gene
344
- if gene not in gv_list.keys():
345
+ if gene not in gv_list:
345
346
  gv_list[gene] = {}
346
347
  gv_list[gene][category] = gv
347
348
 
@@ -351,31 +352,30 @@ class Mixscape:
351
352
  n_components=2,
352
353
  covariance_type="spherical",
353
354
  means_init=means_init,
354
- precisions_init=1 / (std_init ** 2),
355
+ precisions_init=1 / (std_init**2),
355
356
  random_state=random_state,
356
- max_iter=5000,
357
+ max_iter=100,
357
358
  fixed_means=[pvec[nt_cells].mean(), None],
358
359
  fixed_covariances=[pvec[nt_cells].std() ** 2, None],
360
+ **gmmkwargs,
359
361
  ).fit(np.asarray(pvec).reshape(-1, 1))
360
362
  probabilities = mm.predict_proba(np.array(pvec[orig_guide_cells_index]).reshape(-1, 1))
361
363
  lik_ratio = probabilities[:, 0] / probabilities[:, 1]
362
364
  post_prob = 1 / (1 + lik_ratio)
363
365
 
364
366
  # based on the posterior probability, assign cells to the two classes
365
- adata.obs.loc[
366
- [orig_guide_cells_index[cell] for cell in np.where(post_prob > 0.5)[0]], new_class_name
367
- ] = gene
368
- adata.obs.loc[
369
- [orig_guide_cells_index[cell] for cell in np.where(post_prob <= 0.5)[0]], new_class_name
370
- ] = f"{gene} NP"
367
+ ko_mask = post_prob > 0.5
368
+ adata.obs.loc[np.array(orig_guide_cells_index)[ko_mask], new_class_name] = gene
369
+ adata.obs.loc[np.array(orig_guide_cells_index)[~ko_mask], new_class_name] = f"{gene} NP"
371
370
 
372
371
  if sum(adata.obs[new_class_name][split_mask] == gene) < min_de_genes:
373
372
  adata.obs.loc[guide_cells, new_class_name] = "NP"
374
373
  converged = True
375
- if adata.obs[new_class_name][all_cells].equals(old_classes):
374
+ current_classes = adata.obs[new_class_name][all_cells]
375
+ if (current_classes == old_classes).all():
376
376
  converged = True
377
+ old_classes = current_classes
377
378
 
378
- old_classes = adata.obs[new_class_name][all_cells]
379
379
  n_iter += 1
380
380
 
381
381
  adata.obs.loc[(adata.obs[new_class_name] == gene) & split_mask, new_class_name] = (
@@ -414,7 +414,6 @@ class Mixscape:
414
414
  control: Control category from the `pert_key` column.
415
415
  mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
416
416
  layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
417
- control: Control category from the `pert_key` column.
418
417
  n_comps: Number of principal components to use.
419
418
  min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
420
419
  logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
@@ -470,7 +469,8 @@ class Mixscape:
470
469
  )
471
470
  adata_subset = adata[
472
471
  (adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
473
- ].copy()
472
+ ]
473
+ X = adata_subset.X - adata_subset.X.mean(0)
474
474
  projected_pcs: dict[str, np.ndarray] = {}
475
475
  # performs PCA on each mixscape class separately and projects each subspace onto all cells in the data.
476
476
  for _, (key, value) in enumerate(perturbation_markers.items()):
@@ -482,16 +482,10 @@ class Mixscape:
482
482
  ].copy()
483
483
  sc.pp.scale(gene_subset)
484
484
  sc.tl.pca(gene_subset, n_comps=n_comps)
485
- sc.pp.neighbors(gene_subset)
486
- # projects each subspace onto all cells in the data.
487
- sc.tl.ingest(adata=adata_subset, adata_ref=gene_subset, embedding_method="pca")
488
- projected_pcs[key[1]] = adata_subset.obsm["X_pca"]
485
+ # project cells into PCA space of gene_subset
486
+ projected_pcs[key[1]] = np.asarray(np.dot(X, gene_subset.varm["PCs"]))
489
487
  # concatenate all pcs into a single matrix.
490
- for index, (_, value) in enumerate(projected_pcs.items()):
491
- if index == 0:
492
- projected_pcs_array = value
493
- else:
494
- projected_pcs_array = np.concatenate((projected_pcs_array, value), axis=1)
488
+ projected_pcs_array = np.concatenate(list(projected_pcs.values()), axis=1)
495
489
 
496
490
  clf = LinearDiscriminantAnalysis(n_components=len(np.unique(adata_subset.obs[labels])) - 1)
497
491
  clf.fit(projected_pcs_array, adata_subset.obs[labels])
@@ -514,7 +508,7 @@ class Mixscape:
514
508
  logfc_threshold: float,
515
509
  test_method: str,
516
510
  ) -> dict[tuple, np.ndarray]:
517
- """Determine gene sets across all splits/groups through differential gene expression
511
+ """Determine gene sets across all splits/groups through differential gene expression.
518
512
 
519
513
  Args:
520
514
  adata: :class:`~anndata.AnnData` object
@@ -549,7 +543,9 @@ class Mixscape:
549
543
  )
550
544
  # get DE genes for each target gene
551
545
  for gene in gene_targets:
552
- logfc_threshold_mask = np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
546
+ logfc_threshold_mask = (
547
+ np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
548
+ )
553
549
  de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask]
554
550
  pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask]
555
551
  de_genes = de_genes[pvals_adj < pval_cutoff]
@@ -559,19 +555,8 @@ class Mixscape:
559
555
 
560
556
  return perturbation_markers
561
557
 
562
- def _get_column_indices(self, adata, col_names):
563
- if isinstance(col_names, str): # pragma: no cover
564
- col_names = [col_names]
565
-
566
- indices = []
567
- for idx, col in enumerate(adata.var_names):
568
- if col in col_names:
569
- indices.append(idx)
570
-
571
- return indices
572
-
573
558
  @_doc_params(common_plot_args=doc_common_plot_args)
574
- def plot_barplot( # pragma: no cover
559
+ def plot_barplot( # pragma: no cover # noqa: D417
575
560
  self,
576
561
  adata: AnnData,
577
562
  guide_rna_column: str,
@@ -678,7 +663,7 @@ class Mixscape:
678
663
  return None
679
664
 
680
665
  @_doc_params(common_plot_args=doc_common_plot_args)
681
- def plot_heatmap( # pragma: no cover
666
+ def plot_heatmap( # pragma: no cover # noqa: D417
682
667
  self,
683
668
  adata: AnnData,
684
669
  labels: str,
@@ -748,7 +733,7 @@ class Mixscape:
748
733
  return None
749
734
 
750
735
  @_doc_params(common_plot_args=doc_common_plot_args)
751
- def plot_perturbscore( # pragma: no cover
736
+ def plot_perturbscore( # pragma: no cover # noqa: D417
752
737
  self,
753
738
  adata: AnnData,
754
739
  labels: str,
@@ -801,7 +786,7 @@ class Mixscape:
801
786
  if "mixscape" not in adata.uns:
802
787
  raise ValueError("Please run the `mixscape` function first.")
803
788
  perturbation_score = None
804
- for key in adata.uns["mixscape"][target_gene].keys():
789
+ for key in adata.uns["mixscape"][target_gene]:
805
790
  perturbation_score_temp = adata.uns["mixscape"][target_gene][key]
806
791
  perturbation_score_temp["name"] = key
807
792
  if perturbation_score is None:
@@ -914,7 +899,7 @@ class Mixscape:
914
899
  return None
915
900
 
916
901
  @_doc_params(common_plot_args=doc_common_plot_args)
917
- def plot_violin( # pragma: no cover
902
+ def plot_violin( # pragma: no cover # noqa: D417
918
903
  self,
919
904
  adata: AnnData,
920
905
  target_gene_idents: str | list[str],
@@ -994,7 +979,7 @@ class Mixscape:
994
979
  if len(ylabel) != 1:
995
980
  raise ValueError(f"Expected number of y-labels to be `1`, found `{len(ylabel)}`.")
996
981
  elif len(ylabel) != len(keys):
997
- raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, " f"found `{len(ylabel)}`.")
982
+ raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, found `{len(ylabel)}`.")
998
983
 
999
984
  if groupby is not None:
1000
985
  if hue is not None:
@@ -1047,7 +1032,7 @@ class Mixscape:
1047
1032
  g.set(yscale="log")
1048
1033
  g.set_titles(col_template="{col_name}").set_xlabels("")
1049
1034
  if rotation is not None:
1050
- for ax in g.axes[0]:
1035
+ for ax in g.axes[0]: # noqa: PLR1704
1051
1036
  ax.tick_params(axis="x", labelrotation=rotation)
1052
1037
  else:
1053
1038
  # set by default the violin plot cut=0 to limit the extend
@@ -1065,7 +1050,7 @@ class Mixscape:
1065
1050
  else:
1066
1051
  axs = [ax]
1067
1052
  for ax, y, ylab in zip(axs, ys, ylabel, strict=False):
1068
- ax = sns.violinplot(
1053
+ ax = sns.violinplot( # noqa: PLW2901
1069
1054
  x=x,
1070
1055
  y=y,
1071
1056
  data=obs_tidy,
@@ -1079,7 +1064,7 @@ class Mixscape:
1079
1064
  # Get the handles and labels.
1080
1065
  handles, labels = ax.get_legend_handles_labels()
1081
1066
  if stripplot:
1082
- ax = sns.stripplot(
1067
+ ax = sns.stripplot( # noqa: PLW2901
1083
1068
  x=x,
1084
1069
  y=y,
1085
1070
  data=obs_tidy,
@@ -1116,7 +1101,7 @@ class Mixscape:
1116
1101
  return None
1117
1102
 
1118
1103
  @_doc_params(common_plot_args=doc_common_plot_args)
1119
- def plot_lda( # pragma: no cover
1104
+ def plot_lda( # pragma: no cover # noqa: D417
1120
1105
  self,
1121
1106
  adata: AnnData,
1122
1107
  control: str,
@@ -1135,13 +1120,16 @@ class Mixscape:
1135
1120
  """Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
1136
1121
 
1137
1122
  Args:
1138
- adata: The annotated data object.
1123
+ adata: The annotated data objectplot_heatmap.
1139
1124
  control: Control category from the `pert_key` column.
1140
1125
  mixscape_class: The column of `.obs` with the mixscape classification result.
1141
1126
  mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
1142
1127
  perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
1143
- lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
1144
1128
  n_components: The number of dimensions of the embedding.
1129
+ lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
1130
+ color_map: Matplotlib color map.
1131
+ palette: Matplotlib palette.
1132
+ ax: Matplotlib axes.
1145
1133
  {common_plot_args}
1146
1134
  **kwds: Additional arguments to `scanpy.pl.umap`.
1147
1135
 
@@ -1186,13 +1174,14 @@ class Mixscape:
1186
1174
  plt.show()
1187
1175
  return None
1188
1176
 
1177
+
1189
1178
  class MixscapeGaussianMixture(GaussianMixture):
1190
1179
  def __init__(
1191
1180
  self,
1192
1181
  n_components: int,
1193
- fixed_means: Sequence[float] | None = None,
1182
+ fixed_means: Sequence[float] | None = None,
1194
1183
  fixed_covariances: Sequence[float] | None = None,
1195
- **kwargs
1184
+ **kwargs,
1196
1185
  ):
1197
1186
  """Custom Gaussian Mixture Model where means and covariances can be fixed for specific components.
1198
1187
 
@@ -1206,19 +1195,28 @@ class MixscapeGaussianMixture(GaussianMixture):
1206
1195
  self.fixed_means = fixed_means
1207
1196
  self.fixed_covariances = fixed_covariances
1208
1197
 
1198
+ self.fixed_mean_indices = []
1199
+ self.fixed_mean_values = []
1200
+ if fixed_means is not None:
1201
+ self.fixed_mean_indices = [i for i, m in enumerate(fixed_means) if m is not None]
1202
+ if self.fixed_mean_indices:
1203
+ self.fixed_mean_values = np.array([fixed_means[i] for i in self.fixed_mean_indices])
1204
+
1205
+ self.fixed_cov_indices = []
1206
+ self.fixed_cov_values = []
1207
+ if fixed_covariances is not None:
1208
+ self.fixed_cov_indices = [i for i, c in enumerate(fixed_covariances) if c is not None]
1209
+ if self.fixed_cov_indices:
1210
+ self.fixed_cov_values = np.array([fixed_covariances[i] for i in self.fixed_cov_indices])
1211
+
1209
1212
  def _m_step(self, X: np.ndarray, log_resp: np.ndarray):
1210
1213
  """Modified M-step to respect fixed means and covariances."""
1211
1214
  super()._m_step(X, log_resp)
1212
1215
 
1213
- if self.fixed_means is not None:
1214
- for i in range(self.n_components):
1215
- if self.fixed_means[i] is not None:
1216
- self.means_[i] = self.fixed_means[i]
1216
+ if self.fixed_mean_indices:
1217
+ self.means_[self.fixed_mean_indices] = self.fixed_mean_values
1217
1218
 
1218
- if self.fixed_covariances is not None:
1219
- for i in range(self.n_components):
1220
- if self.fixed_covariances[i] is not None:
1221
- self.covariances_[i] = self.fixed_covariances[i]
1219
+ if self.fixed_cov_indices:
1220
+ self.covariances_[self.fixed_cov_indices] = self.fixed_cov_values
1222
1221
 
1223
1222
  return self
1224
-
@@ -76,13 +76,13 @@ class ClusteringSpace(PerturbationSpace):
76
76
  if metric == "asw":
77
77
  from pertpy.tools._perturbation_space._metrics import asw
78
78
 
79
- if "metric" not in kwargs.keys():
79
+ if "metric" not in kwargs:
80
80
  kwargs["metric"] = "euclidean"
81
- if "distances" not in kwargs.keys():
81
+ if "distances" not in kwargs:
82
82
  distances = pairwise_distances(self.X, metric=kwargs["metric"])
83
- if "sample_size" not in kwargs.keys():
83
+ if "sample_size" not in kwargs:
84
84
  kwargs["sample_size"] = None
85
- if "random_state" not in kwargs.keys():
85
+ if "random_state" not in kwargs:
86
86
  kwargs["random_state"] = None
87
87
 
88
88
  asw_score = asw(
@@ -1,7 +1,6 @@
1
1
  from typing import TYPE_CHECKING
2
2
 
3
3
  import numpy as np
4
- import pynndescent
5
4
  from scipy.sparse import issparse
6
5
  from scipy.sparse import vstack as sp_vstack
7
6
  from sklearn.base import ClassifierMixin
@@ -95,7 +94,9 @@ class PerturbationComparison:
95
94
  labels[-control.shape[0] :] = "ctrl"
96
95
  label_groups.append("ctrl")
97
96
 
98
- index = pynndescent.NNDescent(
97
+ from pynndescent import NNDescent
98
+
99
+ index = NNDescent(
99
100
  index_data,
100
101
  n_neighbors=max(50, n_neighbors),
101
102
  random_state=random_state,
@@ -106,7 +107,6 @@ class PerturbationComparison:
106
107
  uq, uq_counts = np.unique(labels[indices], return_counts=True)
107
108
  uq_counts_norm = uq_counts / uq_counts.sum()
108
109
  counts = dict(zip(label_groups, [0.0] * len(label_groups), strict=False))
109
- for group, count_norm in zip(uq, uq_counts_norm, strict=False):
110
- counts[group] = count_norm
110
+ counts = dict(zip(uq, uq_counts_norm, strict=False))
111
111
 
112
112
  return counts