pertpy 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +5 -1
- pertpy/_doc.py +1 -3
- pertpy/_types.py +6 -0
- pertpy/data/_dataloader.py +68 -24
- pertpy/data/_datasets.py +9 -9
- pertpy/metadata/__init__.py +2 -1
- pertpy/metadata/_cell_line.py +133 -25
- pertpy/metadata/_look_up.py +13 -19
- pertpy/metadata/_moa.py +1 -1
- pertpy/preprocessing/_guide_rna.py +138 -44
- pertpy/preprocessing/_guide_rna_mixture.py +17 -19
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +106 -98
- pertpy/tools/_cinemaot.py +74 -114
- pertpy/tools/_coda/_base_coda.py +129 -145
- pertpy/tools/_coda/_sccoda.py +66 -69
- pertpy/tools/_coda/_tasccoda.py +71 -79
- pertpy/tools/_dialogue.py +48 -40
- pertpy/tools/_differential_gene_expression/_base.py +21 -31
- pertpy/tools/_differential_gene_expression/_checks.py +4 -6
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
- pertpy/tools/_differential_gene_expression/_edger.py +6 -10
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
- pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
- pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
- pertpy/tools/_distances/_distance_tests.py +1 -2
- pertpy/tools/_distances/_distances.py +31 -45
- pertpy/tools/_enrichment.py +7 -22
- pertpy/tools/_milo.py +19 -15
- pertpy/tools/_mixscape.py +73 -75
- pertpy/tools/_perturbation_space/_clustering.py +4 -4
- pertpy/tools/_perturbation_space/_comparison.py +4 -4
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
- pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
- pertpy/tools/_perturbation_space/_simple.py +12 -14
- pertpy/tools/_scgen/_scgen.py +16 -17
- pertpy/tools/_scgen/_scgenvae.py +2 -2
- pertpy/tools/_scgen/_utils.py +3 -1
- {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/METADATA +36 -20
- pertpy-0.11.0.dist-info/RECORD +58 -0
- {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
- pertpy/tools/_kernel_pca.py +0 -50
- pertpy-0.10.0.dist-info/RECORD +0 -58
- {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
pertpy/tools/_mixscape.py
CHANGED
@@ -9,15 +9,14 @@ import numpy as np
|
|
9
9
|
import pandas as pd
|
10
10
|
import scanpy as sc
|
11
11
|
import seaborn as sns
|
12
|
+
from fast_array_utils.stats import mean, mean_var
|
12
13
|
from scanpy import get
|
13
|
-
from scanpy._settings import settings
|
14
14
|
from scanpy._utils import _check_use_raw, sanitize_anndata
|
15
15
|
from scanpy.plotting import _utils
|
16
16
|
from scanpy.tools._utils import _choose_representation
|
17
|
-
from scipy.sparse import csr_matrix,
|
17
|
+
from scipy.sparse import csr_matrix, spmatrix
|
18
18
|
from sklearn.mixture import GaussianMixture
|
19
19
|
|
20
|
-
import pertpy as pt
|
21
20
|
from pertpy._doc import _doc_params, doc_common_plot_args
|
22
21
|
|
23
22
|
if TYPE_CHECKING:
|
@@ -111,7 +110,7 @@ class Mixscape:
|
|
111
110
|
for split in adata.obs[split_by].unique():
|
112
111
|
split_mask = adata.obs[split_by] == split
|
113
112
|
control_mask_group = control_mask & split_mask
|
114
|
-
control_mean_expr = adata.X[control_mask_group]
|
113
|
+
control_mean_expr = mean(adata.X[control_mask_group], axis=0)
|
115
114
|
adata.layers["X_pert"][split_mask] = (
|
116
115
|
np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0)
|
117
116
|
- adata.layers["X_pert"][split_mask]
|
@@ -127,14 +126,14 @@ class Mixscape:
|
|
127
126
|
if n_dims is not None and n_dims < representation.shape[1]:
|
128
127
|
representation = representation[:, :n_dims]
|
129
128
|
|
129
|
+
from pynndescent import NNDescent
|
130
|
+
|
130
131
|
for split_mask in split_masks:
|
131
132
|
control_mask_split = control_mask & split_mask
|
132
133
|
|
133
134
|
R_split = representation[split_mask]
|
134
135
|
R_control = representation[np.asarray(control_mask_split)]
|
135
136
|
|
136
|
-
from pynndescent import NNDescent
|
137
|
-
|
138
137
|
eps = kwargs.pop("epsilon", 0.1)
|
139
138
|
nn_index = NNDescent(R_control, **kwargs)
|
140
139
|
indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
|
@@ -153,11 +152,10 @@ class Mixscape:
|
|
153
152
|
shape=(n_split, n_control),
|
154
153
|
)
|
155
154
|
neigh_matrix /= n_neighbors
|
156
|
-
adata.layers["X_pert"][split_mask] = (
|
157
|
-
|
155
|
+
adata.layers["X_pert"][np.asarray(split_mask)] = (
|
156
|
+
sc.pp.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][np.asarray(split_mask)]
|
158
157
|
)
|
159
158
|
else:
|
160
|
-
is_sparse = issparse(X_control)
|
161
159
|
split_indices = np.where(split_mask)[0]
|
162
160
|
for i in range(0, n_split, batch_size):
|
163
161
|
size = min(i + batch_size, n_split)
|
@@ -168,10 +166,9 @@ class Mixscape:
|
|
168
166
|
|
169
167
|
size = size - i
|
170
168
|
|
171
|
-
# sparse is very slow
|
172
169
|
means_batch = X_control[batch]
|
173
|
-
|
174
|
-
means_batch =
|
170
|
+
batch_reshaped = means_batch.reshape(size, n_neighbors, -1)
|
171
|
+
means_batch, _ = mean_var(batch_reshaped, axis=1)
|
175
172
|
|
176
173
|
adata.layers["X_pert"][split_batch] = (
|
177
174
|
np.log1p(means_batch) - adata.layers["X_pert"][split_batch]
|
@@ -199,6 +196,7 @@ class Mixscape:
|
|
199
196
|
perturbation_type: str | None = "KO",
|
200
197
|
random_state: int | None = 0,
|
201
198
|
copy: bool | None = False,
|
199
|
+
**gmmkwargs,
|
202
200
|
):
|
203
201
|
"""Identify perturbed and non-perturbed gRNA expressing cells that accounts for multiple treatments/conditions/chemical perturbations.
|
204
202
|
|
@@ -221,6 +219,7 @@ class Mixscape:
|
|
221
219
|
perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
222
220
|
random_state: Random seed for the GaussianMixture model.
|
223
221
|
copy: Determines whether a copy of the `adata` is returned.
|
222
|
+
**gmmkwargs: Passed to custom implementation of scikit-learn Gaussian Mixture Model.
|
224
223
|
|
225
224
|
Returns:
|
226
225
|
If `copy=True`, returns the copy of `adata` with the classification result in `.obs`.
|
@@ -307,10 +306,9 @@ class Mixscape:
|
|
307
306
|
|
308
307
|
else:
|
309
308
|
de_genes = perturbation_markers[(category, gene)]
|
310
|
-
de_genes_indices =
|
309
|
+
de_genes_indices = np.where(np.isin(adata.var_names, list(de_genes)))[0]
|
311
310
|
|
312
311
|
dat = X[np.asarray(all_cells)][:, de_genes_indices]
|
313
|
-
dat_cells = all_cells[all_cells].index
|
314
312
|
if scale:
|
315
313
|
dat = sc.pp.scale(dat)
|
316
314
|
|
@@ -318,6 +316,9 @@ class Mixscape:
|
|
318
316
|
n_iter = 0
|
319
317
|
old_classes = adata.obs[new_class_name][all_cells]
|
320
318
|
|
319
|
+
nt_cells_dat_idx = all_cells[all_cells].index.get_indexer(nt_cells[nt_cells].index)
|
320
|
+
nt_cells_mean = np.mean(dat[nt_cells_dat_idx], axis=0)
|
321
|
+
|
321
322
|
while not converged and n_iter < iter_num:
|
322
323
|
# Get all cells in current split&Gene
|
323
324
|
guide_cells = (adata.obs[new_class_name] == gene) & split_mask
|
@@ -326,12 +327,12 @@ class Mixscape:
|
|
326
327
|
# all cells in current split&Gene minus all NT cells in current split
|
327
328
|
# Each row is for each cell, each column is for each gene, get mean for each column
|
328
329
|
guide_cells_dat_idx = all_cells[all_cells].index.get_indexer(guide_cells[guide_cells].index)
|
329
|
-
|
330
|
-
vec =
|
330
|
+
guide_cells_mean = np.mean(dat[guide_cells_dat_idx], axis=0)
|
331
|
+
vec = guide_cells_mean - nt_cells_mean
|
331
332
|
|
332
333
|
# project cells onto the perturbation vector
|
333
334
|
if isinstance(dat, spmatrix):
|
334
|
-
pvec =
|
335
|
+
pvec = dat.dot(vec) / np.dot(vec, vec)
|
335
336
|
else:
|
336
337
|
pvec = np.dot(dat, vec) / np.dot(vec, vec)
|
337
338
|
pvec = pd.Series(np.asarray(pvec).flatten(), index=list(all_cells.index[all_cells]))
|
@@ -341,7 +342,7 @@ class Mixscape:
|
|
341
342
|
gv["pvec"] = pvec
|
342
343
|
gv[labels] = control
|
343
344
|
gv.loc[guide_cells, labels] = gene
|
344
|
-
if gene not in gv_list
|
345
|
+
if gene not in gv_list:
|
345
346
|
gv_list[gene] = {}
|
346
347
|
gv_list[gene][category] = gv
|
347
348
|
|
@@ -351,31 +352,30 @@ class Mixscape:
|
|
351
352
|
n_components=2,
|
352
353
|
covariance_type="spherical",
|
353
354
|
means_init=means_init,
|
354
|
-
precisions_init=1 / (std_init
|
355
|
+
precisions_init=1 / (std_init**2),
|
355
356
|
random_state=random_state,
|
356
|
-
max_iter=
|
357
|
+
max_iter=100,
|
357
358
|
fixed_means=[pvec[nt_cells].mean(), None],
|
358
359
|
fixed_covariances=[pvec[nt_cells].std() ** 2, None],
|
360
|
+
**gmmkwargs,
|
359
361
|
).fit(np.asarray(pvec).reshape(-1, 1))
|
360
362
|
probabilities = mm.predict_proba(np.array(pvec[orig_guide_cells_index]).reshape(-1, 1))
|
361
363
|
lik_ratio = probabilities[:, 0] / probabilities[:, 1]
|
362
364
|
post_prob = 1 / (1 + lik_ratio)
|
363
365
|
|
364
366
|
# based on the posterior probability, assign cells to the two classes
|
365
|
-
|
366
|
-
|
367
|
-
] = gene
|
368
|
-
adata.obs.loc[
|
369
|
-
[orig_guide_cells_index[cell] for cell in np.where(post_prob <= 0.5)[0]], new_class_name
|
370
|
-
] = f"{gene} NP"
|
367
|
+
ko_mask = post_prob > 0.5
|
368
|
+
adata.obs.loc[np.array(orig_guide_cells_index)[ko_mask], new_class_name] = gene
|
369
|
+
adata.obs.loc[np.array(orig_guide_cells_index)[~ko_mask], new_class_name] = f"{gene} NP"
|
371
370
|
|
372
371
|
if sum(adata.obs[new_class_name][split_mask] == gene) < min_de_genes:
|
373
372
|
adata.obs.loc[guide_cells, new_class_name] = "NP"
|
374
373
|
converged = True
|
375
|
-
|
374
|
+
current_classes = adata.obs[new_class_name][all_cells]
|
375
|
+
if (current_classes == old_classes).all():
|
376
376
|
converged = True
|
377
|
+
old_classes = current_classes
|
377
378
|
|
378
|
-
old_classes = adata.obs[new_class_name][all_cells]
|
379
379
|
n_iter += 1
|
380
380
|
|
381
381
|
adata.obs.loc[(adata.obs[new_class_name] == gene) & split_mask, new_class_name] = (
|
@@ -414,7 +414,6 @@ class Mixscape:
|
|
414
414
|
control: Control category from the `pert_key` column.
|
415
415
|
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
|
416
416
|
layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
|
417
|
-
control: Control category from the `pert_key` column.
|
418
417
|
n_comps: Number of principal components to use.
|
419
418
|
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
420
419
|
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
|
@@ -470,7 +469,8 @@ class Mixscape:
|
|
470
469
|
)
|
471
470
|
adata_subset = adata[
|
472
471
|
(adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
|
473
|
-
]
|
472
|
+
]
|
473
|
+
X = adata_subset.X - adata_subset.X.mean(0)
|
474
474
|
projected_pcs: dict[str, np.ndarray] = {}
|
475
475
|
# performs PCA on each mixscape class separately and projects each subspace onto all cells in the data.
|
476
476
|
for _, (key, value) in enumerate(perturbation_markers.items()):
|
@@ -482,16 +482,10 @@ class Mixscape:
|
|
482
482
|
].copy()
|
483
483
|
sc.pp.scale(gene_subset)
|
484
484
|
sc.tl.pca(gene_subset, n_comps=n_comps)
|
485
|
-
|
486
|
-
|
487
|
-
sc.tl.ingest(adata=adata_subset, adata_ref=gene_subset, embedding_method="pca")
|
488
|
-
projected_pcs[key[1]] = adata_subset.obsm["X_pca"]
|
485
|
+
# project cells into PCA space of gene_subset
|
486
|
+
projected_pcs[key[1]] = np.asarray(np.dot(X, gene_subset.varm["PCs"]))
|
489
487
|
# concatenate all pcs into a single matrix.
|
490
|
-
|
491
|
-
if index == 0:
|
492
|
-
projected_pcs_array = value
|
493
|
-
else:
|
494
|
-
projected_pcs_array = np.concatenate((projected_pcs_array, value), axis=1)
|
488
|
+
projected_pcs_array = np.concatenate(list(projected_pcs.values()), axis=1)
|
495
489
|
|
496
490
|
clf = LinearDiscriminantAnalysis(n_components=len(np.unique(adata_subset.obs[labels])) - 1)
|
497
491
|
clf.fit(projected_pcs_array, adata_subset.obs[labels])
|
@@ -514,7 +508,7 @@ class Mixscape:
|
|
514
508
|
logfc_threshold: float,
|
515
509
|
test_method: str,
|
516
510
|
) -> dict[tuple, np.ndarray]:
|
517
|
-
"""Determine gene sets across all splits/groups through differential gene expression
|
511
|
+
"""Determine gene sets across all splits/groups through differential gene expression.
|
518
512
|
|
519
513
|
Args:
|
520
514
|
adata: :class:`~anndata.AnnData` object
|
@@ -549,7 +543,9 @@ class Mixscape:
|
|
549
543
|
)
|
550
544
|
# get DE genes for each target gene
|
551
545
|
for gene in gene_targets:
|
552
|
-
logfc_threshold_mask =
|
546
|
+
logfc_threshold_mask = (
|
547
|
+
np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
|
548
|
+
)
|
553
549
|
de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask]
|
554
550
|
pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask]
|
555
551
|
de_genes = de_genes[pvals_adj < pval_cutoff]
|
@@ -559,19 +555,8 @@ class Mixscape:
|
|
559
555
|
|
560
556
|
return perturbation_markers
|
561
557
|
|
562
|
-
def _get_column_indices(self, adata, col_names):
|
563
|
-
if isinstance(col_names, str): # pragma: no cover
|
564
|
-
col_names = [col_names]
|
565
|
-
|
566
|
-
indices = []
|
567
|
-
for idx, col in enumerate(adata.var_names):
|
568
|
-
if col in col_names:
|
569
|
-
indices.append(idx)
|
570
|
-
|
571
|
-
return indices
|
572
|
-
|
573
558
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
574
|
-
def plot_barplot( # pragma: no cover
|
559
|
+
def plot_barplot( # pragma: no cover # noqa: D417
|
575
560
|
self,
|
576
561
|
adata: AnnData,
|
577
562
|
guide_rna_column: str,
|
@@ -678,7 +663,7 @@ class Mixscape:
|
|
678
663
|
return None
|
679
664
|
|
680
665
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
681
|
-
def plot_heatmap( # pragma: no cover
|
666
|
+
def plot_heatmap( # pragma: no cover # noqa: D417
|
682
667
|
self,
|
683
668
|
adata: AnnData,
|
684
669
|
labels: str,
|
@@ -748,7 +733,7 @@ class Mixscape:
|
|
748
733
|
return None
|
749
734
|
|
750
735
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
751
|
-
def plot_perturbscore( # pragma: no cover
|
736
|
+
def plot_perturbscore( # pragma: no cover # noqa: D417
|
752
737
|
self,
|
753
738
|
adata: AnnData,
|
754
739
|
labels: str,
|
@@ -801,7 +786,7 @@ class Mixscape:
|
|
801
786
|
if "mixscape" not in adata.uns:
|
802
787
|
raise ValueError("Please run the `mixscape` function first.")
|
803
788
|
perturbation_score = None
|
804
|
-
for key in adata.uns["mixscape"][target_gene]
|
789
|
+
for key in adata.uns["mixscape"][target_gene]:
|
805
790
|
perturbation_score_temp = adata.uns["mixscape"][target_gene][key]
|
806
791
|
perturbation_score_temp["name"] = key
|
807
792
|
if perturbation_score is None:
|
@@ -914,7 +899,7 @@ class Mixscape:
|
|
914
899
|
return None
|
915
900
|
|
916
901
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
917
|
-
def plot_violin( # pragma: no cover
|
902
|
+
def plot_violin( # pragma: no cover # noqa: D417
|
918
903
|
self,
|
919
904
|
adata: AnnData,
|
920
905
|
target_gene_idents: str | list[str],
|
@@ -994,7 +979,7 @@ class Mixscape:
|
|
994
979
|
if len(ylabel) != 1:
|
995
980
|
raise ValueError(f"Expected number of y-labels to be `1`, found `{len(ylabel)}`.")
|
996
981
|
elif len(ylabel) != len(keys):
|
997
|
-
raise ValueError(f"Expected number of y-labels to be `{len(keys)}`,
|
982
|
+
raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, found `{len(ylabel)}`.")
|
998
983
|
|
999
984
|
if groupby is not None:
|
1000
985
|
if hue is not None:
|
@@ -1047,7 +1032,7 @@ class Mixscape:
|
|
1047
1032
|
g.set(yscale="log")
|
1048
1033
|
g.set_titles(col_template="{col_name}").set_xlabels("")
|
1049
1034
|
if rotation is not None:
|
1050
|
-
for ax in g.axes[0]:
|
1035
|
+
for ax in g.axes[0]: # noqa: PLR1704
|
1051
1036
|
ax.tick_params(axis="x", labelrotation=rotation)
|
1052
1037
|
else:
|
1053
1038
|
# set by default the violin plot cut=0 to limit the extend
|
@@ -1065,7 +1050,7 @@ class Mixscape:
|
|
1065
1050
|
else:
|
1066
1051
|
axs = [ax]
|
1067
1052
|
for ax, y, ylab in zip(axs, ys, ylabel, strict=False):
|
1068
|
-
ax = sns.violinplot(
|
1053
|
+
ax = sns.violinplot( # noqa: PLW2901
|
1069
1054
|
x=x,
|
1070
1055
|
y=y,
|
1071
1056
|
data=obs_tidy,
|
@@ -1079,7 +1064,7 @@ class Mixscape:
|
|
1079
1064
|
# Get the handles and labels.
|
1080
1065
|
handles, labels = ax.get_legend_handles_labels()
|
1081
1066
|
if stripplot:
|
1082
|
-
ax = sns.stripplot(
|
1067
|
+
ax = sns.stripplot( # noqa: PLW2901
|
1083
1068
|
x=x,
|
1084
1069
|
y=y,
|
1085
1070
|
data=obs_tidy,
|
@@ -1116,7 +1101,7 @@ class Mixscape:
|
|
1116
1101
|
return None
|
1117
1102
|
|
1118
1103
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
1119
|
-
def plot_lda( # pragma: no cover
|
1104
|
+
def plot_lda( # pragma: no cover # noqa: D417
|
1120
1105
|
self,
|
1121
1106
|
adata: AnnData,
|
1122
1107
|
control: str,
|
@@ -1135,13 +1120,16 @@ class Mixscape:
|
|
1135
1120
|
"""Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
|
1136
1121
|
|
1137
1122
|
Args:
|
1138
|
-
adata: The annotated data
|
1123
|
+
adata: The annotated data objectplot_heatmap.
|
1139
1124
|
control: Control category from the `pert_key` column.
|
1140
1125
|
mixscape_class: The column of `.obs` with the mixscape classification result.
|
1141
1126
|
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
|
1142
1127
|
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
1143
|
-
lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
|
1144
1128
|
n_components: The number of dimensions of the embedding.
|
1129
|
+
lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
|
1130
|
+
color_map: Matplotlib color map.
|
1131
|
+
palette: Matplotlib palette.
|
1132
|
+
ax: Matplotlib axes.
|
1145
1133
|
{common_plot_args}
|
1146
1134
|
**kwds: Additional arguments to `scanpy.pl.umap`.
|
1147
1135
|
|
@@ -1186,13 +1174,14 @@ class Mixscape:
|
|
1186
1174
|
plt.show()
|
1187
1175
|
return None
|
1188
1176
|
|
1177
|
+
|
1189
1178
|
class MixscapeGaussianMixture(GaussianMixture):
|
1190
1179
|
def __init__(
|
1191
1180
|
self,
|
1192
1181
|
n_components: int,
|
1193
|
-
fixed_means:
|
1182
|
+
fixed_means: Sequence[float] | None = None,
|
1194
1183
|
fixed_covariances: Sequence[float] | None = None,
|
1195
|
-
**kwargs
|
1184
|
+
**kwargs,
|
1196
1185
|
):
|
1197
1186
|
"""Custom Gaussian Mixture Model where means and covariances can be fixed for specific components.
|
1198
1187
|
|
@@ -1206,19 +1195,28 @@ class MixscapeGaussianMixture(GaussianMixture):
|
|
1206
1195
|
self.fixed_means = fixed_means
|
1207
1196
|
self.fixed_covariances = fixed_covariances
|
1208
1197
|
|
1198
|
+
self.fixed_mean_indices = []
|
1199
|
+
self.fixed_mean_values = []
|
1200
|
+
if fixed_means is not None:
|
1201
|
+
self.fixed_mean_indices = [i for i, m in enumerate(fixed_means) if m is not None]
|
1202
|
+
if self.fixed_mean_indices:
|
1203
|
+
self.fixed_mean_values = np.array([fixed_means[i] for i in self.fixed_mean_indices])
|
1204
|
+
|
1205
|
+
self.fixed_cov_indices = []
|
1206
|
+
self.fixed_cov_values = []
|
1207
|
+
if fixed_covariances is not None:
|
1208
|
+
self.fixed_cov_indices = [i for i, c in enumerate(fixed_covariances) if c is not None]
|
1209
|
+
if self.fixed_cov_indices:
|
1210
|
+
self.fixed_cov_values = np.array([fixed_covariances[i] for i in self.fixed_cov_indices])
|
1211
|
+
|
1209
1212
|
def _m_step(self, X: np.ndarray, log_resp: np.ndarray):
|
1210
1213
|
"""Modified M-step to respect fixed means and covariances."""
|
1211
1214
|
super()._m_step(X, log_resp)
|
1212
1215
|
|
1213
|
-
if self.
|
1214
|
-
|
1215
|
-
if self.fixed_means[i] is not None:
|
1216
|
-
self.means_[i] = self.fixed_means[i]
|
1216
|
+
if self.fixed_mean_indices:
|
1217
|
+
self.means_[self.fixed_mean_indices] = self.fixed_mean_values
|
1217
1218
|
|
1218
|
-
if self.
|
1219
|
-
|
1220
|
-
if self.fixed_covariances[i] is not None:
|
1221
|
-
self.covariances_[i] = self.fixed_covariances[i]
|
1219
|
+
if self.fixed_cov_indices:
|
1220
|
+
self.covariances_[self.fixed_cov_indices] = self.fixed_cov_values
|
1222
1221
|
|
1223
1222
|
return self
|
1224
|
-
|
@@ -76,13 +76,13 @@ class ClusteringSpace(PerturbationSpace):
|
|
76
76
|
if metric == "asw":
|
77
77
|
from pertpy.tools._perturbation_space._metrics import asw
|
78
78
|
|
79
|
-
if "metric" not in kwargs
|
79
|
+
if "metric" not in kwargs:
|
80
80
|
kwargs["metric"] = "euclidean"
|
81
|
-
if "distances" not in kwargs
|
81
|
+
if "distances" not in kwargs:
|
82
82
|
distances = pairwise_distances(self.X, metric=kwargs["metric"])
|
83
|
-
if "sample_size" not in kwargs
|
83
|
+
if "sample_size" not in kwargs:
|
84
84
|
kwargs["sample_size"] = None
|
85
|
-
if "random_state" not in kwargs
|
85
|
+
if "random_state" not in kwargs:
|
86
86
|
kwargs["random_state"] = None
|
87
87
|
|
88
88
|
asw_score = asw(
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from typing import TYPE_CHECKING
|
2
2
|
|
3
3
|
import numpy as np
|
4
|
-
import pynndescent
|
5
4
|
from scipy.sparse import issparse
|
6
5
|
from scipy.sparse import vstack as sp_vstack
|
7
6
|
from sklearn.base import ClassifierMixin
|
@@ -95,7 +94,9 @@ class PerturbationComparison:
|
|
95
94
|
labels[-control.shape[0] :] = "ctrl"
|
96
95
|
label_groups.append("ctrl")
|
97
96
|
|
98
|
-
|
97
|
+
from pynndescent import NNDescent
|
98
|
+
|
99
|
+
index = NNDescent(
|
99
100
|
index_data,
|
100
101
|
n_neighbors=max(50, n_neighbors),
|
101
102
|
random_state=random_state,
|
@@ -106,7 +107,6 @@ class PerturbationComparison:
|
|
106
107
|
uq, uq_counts = np.unique(labels[indices], return_counts=True)
|
107
108
|
uq_counts_norm = uq_counts / uq_counts.sum()
|
108
109
|
counts = dict(zip(label_groups, [0.0] * len(label_groups), strict=False))
|
109
|
-
|
110
|
-
counts[group] = count_norm
|
110
|
+
counts = dict(zip(uq, uq_counts_norm, strict=False))
|
111
111
|
|
112
112
|
return counts
|