scdesigner 0.0.5__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdesigner/base/__init__.py +8 -0
- scdesigner/base/copula.py +416 -0
- scdesigner/base/marginal.py +391 -0
- scdesigner/base/simulator.py +59 -0
- scdesigner/copulas/__init__.py +8 -0
- scdesigner/copulas/standard_copula.py +645 -0
- scdesigner/datasets/__init__.py +5 -0
- scdesigner/datasets/pancreas.py +39 -0
- scdesigner/distributions/__init__.py +19 -0
- scdesigner/{minimal → distributions}/bernoulli.py +42 -14
- scdesigner/distributions/gaussian.py +114 -0
- scdesigner/distributions/negbin.py +121 -0
- scdesigner/distributions/negbin_irls.py +72 -0
- scdesigner/distributions/negbin_irls_funs.py +456 -0
- scdesigner/distributions/poisson.py +88 -0
- scdesigner/{minimal → distributions}/zero_inflated_negbin.py +39 -10
- scdesigner/distributions/zero_inflated_poisson.py +103 -0
- scdesigner/simulators/__init__.py +24 -28
- scdesigner/simulators/composite.py +239 -0
- scdesigner/simulators/positive_nonnegative_matrix_factorization.py +477 -0
- scdesigner/simulators/scd3.py +486 -0
- scdesigner/transform/__init__.py +8 -6
- scdesigner/{minimal → transform}/transform.py +1 -1
- scdesigner/{minimal → utils}/kwargs.py +4 -1
- {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/METADATA +1 -1
- scdesigner-0.0.10.dist-info/RECORD +28 -0
- {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/WHEEL +1 -1
- scdesigner/data/__init__.py +0 -16
- scdesigner/data/formula.py +0 -137
- scdesigner/data/group.py +0 -123
- scdesigner/data/sparse.py +0 -39
- scdesigner/diagnose/__init__.py +0 -65
- scdesigner/diagnose/aic_bic.py +0 -119
- scdesigner/diagnose/plot.py +0 -242
- scdesigner/estimators/__init__.py +0 -32
- scdesigner/estimators/bernoulli.py +0 -85
- scdesigner/estimators/gaussian.py +0 -121
- scdesigner/estimators/gaussian_copula_factory.py +0 -367
- scdesigner/estimators/glm_factory.py +0 -75
- scdesigner/estimators/negbin.py +0 -153
- scdesigner/estimators/pnmf.py +0 -160
- scdesigner/estimators/poisson.py +0 -124
- scdesigner/estimators/zero_inflated_negbin.py +0 -195
- scdesigner/estimators/zero_inflated_poisson.py +0 -85
- scdesigner/format/__init__.py +0 -4
- scdesigner/format/format.py +0 -20
- scdesigner/format/print.py +0 -30
- scdesigner/minimal/__init__.py +0 -17
- scdesigner/minimal/composite.py +0 -119
- scdesigner/minimal/copula.py +0 -205
- scdesigner/minimal/formula.py +0 -23
- scdesigner/minimal/gaussian.py +0 -65
- scdesigner/minimal/loader.py +0 -211
- scdesigner/minimal/marginal.py +0 -154
- scdesigner/minimal/negbin.py +0 -73
- scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -231
- scdesigner/minimal/scd3.py +0 -96
- scdesigner/minimal/scd3_instances.py +0 -50
- scdesigner/minimal/simulator.py +0 -25
- scdesigner/minimal/standard_copula.py +0 -383
- scdesigner/predictors/__init__.py +0 -15
- scdesigner/predictors/bernoulli.py +0 -9
- scdesigner/predictors/gaussian.py +0 -16
- scdesigner/predictors/negbin.py +0 -17
- scdesigner/predictors/poisson.py +0 -12
- scdesigner/predictors/zero_inflated_negbin.py +0 -18
- scdesigner/predictors/zero_inflated_poisson.py +0 -18
- scdesigner/samplers/__init__.py +0 -23
- scdesigner/samplers/bernoulli.py +0 -27
- scdesigner/samplers/gaussian.py +0 -25
- scdesigner/samplers/glm_factory.py +0 -103
- scdesigner/samplers/negbin.py +0 -25
- scdesigner/samplers/poisson.py +0 -25
- scdesigner/samplers/zero_inflated_negbin.py +0 -40
- scdesigner/samplers/zero_inflated_poisson.py +0 -16
- scdesigner/simulators/composite_regressor.py +0 -72
- scdesigner/simulators/glm_simulator.py +0 -167
- scdesigner/simulators/pnmf_regression.py +0 -61
- scdesigner/transform/amplify.py +0 -14
- scdesigner/transform/mask.py +0 -33
- scdesigner/transform/nullify.py +0 -25
- scdesigner/transform/split.py +0 -23
- scdesigner/transform/substitute.py +0 -14
- scdesigner-0.0.5.dist-info/RECORD +0 -66
scdesigner/minimal/composite.py
DELETED
|
@@ -1,119 +0,0 @@
|
|
|
1
|
-
from .loader import obs_loader
|
|
2
|
-
from .scd3 import SCD3Simulator
|
|
3
|
-
from .standard_copula import StandardCopula
|
|
4
|
-
from anndata import AnnData
|
|
5
|
-
from typing import Dict, Optional, List
|
|
6
|
-
import numpy as np
|
|
7
|
-
import torch
|
|
8
|
-
|
|
9
|
-
class CompositeCopula(SCD3Simulator):
|
|
10
|
-
def __init__(self, marginals: List,
|
|
11
|
-
copula_formula: Optional[str] = None) -> None:
|
|
12
|
-
self.marginals = marginals
|
|
13
|
-
self.copula = StandardCopula(copula_formula)
|
|
14
|
-
self.template = None
|
|
15
|
-
self.parameters = None
|
|
16
|
-
self.merged_formula = None
|
|
17
|
-
|
|
18
|
-
def fit(
|
|
19
|
-
self,
|
|
20
|
-
adata: AnnData,
|
|
21
|
-
**kwargs):
|
|
22
|
-
"""Fit the simulator"""
|
|
23
|
-
self.template = adata
|
|
24
|
-
merged_formula = {}
|
|
25
|
-
|
|
26
|
-
# fit each marginal model
|
|
27
|
-
for m in range(len(self.marginals)):
|
|
28
|
-
self.marginals[m][1].setup_data(adata[:, self.marginals[m][0]], **kwargs)
|
|
29
|
-
self.marginals[m][1].setup_optimizer(**kwargs)
|
|
30
|
-
self.marginals[m][1].fit(**kwargs)
|
|
31
|
-
|
|
32
|
-
# prepare formula for copula loader
|
|
33
|
-
f = self.marginals[m][1].formula
|
|
34
|
-
prefixed_f = {f"group{m}_{k}": v for k, v in f.items()}
|
|
35
|
-
merged_formula = merged_formula | prefixed_f
|
|
36
|
-
|
|
37
|
-
# copula simulator
|
|
38
|
-
self.merged_formula = merged_formula
|
|
39
|
-
self.copula.setup_data(adata, merged_formula, **kwargs)
|
|
40
|
-
self.copula.fit(self.merged_uniformize, **kwargs)
|
|
41
|
-
self.parameters = {
|
|
42
|
-
"marginal": [m[1].parameters for m in self.marginals],
|
|
43
|
-
"copula": self.copula.parameters
|
|
44
|
-
}
|
|
45
|
-
|
|
46
|
-
def merged_uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
47
|
-
"""Produce a merged uniformized matrix for all marginals.
|
|
48
|
-
|
|
49
|
-
Delegates to each marginal's `uniformize` method and places the
|
|
50
|
-
result into the columns of a full matrix according to the variable
|
|
51
|
-
selection given in `self.marginals[m][0]`.
|
|
52
|
-
"""
|
|
53
|
-
y_np = y.detach().cpu().numpy()
|
|
54
|
-
u = np.empty_like(y_np, dtype=float)
|
|
55
|
-
|
|
56
|
-
for m in range(len(self.marginals)):
|
|
57
|
-
sel = self.marginals[m][0]
|
|
58
|
-
ix = _var_indices(sel, self.template)
|
|
59
|
-
|
|
60
|
-
# remove the `group{m}_` prefix we used to distinguish the marginals
|
|
61
|
-
prefix = f"group{m}_"
|
|
62
|
-
cur_x = {k.removeprefix(prefix): v if k.startswith(prefix) else v for k, v in x.items()}
|
|
63
|
-
|
|
64
|
-
# slice the subset of y for this marginal and call its uniformize
|
|
65
|
-
y_sub = torch.from_numpy(y_np[:, ix])
|
|
66
|
-
u[:, ix] = self.marginals[m][1].uniformize(y_sub, cur_x)
|
|
67
|
-
return torch.from_numpy(u)
|
|
68
|
-
|
|
69
|
-
def predict(self, obs=None, batch_size: int = 1000, **kwargs):
|
|
70
|
-
"""Predict from an obs dataframe"""
|
|
71
|
-
# prepare an internal data loader for this obs
|
|
72
|
-
if obs is None:
|
|
73
|
-
obs = self.template.obs
|
|
74
|
-
loader = obs_loader(
|
|
75
|
-
obs,
|
|
76
|
-
self.merged_formula,
|
|
77
|
-
batch_size=batch_size,
|
|
78
|
-
**kwargs
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
# prepare per-marginal collectors
|
|
82
|
-
n_marginals = len(self.marginals)
|
|
83
|
-
local_pred = [[] for _ in range(n_marginals)]
|
|
84
|
-
|
|
85
|
-
# for each batch, call each marginal's predict on its subset of x
|
|
86
|
-
for _, x_dict in loader:
|
|
87
|
-
for m in range(n_marginals):
|
|
88
|
-
prefix = f"group{m}_"
|
|
89
|
-
# build cur_x where prefixed keys are unprefixed for the marginal
|
|
90
|
-
cur_x = {k.removeprefix(prefix): v for k, v in x_dict.items()}
|
|
91
|
-
params = self.marginals[m][1].predict(cur_x)
|
|
92
|
-
local_pred[m].append(params)
|
|
93
|
-
|
|
94
|
-
# merge batch-wise parameter dicts for each marginal and return
|
|
95
|
-
results = []
|
|
96
|
-
for m in range(n_marginals):
|
|
97
|
-
parts = local_pred[m]
|
|
98
|
-
keys = list(parts[0].keys())
|
|
99
|
-
results.append({k: torch.cat([d[k] for d in parts]).detach().cpu().numpy() for k in keys})
|
|
100
|
-
|
|
101
|
-
return results
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def _var_indices(sel, adata: AnnData) -> np.ndarray:
|
|
105
|
-
"""Return integer indices of `sel` within `adata.var_names`.
|
|
106
|
-
|
|
107
|
-
Expected use: `sel` is a list (or tuple) of variable names (strings).
|
|
108
|
-
"""
|
|
109
|
-
# If sel is a single string, make it a list so we return consistent shape
|
|
110
|
-
single_string = False
|
|
111
|
-
if isinstance(sel, str):
|
|
112
|
-
sel = [sel]
|
|
113
|
-
single_string = True
|
|
114
|
-
|
|
115
|
-
idx = np.asarray(adata.var_names.get_indexer(sel), dtype=int)
|
|
116
|
-
if (idx < 0).any():
|
|
117
|
-
missing = [s for s, i in zip(sel, idx) if i < 0]
|
|
118
|
-
raise KeyError(f"Variables not found in adata.var_names: {missing}")
|
|
119
|
-
return idx if not single_string else idx.reshape(-1)
|
scdesigner/minimal/copula.py
DELETED
|
@@ -1,205 +0,0 @@
|
|
|
1
|
-
from typing import Dict, Callable, Tuple
|
|
2
|
-
import torch
|
|
3
|
-
from anndata import AnnData
|
|
4
|
-
from .loader import adata_loader
|
|
5
|
-
from abc import ABC, abstractmethod
|
|
6
|
-
import numpy as np
|
|
7
|
-
import pandas as pd
|
|
8
|
-
from typing import Optional, Union
|
|
9
|
-
class Copula(ABC):
|
|
10
|
-
def __init__(self, formula: str, **kwargs):
|
|
11
|
-
self.formula = formula
|
|
12
|
-
self.loader = None
|
|
13
|
-
self.n_outcomes = None
|
|
14
|
-
self.parameters = None # Should be a dictionary of CovarianceStructure objects
|
|
15
|
-
|
|
16
|
-
def setup_data(self, adata: AnnData, marginal_formula: Dict[str, str], batch_size: int = 1024, **kwargs):
|
|
17
|
-
self.adata = adata
|
|
18
|
-
self.formula = self.formula | marginal_formula
|
|
19
|
-
self.loader = adata_loader(adata, self.formula, batch_size=batch_size, **kwargs)
|
|
20
|
-
X_batch, _ = next(iter(self.loader))
|
|
21
|
-
self.n_outcomes = X_batch.shape[1]
|
|
22
|
-
|
|
23
|
-
def decorrelate(self, row_pattern: str, col_pattern: str, group: Union[str, list, None] = None):
|
|
24
|
-
"""Decorrelate the covariance matrix for the given row and column patterns.
|
|
25
|
-
|
|
26
|
-
Args:
|
|
27
|
-
row_pattern (str): The regex pattern for the row names to match.
|
|
28
|
-
col_pattern (str): The regex pattern for the column names to match.
|
|
29
|
-
group (Union[str, list, None]): The group or groups to apply the transformation to. If None, the transformation is applied to all groups.
|
|
30
|
-
"""
|
|
31
|
-
if group is None:
|
|
32
|
-
for g in self.groups:
|
|
33
|
-
self.parameters[g].decorrelate(row_pattern, col_pattern)
|
|
34
|
-
elif isinstance(group, str):
|
|
35
|
-
self.parameters[group].decorrelate(row_pattern, col_pattern)
|
|
36
|
-
else:
|
|
37
|
-
for g in group:
|
|
38
|
-
self.parameters[g].decorrelate(row_pattern, col_pattern)
|
|
39
|
-
|
|
40
|
-
def correlate(self, factor: float, row_pattern: str, col_pattern: str, group: Union[str, list, None] = None):
|
|
41
|
-
"""Multiply selected off-diagonal entries by factor.
|
|
42
|
-
|
|
43
|
-
Args:
|
|
44
|
-
row_pattern (str): The regex pattern for the row names to match.
|
|
45
|
-
col_pattern (str): The regex pattern for the column names to match.
|
|
46
|
-
factor (float): The factor to multiply the off-diagonal entries by.
|
|
47
|
-
group (Union[str, list, None]): The group or groups to apply the transformation to. If None, the transformation is applied to all groups.
|
|
48
|
-
"""
|
|
49
|
-
if group is None:
|
|
50
|
-
for g in self.groups:
|
|
51
|
-
self.parameters[g].correlate(row_pattern, col_pattern, factor)
|
|
52
|
-
elif isinstance(group, str):
|
|
53
|
-
self.parameters[group].correlate(row_pattern, col_pattern, factor)
|
|
54
|
-
else:
|
|
55
|
-
for g in group:
|
|
56
|
-
self.parameters[g].correlate(row_pattern, col_pattern, factor)
|
|
57
|
-
|
|
58
|
-
@abstractmethod
|
|
59
|
-
def fit(self, uniformizer: Callable, **kwargs):
|
|
60
|
-
raise NotImplementedError
|
|
61
|
-
|
|
62
|
-
@abstractmethod
|
|
63
|
-
def pseudo_obs(self, x_dict: Dict):
|
|
64
|
-
raise NotImplementedError
|
|
65
|
-
|
|
66
|
-
@abstractmethod
|
|
67
|
-
def likelihood(self, uniformizer: Callable, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
|
|
68
|
-
raise NotImplementedError
|
|
69
|
-
|
|
70
|
-
@abstractmethod
|
|
71
|
-
def num_params(self, **kwargs):
|
|
72
|
-
raise NotImplementedError
|
|
73
|
-
|
|
74
|
-
# @abstractmethod
|
|
75
|
-
# def format_parameters(self):
|
|
76
|
-
# raise NotImplementedError
|
|
77
|
-
|
|
78
|
-
class CovarianceStructure:
|
|
79
|
-
"""
|
|
80
|
-
Efficient storage for covariance matrices in copula-based gene expression modeling.
|
|
81
|
-
|
|
82
|
-
This class provides memory-efficient storage for covariance information by storing
|
|
83
|
-
either a full covariance matrix or a block matrix with diagonal variances for
|
|
84
|
-
remaining genes. This enables fast copula estimation and sampling for large
|
|
85
|
-
gene expression datasets.
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
Attributes
|
|
90
|
-
----------
|
|
91
|
-
cov : pd.DataFrame
|
|
92
|
-
Covariance matrix for modeled genes with gene names as index/columns
|
|
93
|
-
modeled_indices : np.ndarray
|
|
94
|
-
Indices of modeled genes in original ordering
|
|
95
|
-
remaining_var : pd.Series or None
|
|
96
|
-
Diagonal variances for remaining genes, None if full matrix stored
|
|
97
|
-
remaining_indices : np.ndarray or None
|
|
98
|
-
Indices of remaining genes in original ordering
|
|
99
|
-
num_modeled_genes : int
|
|
100
|
-
Number of modeled genes
|
|
101
|
-
num_remaining_genes : int
|
|
102
|
-
Number of remaining genes (0 if full matrix stored)
|
|
103
|
-
total_genes : int
|
|
104
|
-
Total number of genes
|
|
105
|
-
"""
|
|
106
|
-
|
|
107
|
-
def __init__(self, cov: np.ndarray,
|
|
108
|
-
modeled_names: pd.Index,
|
|
109
|
-
modeled_indices: Optional[np.ndarray] = None,
|
|
110
|
-
remaining_var: Optional[np.ndarray] = None,
|
|
111
|
-
remaining_indices: Optional[np.ndarray] = None,
|
|
112
|
-
remaining_names: Optional[pd.Index] = None):
|
|
113
|
-
"""initialize a CovarianceStructure object.
|
|
114
|
-
|
|
115
|
-
Args:
|
|
116
|
-
cov (np.ndarray): Covariance matrix for modeled genes, shape (n_modeled_genes, n_modeled_genes)
|
|
117
|
-
modeled_names (pd.Index): Gene names for the modeled genes
|
|
118
|
-
modeled_indices (Optional[np.ndarray], optional): Indices of modeled genes in original ordering. Defaults to sequential indices.
|
|
119
|
-
remaining_var (Optional[np.ndarray], optional): Diagonal variances for remaining genes, shape (n_remaining_genes,)
|
|
120
|
-
remaining_indices (Optional[np.ndarray], optional): Indices of remaining genes in original ordering
|
|
121
|
-
remaining_names (Optional[pd.Index], optional): Gene names for remaining genes
|
|
122
|
-
"""
|
|
123
|
-
self.cov = pd.DataFrame(cov, index=modeled_names, columns=modeled_names)
|
|
124
|
-
|
|
125
|
-
if modeled_indices is not None:
|
|
126
|
-
self.modeled_indices = modeled_indices
|
|
127
|
-
else:
|
|
128
|
-
self.modeled_indices = np.arange(len(modeled_names))
|
|
129
|
-
|
|
130
|
-
if remaining_var is not None:
|
|
131
|
-
self.remaining_var = pd.Series(remaining_var, index=remaining_names)
|
|
132
|
-
else:
|
|
133
|
-
self.remaining_var = None
|
|
134
|
-
|
|
135
|
-
self.remaining_indices = remaining_indices
|
|
136
|
-
self.num_modeled_genes = len(modeled_names)
|
|
137
|
-
self.num_remaining_genes = len(remaining_indices) if remaining_indices is not None else 0
|
|
138
|
-
self.total_genes = self.num_modeled_genes + self.num_remaining_genes
|
|
139
|
-
|
|
140
|
-
def __repr__(self):
|
|
141
|
-
if self.remaining_var is None:
|
|
142
|
-
return self.cov.__repr__()
|
|
143
|
-
else:
|
|
144
|
-
return f"CovarianceStructure(modeled_genes={self.num_modeled_genes}, \
|
|
145
|
-
total_genes={self.total_genes})"
|
|
146
|
-
|
|
147
|
-
def _repr_html_(self):
|
|
148
|
-
"""Jupyter Notebook display"""
|
|
149
|
-
if self.remaining_var is None:
|
|
150
|
-
return self.cov._repr_html_()
|
|
151
|
-
else:
|
|
152
|
-
html = f"<b>CovarianceStructure:</b> {self.num_modeled_genes} modeled genes, {self.total_genes} total<br>"
|
|
153
|
-
html += "<h4>Modeled Covariance Matrix</h4>" + self.cov._repr_html_()
|
|
154
|
-
html += "<h4>Remaining Gene Variances</h4>" + self.remaining_var.to_frame("variance").T._repr_html_()
|
|
155
|
-
return html
|
|
156
|
-
|
|
157
|
-
def decorrelate(self, row_pattern: str, col_pattern: str):
|
|
158
|
-
"""Decorrelate the covariance matrix for the given row and column patterns.
|
|
159
|
-
"""
|
|
160
|
-
from .transform import data_frame_mask
|
|
161
|
-
m1 = data_frame_mask(self.cov, ".", col_pattern)
|
|
162
|
-
m2 = data_frame_mask(self.cov, row_pattern, ".")
|
|
163
|
-
mask = (m1 | m2)
|
|
164
|
-
np.fill_diagonal(mask, False)
|
|
165
|
-
self.cov.values[mask] = 0
|
|
166
|
-
|
|
167
|
-
def correlate(self, row_pattern: str, col_pattern: str, factor: float):
|
|
168
|
-
"""Multiply selected off-diagonal entries by factor.
|
|
169
|
-
|
|
170
|
-
Args:
|
|
171
|
-
row_pattern (str): The regex pattern for the row names to match.
|
|
172
|
-
col_pattern (str): The regex pattern for the column names to match.
|
|
173
|
-
factor (float): The factor to multiply the off-diagonal entries by.
|
|
174
|
-
"""
|
|
175
|
-
from .transform import data_frame_mask
|
|
176
|
-
m1 = data_frame_mask(self.cov, ".", col_pattern)
|
|
177
|
-
m2 = data_frame_mask(self.cov, row_pattern, ".")
|
|
178
|
-
mask = (m1 | m2)
|
|
179
|
-
np.fill_diagonal(mask, False)
|
|
180
|
-
self.cov.values[mask] = self.cov.values[mask] * factor
|
|
181
|
-
|
|
182
|
-
@property
|
|
183
|
-
def shape(self):
|
|
184
|
-
return (self.total_genes, self.total_genes)
|
|
185
|
-
|
|
186
|
-
def to_full_matrix(self):
|
|
187
|
-
"""
|
|
188
|
-
Convert to full covariance matrix for compatibility/debugging.
|
|
189
|
-
Returns:
|
|
190
|
-
--------
|
|
191
|
-
np.ndarray : Full covariance matrix with shape (total_genes, total_genes)
|
|
192
|
-
"""
|
|
193
|
-
if self.remaining_var is None:
|
|
194
|
-
return self.cov.values
|
|
195
|
-
else:
|
|
196
|
-
full_cov = np.zeros((self.total_genes, self.total_genes))
|
|
197
|
-
|
|
198
|
-
# Fill in top-k block
|
|
199
|
-
ix_modeled = np.ix_(self.modeled_indices, self.modeled_indices)
|
|
200
|
-
full_cov[ix_modeled] = self.cov.values
|
|
201
|
-
|
|
202
|
-
# Fill in diagonal for remaining genes
|
|
203
|
-
full_cov[self.remaining_indices, self.remaining_indices] = self.remaining_var.values
|
|
204
|
-
|
|
205
|
-
return full_cov
|
scdesigner/minimal/formula.py
DELETED
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
from typing import Union
|
|
2
|
-
import warnings
|
|
3
|
-
|
|
4
|
-
def standardize_formula(formula: Union[str, dict], allowed_keys = None):
|
|
5
|
-
# The first element of allowed_keys should be the name of default parameter
|
|
6
|
-
if allowed_keys is None:
|
|
7
|
-
raise ValueError("Internal error: allowed_keys must be specified")
|
|
8
|
-
formula = {allowed_keys[0]: formula} if isinstance(formula, str) else formula
|
|
9
|
-
|
|
10
|
-
formula_keys = set(formula.keys())
|
|
11
|
-
allowed_keys = set(allowed_keys)
|
|
12
|
-
|
|
13
|
-
if not formula_keys & allowed_keys:
|
|
14
|
-
raise ValueError(f"formula must have at least one of the following keys: {allowed_keys}")
|
|
15
|
-
|
|
16
|
-
if extra_keys := formula_keys - allowed_keys:
|
|
17
|
-
warnings.warn(
|
|
18
|
-
f"Invalid formulas in dictionary will not be used: {extra_keys}",
|
|
19
|
-
UserWarning,
|
|
20
|
-
)
|
|
21
|
-
|
|
22
|
-
formula.update({k: '~ 1' for k in allowed_keys - formula_keys})
|
|
23
|
-
return formula
|
scdesigner/minimal/gaussian.py
DELETED
|
@@ -1,65 +0,0 @@
|
|
|
1
|
-
from .formula import standardize_formula
|
|
2
|
-
from .marginal import GLMPredictor, Marginal
|
|
3
|
-
from .loader import _to_numpy
|
|
4
|
-
from typing import Union, Dict, Optional
|
|
5
|
-
import torch
|
|
6
|
-
import numpy as np
|
|
7
|
-
from scipy.stats import norm
|
|
8
|
-
|
|
9
|
-
class Gaussian(Marginal):
|
|
10
|
-
"""Gaussian marginal estimator"""
|
|
11
|
-
def __init__(self, formula: Union[Dict, str]):
|
|
12
|
-
formula = standardize_formula(formula, allowed_keys=['mean', 'sdev'])
|
|
13
|
-
super().__init__(formula)
|
|
14
|
-
|
|
15
|
-
def setup_optimizer(
|
|
16
|
-
self,
|
|
17
|
-
optimizer_class: Optional[callable] = torch.optim.Adam,
|
|
18
|
-
**optimizer_kwargs,
|
|
19
|
-
):
|
|
20
|
-
if self.loader is None:
|
|
21
|
-
raise RuntimeError("self.loader is not set (call setup_data first)")
|
|
22
|
-
|
|
23
|
-
nll = lambda batch: -self.likelihood(batch).sum()
|
|
24
|
-
link_fns = {"mean": lambda x: x}
|
|
25
|
-
self.predict = GLMPredictor(
|
|
26
|
-
n_outcomes=self.n_outcomes,
|
|
27
|
-
feature_dims=self.feature_dims,
|
|
28
|
-
link_fns=link_fns,
|
|
29
|
-
loss_fn=nll,
|
|
30
|
-
optimizer_class=optimizer_class,
|
|
31
|
-
optimizer_kwargs=optimizer_kwargs
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
def likelihood(self, batch):
|
|
35
|
-
"""Compute the negative log-likelihood"""
|
|
36
|
-
y, x = batch
|
|
37
|
-
params = self.predict(x)
|
|
38
|
-
mu = params.get("mean")
|
|
39
|
-
sigma = params.get("sdev")
|
|
40
|
-
|
|
41
|
-
# log likelihood for Gaussian
|
|
42
|
-
log_likelihood = -0.5 * (torch.log(2 * torch.pi * sigma ** 2) + ((y - mu) ** 2) / (sigma ** 2))
|
|
43
|
-
return log_likelihood
|
|
44
|
-
|
|
45
|
-
def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
|
|
46
|
-
"""Invert pseudoobservations."""
|
|
47
|
-
mu, sdev, u = self._local_params(x, u)
|
|
48
|
-
y = norm(loc=mu, scale=sdev).ppf(u)
|
|
49
|
-
return torch.from_numpy(y).float()
|
|
50
|
-
|
|
51
|
-
def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6):
|
|
52
|
-
"""Return uniformized pseudo-observations for counts y given covariates x."""
|
|
53
|
-
# cdf values using scipy's parameterization
|
|
54
|
-
mu, sdev, y = self._local_params(x, y)
|
|
55
|
-
u = norm.cdf(y, loc=mu, scale=sdev)
|
|
56
|
-
u = np.clip(u, epsilon, 1 - epsilon)
|
|
57
|
-
return torch.from_numpy(u).float()
|
|
58
|
-
|
|
59
|
-
def _local_params(self, x, y=None):
|
|
60
|
-
params = self.predict(x)
|
|
61
|
-
mu = params.get('mean')
|
|
62
|
-
sdev = params.get('sdev')
|
|
63
|
-
if y is None:
|
|
64
|
-
return _to_numpy(mu, sdev)
|
|
65
|
-
return _to_numpy(mu, sdev, y)
|
scdesigner/minimal/loader.py
DELETED
|
@@ -1,211 +0,0 @@
|
|
|
1
|
-
from .kwargs import DEFAULT_ALLOWED_KWARGS, _filter_kwargs
|
|
2
|
-
from anndata import AnnData
|
|
3
|
-
from formulaic import model_matrix
|
|
4
|
-
from torch.utils.data import Dataset, DataLoader
|
|
5
|
-
from typing import Dict
|
|
6
|
-
import numpy as np
|
|
7
|
-
import pandas as pd
|
|
8
|
-
import scipy.sparse
|
|
9
|
-
import torch
|
|
10
|
-
|
|
11
|
-
def get_device():
|
|
12
|
-
"""Detect and return the best available device (MPS, CUDA, or CPU)."""
|
|
13
|
-
if torch.backends.mps.is_available():
|
|
14
|
-
return torch.device("mps")
|
|
15
|
-
elif torch.cuda.is_available():
|
|
16
|
-
return torch.device("cuda")
|
|
17
|
-
else:
|
|
18
|
-
return torch.device("cpu")
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class PreloadedDataset(Dataset):
|
|
22
|
-
"""Dataset that assumes x and y are both fully in memory."""
|
|
23
|
-
def __init__(self, y_tensor, x_tensors, predictor_names):
|
|
24
|
-
self.y = y_tensor
|
|
25
|
-
self.x = x_tensors
|
|
26
|
-
self.predictor_names = predictor_names
|
|
27
|
-
|
|
28
|
-
def __len__(self):
|
|
29
|
-
return len(self.y)
|
|
30
|
-
|
|
31
|
-
def __getitem__(self, idx):
|
|
32
|
-
return self.y[idx], {k: v[idx] for k, v in self.x.items()}
|
|
33
|
-
|
|
34
|
-
class AnnDataDataset(Dataset):
|
|
35
|
-
"""Simple PyTorch Dataset for AnnData objects.
|
|
36
|
-
|
|
37
|
-
Supports optional chunked loading for backed AnnData objects. When
|
|
38
|
-
`chunk_size` is provided, the dataset will load contiguous slices
|
|
39
|
-
of rows (of size `chunk_size`) into memory once and serve individual
|
|
40
|
-
rows from that cached chunk. Chunks are moved to device for faster access.
|
|
41
|
-
"""
|
|
42
|
-
def __init__(self, adata: AnnData, formula: Dict[str, str], chunk_size: int):
|
|
43
|
-
self.adata = adata
|
|
44
|
-
self.formula = formula
|
|
45
|
-
self.chunk_size = chunk_size
|
|
46
|
-
self.device = get_device()
|
|
47
|
-
|
|
48
|
-
# keeping track of covariate-related information
|
|
49
|
-
self.obs_levels = categories(self.adata.obs)
|
|
50
|
-
self.obs_matrices = {}
|
|
51
|
-
self.predictor_names = None
|
|
52
|
-
|
|
53
|
-
# Internal cache for the currently loaded chunk
|
|
54
|
-
self._chunk: AnnData | None = None
|
|
55
|
-
self._chunk_X = None
|
|
56
|
-
self._chunk_start = 0
|
|
57
|
-
|
|
58
|
-
def __len__(self):
|
|
59
|
-
return len(self.adata)
|
|
60
|
-
|
|
61
|
-
def __getitem__(self, idx):
|
|
62
|
-
"""Returns (X, obs) for the given index.
|
|
63
|
-
|
|
64
|
-
If `chunk_size` was specified the dataset will load a chunk
|
|
65
|
-
containing `idx` into memory (if not already cached) and
|
|
66
|
-
index into that chunk.
|
|
67
|
-
"""
|
|
68
|
-
self._ensure_chunk_loaded(idx)
|
|
69
|
-
local_idx = idx - self._chunk_start
|
|
70
|
-
|
|
71
|
-
# Get obs data from GPU-cached matrices
|
|
72
|
-
obs_dict = {}
|
|
73
|
-
for key in self.formula.keys():
|
|
74
|
-
obs_dict[key] = self.obs_matrices[key][local_idx: local_idx + 1]
|
|
75
|
-
return self._chunk_X[local_idx], obs_dict
|
|
76
|
-
|
|
77
|
-
def _ensure_chunk_loaded(self, idx: int) -> None:
|
|
78
|
-
"""Load the chunk that contains `idx` into the internal cache."""
|
|
79
|
-
start = (idx // self.chunk_size) * self.chunk_size
|
|
80
|
-
end = min(start + self.chunk_size, len(self.adata))
|
|
81
|
-
|
|
82
|
-
if (self._chunk is None) or not (self._chunk_start <= idx < self._chunk_start + len(self._chunk)):
|
|
83
|
-
# load the next chunk into memory
|
|
84
|
-
chunk = self.adata[start:end]
|
|
85
|
-
if getattr(chunk, 'isbacked', False):
|
|
86
|
-
chunk = chunk.to_memory()
|
|
87
|
-
self._chunk = chunk
|
|
88
|
-
self._chunk_start = start
|
|
89
|
-
|
|
90
|
-
# Move chunk to GPU
|
|
91
|
-
X = chunk.X
|
|
92
|
-
if hasattr(X, 'toarray'):
|
|
93
|
-
X = X.toarray()
|
|
94
|
-
self._chunk_X = torch.tensor(X, dtype=torch.float32).to(self.device)
|
|
95
|
-
|
|
96
|
-
# Compute model matrices for this chunk's `obs` and move to GPU
|
|
97
|
-
obs_coded_chunk = code_levels(self._chunk.obs.copy(), self.obs_levels)
|
|
98
|
-
self.obs_matrices = {}
|
|
99
|
-
predictor_names = {}
|
|
100
|
-
for key, f in self.formula.items():
|
|
101
|
-
mat = model_matrix(f, obs_coded_chunk)
|
|
102
|
-
predictor_names [key] = list(mat.columns)
|
|
103
|
-
self.obs_matrices[key] = torch.tensor(mat.values, dtype=torch.float32).to(self.device)
|
|
104
|
-
|
|
105
|
-
# Capture predictor (column) names from the model matrices once.
|
|
106
|
-
if self.predictor_names is None:
|
|
107
|
-
self.predictor_names = predictor_names
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
def adata_loader(
|
|
111
|
-
adata: AnnData,
|
|
112
|
-
formula: Dict[str, str],
|
|
113
|
-
chunk_size: int = None,
|
|
114
|
-
batch_size: int = 1024,
|
|
115
|
-
shuffle: bool = False,
|
|
116
|
-
num_workers: int = 0,
|
|
117
|
-
**kwargs
|
|
118
|
-
) -> DataLoader:
|
|
119
|
-
"""Create a DataLoader from AnnData that returns batches of (X, obs)."""
|
|
120
|
-
data_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['data'])
|
|
121
|
-
device = get_device()
|
|
122
|
-
|
|
123
|
-
# separate chunked from non-chunked cases
|
|
124
|
-
if not getattr(adata, 'isbacked', False):
|
|
125
|
-
dataset = _preloaded_adata(adata, formula, device)
|
|
126
|
-
else:
|
|
127
|
-
dataset = AnnDataDataset(adata, formula, chunk_size or 5000)
|
|
128
|
-
|
|
129
|
-
return DataLoader(
|
|
130
|
-
dataset,
|
|
131
|
-
batch_size=batch_size,
|
|
132
|
-
shuffle=shuffle,
|
|
133
|
-
num_workers=num_workers,
|
|
134
|
-
collate_fn=dict_collate_fn,
|
|
135
|
-
**data_kwargs
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
def obs_loader(obs: pd.DataFrame, marginal_formula, **kwargs):
|
|
139
|
-
adata = AnnData(X=np.zeros((len(obs), 1)), obs=obs)
|
|
140
|
-
return adata_loader(
|
|
141
|
-
adata,
|
|
142
|
-
marginal_formula,
|
|
143
|
-
**kwargs
|
|
144
|
-
)
|
|
145
|
-
|
|
146
|
-
################################################################################
|
|
147
|
-
## Extraction of in-memory AnnData to PreloadedDataset
|
|
148
|
-
################################################################################
|
|
149
|
-
|
|
150
|
-
def _preloaded_adata(adata: AnnData, formula: Dict[str, str], device: torch.device) -> PreloadedDataset:
|
|
151
|
-
X = adata.X
|
|
152
|
-
if scipy.sparse.issparse(X):
|
|
153
|
-
X = X.toarray()
|
|
154
|
-
y = torch.tensor(X, dtype=torch.float32).to(device)
|
|
155
|
-
|
|
156
|
-
obs = code_levels(adata.obs.copy(), categories(adata.obs))
|
|
157
|
-
x = {
|
|
158
|
-
k: torch.tensor(model_matrix(f, obs).values, dtype=torch.float32).to(device)
|
|
159
|
-
for k, f in formula.items()
|
|
160
|
-
}
|
|
161
|
-
predictor_names = {k: list(model_matrix(f, obs).columns) for k, f in formula.items()}
|
|
162
|
-
return PreloadedDataset(y, x, predictor_names)
|
|
163
|
-
|
|
164
|
-
################################################################################
|
|
165
|
-
## Helper functions
|
|
166
|
-
################################################################################
|
|
167
|
-
|
|
168
|
-
def dict_collate_fn(batch):
|
|
169
|
-
"""
|
|
170
|
-
Custom collate function for handling dictionary obs tensors.
|
|
171
|
-
"""
|
|
172
|
-
X_batch = torch.stack([item[0] for item in batch])
|
|
173
|
-
obs_batch = [item[1] for item in batch]
|
|
174
|
-
|
|
175
|
-
obs_dict = {}
|
|
176
|
-
for key in obs_batch[0].keys():
|
|
177
|
-
obs_dict[key] = torch.stack([obs[key] for obs in obs_batch])
|
|
178
|
-
return X_batch, obs_dict
|
|
179
|
-
|
|
180
|
-
def to_tensor(X):
|
|
181
|
-
# If the tensor is 2D with second dim == 1, squeeze only the first
|
|
182
|
-
# dim when appropriate (e.g. converting a single-row X to 1D samples)
|
|
183
|
-
t = torch.tensor(X, dtype=torch.float32)
|
|
184
|
-
if t.dim() == 2 and t.size(1) == 1:
|
|
185
|
-
if t.size(0) == 1:
|
|
186
|
-
return t.view(1)
|
|
187
|
-
return t
|
|
188
|
-
return t.squeeze()
|
|
189
|
-
|
|
190
|
-
def categories(obs):
|
|
191
|
-
levels = {}
|
|
192
|
-
for k in obs.columns:
|
|
193
|
-
obs_type = str(obs[k].dtype)
|
|
194
|
-
if obs_type in ["category", "object"]:
|
|
195
|
-
levels[k] = obs[k].unique()
|
|
196
|
-
return levels
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
def code_levels(obs, categories):
|
|
200
|
-
for k in obs.columns:
|
|
201
|
-
if str(obs[k].dtype) == "category":
|
|
202
|
-
obs[k] = obs[k].astype(pd.CategoricalDtype(categories[k]))
|
|
203
|
-
return obs
|
|
204
|
-
|
|
205
|
-
###############################################################################
|
|
206
|
-
## Misc. Helper functions
|
|
207
|
-
###############################################################################
|
|
208
|
-
|
|
209
|
-
def _to_numpy(*tensors):
|
|
210
|
-
"""Convenience helper: detach, move to CPU, and convert tensors to numpy arrays."""
|
|
211
|
-
return tuple(t.detach().cpu().numpy() for t in tensors)
|