pertpy 0.9.5__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. pertpy/__init__.py +5 -1
  2. pertpy/_doc.py +2 -5
  3. pertpy/_types.py +6 -0
  4. pertpy/data/_dataloader.py +68 -24
  5. pertpy/data/_datasets.py +9 -9
  6. pertpy/metadata/__init__.py +2 -1
  7. pertpy/metadata/_cell_line.py +136 -30
  8. pertpy/metadata/_look_up.py +13 -19
  9. pertpy/metadata/_moa.py +1 -1
  10. pertpy/preprocessing/_guide_rna.py +221 -39
  11. pertpy/preprocessing/_guide_rna_mixture.py +177 -0
  12. pertpy/tools/__init__.py +1 -1
  13. pertpy/tools/_augur.py +138 -142
  14. pertpy/tools/_cinemaot.py +75 -117
  15. pertpy/tools/_coda/_base_coda.py +150 -174
  16. pertpy/tools/_coda/_sccoda.py +66 -69
  17. pertpy/tools/_coda/_tasccoda.py +71 -79
  18. pertpy/tools/_dialogue.py +60 -56
  19. pertpy/tools/_differential_gene_expression/_base.py +25 -43
  20. pertpy/tools/_differential_gene_expression/_checks.py +4 -6
  21. pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
  22. pertpy/tools/_differential_gene_expression/_edger.py +6 -10
  23. pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
  24. pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
  25. pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
  26. pertpy/tools/_distances/_distance_tests.py +1 -2
  27. pertpy/tools/_distances/_distances.py +86 -92
  28. pertpy/tools/_enrichment.py +8 -25
  29. pertpy/tools/_milo.py +23 -27
  30. pertpy/tools/_mixscape.py +261 -175
  31. pertpy/tools/_perturbation_space/_clustering.py +4 -4
  32. pertpy/tools/_perturbation_space/_comparison.py +4 -4
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
  35. pertpy/tools/_perturbation_space/_simple.py +13 -17
  36. pertpy/tools/_scgen/_scgen.py +17 -20
  37. pertpy/tools/_scgen/_scgenvae.py +2 -2
  38. pertpy/tools/_scgen/_utils.py +3 -1
  39. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/METADATA +37 -21
  40. pertpy-0.11.0.dist-info/RECORD +58 -0
  41. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
  42. pertpy/tools/_kernel_pca.py +0 -50
  43. pertpy-0.9.5.dist-info/RECORD +0 -57
  44. {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
pertpy/tools/_mixscape.py CHANGED
@@ -9,15 +9,14 @@ import numpy as np
9
9
  import pandas as pd
10
10
  import scanpy as sc
11
11
  import seaborn as sns
12
+ from fast_array_utils.stats import mean, mean_var
12
13
  from scanpy import get
13
- from scanpy._settings import settings
14
14
  from scanpy._utils import _check_use_raw, sanitize_anndata
15
15
  from scanpy.plotting import _utils
16
16
  from scanpy.tools._utils import _choose_representation
17
- from scipy.sparse import csr_matrix, issparse, spmatrix
17
+ from scipy.sparse import csr_matrix, spmatrix
18
18
  from sklearn.mixture import GaussianMixture
19
19
 
20
- import pertpy as pt
21
20
  from pertpy._doc import _doc_params, doc_common_plot_args
22
21
 
23
22
  if TYPE_CHECKING:
@@ -41,9 +40,12 @@ class Mixscape:
41
40
  adata: AnnData,
42
41
  pert_key: str,
43
42
  control: str,
43
+ *,
44
+ ref_selection_mode: Literal["nn", "split_by"] = "nn",
44
45
  split_by: str | None = None,
45
46
  n_neighbors: int = 20,
46
47
  use_rep: str | None = None,
48
+ n_dims: int | None = 15,
47
49
  n_pcs: int | None = None,
48
50
  batch_size: int | None = None,
49
51
  copy: bool = False,
@@ -51,14 +53,18 @@ class Mixscape:
51
53
  ):
52
54
  """Calculate perturbation signature.
53
55
 
54
- For each cell, we identify `n_neighbors` cells from the control pool with the most similar mRNA expression profiles.
55
- The perturbation signature is calculated by subtracting the averaged mRNA expression profile of the control
56
- neighbors from the mRNA expression profile of each cell.
56
+ The perturbation signature is calculated by subtracting the mRNA expression profile of each cell from the averaged
57
+ mRNA expression profile of the control cells (selected according to `ref_selection_mode`).
58
+ The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the
59
+ perturbation signature is calculated on unscaled data by default, and we therefore recommend to do the same.
57
60
 
58
61
  Args:
59
62
  adata: The annotated data object.
60
63
  pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
61
- control: Control category from the `pert_key` column.
64
+ control: Name of the control category from the `pert_key` column.
65
+ ref_selection_mode: Method to select reference cells for the perturbation signature calculation. If `nn`,
66
+ the `n_neighbors` cells from the control pool with the most similar mRNA expression profiles are selected. If `split_by`,
67
+ the control cells from the same split in `split_by` (e.g. indicating biological replicates) are used to calculate the perturbation signature.
62
68
  split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
63
69
  the perturbation signature for every replicate separately.
64
70
  n_neighbors: Number of neighbors from the control to use for the perturbation signature.
@@ -66,7 +72,10 @@ class Mixscape:
66
72
  If `None`, the representation is chosen automatically:
67
73
  For `.n_vars` < 50, `.X` is used, otherwise 'X_pca' is used.
68
74
  If 'X_pca' is not present, it’s computed with default parameters.
69
- n_pcs: Use this many PCs. If `n_pcs==0` use `.X` if `use_rep is None`.
75
+ n_dims: Number of dimensions to use from the representation to calculate the perturbation signature.
76
+ If `None`, use all dimensions.
77
+ n_pcs: If PCA representation is used, the number of principal components to compute.
78
+ If `n_pcs==0` use `.X` if `use_rep is None`.
70
79
  batch_size: Size of batch to calculate the perturbation signature.
71
80
  If 'None', the perturbation signature is calcuated in the full mode, requiring more memory.
72
81
  The batched mode is very inefficient for sparse data.
@@ -83,8 +92,13 @@ class Mixscape:
83
92
  >>> import pertpy as pt
84
93
  >>> mdata = pt.dt.papalexi_2021()
85
94
  >>> ms_pt = pt.tl.Mixscape()
86
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
95
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
87
96
  """
97
+ if ref_selection_mode not in ["nn", "split_by"]:
98
+ raise ValueError("ref_selection_mode must be either 'nn' or 'split_by'.")
99
+ if ref_selection_mode == "split_by" and split_by is None:
100
+ raise ValueError("split_by must be provided if ref_selection_mode is 'split_by'.")
101
+
88
102
  if copy:
89
103
  adata = adata.copy()
90
104
 
@@ -92,59 +106,73 @@ class Mixscape:
92
106
 
93
107
  control_mask = adata.obs[pert_key] == control
94
108
 
95
- if split_by is None:
96
- split_masks = [np.full(adata.n_obs, True, dtype=bool)]
109
+ if ref_selection_mode == "split_by":
110
+ for split in adata.obs[split_by].unique():
111
+ split_mask = adata.obs[split_by] == split
112
+ control_mask_group = control_mask & split_mask
113
+ control_mean_expr = mean(adata.X[control_mask_group], axis=0)
114
+ adata.layers["X_pert"][split_mask] = (
115
+ np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0)
116
+ - adata.layers["X_pert"][split_mask]
117
+ )
97
118
  else:
98
- split_obs = adata.obs[split_by]
99
- split_masks = [split_obs == cat for cat in split_obs.unique()]
119
+ if split_by is None:
120
+ split_masks = [np.full(adata.n_obs, True, dtype=bool)]
121
+ else:
122
+ split_obs = adata.obs[split_by]
123
+ split_masks = [split_obs == cat for cat in split_obs.unique()]
100
124
 
101
- representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
125
+ representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
126
+ if n_dims is not None and n_dims < representation.shape[1]:
127
+ representation = representation[:, :n_dims]
102
128
 
103
- for split_mask in split_masks:
104
- control_mask_split = control_mask & split_mask
129
+ from pynndescent import NNDescent
105
130
 
106
- R_split = representation[split_mask]
107
- R_control = representation[np.asarray(control_mask_split)]
131
+ for split_mask in split_masks:
132
+ control_mask_split = control_mask & split_mask
108
133
 
109
- from pynndescent import NNDescent
134
+ R_split = representation[split_mask]
135
+ R_control = representation[np.asarray(control_mask_split)]
110
136
 
111
- eps = kwargs.pop("epsilon", 0.1)
112
- nn_index = NNDescent(R_control, **kwargs)
113
- indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
137
+ eps = kwargs.pop("epsilon", 0.1)
138
+ nn_index = NNDescent(R_control, **kwargs)
139
+ indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
114
140
 
115
- X_control = np.expm1(adata.X[np.asarray(control_mask_split)])
141
+ X_control = np.expm1(adata.X[np.asarray(control_mask_split)])
116
142
 
117
- n_split = split_mask.sum()
118
- n_control = X_control.shape[0]
143
+ n_split = split_mask.sum()
144
+ n_control = X_control.shape[0]
119
145
 
120
- if batch_size is None:
121
- col_indices = np.ravel(indices)
122
- row_indices = np.repeat(np.arange(n_split), n_neighbors)
146
+ if batch_size is None:
147
+ col_indices = np.ravel(indices)
148
+ row_indices = np.repeat(np.arange(n_split), n_neighbors)
123
149
 
124
- neigh_matrix = csr_matrix(
125
- (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)),
126
- shape=(n_split, n_control),
127
- )
128
- neigh_matrix /= n_neighbors
129
- adata.layers["X_pert"][split_mask] -= np.log1p(neigh_matrix @ X_control)
130
- else:
131
- is_sparse = issparse(X_control)
132
- split_indices = np.where(split_mask)[0]
133
- for i in range(0, n_split, batch_size):
134
- size = min(i + batch_size, n_split)
135
- select = slice(i, size)
150
+ neigh_matrix = csr_matrix(
151
+ (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)),
152
+ shape=(n_split, n_control),
153
+ )
154
+ neigh_matrix /= n_neighbors
155
+ adata.layers["X_pert"][np.asarray(split_mask)] = (
156
+ sc.pp.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][np.asarray(split_mask)]
157
+ )
158
+ else:
159
+ split_indices = np.where(split_mask)[0]
160
+ for i in range(0, n_split, batch_size):
161
+ size = min(i + batch_size, n_split)
162
+ select = slice(i, size)
136
163
 
137
- batch = np.ravel(indices[select])
138
- split_batch = split_indices[select]
164
+ batch = np.ravel(indices[select])
165
+ split_batch = split_indices[select]
139
166
 
140
- size = size - i
167
+ size = size - i
141
168
 
142
- # sparse is very slow
143
- means_batch = X_control[batch]
144
- means_batch = means_batch.toarray() if is_sparse else means_batch
145
- means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1)
169
+ means_batch = X_control[batch]
170
+ batch_reshaped = means_batch.reshape(size, n_neighbors, -1)
171
+ means_batch, _ = mean_var(batch_reshaped, axis=1)
146
172
 
147
- adata.layers["X_pert"][split_batch] -= np.log1p(means_batch)
173
+ adata.layers["X_pert"][split_batch] = (
174
+ np.log1p(means_batch) - adata.layers["X_pert"][split_batch]
175
+ )
148
176
 
149
177
  if copy:
150
178
  return adata
@@ -154,34 +182,44 @@ class Mixscape:
154
182
  adata: AnnData,
155
183
  labels: str,
156
184
  control: str,
185
+ *,
157
186
  new_class_name: str | None = "mixscape_class",
158
- min_de_genes: int | None = 5,
159
187
  layer: str | None = None,
188
+ min_de_genes: int | None = 5,
160
189
  logfc_threshold: float | None = 0.25,
190
+ de_layer: str | None = None,
191
+ test_method: str | None = "wilcoxon",
161
192
  iter_num: int | None = 10,
193
+ scale: bool | None = True,
162
194
  split_by: str | None = None,
163
195
  pval_cutoff: float | None = 5e-2,
164
196
  perturbation_type: str | None = "KO",
197
+ random_state: int | None = 0,
165
198
  copy: bool | None = False,
199
+ **gmmkwargs,
166
200
  ):
167
201
  """Identify perturbed and non-perturbed gRNA expressing cells that accounts for multiple treatments/conditions/chemical perturbations.
168
202
 
169
- The implementation resembles https://satijalab.org/seurat/reference/runmixscape
203
+ The implementation resembles https://satijalab.org/seurat/reference/runmixscape.
170
204
 
171
205
  Args:
172
206
  adata: The annotated data object.
173
207
  labels: The column of `.obs` with target gene labels.
174
- control: Control category from the `pert_key` column.
208
+ control: Control category from the `labels` column.
175
209
  new_class_name: Name of mixscape classification to be stored in `.obs`.
176
- min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
177
210
  layer: Key from adata.layers whose value will be used to perform tests on. Default is using `.layers["X_pert"]`.
211
+ min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
178
212
  logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells (default: 0.25).
213
+ de_layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
214
+ test_method: Method to use for differential expression testing.
179
215
  iter_num: Number of normalmixEM iterations to run if convergence does not occur.
180
- split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
181
- the perturbation signature for every replicate separately.
216
+ scale: Scale the data specified in `layer` before running the GaussianMixture model on it.
217
+ split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific.
182
218
  pval_cutoff: P-value cut-off for selection of significantly DE genes.
183
219
  perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications.
220
+ random_state: Random seed for the GaussianMixture model.
184
221
  copy: Determines whether a copy of the `adata` is returned.
222
+ **gmmkwargs: Passed to custom implementation of scikit-learn Gaussian Mixture Model.
185
223
 
186
224
  Returns:
187
225
  If `copy=True`, returns the copy of `adata` with the classification result in `.obs`.
@@ -203,8 +241,8 @@ class Mixscape:
203
241
  >>> import pertpy as pt
204
242
  >>> mdata = pt.dt.papalexi_2021()
205
243
  >>> ms_pt = pt.tl.Mixscape()
206
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
207
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
244
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
245
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
208
246
  """
209
247
  if copy:
210
248
  adata = adata.copy()
@@ -218,7 +256,16 @@ class Mixscape:
218
256
  split_masks = [split_obs == category for category in categories]
219
257
 
220
258
  perturbation_markers = self._get_perturbation_markers(
221
- adata, split_masks, categories, labels, control, layer, pval_cutoff, min_de_genes, logfc_threshold
259
+ adata=adata,
260
+ split_masks=split_masks,
261
+ categories=categories,
262
+ labels=labels,
263
+ control=control,
264
+ layer=de_layer,
265
+ pval_cutoff=pval_cutoff,
266
+ min_de_genes=min_de_genes,
267
+ logfc_threshold=logfc_threshold,
268
+ test_method=test_method,
222
269
  )
223
270
 
224
271
  adata_comp = adata
@@ -231,6 +278,7 @@ class Mixscape:
231
278
  raise KeyError(
232
279
  "No 'X_pert' found in .layers! Please run perturbation_signature first to calculate perturbation signature!"
233
280
  ) from None
281
+
234
282
  # initialize return variables
235
283
  adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
236
284
  adata.obs[new_class_name] = adata.obs[labels].astype(str)
@@ -241,10 +289,12 @@ class Mixscape:
241
289
  dtype=np.object_,
242
290
  )
243
291
  gv_list: dict[str, dict] = {}
292
+
293
+ adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0.0
244
294
  for split, split_mask in enumerate(split_masks):
245
295
  category = categories[split]
246
- genes = list(set(adata[split_mask].obs[labels]).difference([control]))
247
- for gene in genes:
296
+ gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
297
+ for gene in gene_targets:
248
298
  post_prob = 0
249
299
  orig_guide_cells = (adata.obs[labels] == gene) & split_mask
250
300
  orig_guide_cells_index = list(orig_guide_cells.index[orig_guide_cells])
@@ -253,63 +303,79 @@ class Mixscape:
253
303
 
254
304
  if len(perturbation_markers[(category, gene)]) == 0:
255
305
  adata.obs.loc[orig_guide_cells, new_class_name] = f"{gene} NP"
306
+
256
307
  else:
257
308
  de_genes = perturbation_markers[(category, gene)]
258
- de_genes_indices = self._get_column_indices(adata, list(de_genes))
309
+ de_genes_indices = np.where(np.isin(adata.var_names, list(de_genes)))[0]
310
+
259
311
  dat = X[np.asarray(all_cells)][:, de_genes_indices]
312
+ if scale:
313
+ dat = sc.pp.scale(dat)
314
+
260
315
  converged = False
261
316
  n_iter = 0
262
- old_classes = adata.obs[labels][all_cells]
317
+ old_classes = adata.obs[new_class_name][all_cells]
318
+
319
+ nt_cells_dat_idx = all_cells[all_cells].index.get_indexer(nt_cells[nt_cells].index)
320
+ nt_cells_mean = np.mean(dat[nt_cells_dat_idx], axis=0)
321
+
263
322
  while not converged and n_iter < iter_num:
264
323
  # Get all cells in current split&Gene
265
- guide_cells = (adata.obs[labels] == gene) & split_mask
324
+ guide_cells = (adata.obs[new_class_name] == gene) & split_mask
325
+
266
326
  # get average value for each gene over all selected cells
267
327
  # all cells in current split&Gene minus all NT cells in current split
268
328
  # Each row is for each cell, each column is for each gene, get mean for each column
269
- vec = np.mean(X[np.asarray(guide_cells)][:, de_genes_indices], axis=0) - np.mean(
270
- X[np.asarray(nt_cells)][:, de_genes_indices], axis=0
271
- )
329
+ guide_cells_dat_idx = all_cells[all_cells].index.get_indexer(guide_cells[guide_cells].index)
330
+ guide_cells_mean = np.mean(dat[guide_cells_dat_idx], axis=0)
331
+ vec = guide_cells_mean - nt_cells_mean
332
+
272
333
  # project cells onto the perturbation vector
273
334
  if isinstance(dat, spmatrix):
274
- pvec = np.sum(np.multiply(dat.toarray(), vec), axis=1) / np.sum(np.multiply(vec, vec))
335
+ pvec = dat.dot(vec) / np.dot(vec, vec)
275
336
  else:
276
- pvec = np.sum(np.multiply(dat, vec), axis=1) / np.sum(np.multiply(vec, vec))
337
+ pvec = np.dot(dat, vec) / np.dot(vec, vec)
277
338
  pvec = pd.Series(np.asarray(pvec).flatten(), index=list(all_cells.index[all_cells]))
339
+
278
340
  if n_iter == 0:
279
341
  gv = pd.DataFrame(columns=["pvec", labels])
280
342
  gv["pvec"] = pvec
281
343
  gv[labels] = control
282
344
  gv.loc[guide_cells, labels] = gene
283
- if gene not in gv_list.keys():
345
+ if gene not in gv_list:
284
346
  gv_list[gene] = {}
285
347
  gv_list[gene][category] = gv
286
348
 
287
- guide_norm = self._define_normal_mixscape(pvec[guide_cells])
288
- nt_norm = self._define_normal_mixscape(pvec[nt_cells])
289
- means_init = np.array([[nt_norm[0]], [guide_norm[0]]])
290
- precisions_init = np.array([nt_norm[1], guide_norm[1]])
291
- mm = GaussianMixture(
349
+ means_init = np.array([[pvec[nt_cells].mean()], [pvec[guide_cells].mean()]])
350
+ std_init = np.array([pvec[nt_cells].std(), pvec[guide_cells].std()])
351
+ mm = MixscapeGaussianMixture(
292
352
  n_components=2,
293
353
  covariance_type="spherical",
294
354
  means_init=means_init,
295
- precisions_init=precisions_init,
355
+ precisions_init=1 / (std_init**2),
356
+ random_state=random_state,
357
+ max_iter=100,
358
+ fixed_means=[pvec[nt_cells].mean(), None],
359
+ fixed_covariances=[pvec[nt_cells].std() ** 2, None],
360
+ **gmmkwargs,
296
361
  ).fit(np.asarray(pvec).reshape(-1, 1))
297
362
  probabilities = mm.predict_proba(np.array(pvec[orig_guide_cells_index]).reshape(-1, 1))
298
363
  lik_ratio = probabilities[:, 0] / probabilities[:, 1]
299
364
  post_prob = 1 / (1 + lik_ratio)
365
+
300
366
  # based on the posterior probability, assign cells to the two classes
301
- adata.obs.loc[
302
- [orig_guide_cells_index[cell] for cell in np.where(post_prob > 0.5)[0]], new_class_name
303
- ] = gene
304
- adata.obs.loc[
305
- [orig_guide_cells_index[cell] for cell in np.where(post_prob <= 0.5)[0]], new_class_name
306
- ] = f"{gene} NP"
367
+ ko_mask = post_prob > 0.5
368
+ adata.obs.loc[np.array(orig_guide_cells_index)[ko_mask], new_class_name] = gene
369
+ adata.obs.loc[np.array(orig_guide_cells_index)[~ko_mask], new_class_name] = f"{gene} NP"
370
+
307
371
  if sum(adata.obs[new_class_name][split_mask] == gene) < min_de_genes:
308
372
  adata.obs.loc[guide_cells, new_class_name] = "NP"
309
373
  converged = True
310
- if adata.obs[new_class_name][all_cells].equals(old_classes):
374
+ current_classes = adata.obs[new_class_name][all_cells]
375
+ if (current_classes == old_classes).all():
311
376
  converged = True
312
- old_classes = adata.obs[new_class_name][all_cells]
377
+ old_classes = current_classes
378
+
313
379
  n_iter += 1
314
380
 
315
381
  adata.obs.loc[(adata.obs[new_class_name] == gene) & split_mask, new_class_name] = (
@@ -317,9 +383,7 @@ class Mixscape:
317
383
  )
318
384
 
319
385
  adata.obs[f"{new_class_name}_global"] = [a.split(" ")[-1] for a in adata.obs[new_class_name]]
320
- adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = np.round(
321
- post_prob
322
- ).astype("int64")
386
+ adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = post_prob
323
387
  adata.uns["mixscape"] = gv_list
324
388
 
325
389
  if copy:
@@ -330,11 +394,13 @@ class Mixscape:
330
394
  adata: AnnData,
331
395
  labels: str,
332
396
  control: str,
397
+ *,
333
398
  mixscape_class_global: str | None = "mixscape_class_global",
334
399
  layer: str | None = None,
335
400
  n_comps: int | None = 10,
336
401
  min_de_genes: int | None = 5,
337
402
  logfc_threshold: float | None = 0.25,
403
+ test_method: str | None = "wilcoxon",
338
404
  split_by: str | None = None,
339
405
  pval_cutoff: float | None = 5e-2,
340
406
  perturbation_type: str | None = "KO",
@@ -347,12 +413,12 @@ class Mixscape:
347
413
  labels: The column of `.obs` with target gene labels.
348
414
  control: Control category from the `pert_key` column.
349
415
  mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
350
- layer: Key from `adata.layers` whose value will be used to perform tests on.
351
- control: Control category from the `pert_key` column.
416
+ layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
352
417
  n_comps: Number of principal components to use.
353
418
  min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
354
419
  logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
355
- split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
420
+ test_method: Method to use for differential expression testing.
421
+ split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific.
356
422
  pval_cutoff: P-value cut-off for selection of significantly DE genes.
357
423
  perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
358
424
  copy: Determines whether a copy of the `adata` is returned.
@@ -370,9 +436,9 @@ class Mixscape:
370
436
  >>> import pertpy as pt
371
437
  >>> mdata = pt.dt.papalexi_2021()
372
438
  >>> ms_pt = pt.tl.Mixscape()
373
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
374
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
375
- >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
439
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
440
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
441
+ >>> ms_pt.lda(mdata["rna"], "gene_target", "NT")
376
442
  """
377
443
  if copy:
378
444
  adata = adata.copy()
@@ -388,9 +454,8 @@ class Mixscape:
388
454
  categories = split_obs.unique()
389
455
  split_masks = [split_obs == category for category in categories]
390
456
 
391
- mixscape_identifier = pt.tl.Mixscape()
392
457
  # determine gene sets across all splits/groups through differential gene expression
393
- perturbation_markers = mixscape_identifier._get_perturbation_markers(
458
+ perturbation_markers = self._get_perturbation_markers(
394
459
  adata=adata,
395
460
  split_masks=split_masks,
396
461
  categories=categories,
@@ -400,10 +465,12 @@ class Mixscape:
400
465
  pval_cutoff=pval_cutoff,
401
466
  min_de_genes=min_de_genes,
402
467
  logfc_threshold=logfc_threshold,
468
+ test_method=test_method,
403
469
  )
404
470
  adata_subset = adata[
405
471
  (adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
406
- ].copy()
472
+ ]
473
+ X = adata_subset.X - adata_subset.X.mean(0)
407
474
  projected_pcs: dict[str, np.ndarray] = {}
408
475
  # performs PCA on each mixscape class separately and projects each subspace onto all cells in the data.
409
476
  for _, (key, value) in enumerate(perturbation_markers.items()):
@@ -415,16 +482,10 @@ class Mixscape:
415
482
  ].copy()
416
483
  sc.pp.scale(gene_subset)
417
484
  sc.tl.pca(gene_subset, n_comps=n_comps)
418
- sc.pp.neighbors(gene_subset)
419
- # projects each subspace onto all cells in the data.
420
- sc.tl.ingest(adata=adata_subset, adata_ref=gene_subset, embedding_method="pca")
421
- projected_pcs[key[1]] = adata_subset.obsm["X_pca"]
485
+ # project cells into PCA space of gene_subset
486
+ projected_pcs[key[1]] = np.asarray(np.dot(X, gene_subset.varm["PCs"]))
422
487
  # concatenate all pcs into a single matrix.
423
- for index, (_, value) in enumerate(projected_pcs.items()):
424
- if index == 0:
425
- projected_pcs_array = value
426
- else:
427
- projected_pcs_array = np.concatenate((projected_pcs_array, value), axis=1)
488
+ projected_pcs_array = np.concatenate(list(projected_pcs.values()), axis=1)
428
489
 
429
490
  clf = LinearDiscriminantAnalysis(n_components=len(np.unique(adata_subset.obs[labels])) - 1)
430
491
  clf.fit(projected_pcs_array, adata_subset.obs[labels])
@@ -445,12 +506,21 @@ class Mixscape:
445
506
  pval_cutoff: float,
446
507
  min_de_genes: float,
447
508
  logfc_threshold: float,
509
+ test_method: str,
448
510
  ) -> dict[tuple, np.ndarray]:
449
- """Determine gene sets across all splits/groups through differential gene expression
511
+ """Determine gene sets across all splits/groups through differential gene expression.
450
512
 
451
513
  Args:
452
514
  adata: :class:`~anndata.AnnData` object
453
- col_names: Column names to extract the indices for
515
+ split_masks: List of boolean masks for each split/group.
516
+ categories: List of split/group names.
517
+ labels: The column of `.obs` with target gene labels.
518
+ control: Control category from the `labels` column.
519
+ layer: Key from adata.layers whose value will be used to compare gene expression.
520
+ pval_cutoff: P-value cut-off for selection of significantly DE genes.
521
+ min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
522
+ logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
523
+ test_method: Method to use for differential expression testing.
454
524
 
455
525
  Returns:
456
526
  Set of column indices.
@@ -459,21 +529,23 @@ class Mixscape:
459
529
  for split, split_mask in enumerate(split_masks):
460
530
  category = categories[split]
461
531
  # get gene sets for each split
462
- genes = list(set(adata[split_mask].obs[labels]).difference([control]))
532
+ gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
463
533
  adata_split = adata[split_mask].copy()
464
534
  # find top DE genes between cells with targeting and non-targeting gRNAs
465
535
  sc.tl.rank_genes_groups(
466
536
  adata_split,
467
537
  layer=layer,
468
538
  groupby=labels,
469
- groups=genes,
539
+ groups=gene_targets,
470
540
  reference=control,
471
- method="t-test",
541
+ method=test_method,
472
542
  use_raw=False,
473
543
  )
474
- # get DE genes for each gene
475
- for gene in genes:
476
- logfc_threshold_mask = adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene] >= logfc_threshold
544
+ # get DE genes for each target gene
545
+ for gene in gene_targets:
546
+ logfc_threshold_mask = (
547
+ np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
548
+ )
477
549
  de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask]
478
550
  pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask]
479
551
  de_genes = de_genes[pvals_adj < pval_cutoff]
@@ -483,33 +555,8 @@ class Mixscape:
483
555
 
484
556
  return perturbation_markers
485
557
 
486
- def _get_column_indices(self, adata, col_names):
487
- if isinstance(col_names, str): # pragma: no cover
488
- col_names = [col_names]
489
-
490
- indices = []
491
- for idx, col in enumerate(adata.var_names):
492
- if col in col_names:
493
- indices.append(idx)
494
-
495
- return indices
496
-
497
- def _define_normal_mixscape(self, X: np.ndarray | sparse.spmatrix | pd.DataFrame | None) -> list[float]:
498
- """Calculates the mean and standard deviation of a matrix.
499
-
500
- Args:
501
- X: The matrix to calculate the properties for.
502
-
503
- Returns:
504
- Mean and standard deviation of the matrix.
505
- """
506
- mu = X.mean()
507
- sd = X.std()
508
-
509
- return [mu, sd]
510
-
511
558
  @_doc_params(common_plot_args=doc_common_plot_args)
512
- def plot_barplot( # pragma: no cover
559
+ def plot_barplot( # pragma: no cover # noqa: D417
513
560
  self,
514
561
  adata: AnnData,
515
562
  guide_rna_column: str,
@@ -522,7 +569,6 @@ class Mixscape:
522
569
  legend_text_size: int = 8,
523
570
  legend_bbox_to_anchor: tuple[float, float] = None,
524
571
  figsize: tuple[float, float] = (25, 25),
525
- show: bool = True,
526
572
  return_fig: bool = False,
527
573
  ) -> Figure | None:
528
574
  """Barplot to visualize perturbation scores calculated by the `mixscape` function.
@@ -548,8 +594,8 @@ class Mixscape:
548
594
  >>> import pertpy as pt
549
595
  >>> mdata = pt.dt.papalexi_2021()
550
596
  >>> ms_pt = pt.tl.Mixscape()
551
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
552
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
597
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
598
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
553
599
  >>> ms_pt.plot_barplot(mdata["rna"], guide_rna_column="NT")
554
600
 
555
601
  Preview:
@@ -611,14 +657,13 @@ class Mixscape:
611
657
  fig.subplots_adjust(hspace=0.5, wspace=0.5)
612
658
  plt.tight_layout()
613
659
 
614
- if show:
615
- plt.show()
616
660
  if return_fig:
617
661
  return fig
662
+ plt.show()
618
663
  return None
619
664
 
620
665
  @_doc_params(common_plot_args=doc_common_plot_args)
621
- def plot_heatmap( # pragma: no cover
666
+ def plot_heatmap( # pragma: no cover # noqa: D417
622
667
  self,
623
668
  adata: AnnData,
624
669
  labels: str,
@@ -630,7 +675,6 @@ class Mixscape:
630
675
  subsample_number: int | None = 900,
631
676
  vmin: float | None = -2,
632
677
  vmax: float | None = 2,
633
- show: bool = True,
634
678
  return_fig: bool = False,
635
679
  **kwds,
636
680
  ) -> Figure | None:
@@ -656,8 +700,8 @@ class Mixscape:
656
700
  >>> import pertpy as pt
657
701
  >>> mdata = pt.dt.papalexi_2021()
658
702
  >>> ms_pt = pt.tl.Mixscape()
659
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
660
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
703
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
704
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
661
705
  >>> ms_pt.plot_heatmap(
662
706
  ... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT"
663
707
  ... )
@@ -683,14 +727,13 @@ class Mixscape:
683
727
  **kwds,
684
728
  )
685
729
 
686
- if show:
687
- plt.show()
688
730
  if return_fig:
689
731
  return fig
732
+ plt.show()
690
733
  return None
691
734
 
692
735
  @_doc_params(common_plot_args=doc_common_plot_args)
693
- def plot_perturbscore( # pragma: no cover
736
+ def plot_perturbscore( # pragma: no cover # noqa: D417
694
737
  self,
695
738
  adata: AnnData,
696
739
  labels: str,
@@ -702,7 +745,6 @@ class Mixscape:
702
745
  split_by: str = None,
703
746
  before_mixscape: bool = False,
704
747
  perturbation_type: str = "KO",
705
- show: bool = True,
706
748
  return_fig: bool = False,
707
749
  ) -> Figure | None:
708
750
  """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
@@ -734,8 +776,8 @@ class Mixscape:
734
776
  >>> import pertpy as pt
735
777
  >>> mdata = pt.dt.papalexi_2021()
736
778
  >>> ms_pt = pt.tl.Mixscape()
737
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
738
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
779
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
780
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
739
781
  >>> ms_pt.plot_perturbscore(adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange")
740
782
 
741
783
  Preview:
@@ -744,7 +786,7 @@ class Mixscape:
744
786
  if "mixscape" not in adata.uns:
745
787
  raise ValueError("Please run the `mixscape` function first.")
746
788
  perturbation_score = None
747
- for key in adata.uns["mixscape"][target_gene].keys():
789
+ for key in adata.uns["mixscape"][target_gene]:
748
790
  perturbation_score_temp = adata.uns["mixscape"][target_gene][key]
749
791
  perturbation_score_temp["name"] = key
750
792
  if perturbation_score is None:
@@ -851,14 +893,13 @@ class Mixscape:
851
893
  plt.legend(title="mixscape class", title_fontsize=14, fontsize=12)
852
894
  sns.despine()
853
895
 
854
- if show:
855
- plt.show()
856
896
  if return_fig:
857
897
  return plt.gcf()
898
+ plt.show()
858
899
  return None
859
900
 
860
901
  @_doc_params(common_plot_args=doc_common_plot_args)
861
- def plot_violin( # pragma: no cover
902
+ def plot_violin( # pragma: no cover # noqa: D417
862
903
  self,
863
904
  adata: AnnData,
864
905
  target_gene_idents: str | list[str],
@@ -879,7 +920,6 @@ class Mixscape:
879
920
  ylabel: str | Sequence[str] | None = None,
880
921
  rotation: float | None = None,
881
922
  ax: Axes | None = None,
882
- show: bool = True,
883
923
  return_fig: bool = False,
884
924
  **kwargs,
885
925
  ) -> Axes | Figure | None:
@@ -910,8 +950,8 @@ class Mixscape:
910
950
  >>> import pertpy as pt
911
951
  >>> mdata = pt.dt.papalexi_2021()
912
952
  >>> ms_pt = pt.tl.Mixscape()
913
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
914
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
953
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
954
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
915
955
  >>> ms_pt.plot_violin(
916
956
  ... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class"
917
957
  ... )
@@ -939,7 +979,7 @@ class Mixscape:
939
979
  if len(ylabel) != 1:
940
980
  raise ValueError(f"Expected number of y-labels to be `1`, found `{len(ylabel)}`.")
941
981
  elif len(ylabel) != len(keys):
942
- raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, " f"found `{len(ylabel)}`.")
982
+ raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, found `{len(ylabel)}`.")
943
983
 
944
984
  if groupby is not None:
945
985
  if hue is not None:
@@ -992,7 +1032,7 @@ class Mixscape:
992
1032
  g.set(yscale="log")
993
1033
  g.set_titles(col_template="{col_name}").set_xlabels("")
994
1034
  if rotation is not None:
995
- for ax in g.axes[0]:
1035
+ for ax in g.axes[0]: # noqa: PLR1704
996
1036
  ax.tick_params(axis="x", labelrotation=rotation)
997
1037
  else:
998
1038
  # set by default the violin plot cut=0 to limit the extend
@@ -1010,7 +1050,7 @@ class Mixscape:
1010
1050
  else:
1011
1051
  axs = [ax]
1012
1052
  for ax, y, ylab in zip(axs, ys, ylabel, strict=False):
1013
- ax = sns.violinplot(
1053
+ ax = sns.violinplot( # noqa: PLW2901
1014
1054
  x=x,
1015
1055
  y=y,
1016
1056
  data=obs_tidy,
@@ -1024,7 +1064,7 @@ class Mixscape:
1024
1064
  # Get the handles and labels.
1025
1065
  handles, labels = ax.get_legend_handles_labels()
1026
1066
  if stripplot:
1027
- ax = sns.stripplot(
1067
+ ax = sns.stripplot( # noqa: PLW2901
1028
1068
  x=x,
1029
1069
  y=y,
1030
1070
  data=obs_tidy,
@@ -1047,12 +1087,9 @@ class Mixscape:
1047
1087
  if rotation is not None:
1048
1088
  ax.tick_params(axis="x", labelrotation=rotation)
1049
1089
 
1050
- show = settings.autoshow if show is None else show
1051
1090
  if hue is not None and stripplot is True:
1052
1091
  plt.legend(handles, labels)
1053
1092
 
1054
- if show:
1055
- plt.show()
1056
1093
  if return_fig:
1057
1094
  if multi_panel and groupby is None and len(ys) == 1:
1058
1095
  return g
@@ -1060,10 +1097,11 @@ class Mixscape:
1060
1097
  return axs[0]
1061
1098
  else:
1062
1099
  return axs
1100
+ plt.show()
1063
1101
  return None
1064
1102
 
1065
1103
  @_doc_params(common_plot_args=doc_common_plot_args)
1066
- def plot_lda( # pragma: no cover
1104
+ def plot_lda( # pragma: no cover # noqa: D417
1067
1105
  self,
1068
1106
  adata: AnnData,
1069
1107
  control: str,
@@ -1076,20 +1114,22 @@ class Mixscape:
1076
1114
  color_map: Colormap | str | None = None,
1077
1115
  palette: str | Sequence[str] | None = None,
1078
1116
  ax: Axes | None = None,
1079
- show: bool = True,
1080
1117
  return_fig: bool = False,
1081
1118
  **kwds,
1082
1119
  ) -> Figure | None:
1083
1120
  """Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
1084
1121
 
1085
1122
  Args:
1086
- adata: The annotated data object.
1123
+ adata: The annotated data objectplot_heatmap.
1087
1124
  control: Control category from the `pert_key` column.
1088
1125
  mixscape_class: The column of `.obs` with the mixscape classification result.
1089
1126
  mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
1090
1127
  perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
1091
- lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
1092
1128
  n_components: The number of dimensions of the embedding.
1129
+ lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
1130
+ color_map: Matplotlib color map.
1131
+ palette: Matplotlib palette.
1132
+ ax: Matplotlib axes.
1093
1133
  {common_plot_args}
1094
1134
  **kwds: Additional arguments to `scanpy.pl.umap`.
1095
1135
 
@@ -1097,9 +1137,9 @@ class Mixscape:
1097
1137
  >>> import pertpy as pt
1098
1138
  >>> mdata = pt.dt.papalexi_2021()
1099
1139
  >>> ms_pt = pt.tl.Mixscape()
1100
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
1101
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
1102
- >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
1140
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
1141
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
1142
+ >>> ms_pt.lda(mdata["rna"], "gene_target", "NT", split_by="replicate")
1103
1143
  >>> ms_pt.plot_lda(adata=mdata["rna"], control="NT")
1104
1144
 
1105
1145
  Preview:
@@ -1129,8 +1169,54 @@ class Mixscape:
1129
1169
  **kwds,
1130
1170
  )
1131
1171
 
1132
- if show:
1133
- plt.show()
1134
1172
  if return_fig:
1135
1173
  return fig
1174
+ plt.show()
1136
1175
  return None
1176
+
1177
+
1178
+ class MixscapeGaussianMixture(GaussianMixture):
1179
+ def __init__(
1180
+ self,
1181
+ n_components: int,
1182
+ fixed_means: Sequence[float] | None = None,
1183
+ fixed_covariances: Sequence[float] | None = None,
1184
+ **kwargs,
1185
+ ):
1186
+ """Custom Gaussian Mixture Model where means and covariances can be fixed for specific components.
1187
+
1188
+ Args:
1189
+ n_components: Number of Gaussian components
1190
+ fixed_means: Means to fix (use None for those that should be estimated)
1191
+ fixed_covariances: Covariances to fix (use None for those that should be estimated)
1192
+ kwargs: Additional arguments passed to scikit-learn's GaussianMixture
1193
+ """
1194
+ super().__init__(n_components=n_components, **kwargs)
1195
+ self.fixed_means = fixed_means
1196
+ self.fixed_covariances = fixed_covariances
1197
+
1198
+ self.fixed_mean_indices = []
1199
+ self.fixed_mean_values = []
1200
+ if fixed_means is not None:
1201
+ self.fixed_mean_indices = [i for i, m in enumerate(fixed_means) if m is not None]
1202
+ if self.fixed_mean_indices:
1203
+ self.fixed_mean_values = np.array([fixed_means[i] for i in self.fixed_mean_indices])
1204
+
1205
+ self.fixed_cov_indices = []
1206
+ self.fixed_cov_values = []
1207
+ if fixed_covariances is not None:
1208
+ self.fixed_cov_indices = [i for i, c in enumerate(fixed_covariances) if c is not None]
1209
+ if self.fixed_cov_indices:
1210
+ self.fixed_cov_values = np.array([fixed_covariances[i] for i in self.fixed_cov_indices])
1211
+
1212
+ def _m_step(self, X: np.ndarray, log_resp: np.ndarray):
1213
+ """Modified M-step to respect fixed means and covariances."""
1214
+ super()._m_step(X, log_resp)
1215
+
1216
+ if self.fixed_mean_indices:
1217
+ self.means_[self.fixed_mean_indices] = self.fixed_mean_values
1218
+
1219
+ if self.fixed_cov_indices:
1220
+ self.covariances_[self.fixed_cov_indices] = self.fixed_cov_values
1221
+
1222
+ return self