scdesigner 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scdesigner might be problematic. Click here for more details.
- scdesigner/__init__.py +0 -0
- scdesigner/data/__init__.py +16 -0
- scdesigner/data/formula.py +137 -0
- scdesigner/data/group.py +123 -0
- scdesigner/data/sparse.py +39 -0
- scdesigner/diagnose/__init__.py +65 -0
- scdesigner/diagnose/aic_bic.py +119 -0
- scdesigner/diagnose/plot.py +242 -0
- scdesigner/estimators/__init__.py +27 -0
- scdesigner/estimators/bernoulli.py +85 -0
- scdesigner/estimators/gaussian.py +121 -0
- scdesigner/estimators/gaussian_copula_factory.py +152 -0
- scdesigner/estimators/glm_factory.py +75 -0
- scdesigner/estimators/negbin.py +129 -0
- scdesigner/estimators/pnmf.py +160 -0
- scdesigner/estimators/poisson.py +100 -0
- scdesigner/estimators/zero_inflated_negbin.py +195 -0
- scdesigner/estimators/zero_inflated_poisson.py +85 -0
- scdesigner/format/__init__.py +4 -0
- scdesigner/format/format.py +20 -0
- scdesigner/format/print.py +30 -0
- scdesigner/minimal/__init__.py +17 -0
- scdesigner/minimal/bernoulli.py +61 -0
- scdesigner/minimal/composite.py +119 -0
- scdesigner/minimal/copula.py +33 -0
- scdesigner/minimal/formula.py +23 -0
- scdesigner/minimal/gaussian.py +65 -0
- scdesigner/minimal/kwargs.py +24 -0
- scdesigner/minimal/loader.py +166 -0
- scdesigner/minimal/marginal.py +140 -0
- scdesigner/minimal/negbin.py +73 -0
- scdesigner/minimal/positive_nonnegative_matrix_factorization.py +231 -0
- scdesigner/minimal/scd3.py +95 -0
- scdesigner/minimal/scd3_instances.py +50 -0
- scdesigner/minimal/simulator.py +25 -0
- scdesigner/minimal/standard_covariance.py +124 -0
- scdesigner/minimal/transform.py +145 -0
- scdesigner/minimal/zero_inflated_negbin.py +86 -0
- scdesigner/predictors/__init__.py +15 -0
- scdesigner/predictors/bernoulli.py +9 -0
- scdesigner/predictors/gaussian.py +16 -0
- scdesigner/predictors/negbin.py +17 -0
- scdesigner/predictors/poisson.py +12 -0
- scdesigner/predictors/zero_inflated_negbin.py +18 -0
- scdesigner/predictors/zero_inflated_poisson.py +18 -0
- scdesigner/samplers/__init__.py +23 -0
- scdesigner/samplers/bernoulli.py +27 -0
- scdesigner/samplers/gaussian.py +25 -0
- scdesigner/samplers/glm_factory.py +41 -0
- scdesigner/samplers/negbin.py +25 -0
- scdesigner/samplers/poisson.py +25 -0
- scdesigner/samplers/zero_inflated_negbin.py +40 -0
- scdesigner/samplers/zero_inflated_poisson.py +16 -0
- scdesigner/simulators/__init__.py +31 -0
- scdesigner/simulators/composite_regressor.py +72 -0
- scdesigner/simulators/glm_simulator.py +167 -0
- scdesigner/simulators/pnmf_regression.py +61 -0
- scdesigner/transform/__init__.py +7 -0
- scdesigner/transform/amplify.py +14 -0
- scdesigner/transform/mask.py +33 -0
- scdesigner/transform/nullify.py +25 -0
- scdesigner/transform/split.py +23 -0
- scdesigner/transform/substitute.py +14 -0
- scdesigner-0.0.1.dist-info/METADATA +23 -0
- scdesigner-0.0.1.dist-info/RECORD +66 -0
- scdesigner-0.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import rich
|
|
2
|
+
import rich.table
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def shorten_names(features, max_features=5):
|
|
6
|
+
if len(features) > max_features:
|
|
7
|
+
features = features[: (int(max_features - 1))] + ["..."] + features[-1:]
|
|
8
|
+
return f"""[{', '.join(features)}]"""
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def print_simulator(margins, copula):
|
|
12
|
+
table = rich.table.Table(
|
|
13
|
+
title="[bold magenta]Simulation Plan[/bold magenta]", title_justify="left"
|
|
14
|
+
)
|
|
15
|
+
table.add_column("formula")
|
|
16
|
+
table.add_column("distribution")
|
|
17
|
+
table.add_column("features")
|
|
18
|
+
|
|
19
|
+
i = 1
|
|
20
|
+
for m in margins:
|
|
21
|
+
features, margin = m
|
|
22
|
+
tup = tuple(margin.to_df().iloc[0, :]) + (shorten_names(features),)
|
|
23
|
+
table.add_row(*tup)
|
|
24
|
+
i += 1
|
|
25
|
+
|
|
26
|
+
rich.print(table)
|
|
27
|
+
if copula is None:
|
|
28
|
+
rich.print("Marginal models without copula.")
|
|
29
|
+
|
|
30
|
+
return ""
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .scd3_instances import (
|
|
2
|
+
BernoulliCopula,
|
|
3
|
+
GaussianCopula,
|
|
4
|
+
NegBinCopula,
|
|
5
|
+
ZeroInflatedNegBinCopula
|
|
6
|
+
)
|
|
7
|
+
from .composite import CompositeCopula
|
|
8
|
+
from .positive_nonnegative_matrix_factorization import PositiveNMF
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"BernoulliCopula",
|
|
12
|
+
"CompositeCopula",
|
|
13
|
+
"GaussianCopula",
|
|
14
|
+
"NegBinCopula",
|
|
15
|
+
"PositiveNMF",
|
|
16
|
+
"ZeroInflatedNegBinCopula"
|
|
17
|
+
]
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from .formula import standardize_formula
|
|
2
|
+
from .marginal import GLMPredictor, Marginal
|
|
3
|
+
from .loader import _to_numpy
|
|
4
|
+
from typing import Union, Dict, Optional
|
|
5
|
+
import torch
|
|
6
|
+
import numpy as np
|
|
7
|
+
from scipy.stats import nbinom, bernoulli
|
|
8
|
+
|
|
9
|
+
class ZeroInflatedNegBin(Marginal):
|
|
10
|
+
"""Zero-inflated negative-binomial marginal estimator"""
|
|
11
|
+
def __init__(self, formula: Union[Dict, str]):
|
|
12
|
+
formula = standardize_formula(formula, allowed_keys=['mean', 'dispersion', 'zero_inflation'])
|
|
13
|
+
super().__init__(formula)
|
|
14
|
+
|
|
15
|
+
def setup_optimizer(
|
|
16
|
+
self,
|
|
17
|
+
optimizer_class: Optional[callable] = torch.optim.Adam,
|
|
18
|
+
**optimizer_kwargs,
|
|
19
|
+
):
|
|
20
|
+
if self.loader is None:
|
|
21
|
+
raise RuntimeError("self.loader is not set (call setup_data first)")
|
|
22
|
+
|
|
23
|
+
link_fns = {"mean": torch.sigmoid}
|
|
24
|
+
nll = lambda batch: -self.likelihood(batch).sum()
|
|
25
|
+
self.predict = GLMPredictor(
|
|
26
|
+
n_outcomes=self.n_outcomes,
|
|
27
|
+
feature_dims=self.feature_dims,
|
|
28
|
+
link_fns=link_fns,
|
|
29
|
+
loss_fn=nll,
|
|
30
|
+
optimizer_class=optimizer_class,
|
|
31
|
+
optimizer_kwargs=optimizer_kwargs
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def likelihood(self, batch):
|
|
35
|
+
"""Compute the negative log-likelihood"""
|
|
36
|
+
y, x = batch
|
|
37
|
+
params = self.predict(x)
|
|
38
|
+
theta = params.get("mean")
|
|
39
|
+
return y * torch.log(theta) + (1 - y) * torch.log(1 - theta)
|
|
40
|
+
|
|
41
|
+
def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
|
|
42
|
+
"""Invert pseudoobservations."""
|
|
43
|
+
theta, u = self._local_params(x, u)
|
|
44
|
+
y = bernoulli(theta).ppf(u)
|
|
45
|
+
return torch.from_numpy(y).float()
|
|
46
|
+
|
|
47
|
+
def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6):
|
|
48
|
+
"""Return uniformized pseudo-observations for counts y given covariates x."""
|
|
49
|
+
theta, y = self._local_params(x, y)
|
|
50
|
+
u1 = bernoulli(theta).cdf(y)
|
|
51
|
+
u2 = np.where(y > 0, bernoulli(theta).cdf(y - 1), 0)
|
|
52
|
+
v = np.random.uniform(size=y.shape)
|
|
53
|
+
u = np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
|
|
54
|
+
return torch.from_numpy(u).float()
|
|
55
|
+
|
|
56
|
+
def _local_params(self, x, y=None):
|
|
57
|
+
params = self.predict(x)
|
|
58
|
+
theta = params.get('mean')
|
|
59
|
+
if y is None:
|
|
60
|
+
return _to_numpy(theta)
|
|
61
|
+
return _to_numpy(theta, y)
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from .loader import obs_loader
|
|
2
|
+
from .scd3 import SCD3Simulator
|
|
3
|
+
from .standard_covariance import StandardCovariance
|
|
4
|
+
from anndata import AnnData
|
|
5
|
+
from typing import Dict, Optional, List
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
class CompositeCopula(SCD3Simulator):
|
|
10
|
+
def __init__(self, marginals: List,
|
|
11
|
+
copula_formula: Optional[str] = None) -> None:
|
|
12
|
+
self.marginals = marginals
|
|
13
|
+
self.copula = StandardCovariance(copula_formula)
|
|
14
|
+
self.template = None
|
|
15
|
+
self.parameters = None
|
|
16
|
+
self.merged_formula = None
|
|
17
|
+
|
|
18
|
+
def fit(
|
|
19
|
+
self,
|
|
20
|
+
adata: AnnData,
|
|
21
|
+
**kwargs):
|
|
22
|
+
"""Fit the simulator"""
|
|
23
|
+
self.template = adata
|
|
24
|
+
merged_formula = {}
|
|
25
|
+
|
|
26
|
+
# fit each marginal model
|
|
27
|
+
for m in range(len(self.marginals)):
|
|
28
|
+
self.marginals[m][1].setup_data(adata[:, self.marginals[m][0]], **kwargs)
|
|
29
|
+
self.marginals[m][1].setup_optimizer(**kwargs)
|
|
30
|
+
self.marginals[m][1].fit(**kwargs)
|
|
31
|
+
|
|
32
|
+
# prepare formula for copula loader
|
|
33
|
+
f = self.marginals[m][1].formula
|
|
34
|
+
prefixed_f = {f"group{m}_{k}": v for k, v in f.items()}
|
|
35
|
+
merged_formula = merged_formula | prefixed_f
|
|
36
|
+
|
|
37
|
+
# copula simulator
|
|
38
|
+
self.merged_formula = merged_formula
|
|
39
|
+
self.copula.setup_data(adata, merged_formula, **kwargs)
|
|
40
|
+
self.copula.fit(self.merged_uniformize, **kwargs)
|
|
41
|
+
self.parameters = {
|
|
42
|
+
"marginal": [m[1].parameters for m in self.marginals],
|
|
43
|
+
"copula": self.copula.parameters
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
def merged_uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
47
|
+
"""Produce a merged uniformized matrix for all marginals.
|
|
48
|
+
|
|
49
|
+
Delegates to each marginal's `uniformize` method and places the
|
|
50
|
+
result into the columns of a full matrix according to the variable
|
|
51
|
+
selection given in `self.marginals[m][0]`.
|
|
52
|
+
"""
|
|
53
|
+
y_np = y.detach().cpu().numpy()
|
|
54
|
+
u = np.empty_like(y_np, dtype=float)
|
|
55
|
+
|
|
56
|
+
for m in range(len(self.marginals)):
|
|
57
|
+
sel = self.marginals[m][0]
|
|
58
|
+
ix = _var_indices(sel, self.template)
|
|
59
|
+
|
|
60
|
+
# remove the `group{m}_` prefix we used to distinguish the marginals
|
|
61
|
+
prefix = f"group{m}_"
|
|
62
|
+
cur_x = {k.removeprefix(prefix): v if k.startswith(prefix) else v for k, v in x.items()}
|
|
63
|
+
|
|
64
|
+
# slice the subset of y for this marginal and call its uniformize
|
|
65
|
+
y_sub = torch.from_numpy(y_np[:, ix])
|
|
66
|
+
u[:, ix] = self.marginals[m][1].uniformize(y_sub, cur_x)
|
|
67
|
+
return torch.from_numpy(u)
|
|
68
|
+
|
|
69
|
+
def predict(self, obs=None, batch_size: int = 1000, **kwargs):
|
|
70
|
+
"""Predict from an obs dataframe"""
|
|
71
|
+
# prepare an internal data loader for this obs
|
|
72
|
+
if obs is None:
|
|
73
|
+
obs = self.template.obs
|
|
74
|
+
loader = obs_loader(
|
|
75
|
+
obs,
|
|
76
|
+
self.merged_formula,
|
|
77
|
+
batch_size=batch_size,
|
|
78
|
+
**kwargs
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# prepare per-marginal collectors
|
|
82
|
+
n_marginals = len(self.marginals)
|
|
83
|
+
local_pred = [[] for _ in range(n_marginals)]
|
|
84
|
+
|
|
85
|
+
# for each batch, call each marginal's predict on its subset of x
|
|
86
|
+
for _, x_dict in loader:
|
|
87
|
+
for m in range(n_marginals):
|
|
88
|
+
prefix = f"group{m}_"
|
|
89
|
+
# build cur_x where prefixed keys are unprefixed for the marginal
|
|
90
|
+
cur_x = {k.removeprefix(prefix): v for k, v in x_dict.items()}
|
|
91
|
+
params = self.marginals[m][1].predict(cur_x)
|
|
92
|
+
local_pred[m].append(params)
|
|
93
|
+
|
|
94
|
+
# merge batch-wise parameter dicts for each marginal and return
|
|
95
|
+
results = []
|
|
96
|
+
for m in range(n_marginals):
|
|
97
|
+
parts = local_pred[m]
|
|
98
|
+
keys = list(parts[0].keys())
|
|
99
|
+
results.append({k: torch.cat([d[k] for d in parts]).detach().cpu().numpy() for k in keys})
|
|
100
|
+
|
|
101
|
+
return results
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _var_indices(sel, adata: AnnData) -> np.ndarray:
|
|
105
|
+
"""Return integer indices of `sel` within `adata.var_names`.
|
|
106
|
+
|
|
107
|
+
Expected use: `sel` is a list (or tuple) of variable names (strings).
|
|
108
|
+
"""
|
|
109
|
+
# If sel is a single string, make it a list so we return consistent shape
|
|
110
|
+
single_string = False
|
|
111
|
+
if isinstance(sel, str):
|
|
112
|
+
sel = [sel]
|
|
113
|
+
single_string = True
|
|
114
|
+
|
|
115
|
+
idx = np.asarray(adata.var_names.get_indexer(sel), dtype=int)
|
|
116
|
+
if (idx < 0).any():
|
|
117
|
+
missing = [s for s, i in zip(sel, idx) if i < 0]
|
|
118
|
+
raise KeyError(f"Variables not found in adata.var_names: {missing}")
|
|
119
|
+
return idx if not single_string else idx.reshape(-1)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from typing import Dict, Callable, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
from anndata import AnnData
|
|
4
|
+
from .loader import adata_loader
|
|
5
|
+
|
|
6
|
+
class Copula:
|
|
7
|
+
def __init__(self, formula: str, **kwargs):
|
|
8
|
+
self.formula = formula
|
|
9
|
+
self.loader = None
|
|
10
|
+
self.n_outcomes = None
|
|
11
|
+
self.parameters = None
|
|
12
|
+
|
|
13
|
+
def setup_data(self, adata: AnnData, marginal_formula: Dict[str, str], batch_size: int = 1024, **kwargs):
|
|
14
|
+
self.adata = adata
|
|
15
|
+
self.formula = self.formula | marginal_formula
|
|
16
|
+
self.loader = adata_loader(adata, self.formula, batch_size=batch_size, **kwargs)
|
|
17
|
+
X_batch, _ = next(iter(self.loader))
|
|
18
|
+
self.n_outcomes = X_batch.shape[1]
|
|
19
|
+
|
|
20
|
+
def fit(self, uniformizer: Callable, **kwargs):
|
|
21
|
+
raise NotImplementedError
|
|
22
|
+
|
|
23
|
+
def pseudo_obs(self, x_dict: Dict):
|
|
24
|
+
raise NotImplementedError
|
|
25
|
+
|
|
26
|
+
def likelihood(self, uniformizer: Callable, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
|
|
27
|
+
raise NotImplementedError
|
|
28
|
+
|
|
29
|
+
def num_params(self, **kwargs):
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
|
|
32
|
+
def format_parameters(self):
|
|
33
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
def standardize_formula(formula: Union[str, dict], allowed_keys = None):
|
|
5
|
+
# The first element of allowed_keys should be the name of default parameter
|
|
6
|
+
if allowed_keys is None:
|
|
7
|
+
raise ValueError("Internal error: allowed_keys must be specified")
|
|
8
|
+
formula = {allowed_keys[0]: formula} if isinstance(formula, str) else formula
|
|
9
|
+
|
|
10
|
+
formula_keys = set(formula.keys())
|
|
11
|
+
allowed_keys = set(allowed_keys)
|
|
12
|
+
|
|
13
|
+
if not formula_keys & allowed_keys:
|
|
14
|
+
raise ValueError(f"formula must have at least one of the following keys: {allowed_keys}")
|
|
15
|
+
|
|
16
|
+
if extra_keys := formula_keys - allowed_keys:
|
|
17
|
+
warnings.warn(
|
|
18
|
+
f"Invalid formulas in dictionary will not be used: {extra_keys}",
|
|
19
|
+
UserWarning,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
formula.update({k: '~ 1' for k in allowed_keys - formula_keys})
|
|
23
|
+
return formula
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from .formula import standardize_formula
|
|
2
|
+
from .marginal import GLMPredictor, Marginal
|
|
3
|
+
from .loader import _to_numpy
|
|
4
|
+
from typing import Union, Dict, Optional
|
|
5
|
+
import torch
|
|
6
|
+
import numpy as np
|
|
7
|
+
from scipy.stats import norm
|
|
8
|
+
|
|
9
|
+
class Gaussian(Marginal):
|
|
10
|
+
"""Gaussian marginal estimator"""
|
|
11
|
+
def __init__(self, formula: Union[Dict, str]):
|
|
12
|
+
formula = standardize_formula(formula, allowed_keys=['mean', 'sdev'])
|
|
13
|
+
super().__init__(formula)
|
|
14
|
+
|
|
15
|
+
def setup_optimizer(
|
|
16
|
+
self,
|
|
17
|
+
optimizer_class: Optional[callable] = torch.optim.Adam,
|
|
18
|
+
**optimizer_kwargs,
|
|
19
|
+
):
|
|
20
|
+
if self.loader is None:
|
|
21
|
+
raise RuntimeError("self.loader is not set (call setup_data first)")
|
|
22
|
+
|
|
23
|
+
nll = lambda batch: -self.likelihood(batch).sum()
|
|
24
|
+
link_fns = {"mean": lambda x: x}
|
|
25
|
+
self.predict = GLMPredictor(
|
|
26
|
+
n_outcomes=self.n_outcomes,
|
|
27
|
+
feature_dims=self.feature_dims,
|
|
28
|
+
link_fns=link_fns,
|
|
29
|
+
loss_fn=nll,
|
|
30
|
+
optimizer_class=optimizer_class,
|
|
31
|
+
optimizer_kwargs=optimizer_kwargs
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def likelihood(self, batch):
|
|
35
|
+
"""Compute the negative log-likelihood"""
|
|
36
|
+
y, x = batch
|
|
37
|
+
params = self.predict(x)
|
|
38
|
+
mu = params.get("mean")
|
|
39
|
+
sigma = params.get("sdev")
|
|
40
|
+
|
|
41
|
+
# log likelihood for Gaussian
|
|
42
|
+
log_likelihood = -0.5 * (torch.log(2 * torch.pi * sigma ** 2) + ((y - mu) ** 2) / (sigma ** 2))
|
|
43
|
+
return log_likelihood
|
|
44
|
+
|
|
45
|
+
def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
|
|
46
|
+
"""Invert pseudoobservations."""
|
|
47
|
+
mu, sdev, u = self._local_params(x, u)
|
|
48
|
+
y = norm(loc=mu, scale=sdev).ppf(u)
|
|
49
|
+
return torch.from_numpy(y).float()
|
|
50
|
+
|
|
51
|
+
def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6):
|
|
52
|
+
"""Return uniformized pseudo-observations for counts y given covariates x."""
|
|
53
|
+
# cdf values using scipy's parameterization
|
|
54
|
+
mu, sdev, y = self._local_params(x, y)
|
|
55
|
+
u = norm.cdf(y, loc=mu, scale=sdev)
|
|
56
|
+
u = np.clip(u, epsilon, 1 - epsilon)
|
|
57
|
+
return torch.from_numpy(u).float()
|
|
58
|
+
|
|
59
|
+
def _local_params(self, x, y=None):
|
|
60
|
+
params = self.predict(x)
|
|
61
|
+
mu = params.get('mean')
|
|
62
|
+
sdev = params.get('sdev')
|
|
63
|
+
if y is None:
|
|
64
|
+
return _to_numpy(mu, sdev)
|
|
65
|
+
return _to_numpy(mu, sdev, y)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
def _filter_kwargs(kwargs: dict, allowed: set) -> dict:
|
|
2
|
+
"""Return a new dict containing only keys from kwargs that are in allowed."""
|
|
3
|
+
return {k: v for k, v in kwargs.items() if k in allowed}
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# allowed kwargs for common components
|
|
7
|
+
DEFAULT_ALLOWED_KWARGS = {
|
|
8
|
+
'trainer': {
|
|
9
|
+
'max_epochs', 'gpus', 'devices', 'accelerator', 'precision',
|
|
10
|
+
'logger', 'callbacks', 'strategy', 'num_nodes', 'limit_train_batches',
|
|
11
|
+
'log_every_n_steps', 'accumulate_grad_batches'
|
|
12
|
+
},
|
|
13
|
+
'data': {
|
|
14
|
+
'chunk_size', 'batch_size', 'shuffle', 'num_workers'
|
|
15
|
+
},
|
|
16
|
+
'optimizer': {
|
|
17
|
+
'lr', 'learning_rate', 'momentum', 'weight_decay', 'eps', 'betas',
|
|
18
|
+
'amsgrad', 'dampening', 'nesterov', 'alpha',
|
|
19
|
+
'T_max', 'eta_min', 'step_size', 'gamma', 'milestones', 'last_epoch',
|
|
20
|
+
'verbose', 'patience', 'threshold', 'cooldown',
|
|
21
|
+
'optimizer_class', 'optimizer', 'scheduler_class', 'scheduler',
|
|
22
|
+
'monitor', 'interval', 'frequency'
|
|
23
|
+
}
|
|
24
|
+
}
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
from .kwargs import DEFAULT_ALLOWED_KWARGS, _filter_kwargs
|
|
2
|
+
from anndata import AnnData
|
|
3
|
+
from formulaic import model_matrix
|
|
4
|
+
from torch.utils.data import Dataset, DataLoader
|
|
5
|
+
from typing import Dict
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
class AnnDataDataset(Dataset):
|
|
11
|
+
"""Simple PyTorch Dataset for AnnData objects.
|
|
12
|
+
|
|
13
|
+
Supports optional chunked loading for backed AnnData objects. When
|
|
14
|
+
`chunk_size` is provided, the dataset will load contiguous slices
|
|
15
|
+
of rows (of size `chunk_size`) into memory once and serve individual
|
|
16
|
+
rows from that cached chunk. This avoids calling `to_memory()` on
|
|
17
|
+
a per-row basis which is expensive for large backed files.
|
|
18
|
+
"""
|
|
19
|
+
def __init__(self, adata: AnnData, formula: Dict[str, str], chunk_size: int):
|
|
20
|
+
self.adata = adata
|
|
21
|
+
self.formula = formula
|
|
22
|
+
self.chunk_size = chunk_size
|
|
23
|
+
|
|
24
|
+
# keeping track of covariate-related information
|
|
25
|
+
self.obs_levels = categories(self.adata.obs)
|
|
26
|
+
self.obs_matrices = {}
|
|
27
|
+
self.predictor_names = None
|
|
28
|
+
|
|
29
|
+
# Internal cache for the currently loaded chunk
|
|
30
|
+
self._chunk: AnnData | None = None
|
|
31
|
+
self._chunk_start = 0
|
|
32
|
+
|
|
33
|
+
def __len__(self):
|
|
34
|
+
return len(self.adata)
|
|
35
|
+
|
|
36
|
+
def __getitem__(self, idx):
|
|
37
|
+
"""Returns (X, obs) for the given index.
|
|
38
|
+
|
|
39
|
+
If `chunk_size` was specified the dataset will load a chunk
|
|
40
|
+
containing `idx` into memory (if not already cached) and
|
|
41
|
+
index into that chunk.
|
|
42
|
+
"""
|
|
43
|
+
self._ensure_chunk_loaded(idx)
|
|
44
|
+
local_idx = idx - self._chunk_start
|
|
45
|
+
adata_slice = self._chunk[local_idx]
|
|
46
|
+
|
|
47
|
+
# Get X data, accounting for potential sparse matrices
|
|
48
|
+
X = adata_slice.X
|
|
49
|
+
if hasattr(X, 'toarray'):
|
|
50
|
+
X = X.toarray()
|
|
51
|
+
|
|
52
|
+
# Get obs data
|
|
53
|
+
obs_dict = {}
|
|
54
|
+
for key in self.formula.keys():
|
|
55
|
+
mat = self.obs_matrices.get(key)
|
|
56
|
+
obs_dict[key] = to_tensor(mat.values[local_idx: local_idx + 1])
|
|
57
|
+
return to_tensor(X), obs_dict
|
|
58
|
+
|
|
59
|
+
def _ensure_chunk_loaded(self, idx: int) -> None:
|
|
60
|
+
"""Load the chunk that contains `idx` into the internal cache."""
|
|
61
|
+
start = (idx // self.chunk_size) * self.chunk_size
|
|
62
|
+
end = min(start + self.chunk_size, len(self.adata))
|
|
63
|
+
|
|
64
|
+
if (self._chunk is None) or not (self._chunk_start <= idx < self._chunk_start + len(self._chunk)):
|
|
65
|
+
# load the next chunk into memory
|
|
66
|
+
chunk = self.adata[start:end]
|
|
67
|
+
if getattr(chunk, 'isbacked', False):
|
|
68
|
+
chunk = chunk.to_memory()
|
|
69
|
+
self._chunk = chunk
|
|
70
|
+
self._chunk_start = start
|
|
71
|
+
|
|
72
|
+
# Compute model matrices for this chunk's `obs` so we don't need
|
|
73
|
+
# to keep the full obs data model matrices in memory.
|
|
74
|
+
obs_coded_chunk = code_levels(self._chunk.obs.copy(), self.obs_levels)
|
|
75
|
+
self.obs_matrices = {}
|
|
76
|
+
for key, f in self.formula.items():
|
|
77
|
+
self.obs_matrices[key] = model_matrix(f, obs_coded_chunk)
|
|
78
|
+
|
|
79
|
+
# Capture predictor (column) names from the model matrices once.
|
|
80
|
+
if self.predictor_names is None:
|
|
81
|
+
self.predictor_names = {k: list(v.columns) for k, v in self.obs_matrices.items()}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def adata_loader(adata: AnnData,
|
|
85
|
+
formula: Dict[str, str],
|
|
86
|
+
chunk_size: int = None,
|
|
87
|
+
batch_size: int = 1024,
|
|
88
|
+
shuffle: bool = False,
|
|
89
|
+
num_workers: int = 0,
|
|
90
|
+
**kwargs) -> DataLoader:
|
|
91
|
+
"""
|
|
92
|
+
Create a DataLoader from AnnData that returns batches of (X, obs).
|
|
93
|
+
"""
|
|
94
|
+
data_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['data'])
|
|
95
|
+
if chunk_size is None:
|
|
96
|
+
if getattr(adata, 'isbacked', False):
|
|
97
|
+
chunk_size = 5000
|
|
98
|
+
else:
|
|
99
|
+
chunk_size = len(adata)
|
|
100
|
+
|
|
101
|
+
dataset = AnnDataDataset(adata, formula, chunk_size)
|
|
102
|
+
return DataLoader(
|
|
103
|
+
dataset,
|
|
104
|
+
batch_size=batch_size,
|
|
105
|
+
shuffle=shuffle,
|
|
106
|
+
num_workers=num_workers,
|
|
107
|
+
collate_fn=dict_collate_fn,
|
|
108
|
+
**data_kwargs
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def obs_loader(obs: pd.DataFrame, marginal_formula, **kwargs):
|
|
112
|
+
adata = AnnData(X=np.zeros((len(obs), 1)), obs=obs)
|
|
113
|
+
return adata_loader(
|
|
114
|
+
adata,
|
|
115
|
+
marginal_formula,
|
|
116
|
+
**kwargs
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
################################################################################
|
|
120
|
+
## Helper functions
|
|
121
|
+
################################################################################
|
|
122
|
+
|
|
123
|
+
def dict_collate_fn(batch):
|
|
124
|
+
"""
|
|
125
|
+
Custom collate function for handling dictionary obs tensors.
|
|
126
|
+
"""
|
|
127
|
+
X_batch = torch.stack([item[0] for item in batch])
|
|
128
|
+
obs_batch = [item[1] for item in batch]
|
|
129
|
+
|
|
130
|
+
obs_dict = {}
|
|
131
|
+
for key in obs_batch[0].keys():
|
|
132
|
+
obs_dict[key] = torch.stack([obs[key] for obs in obs_batch])
|
|
133
|
+
return X_batch, obs_dict
|
|
134
|
+
|
|
135
|
+
def to_tensor(X):
|
|
136
|
+
# If the tensor is 2D with second dim == 1, squeeze only the first
|
|
137
|
+
# dim when appropriate (e.g. converting a single-row X to 1D samples)
|
|
138
|
+
t = torch.tensor(X, dtype=torch.float32)
|
|
139
|
+
if t.dim() == 2 and t.size(1) == 1:
|
|
140
|
+
if t.size(0) == 1:
|
|
141
|
+
return t.view(1)
|
|
142
|
+
return t
|
|
143
|
+
return t.squeeze()
|
|
144
|
+
|
|
145
|
+
def categories(obs):
|
|
146
|
+
levels = {}
|
|
147
|
+
for k in obs.columns:
|
|
148
|
+
obs_type = str(obs[k].dtype)
|
|
149
|
+
if obs_type in ["category", "object"]:
|
|
150
|
+
levels[k] = obs[k].unique()
|
|
151
|
+
return levels
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def code_levels(obs, categories):
|
|
155
|
+
for k in obs.columns:
|
|
156
|
+
if str(obs[k].dtype) == "category":
|
|
157
|
+
obs[k] = obs[k].astype(pd.CategoricalDtype(categories[k]))
|
|
158
|
+
return obs
|
|
159
|
+
|
|
160
|
+
###############################################################################
|
|
161
|
+
## Misc. Helper functions
|
|
162
|
+
###############################################################################
|
|
163
|
+
|
|
164
|
+
def _to_numpy(*tensors):
|
|
165
|
+
"""Convenience helper: detach, move to CPU, and convert tensors to numpy arrays."""
|
|
166
|
+
return tuple(t.detach().cpu().numpy() for t in tensors)
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
from .kwargs import DEFAULT_ALLOWED_KWARGS, _filter_kwargs
|
|
2
|
+
from .loader import adata_loader
|
|
3
|
+
from anndata import AnnData
|
|
4
|
+
from typing import Union, Dict, Optional, Tuple
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import pytorch_lightning as pl
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Marginal:
|
|
12
|
+
def __init__(self, formula: Union[Dict, str]):
|
|
13
|
+
self.formula = formula
|
|
14
|
+
self.feature_dims = None
|
|
15
|
+
self.loader = None
|
|
16
|
+
self.n_outcomes = None
|
|
17
|
+
self.predict = None
|
|
18
|
+
self.predictor_names = None
|
|
19
|
+
self.parameters = None
|
|
20
|
+
|
|
21
|
+
def setup_data(self, adata: AnnData, batch_size: int = 1024, **kwargs):
|
|
22
|
+
"""Set up the dataloader for the AnnData object."""
|
|
23
|
+
# keep a reference to the AnnData for later use (e.g., var_names)
|
|
24
|
+
self.adata = adata
|
|
25
|
+
self.loader = adata_loader(adata, self.formula, batch_size=batch_size, **kwargs)
|
|
26
|
+
X_batch, obs_batch = next(iter(self.loader))
|
|
27
|
+
self.n_outcomes = X_batch.shape[1]
|
|
28
|
+
self.feature_dims = {k: v.shape[1] for k, v in obs_batch.items()}
|
|
29
|
+
self.predictor_names = self.loader.dataset.predictor_names
|
|
30
|
+
|
|
31
|
+
def fit(self, **kwargs):
|
|
32
|
+
"""Fit the marginal predictor"""
|
|
33
|
+
if self.predict is None:
|
|
34
|
+
self.setup_optimizer(**kwargs)
|
|
35
|
+
trainer_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['trainer'])
|
|
36
|
+
trainer = pl.Trainer(**trainer_kwargs)
|
|
37
|
+
trainer.fit(self.predict, train_dataloaders=self.loader)
|
|
38
|
+
self.parameters = self.format_parameters()
|
|
39
|
+
|
|
40
|
+
def setup_optimizer(self, **kwargs):
|
|
41
|
+
raise NotImplementedError
|
|
42
|
+
|
|
43
|
+
def likelihood(self, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
|
|
44
|
+
"""Compute the (negative) log-likelihood or loss for a batch.
|
|
45
|
+
"""
|
|
46
|
+
raise NotImplementedError
|
|
47
|
+
|
|
48
|
+
def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
|
|
49
|
+
"""Invert pseudoobservations."""
|
|
50
|
+
raise NotImplementedError
|
|
51
|
+
|
|
52
|
+
def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor]):
|
|
53
|
+
"""Uniformize using learned CDF.
|
|
54
|
+
"""
|
|
55
|
+
raise NotImplementedError
|
|
56
|
+
|
|
57
|
+
def format_parameters(self):
|
|
58
|
+
"""Convert fitted coefficient tensors into pandas DataFrames.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
dict: mapping from parameter name -> pandas.DataFrame with rows
|
|
62
|
+
corresponding to predictor column names (from
|
|
63
|
+
`self.predictor_names[param]`) and columns corresponding to
|
|
64
|
+
`self.adata.var_names` (gene names). The values are moved to
|
|
65
|
+
CPU and converted to numpy floats.
|
|
66
|
+
"""
|
|
67
|
+
var_names = list(self.adata.var_names)
|
|
68
|
+
|
|
69
|
+
dfs = {}
|
|
70
|
+
for param, tensor in self.predict.coefs.items():
|
|
71
|
+
coef_np = tensor.detach().cpu().numpy()
|
|
72
|
+
row_names = list(self.predictor_names[param])
|
|
73
|
+
dfs[param] = pd.DataFrame(coef_np, index=row_names, columns=var_names)
|
|
74
|
+
return dfs
|
|
75
|
+
|
|
76
|
+
def num_params(self):
|
|
77
|
+
"""Return the number of parameters."""
|
|
78
|
+
if self.predict is None:
|
|
79
|
+
return 0
|
|
80
|
+
return sum(p.numel() for p in self.predict.parameters() if p.requires_grad)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class GLMPredictor(pl.LightningModule):
|
|
84
|
+
"""GLM-style predictor with arbitrary named parameters.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
n_outcomes: number of model outputs (e.g. genes)
|
|
88
|
+
feature_dims: mapping from param name -> number of covariate features
|
|
89
|
+
link_fns: optional mapping from param name -> callable(link) applied to linear predictor
|
|
90
|
+
|
|
91
|
+
The module will create one coefficient matrix per named parameter with shape
|
|
92
|
+
(n_features_for_param, n_outcomes) and expose them as Parameters under
|
|
93
|
+
`self.coefs[param_name]`.
|
|
94
|
+
"""
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
n_outcomes: int,
|
|
98
|
+
feature_dims: Dict[str, int],
|
|
99
|
+
link_fns: Dict[str, callable] = None,
|
|
100
|
+
loss_fn: Optional[callable] = None,
|
|
101
|
+
optimizer_class: Optional[callable] = torch.optim.Adam,
|
|
102
|
+
optimizer_kwargs: Optional[Dict] = None,
|
|
103
|
+
):
|
|
104
|
+
super().__init__()
|
|
105
|
+
self.n_outcomes = int(n_outcomes)
|
|
106
|
+
self.feature_dims = dict(feature_dims)
|
|
107
|
+
self.param_names = list(self.feature_dims.keys())
|
|
108
|
+
|
|
109
|
+
# create default link functions and parameter matrices
|
|
110
|
+
self.link_fns = link_fns or {k: torch.exp for k in self.param_names}
|
|
111
|
+
self.coefs = nn.ParameterDict()
|
|
112
|
+
for key, dim in self.feature_dims.items():
|
|
113
|
+
self.coefs[key] = nn.Parameter(torch.zeros(dim, self.n_outcomes))
|
|
114
|
+
|
|
115
|
+
# optimization parameters
|
|
116
|
+
self.reset_parameters()
|
|
117
|
+
self.loss_fn = loss_fn
|
|
118
|
+
self.optimizer_class = optimizer_class
|
|
119
|
+
self.optimizer_kwargs = optimizer_kwargs
|
|
120
|
+
|
|
121
|
+
def reset_parameters(self):
|
|
122
|
+
for p in self.coefs.values():
|
|
123
|
+
nn.init.normal_(p, mean=0.0, std=1e-2)
|
|
124
|
+
|
|
125
|
+
def forward(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
126
|
+
out = {}
|
|
127
|
+
for name in self.param_names:
|
|
128
|
+
x_beta = obs_dict[name] @ self.coefs[name]
|
|
129
|
+
link = self.link_fns.get(name, torch.exp)
|
|
130
|
+
out[name] = link(x_beta)
|
|
131
|
+
return out
|
|
132
|
+
|
|
133
|
+
def training_step(self, batch):
|
|
134
|
+
loss = self.loss_fn(batch)
|
|
135
|
+
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
|
136
|
+
return loss
|
|
137
|
+
|
|
138
|
+
def configure_optimizers(self, **kwargs):
|
|
139
|
+
optimizer_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['optimizer'])
|
|
140
|
+
return self.optimizer_class(self.parameters(), **optimizer_kwargs)
|