scdesigner 0.0.5__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. scdesigner/base/__init__.py +8 -0
  2. scdesigner/base/copula.py +416 -0
  3. scdesigner/base/marginal.py +391 -0
  4. scdesigner/base/simulator.py +59 -0
  5. scdesigner/copulas/__init__.py +8 -0
  6. scdesigner/copulas/standard_copula.py +645 -0
  7. scdesigner/datasets/__init__.py +5 -0
  8. scdesigner/datasets/pancreas.py +39 -0
  9. scdesigner/distributions/__init__.py +19 -0
  10. scdesigner/{minimal → distributions}/bernoulli.py +42 -14
  11. scdesigner/distributions/gaussian.py +114 -0
  12. scdesigner/distributions/negbin.py +121 -0
  13. scdesigner/distributions/negbin_irls.py +72 -0
  14. scdesigner/distributions/negbin_irls_funs.py +456 -0
  15. scdesigner/distributions/poisson.py +88 -0
  16. scdesigner/{minimal → distributions}/zero_inflated_negbin.py +39 -10
  17. scdesigner/distributions/zero_inflated_poisson.py +103 -0
  18. scdesigner/simulators/__init__.py +24 -28
  19. scdesigner/simulators/composite.py +239 -0
  20. scdesigner/simulators/positive_nonnegative_matrix_factorization.py +477 -0
  21. scdesigner/simulators/scd3.py +486 -0
  22. scdesigner/transform/__init__.py +8 -6
  23. scdesigner/{minimal → transform}/transform.py +1 -1
  24. scdesigner/{minimal → utils}/kwargs.py +4 -1
  25. {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/METADATA +1 -1
  26. scdesigner-0.0.10.dist-info/RECORD +28 -0
  27. {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/WHEEL +1 -1
  28. scdesigner/data/__init__.py +0 -16
  29. scdesigner/data/formula.py +0 -137
  30. scdesigner/data/group.py +0 -123
  31. scdesigner/data/sparse.py +0 -39
  32. scdesigner/diagnose/__init__.py +0 -65
  33. scdesigner/diagnose/aic_bic.py +0 -119
  34. scdesigner/diagnose/plot.py +0 -242
  35. scdesigner/estimators/__init__.py +0 -32
  36. scdesigner/estimators/bernoulli.py +0 -85
  37. scdesigner/estimators/gaussian.py +0 -121
  38. scdesigner/estimators/gaussian_copula_factory.py +0 -367
  39. scdesigner/estimators/glm_factory.py +0 -75
  40. scdesigner/estimators/negbin.py +0 -153
  41. scdesigner/estimators/pnmf.py +0 -160
  42. scdesigner/estimators/poisson.py +0 -124
  43. scdesigner/estimators/zero_inflated_negbin.py +0 -195
  44. scdesigner/estimators/zero_inflated_poisson.py +0 -85
  45. scdesigner/format/__init__.py +0 -4
  46. scdesigner/format/format.py +0 -20
  47. scdesigner/format/print.py +0 -30
  48. scdesigner/minimal/__init__.py +0 -17
  49. scdesigner/minimal/composite.py +0 -119
  50. scdesigner/minimal/copula.py +0 -205
  51. scdesigner/minimal/formula.py +0 -23
  52. scdesigner/minimal/gaussian.py +0 -65
  53. scdesigner/minimal/loader.py +0 -211
  54. scdesigner/minimal/marginal.py +0 -154
  55. scdesigner/minimal/negbin.py +0 -73
  56. scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -231
  57. scdesigner/minimal/scd3.py +0 -96
  58. scdesigner/minimal/scd3_instances.py +0 -50
  59. scdesigner/minimal/simulator.py +0 -25
  60. scdesigner/minimal/standard_copula.py +0 -383
  61. scdesigner/predictors/__init__.py +0 -15
  62. scdesigner/predictors/bernoulli.py +0 -9
  63. scdesigner/predictors/gaussian.py +0 -16
  64. scdesigner/predictors/negbin.py +0 -17
  65. scdesigner/predictors/poisson.py +0 -12
  66. scdesigner/predictors/zero_inflated_negbin.py +0 -18
  67. scdesigner/predictors/zero_inflated_poisson.py +0 -18
  68. scdesigner/samplers/__init__.py +0 -23
  69. scdesigner/samplers/bernoulli.py +0 -27
  70. scdesigner/samplers/gaussian.py +0 -25
  71. scdesigner/samplers/glm_factory.py +0 -103
  72. scdesigner/samplers/negbin.py +0 -25
  73. scdesigner/samplers/poisson.py +0 -25
  74. scdesigner/samplers/zero_inflated_negbin.py +0 -40
  75. scdesigner/samplers/zero_inflated_poisson.py +0 -16
  76. scdesigner/simulators/composite_regressor.py +0 -72
  77. scdesigner/simulators/glm_simulator.py +0 -167
  78. scdesigner/simulators/pnmf_regression.py +0 -61
  79. scdesigner/transform/amplify.py +0 -14
  80. scdesigner/transform/mask.py +0 -33
  81. scdesigner/transform/nullify.py +0 -25
  82. scdesigner/transform/split.py +0 -23
  83. scdesigner/transform/substitute.py +0 -14
  84. scdesigner-0.0.5.dist-info/RECORD +0 -66
@@ -0,0 +1,19 @@
1
+ """Marginal distribution implementations."""
2
+
3
+ from .negbin import NegBin
4
+ from .negbin_irls import NegBinIRLS
5
+ from .zero_inflated_negbin import ZeroInflatedNegBin
6
+ from .gaussian import Gaussian
7
+ from .bernoulli import Bernoulli
8
+ from .poisson import Poisson
9
+ from .zero_inflated_poisson import ZeroInflatedPoisson
10
+
11
+ __all__ = [
12
+ "NegBin",
13
+ "NegBinIRLS",
14
+ "ZeroInflatedNegBin",
15
+ "Gaussian",
16
+ "Bernoulli",
17
+ "Poisson",
18
+ "ZeroInflatedPoisson",
19
+ ]
@@ -1,15 +1,42 @@
1
- from .formula import standardize_formula
2
- from .marginal import GLMPredictor, Marginal
3
- from .loader import _to_numpy
4
- from typing import Union, Dict, Optional
1
+ from ..data.formula import standardize_formula
2
+ from ..base.marginal import GLMPredictor, Marginal
3
+ from ..data.loader import _to_numpy
4
+ from typing import Union, Dict, Optional, Tuple
5
5
  import torch
6
6
  import numpy as np
7
- from scipy.stats import nbinom, bernoulli
7
+ from scipy.stats import bernoulli
8
8
 
9
- class ZeroInflatedNegBin(Marginal):
10
- """Zero-inflated negative-binomial marginal estimator"""
9
+ class Bernoulli(Marginal):
10
+ """Bernoulli marginal estimator
11
+
12
+ This subclass behaves like `Marginal` but assumes each feature follows a
13
+ Bernoulli distribution with success probability `theta_j(x)` that depends
14
+ on covariates `x` through the `formula` argument.
15
+
16
+ The allowed formula keys are 'mean' (interpreted as the logit of the
17
+ success probability when used with a GLM link). If a string formula is
18
+ provided, it is taken to specify the `mean` model.
19
+
20
+ Examples
21
+ --------
22
+ >>> from scdesigner.distributions import Bernoulli
23
+ >>> from scdesigner.datasets import pancreas
24
+ >>>
25
+ >>> sim = Bernoulli(formula="~ pseudotime")
26
+ >>> sim.setup_data(pancreas)
27
+ >>> sim.fit(max_epochs=1, verbose=False)
28
+ >>>
29
+ >>> # evaluate p(y | x) and theta(x)
30
+ >>> y, x = next(iter(sim.loader))
31
+ >>> l = sim.likelihood((y, x))
32
+ >>> y_hat = sim.predict(x)
33
+ >>>
34
+ >>> # convert to quantiles and back
35
+ >>> u = sim.uniformize(y, x)
36
+ >>> x_star = sim.invert(u, x)
37
+ """
11
38
  def __init__(self, formula: Union[Dict, str]):
12
- formula = standardize_formula(formula, allowed_keys=['mean', 'dispersion', 'zero_inflation'])
39
+ formula = standardize_formula(formula, allowed_keys=['mean'])
13
40
  super().__init__(formula)
14
41
 
15
42
  def setup_optimizer(
@@ -21,7 +48,8 @@ class ZeroInflatedNegBin(Marginal):
21
48
  raise RuntimeError("self.loader is not set (call setup_data first)")
22
49
 
23
50
  link_fns = {"mean": torch.sigmoid}
24
- nll = lambda batch: -self.likelihood(batch).sum()
51
+ def nll(batch):
52
+ return -self.likelihood(batch).sum()
25
53
  self.predict = GLMPredictor(
26
54
  n_outcomes=self.n_outcomes,
27
55
  feature_dims=self.feature_dims,
@@ -31,20 +59,20 @@ class ZeroInflatedNegBin(Marginal):
31
59
  optimizer_kwargs=optimizer_kwargs
32
60
  )
33
61
 
34
- def likelihood(self, batch):
35
- """Compute the negative log-likelihood"""
62
+ def likelihood(self, batch) -> torch.Tensor:
63
+ """Compute the log-likelihood"""
36
64
  y, x = batch
37
65
  params = self.predict(x)
38
66
  theta = params.get("mean")
39
67
  return y * torch.log(theta) + (1 - y) * torch.log(1 - theta)
40
68
 
41
- def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
69
+ def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
42
70
  """Invert pseudoobservations."""
43
71
  theta, u = self._local_params(x, u)
44
72
  y = bernoulli(theta).ppf(u)
45
73
  return torch.from_numpy(y).float()
46
74
 
47
- def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6):
75
+ def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6) -> torch.Tensor:
48
76
  """Return uniformized pseudo-observations for counts y given covariates x."""
49
77
  theta, y = self._local_params(x, y)
50
78
  u1 = bernoulli(theta).cdf(y)
@@ -53,7 +81,7 @@ class ZeroInflatedNegBin(Marginal):
53
81
  u = np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
54
82
  return torch.from_numpy(u).float()
55
83
 
56
- def _local_params(self, x, y=None):
84
+ def _local_params(self, x, y=None) -> Tuple:
57
85
  params = self.predict(x)
58
86
  theta = params.get('mean')
59
87
  if y is None:
@@ -0,0 +1,114 @@
1
+ from ..data.formula import standardize_formula
2
+ from ..base.marginal import GLMPredictor, Marginal
3
+ from ..data.loader import _to_numpy
4
+ from typing import Union, Dict, Optional, Tuple
5
+ import torch
6
+ import numpy as np
7
+ from scipy.stats import norm
8
+
9
+ class Gaussian(Marginal):
10
+ """Gaussian marginal estimator
11
+
12
+ This subclass behaves like `Marginal` but assuming that each gene follows a
13
+ normal N(mu[j](x), sigma[j]^2(x)) distribution. The parameters mu[j](x) and
14
+ sigma[j]^2(x) depend on experimental or biological features x through the
15
+ formula object.
16
+
17
+ The allowed formula keys are 'mean' and 'sdev', defaulting to 'mean' with a
18
+ fixed standard deviation if only a string formula is passed in.
19
+
20
+ Examples
21
+ --------
22
+ >>> from scdesigner.distributions import Gaussian
23
+ >>> from scdesigner.datasets import pancreas
24
+ >>>
25
+ >>> sim = Gaussian(formula={"mean": "~ bs(pseudotime, df=5)", "sdev": "~ pseudotime"})
26
+ >>> sim.setup_data(pancreas)
27
+ >>> sim.fit(max_epochs=1, verbose=False)
28
+ >>>
29
+ >>> # evaluate p(y | x) and mu(x)
30
+ >>> y, x = next(iter(sim.loader))
31
+ >>> l = sim.likelihood((y, x))
32
+ >>> y_hat = sim.predict(x)
33
+ >>>
34
+ >>> # convert to quantiles and back
35
+ >>> u = sim.uniformize(y, x)
36
+ >>> x_star = sim.invert(u, x)
37
+ """
38
+ def __init__(self, formula: Union[Dict, str]):
39
+ formula = standardize_formula(formula, allowed_keys=['mean', 'sdev'])
40
+ super().__init__(formula)
41
+
42
+ def setup_optimizer(
43
+ self,
44
+ optimizer_class: Optional[callable] = torch.optim.Adam,
45
+ **optimizer_kwargs,
46
+ ):
47
+ """
48
+ Gaussian Model Optimizer
49
+
50
+ By default optimization is done using Adam. This can be customized using
51
+ the `optimizer_class` argument. The link function for the mean is an
52
+ identity link.
53
+
54
+ Parameters
55
+ ----------
56
+ optimizer_class : Optional[callable]
57
+ We optimize the negative log likelihood using the Adam optimizer by
58
+ default. Alternative torch.optim.* optimizer can be passed in
59
+ through this argument.
60
+ **optimizer_kwargs :
61
+ Arguments that are passed to the optimizer during estimation.
62
+
63
+ Returns
64
+ -------
65
+ Does not return anything, but modifies the self.predict attribute to
66
+ refer to the new optimizer object.
67
+ """
68
+ if self.loader is None:
69
+ raise RuntimeError("self.loader is not set (call setup_data first)")
70
+
71
+ def nll(batch):
72
+ return -self.likelihood(batch).sum()
73
+ link_fns = {"mean": lambda x: x}
74
+ self.predict = GLMPredictor(
75
+ n_outcomes=self.n_outcomes,
76
+ feature_dims=self.feature_dims,
77
+ link_fns=link_fns,
78
+ loss_fn=nll,
79
+ optimizer_class=optimizer_class,
80
+ optimizer_kwargs=optimizer_kwargs
81
+ )
82
+
83
+ def likelihood(self, batch) -> torch.Tensor:
84
+ """Compute the log-likelihood"""
85
+ y, x = batch
86
+ params = self.predict(x)
87
+ mu = params.get("mean")
88
+ sigma = params.get("sdev")
89
+
90
+ # log likelihood for Gaussian
91
+ log_likelihood = -0.5 * (torch.log(2 * torch.pi * sigma ** 2) + ((y - mu) ** 2) / (sigma ** 2))
92
+ return log_likelihood
93
+
94
+ def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
95
+ """Invert pseudoobservations."""
96
+ mu, sdev, u = self._local_params(x, u)
97
+ y = norm(loc=mu, scale=sdev).ppf(u)
98
+ return torch.from_numpy(y).float()
99
+
100
+ def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6) -> torch.Tensor:
101
+ """Return uniformized pseudo-observations for counts y given covariates x."""
102
+ # cdf values using scipy's parameterization
103
+ mu, sdev, y = self._local_params(x, y)
104
+ u = norm.cdf(y, loc=mu, scale=sdev)
105
+ u = np.clip(u, epsilon, 1 - epsilon)
106
+ return torch.from_numpy(u).float()
107
+
108
+ def _local_params(self, x, y=None) -> Tuple:
109
+ params = self.predict(x)
110
+ mu = params.get('mean')
111
+ sdev = params.get('sdev')
112
+ if y is None:
113
+ return _to_numpy(mu, sdev)
114
+ return _to_numpy(mu, sdev, y)
@@ -0,0 +1,121 @@
1
+ from ..base.marginal import GLMPredictor, Marginal
2
+ from ..data.formula import standardize_formula
3
+ from ..data.loader import _to_numpy
4
+ from ..utils.kwargs import _filter_kwargs, DEFAULT_ALLOWED_KWARGS
5
+ from .negbin_irls_funs import initialize_parameters
6
+ from scipy.stats import nbinom
7
+ from typing import Union, Dict, Optional, Tuple
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class NegBin(Marginal):
13
+ """Negative-binomial marginal estimator with poisson initialization
14
+
15
+ This subclass behaves like `Marginal` but assumes each gene follows a
16
+ negative binomial distribution NB(mu_j(x), r_j(x)) parameterized via a mean
17
+ `mu_j(x)` and dispersion `r_j(x)` that depend on covariates `x` through the
18
+ provided `formula` object.
19
+
20
+ The allowed formula keys are 'mean' and 'dispersion', defaulting to
21
+ 'mean' with a fixed dispersion if only a string formula is passed in.
22
+
23
+ Examples
24
+ --------
25
+ >>> from scdesigner.distributions import NegBin
26
+ >>> from scdesigner.datasets import pancreas
27
+ >>>
28
+ >>> sim = NegBin(formula={"mean": "~ bs(pseudotime, df=5)", "dispersion": "~ pseudotime"})
29
+ >>> sim.setup_data(pancreas)
30
+ >>> sim.fit(max_epochs=1, verbose=False)
31
+ >>>
32
+ >>> # evaluate p(y | x) and mu(x)
33
+ >>> y, x = next(iter(sim.loader))
34
+ >>> l = sim.likelihood((y, x))
35
+ >>> y_hat = sim.predict(x)
36
+ >>>
37
+ >>> # convert to quantiles and back
38
+ >>> u = sim.uniformize(y, x)
39
+ >>> x_star = sim.invert(u, x)
40
+ """
41
+ def __init__(self, formula: Union[Dict, str]):
42
+ formula = standardize_formula(formula, allowed_keys=['mean', 'dispersion'])
43
+ super().__init__(formula)
44
+
45
+ def setup_optimizer(
46
+ self,
47
+ optimizer_class: Optional[callable] = torch.optim.AdamW,
48
+ **optimizer_kwargs,
49
+ ):
50
+ if self.loader is None:
51
+ raise RuntimeError("self.loader is not set (call setup_data first)")
52
+
53
+ def nll(batch):
54
+ return -self.likelihood(batch).sum()
55
+ self.predict = GLMPredictor(
56
+ n_outcomes=self.n_outcomes,
57
+ feature_dims=self.feature_dims,
58
+ loss_fn=nll,
59
+ optimizer_class=optimizer_class,
60
+ optimizer_kwargs=optimizer_kwargs
61
+ )
62
+
63
+ def likelihood(self, batch) -> torch.Tensor:
64
+ """Compute the log-likelihood"""
65
+ y, x = batch
66
+ params = self.predict(x)
67
+ mu = params.get('mean')
68
+ r = params.get('dispersion')
69
+ return (
70
+ torch.lgamma(y + r)
71
+ - torch.lgamma(r)
72
+ - torch.lgamma(y + 1.0)
73
+ + r * torch.log(r)
74
+ + y * torch.log(mu)
75
+ - (r + y) * torch.log(r + mu)
76
+ )
77
+
78
+ def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
79
+ """Invert pseudoobservations."""
80
+ mu, r, u = self._local_params(x, u)
81
+ p = r / (r + mu)
82
+ y = nbinom(n=r, p=p).ppf(u)
83
+ return torch.from_numpy(y).float()
84
+
85
+ def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6) -> torch.Tensor:
86
+ """Return uniformized pseudo-observations for counts y given covariates x."""
87
+ # cdf values using scipy's parameterization
88
+ mu, r, y = self._local_params(x, y)
89
+ p = r / (r + mu)
90
+ u1 = nbinom(n=r, p=p).cdf(y)
91
+ u2 = np.where(y > 0, nbinom(n=r, p=p).cdf(y - 1), 0.0)
92
+
93
+ # randomize within discrete mass to get uniform(0,1)
94
+ v = np.random.uniform(size=y.shape)
95
+ u = np.clip(v * u1 + (1.0 - v) * u2, epsilon, 1.0 - epsilon)
96
+ return torch.from_numpy(u).float()
97
+
98
+ def _local_params(self, x, y=None) -> Tuple:
99
+ params = self.predict(x)
100
+ mu = params.get('mean')
101
+ r = params.get('dispersion')
102
+ if y is None:
103
+ return _to_numpy(mu, r)
104
+ return _to_numpy(mu, r, y)
105
+
106
+ def fit(self, max_epochs: int = 100, verbose: bool = True, **kwargs):
107
+ if self.predict is None:
108
+ self.setup_optimizer(**kwargs)
109
+
110
+ # initialize using a poisson fit
111
+ initialize_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['initialize'])
112
+ beta_init, gamma_init = initialize_parameters(
113
+ self.loader, self.n_outcomes, self.feature_dims['mean'],
114
+ self.feature_dims['dispersion'],
115
+ **initialize_kwargs
116
+ )
117
+ with torch.no_grad():
118
+ self.predict.coefs['mean'].copy_(beta_init)
119
+ self.predict.coefs['dispersion'].copy_(gamma_init)
120
+
121
+ return Marginal.fit(self, max_epochs, verbose, **kwargs)
@@ -0,0 +1,72 @@
1
+ import torch
2
+ from .negbin import NegBin
3
+ from .negbin_irls_funs import initialize_parameters, step_stochastic_irls
4
+ from ..data.formula import standardize_formula
5
+ from ..utils.kwargs import _filter_kwargs, DEFAULT_ALLOWED_KWARGS
6
+ from typing import Union, Dict
7
+
8
+
9
+ class NegBinIRLS(NegBin):
10
+ """
11
+ Negative-Binomial Marginal using Stochastic IRLS with
12
+ active response tracking and log-likelihood convergence.
13
+ """
14
+ def __init__(self, formula: Union[Dict, str]):
15
+ formula = standardize_formula(formula, allowed_keys=['mean', 'dispersion'])
16
+ super().__init__(formula, device="cpu")
17
+
18
+
19
+ def fit(self, max_epochs=10, tol=1e-4, eta=0.1, verbose=True, **kwargs):
20
+ if self.predict is None:
21
+ self.setup_optimizer(**kwargs)
22
+
23
+ # 1. Initialization using poisson fit
24
+ initialize_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['initialize'])
25
+ beta_init, gamma_init = initialize_parameters(
26
+ self.loader, self.n_outcomes, self.feature_dims['mean'],
27
+ self.feature_dims['dispersion'],
28
+ **initialize_kwargs
29
+ )
30
+
31
+ with torch.no_grad():
32
+ self.predict.coefs['mean'].copy_(beta_init)
33
+ self.predict.coefs['dispersion'].copy_(gamma_init)
34
+
35
+ # 2. All genes are active at the start
36
+ active_mask = torch.ones(self.n_outcomes, dtype=torch.bool)
37
+ ll_ = - 1e9 * torch.ones(self.n_outcomes, dtype=torch.float32)
38
+
39
+ for epoch in range(max_epochs):
40
+ if not active_mask.any(): break
41
+ ll, n_batches = 0.0, 0
42
+
43
+ with torch.no_grad():
44
+ for y_batch, x_dict in self.loader:
45
+
46
+ # Slice active genes
47
+ idx = torch.where(active_mask)[0]
48
+ y_act = y_batch[:, active_mask]
49
+ X = x_dict['mean']
50
+ Z = x_dict['dispersion']
51
+
52
+ # Fetch current coefficients and update
53
+ b_curr = self.predict.coefs['mean'][:, active_mask]
54
+ g_curr = self.predict.coefs['dispersion'][:, active_mask]
55
+ b_next, g_next, conv_mask, ll_cur = step_stochastic_irls(y_act, X, Z, b_curr, g_curr, eta, tol, ll_[active_mask])
56
+ ll_[active_mask] = ll_cur
57
+
58
+ # Update Parameters and de-activate converged genes
59
+ with torch.no_grad():
60
+ self.predict.coefs['mean'][:, active_mask] = b_next
61
+ self.predict.coefs['dispersion'][:, active_mask] = g_next
62
+ active_mask[idx[conv_mask]] = False
63
+
64
+ # Accumulate batch log-likelihood using `ll` from the IRLS step
65
+ ll += ll_.sum().item()
66
+ n_batches += 1
67
+
68
+ if verbose and ((epoch + 1) % 10) == 0:
69
+ print(f"Epoch {epoch+1}/{max_epochs} | Genes remaining: {active_mask.sum().item()} | Loss: {-ll / n_batches:.4f}", end='\r')
70
+ if not active_mask.any(): break
71
+
72
+ self.parameters = self.format_parameters()