pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
pertpy/tools/_mixscape.py CHANGED
@@ -18,6 +18,7 @@ from scipy.sparse import csr_matrix, issparse, spmatrix
18
18
  from sklearn.mixture import GaussianMixture
19
19
 
20
20
  import pertpy as pt
21
+ from pertpy._doc import _doc_params, doc_common_plot_args
21
22
 
22
23
  if TYPE_CHECKING:
23
24
  from collections.abc import Sequence
@@ -25,6 +26,7 @@ if TYPE_CHECKING:
25
26
  from anndata import AnnData
26
27
  from matplotlib.axes import Axes
27
28
  from matplotlib.colors import Colormap
29
+ from matplotlib.pyplot import Figure
28
30
  from scipy import sparse
29
31
 
30
32
 
@@ -39,9 +41,12 @@ class Mixscape:
39
41
  adata: AnnData,
40
42
  pert_key: str,
41
43
  control: str,
44
+ *,
45
+ ref_selection_mode: Literal["nn", "split_by"] = "nn",
42
46
  split_by: str | None = None,
43
47
  n_neighbors: int = 20,
44
48
  use_rep: str | None = None,
49
+ n_dims: int | None = 15,
45
50
  n_pcs: int | None = None,
46
51
  batch_size: int | None = None,
47
52
  copy: bool = False,
@@ -49,14 +54,18 @@ class Mixscape:
49
54
  ):
50
55
  """Calculate perturbation signature.
51
56
 
52
- For each cell, we identify `n_neighbors` cells from the control pool with the most similar mRNA expression profiles.
53
- The perturbation signature is calculated by subtracting the averaged mRNA expression profile of the control
54
- neighbors from the mRNA expression profile of each cell.
57
+ The perturbation signature is calculated by subtracting the mRNA expression profile of each cell from the averaged
58
+ mRNA expression profile of the control cells (selected according to `ref_selection_mode`).
59
+ The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the
60
+ perturbation signature is calculated on unscaled data by default, and we therefore recommend to do the same.
55
61
 
56
62
  Args:
57
63
  adata: The annotated data object.
58
64
  pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
59
- control: Control category from the `pert_key` column.
65
+ control: Name of the control category from the `pert_key` column.
66
+ ref_selection_mode: Method to select reference cells for the perturbation signature calculation. If `nn`,
67
+ the `n_neighbors` cells from the control pool with the most similar mRNA expression profiles are selected. If `split_by`,
68
+ the control cells from the same split in `split_by` (e.g. indicating biological replicates) are used to calculate the perturbation signature.
60
69
  split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
61
70
  the perturbation signature for every replicate separately.
62
71
  n_neighbors: Number of neighbors from the control to use for the perturbation signature.
@@ -64,7 +73,10 @@ class Mixscape:
64
73
  If `None`, the representation is chosen automatically:
65
74
  For `.n_vars` < 50, `.X` is used, otherwise 'X_pca' is used.
66
75
  If 'X_pca' is not present, it’s computed with default parameters.
67
- n_pcs: Use this many PCs. If `n_pcs==0` use `.X` if `use_rep is None`.
76
+ n_dims: Number of dimensions to use from the representation to calculate the perturbation signature.
77
+ If `None`, use all dimensions.
78
+ n_pcs: If PCA representation is used, the number of principal components to compute.
79
+ If `n_pcs==0` use `.X` if `use_rep is None`.
68
80
  batch_size: Size of batch to calculate the perturbation signature.
69
81
  If 'None', the perturbation signature is calcuated in the full mode, requiring more memory.
70
82
  The batched mode is very inefficient for sparse data.
@@ -81,8 +93,13 @@ class Mixscape:
81
93
  >>> import pertpy as pt
82
94
  >>> mdata = pt.dt.papalexi_2021()
83
95
  >>> ms_pt = pt.tl.Mixscape()
84
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
96
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
85
97
  """
98
+ if ref_selection_mode not in ["nn", "split_by"]:
99
+ raise ValueError("ref_selection_mode must be either 'nn' or 'split_by'.")
100
+ if ref_selection_mode == "split_by" and split_by is None:
101
+ raise ValueError("split_by must be provided if ref_selection_mode is 'split_by'.")
102
+
86
103
  if copy:
87
104
  adata = adata.copy()
88
105
 
@@ -90,59 +107,75 @@ class Mixscape:
90
107
 
91
108
  control_mask = adata.obs[pert_key] == control
92
109
 
93
- if split_by is None:
94
- split_masks = [np.full(adata.n_obs, True, dtype=bool)]
110
+ if ref_selection_mode == "split_by":
111
+ for split in adata.obs[split_by].unique():
112
+ split_mask = adata.obs[split_by] == split
113
+ control_mask_group = control_mask & split_mask
114
+ control_mean_expr = adata.X[control_mask_group].mean(0)
115
+ adata.layers["X_pert"][split_mask] = (
116
+ np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0)
117
+ - adata.layers["X_pert"][split_mask]
118
+ )
95
119
  else:
96
- split_obs = adata.obs[split_by]
97
- split_masks = [split_obs == cat for cat in split_obs.unique()]
120
+ if split_by is None:
121
+ split_masks = [np.full(adata.n_obs, True, dtype=bool)]
122
+ else:
123
+ split_obs = adata.obs[split_by]
124
+ split_masks = [split_obs == cat for cat in split_obs.unique()]
98
125
 
99
- representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
126
+ representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
127
+ if n_dims is not None and n_dims < representation.shape[1]:
128
+ representation = representation[:, :n_dims]
100
129
 
101
- for split_mask in split_masks:
102
- control_mask_split = control_mask & split_mask
130
+ for split_mask in split_masks:
131
+ control_mask_split = control_mask & split_mask
103
132
 
104
- R_split = representation[split_mask]
105
- R_control = representation[control_mask_split]
133
+ R_split = representation[split_mask]
134
+ R_control = representation[np.asarray(control_mask_split)]
106
135
 
107
- from pynndescent import NNDescent
136
+ from pynndescent import NNDescent
108
137
 
109
- eps = kwargs.pop("epsilon", 0.1)
110
- nn_index = NNDescent(R_control, **kwargs)
111
- indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
138
+ eps = kwargs.pop("epsilon", 0.1)
139
+ nn_index = NNDescent(R_control, **kwargs)
140
+ indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
112
141
 
113
- X_control = np.expm1(adata.X[control_mask_split])
142
+ X_control = np.expm1(adata.X[np.asarray(control_mask_split)])
114
143
 
115
- n_split = split_mask.sum()
116
- n_control = X_control.shape[0]
144
+ n_split = split_mask.sum()
145
+ n_control = X_control.shape[0]
117
146
 
118
- if batch_size is None:
119
- col_indices = np.ravel(indices)
120
- row_indices = np.repeat(np.arange(n_split), n_neighbors)
147
+ if batch_size is None:
148
+ col_indices = np.ravel(indices)
149
+ row_indices = np.repeat(np.arange(n_split), n_neighbors)
121
150
 
122
- neigh_matrix = csr_matrix(
123
- (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)),
124
- shape=(n_split, n_control),
125
- )
126
- neigh_matrix /= n_neighbors
127
- adata.layers["X_pert"][split_mask] -= np.log1p(neigh_matrix @ X_control)
128
- else:
129
- is_sparse = issparse(X_control)
130
- split_indices = np.where(split_mask)[0]
131
- for i in range(0, n_split, batch_size):
132
- size = min(i + batch_size, n_split)
133
- select = slice(i, size)
151
+ neigh_matrix = csr_matrix(
152
+ (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)),
153
+ shape=(n_split, n_control),
154
+ )
155
+ neigh_matrix /= n_neighbors
156
+ adata.layers["X_pert"][split_mask] = (
157
+ np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask]
158
+ )
159
+ else:
160
+ is_sparse = issparse(X_control)
161
+ split_indices = np.where(split_mask)[0]
162
+ for i in range(0, n_split, batch_size):
163
+ size = min(i + batch_size, n_split)
164
+ select = slice(i, size)
134
165
 
135
- batch = np.ravel(indices[select])
136
- split_batch = split_indices[select]
166
+ batch = np.ravel(indices[select])
167
+ split_batch = split_indices[select]
137
168
 
138
- size = size - i
169
+ size = size - i
139
170
 
140
- # sparse is very slow
141
- means_batch = X_control[batch]
142
- means_batch = means_batch.toarray() if is_sparse else means_batch
143
- means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1)
171
+ # sparse is very slow
172
+ means_batch = X_control[batch]
173
+ means_batch = means_batch.toarray() if is_sparse else means_batch
174
+ means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1)
144
175
 
145
- adata.layers["X_pert"][split_batch] -= np.log1p(means_batch)
176
+ adata.layers["X_pert"][split_batch] = (
177
+ np.log1p(means_batch) - adata.layers["X_pert"][split_batch]
178
+ )
146
179
 
147
180
  if copy:
148
181
  return adata
@@ -152,33 +185,41 @@ class Mixscape:
152
185
  adata: AnnData,
153
186
  labels: str,
154
187
  control: str,
188
+ *,
155
189
  new_class_name: str | None = "mixscape_class",
156
- min_de_genes: int | None = 5,
157
190
  layer: str | None = None,
191
+ min_de_genes: int | None = 5,
158
192
  logfc_threshold: float | None = 0.25,
193
+ de_layer: str | None = None,
194
+ test_method: str | None = "wilcoxon",
159
195
  iter_num: int | None = 10,
196
+ scale: bool | None = True,
160
197
  split_by: str | None = None,
161
198
  pval_cutoff: float | None = 5e-2,
162
199
  perturbation_type: str | None = "KO",
200
+ random_state: int | None = 0,
163
201
  copy: bool | None = False,
164
202
  ):
165
203
  """Identify perturbed and non-perturbed gRNA expressing cells that accounts for multiple treatments/conditions/chemical perturbations.
166
204
 
167
- The implementation resembles https://satijalab.org/seurat/reference/runmixscape
205
+ The implementation resembles https://satijalab.org/seurat/reference/runmixscape.
168
206
 
169
207
  Args:
170
208
  adata: The annotated data object.
171
209
  labels: The column of `.obs` with target gene labels.
172
- control: Control category from the `pert_key` column.
210
+ control: Control category from the `labels` column.
173
211
  new_class_name: Name of mixscape classification to be stored in `.obs`.
174
- min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
175
212
  layer: Key from adata.layers whose value will be used to perform tests on. Default is using `.layers["X_pert"]`.
213
+ min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
176
214
  logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells (default: 0.25).
215
+ de_layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
216
+ test_method: Method to use for differential expression testing.
177
217
  iter_num: Number of normalmixEM iterations to run if convergence does not occur.
178
- split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
179
- the perturbation signature for every replicate separately.
218
+ scale: Scale the data specified in `layer` before running the GaussianMixture model on it.
219
+ split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific.
180
220
  pval_cutoff: P-value cut-off for selection of significantly DE genes.
181
221
  perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications.
222
+ random_state: Random seed for the GaussianMixture model.
182
223
  copy: Determines whether a copy of the `adata` is returned.
183
224
 
184
225
  Returns:
@@ -201,8 +242,8 @@ class Mixscape:
201
242
  >>> import pertpy as pt
202
243
  >>> mdata = pt.dt.papalexi_2021()
203
244
  >>> ms_pt = pt.tl.Mixscape()
204
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
205
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
245
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
246
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
206
247
  """
207
248
  if copy:
208
249
  adata = adata.copy()
@@ -216,7 +257,16 @@ class Mixscape:
216
257
  split_masks = [split_obs == category for category in categories]
217
258
 
218
259
  perturbation_markers = self._get_perturbation_markers(
219
- adata, split_masks, categories, labels, control, layer, pval_cutoff, min_de_genes, logfc_threshold
260
+ adata=adata,
261
+ split_masks=split_masks,
262
+ categories=categories,
263
+ labels=labels,
264
+ control=control,
265
+ layer=de_layer,
266
+ pval_cutoff=pval_cutoff,
267
+ min_de_genes=min_de_genes,
268
+ logfc_threshold=logfc_threshold,
269
+ test_method=test_method,
220
270
  )
221
271
 
222
272
  adata_comp = adata
@@ -229,6 +279,7 @@ class Mixscape:
229
279
  raise KeyError(
230
280
  "No 'X_pert' found in .layers! Please run perturbation_signature first to calculate perturbation signature!"
231
281
  ) from None
282
+
232
283
  # initialize return variables
233
284
  adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
234
285
  adata.obs[new_class_name] = adata.obs[labels].astype(str)
@@ -239,10 +290,12 @@ class Mixscape:
239
290
  dtype=np.object_,
240
291
  )
241
292
  gv_list: dict[str, dict] = {}
293
+
294
+ adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0.0
242
295
  for split, split_mask in enumerate(split_masks):
243
296
  category = categories[split]
244
- genes = list(set(adata[split_mask].obs[labels]).difference([control]))
245
- for gene in genes:
297
+ gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
298
+ for gene in gene_targets:
246
299
  post_prob = 0
247
300
  orig_guide_cells = (adata.obs[labels] == gene) & split_mask
248
301
  orig_guide_cells_index = list(orig_guide_cells.index[orig_guide_cells])
@@ -251,28 +304,38 @@ class Mixscape:
251
304
 
252
305
  if len(perturbation_markers[(category, gene)]) == 0:
253
306
  adata.obs.loc[orig_guide_cells, new_class_name] = f"{gene} NP"
307
+
254
308
  else:
255
309
  de_genes = perturbation_markers[(category, gene)]
256
310
  de_genes_indices = self._get_column_indices(adata, list(de_genes))
257
- dat = X[all_cells][:, de_genes_indices]
311
+
312
+ dat = X[np.asarray(all_cells)][:, de_genes_indices]
313
+ dat_cells = all_cells[all_cells].index
314
+ if scale:
315
+ dat = sc.pp.scale(dat)
316
+
258
317
  converged = False
259
318
  n_iter = 0
260
- old_classes = adata.obs[labels][all_cells]
319
+ old_classes = adata.obs[new_class_name][all_cells]
320
+
261
321
  while not converged and n_iter < iter_num:
262
322
  # Get all cells in current split&Gene
263
- guide_cells = (adata.obs[labels] == gene) & split_mask
323
+ guide_cells = (adata.obs[new_class_name] == gene) & split_mask
324
+
264
325
  # get average value for each gene over all selected cells
265
326
  # all cells in current split&Gene minus all NT cells in current split
266
327
  # Each row is for each cell, each column is for each gene, get mean for each column
267
- vec = np.mean(X[guide_cells][:, de_genes_indices], axis=0) - np.mean(
268
- X[nt_cells][:, de_genes_indices], axis=0
269
- )
328
+ guide_cells_dat_idx = all_cells[all_cells].index.get_indexer(guide_cells[guide_cells].index)
329
+ nt_cells_dat_idx = all_cells[all_cells].index.get_indexer(nt_cells[nt_cells].index)
330
+ vec = np.mean(dat[guide_cells_dat_idx], axis=0) - np.mean(dat[nt_cells_dat_idx], axis=0)
331
+
270
332
  # project cells onto the perturbation vector
271
333
  if isinstance(dat, spmatrix):
272
- pvec = np.sum(np.multiply(dat.toarray(), vec), axis=1) / np.sum(np.multiply(vec, vec))
334
+ pvec = np.dot(dat.toarray(), vec) / np.dot(vec, vec)
273
335
  else:
274
- pvec = np.sum(np.multiply(dat, vec), axis=1) / np.sum(np.multiply(vec, vec))
336
+ pvec = np.dot(dat, vec) / np.dot(vec, vec)
275
337
  pvec = pd.Series(np.asarray(pvec).flatten(), index=list(all_cells.index[all_cells]))
338
+
276
339
  if n_iter == 0:
277
340
  gv = pd.DataFrame(columns=["pvec", labels])
278
341
  gv["pvec"] = pvec
@@ -282,19 +345,22 @@ class Mixscape:
282
345
  gv_list[gene] = {}
283
346
  gv_list[gene][category] = gv
284
347
 
285
- guide_norm = self._define_normal_mixscape(pvec[guide_cells])
286
- nt_norm = self._define_normal_mixscape(pvec[nt_cells])
287
- means_init = np.array([[nt_norm[0]], [guide_norm[0]]])
288
- precisions_init = np.array([nt_norm[1], guide_norm[1]])
289
- mm = GaussianMixture(
348
+ means_init = np.array([[pvec[nt_cells].mean()], [pvec[guide_cells].mean()]])
349
+ std_init = np.array([pvec[nt_cells].std(), pvec[guide_cells].std()])
350
+ mm = MixscapeGaussianMixture(
290
351
  n_components=2,
291
352
  covariance_type="spherical",
292
353
  means_init=means_init,
293
- precisions_init=precisions_init,
354
+ precisions_init=1 / (std_init ** 2),
355
+ random_state=random_state,
356
+ max_iter=5000,
357
+ fixed_means=[pvec[nt_cells].mean(), None],
358
+ fixed_covariances=[pvec[nt_cells].std() ** 2, None],
294
359
  ).fit(np.asarray(pvec).reshape(-1, 1))
295
360
  probabilities = mm.predict_proba(np.array(pvec[orig_guide_cells_index]).reshape(-1, 1))
296
361
  lik_ratio = probabilities[:, 0] / probabilities[:, 1]
297
362
  post_prob = 1 / (1 + lik_ratio)
363
+
298
364
  # based on the posterior probability, assign cells to the two classes
299
365
  adata.obs.loc[
300
366
  [orig_guide_cells_index[cell] for cell in np.where(post_prob > 0.5)[0]], new_class_name
@@ -302,11 +368,13 @@ class Mixscape:
302
368
  adata.obs.loc[
303
369
  [orig_guide_cells_index[cell] for cell in np.where(post_prob <= 0.5)[0]], new_class_name
304
370
  ] = f"{gene} NP"
371
+
305
372
  if sum(adata.obs[new_class_name][split_mask] == gene) < min_de_genes:
306
373
  adata.obs.loc[guide_cells, new_class_name] = "NP"
307
374
  converged = True
308
375
  if adata.obs[new_class_name][all_cells].equals(old_classes):
309
376
  converged = True
377
+
310
378
  old_classes = adata.obs[new_class_name][all_cells]
311
379
  n_iter += 1
312
380
 
@@ -315,9 +383,7 @@ class Mixscape:
315
383
  )
316
384
 
317
385
  adata.obs[f"{new_class_name}_global"] = [a.split(" ")[-1] for a in adata.obs[new_class_name]]
318
- adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = np.round(
319
- post_prob
320
- ).astype("int64")
386
+ adata.obs.loc[orig_guide_cells_index, f"{new_class_name}_p_{perturbation_type.lower()}"] = post_prob
321
387
  adata.uns["mixscape"] = gv_list
322
388
 
323
389
  if copy:
@@ -328,11 +394,13 @@ class Mixscape:
328
394
  adata: AnnData,
329
395
  labels: str,
330
396
  control: str,
397
+ *,
331
398
  mixscape_class_global: str | None = "mixscape_class_global",
332
399
  layer: str | None = None,
333
400
  n_comps: int | None = 10,
334
401
  min_de_genes: int | None = 5,
335
402
  logfc_threshold: float | None = 0.25,
403
+ test_method: str | None = "wilcoxon",
336
404
  split_by: str | None = None,
337
405
  pval_cutoff: float | None = 5e-2,
338
406
  perturbation_type: str | None = "KO",
@@ -345,12 +413,13 @@ class Mixscape:
345
413
  labels: The column of `.obs` with target gene labels.
346
414
  control: Control category from the `pert_key` column.
347
415
  mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
348
- layer: Key from `adata.layers` whose value will be used to perform tests on.
416
+ layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
349
417
  control: Control category from the `pert_key` column.
350
418
  n_comps: Number of principal components to use.
351
419
  min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
352
420
  logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
353
- split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
421
+ test_method: Method to use for differential expression testing.
422
+ split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific.
354
423
  pval_cutoff: P-value cut-off for selection of significantly DE genes.
355
424
  perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
356
425
  copy: Determines whether a copy of the `adata` is returned.
@@ -368,9 +437,9 @@ class Mixscape:
368
437
  >>> import pertpy as pt
369
438
  >>> mdata = pt.dt.papalexi_2021()
370
439
  >>> ms_pt = pt.tl.Mixscape()
371
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
372
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
373
- >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
440
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
441
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
442
+ >>> ms_pt.lda(mdata["rna"], "gene_target", "NT")
374
443
  """
375
444
  if copy:
376
445
  adata = adata.copy()
@@ -386,9 +455,8 @@ class Mixscape:
386
455
  categories = split_obs.unique()
387
456
  split_masks = [split_obs == category for category in categories]
388
457
 
389
- mixscape_identifier = pt.tl.Mixscape()
390
458
  # determine gene sets across all splits/groups through differential gene expression
391
- perturbation_markers = mixscape_identifier._get_perturbation_markers(
459
+ perturbation_markers = self._get_perturbation_markers(
392
460
  adata=adata,
393
461
  split_masks=split_masks,
394
462
  categories=categories,
@@ -398,6 +466,7 @@ class Mixscape:
398
466
  pval_cutoff=pval_cutoff,
399
467
  min_de_genes=min_de_genes,
400
468
  logfc_threshold=logfc_threshold,
469
+ test_method=test_method,
401
470
  )
402
471
  adata_subset = adata[
403
472
  (adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
@@ -443,12 +512,21 @@ class Mixscape:
443
512
  pval_cutoff: float,
444
513
  min_de_genes: float,
445
514
  logfc_threshold: float,
515
+ test_method: str,
446
516
  ) -> dict[tuple, np.ndarray]:
447
517
  """Determine gene sets across all splits/groups through differential gene expression
448
518
 
449
519
  Args:
450
520
  adata: :class:`~anndata.AnnData` object
451
- col_names: Column names to extract the indices for
521
+ split_masks: List of boolean masks for each split/group.
522
+ categories: List of split/group names.
523
+ labels: The column of `.obs` with target gene labels.
524
+ control: Control category from the `labels` column.
525
+ layer: Key from adata.layers whose value will be used to compare gene expression.
526
+ pval_cutoff: P-value cut-off for selection of significantly DE genes.
527
+ min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
528
+ logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
529
+ test_method: Method to use for differential expression testing.
452
530
 
453
531
  Returns:
454
532
  Set of column indices.
@@ -457,21 +535,21 @@ class Mixscape:
457
535
  for split, split_mask in enumerate(split_masks):
458
536
  category = categories[split]
459
537
  # get gene sets for each split
460
- genes = list(set(adata[split_mask].obs[labels]).difference([control]))
538
+ gene_targets = list(set(adata[split_mask].obs[labels]).difference([control]))
461
539
  adata_split = adata[split_mask].copy()
462
540
  # find top DE genes between cells with targeting and non-targeting gRNAs
463
541
  sc.tl.rank_genes_groups(
464
542
  adata_split,
465
543
  layer=layer,
466
544
  groupby=labels,
467
- groups=genes,
545
+ groups=gene_targets,
468
546
  reference=control,
469
- method="t-test",
547
+ method=test_method,
470
548
  use_raw=False,
471
549
  )
472
- # get DE genes for each gene
473
- for gene in genes:
474
- logfc_threshold_mask = adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene] >= logfc_threshold
550
+ # get DE genes for each target gene
551
+ for gene in gene_targets:
552
+ logfc_threshold_mask = np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold
475
553
  de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask]
476
554
  pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask]
477
555
  de_genes = de_genes[pvals_adj < pval_cutoff]
@@ -492,35 +570,22 @@ class Mixscape:
492
570
 
493
571
  return indices
494
572
 
495
- def _define_normal_mixscape(self, X: np.ndarray | sparse.spmatrix | pd.DataFrame | None) -> list[float]:
496
- """Calculates the mean and standard deviation of a matrix.
497
-
498
- Args:
499
- X: The matrix to calculate the properties for.
500
-
501
- Returns:
502
- Mean and standard deviation of the matrix.
503
- """
504
- mu = X.mean()
505
- sd = X.std()
506
-
507
- return [mu, sd]
508
-
573
+ @_doc_params(common_plot_args=doc_common_plot_args)
509
574
  def plot_barplot( # pragma: no cover
510
575
  self,
511
576
  adata: AnnData,
512
577
  guide_rna_column: str,
578
+ *,
513
579
  mixscape_class_global: str = "mixscape_class_global",
514
580
  axis_text_x_size: int = 8,
515
581
  axis_text_y_size: int = 6,
516
582
  axis_title_size: int = 8,
517
583
  legend_title_size: int = 8,
518
584
  legend_text_size: int = 8,
519
- return_fig: bool | None = None,
520
- ax: Axes | None = None,
521
- show: bool | None = None,
522
- save: bool | str | None = None,
523
- ):
585
+ legend_bbox_to_anchor: tuple[float, float] = None,
586
+ figsize: tuple[float, float] = (25, 25),
587
+ return_fig: bool = False,
588
+ ) -> Figure | None:
524
589
  """Barplot to visualize perturbation scores calculated by the `mixscape` function.
525
590
 
526
591
  Args:
@@ -528,19 +593,24 @@ class Mixscape:
528
593
  guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels.
529
594
  The format must be <gene_target>g<#>. Examples are 'STAT2g1' and 'ATF2g1'.
530
595
  mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
531
- show: Show the plot, do not return axis.
532
- save: If True or a str, save the figure. A string is appended to the default filename.
533
- Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
596
+ axis_text_x_size: Size of the x-axis text.
597
+ axis_text_y_size: Size of the y-axis text.
598
+ axis_title_size: Size of the axis title.
599
+ legend_title_size: Size of the legend title.
600
+ legend_text_size: Size of the legend text.
601
+ legend_bbox_to_anchor: The bbox that the legend will be anchored.
602
+ figsize: The size of the figure.
603
+ {common_plot_args}
534
604
 
535
605
  Returns:
536
- If `show==False`, return a :class:`~matplotlib.axes.Axes.
606
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
537
607
 
538
608
  Examples:
539
609
  >>> import pertpy as pt
540
610
  >>> mdata = pt.dt.papalexi_2021()
541
611
  >>> ms_pt = pt.tl.Mixscape()
542
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
543
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
612
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
613
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
544
614
  >>> ms_pt.plot_barplot(mdata["rna"], guide_rna_column="NT")
545
615
 
546
616
  Preview:
@@ -565,63 +635,64 @@ class Mixscape:
565
635
  all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"]
566
636
  NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"]
567
637
 
568
- if show:
569
- color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"}
570
- unique_genes = NP_KO_cells["gene"].unique()
571
- fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=(25, 25), sharey=True)
572
- for i, gene in enumerate(unique_genes):
573
- ax = axs[int(i / 5), i % 5]
574
- grouped_df = (
575
- NP_KO_cells[NP_KO_cells["gene"] == gene]
576
- .groupby(["guide_number", "mixscape_class_global"], observed=False)["value"]
577
- .sum()
578
- .unstack()
579
- )
580
- grouped_df.plot(
581
- kind="bar",
582
- stacked=True,
583
- color=[color_mapping[col] for col in grouped_df.columns],
584
- ax=ax,
585
- width=0.8,
586
- legend=False,
587
- )
588
- ax.set_title(
589
- gene, bbox={"facecolor": "white", "edgecolor": "black", "pad": 1}, fontsize=axis_title_size
590
- )
591
- ax.set(xlabel="sgRNA", ylabel="% of cells")
592
- sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
593
- ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size)
594
- ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size)
595
- fig.subplots_adjust(right=0.8)
596
- fig.subplots_adjust(hspace=0.5, wspace=0.5)
638
+ color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"}
639
+ unique_genes = NP_KO_cells["gene"].unique()
640
+ fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=figsize, sharey=True)
641
+ for i, gene in enumerate(unique_genes):
642
+ ax = axs[int(i / 5), i % 5]
643
+ grouped_df = (
644
+ NP_KO_cells[NP_KO_cells["gene"] == gene]
645
+ .groupby(["guide_number", "mixscape_class_global"], observed=False)["value"]
646
+ .sum()
647
+ .unstack()
648
+ )
649
+ grouped_df.plot(
650
+ kind="bar",
651
+ stacked=True,
652
+ color=[color_mapping[col] for col in grouped_df.columns],
653
+ ax=ax,
654
+ width=0.8,
655
+ legend=False,
656
+ )
657
+ ax.set_title(gene, bbox={"facecolor": "white", "edgecolor": "black", "pad": 1}, fontsize=axis_title_size)
658
+ ax.set(xlabel="sgRNA", ylabel="% of cells")
659
+ sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
660
+ ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size)
661
+ ax.set_yticks(ax.get_yticks(), ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size)
597
662
  ax.legend(
598
- title="mixscape_class_global",
663
+ title="Mixscape Class",
599
664
  loc="center right",
600
- bbox_to_anchor=(2.2, 3.5),
665
+ bbox_to_anchor=legend_bbox_to_anchor,
601
666
  frameon=True,
602
667
  fontsize=legend_text_size,
603
668
  title_fontsize=legend_title_size,
604
669
  )
605
670
 
671
+ fig.subplots_adjust(right=0.8)
672
+ fig.subplots_adjust(hspace=0.5, wspace=0.5)
606
673
  plt.tight_layout()
607
- _utils.savefig_or_show("mixscape_barplot", show=show, save=save)
608
674
 
675
+ if return_fig:
676
+ return fig
677
+ plt.show()
678
+ return None
679
+
680
+ @_doc_params(common_plot_args=doc_common_plot_args)
609
681
  def plot_heatmap( # pragma: no cover
610
682
  self,
611
683
  adata: AnnData,
612
684
  labels: str,
613
685
  target_gene: str,
614
686
  control: str,
687
+ *,
615
688
  layer: str | None = None,
616
689
  method: str | None = "wilcoxon",
617
690
  subsample_number: int | None = 900,
618
691
  vmin: float | None = -2,
619
692
  vmax: float | None = 2,
620
- return_fig: bool | None = None,
621
- show: bool | None = None,
622
- save: bool | str | None = None,
693
+ return_fig: bool = False,
623
694
  **kwds,
624
- ) -> Axes | None:
695
+ ) -> Figure | None:
625
696
  """Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first.
626
697
 
627
698
  Args:
@@ -634,21 +705,18 @@ class Mixscape:
634
705
  subsample_number: Subsample to this number of observations.
635
706
  vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
636
707
  vmax: The value representing the upper limit of the color scale. Values larger than vmax are plotted with the same color as vmax.
637
- show: Show the plot, do not return axis.
638
- save: If `True` or a `str`, save the figure. A string is appended to the default filename.
639
- Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
640
- ax: A matplotlib axes object. Only works if plotting a single component.
708
+ {common_plot_args}
641
709
  **kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`.
642
710
 
643
711
  Returns:
644
- If `show==False`, return a :class:`~matplotlib.axes.Axes`.
712
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
645
713
 
646
714
  Examples:
647
715
  >>> import pertpy as pt
648
716
  >>> mdata = pt.dt.papalexi_2021()
649
717
  >>> ms_pt = pt.tl.Mixscape()
650
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
651
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
718
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
719
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
652
720
  >>> ms_pt.plot_heatmap(
653
721
  ... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT"
654
722
  ... )
@@ -663,35 +731,37 @@ class Mixscape:
663
731
  sc.pp.scale(adata_subset, max_value=vmax)
664
732
  sc.pp.subsample(adata_subset, n_obs=subsample_number)
665
733
 
666
- return sc.pl.rank_genes_groups_heatmap(
734
+ fig = sc.pl.rank_genes_groups_heatmap(
667
735
  adata_subset,
668
736
  groupby="mixscape_class",
669
737
  vmin=vmin,
670
738
  vmax=vmax,
671
739
  n_genes=20,
672
740
  groups=["NT"],
673
- return_fig=return_fig,
674
- show=show,
675
- save=save,
741
+ show=False,
676
742
  **kwds,
677
743
  )
678
744
 
745
+ if return_fig:
746
+ return fig
747
+ plt.show()
748
+ return None
749
+
750
+ @_doc_params(common_plot_args=doc_common_plot_args)
679
751
  def plot_perturbscore( # pragma: no cover
680
752
  self,
681
753
  adata: AnnData,
682
754
  labels: str,
683
755
  target_gene: str,
756
+ *,
684
757
  mixscape_class: str = "mixscape_class",
685
758
  color: str = "orange",
686
759
  palette: dict[str, str] = None,
687
760
  split_by: str = None,
688
761
  before_mixscape: bool = False,
689
762
  perturbation_type: str = "KO",
690
- return_fig: bool | None = None,
691
- ax: Axes | None = None,
692
- show: bool | None = None,
693
- save: bool | str | None = None,
694
- ) -> None:
763
+ return_fig: bool = False,
764
+ ) -> Figure | None:
695
765
  """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
696
766
 
697
767
  Requires `pt.tl.mixscape` to be run first.
@@ -710,6 +780,10 @@ class Mixscape:
710
780
  before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification.
711
781
  Default is set to NULL and plots cells by original class ID.
712
782
  perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
783
+ {common_plot_args}
784
+
785
+ Returns:
786
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
713
787
 
714
788
  Examples:
715
789
  Visualizing the perturbation scores for the cells in a dataset:
@@ -717,8 +791,8 @@ class Mixscape:
717
791
  >>> import pertpy as pt
718
792
  >>> mdata = pt.dt.papalexi_2021()
719
793
  >>> ms_pt = pt.tl.Mixscape()
720
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
721
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
794
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
795
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
722
796
  >>> ms_pt.plot_perturbscore(adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange")
723
797
 
724
798
  Preview:
@@ -778,15 +852,6 @@ class Mixscape:
778
852
  plt.legend(title="gene_target", title_fontsize=14, fontsize=12)
779
853
  sns.despine()
780
854
 
781
- if save:
782
- plt.savefig(save, bbox_inches="tight")
783
- if show:
784
- plt.show()
785
- if return_fig:
786
- return plt.gcf()
787
- if not (show or save):
788
- return plt.gca()
789
-
790
855
  # If before_mixscape is False, split densities based on mixscape classifications
791
856
  else:
792
857
  if palette is None:
@@ -843,19 +908,17 @@ class Mixscape:
843
908
  plt.legend(title="mixscape class", title_fontsize=14, fontsize=12)
844
909
  sns.despine()
845
910
 
846
- if save:
847
- plt.savefig(save, bbox_inches="tight")
848
- if show:
849
- plt.show()
850
- if return_fig:
851
- return plt.gcf()
852
- if not (show or save):
853
- return plt.gca()
911
+ if return_fig:
912
+ return plt.gcf()
913
+ plt.show()
914
+ return None
854
915
 
916
+ @_doc_params(common_plot_args=doc_common_plot_args)
855
917
  def plot_violin( # pragma: no cover
856
918
  self,
857
919
  adata: AnnData,
858
920
  target_gene_idents: str | list[str],
921
+ *,
859
922
  keys: str | Sequence[str] = "mixscape_class_p_ko",
860
923
  groupby: str | None = "mixscape_class",
861
924
  log: bool = False,
@@ -872,10 +935,9 @@ class Mixscape:
872
935
  ylabel: str | Sequence[str] | None = None,
873
936
  rotation: float | None = None,
874
937
  ax: Axes | None = None,
875
- show: bool | None = None,
876
- save: bool | str | None = None,
938
+ return_fig: bool = False,
877
939
  **kwargs,
878
- ):
940
+ ) -> Axes | Figure | None:
879
941
  """Violin plot using mixscape results.
880
942
 
881
943
  Requires `pt.tl.mixscape` to be run first.
@@ -892,21 +954,19 @@ class Mixscape:
892
954
  xlabel: Label of the x-axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
893
955
  ylabel: Label of the y-axis. If `None` and `groupby` is `None`, defaults to `'value'`.
894
956
  If `None` and `groubpy` is not `None`, defaults to `keys`.
895
- show: Show the plot, do not return axis.
896
- save: If `True` or a `str`, save the figure. A string is appended to the default filename.
897
- Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
898
957
  ax: A matplotlib axes object. Only works if plotting a single component.
958
+ {common_plot_args}
899
959
  **kwargs: Additional arguments to `seaborn.violinplot`.
900
960
 
901
961
  Returns:
902
- A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`.
962
+ If `return_fig` is `True`, returns the figure (as Axes list if it's a multi-panel plot), otherwise `None`.
903
963
 
904
964
  Examples:
905
965
  >>> import pertpy as pt
906
966
  >>> mdata = pt.dt.papalexi_2021()
907
967
  >>> ms_pt = pt.tl.Mixscape()
908
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
909
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
968
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
969
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
910
970
  >>> ms_pt.plot_violin(
911
971
  ... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class"
912
972
  ... )
@@ -1042,23 +1102,25 @@ class Mixscape:
1042
1102
  if rotation is not None:
1043
1103
  ax.tick_params(axis="x", labelrotation=rotation)
1044
1104
 
1045
- show = settings.autoshow if show is None else show
1046
1105
  if hue is not None and stripplot is True:
1047
1106
  plt.legend(handles, labels)
1048
- _utils.savefig_or_show("mixscape_violin", show=show, save=save)
1049
1107
 
1050
- if not show:
1108
+ if return_fig:
1051
1109
  if multi_panel and groupby is None and len(ys) == 1:
1052
1110
  return g
1053
1111
  elif len(axs) == 1:
1054
1112
  return axs[0]
1055
1113
  else:
1056
1114
  return axs
1115
+ plt.show()
1116
+ return None
1057
1117
 
1118
+ @_doc_params(common_plot_args=doc_common_plot_args)
1058
1119
  def plot_lda( # pragma: no cover
1059
1120
  self,
1060
1121
  adata: AnnData,
1061
1122
  control: str,
1123
+ *,
1062
1124
  mixscape_class: str = "mixscape_class",
1063
1125
  mixscape_class_global: str = "mixscape_class_global",
1064
1126
  perturbation_type: str | None = "KO",
@@ -1066,12 +1128,10 @@ class Mixscape:
1066
1128
  n_components: int | None = None,
1067
1129
  color_map: Colormap | str | None = None,
1068
1130
  palette: str | Sequence[str] | None = None,
1069
- return_fig: bool | None = None,
1070
1131
  ax: Axes | None = None,
1071
- show: bool | None = None,
1072
- save: bool | str | None = None,
1132
+ return_fig: bool = False,
1073
1133
  **kwds,
1074
- ) -> None:
1134
+ ) -> Figure | None:
1075
1135
  """Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
1076
1136
 
1077
1137
  Args:
@@ -1082,18 +1142,16 @@ class Mixscape:
1082
1142
  perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
1083
1143
  lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
1084
1144
  n_components: The number of dimensions of the embedding.
1085
- show: Show the plot, do not return axis.
1086
- save: If `True` or a `str`, save the figure. A string is appended to the default filename.
1087
- Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
1145
+ {common_plot_args}
1088
1146
  **kwds: Additional arguments to `scanpy.pl.umap`.
1089
1147
 
1090
1148
  Examples:
1091
1149
  >>> import pertpy as pt
1092
1150
  >>> mdata = pt.dt.papalexi_2021()
1093
1151
  >>> ms_pt = pt.tl.Mixscape()
1094
- >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
1095
- >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
1096
- >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert")
1152
+ >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
1153
+ >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert")
1154
+ >>> ms_pt.lda(mdata["rna"], "gene_target", "NT", split_by="replicate")
1097
1155
  >>> ms_pt.plot_lda(adata=mdata["rna"], control="NT")
1098
1156
 
1099
1157
  Preview:
@@ -1112,14 +1170,55 @@ class Mixscape:
1112
1170
  n_components = adata_subset.uns[lda_key].shape[1]
1113
1171
  sc.pp.neighbors(adata_subset, use_rep=lda_key)
1114
1172
  sc.tl.umap(adata_subset, n_components=n_components)
1115
- sc.pl.umap(
1173
+ fig = sc.pl.umap(
1116
1174
  adata_subset,
1117
1175
  color=mixscape_class,
1118
1176
  palette=palette,
1119
1177
  color_map=color_map,
1120
1178
  return_fig=return_fig,
1121
- show=show,
1122
- save=save,
1179
+ show=False,
1123
1180
  ax=ax,
1124
1181
  **kwds,
1125
1182
  )
1183
+
1184
+ if return_fig:
1185
+ return fig
1186
+ plt.show()
1187
+ return None
1188
+
1189
+ class MixscapeGaussianMixture(GaussianMixture):
1190
+ def __init__(
1191
+ self,
1192
+ n_components: int,
1193
+ fixed_means: Sequence[float] | None = None,
1194
+ fixed_covariances: Sequence[float] | None = None,
1195
+ **kwargs
1196
+ ):
1197
+ """Custom Gaussian Mixture Model where means and covariances can be fixed for specific components.
1198
+
1199
+ Args:
1200
+ n_components: Number of Gaussian components
1201
+ fixed_means: Means to fix (use None for those that should be estimated)
1202
+ fixed_covariances: Covariances to fix (use None for those that should be estimated)
1203
+ kwargs: Additional arguments passed to scikit-learn's GaussianMixture
1204
+ """
1205
+ super().__init__(n_components=n_components, **kwargs)
1206
+ self.fixed_means = fixed_means
1207
+ self.fixed_covariances = fixed_covariances
1208
+
1209
+ def _m_step(self, X: np.ndarray, log_resp: np.ndarray):
1210
+ """Modified M-step to respect fixed means and covariances."""
1211
+ super()._m_step(X, log_resp)
1212
+
1213
+ if self.fixed_means is not None:
1214
+ for i in range(self.n_components):
1215
+ if self.fixed_means[i] is not None:
1216
+ self.means_[i] = self.fixed_means[i]
1217
+
1218
+ if self.fixed_covariances is not None:
1219
+ for i in range(self.n_components):
1220
+ if self.fixed_covariances[i] is not None:
1221
+ self.covariances_[i] = self.fixed_covariances[i]
1222
+
1223
+ return self
1224
+