DeConveil 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
DeConveil/dds.py ADDED
@@ -0,0 +1,1279 @@
1
+ import sys
2
+ import time
3
+ import warnings
4
+ from typing import List
5
+ from typing import Literal
6
+ from typing import Optional
7
+ from typing import Union
8
+ from typing import cast
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ from scipy.optimize import minimize
13
+ from scipy.special import polygamma # type: ignore
14
+ from scipy.stats import f # type: ignore
15
+ from scipy.stats import trim_mean # type: ignore
16
+
17
+ from deconveil.default_inference import DefInference
18
+ from deconveil.inference import Inference
19
+ from deconveil import utils_CNaware
20
+ from deconveil.utils_CNaware import fit_rough_dispersions
21
+ from deconveil.utils_CNaware import fit_moments_dispersions2
22
+ from deconveil.utils_CNaware import grid_fit_beta
23
+ from deconveil.utils_CNaware import irls_glm
24
+
25
+ from pydeseq2.preprocessing import deseq2_norm_fit
26
+ from pydeseq2.preprocessing import deseq2_norm_transform
27
+ from pydeseq2.utils import build_design_matrix
28
+ from pydeseq2.utils import dispersion_trend
29
+ from pydeseq2.utils import mean_absolute_deviation
30
+ from pydeseq2.utils import n_or_more_replicates
31
+ from pydeseq2.utils import nb_nll
32
+ from pydeseq2.utils import replace_underscores
33
+ from pydeseq2.utils import robust_method_of_moments_disp
34
+ from pydeseq2.utils import test_valid_counts
35
+ from pydeseq2.utils import trimmed_mean
36
+
37
+
38
+ class deconveil_fit:
39
+ r"""A class to implement dispersion and log fold-change (LFC) estimation.
40
+ Dispersions and LFCs are estimated following the DESeq2/PyDESeq2 pipeline.
41
+
42
+ Parameters
43
+ ----------
44
+ counts : pandas.DataFrame
45
+ Raw counts. One column per gene, rows are indexed by sample barcodes.
46
+
47
+ cnv : pandas.DataFrame
48
+ Discrete numbres. One column per gene, rows are indexed by sample barcodes.
49
+
50
+
51
+ metadata : pandas.DataFrame
52
+ DataFrame containing sample metadata.
53
+ Must be indexed by sample barcodes.
54
+
55
+ design_factors : str or list
56
+ Name of the columns of metadata to be used as design variables.
57
+ (default: ``'condition'``).
58
+
59
+ continuous_factors : list or None
60
+ An optional list of continuous (as opposed to categorical) factors. Any factor
61
+ not in ``continuous_factors`` will be considered categorical (default: ``None``).
62
+
63
+ ref_level : list or None
64
+ An optional list of two strings of the form ``["factor", "test_level"]``
65
+ specifying the factor of interest and the reference (control) level against which
66
+ we're testing, e.g. ``["condition", "A"]``. (default: ``None``).
67
+
68
+ fit_type: str
69
+ Either ``"parametric"`` or ``"mean"`` for the type of fitting of dispersions to
70
+ the mean intensity. ``"parametric"``: fit a dispersion-mean relation via a
71
+ robust gamma-family GLM. ``"mean"``: use the mean of gene-wise dispersion
72
+ estimates. Will set the fit type for the DEA and the vst transformation. If
73
+ needed, it can be set separately for each method.(default: ``"parametric"``).
74
+
75
+ min_mu : float
76
+ Threshold for mean estimates. (default: ``0.5``).
77
+
78
+ min_disp : float
79
+ Lower threshold for dispersion parameters. (default: ``1e-8``).
80
+
81
+ max_disp : float
82
+ Upper threshold for dispersion parameters.
83
+ Note: The threshold that is actually enforced is max(max_disp, len(counts)).
84
+ (default: ``10``).
85
+
86
+ refit_cooks : bool
87
+ Whether to refit cooks outliers. (default: ``True``).
88
+
89
+ min_replicates : int
90
+ Minimum number of replicates a condition should have
91
+ to allow refitting its samples. (default: ``7``).
92
+
93
+ beta_tol : float
94
+ Stopping criterion for IRWLS. (default: ``1e-8``).
95
+
96
+ .. math:: \vert dev_t - dev_{t+1}\vert / (\vert dev \vert + 0.1) < \beta_{tol}.
97
+
98
+ n_cpus : int
99
+ Number of cpus to use. If ``None`` and if ``inference`` is not provided, all
100
+ available cpus will be used by the ``DefaultInference``. If both are specified
101
+ (i.e., ``n_cpus`` and ``inference`` are not ``None``), it will try to override
102
+ the ``n_cpus`` attribute of the ``inference`` object. (default: ``None``).
103
+
104
+ inference : Inference
105
+ Implementation of inference routines object instance.
106
+ (default:
107
+ :class:`DefaultInference <pydeseq2.default_inference.DefaultInference>`).
108
+
109
+ Attributes
110
+ ----------
111
+ n_processes : int
112
+ Number of cpus to use for multiprocessing.
113
+
114
+ non_zero_idx : ndarray
115
+ Indices of genes that have non-uniformly zero counts.
116
+
117
+ non_zero_genes : pandas.Index
118
+ Index of genes that have non-uniformly zero counts.
119
+
120
+ logmeans: numpy.ndarray
121
+ Gene-wise mean log counts, computed in ``preprocessing.deseq2_norm_fit()``.
122
+
123
+ filtered_genes: numpy.ndarray
124
+ Genes whose log means are different from -∞, computed in
125
+ preprocessing.deseq2_norm_fit().
126
+
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ *,
132
+ counts: Optional[pd.DataFrame] = None,
133
+ cnv: Optional[pd.DataFrame] = None,
134
+ metadata: Optional[pd.DataFrame] = None,
135
+ design_factors: Union[str, List[str]] = "condition",
136
+ continuous_factors: Optional[List[str]] = None,
137
+ ref_level: Optional[List[str]] = None,
138
+ fit_type: Literal["parametric", "mean"] = "parametric",
139
+ min_mu: float = 0.5,
140
+ min_disp: float = 1e-8,
141
+ max_disp: float = 10.0,
142
+ refit_cooks: bool = True,
143
+ min_replicates: int = 7,
144
+ beta_tol: float = 1e-8,
145
+ n_cpus: Optional[int] = None,
146
+ inference: Optional[Inference] = None,
147
+ quiet: bool = False,
148
+ ) -> None:
149
+
150
+ """
151
+ Initialize object
152
+ """
153
+ self.data={}
154
+ self.data["counts"] = counts
155
+ self.data["counts"] = self.data["counts"].astype(int)
156
+ self.data["cnv"] = cnv
157
+
158
+ if self.data["counts"].shape[0] != self.data["cnv"].shape[0] and self.data["counts"].shape[1] != self.data["cnv"].shape[1]:
159
+ raise ValueError("Matrices must have the same dimensions for element-wise operations.")
160
+
161
+ # Test counts before going further
162
+ test_valid_counts(counts)
163
+
164
+ self.metadata = metadata
165
+ self.fit_type = fit_type
166
+
167
+ # Convert design_factors to list if a single string was provided.
168
+ self.design_factors = (
169
+ [design_factors] if isinstance(design_factors, str) else design_factors
170
+ )
171
+
172
+ self.continuous_factors = continuous_factors
173
+
174
+
175
+ # Build the design matrix
176
+ self.design_matrix = build_design_matrix(
177
+ metadata=self.metadata,
178
+ design_factors=self.design_factors,
179
+ continuous_factors=self.continuous_factors,
180
+ ref_level=ref_level,
181
+ expanded=False,
182
+ intercept=True,
183
+ )
184
+
185
+ self.obsm={}
186
+ self.obsm["design_matrix"] = self.design_matrix
187
+ self.min_mu = min_mu
188
+ self.min_disp = min_disp
189
+ self.n_obs=self.data["counts"].shape[0]
190
+ self.n_vars=self.data["counts"].shape[1]
191
+ self.var_names=self.data["counts"].columns
192
+ self.max_disp = np.maximum(max_disp, self.n_obs)
193
+ self.refit_cooks = refit_cooks
194
+ self.ref_level = ref_level
195
+ self.min_replicates = min_replicates
196
+ self.beta_tol = beta_tol
197
+ self.quiet = quiet
198
+ self.logmeans = None
199
+ self.filtered_genes = None
200
+ self.uns={}
201
+ self.varm={}
202
+ self.layers={}
203
+
204
+
205
+ if inference:
206
+ if hasattr(inference, "n_cpus"):
207
+ if n_cpus:
208
+ inference.n_cpus = n_cpus
209
+ else:
210
+ warnings.warn(
211
+ "The provided inference object does not have an n_cpus "
212
+ "attribute, cannot override `n_cpus`.",
213
+ UserWarning,
214
+ stacklevel=2,
215
+ )
216
+ # Initialize the inference object.
217
+ self.inference = inference or DefInference(n_cpus=n_cpus)
218
+
219
+
220
+ def vst(
221
+ self,
222
+ use_design: bool = False,
223
+ fit_type: Optional[Literal["parametric", "mean"]] = None,
224
+ ) -> None:
225
+
226
+ """Fit a variance stabilizing transformation, and apply it to normalized counts.
227
+ Results are stored in ``vst_counts"``.
228
+ """
229
+
230
+ if fit_type is not None:
231
+ self.vst_fit_type = fit_type
232
+ else:
233
+ self.vst_fit_type = self.fit_type
234
+
235
+ print(f"Fit type used for VST : {self.vst_fit_type}")
236
+ self.vst_fit(use_design=use_design)
237
+ self.layers["vst_counts"] = self.vst_transform()
238
+
239
+
240
+ def vst_fit(
241
+ self,
242
+ use_design: bool = False,
243
+ ) -> None:
244
+ """Fit a variance stabilizing transformation.
245
+
246
+ This method should be called before `vst_transform`.
247
+
248
+ Results are stored in ``dds.vst_counts``.
249
+
250
+ Parameters
251
+ ----------
252
+ use_design : bool
253
+ Whether to use the full design matrix to fit dispersions and the trend curve.
254
+ If False, only an intercept is used.
255
+ Only useful if ``fit_type = "parametric"`.
256
+ (default: ``False``).
257
+ """
258
+ # Start by fitting median-of-ratio size factors if not already present,
259
+ # or if they were computed iteratively
260
+ if "size_factors" not in self.obsm or self.logmeans is None:
261
+ self.fit_size_factors() # by default, fit_type != "iterative"
262
+
263
+ if not hasattr(self, "vst_fit_type"):
264
+ self.vst_fit_type = self.fit_type
265
+
266
+ if use_design:
267
+ if self.vst_fit_type == "parametric":
268
+ self._fit_parametric_dispersion_trend(vst=True)
269
+ else:
270
+ warnings.warn(
271
+ "use_design=True is only useful when fit_type='parametric'. ",
272
+ UserWarning,
273
+ stacklevel=2,
274
+ )
275
+ self.fit_genewise_dispersions(vst=True)
276
+
277
+ else:
278
+ # Reduce the design matrix to an intercept and reconstruct at the end
279
+ self.obsm["design_matrix_buffer"] = self.obsm["design_matrix"].copy()
280
+ self.obsm["design_matrix"] = pd.DataFrame(
281
+ 1, index=self.obs_names, columns=[["intercept"]]
282
+ )
283
+ # Fit the trend curve with an intercept design
284
+ self.fit_genewise_dispersions(vst=True)
285
+ if self.vst_fit_type == "parametric":
286
+ self._fit_parametric_dispersion_trend(vst=True)
287
+
288
+ # Restore the design matrix and free buffer
289
+ self.obsm["design_matrix"] = self.obsm["design_matrix_buffer"].copy()
290
+ del self.obsm["design_matrix_buffer"]
291
+
292
+
293
+ def vst_transform(self, counts: Optional[np.ndarray] = None) -> np.ndarray:
294
+
295
+ """Apply the variance stabilizing transformation.
296
+ Uses the results from the ``vst_fit`` method.
297
+ Returns
298
+ -------
299
+ numpy.ndarray
300
+ Variance stabilized counts.
301
+ """
302
+
303
+ if "size_factors" not in self.obsm:
304
+ raise RuntimeError(
305
+ "The vst_fit method should be called prior to vst_transform."
306
+ )
307
+
308
+ if counts is None:
309
+ # the transformed counts will be the current ones
310
+ normed_counts = self.layers["normed_counts"]
311
+ else:
312
+ if self.logmeans is None:
313
+ # the size factors were still computed iteratively
314
+ warnings.warn(
315
+ "The size factors were fitted iteratively. They will "
316
+ "be re-computed with the counts to be transformed. In a train/test "
317
+ "setting with a downstream task, this would result in a leak of "
318
+ "data from test to train set.",
319
+ UserWarning,
320
+ stacklevel=2,
321
+ )
322
+ logmeans, filtered_genes = deseq2_norm_fit(counts)
323
+ else:
324
+ logmeans, filtered_genes = self.logmeans, self.filtered_genes
325
+
326
+ normed_counts, _ = deseq2_norm_transform(counts, logmeans, filtered_genes)
327
+
328
+ if self.vst_fit_type == "parametric":
329
+ if "vst_trend_coeffs" not in self.uns:
330
+ raise RuntimeError("Fit the dispersion curve prior to applying VST.")
331
+
332
+ a0, a1 = self.uns["vst_trend_coeffs"]
333
+ return np.log2(
334
+ (
335
+ 1
336
+ + a1
337
+ + 2 * a0 * normed_counts
338
+ + 2 * np.sqrt(a0 * normed_counts * (1 + a1 + a0 * normed_counts))
339
+ )
340
+ / (4 * a0)
341
+ )
342
+
343
+ elif self.vst_fit_type == "mean":
344
+ gene_dispersions = self.varm["vst_genewise_dispersions"]
345
+ use_for_mean = gene_dispersions > 10 * self.min_disp
346
+ mean_disp = trim_mean(gene_dispersions[use_for_mean], proportiontocut=0.001)
347
+ return (
348
+ 2 * np.arcsinh(np.sqrt(mean_disp * normed_counts))
349
+ - np.log(mean_disp)
350
+ - np.log(4)
351
+ ) / np.log(2)
352
+ else:
353
+ raise NotImplementedError(
354
+ f"Found fit_type '{self.vst_fit_type}'. Expected 'parametric' or 'mean'."
355
+ )
356
+
357
+
358
+ def deseq2(self, fit_type: Optional[Literal["parametric", "mean"]] = None) -> None:
359
+
360
+ """Perform dispersion and log fold-change (LFC) estimation.
361
+
362
+ Wrapper for the first part of the PyDESeq2 pipeline.
363
+
364
+ Parameters
365
+ ----------
366
+ fit_type : str
367
+ Either None, ``"parametric"`` or ``"mean"`` for the type of fitting of
368
+ dispersions to the mean intensity.``"parametric"``: fit a dispersion-mean
369
+ relation via a robust gamma-family GLM. ``"mean"``: use the mean of
370
+ gene-wise dispersion estimates.
371
+
372
+ If None, the fit_type provided at class initialization is used.
373
+ (default: ``None``).
374
+ """
375
+
376
+ if fit_type is not None:
377
+ self.fit_type = fit_type
378
+ print(f"Using {self.fit_type} fit type.")
379
+ # Compute DESeq2 normalization factors using the Median-of-ratios method
380
+ self.fit_size_factors()
381
+ # Fit an independent negative binomial model per gene
382
+ self.fit_genewise_dispersions()
383
+ # Fit a parameterized trend curve for dispersions, of the form
384
+ # f(\mu) = \alpha_1/\mu + a_0
385
+ self.fit_dispersion_trend()
386
+ # Compute prior dispersion variance
387
+ self.fit_dispersion_prior()
388
+ # Refit genewise dispersions a posteriori (shrinks estimates towards trend curve)
389
+ self.fit_MAP_dispersions()
390
+ # Fit log-fold changes (in natural log scale)
391
+ self.fit_LFC()
392
+ # Compute Cooks distances to find outliers
393
+ self.calculate_cooks()
394
+
395
+ if self.refit_cooks:
396
+ # Replace outlier counts, and refit dispersions and LFCs
397
+ # for genes that had outliers replaced
398
+ self.refit()
399
+
400
+ def fit_size_factors(
401
+ self,
402
+ fit_type: Literal["ratio", "poscounts", "iterative"] = "ratio",
403
+ control_genes: Optional[
404
+ Union[np.ndarray, List[str], List[int], pd.Index]
405
+ ] = None,
406
+ ) -> None:
407
+ """Fit sample-wise deseq2 normalization (size) factors.
408
+ Parameters
409
+ ----------
410
+ fit_type : str
411
+ The normalization method to use: "ratio", "poscounts" or "iterative".
412
+ (default: ``"ratio"``).
413
+ control_genes : ndarray, list, pandas.Index, or None
414
+ Genes to use as control genes for size factor fitting. If None, all genes
415
+ are used. (default: ``None``).
416
+ """
417
+
418
+ if not self.quiet:
419
+ print("Fitting size factors...", file=sys.stderr)
420
+
421
+ start = time.time()
422
+
423
+ # If control genes are provided, set a mask where those genes are True
424
+ if control_genes is not None:
425
+ _control_mask = np.zeros(self.data["counts"].shape[1], dtype=bool)
426
+
427
+ # Use AnnData internal indexing to get gene index array
428
+ # Allows bool/int/var_name to be provided
429
+ _control_mask[self._normalize_indices((slice(None), control_genes))[1]] = (
430
+ True
431
+ )
432
+
433
+ # Otherwise mask all genes to be True
434
+ else:
435
+ _control_mask = np.ones(self.data["counts"].shape[1], dtype=bool)
436
+
437
+ if fit_type == "iterative":
438
+ self._fit_iterate_size_factors()
439
+
440
+ elif fit_type == "poscounts":
441
+
442
+ # Calculate logcounts for x > 0 and take the mean for each gene
443
+ log_counts = np.zeros_like(self.data["counts"], dtype=float)
444
+ np.log(self.data["counts"], out=log_counts, where=self.data["counts"] != 0)
445
+ logmeans = log_counts.mean(0)
446
+
447
+ # Determine which genes are usable (finite logmeans)
448
+ self.filtered_genes = (~np.isinf(logmeans)) & (logmeans > 0)
449
+ _control_mask &= self.filtered_genes
450
+
451
+ # Calculate size factor per sample
452
+ def sizeFactor(x):
453
+ _mask = np.logical_and(_control_mask, x > 0)
454
+ return np.exp(np.median(np.log(x[_mask]) - logmeans[_mask]))
455
+
456
+ sf = np.apply_along_axis(sizeFactor, 1, self.data["counts"])
457
+ del log_counts
458
+
459
+ # Normalize size factors to a geometric mean of 1 to match DESeq
460
+ self.obsm["size_factors"] = sf / (np.exp(np.mean(np.log(sf))))
461
+ self.layers["normed_counts"] = self.data["counts"] / self.obsm["size_factors"][:, None]
462
+ self.logmeans = logmeans
463
+
464
+ # Test whether it is possible to use median-of-ratios.
465
+ elif (self.data["counts"] == 0).any().all():
466
+ # There is at least a zero for each gene
467
+ warnings.warn(
468
+ "Every gene contains at least one zero, "
469
+ "cannot compute log geometric means. Switching to iterative mode.",
470
+ UserWarning,
471
+ stacklevel=2,
472
+ )
473
+ self._fit_iterate_size_factors()
474
+
475
+ else:
476
+ self.logmeans, self.filtered_genes = deseq2_norm_fit(self.data["counts"])
477
+ _control_mask &= self.filtered_genes
478
+
479
+ (
480
+ self.layers["normed_counts"],
481
+ self.obsm["size_factors"],
482
+ ) = deseq2_norm_transform(self.data["counts"], self.logmeans, _control_mask)
483
+
484
+ end = time.time()
485
+ self.varm["_normed_means"] = self.layers["normed_counts"].mean(0)
486
+
487
+ if not self.quiet:
488
+ print(f"... done in {end - start:.2f} seconds.\n", file=sys.stderr)
489
+
490
+
491
+ def fit_genewise_dispersions(self, vst=False) -> None:
492
+
493
+ """Fit gene-wise dispersion estimates.
494
+
495
+ Fits a negative binomial per gene, independently.
496
+
497
+ Parameters
498
+ ----------
499
+ vst : bool
500
+ Whether the dispersion estimates are being fitted as part of the VST
501
+ pipeline. (default: ``False``).
502
+ """
503
+
504
+ # Check that size factors are available. If not, compute them.
505
+ if "size_factors" not in self.obsm:
506
+ self.fit_size_factors()
507
+
508
+ # Exclude genes with all zeroes
509
+ self.varm["non_zero"] = ~(self.data["counts"] == 0).all(axis=0)
510
+ self.non_zero_idx = np.arange(self.n_vars)[self.varm["non_zero"]]
511
+ self.non_zero_genes = self.var_names[self.varm["non_zero"]]
512
+
513
+ #if isinstance(self.non_zero_genes, pd.MultiIndex):
514
+ #raise ValueError("non_zero_genes should not be a MultiIndex")
515
+
516
+ # Fit "method of moments" dispersion estimates
517
+ self._fit_MoM_dispersions()
518
+
519
+ # Convert to numpy for speed
520
+ design_matrix = self.obsm["design_matrix"].values
521
+ counts=self.data["counts"].to_numpy()
522
+ cnv=self.data["cnv"].to_numpy()
523
+
524
+ # with a GLM (using rough dispersion estimates).
525
+ if (
526
+ len(self.obsm["design_matrix"].value_counts())
527
+ == self.obsm["design_matrix"].shape[-1]
528
+ ):
529
+ mu_hat_ = self.inference.lin_reg_mu(
530
+ counts=counts[:, self.non_zero_idx],
531
+ size_factors=self.obsm["size_factors"],
532
+ design_matrix=design_matrix,
533
+ min_mu=self.min_mu,
534
+ )
535
+ else:
536
+ _, mu_hat_, _, _ = self.inference.irls_glm(
537
+ counts=counts[:, self.non_zero_idx],
538
+ cnv=cnv[:, self.non_zero_idx],
539
+ size_factors=self.obsm["size_factors"],
540
+ design_matrix=design_matrix,
541
+ disp=self.varm["_MoM_dispersions"][self.non_zero_idx],
542
+ min_mu=self.min_mu,
543
+ beta_tol=self.beta_tol,
544
+ )
545
+ mu_param_name = "_vst_mu_hat" if vst else "_mu_hat"
546
+ disp_param_name = "genewise_dispersions"
547
+
548
+ self.layers[mu_param_name] = np.full((self.n_obs, self.n_vars), np.nan)
549
+ self.layers[mu_param_name][:, self.varm["non_zero"]] = mu_hat_
550
+
551
+ if not self.quiet:
552
+ print("Fitting dispersions...", file=sys.stderr)
553
+ start = time.time()
554
+ dispersions_, l_bfgs_b_converged_ = self.inference.alpha_mle(
555
+ counts=counts[:, self.non_zero_idx],
556
+ design_matrix=design_matrix,
557
+ mu=self.layers[mu_param_name][:, self.non_zero_idx],
558
+ alpha_hat=self.varm["_MoM_dispersions"][self.non_zero_idx],
559
+ min_disp=self.min_disp,
560
+ max_disp=self.max_disp,
561
+ )
562
+ end = time.time()
563
+
564
+ if not self.quiet:
565
+ print(f"... done in {end - start:.2f} seconds.\n", file=sys.stderr)
566
+
567
+ self.varm[disp_param_name] = np.full(self.n_vars, np.nan)
568
+ self.varm[disp_param_name][self.varm["non_zero"]] = np.clip(
569
+ dispersions_, self.min_disp, self.max_disp
570
+ )
571
+
572
+ self.varm["_genewise_converged"] = np.full(self.n_vars, np.nan)
573
+ self.varm["_genewise_converged"][self.varm["non_zero"]] = l_bfgs_b_converged_
574
+
575
+
576
+ def fit_dispersion_trend(self, vst: bool = False) -> None:
577
+
578
+ """Fit the dispersion trend curve.
579
+
580
+ Parameters
581
+ ----------
582
+ vst : bool
583
+ Whether the dispersion trend curve is being fitted as part of the VST
584
+ pipeline. (default: ``False``).
585
+ """
586
+ #disp_param_name = "vst_genewise_dispersions" if vst else "genewise_dispersions"
587
+ disp_param_name = "genewise_dispersions"
588
+ fit_type = self.vst_fit_type if vst else self.fit_type
589
+
590
+ # Check that genewise dispersions are available. If not, compute them.
591
+ if disp_param_name not in self.varm:
592
+ self.fit_genewise_dispersions(vst)
593
+
594
+ if not self.quiet:
595
+ print("Fitting dispersion trend curve...", file=sys.stderr)
596
+ start = time.time()
597
+
598
+ if fit_type == "parametric":
599
+ self._fit_parametric_dispersion_trend(vst)
600
+ elif fit_type == "mean":
601
+ self._fit_mean_dispersion_trend(vst)
602
+ else:
603
+ raise NotImplementedError(
604
+ f"Expected 'parametric' or 'mean' trend curve fit "
605
+ f"types, received {fit_type}"
606
+ )
607
+ end = time.time()
608
+
609
+ if not self.quiet:
610
+ print(f"... done in {end - start:.2f} seconds.\n", file=sys.stderr)
611
+
612
+ def disp_function(self, x):
613
+ """Return the dispersion trend function at x."""
614
+ if self.uns["disp_function_type"] == "parametric":
615
+ return dispersion_trend(x, self.uns["trend_coeffs"])
616
+ elif self.disp_function_type == "mean":
617
+ return np.full_like(x, self.uns["mean_disp"])
618
+
619
+
620
+ def fit_dispersion_prior(self) -> None:
621
+ """Fit dispersion variance priors and standard deviation of log-residuals.
622
+
623
+ The computation is based on genes whose dispersions are above 100 * min_disp.
624
+
625
+ Note: when the design matrix has fewer than 3 degrees of freedom, the
626
+ estimate of log dispersions is likely to be imprecise.
627
+ """
628
+
629
+ # Check that the dispersion trend curve was fitted. If not, fit it.
630
+ if "fitted_dispersions" not in self.varm:
631
+ self.fit_dispersion_trend()
632
+
633
+ # Exclude genes with all zeroes
634
+ num_samples = self.n_obs
635
+ num_vars = self.obsm["design_matrix"].shape[-1]
636
+
637
+ # Check the degrees of freedom
638
+ if (num_samples - num_vars) <= 3:
639
+ warnings.warn(
640
+ "As the residual degrees of freedom is less than 3, the distribution "
641
+ "of log dispersions is especially asymmetric and likely to be poorly "
642
+ "estimated by the MAD.",
643
+ UserWarning,
644
+ stacklevel=2,
645
+ )
646
+
647
+ # Fit dispersions to the curve, and compute log residuals
648
+ gene_labels = self.non_zero_genes.to_numpy()
649
+ position_map = {label: idx for idx, label in enumerate(gene_labels)}
650
+ gene_index = self.non_zero_genes.map(lambda x: position_map[x])
651
+
652
+ disp_residuals = np.log(
653
+ self.varm["genewise_dispersions"][gene_index]
654
+ ) - np.log(self.varm["fitted_dispersions"][gene_index])
655
+
656
+ # Compute squared log-residuals and prior variance based on genes whose
657
+ # dispersions are above 100 * min_disp. This is to reproduce DESeq2's behaviour.
658
+ above_min_disp = self.varm["genewise_dispersions"][gene_index] >= (
659
+ 100 * self.min_disp
660
+ )
661
+
662
+ self.uns["_squared_logres"] = (
663
+ mean_absolute_deviation(disp_residuals[above_min_disp]) ** 2
664
+ )
665
+
666
+ self.uns["prior_disp_var"] = np.maximum(
667
+ self.uns["_squared_logres"] - polygamma(1, (num_samples - num_vars) / 2),
668
+ 0.25,
669
+ )
670
+
671
+ def fit_MAP_dispersions(self) -> None:
672
+ """Fit Maximum a Posteriori dispersion estimates.
673
+
674
+ After MAP dispersions are fit, filter genes for which we don't apply shrinkage.
675
+ """
676
+
677
+ # Check that the dispersion prior variance is available. If not, compute it.
678
+ if "prior_disp_var" not in self.uns:
679
+ self.fit_dispersion_prior()
680
+
681
+ # Convert to numpy for speed
682
+ design_matrix = self.obsm["design_matrix"].values
683
+ counts=self.data["counts"].to_numpy()
684
+
685
+ if not self.quiet:
686
+ print("Fitting MAP dispersions...", file=sys.stderr)
687
+ start = time.time()
688
+ dispersions_, l_bfgs_b_converged_ = self.inference.alpha_mle(
689
+ counts=counts[:, self.non_zero_idx],
690
+ design_matrix=design_matrix,
691
+ mu=self.layers["_mu_hat"][:, self.non_zero_idx],
692
+ alpha_hat=self.varm["fitted_dispersions"][self.non_zero_idx],
693
+ min_disp=self.min_disp,
694
+ max_disp=self.max_disp,
695
+ prior_disp_var=self.uns["prior_disp_var"].item(),
696
+ cr_reg=True,
697
+ prior_reg=True,
698
+ )
699
+ end = time.time()
700
+
701
+ if not self.quiet:
702
+ print(f"... done in {end-start:.2f} seconds.\n", file=sys.stderr)
703
+
704
+ self.varm["MAP_dispersions"] = np.full(self.n_vars, np.nan)
705
+ self.varm["MAP_dispersions"][self.varm["non_zero"]] = np.clip(
706
+ dispersions_, self.min_disp, self.max_disp
707
+ )
708
+
709
+ self.varm["_MAP_converged"] = np.full(self.n_vars, np.nan)
710
+ self.varm["_MAP_converged"][self.varm["non_zero"]] = l_bfgs_b_converged_
711
+
712
+ # Filter outlier genes for which we won't apply shrinkage
713
+ self.varm["dispersions"] = self.varm["MAP_dispersions"].copy()
714
+ self.varm["_outlier_genes"] = np.log(self.varm["genewise_dispersions"]) > np.log(
715
+ self.varm["fitted_dispersions"]
716
+ ) + 2 * np.sqrt(self.uns["_squared_logres"])
717
+
718
+ self.varm["dispersions"][self.varm["_outlier_genes"]] = self.varm["genewise_dispersions"][
719
+ self.varm["_outlier_genes"]
720
+ ]
721
+
722
+
723
+ def fit_LFC(self) -> None:
724
+ """Fit log fold change (LFC) coefficients.
725
+
726
+ In the 2-level setting, the intercept corresponds to the base mean,
727
+ while the second is the actual LFC coefficient, in natural log scale.
728
+ """
729
+
730
+ # Check that MAP dispersions are available. If not, compute them.
731
+ if "dispersions" not in self.varm:
732
+ self.fit_MAP_dispersions()
733
+
734
+ # Convert to numpy for speed
735
+ design_matrix = self.obsm["design_matrix"].values
736
+ counts=self.data["counts"].to_numpy()
737
+ cnv=self.data["cnv"].to_numpy()
738
+ cnv = cnv / 2
739
+ cnv = cnv + 0.1
740
+
741
+ if not self.quiet:
742
+ print("Fitting LFCs...", file=sys.stderr)
743
+ start = time.time()
744
+ mle_lfcs_, mu_, hat_diagonals_, converged_ = self.inference.irls_glm(
745
+ counts=counts[:, self.non_zero_idx],
746
+ cnv=cnv[:, self.non_zero_idx],
747
+ size_factors=self.obsm["size_factors"],
748
+ design_matrix=design_matrix,
749
+ disp=self.varm["dispersions"][self.non_zero_idx],
750
+ min_mu=self.min_mu,
751
+ beta_tol=self.beta_tol,
752
+ )
753
+ end = time.time()
754
+
755
+ if not self.quiet:
756
+ print(f"... done in {end-start:.2f} seconds.\n", file=sys.stderr)
757
+
758
+ self.varm["LFC"] = pd.DataFrame(
759
+ np.nan,
760
+ index=self.var_names,
761
+ columns=self.obsm["design_matrix"].columns,
762
+ )
763
+
764
+ self.varm["LFC"].update(
765
+ pd.DataFrame(
766
+ mle_lfcs_,
767
+ index=self.non_zero_genes,
768
+ columns=self.obsm["design_matrix"].columns,
769
+ )
770
+ )
771
+
772
+ self.obsm["_mu_LFC"] = mu_
773
+ self.obsm["_hat_diagonals"] = hat_diagonals_
774
+
775
+ self.varm["_LFC_converged"] = np.full(self.n_vars, np.nan)
776
+ self.varm["_LFC_converged"][self.varm["non_zero"]] = converged_
777
+
778
+
779
+ def calculate_cooks(self) -> None:
780
+ """Compute Cook's distance for outlier detection.
781
+
782
+ Measures the contribution of a single entry to the output of LFC estimation.
783
+ """
784
+ # Check that MAP dispersions are available. If not, compute them.
785
+ if "dispersions" not in self.varm:
786
+ self.fit_MAP_dispersions()
787
+
788
+ if not self.quiet:
789
+ print("Calculating cook's distance...", file=sys.stderr)
790
+
791
+ start = time.time()
792
+ num_vars = self.obsm["design_matrix"].shape[-1]
793
+
794
+ #self.layers["normed_counts"] = self.layers["normed_counts"].to_numpy()
795
+ non_zero_mask = self.varm["non_zero"].to_numpy()
796
+ self.layers["normed_counts"] = self.layers["normed_counts"].to_numpy()
797
+
798
+ # Calculate dispersion
799
+ dispersions = robust_method_of_moments_disp(
800
+ self.layers["normed_counts"][:, non_zero_mask],
801
+ self.obsm["design_matrix"],
802
+ )
803
+
804
+ self.data["counts"] = self.data["counts"].to_numpy()
805
+
806
+ # Calculate the squared pearson residuals for non-zero features
807
+ squared_pearson_res = self.data["counts"][:, non_zero_mask] - self.obsm["_mu_LFC"]
808
+ squared_pearson_res **= 2
809
+
810
+ # Calculate the overdispersion parameter tau
811
+ V = self.obsm["_mu_LFC"] ** 2
812
+ V *= dispersions[None, :]
813
+ V += self.obsm["_mu_LFC"]
814
+
815
+ # Calculate r^2 / (tau * num_vars)
816
+ squared_pearson_res /= V
817
+ squared_pearson_res /= num_vars
818
+
819
+ del V
820
+
821
+ # Calculate leverage modifier H / (1 - H)^2
822
+ diag_mul = 1 - self.obsm["_hat_diagonals"]
823
+ diag_mul **= 2
824
+ diag_mul = self.obsm["_hat_diagonals"] / diag_mul
825
+
826
+ # Multiply r^2 / (tau * num_vars) by H / (1 - H)^2 to get cook's distance
827
+ squared_pearson_res *= diag_mul
828
+
829
+ del diag_mul
830
+
831
+ self.layers["cooks"] = np.full((self.n_obs, self.n_vars), np.nan)
832
+ self.layers["cooks"][:, non_zero_mask] = squared_pearson_res
833
+
834
+ if not self.quiet:
835
+ print(f"... done in {time.time()-start:.2f} seconds.\n", file=sys.stderr)
836
+
837
+ def refit(self) -> None:
838
+ """Refit Cook outliers.
839
+
840
+ Replace values that are filtered out based on the Cooks distance with imputed
841
+ values, and then re-run the whole DESeq2 pipeline on replaced values.
842
+ """
843
+ # Replace outlier counts
844
+ self._replace_outliers()
845
+ if not self.quiet:
846
+ print(
847
+ f"Replacing {sum(self.varm['replaced']) } outlier genes.\n",
848
+ file=sys.stderr,
849
+ )
850
+
851
+ if sum(self.varm["replaced"]) > 0:
852
+ # Refit dispersions and LFCs for genes that had outliers replaced
853
+ self._refit_without_outliers()
854
+ else:
855
+ # Store the fact that no sample was refitted
856
+ self.varm["refitted"] = np.full(
857
+ self.n_vars,
858
+ False,
859
+ )
860
+
861
+
862
+ def _fit_MoM_dispersions(self) -> None:
863
+
864
+ """Rough method of moments initial dispersions fit.
865
+
866
+ Estimates are the max of "robust" and "method of moments" estimates.
867
+ """
868
+ # Check that size_factors are available. If not, compute them.
869
+ if "normed_counts" not in self.layers:
870
+ self.fit_size_factors()
871
+
872
+ normed_counts = self.layers["normed_counts"]
873
+ rde = self.inference.fit_rough_dispersions(
874
+ normed_counts,
875
+ self.obsm["design_matrix"].values,
876
+ )
877
+ mde = self.inference.fit_moments_dispersions2(
878
+ normed_counts,
879
+ self.obsm["size_factors"]
880
+ )
881
+ alpha_hat = np.minimum(rde, mde)
882
+
883
+ self.varm["_MoM_dispersions"] = np.full(self.n_vars, np.nan)
884
+ self.varm["_MoM_dispersions"][self.varm["non_zero"]] = np.clip(
885
+ alpha_hat, self.min_disp, self.max_disp
886
+ )
887
+
888
+
889
+ def _fit_parametric_dispersion_trend(self, vst: bool = False):
890
+
891
+ r"""Fit the dispersion curve according to a parametric model.
892
+
893
+ :math:`f(\mu) = \alpha_1/\mu + a_0`.
894
+
895
+ Parameters
896
+ ----------
897
+ vst : bool
898
+ Whether the dispersion trend curve is being fitted as part of the VST
899
+ pipeline. (default: ``False``).
900
+ """
901
+ #disp_param_name = "vst_genewise_dispersions" if vst else "genewise_dispersions"
902
+ disp_param_name = "genewise_dispersions"
903
+
904
+ if disp_param_name not in self.varm:
905
+ self.fit_genewise_dispersions(vst)
906
+
907
+ #disp_param_name = "disp_param_name"
908
+
909
+ # Exclude all-zero counts
910
+ targets = pd.Series(
911
+ self.varm[disp_param_name][self.non_zero_idx].copy(),
912
+ index=self.non_zero_genes,
913
+ )
914
+ covariates = pd.Series(
915
+ 1 / self.varm["_normed_means"][self.non_zero_idx],
916
+ index=self.non_zero_genes,
917
+ )
918
+
919
+ for gene in self.non_zero_genes:
920
+ if (
921
+ np.isinf(covariates.loc[gene]).any()
922
+ or np.isnan(covariates.loc[gene]).any()
923
+ ):
924
+ targets.drop(labels=[gene], inplace=True)
925
+ covariates.drop(labels=[gene], inplace=True)
926
+
927
+ # Initialize coefficients
928
+ old_coeffs = pd.Series([0.1, 0.1])
929
+ coeffs = pd.Series([1.0, 1.0])
930
+ while (coeffs > 1e-10).all() and (
931
+ np.log(np.abs(coeffs / old_coeffs)) ** 2
932
+ ).sum() >= 1e-6:
933
+ old_coeffs = coeffs
934
+ coeffs, predictions, converged = self.inference.dispersion_trend_gamma_glm(
935
+ covariates, targets
936
+ )
937
+
938
+ if not converged or (coeffs <= 1e-10).any():
939
+ warnings.warn(
940
+ "The dispersion trend curve fitting did not converge. "
941
+ "Switching to a mean-based dispersion trend.",
942
+ UserWarning,
943
+ stacklevel=2,
944
+ )
945
+
946
+ self._fit_mean_dispersion_trend(vst)
947
+ return
948
+
949
+ # Filter out genes that are too far away from the curve before refitting
950
+ gene_labels = self.non_zero_genes.to_numpy()
951
+ position_map = {label: idx for idx, label in enumerate(gene_labels)}
952
+ covariates_index = covariates.index.map(lambda x: position_map[x])
953
+
954
+
955
+ pred_ratios = self.varm[disp_param_name][covariates_index] / predictions
956
+
957
+ targets.drop(
958
+ targets[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
959
+ inplace=True,
960
+ )
961
+ covariates.drop(
962
+ covariates[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
963
+ inplace=True,
964
+ )
965
+
966
+ if vst:
967
+ self.uns["vst_trend_coeffs"] = pd.Series(coeffs, index=["a0", "a1"])
968
+ else:
969
+ self.uns["trend_coeffs"] = pd.Series(coeffs, index=["a0", "a1"])
970
+
971
+ self.varm["fitted_dispersions"] = np.full(self.n_vars, np.nan)
972
+ self.uns["disp_function_type"] = "parametric"
973
+ self.varm["fitted_dispersions"][self.varm["non_zero"]] = self.disp_function(
974
+ self.varm["_normed_means"][self.varm["non_zero"]]
975
+ )
976
+
977
+ def _fit_mean_dispersion_trend(self, vst: bool = False):
978
+ """Use the mean of dispersions as trend curve.
979
+
980
+ Parameters
981
+ ----------
982
+ vst : bool
983
+ Whether the dispersion trend curve is being fitted as part of the VST
984
+ pipeline. (default: ``False``).
985
+ """
986
+ #disp_param_name = "vst_genewise_dispersions" if vst else "genewise_dispersions"
987
+ disp_param_name = "genewise_dispersions"
988
+
989
+ self.uns["mean_disp"] = trim_mean(
990
+ self.varm[disp_param_name][self.varm[disp_param_name] > 10 * self.min_disp],
991
+ proportiontocut=0.001,
992
+ )
993
+
994
+ self.uns["disp_function_type"] = "mean"
995
+ self.varm["fitted_dispersions"] = np.full(self.n_vars, self.uns["mean_disp"])
996
+
997
+
998
+ def _replace_outliers(self) -> None:
999
+ """Replace values that are filtered out (based on Cooks) with imputed values."""
1000
+ # Check that cooks distances are available. If not, compute them.
1001
+ if "cooks" not in self.layers:
1002
+ self.calculate_cooks()
1003
+
1004
+ num_samples = self.n_obs
1005
+ num_vars = self.obsm["design_matrix"].shape[1]
1006
+
1007
+ # Check whether cohorts have enough samples to allow refitting
1008
+ self.obsm["replaceable"] = n_or_more_replicates(
1009
+ self.obsm["design_matrix"], self.min_replicates
1010
+ ).values
1011
+
1012
+ if self.obsm["replaceable"].sum() == 0:
1013
+ # No sample can be replaced. Set self.replaced to False and exit.
1014
+ self.varm["replaced"] = np.full(
1015
+ self.n_vars,
1016
+ False,
1017
+ )
1018
+ return
1019
+
1020
+ # Get positions of counts with cooks above threshold
1021
+ cooks_cutoff = f.ppf(0.99, num_vars, num_samples - num_vars)
1022
+ idx = self.layers["cooks"] > cooks_cutoff
1023
+ self.varm["replaced"] = idx.any(axis=0)
1024
+
1025
+ if sum(self.varm["replaced"] > 0):
1026
+ # Compute replacement counts: trimmed means * size_factors
1027
+
1028
+ self.counts_to_refit = pd.DataFrame(
1029
+ self.data["counts"][:, self.varm["replaced"]],
1030
+ columns=self.var_names[self.varm["replaced"]]
1031
+ )
1032
+
1033
+ trim_base_mean = pd.DataFrame(
1034
+ cast(
1035
+ np.ndarray,
1036
+ trimmed_mean(
1037
+ self.counts_to_refit / self.obsm["size_factors"][:, None],
1038
+ trim=0.2,
1039
+ axis=0,
1040
+ ),
1041
+ ),
1042
+ index=self.counts_to_refit.columns # Use .columns instead of .var_names
1043
+ )
1044
+
1045
+ replacement_counts = (
1046
+ pd.DataFrame(
1047
+ (trim_base_mean.values * self.obsm["size_factors"]).T, # Ensure this is 2D
1048
+ index=self.counts_to_refit.index,
1049
+ columns=self.counts_to_refit.columns
1050
+ )
1051
+ .astype(int)
1052
+ )
1053
+
1054
+ replace_mask = self.obsm["replaceable"][:, None] & idx[:, self.varm["replaced"]]
1055
+
1056
+ print(f"replace_mask before filtering: {replace_mask.shape}")
1057
+ print(f"Number of True values in replace_mask: {np.sum(replace_mask)}")
1058
+
1059
+ # Ensure that replacement_counts and the indexing result are compatible
1060
+ replacement_counts_trimmed = replacement_counts.loc[replace_mask.any(axis=1)]
1061
+ print(f"replacement_counts_trimmed shape: {replacement_counts_trimmed.shape}")
1062
+
1063
+ # Apply the mask to refit only the matching rows, ensuring consistent dimensions
1064
+ assert replacement_counts_trimmed.shape == self.counts_to_refit.loc[replace_mask.any(axis=1)].shape, \
1065
+ f"Shape mismatch: replacement_counts_trimmed {replacement_counts_trimmed.shape} vs refit slice {self.counts_to_refit.loc[replace_mask.any(axis=1)].shape}"
1066
+
1067
+ # Replace counts in self.counts_to_refit for the rows matching the mask
1068
+ self.counts_to_refit.loc[replace_mask.any(axis=1)] = replacement_counts_trimmed.values
1069
+
1070
+
1071
+ def _refit_without_outliers(
1072
+ self,
1073
+ ) -> None:
1074
+ """Re-run the whole DESeq2 pipeline with replaced outliers."""
1075
+ assert (
1076
+ self.refit_cooks
1077
+ ), "Trying to refit Cooks outliers but the 'refit_cooks' flag is set to False"
1078
+
1079
+ # Check that _replace_outliers() was previously run.
1080
+ if "replaced" not in self.varm:
1081
+ self._replace_outliers()
1082
+
1083
+ # Only refit genes for which replacing outliers hasn't resulted in all zeroes
1084
+ new_all_zeroes = (self.counts_to_refit == 0).all(axis=0)
1085
+ self.new_all_zeroes_genes = self.counts_to_refit.columns[new_all_zeroes]
1086
+
1087
+ self.varm["refitted"] = self.varm["replaced"].copy()
1088
+ # Only replace if genes are not all zeroes after outlier replacement
1089
+ self.varm["refitted"][self.varm["refitted"]] = ~new_all_zeroes
1090
+
1091
+ # Take into account new all-zero genes
1092
+ if new_all_zeroes.sum() > 0:
1093
+ self.varm["_normed_means"][
1094
+ self.var_names.get_indexer(self.new_all_zeroes_genes)
1095
+ ] = 0
1096
+ self.varm["LFC"].loc[self.new_all_zeroes_genes, :] = 0
1097
+
1098
+ if self.varm["refitted"].sum() == 0: # if no gene can be refitted, we can skip
1099
+ return
1100
+
1101
+ self.counts_to_refit = self.counts_to_refit.loc[:, ~new_all_zeroes].copy()
1102
+ if isinstance(self.new_all_zeroes_genes, pd.MultiIndex):
1103
+ raise ValueError
1104
+
1105
+ sub_dds = deconveil_fit(
1106
+ counts=pd.DataFrame(
1107
+ self.counts_to_refit,
1108
+ index=self.counts_to_refit.index,
1109
+ columns=self.counts_to_refit.columns,
1110
+ ),
1111
+ cnv=self.data["cnv"],
1112
+ metadata=self.metadata,
1113
+ design_factors=self.design_factors,
1114
+ continuous_factors=self.continuous_factors,
1115
+ ref_level=self.ref_level,
1116
+ min_mu=self.min_mu,
1117
+ min_disp=self.min_disp,
1118
+ max_disp=self.max_disp,
1119
+ refit_cooks=self.refit_cooks,
1120
+ min_replicates=self.min_replicates,
1121
+ beta_tol=self.beta_tol,
1122
+ inference=self.inference,
1123
+ )
1124
+
1125
+ # Use the same size factors
1126
+ sub_dds.obsm["size_factors"] = self.obsm["size_factors"]
1127
+ sub_dds.layers["normed_counts"] = (
1128
+ sub_dds.data["counts"] / sub_dds.obsm["size_factors"][:, None]
1129
+ )
1130
+
1131
+ # Estimate gene-wise dispersions.
1132
+ sub_dds.fit_genewise_dispersions()
1133
+
1134
+ # Compute trend dispersions.
1135
+ # Note: the trend curve is not refitted.
1136
+ sub_dds.uns["disp_function_type"] = self.uns["disp_function_type"]
1137
+ if sub_dds.uns["disp_function_type"] == "parametric":
1138
+ sub_dds.uns["trend_coeffs"] = self.uns["trend_coeffs"]
1139
+ elif sub_dds.uns["disp_function_type"] == "mean":
1140
+ sub_dds.uns["mean_disp"] = self.uns["mean_disp"]
1141
+ sub_dds.varm["_normed_means"] = sub_dds.layers["normed_counts"].mean(0)
1142
+ # Reshape in case there's a single gene to refit
1143
+ sub_dds.varm["fitted_dispersions"] = sub_dds.disp_function(
1144
+ sub_dds.varm["_normed_means"]
1145
+ )
1146
+
1147
+ # Estimate MAP dispersions.
1148
+ # Note: the prior variance is not recomputed.
1149
+ sub_dds.uns["_squared_logres"] = self.uns["_squared_logres"]
1150
+ sub_dds.uns["prior_disp_var"] = self.uns["prior_disp_var"]
1151
+
1152
+ sub_dds.fit_MAP_dispersions()
1153
+
1154
+ # Estimate log-fold changes (in natural log scale)
1155
+ sub_dds.fit_LFC()
1156
+
1157
+ # Replace values in main object
1158
+ self.varm["_normed_means"][self.varm["refitted"]] = sub_dds.varm["_normed_means"]
1159
+ self.varm["LFC"][self.varm["refitted"]] = sub_dds.varm["LFC"]
1160
+ self.varm["genewise_dispersions"][self.varm["refitted"]] = sub_dds.varm[
1161
+ "genewise_dispersions"
1162
+ ]
1163
+ self.varm["fitted_dispersions"][self.varm["refitted"]] = sub_dds.varm[
1164
+ "fitted_dispersions"
1165
+ ]
1166
+ self.varm["dispersions"][self.varm["refitted"]] = sub_dds.varm["dispersions"]
1167
+
1168
+ replace_cooks = pd.DataFrame(self.layers["cooks"].copy())
1169
+ replace_cooks.loc[self.obsm["replaceable"], self.varm["refitted"]] = 0.0
1170
+
1171
+ self.layers["replace_cooks"] = replace_cooks
1172
+
1173
+
1174
+ def _fit_iterate_size_factors(self, niter: int = 10, quant: float = 0.95) -> None:
1175
+ """
1176
+ Fit size factors using the ``iterative`` method.
1177
+
1178
+ Used when each gene has at least one zero.
1179
+
1180
+ Parameters
1181
+ ----------
1182
+ niter : int
1183
+ Maximum number of iterations to perform (default: ``10``).
1184
+
1185
+ quant : float
1186
+ Quantile value at which negative likelihood is cut in the optimization
1187
+ (default: ``0.95``).
1188
+
1189
+ """
1190
+ # Initialize size factors and normed counts fields
1191
+ self.obsm["size_factors"] = np.ones(self.n_obs)
1192
+ self.layers["normed_counts"] = self.data["counts"]
1193
+
1194
+ # Reduce the design matrix to an intercept and reconstruct at the end
1195
+ self.obsm["design_matrix_buffer"] = self.obsm["design_matrix"].copy()
1196
+ self.obsm["design_matrix"] = pd.DataFrame(
1197
+ 1, index=self.obs_names, columns=["intercept"]
1198
+ )
1199
+
1200
+ # Fit size factors using MLE
1201
+ def objective(p):
1202
+ sf = np.exp(p - np.mean(p))
1203
+ nll = nb_nll(
1204
+ counts=self[:, self.non_zero_genes].data["counts"],
1205
+ mu=self[:, self.non_zero_genes].layers["_mu_hat"]
1206
+ / self.obsm["size_factors"][:, None]
1207
+ * sf[:, None],
1208
+ alpha=self[:, self.non_zero_genes].varm["dispersions"],
1209
+ )
1210
+ # Take out the lowest likelihoods (highest neg) from the sum
1211
+ return np.sum(nll[nll < np.quantile(nll, quant)])
1212
+
1213
+ for i in range(niter):
1214
+ # Estimate dispersions based on current size factors
1215
+ self.fit_genewise_dispersions()
1216
+
1217
+ # Use a mean trend curve
1218
+ use_for_mean_genes = self.var_names[
1219
+ (self.varm["genewise_dispersions"] > 10 * self.min_disp)
1220
+ & self.varm["non_zero"]
1221
+ ]
1222
+
1223
+ if len(use_for_mean_genes) == 0:
1224
+ print(
1225
+ "No genes have a dispersion above 10 * min_disp in "
1226
+ "_fit_iterate_size_factors."
1227
+ )
1228
+ break
1229
+
1230
+ mean_disp = trim_mean(
1231
+ self[:, use_for_mean_genes].varm["genewise_dispersions"],
1232
+ proportiontocut=0.001,
1233
+ )
1234
+
1235
+ self.varm["fitted_dispersions"] = np.ones(self.n_vars) * mean_disp
1236
+ self.fit_dispersion_prior()
1237
+ self.fit_MAP_dispersions()
1238
+ old_sf = self.obsm["size_factors"].copy()
1239
+
1240
+ # Fit size factors using MLE
1241
+ res = minimize(objective, np.log(old_sf), method="Powell")
1242
+
1243
+ self.obsm["size_factors"] = np.exp(res.x - np.mean(res.x))
1244
+
1245
+ if not res.success:
1246
+ print("A size factor fitting iteration failed.", file=sys.stderr)
1247
+ break
1248
+
1249
+ if (i > 1) and np.sum(
1250
+ (np.log(old_sf) - np.log(self.obsm["size_factors"])) ** 2
1251
+ ) < 1e-4:
1252
+ break
1253
+ elif i == niter - 1:
1254
+ print("Iterative size factor fitting did not converge.", file=sys.stderr)
1255
+
1256
+ # Restore the design matrix and free buffer
1257
+ self.obsm["design_matrix"] = self.obsm["design_matrix_buffer"].copy()
1258
+ del self.obsm["design_matrix_buffer"]
1259
+
1260
+ # Store normalized counts
1261
+ self.layers["normed_counts"] = self.data["counts"] / self.obsm["size_factors"][:, None]
1262
+
1263
+
1264
+ def _check_full_rank_design(self):
1265
+ """Check that the design matrix has full column rank."""
1266
+ rank = np.linalg.matrix_rank(self.obsm["design_matrix"])
1267
+ num_vars = self.obsm["design_matrix"].shape[1]
1268
+
1269
+ if rank < num_vars:
1270
+ warnings.warn(
1271
+ "The design matrix is not full rank, so the model cannot be "
1272
+ "fitted, but some operations like design-free VST remain possible. "
1273
+ "To perform differential expression analysis, please remove the design "
1274
+ "variables that are linear combinations of others.",
1275
+ UserWarning,
1276
+ stacklevel=2,
1277
+ )
1278
+
1279
+