scdesigner 0.0.5__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdesigner/base/__init__.py +8 -0
- scdesigner/base/copula.py +416 -0
- scdesigner/base/marginal.py +391 -0
- scdesigner/base/simulator.py +59 -0
- scdesigner/copulas/__init__.py +8 -0
- scdesigner/copulas/standard_copula.py +645 -0
- scdesigner/datasets/__init__.py +5 -0
- scdesigner/datasets/pancreas.py +39 -0
- scdesigner/distributions/__init__.py +19 -0
- scdesigner/{minimal → distributions}/bernoulli.py +42 -14
- scdesigner/distributions/gaussian.py +114 -0
- scdesigner/distributions/negbin.py +121 -0
- scdesigner/distributions/negbin_irls.py +72 -0
- scdesigner/distributions/negbin_irls_funs.py +456 -0
- scdesigner/distributions/poisson.py +88 -0
- scdesigner/{minimal → distributions}/zero_inflated_negbin.py +39 -10
- scdesigner/distributions/zero_inflated_poisson.py +103 -0
- scdesigner/simulators/__init__.py +24 -28
- scdesigner/simulators/composite.py +239 -0
- scdesigner/simulators/positive_nonnegative_matrix_factorization.py +477 -0
- scdesigner/simulators/scd3.py +486 -0
- scdesigner/transform/__init__.py +8 -6
- scdesigner/{minimal → transform}/transform.py +1 -1
- scdesigner/{minimal → utils}/kwargs.py +4 -1
- {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/METADATA +1 -1
- scdesigner-0.0.10.dist-info/RECORD +28 -0
- {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/WHEEL +1 -1
- scdesigner/data/__init__.py +0 -16
- scdesigner/data/formula.py +0 -137
- scdesigner/data/group.py +0 -123
- scdesigner/data/sparse.py +0 -39
- scdesigner/diagnose/__init__.py +0 -65
- scdesigner/diagnose/aic_bic.py +0 -119
- scdesigner/diagnose/plot.py +0 -242
- scdesigner/estimators/__init__.py +0 -32
- scdesigner/estimators/bernoulli.py +0 -85
- scdesigner/estimators/gaussian.py +0 -121
- scdesigner/estimators/gaussian_copula_factory.py +0 -367
- scdesigner/estimators/glm_factory.py +0 -75
- scdesigner/estimators/negbin.py +0 -153
- scdesigner/estimators/pnmf.py +0 -160
- scdesigner/estimators/poisson.py +0 -124
- scdesigner/estimators/zero_inflated_negbin.py +0 -195
- scdesigner/estimators/zero_inflated_poisson.py +0 -85
- scdesigner/format/__init__.py +0 -4
- scdesigner/format/format.py +0 -20
- scdesigner/format/print.py +0 -30
- scdesigner/minimal/__init__.py +0 -17
- scdesigner/minimal/composite.py +0 -119
- scdesigner/minimal/copula.py +0 -205
- scdesigner/minimal/formula.py +0 -23
- scdesigner/minimal/gaussian.py +0 -65
- scdesigner/minimal/loader.py +0 -211
- scdesigner/minimal/marginal.py +0 -154
- scdesigner/minimal/negbin.py +0 -73
- scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -231
- scdesigner/minimal/scd3.py +0 -96
- scdesigner/minimal/scd3_instances.py +0 -50
- scdesigner/minimal/simulator.py +0 -25
- scdesigner/minimal/standard_copula.py +0 -383
- scdesigner/predictors/__init__.py +0 -15
- scdesigner/predictors/bernoulli.py +0 -9
- scdesigner/predictors/gaussian.py +0 -16
- scdesigner/predictors/negbin.py +0 -17
- scdesigner/predictors/poisson.py +0 -12
- scdesigner/predictors/zero_inflated_negbin.py +0 -18
- scdesigner/predictors/zero_inflated_poisson.py +0 -18
- scdesigner/samplers/__init__.py +0 -23
- scdesigner/samplers/bernoulli.py +0 -27
- scdesigner/samplers/gaussian.py +0 -25
- scdesigner/samplers/glm_factory.py +0 -103
- scdesigner/samplers/negbin.py +0 -25
- scdesigner/samplers/poisson.py +0 -25
- scdesigner/samplers/zero_inflated_negbin.py +0 -40
- scdesigner/samplers/zero_inflated_poisson.py +0 -16
- scdesigner/simulators/composite_regressor.py +0 -72
- scdesigner/simulators/glm_simulator.py +0 -167
- scdesigner/simulators/pnmf_regression.py +0 -61
- scdesigner/transform/amplify.py +0 -14
- scdesigner/transform/mask.py +0 -33
- scdesigner/transform/nullify.py +0 -25
- scdesigner/transform/split.py +0 -23
- scdesigner/transform/substitute.py +0 -14
- scdesigner-0.0.5.dist-info/RECORD +0 -66
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Marginal distribution implementations."""
|
|
2
|
+
|
|
3
|
+
from .negbin import NegBin
|
|
4
|
+
from .negbin_irls import NegBinIRLS
|
|
5
|
+
from .zero_inflated_negbin import ZeroInflatedNegBin
|
|
6
|
+
from .gaussian import Gaussian
|
|
7
|
+
from .bernoulli import Bernoulli
|
|
8
|
+
from .poisson import Poisson
|
|
9
|
+
from .zero_inflated_poisson import ZeroInflatedPoisson
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"NegBin",
|
|
13
|
+
"NegBinIRLS",
|
|
14
|
+
"ZeroInflatedNegBin",
|
|
15
|
+
"Gaussian",
|
|
16
|
+
"Bernoulli",
|
|
17
|
+
"Poisson",
|
|
18
|
+
"ZeroInflatedPoisson",
|
|
19
|
+
]
|
|
@@ -1,15 +1,42 @@
|
|
|
1
|
-
from .formula import standardize_formula
|
|
2
|
-
from .marginal import GLMPredictor, Marginal
|
|
3
|
-
from .loader import _to_numpy
|
|
4
|
-
from typing import Union, Dict, Optional
|
|
1
|
+
from ..data.formula import standardize_formula
|
|
2
|
+
from ..base.marginal import GLMPredictor, Marginal
|
|
3
|
+
from ..data.loader import _to_numpy
|
|
4
|
+
from typing import Union, Dict, Optional, Tuple
|
|
5
5
|
import torch
|
|
6
6
|
import numpy as np
|
|
7
|
-
from scipy.stats import
|
|
7
|
+
from scipy.stats import bernoulli
|
|
8
8
|
|
|
9
|
-
class
|
|
10
|
-
"""
|
|
9
|
+
class Bernoulli(Marginal):
|
|
10
|
+
"""Bernoulli marginal estimator
|
|
11
|
+
|
|
12
|
+
This subclass behaves like `Marginal` but assumes each feature follows a
|
|
13
|
+
Bernoulli distribution with success probability `theta_j(x)` that depends
|
|
14
|
+
on covariates `x` through the `formula` argument.
|
|
15
|
+
|
|
16
|
+
The allowed formula keys are 'mean' (interpreted as the logit of the
|
|
17
|
+
success probability when used with a GLM link). If a string formula is
|
|
18
|
+
provided, it is taken to specify the `mean` model.
|
|
19
|
+
|
|
20
|
+
Examples
|
|
21
|
+
--------
|
|
22
|
+
>>> from scdesigner.distributions import Bernoulli
|
|
23
|
+
>>> from scdesigner.datasets import pancreas
|
|
24
|
+
>>>
|
|
25
|
+
>>> sim = Bernoulli(formula="~ pseudotime")
|
|
26
|
+
>>> sim.setup_data(pancreas)
|
|
27
|
+
>>> sim.fit(max_epochs=1, verbose=False)
|
|
28
|
+
>>>
|
|
29
|
+
>>> # evaluate p(y | x) and theta(x)
|
|
30
|
+
>>> y, x = next(iter(sim.loader))
|
|
31
|
+
>>> l = sim.likelihood((y, x))
|
|
32
|
+
>>> y_hat = sim.predict(x)
|
|
33
|
+
>>>
|
|
34
|
+
>>> # convert to quantiles and back
|
|
35
|
+
>>> u = sim.uniformize(y, x)
|
|
36
|
+
>>> x_star = sim.invert(u, x)
|
|
37
|
+
"""
|
|
11
38
|
def __init__(self, formula: Union[Dict, str]):
|
|
12
|
-
formula = standardize_formula(formula, allowed_keys=['mean'
|
|
39
|
+
formula = standardize_formula(formula, allowed_keys=['mean'])
|
|
13
40
|
super().__init__(formula)
|
|
14
41
|
|
|
15
42
|
def setup_optimizer(
|
|
@@ -21,7 +48,8 @@ class ZeroInflatedNegBin(Marginal):
|
|
|
21
48
|
raise RuntimeError("self.loader is not set (call setup_data first)")
|
|
22
49
|
|
|
23
50
|
link_fns = {"mean": torch.sigmoid}
|
|
24
|
-
nll
|
|
51
|
+
def nll(batch):
|
|
52
|
+
return -self.likelihood(batch).sum()
|
|
25
53
|
self.predict = GLMPredictor(
|
|
26
54
|
n_outcomes=self.n_outcomes,
|
|
27
55
|
feature_dims=self.feature_dims,
|
|
@@ -31,20 +59,20 @@ class ZeroInflatedNegBin(Marginal):
|
|
|
31
59
|
optimizer_kwargs=optimizer_kwargs
|
|
32
60
|
)
|
|
33
61
|
|
|
34
|
-
def likelihood(self, batch):
|
|
35
|
-
"""Compute the
|
|
62
|
+
def likelihood(self, batch) -> torch.Tensor:
|
|
63
|
+
"""Compute the log-likelihood"""
|
|
36
64
|
y, x = batch
|
|
37
65
|
params = self.predict(x)
|
|
38
66
|
theta = params.get("mean")
|
|
39
67
|
return y * torch.log(theta) + (1 - y) * torch.log(1 - theta)
|
|
40
68
|
|
|
41
|
-
def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
|
|
69
|
+
def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
42
70
|
"""Invert pseudoobservations."""
|
|
43
71
|
theta, u = self._local_params(x, u)
|
|
44
72
|
y = bernoulli(theta).ppf(u)
|
|
45
73
|
return torch.from_numpy(y).float()
|
|
46
74
|
|
|
47
|
-
def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6):
|
|
75
|
+
def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6) -> torch.Tensor:
|
|
48
76
|
"""Return uniformized pseudo-observations for counts y given covariates x."""
|
|
49
77
|
theta, y = self._local_params(x, y)
|
|
50
78
|
u1 = bernoulli(theta).cdf(y)
|
|
@@ -53,7 +81,7 @@ class ZeroInflatedNegBin(Marginal):
|
|
|
53
81
|
u = np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
|
|
54
82
|
return torch.from_numpy(u).float()
|
|
55
83
|
|
|
56
|
-
def _local_params(self, x, y=None):
|
|
84
|
+
def _local_params(self, x, y=None) -> Tuple:
|
|
57
85
|
params = self.predict(x)
|
|
58
86
|
theta = params.get('mean')
|
|
59
87
|
if y is None:
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from ..data.formula import standardize_formula
|
|
2
|
+
from ..base.marginal import GLMPredictor, Marginal
|
|
3
|
+
from ..data.loader import _to_numpy
|
|
4
|
+
from typing import Union, Dict, Optional, Tuple
|
|
5
|
+
import torch
|
|
6
|
+
import numpy as np
|
|
7
|
+
from scipy.stats import norm
|
|
8
|
+
|
|
9
|
+
class Gaussian(Marginal):
|
|
10
|
+
"""Gaussian marginal estimator
|
|
11
|
+
|
|
12
|
+
This subclass behaves like `Marginal` but assuming that each gene follows a
|
|
13
|
+
normal N(mu[j](x), sigma[j]^2(x)) distribution. The parameters mu[j](x) and
|
|
14
|
+
sigma[j]^2(x) depend on experimental or biological features x through the
|
|
15
|
+
formula object.
|
|
16
|
+
|
|
17
|
+
The allowed formula keys are 'mean' and 'sdev', defaulting to 'mean' with a
|
|
18
|
+
fixed standard deviation if only a string formula is passed in.
|
|
19
|
+
|
|
20
|
+
Examples
|
|
21
|
+
--------
|
|
22
|
+
>>> from scdesigner.distributions import Gaussian
|
|
23
|
+
>>> from scdesigner.datasets import pancreas
|
|
24
|
+
>>>
|
|
25
|
+
>>> sim = Gaussian(formula={"mean": "~ bs(pseudotime, df=5)", "sdev": "~ pseudotime"})
|
|
26
|
+
>>> sim.setup_data(pancreas)
|
|
27
|
+
>>> sim.fit(max_epochs=1, verbose=False)
|
|
28
|
+
>>>
|
|
29
|
+
>>> # evaluate p(y | x) and mu(x)
|
|
30
|
+
>>> y, x = next(iter(sim.loader))
|
|
31
|
+
>>> l = sim.likelihood((y, x))
|
|
32
|
+
>>> y_hat = sim.predict(x)
|
|
33
|
+
>>>
|
|
34
|
+
>>> # convert to quantiles and back
|
|
35
|
+
>>> u = sim.uniformize(y, x)
|
|
36
|
+
>>> x_star = sim.invert(u, x)
|
|
37
|
+
"""
|
|
38
|
+
def __init__(self, formula: Union[Dict, str]):
|
|
39
|
+
formula = standardize_formula(formula, allowed_keys=['mean', 'sdev'])
|
|
40
|
+
super().__init__(formula)
|
|
41
|
+
|
|
42
|
+
def setup_optimizer(
|
|
43
|
+
self,
|
|
44
|
+
optimizer_class: Optional[callable] = torch.optim.Adam,
|
|
45
|
+
**optimizer_kwargs,
|
|
46
|
+
):
|
|
47
|
+
"""
|
|
48
|
+
Gaussian Model Optimizer
|
|
49
|
+
|
|
50
|
+
By default optimization is done using Adam. This can be customized using
|
|
51
|
+
the `optimizer_class` argument. The link function for the mean is an
|
|
52
|
+
identity link.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
optimizer_class : Optional[callable]
|
|
57
|
+
We optimize the negative log likelihood using the Adam optimizer by
|
|
58
|
+
default. Alternative torch.optim.* optimizer can be passed in
|
|
59
|
+
through this argument.
|
|
60
|
+
**optimizer_kwargs :
|
|
61
|
+
Arguments that are passed to the optimizer during estimation.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
Does not return anything, but modifies the self.predict attribute to
|
|
66
|
+
refer to the new optimizer object.
|
|
67
|
+
"""
|
|
68
|
+
if self.loader is None:
|
|
69
|
+
raise RuntimeError("self.loader is not set (call setup_data first)")
|
|
70
|
+
|
|
71
|
+
def nll(batch):
|
|
72
|
+
return -self.likelihood(batch).sum()
|
|
73
|
+
link_fns = {"mean": lambda x: x}
|
|
74
|
+
self.predict = GLMPredictor(
|
|
75
|
+
n_outcomes=self.n_outcomes,
|
|
76
|
+
feature_dims=self.feature_dims,
|
|
77
|
+
link_fns=link_fns,
|
|
78
|
+
loss_fn=nll,
|
|
79
|
+
optimizer_class=optimizer_class,
|
|
80
|
+
optimizer_kwargs=optimizer_kwargs
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def likelihood(self, batch) -> torch.Tensor:
|
|
84
|
+
"""Compute the log-likelihood"""
|
|
85
|
+
y, x = batch
|
|
86
|
+
params = self.predict(x)
|
|
87
|
+
mu = params.get("mean")
|
|
88
|
+
sigma = params.get("sdev")
|
|
89
|
+
|
|
90
|
+
# log likelihood for Gaussian
|
|
91
|
+
log_likelihood = -0.5 * (torch.log(2 * torch.pi * sigma ** 2) + ((y - mu) ** 2) / (sigma ** 2))
|
|
92
|
+
return log_likelihood
|
|
93
|
+
|
|
94
|
+
def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
95
|
+
"""Invert pseudoobservations."""
|
|
96
|
+
mu, sdev, u = self._local_params(x, u)
|
|
97
|
+
y = norm(loc=mu, scale=sdev).ppf(u)
|
|
98
|
+
return torch.from_numpy(y).float()
|
|
99
|
+
|
|
100
|
+
def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6) -> torch.Tensor:
|
|
101
|
+
"""Return uniformized pseudo-observations for counts y given covariates x."""
|
|
102
|
+
# cdf values using scipy's parameterization
|
|
103
|
+
mu, sdev, y = self._local_params(x, y)
|
|
104
|
+
u = norm.cdf(y, loc=mu, scale=sdev)
|
|
105
|
+
u = np.clip(u, epsilon, 1 - epsilon)
|
|
106
|
+
return torch.from_numpy(u).float()
|
|
107
|
+
|
|
108
|
+
def _local_params(self, x, y=None) -> Tuple:
|
|
109
|
+
params = self.predict(x)
|
|
110
|
+
mu = params.get('mean')
|
|
111
|
+
sdev = params.get('sdev')
|
|
112
|
+
if y is None:
|
|
113
|
+
return _to_numpy(mu, sdev)
|
|
114
|
+
return _to_numpy(mu, sdev, y)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from ..base.marginal import GLMPredictor, Marginal
|
|
2
|
+
from ..data.formula import standardize_formula
|
|
3
|
+
from ..data.loader import _to_numpy
|
|
4
|
+
from ..utils.kwargs import _filter_kwargs, DEFAULT_ALLOWED_KWARGS
|
|
5
|
+
from .negbin_irls_funs import initialize_parameters
|
|
6
|
+
from scipy.stats import nbinom
|
|
7
|
+
from typing import Union, Dict, Optional, Tuple
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class NegBin(Marginal):
|
|
13
|
+
"""Negative-binomial marginal estimator with poisson initialization
|
|
14
|
+
|
|
15
|
+
This subclass behaves like `Marginal` but assumes each gene follows a
|
|
16
|
+
negative binomial distribution NB(mu_j(x), r_j(x)) parameterized via a mean
|
|
17
|
+
`mu_j(x)` and dispersion `r_j(x)` that depend on covariates `x` through the
|
|
18
|
+
provided `formula` object.
|
|
19
|
+
|
|
20
|
+
The allowed formula keys are 'mean' and 'dispersion', defaulting to
|
|
21
|
+
'mean' with a fixed dispersion if only a string formula is passed in.
|
|
22
|
+
|
|
23
|
+
Examples
|
|
24
|
+
--------
|
|
25
|
+
>>> from scdesigner.distributions import NegBin
|
|
26
|
+
>>> from scdesigner.datasets import pancreas
|
|
27
|
+
>>>
|
|
28
|
+
>>> sim = NegBin(formula={"mean": "~ bs(pseudotime, df=5)", "dispersion": "~ pseudotime"})
|
|
29
|
+
>>> sim.setup_data(pancreas)
|
|
30
|
+
>>> sim.fit(max_epochs=1, verbose=False)
|
|
31
|
+
>>>
|
|
32
|
+
>>> # evaluate p(y | x) and mu(x)
|
|
33
|
+
>>> y, x = next(iter(sim.loader))
|
|
34
|
+
>>> l = sim.likelihood((y, x))
|
|
35
|
+
>>> y_hat = sim.predict(x)
|
|
36
|
+
>>>
|
|
37
|
+
>>> # convert to quantiles and back
|
|
38
|
+
>>> u = sim.uniformize(y, x)
|
|
39
|
+
>>> x_star = sim.invert(u, x)
|
|
40
|
+
"""
|
|
41
|
+
def __init__(self, formula: Union[Dict, str]):
|
|
42
|
+
formula = standardize_formula(formula, allowed_keys=['mean', 'dispersion'])
|
|
43
|
+
super().__init__(formula)
|
|
44
|
+
|
|
45
|
+
def setup_optimizer(
|
|
46
|
+
self,
|
|
47
|
+
optimizer_class: Optional[callable] = torch.optim.AdamW,
|
|
48
|
+
**optimizer_kwargs,
|
|
49
|
+
):
|
|
50
|
+
if self.loader is None:
|
|
51
|
+
raise RuntimeError("self.loader is not set (call setup_data first)")
|
|
52
|
+
|
|
53
|
+
def nll(batch):
|
|
54
|
+
return -self.likelihood(batch).sum()
|
|
55
|
+
self.predict = GLMPredictor(
|
|
56
|
+
n_outcomes=self.n_outcomes,
|
|
57
|
+
feature_dims=self.feature_dims,
|
|
58
|
+
loss_fn=nll,
|
|
59
|
+
optimizer_class=optimizer_class,
|
|
60
|
+
optimizer_kwargs=optimizer_kwargs
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def likelihood(self, batch) -> torch.Tensor:
|
|
64
|
+
"""Compute the log-likelihood"""
|
|
65
|
+
y, x = batch
|
|
66
|
+
params = self.predict(x)
|
|
67
|
+
mu = params.get('mean')
|
|
68
|
+
r = params.get('dispersion')
|
|
69
|
+
return (
|
|
70
|
+
torch.lgamma(y + r)
|
|
71
|
+
- torch.lgamma(r)
|
|
72
|
+
- torch.lgamma(y + 1.0)
|
|
73
|
+
+ r * torch.log(r)
|
|
74
|
+
+ y * torch.log(mu)
|
|
75
|
+
- (r + y) * torch.log(r + mu)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
79
|
+
"""Invert pseudoobservations."""
|
|
80
|
+
mu, r, u = self._local_params(x, u)
|
|
81
|
+
p = r / (r + mu)
|
|
82
|
+
y = nbinom(n=r, p=p).ppf(u)
|
|
83
|
+
return torch.from_numpy(y).float()
|
|
84
|
+
|
|
85
|
+
def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6) -> torch.Tensor:
|
|
86
|
+
"""Return uniformized pseudo-observations for counts y given covariates x."""
|
|
87
|
+
# cdf values using scipy's parameterization
|
|
88
|
+
mu, r, y = self._local_params(x, y)
|
|
89
|
+
p = r / (r + mu)
|
|
90
|
+
u1 = nbinom(n=r, p=p).cdf(y)
|
|
91
|
+
u2 = np.where(y > 0, nbinom(n=r, p=p).cdf(y - 1), 0.0)
|
|
92
|
+
|
|
93
|
+
# randomize within discrete mass to get uniform(0,1)
|
|
94
|
+
v = np.random.uniform(size=y.shape)
|
|
95
|
+
u = np.clip(v * u1 + (1.0 - v) * u2, epsilon, 1.0 - epsilon)
|
|
96
|
+
return torch.from_numpy(u).float()
|
|
97
|
+
|
|
98
|
+
def _local_params(self, x, y=None) -> Tuple:
|
|
99
|
+
params = self.predict(x)
|
|
100
|
+
mu = params.get('mean')
|
|
101
|
+
r = params.get('dispersion')
|
|
102
|
+
if y is None:
|
|
103
|
+
return _to_numpy(mu, r)
|
|
104
|
+
return _to_numpy(mu, r, y)
|
|
105
|
+
|
|
106
|
+
def fit(self, max_epochs: int = 100, verbose: bool = True, **kwargs):
|
|
107
|
+
if self.predict is None:
|
|
108
|
+
self.setup_optimizer(**kwargs)
|
|
109
|
+
|
|
110
|
+
# initialize using a poisson fit
|
|
111
|
+
initialize_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['initialize'])
|
|
112
|
+
beta_init, gamma_init = initialize_parameters(
|
|
113
|
+
self.loader, self.n_outcomes, self.feature_dims['mean'],
|
|
114
|
+
self.feature_dims['dispersion'],
|
|
115
|
+
**initialize_kwargs
|
|
116
|
+
)
|
|
117
|
+
with torch.no_grad():
|
|
118
|
+
self.predict.coefs['mean'].copy_(beta_init)
|
|
119
|
+
self.predict.coefs['dispersion'].copy_(gamma_init)
|
|
120
|
+
|
|
121
|
+
return Marginal.fit(self, max_epochs, verbose, **kwargs)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .negbin import NegBin
|
|
3
|
+
from .negbin_irls_funs import initialize_parameters, step_stochastic_irls
|
|
4
|
+
from ..data.formula import standardize_formula
|
|
5
|
+
from ..utils.kwargs import _filter_kwargs, DEFAULT_ALLOWED_KWARGS
|
|
6
|
+
from typing import Union, Dict
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class NegBinIRLS(NegBin):
|
|
10
|
+
"""
|
|
11
|
+
Negative-Binomial Marginal using Stochastic IRLS with
|
|
12
|
+
active response tracking and log-likelihood convergence.
|
|
13
|
+
"""
|
|
14
|
+
def __init__(self, formula: Union[Dict, str]):
|
|
15
|
+
formula = standardize_formula(formula, allowed_keys=['mean', 'dispersion'])
|
|
16
|
+
super().__init__(formula, device="cpu")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def fit(self, max_epochs=10, tol=1e-4, eta=0.1, verbose=True, **kwargs):
|
|
20
|
+
if self.predict is None:
|
|
21
|
+
self.setup_optimizer(**kwargs)
|
|
22
|
+
|
|
23
|
+
# 1. Initialization using poisson fit
|
|
24
|
+
initialize_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['initialize'])
|
|
25
|
+
beta_init, gamma_init = initialize_parameters(
|
|
26
|
+
self.loader, self.n_outcomes, self.feature_dims['mean'],
|
|
27
|
+
self.feature_dims['dispersion'],
|
|
28
|
+
**initialize_kwargs
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
with torch.no_grad():
|
|
32
|
+
self.predict.coefs['mean'].copy_(beta_init)
|
|
33
|
+
self.predict.coefs['dispersion'].copy_(gamma_init)
|
|
34
|
+
|
|
35
|
+
# 2. All genes are active at the start
|
|
36
|
+
active_mask = torch.ones(self.n_outcomes, dtype=torch.bool)
|
|
37
|
+
ll_ = - 1e9 * torch.ones(self.n_outcomes, dtype=torch.float32)
|
|
38
|
+
|
|
39
|
+
for epoch in range(max_epochs):
|
|
40
|
+
if not active_mask.any(): break
|
|
41
|
+
ll, n_batches = 0.0, 0
|
|
42
|
+
|
|
43
|
+
with torch.no_grad():
|
|
44
|
+
for y_batch, x_dict in self.loader:
|
|
45
|
+
|
|
46
|
+
# Slice active genes
|
|
47
|
+
idx = torch.where(active_mask)[0]
|
|
48
|
+
y_act = y_batch[:, active_mask]
|
|
49
|
+
X = x_dict['mean']
|
|
50
|
+
Z = x_dict['dispersion']
|
|
51
|
+
|
|
52
|
+
# Fetch current coefficients and update
|
|
53
|
+
b_curr = self.predict.coefs['mean'][:, active_mask]
|
|
54
|
+
g_curr = self.predict.coefs['dispersion'][:, active_mask]
|
|
55
|
+
b_next, g_next, conv_mask, ll_cur = step_stochastic_irls(y_act, X, Z, b_curr, g_curr, eta, tol, ll_[active_mask])
|
|
56
|
+
ll_[active_mask] = ll_cur
|
|
57
|
+
|
|
58
|
+
# Update Parameters and de-activate converged genes
|
|
59
|
+
with torch.no_grad():
|
|
60
|
+
self.predict.coefs['mean'][:, active_mask] = b_next
|
|
61
|
+
self.predict.coefs['dispersion'][:, active_mask] = g_next
|
|
62
|
+
active_mask[idx[conv_mask]] = False
|
|
63
|
+
|
|
64
|
+
# Accumulate batch log-likelihood using `ll` from the IRLS step
|
|
65
|
+
ll += ll_.sum().item()
|
|
66
|
+
n_batches += 1
|
|
67
|
+
|
|
68
|
+
if verbose and ((epoch + 1) % 10) == 0:
|
|
69
|
+
print(f"Epoch {epoch+1}/{max_epochs} | Genes remaining: {active_mask.sum().item()} | Loss: {-ll / n_batches:.4f}", end='\r')
|
|
70
|
+
if not active_mask.any(): break
|
|
71
|
+
|
|
72
|
+
self.parameters = self.format_parameters()
|