scdesigner 0.0.5__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdesigner/base/__init__.py +8 -0
- scdesigner/base/copula.py +416 -0
- scdesigner/base/marginal.py +391 -0
- scdesigner/base/simulator.py +59 -0
- scdesigner/copulas/__init__.py +8 -0
- scdesigner/copulas/standard_copula.py +645 -0
- scdesigner/datasets/__init__.py +5 -0
- scdesigner/datasets/pancreas.py +39 -0
- scdesigner/distributions/__init__.py +19 -0
- scdesigner/{minimal → distributions}/bernoulli.py +42 -14
- scdesigner/distributions/gaussian.py +114 -0
- scdesigner/distributions/negbin.py +121 -0
- scdesigner/distributions/negbin_irls.py +72 -0
- scdesigner/distributions/negbin_irls_funs.py +456 -0
- scdesigner/distributions/poisson.py +88 -0
- scdesigner/{minimal → distributions}/zero_inflated_negbin.py +39 -10
- scdesigner/distributions/zero_inflated_poisson.py +103 -0
- scdesigner/simulators/__init__.py +24 -28
- scdesigner/simulators/composite.py +239 -0
- scdesigner/simulators/positive_nonnegative_matrix_factorization.py +477 -0
- scdesigner/simulators/scd3.py +486 -0
- scdesigner/transform/__init__.py +8 -6
- scdesigner/{minimal → transform}/transform.py +1 -1
- scdesigner/{minimal → utils}/kwargs.py +4 -1
- {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/METADATA +1 -1
- scdesigner-0.0.10.dist-info/RECORD +28 -0
- {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/WHEEL +1 -1
- scdesigner/data/__init__.py +0 -16
- scdesigner/data/formula.py +0 -137
- scdesigner/data/group.py +0 -123
- scdesigner/data/sparse.py +0 -39
- scdesigner/diagnose/__init__.py +0 -65
- scdesigner/diagnose/aic_bic.py +0 -119
- scdesigner/diagnose/plot.py +0 -242
- scdesigner/estimators/__init__.py +0 -32
- scdesigner/estimators/bernoulli.py +0 -85
- scdesigner/estimators/gaussian.py +0 -121
- scdesigner/estimators/gaussian_copula_factory.py +0 -367
- scdesigner/estimators/glm_factory.py +0 -75
- scdesigner/estimators/negbin.py +0 -153
- scdesigner/estimators/pnmf.py +0 -160
- scdesigner/estimators/poisson.py +0 -124
- scdesigner/estimators/zero_inflated_negbin.py +0 -195
- scdesigner/estimators/zero_inflated_poisson.py +0 -85
- scdesigner/format/__init__.py +0 -4
- scdesigner/format/format.py +0 -20
- scdesigner/format/print.py +0 -30
- scdesigner/minimal/__init__.py +0 -17
- scdesigner/minimal/composite.py +0 -119
- scdesigner/minimal/copula.py +0 -205
- scdesigner/minimal/formula.py +0 -23
- scdesigner/minimal/gaussian.py +0 -65
- scdesigner/minimal/loader.py +0 -211
- scdesigner/minimal/marginal.py +0 -154
- scdesigner/minimal/negbin.py +0 -73
- scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -231
- scdesigner/minimal/scd3.py +0 -96
- scdesigner/minimal/scd3_instances.py +0 -50
- scdesigner/minimal/simulator.py +0 -25
- scdesigner/minimal/standard_copula.py +0 -383
- scdesigner/predictors/__init__.py +0 -15
- scdesigner/predictors/bernoulli.py +0 -9
- scdesigner/predictors/gaussian.py +0 -16
- scdesigner/predictors/negbin.py +0 -17
- scdesigner/predictors/poisson.py +0 -12
- scdesigner/predictors/zero_inflated_negbin.py +0 -18
- scdesigner/predictors/zero_inflated_poisson.py +0 -18
- scdesigner/samplers/__init__.py +0 -23
- scdesigner/samplers/bernoulli.py +0 -27
- scdesigner/samplers/gaussian.py +0 -25
- scdesigner/samplers/glm_factory.py +0 -103
- scdesigner/samplers/negbin.py +0 -25
- scdesigner/samplers/poisson.py +0 -25
- scdesigner/samplers/zero_inflated_negbin.py +0 -40
- scdesigner/samplers/zero_inflated_poisson.py +0 -16
- scdesigner/simulators/composite_regressor.py +0 -72
- scdesigner/simulators/glm_simulator.py +0 -167
- scdesigner/simulators/pnmf_regression.py +0 -61
- scdesigner/transform/amplify.py +0 -14
- scdesigner/transform/mask.py +0 -33
- scdesigner/transform/nullify.py +0 -25
- scdesigner/transform/split.py +0 -23
- scdesigner/transform/substitute.py +0 -14
- scdesigner-0.0.5.dist-info/RECORD +0 -66
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from ..data.formula import standardize_formula
|
|
2
|
+
from ..base.marginal import GLMPredictor, Marginal
|
|
3
|
+
from ..data.loader import _to_numpy
|
|
4
|
+
from typing import Union, Dict, Optional, Tuple
|
|
5
|
+
import torch
|
|
6
|
+
import numpy as np
|
|
7
|
+
from scipy.stats import poisson, bernoulli
|
|
8
|
+
|
|
9
|
+
class ZeroInflatedPoisson(Marginal):
|
|
10
|
+
"""Zero-Inflated Poisson marginal estimator
|
|
11
|
+
|
|
12
|
+
This subclass models counts with an explicit zero-inflation component.
|
|
13
|
+
For each feature j the observation follows a mixture: with probability
|
|
14
|
+
`pi_j(x)` the value is an extra zero, otherwise the count is drawn from
|
|
15
|
+
a Poisson distribution with mean `mu_j(x)`. Both `mu_j(x)` and the
|
|
16
|
+
inflation probability `pi_j(x)` may depend on covariates `x` through the
|
|
17
|
+
`formula` argument.
|
|
18
|
+
|
|
19
|
+
The allowed formula keys are 'mean' and 'zero_inflation'. If a string
|
|
20
|
+
formula is supplied it is taken to specify the `mean` by default.
|
|
21
|
+
|
|
22
|
+
Examples
|
|
23
|
+
--------
|
|
24
|
+
>>> from scdesigner.distributions import ZeroInflatedPoisson
|
|
25
|
+
>>> from scdesigner.datasets import pancreas
|
|
26
|
+
>>>
|
|
27
|
+
>>> sim = ZeroInflatedPoisson(formula={"mean": "~ pseudotime", "zero_inflation": "~ pseudotime"})
|
|
28
|
+
>>> sim.setup_data(pancreas)
|
|
29
|
+
>>> sim.fit(max_epochs=1, verbose=False)
|
|
30
|
+
>>>
|
|
31
|
+
>>> # evaluate p(y | x) and model parameters
|
|
32
|
+
>>> y, x = next(iter(sim.loader))
|
|
33
|
+
>>> l = sim.likelihood((y, x))
|
|
34
|
+
>>> y_hat = sim.predict(x)
|
|
35
|
+
>>>
|
|
36
|
+
>>> # convert to quantiles and back
|
|
37
|
+
>>> u = sim.uniformize(y, x)
|
|
38
|
+
>>> x_star = sim.invert(u, x)
|
|
39
|
+
"""
|
|
40
|
+
def __init__(self, formula: Union[Dict, str]):
|
|
41
|
+
formula = standardize_formula(formula, allowed_keys=['mean', 'zero_inflation'])
|
|
42
|
+
super().__init__(formula)
|
|
43
|
+
|
|
44
|
+
def setup_optimizer(
|
|
45
|
+
self,
|
|
46
|
+
optimizer_class: Optional[callable] = torch.optim.Adam,
|
|
47
|
+
**optimizer_kwargs,
|
|
48
|
+
):
|
|
49
|
+
if self.loader is None:
|
|
50
|
+
raise RuntimeError("self.loader is not set (call setup_data first)")
|
|
51
|
+
|
|
52
|
+
link_funs = {
|
|
53
|
+
"mean": torch.exp,
|
|
54
|
+
"zero_inflation": torch.sigmoid,
|
|
55
|
+
}
|
|
56
|
+
def nll(batch):
|
|
57
|
+
return -self.likelihood(batch).sum()
|
|
58
|
+
self.predict = GLMPredictor(
|
|
59
|
+
n_outcomes=self.n_outcomes,
|
|
60
|
+
feature_dims=self.feature_dims,
|
|
61
|
+
link_fns=link_funs,
|
|
62
|
+
loss_fn=nll,
|
|
63
|
+
optimizer_class=optimizer_class,
|
|
64
|
+
optimizer_kwargs=optimizer_kwargs
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def likelihood(self, batch) -> torch.Tensor:
|
|
68
|
+
"""Compute the log-likelihood"""
|
|
69
|
+
y, x = batch
|
|
70
|
+
params = self.predict(x)
|
|
71
|
+
mu = params.get("mean")
|
|
72
|
+
pi = params.get("zero_inflation")
|
|
73
|
+
|
|
74
|
+
poisson_loglikelihood = y * torch.log(mu + 1e-10) - mu - torch.lgamma(y + 1)
|
|
75
|
+
return torch.log(
|
|
76
|
+
pi * (y == 0) + (1 - pi) * torch.exp(poisson_loglikelihood) + 1e-10
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
80
|
+
"""Invert pseudoobservations."""
|
|
81
|
+
mu, pi, u = self._local_params(x, u)
|
|
82
|
+
y = poisson(mu).ppf(u)
|
|
83
|
+
delta = bernoulli(1 - pi).ppf(u)
|
|
84
|
+
return torch.from_numpy(y * delta).float()
|
|
85
|
+
|
|
86
|
+
def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6) -> torch.Tensor:
|
|
87
|
+
"""Return uniformized pseudo-observations for counts y given covariates x."""
|
|
88
|
+
# cdf values using scipy's parameterization
|
|
89
|
+
mu, pi, y = self._local_params(x, y)
|
|
90
|
+
nb_distn = poisson(mu)
|
|
91
|
+
u1 = pi + (1 - pi) * nb_distn.cdf(y)
|
|
92
|
+
u2 = np.where(y > 0, pi + (1 - pi) * nb_distn.cdf(y-1), 0)
|
|
93
|
+
v = np.random.uniform(size=y.shape)
|
|
94
|
+
u = np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
|
|
95
|
+
return torch.from_numpy(u).float()
|
|
96
|
+
|
|
97
|
+
def _local_params(self, x, y=None) -> Tuple:
|
|
98
|
+
params = self.predict(x)
|
|
99
|
+
mu = params.get('mean')
|
|
100
|
+
pi = params.get('zero_inflation')
|
|
101
|
+
if y is None:
|
|
102
|
+
return _to_numpy(mu, pi)
|
|
103
|
+
return _to_numpy(mu, pi, y)
|
|
@@ -1,31 +1,27 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
ZeroInflatedNegBinCopulaSimulator,
|
|
12
|
-
ZeroInflatedNegBinRegressionSimulator,
|
|
13
|
-
ZeroInflatedPoissonRegressionSimulator,
|
|
1
|
+
"""Simulator classes"""
|
|
2
|
+
|
|
3
|
+
from .scd3 import (
|
|
4
|
+
BernoulliCopula,
|
|
5
|
+
GaussianCopula,
|
|
6
|
+
NegBinCopula,
|
|
7
|
+
NegBinIRLSCopula,
|
|
8
|
+
PoissonCopula,
|
|
9
|
+
ZeroInflatedNegBinCopula,
|
|
10
|
+
ZeroInflatedPoissonCopula
|
|
14
11
|
)
|
|
15
|
-
from .
|
|
12
|
+
from .composite import CompositeCopula
|
|
13
|
+
from .positive_nonnegative_matrix_factorization import PositiveNMF
|
|
16
14
|
|
|
17
15
|
__all__ = [
|
|
18
|
-
"
|
|
19
|
-
"
|
|
20
|
-
"
|
|
21
|
-
"
|
|
22
|
-
"
|
|
23
|
-
"
|
|
24
|
-
"
|
|
25
|
-
"
|
|
26
|
-
"
|
|
27
|
-
"
|
|
28
|
-
"
|
|
29
|
-
|
|
30
|
-
"ZeroInflatedPoissonRegressionSimulator",
|
|
31
|
-
]
|
|
16
|
+
"BernoulliCopula",
|
|
17
|
+
"CompositeCopula",
|
|
18
|
+
"GaussianCopula",
|
|
19
|
+
"NegBinCopula",
|
|
20
|
+
"NegBinCopula",
|
|
21
|
+
"NegBinIRLSCopula",
|
|
22
|
+
"NegBinInitCopula",
|
|
23
|
+
"PoissonCopula",
|
|
24
|
+
"PositiveNMF",
|
|
25
|
+
"ZeroInflatedNegBinCopula",
|
|
26
|
+
"ZeroInflatedPoissonCopula"
|
|
27
|
+
]
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
"""Composite simulator that combines multiple marginals with a Gaussian copula.
|
|
2
|
+
|
|
3
|
+
This module provides :class:`CompositeCopula`, a simulator that fits several
|
|
4
|
+
marginal models and then couples their dependence structure with a
|
|
5
|
+
:class:`~scdesigner.copulas.standard_copula.StandardCopula`.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from ..data.loader import obs_loader
|
|
9
|
+
from .scd3 import SCD3Simulator
|
|
10
|
+
from ..copulas.standard_copula import StandardCopula
|
|
11
|
+
from anndata import AnnData
|
|
12
|
+
from typing import Dict, Optional, List
|
|
13
|
+
import numpy as np
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
class CompositeCopula(SCD3Simulator):
|
|
17
|
+
"""
|
|
18
|
+
Composite simulator: multiple marginals + a shared Gaussian copula.
|
|
19
|
+
|
|
20
|
+
The composite simulator fits each marginal model independently on a
|
|
21
|
+
(potentially different) subset of variables, and then fits a Gaussian
|
|
22
|
+
copula on the *merged* uniformized outputs from all marginals to capture
|
|
23
|
+
cross-feature dependence.
|
|
24
|
+
|
|
25
|
+
Each marginal is provided as a pair ``(sel, marginal)`` where:
|
|
26
|
+
|
|
27
|
+
- ``sel`` selects which variables in ``adata`` the marginal is responsible
|
|
28
|
+
for (e.g. a list of gene names, a single gene name).
|
|
29
|
+
- ``marginal`` is an object implementing the marginal simulator interface
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
marginals : list
|
|
34
|
+
List of ``(sel, marginal)`` pairs.
|
|
35
|
+
copula_formula : str, optional
|
|
36
|
+
Formula passed to :class:`~scdesigner.copulas.standard_copula.StandardCopula`
|
|
37
|
+
to determine copula grouping structure (e.g. ``"group ~ 1"``). If
|
|
38
|
+
``None``, uses the copula's default.
|
|
39
|
+
|
|
40
|
+
Attributes
|
|
41
|
+
----------
|
|
42
|
+
marginals : list
|
|
43
|
+
The provided marginal specifications.
|
|
44
|
+
copula : StandardCopula
|
|
45
|
+
The fitted copula component.
|
|
46
|
+
template : AnnData or None
|
|
47
|
+
Training dataset (set during :meth:`fit`).
|
|
48
|
+
parameters : dict or None
|
|
49
|
+
Fitted parameters, with keys ``"marginal"`` and ``"copula"``.
|
|
50
|
+
merged_formula : dict or None
|
|
51
|
+
Merged (prefixed) formula dictionary used to construct the copula data loader.
|
|
52
|
+
|
|
53
|
+
Examples
|
|
54
|
+
--------
|
|
55
|
+
Fit two marginal models on disjoint gene sets and then fit a copula:
|
|
56
|
+
|
|
57
|
+
>>> import numpy as np
|
|
58
|
+
>>> import pandas as pd
|
|
59
|
+
>>> from anndata import AnnData
|
|
60
|
+
>>> from scdesigner.simulators import CompositeCopula
|
|
61
|
+
>>> from scdesigner.distributions import NegBin, Poisson
|
|
62
|
+
>>>
|
|
63
|
+
>>> X = np.random.poisson(1.0, size=(100, 10)).astype(float)
|
|
64
|
+
>>> obs = pd.DataFrame({"cell_type": np.random.choice(["A", "B"], size=100)})
|
|
65
|
+
>>> adata = AnnData(X=X, obs=obs)
|
|
66
|
+
>>> adata.var_names = [f"g{i}" for i in range(adata.n_vars)]
|
|
67
|
+
>>>
|
|
68
|
+
>>> # Example selectors: first 5 genes vs last 5 genes
|
|
69
|
+
>>> sel1 = adata.var_names[:5].tolist()
|
|
70
|
+
>>> sel2 = adata.var_names[5:].tolist()
|
|
71
|
+
>>> m1 = NegBin(formula={"mean": "~ cell_type", "dispersion": "~ 1"})
|
|
72
|
+
>>> m2 = Poisson(formula={"mean": "~ cell_type"})
|
|
73
|
+
>>>
|
|
74
|
+
>>> composite = CompositeCopula([(sel1, m1), (sel2, m2)])
|
|
75
|
+
>>> composite.fit(adata, batch_size=256, verbose=False)
|
|
76
|
+
>>> params = composite.predict(adata.obs.iloc[:3], batch_size=3)
|
|
77
|
+
"""
|
|
78
|
+
def __init__(self, marginals: List,
|
|
79
|
+
copula_formula: Optional[str] = None) -> None:
|
|
80
|
+
"""Create a composite simulator.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
marginals : list
|
|
85
|
+
List of ``(sel, marginal)`` pairs. See class docstring for details.
|
|
86
|
+
copula_formula : str, optional
|
|
87
|
+
Copula grouping formula passed to :class:`StandardCopula`.
|
|
88
|
+
"""
|
|
89
|
+
self.marginals = marginals
|
|
90
|
+
self.copula = StandardCopula(copula_formula) if copula_formula is not None else StandardCopula()
|
|
91
|
+
self.template = None
|
|
92
|
+
self.parameters = None
|
|
93
|
+
self.merged_formula = None
|
|
94
|
+
|
|
95
|
+
def fit(
|
|
96
|
+
self,
|
|
97
|
+
adata: AnnData,
|
|
98
|
+
verbose: bool = True,
|
|
99
|
+
**kwargs,):
|
|
100
|
+
"""Fit all marginals and then fit the copula on merged uniforms.
|
|
101
|
+
|
|
102
|
+
Parameters
|
|
103
|
+
----------
|
|
104
|
+
adata : AnnData
|
|
105
|
+
Training dataset.
|
|
106
|
+
**kwargs
|
|
107
|
+
Additional keyword arguments forwarded to marginal setup/fit methods
|
|
108
|
+
and to the copula's ``setup_data`` / ``fit`` calls (e.g.
|
|
109
|
+
``batch_size``).
|
|
110
|
+
verbose : bool, optional
|
|
111
|
+
Whether to print verbose output.
|
|
112
|
+
"""
|
|
113
|
+
self.template = adata
|
|
114
|
+
merged_formula = {}
|
|
115
|
+
|
|
116
|
+
# fit each marginal model
|
|
117
|
+
for m in range(len(self.marginals)):
|
|
118
|
+
self.marginals[m][1].setup_data(adata[:, self.marginals[m][0]], **kwargs)
|
|
119
|
+
self.marginals[m][1].setup_optimizer(**kwargs)
|
|
120
|
+
self.marginals[m][1].fit(**kwargs, verbose=verbose)
|
|
121
|
+
|
|
122
|
+
# prepare formula for copula loader
|
|
123
|
+
f = self.marginals[m][1].formula
|
|
124
|
+
prefixed_f = {f"group{m}_{k}": v for k, v in f.items()}
|
|
125
|
+
merged_formula = merged_formula | prefixed_f
|
|
126
|
+
|
|
127
|
+
# copula simulator
|
|
128
|
+
self.merged_formula = merged_formula
|
|
129
|
+
self.copula.setup_data(adata, merged_formula, **kwargs)
|
|
130
|
+
self.copula.fit(self.merged_uniformize, **kwargs)
|
|
131
|
+
self.parameters = {
|
|
132
|
+
"marginal": [m[1].parameters for m in self.marginals],
|
|
133
|
+
"copula": self.copula.parameters
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
def merged_uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
137
|
+
"""Produce a merged uniformized matrix for all marginals.
|
|
138
|
+
|
|
139
|
+
Delegates to each marginal's `uniformize` method and places the
|
|
140
|
+
result into the columns of a full matrix according to the variable
|
|
141
|
+
selection given in `self.marginals[m][0]`.
|
|
142
|
+
"""
|
|
143
|
+
y_np = y.detach().cpu().numpy()
|
|
144
|
+
u = np.empty_like(y_np, dtype=float)
|
|
145
|
+
|
|
146
|
+
for m in range(len(self.marginals)):
|
|
147
|
+
sel = self.marginals[m][0]
|
|
148
|
+
ix = _var_indices(sel, self.template)
|
|
149
|
+
|
|
150
|
+
# remove the `group{m}_` prefix we used to distinguish the marginals
|
|
151
|
+
prefix = f"group{m}_"
|
|
152
|
+
cur_x = {k.removeprefix(prefix): v if k.startswith(prefix) else v for k, v in x.items()}
|
|
153
|
+
|
|
154
|
+
# slice the subset of y for this marginal and call its uniformize
|
|
155
|
+
y_sub = torch.from_numpy(y_np[:, ix])
|
|
156
|
+
u[:, ix] = self.marginals[m][1].uniformize(y_sub, cur_x)
|
|
157
|
+
return torch.from_numpy(u)
|
|
158
|
+
|
|
159
|
+
def predict(self, obs=None, batch_size: int = 1000, **kwargs):
|
|
160
|
+
"""Predict marginal parameters for observations (batched).
|
|
161
|
+
|
|
162
|
+
This method constructs an internal loader for ``obs`` using the merged
|
|
163
|
+
(prefixed) formula dictionary, then dispatches per-marginal ``predict``
|
|
164
|
+
calls on each batch after stripping the prefixes.
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
obs : pandas.DataFrame, optional
|
|
169
|
+
Observation metadata. Defaults to ``self.template.obs``.
|
|
170
|
+
batch_size : int, optional
|
|
171
|
+
Batch size for the internal observation loader.
|
|
172
|
+
**kwargs
|
|
173
|
+
Forwarded to :func:`~scdesigner.data.loader.obs_loader`.
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
list[dict[str, np.ndarray]]
|
|
178
|
+
List with one element per marginal. Each element is a dict mapping
|
|
179
|
+
parameter names to numpy arrays, concatenated across batches.
|
|
180
|
+
"""
|
|
181
|
+
# prepare an internal data loader for this obs
|
|
182
|
+
if obs is None:
|
|
183
|
+
obs = self.template.obs
|
|
184
|
+
loader = obs_loader(
|
|
185
|
+
obs,
|
|
186
|
+
self.merged_formula,
|
|
187
|
+
batch_size=batch_size,
|
|
188
|
+
**kwargs
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# prepare per-marginal collectors
|
|
192
|
+
n_marginals = len(self.marginals)
|
|
193
|
+
local_pred = [[] for _ in range(n_marginals)]
|
|
194
|
+
|
|
195
|
+
# for each batch, call each marginal's predict on its subset of x
|
|
196
|
+
for _, x_dict in loader:
|
|
197
|
+
for m in range(n_marginals):
|
|
198
|
+
prefix = f"group{m}_"
|
|
199
|
+
# build cur_x where prefixed keys are unprefixed for the marginal
|
|
200
|
+
cur_x = {k.removeprefix(prefix): v for k, v in x_dict.items()}
|
|
201
|
+
params = self.marginals[m][1].predict(cur_x)
|
|
202
|
+
local_pred[m].append(params)
|
|
203
|
+
|
|
204
|
+
# merge batch-wise parameter dicts for each marginal and return
|
|
205
|
+
results = []
|
|
206
|
+
for m in range(n_marginals):
|
|
207
|
+
parts = local_pred[m]
|
|
208
|
+
keys = list(parts[0].keys())
|
|
209
|
+
results.append({k: torch.cat([d[k] for d in parts]).detach().cpu().numpy() for k in keys})
|
|
210
|
+
|
|
211
|
+
return results
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _var_indices(sel, adata: AnnData) -> np.ndarray:
|
|
215
|
+
"""Return integer indices of ``sel`` within ``adata.var_names``.
|
|
216
|
+
|
|
217
|
+
Parameters
|
|
218
|
+
----------
|
|
219
|
+
sel : str or list of str
|
|
220
|
+
The variable names to select.
|
|
221
|
+
adata : AnnData
|
|
222
|
+
The AnnData object to select variables from.
|
|
223
|
+
|
|
224
|
+
Returns
|
|
225
|
+
-------
|
|
226
|
+
np.ndarray
|
|
227
|
+
The integer indices of the selected variables.
|
|
228
|
+
"""
|
|
229
|
+
# If sel is a single string, make it a list so we return consistent shape
|
|
230
|
+
single_string = False
|
|
231
|
+
if isinstance(sel, str):
|
|
232
|
+
sel = [sel]
|
|
233
|
+
single_string = True
|
|
234
|
+
|
|
235
|
+
idx = np.asarray(adata.var_names.get_indexer(sel), dtype=int)
|
|
236
|
+
if (idx < 0).any():
|
|
237
|
+
missing = [s for s, i in zip(sel, idx) if i < 0]
|
|
238
|
+
raise KeyError(f"Variables not found in adata.var_names: {missing}")
|
|
239
|
+
return idx if not single_string else idx.reshape(-1)
|